Initial commit
This commit is contained in:
24
src/__init__.py
Normal file
24
src/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
OpenAI/Codex CLI 自动注册系统
|
||||
"""
|
||||
|
||||
from .config import get_settings, EmailServiceType
|
||||
from .database import get_db, Account, EmailService, RegistrationTask
|
||||
from .core import RegistrationEngine, RegistrationResult
|
||||
from .services import EmailServiceFactory, BaseEmailService
|
||||
|
||||
__version__ = "2.0.0"
|
||||
__author__ = "Yasal"
|
||||
|
||||
__all__ = [
|
||||
'get_settings',
|
||||
'EmailServiceType',
|
||||
'get_db',
|
||||
'Account',
|
||||
'EmailService',
|
||||
'RegistrationTask',
|
||||
'RegistrationEngine',
|
||||
'RegistrationResult',
|
||||
'EmailServiceFactory',
|
||||
'BaseEmailService',
|
||||
]
|
||||
53
src/config/__init__.py
Normal file
53
src/config/__init__.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
配置模块
|
||||
"""
|
||||
|
||||
from .settings import (
|
||||
Settings,
|
||||
get_settings,
|
||||
update_settings,
|
||||
get_database_url,
|
||||
init_default_settings,
|
||||
get_setting_definition,
|
||||
get_all_setting_definitions,
|
||||
SETTING_DEFINITIONS,
|
||||
SettingCategory,
|
||||
SettingDefinition,
|
||||
)
|
||||
from .constants import (
|
||||
AccountStatus,
|
||||
TaskStatus,
|
||||
EmailServiceType,
|
||||
APP_NAME,
|
||||
APP_VERSION,
|
||||
OTP_CODE_PATTERN,
|
||||
DEFAULT_PASSWORD_LENGTH,
|
||||
PASSWORD_CHARSET,
|
||||
DEFAULT_USER_INFO,
|
||||
generate_random_user_info,
|
||||
OPENAI_API_ENDPOINTS,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'Settings',
|
||||
'get_settings',
|
||||
'update_settings',
|
||||
'get_database_url',
|
||||
'init_default_settings',
|
||||
'get_setting_definition',
|
||||
'get_all_setting_definitions',
|
||||
'SETTING_DEFINITIONS',
|
||||
'SettingCategory',
|
||||
'SettingDefinition',
|
||||
'AccountStatus',
|
||||
'TaskStatus',
|
||||
'EmailServiceType',
|
||||
'APP_NAME',
|
||||
'APP_VERSION',
|
||||
'OTP_CODE_PATTERN',
|
||||
'DEFAULT_PASSWORD_LENGTH',
|
||||
'PASSWORD_CHARSET',
|
||||
'DEFAULT_USER_INFO',
|
||||
'generate_random_user_info',
|
||||
'OPENAI_API_ENDPOINTS',
|
||||
]
|
||||
399
src/config/constants.py
Normal file
399
src/config/constants.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
常量定义
|
||||
"""
|
||||
|
||||
import random
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 枚举类型
|
||||
# ============================================================================
|
||||
|
||||
class AccountStatus(str, Enum):
|
||||
"""账户状态"""
|
||||
ACTIVE = "active"
|
||||
EXPIRED = "expired"
|
||||
BANNED = "banned"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""任务状态"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class EmailServiceType(str, Enum):
|
||||
"""邮箱服务类型"""
|
||||
TEMPMAIL = "tempmail"
|
||||
OUTLOOK = "outlook"
|
||||
MOE_MAIL = "moe_mail"
|
||||
TEMP_MAIL = "temp_mail"
|
||||
DUCK_MAIL = "duck_mail"
|
||||
FREEMAIL = "freemail"
|
||||
IMAP_MAIL = "imap_mail"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 应用常量
|
||||
# ============================================================================
|
||||
|
||||
APP_NAME = "OpenAI/Codex CLI 自动注册系统"
|
||||
APP_VERSION = "2.0.0"
|
||||
APP_DESCRIPTION = "自动注册 OpenAI/Codex CLI 账号的系统"
|
||||
|
||||
# ============================================================================
|
||||
# OpenAI OAuth 相关常量
|
||||
# ============================================================================
|
||||
|
||||
# OAuth 参数
|
||||
OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
OAUTH_AUTH_URL = "https://auth.openai.com/oauth/authorize"
|
||||
OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
OAUTH_REDIRECT_URI = "http://localhost:1455/auth/callback"
|
||||
OAUTH_SCOPE = "openid email profile offline_access"
|
||||
|
||||
# OpenAI API 端点
|
||||
OPENAI_API_ENDPOINTS = {
|
||||
"sentinel": "https://sentinel.openai.com/backend-api/sentinel/req",
|
||||
"signup": "https://auth.openai.com/api/accounts/authorize/continue",
|
||||
"register": "https://auth.openai.com/api/accounts/user/register",
|
||||
"password_verify": "https://auth.openai.com/api/accounts/password/verify",
|
||||
"send_otp": "https://auth.openai.com/api/accounts/email-otp/send",
|
||||
"validate_otp": "https://auth.openai.com/api/accounts/email-otp/validate",
|
||||
"create_account": "https://auth.openai.com/api/accounts/create_account",
|
||||
"select_workspace": "https://auth.openai.com/api/accounts/workspace/select",
|
||||
}
|
||||
|
||||
# OpenAI 页面类型(用于判断账号状态)
|
||||
OPENAI_PAGE_TYPES = {
|
||||
"EMAIL_OTP_VERIFICATION": "email_otp_verification", # 已注册账号,需要 OTP 验证
|
||||
"PASSWORD_REGISTRATION": "create_account_password", # 新账号,需要设置密码
|
||||
"LOGIN_PASSWORD": "login_password", # 登录流程,需要输入密码
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# 邮箱服务相关常量
|
||||
# ============================================================================
|
||||
|
||||
# Tempmail.lol API 端点
|
||||
TEMPMAIL_API_ENDPOINTS = {
|
||||
"create_inbox": "/inbox/create",
|
||||
"get_inbox": "/inbox",
|
||||
}
|
||||
|
||||
# 自定义域名邮箱 API 端点
|
||||
CUSTOM_DOMAIN_API_ENDPOINTS = {
|
||||
"get_config": "/api/config",
|
||||
"create_email": "/api/emails/generate",
|
||||
"list_emails": "/api/emails",
|
||||
"get_email_messages": "/api/emails/{emailId}",
|
||||
"delete_email": "/api/emails/{emailId}",
|
||||
"get_message": "/api/emails/{emailId}/{messageId}",
|
||||
}
|
||||
|
||||
# 邮箱服务默认配置
|
||||
EMAIL_SERVICE_DEFAULTS = {
|
||||
"tempmail": {
|
||||
"base_url": "https://api.tempmail.lol/v2",
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
},
|
||||
"outlook": {
|
||||
"imap_server": "outlook.office365.com",
|
||||
"imap_port": 993,
|
||||
"smtp_server": "smtp.office365.com",
|
||||
"smtp_port": 587,
|
||||
"timeout": 30,
|
||||
},
|
||||
"moe_mail": {
|
||||
"base_url": "", # 需要用户配置
|
||||
"api_key_header": "X-API-Key",
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
},
|
||||
"duck_mail": {
|
||||
"base_url": "",
|
||||
"default_domain": "",
|
||||
"password_length": 12,
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
},
|
||||
"freemail": {
|
||||
"base_url": "",
|
||||
"admin_token": "",
|
||||
"domain": "",
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
},
|
||||
"imap_mail": {
|
||||
"host": "",
|
||||
"port": 993,
|
||||
"use_ssl": True,
|
||||
"email": "",
|
||||
"password": "",
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
}
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# 注册流程相关常量
|
||||
# ============================================================================
|
||||
|
||||
# 验证码相关
|
||||
OTP_CODE_PATTERN = r"(?<!\d)(\d{6})(?!\d)"
|
||||
OTP_MAX_ATTEMPTS = 40 # 最大轮询次数
|
||||
|
||||
# 验证码提取正则(增强版)
|
||||
# 简单匹配:任意 6 位数字
|
||||
OTP_CODE_SIMPLE_PATTERN = r"(?<!\d)(\d{6})(?!\d)"
|
||||
# 语义匹配:带上下文的验证码(如 "code is 123456", "验证码 123456")
|
||||
OTP_CODE_SEMANTIC_PATTERN = r'(?:code\s+is|验证码[是为]?\s*[::]?\s*)(\d{6})'
|
||||
|
||||
# OpenAI 验证邮件发件人
|
||||
OPENAI_EMAIL_SENDERS = [
|
||||
"noreply@openai.com",
|
||||
"no-reply@openai.com",
|
||||
"@openai.com", # 精确域名匹配
|
||||
".openai.com", # 子域名匹配(如 otp@tm1.openai.com)
|
||||
]
|
||||
|
||||
# OpenAI 验证邮件关键词
|
||||
OPENAI_VERIFICATION_KEYWORDS = [
|
||||
"verify your email",
|
||||
"verification code",
|
||||
"验证码",
|
||||
"your openai code",
|
||||
"code is",
|
||||
"one-time code",
|
||||
]
|
||||
|
||||
# 密码生成
|
||||
PASSWORD_CHARSET = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
DEFAULT_PASSWORD_LENGTH = 12
|
||||
|
||||
# 用户信息生成(用于注册)
|
||||
|
||||
# 常用英文名
|
||||
FIRST_NAMES = [
|
||||
"James", "John", "Robert", "Michael", "William", "David", "Richard", "Joseph", "Thomas", "Charles",
|
||||
"Emma", "Olivia", "Ava", "Isabella", "Sophia", "Mia", "Charlotte", "Amelia", "Harper", "Evelyn",
|
||||
"Alex", "Jordan", "Taylor", "Morgan", "Casey", "Riley", "Jamie", "Avery", "Quinn", "Skyler",
|
||||
"Liam", "Noah", "Ethan", "Lucas", "Mason", "Oliver", "Elijah", "Aiden", "Henry", "Sebastian",
|
||||
"Grace", "Lily", "Chloe", "Zoey", "Nora", "Aria", "Hazel", "Aurora", "Stella", "Ivy"
|
||||
]
|
||||
|
||||
def generate_random_user_info() -> dict:
|
||||
"""
|
||||
生成随机用户信息
|
||||
|
||||
Returns:
|
||||
包含 name 和 birthdate 的字典
|
||||
"""
|
||||
# 随机选择名字
|
||||
name = random.choice(FIRST_NAMES)
|
||||
|
||||
# 生成随机生日(18-45岁)
|
||||
current_year = datetime.now().year
|
||||
birth_year = random.randint(current_year - 45, current_year - 18)
|
||||
birth_month = random.randint(1, 12)
|
||||
# 根据月份确定天数
|
||||
if birth_month in [1, 3, 5, 7, 8, 10, 12]:
|
||||
birth_day = random.randint(1, 31)
|
||||
elif birth_month in [4, 6, 9, 11]:
|
||||
birth_day = random.randint(1, 30)
|
||||
else:
|
||||
# 2月,简化处理
|
||||
birth_day = random.randint(1, 28)
|
||||
|
||||
birthdate = f"{birth_year}-{birth_month:02d}-{birth_day:02d}"
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"birthdate": birthdate
|
||||
}
|
||||
|
||||
# 保留默认值供兼容
|
||||
DEFAULT_USER_INFO = {
|
||||
"name": "Neo",
|
||||
"birthdate": "2000-02-20",
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# 代理相关常量
|
||||
# ============================================================================
|
||||
|
||||
PROXY_TYPES = ["http", "socks5", "socks5h"]
|
||||
DEFAULT_PROXY_CONFIG = {
|
||||
"enabled": False,
|
||||
"type": "http",
|
||||
"host": "127.0.0.1",
|
||||
"port": 7890,
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# 数据库相关常量
|
||||
# ============================================================================
|
||||
|
||||
# 数据库表名
|
||||
DB_TABLE_NAMES = {
|
||||
"accounts": "accounts",
|
||||
"email_services": "email_services",
|
||||
"registration_tasks": "registration_tasks",
|
||||
"settings": "settings",
|
||||
}
|
||||
|
||||
# 默认设置
|
||||
DEFAULT_SETTINGS = [
|
||||
# (key, value, description, category)
|
||||
("system.name", APP_NAME, "系统名称", "general"),
|
||||
("system.version", APP_VERSION, "系统版本", "general"),
|
||||
("logs.retention_days", "30", "日志保留天数", "general"),
|
||||
("openai.client_id", OAUTH_CLIENT_ID, "OpenAI OAuth Client ID", "openai"),
|
||||
("openai.auth_url", OAUTH_AUTH_URL, "OpenAI 认证地址", "openai"),
|
||||
("openai.token_url", OAUTH_TOKEN_URL, "OpenAI Token 地址", "openai"),
|
||||
("openai.redirect_uri", OAUTH_REDIRECT_URI, "OpenAI 回调地址", "openai"),
|
||||
("openai.scope", OAUTH_SCOPE, "OpenAI 权限范围", "openai"),
|
||||
("proxy.enabled", "false", "是否启用代理", "proxy"),
|
||||
("proxy.type", "http", "代理类型 (http/socks5)", "proxy"),
|
||||
("proxy.host", "127.0.0.1", "代理主机", "proxy"),
|
||||
("proxy.port", "7890", "代理端口", "proxy"),
|
||||
("registration.max_retries", "3", "最大重试次数", "registration"),
|
||||
("registration.timeout", "120", "超时时间(秒)", "registration"),
|
||||
("registration.default_password_length", "12", "默认密码长度", "registration"),
|
||||
("webui.host", "0.0.0.0", "Web UI 监听主机", "webui"),
|
||||
("webui.port", "8000", "Web UI 监听端口", "webui"),
|
||||
("webui.debug", "true", "调试模式", "webui"),
|
||||
]
|
||||
|
||||
# ============================================================================
|
||||
# Web UI 相关常量
|
||||
# ============================================================================
|
||||
|
||||
# WebSocket 事件
|
||||
WEBSOCKET_EVENTS = {
|
||||
"CONNECT": "connect",
|
||||
"DISCONNECT": "disconnect",
|
||||
"LOG": "log",
|
||||
"STATUS": "status",
|
||||
"ERROR": "error",
|
||||
"COMPLETE": "complete",
|
||||
}
|
||||
|
||||
# API 响应状态码
|
||||
API_STATUS_CODES = {
|
||||
"SUCCESS": 200,
|
||||
"CREATED": 201,
|
||||
"BAD_REQUEST": 400,
|
||||
"UNAUTHORIZED": 401,
|
||||
"FORBIDDEN": 403,
|
||||
"NOT_FOUND": 404,
|
||||
"CONFLICT": 409,
|
||||
"INTERNAL_ERROR": 500,
|
||||
}
|
||||
|
||||
# 分页
|
||||
DEFAULT_PAGE_SIZE = 20
|
||||
MAX_PAGE_SIZE = 100
|
||||
|
||||
# ============================================================================
|
||||
# 错误消息
|
||||
# ============================================================================
|
||||
|
||||
ERROR_MESSAGES = {
|
||||
# 通用错误
|
||||
"DATABASE_ERROR": "数据库操作失败",
|
||||
"CONFIG_ERROR": "配置错误",
|
||||
"NETWORK_ERROR": "网络连接失败",
|
||||
"TIMEOUT": "操作超时",
|
||||
"VALIDATION_ERROR": "参数验证失败",
|
||||
|
||||
# 邮箱服务错误
|
||||
"EMAIL_SERVICE_UNAVAILABLE": "邮箱服务不可用",
|
||||
"EMAIL_CREATION_FAILED": "创建邮箱失败",
|
||||
"OTP_NOT_RECEIVED": "未收到验证码",
|
||||
"OTP_INVALID": "验证码无效",
|
||||
|
||||
# OpenAI 相关错误
|
||||
"OPENAI_AUTH_FAILED": "OpenAI 认证失败",
|
||||
"OPENAI_RATE_LIMIT": "OpenAI 接口限流",
|
||||
"OPENAI_CAPTCHA": "遇到验证码",
|
||||
|
||||
# 代理错误
|
||||
"PROXY_FAILED": "代理连接失败",
|
||||
"PROXY_AUTH_FAILED": "代理认证失败",
|
||||
|
||||
# 账户错误
|
||||
"ACCOUNT_NOT_FOUND": "账户不存在",
|
||||
"ACCOUNT_ALREADY_EXISTS": "账户已存在",
|
||||
"ACCOUNT_INVALID": "账户无效",
|
||||
|
||||
# 任务错误
|
||||
"TASK_NOT_FOUND": "任务不存在",
|
||||
"TASK_ALREADY_RUNNING": "任务已在运行中",
|
||||
"TASK_CANCELLED": "任务已取消",
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# 正则表达式
|
||||
# ============================================================================
|
||||
|
||||
REGEX_PATTERNS = {
|
||||
"EMAIL": r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$",
|
||||
"URL": r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+",
|
||||
"IP_ADDRESS": r"\b(?:\d{1,3}\.){3}\d{1,3}\b",
|
||||
"OTP_CODE": OTP_CODE_PATTERN,
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# 时间常量
|
||||
# ============================================================================
|
||||
|
||||
TIME_CONSTANTS = {
|
||||
"SECOND": 1,
|
||||
"MINUTE": 60,
|
||||
"HOUR": 3600,
|
||||
"DAY": 86400,
|
||||
"WEEK": 604800,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Microsoft/Outlook 相关常量
|
||||
# ============================================================================
|
||||
|
||||
# Microsoft OAuth2 Token 端点
|
||||
MICROSOFT_TOKEN_ENDPOINTS = {
|
||||
# 旧版 IMAP 使用的端点
|
||||
"LIVE": "https://login.live.com/oauth20_token.srf",
|
||||
# 新版 IMAP 使用的端点(需要特定 scope)
|
||||
"CONSUMERS": "https://login.microsoftonline.com/consumers/oauth2/v2.0/token",
|
||||
# Graph API 使用的端点
|
||||
"COMMON": "https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||
}
|
||||
|
||||
# IMAP 服务器配置
|
||||
OUTLOOK_IMAP_SERVERS = {
|
||||
"OLD": "outlook.office365.com", # 旧版 IMAP
|
||||
"NEW": "outlook.live.com", # 新版 IMAP
|
||||
}
|
||||
|
||||
# Microsoft OAuth2 Scopes
|
||||
MICROSOFT_SCOPES = {
|
||||
# 旧版 IMAP 不需要特定 scope
|
||||
"IMAP_OLD": "",
|
||||
# 新版 IMAP 需要的 scope
|
||||
"IMAP_NEW": "https://outlook.office.com/IMAP.AccessAsUser.All offline_access",
|
||||
# Graph API 需要的 scope
|
||||
"GRAPH_API": "https://graph.microsoft.com/.default",
|
||||
}
|
||||
|
||||
# Outlook 提供者默认优先级
|
||||
OUTLOOK_PROVIDER_PRIORITY = ["imap_new", "imap_old", "graph_api"]
|
||||
767
src/config/settings.py
Normal file
767
src/config/settings.py
Normal file
@@ -0,0 +1,767 @@
|
||||
"""
|
||||
配置管理 - 完全基于数据库存储
|
||||
所有配置都从数据库读取,不再使用环境变量或 .env 文件
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict, Any, Type, List
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic.types import SecretStr
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class SettingCategory(str, Enum):
|
||||
"""设置分类"""
|
||||
GENERAL = "general"
|
||||
DATABASE = "database"
|
||||
WEBUI = "webui"
|
||||
LOG = "log"
|
||||
OPENAI = "openai"
|
||||
PROXY = "proxy"
|
||||
REGISTRATION = "registration"
|
||||
EMAIL = "email"
|
||||
TEMPMAIL = "tempmail"
|
||||
CUSTOM_DOMAIN = "moe_mail"
|
||||
SECURITY = "security"
|
||||
CPA = "cpa"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SettingDefinition:
|
||||
"""设置定义"""
|
||||
db_key: str
|
||||
default_value: Any
|
||||
category: SettingCategory
|
||||
description: str = ""
|
||||
is_secret: bool = False
|
||||
|
||||
|
||||
# 所有配置项定义(包含数据库键名、默认值、分类、描述)
|
||||
SETTING_DEFINITIONS: Dict[str, SettingDefinition] = {
|
||||
# 应用信息
|
||||
"app_name": SettingDefinition(
|
||||
db_key="app.name",
|
||||
default_value="OpenAI/Codex CLI 自动注册系统",
|
||||
category=SettingCategory.GENERAL,
|
||||
description="应用名称"
|
||||
),
|
||||
"app_version": SettingDefinition(
|
||||
db_key="app.version",
|
||||
default_value="2.0.0",
|
||||
category=SettingCategory.GENERAL,
|
||||
description="应用版本"
|
||||
),
|
||||
"debug": SettingDefinition(
|
||||
db_key="app.debug",
|
||||
default_value=False,
|
||||
category=SettingCategory.GENERAL,
|
||||
description="调试模式"
|
||||
),
|
||||
|
||||
# 数据库配置
|
||||
"database_url": SettingDefinition(
|
||||
db_key="database.url",
|
||||
default_value="data/database.db",
|
||||
category=SettingCategory.DATABASE,
|
||||
description="数据库路径或连接字符串"
|
||||
),
|
||||
|
||||
# Web UI 配置
|
||||
"webui_host": SettingDefinition(
|
||||
db_key="webui.host",
|
||||
default_value="0.0.0.0",
|
||||
category=SettingCategory.WEBUI,
|
||||
description="Web UI 监听地址"
|
||||
),
|
||||
"webui_port": SettingDefinition(
|
||||
db_key="webui.port",
|
||||
default_value=8000,
|
||||
category=SettingCategory.WEBUI,
|
||||
description="Web UI 监听端口"
|
||||
),
|
||||
"webui_secret_key": SettingDefinition(
|
||||
db_key="webui.secret_key",
|
||||
default_value="your-secret-key-change-in-production",
|
||||
category=SettingCategory.WEBUI,
|
||||
description="Web UI 密钥",
|
||||
is_secret=True
|
||||
),
|
||||
"webui_access_password": SettingDefinition(
|
||||
db_key="webui.access_password",
|
||||
default_value="admin123",
|
||||
category=SettingCategory.WEBUI,
|
||||
description="Web UI 访问密码",
|
||||
is_secret=True
|
||||
),
|
||||
|
||||
# 日志配置
|
||||
"log_level": SettingDefinition(
|
||||
db_key="log.level",
|
||||
default_value="INFO",
|
||||
category=SettingCategory.LOG,
|
||||
description="日志级别"
|
||||
),
|
||||
"log_file": SettingDefinition(
|
||||
db_key="log.file",
|
||||
default_value="logs/app.log",
|
||||
category=SettingCategory.LOG,
|
||||
description="日志文件路径"
|
||||
),
|
||||
"log_retention_days": SettingDefinition(
|
||||
db_key="log.retention_days",
|
||||
default_value=30,
|
||||
category=SettingCategory.LOG,
|
||||
description="日志保留天数"
|
||||
),
|
||||
|
||||
# OpenAI 配置
|
||||
"openai_client_id": SettingDefinition(
|
||||
db_key="openai.client_id",
|
||||
default_value="app_EMoamEEZ73f0CkXaXp7hrann",
|
||||
category=SettingCategory.OPENAI,
|
||||
description="OpenAI OAuth 客户端 ID"
|
||||
),
|
||||
"openai_auth_url": SettingDefinition(
|
||||
db_key="openai.auth_url",
|
||||
default_value="https://auth.openai.com/oauth/authorize",
|
||||
category=SettingCategory.OPENAI,
|
||||
description="OpenAI OAuth 授权 URL"
|
||||
),
|
||||
"openai_token_url": SettingDefinition(
|
||||
db_key="openai.token_url",
|
||||
default_value="https://auth.openai.com/oauth/token",
|
||||
category=SettingCategory.OPENAI,
|
||||
description="OpenAI OAuth Token URL"
|
||||
),
|
||||
"openai_redirect_uri": SettingDefinition(
|
||||
db_key="openai.redirect_uri",
|
||||
default_value="http://localhost:1455/auth/callback",
|
||||
category=SettingCategory.OPENAI,
|
||||
description="OpenAI OAuth 回调 URI"
|
||||
),
|
||||
"openai_scope": SettingDefinition(
|
||||
db_key="openai.scope",
|
||||
default_value="openid email profile offline_access",
|
||||
category=SettingCategory.OPENAI,
|
||||
description="OpenAI OAuth 权限范围"
|
||||
),
|
||||
|
||||
# 代理配置
|
||||
"proxy_enabled": SettingDefinition(
|
||||
db_key="proxy.enabled",
|
||||
default_value=False,
|
||||
category=SettingCategory.PROXY,
|
||||
description="是否启用代理"
|
||||
),
|
||||
"proxy_type": SettingDefinition(
|
||||
db_key="proxy.type",
|
||||
default_value="http",
|
||||
category=SettingCategory.PROXY,
|
||||
description="代理类型 (http/socks5)"
|
||||
),
|
||||
"proxy_host": SettingDefinition(
|
||||
db_key="proxy.host",
|
||||
default_value="127.0.0.1",
|
||||
category=SettingCategory.PROXY,
|
||||
description="代理服务器地址"
|
||||
),
|
||||
"proxy_port": SettingDefinition(
|
||||
db_key="proxy.port",
|
||||
default_value=7890,
|
||||
category=SettingCategory.PROXY,
|
||||
description="代理服务器端口"
|
||||
),
|
||||
"proxy_username": SettingDefinition(
|
||||
db_key="proxy.username",
|
||||
default_value="",
|
||||
category=SettingCategory.PROXY,
|
||||
description="代理用户名"
|
||||
),
|
||||
"proxy_password": SettingDefinition(
|
||||
db_key="proxy.password",
|
||||
default_value="",
|
||||
category=SettingCategory.PROXY,
|
||||
description="代理密码",
|
||||
is_secret=True
|
||||
),
|
||||
"proxy_dynamic_enabled": SettingDefinition(
|
||||
db_key="proxy.dynamic_enabled",
|
||||
default_value=False,
|
||||
category=SettingCategory.PROXY,
|
||||
description="是否启用动态代理"
|
||||
),
|
||||
"proxy_dynamic_api_url": SettingDefinition(
|
||||
db_key="proxy.dynamic_api_url",
|
||||
default_value="",
|
||||
category=SettingCategory.PROXY,
|
||||
description="动态代理 API 地址,返回代理 URL 字符串"
|
||||
),
|
||||
"proxy_dynamic_api_key": SettingDefinition(
|
||||
db_key="proxy.dynamic_api_key",
|
||||
default_value="",
|
||||
category=SettingCategory.PROXY,
|
||||
description="动态代理 API 密钥(可选)",
|
||||
is_secret=True
|
||||
),
|
||||
"proxy_dynamic_api_key_header": SettingDefinition(
|
||||
db_key="proxy.dynamic_api_key_header",
|
||||
default_value="X-API-Key",
|
||||
category=SettingCategory.PROXY,
|
||||
description="动态代理 API 密钥请求头名称"
|
||||
),
|
||||
"proxy_dynamic_result_field": SettingDefinition(
|
||||
db_key="proxy.dynamic_result_field",
|
||||
default_value="",
|
||||
category=SettingCategory.PROXY,
|
||||
description="从 JSON 响应中提取代理 URL 的字段路径(留空则使用响应原文)"
|
||||
),
|
||||
|
||||
# 注册配置
|
||||
"registration_max_retries": SettingDefinition(
|
||||
db_key="registration.max_retries",
|
||||
default_value=3,
|
||||
category=SettingCategory.REGISTRATION,
|
||||
description="注册最大重试次数"
|
||||
),
|
||||
"registration_timeout": SettingDefinition(
|
||||
db_key="registration.timeout",
|
||||
default_value=120,
|
||||
category=SettingCategory.REGISTRATION,
|
||||
description="注册超时时间(秒)"
|
||||
),
|
||||
"registration_default_password_length": SettingDefinition(
|
||||
db_key="registration.default_password_length",
|
||||
default_value=12,
|
||||
category=SettingCategory.REGISTRATION,
|
||||
description="默认密码长度"
|
||||
),
|
||||
"registration_sleep_min": SettingDefinition(
|
||||
db_key="registration.sleep_min",
|
||||
default_value=5,
|
||||
category=SettingCategory.REGISTRATION,
|
||||
description="注册间隔最小值(秒)"
|
||||
),
|
||||
"registration_sleep_max": SettingDefinition(
|
||||
db_key="registration.sleep_max",
|
||||
default_value=30,
|
||||
category=SettingCategory.REGISTRATION,
|
||||
description="注册间隔最大值(秒)"
|
||||
),
|
||||
|
||||
# 邮箱服务配置
|
||||
"email_service_priority": SettingDefinition(
|
||||
db_key="email.service_priority",
|
||||
default_value={"tempmail": 0, "outlook": 1, "moe_mail": 2},
|
||||
category=SettingCategory.EMAIL,
|
||||
description="邮箱服务优先级"
|
||||
),
|
||||
|
||||
# Tempmail.lol 配置
|
||||
"tempmail_base_url": SettingDefinition(
|
||||
db_key="tempmail.base_url",
|
||||
default_value="https://api.tempmail.lol/v2",
|
||||
category=SettingCategory.TEMPMAIL,
|
||||
description="Tempmail API 地址"
|
||||
),
|
||||
"tempmail_timeout": SettingDefinition(
|
||||
db_key="tempmail.timeout",
|
||||
default_value=30,
|
||||
category=SettingCategory.TEMPMAIL,
|
||||
description="Tempmail 超时时间(秒)"
|
||||
),
|
||||
"tempmail_max_retries": SettingDefinition(
|
||||
db_key="tempmail.max_retries",
|
||||
default_value=3,
|
||||
category=SettingCategory.TEMPMAIL,
|
||||
description="Tempmail 最大重试次数"
|
||||
),
|
||||
|
||||
# 自定义域名邮箱配置
|
||||
"custom_domain_base_url": SettingDefinition(
|
||||
db_key="custom_domain.base_url",
|
||||
default_value="",
|
||||
category=SettingCategory.CUSTOM_DOMAIN,
|
||||
description="自定义域名 API 地址"
|
||||
),
|
||||
"custom_domain_api_key": SettingDefinition(
|
||||
db_key="custom_domain.api_key",
|
||||
default_value="",
|
||||
category=SettingCategory.CUSTOM_DOMAIN,
|
||||
description="自定义域名 API 密钥",
|
||||
is_secret=True
|
||||
),
|
||||
|
||||
# 安全配置
|
||||
"encryption_key": SettingDefinition(
|
||||
db_key="security.encryption_key",
|
||||
default_value="your-encryption-key-change-in-production",
|
||||
category=SettingCategory.SECURITY,
|
||||
description="加密密钥",
|
||||
is_secret=True
|
||||
),
|
||||
|
||||
# Team Manager 配置
|
||||
"tm_enabled": SettingDefinition(
|
||||
db_key="tm.enabled",
|
||||
default_value=False,
|
||||
category=SettingCategory.GENERAL,
|
||||
description="是否启用 Team Manager 上传"
|
||||
),
|
||||
"tm_api_url": SettingDefinition(
|
||||
db_key="tm.api_url",
|
||||
default_value="",
|
||||
category=SettingCategory.GENERAL,
|
||||
description="Team Manager API 地址"
|
||||
),
|
||||
"tm_api_key": SettingDefinition(
|
||||
db_key="tm.api_key",
|
||||
default_value="",
|
||||
category=SettingCategory.GENERAL,
|
||||
description="Team Manager API Key",
|
||||
is_secret=True
|
||||
),
|
||||
|
||||
# CPA 上传配置
|
||||
"cpa_enabled": SettingDefinition(
|
||||
db_key="cpa.enabled",
|
||||
default_value=False,
|
||||
category=SettingCategory.CPA,
|
||||
description="是否启用 CPA 上传"
|
||||
),
|
||||
"cpa_api_url": SettingDefinition(
|
||||
db_key="cpa.api_url",
|
||||
default_value="",
|
||||
category=SettingCategory.CPA,
|
||||
description="CPA API 地址"
|
||||
),
|
||||
"cpa_api_token": SettingDefinition(
|
||||
db_key="cpa.api_token",
|
||||
default_value="",
|
||||
category=SettingCategory.CPA,
|
||||
description="CPA API Token",
|
||||
is_secret=True
|
||||
),
|
||||
|
||||
# 验证码配置
|
||||
"email_code_timeout": SettingDefinition(
|
||||
db_key="email_code.timeout",
|
||||
default_value=120,
|
||||
category=SettingCategory.EMAIL,
|
||||
description="验证码等待超时时间(秒)"
|
||||
),
|
||||
"email_code_poll_interval": SettingDefinition(
|
||||
db_key="email_code.poll_interval",
|
||||
default_value=3,
|
||||
category=SettingCategory.EMAIL,
|
||||
description="验证码轮询间隔(秒)"
|
||||
),
|
||||
|
||||
# Outlook 配置
|
||||
"outlook_provider_priority": SettingDefinition(
|
||||
db_key="outlook.provider_priority",
|
||||
default_value=["imap_old", "imap_new", "graph_api"],
|
||||
category=SettingCategory.EMAIL,
|
||||
description="Outlook 提供者优先级"
|
||||
),
|
||||
"outlook_health_failure_threshold": SettingDefinition(
|
||||
db_key="outlook.health_failure_threshold",
|
||||
default_value=5,
|
||||
category=SettingCategory.EMAIL,
|
||||
description="Outlook 提供者连续失败次数阈值"
|
||||
),
|
||||
"outlook_health_disable_duration": SettingDefinition(
|
||||
db_key="outlook.health_disable_duration",
|
||||
default_value=60,
|
||||
category=SettingCategory.EMAIL,
|
||||
description="Outlook 提供者禁用时长(秒)"
|
||||
),
|
||||
"outlook_default_client_id": SettingDefinition(
|
||||
db_key="outlook.default_client_id",
|
||||
default_value="24d9a0ed-8787-4584-883c-2fd79308940a",
|
||||
category=SettingCategory.EMAIL,
|
||||
description="Outlook OAuth 默认 Client ID"
|
||||
),
|
||||
}
|
||||
|
||||
# 属性名到数据库键名的映射(用于向后兼容)
|
||||
DB_SETTING_KEYS = {name: defn.db_key for name, defn in SETTING_DEFINITIONS.items()}
|
||||
|
||||
# 类型定义映射
|
||||
SETTING_TYPES: Dict[str, Type] = {
|
||||
"debug": bool,
|
||||
"webui_port": int,
|
||||
"log_retention_days": int,
|
||||
"proxy_enabled": bool,
|
||||
"proxy_port": int,
|
||||
"proxy_dynamic_enabled": bool,
|
||||
"registration_max_retries": int,
|
||||
"registration_timeout": int,
|
||||
"registration_default_password_length": int,
|
||||
"registration_sleep_min": int,
|
||||
"registration_sleep_max": int,
|
||||
"email_service_priority": dict,
|
||||
"tempmail_timeout": int,
|
||||
"tempmail_max_retries": int,
|
||||
"tm_enabled": bool,
|
||||
"cpa_enabled": bool,
|
||||
"email_code_timeout": int,
|
||||
"email_code_poll_interval": int,
|
||||
"outlook_provider_priority": list,
|
||||
"outlook_health_failure_threshold": int,
|
||||
"outlook_health_disable_duration": int,
|
||||
}
|
||||
|
||||
# 需要作为 SecretStr 处理的字段
|
||||
SECRET_FIELDS = {name for name, defn in SETTING_DEFINITIONS.items() if defn.is_secret}
|
||||
|
||||
|
||||
def _convert_value(attr_name: str, value: str) -> Any:
|
||||
"""将数据库字符串值转换为正确的类型"""
|
||||
if attr_name in SECRET_FIELDS:
|
||||
return SecretStr(value) if value else SecretStr("")
|
||||
|
||||
target_type = SETTING_TYPES.get(attr_name, str)
|
||||
|
||||
if target_type == bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
return str(value).lower() in ("true", "1", "yes", "on")
|
||||
elif target_type == int:
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
return int(value) if value else 0
|
||||
elif target_type == dict:
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if not value:
|
||||
return {}
|
||||
import json
|
||||
import ast
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
try:
|
||||
return ast.literal_eval(value)
|
||||
except Exception:
|
||||
return {}
|
||||
elif target_type == list:
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
if not value:
|
||||
return []
|
||||
import json
|
||||
import ast
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
try:
|
||||
return ast.literal_eval(value)
|
||||
except Exception:
|
||||
return []
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def _normalize_database_url(url: str) -> str:
|
||||
if url.startswith("postgres://"):
|
||||
return "postgresql+psycopg://" + url[len("postgres://"):]
|
||||
if url.startswith("postgresql://"):
|
||||
return "postgresql+psycopg://" + url[len("postgresql://"):]
|
||||
return url
|
||||
|
||||
|
||||
def _value_to_string(value: Any) -> str:
|
||||
"""将值转换为数据库存储的字符串"""
|
||||
if isinstance(value, SecretStr):
|
||||
return value.get_secret_value()
|
||||
elif isinstance(value, bool):
|
||||
return "true" if value else "false"
|
||||
elif isinstance(value, (dict, list)):
|
||||
import json
|
||||
return json.dumps(value)
|
||||
elif value is None:
|
||||
return ""
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
|
||||
def init_default_settings() -> None:
|
||||
"""
|
||||
初始化数据库中的默认设置
|
||||
如果设置项不存在,则创建并设置默认值
|
||||
"""
|
||||
try:
|
||||
from ..database.session import get_db
|
||||
from ..database.crud import get_setting, set_setting
|
||||
|
||||
with get_db() as db:
|
||||
for attr_name, defn in SETTING_DEFINITIONS.items():
|
||||
existing = get_setting(db, defn.db_key)
|
||||
if not existing:
|
||||
default_value = defn.default_value
|
||||
if attr_name == "database_url":
|
||||
env_url = os.environ.get("APP_DATABASE_URL") or os.environ.get("DATABASE_URL")
|
||||
if env_url:
|
||||
default_value = _normalize_database_url(env_url)
|
||||
default_value = _value_to_string(default_value)
|
||||
set_setting(
|
||||
db,
|
||||
defn.db_key,
|
||||
default_value,
|
||||
category=defn.category.value,
|
||||
description=defn.description
|
||||
)
|
||||
print(f"[Settings] 初始化默认设置: {defn.db_key} = {default_value if not defn.is_secret else '***'}")
|
||||
except Exception as e:
|
||||
if "未初始化" not in str(e):
|
||||
print(f"[Settings] 初始化默认设置失败: {e}")
|
||||
|
||||
|
||||
def _load_settings_from_db() -> Dict[str, Any]:
|
||||
"""从数据库加载所有设置"""
|
||||
try:
|
||||
from ..database.session import get_db
|
||||
from ..database.crud import get_setting
|
||||
|
||||
settings_dict = {}
|
||||
with get_db() as db:
|
||||
for attr_name, defn in SETTING_DEFINITIONS.items():
|
||||
db_setting = get_setting(db, defn.db_key)
|
||||
if db_setting:
|
||||
settings_dict[attr_name] = _convert_value(attr_name, db_setting.value)
|
||||
else:
|
||||
# 数据库中没有此设置,使用默认值
|
||||
settings_dict[attr_name] = _convert_value(attr_name, _value_to_string(defn.default_value))
|
||||
env_url = os.environ.get("APP_DATABASE_URL") or os.environ.get("DATABASE_URL")
|
||||
if env_url:
|
||||
settings_dict["database_url"] = _normalize_database_url(env_url)
|
||||
env_host = os.environ.get("APP_HOST")
|
||||
if env_host:
|
||||
settings_dict["webui_host"] = env_host
|
||||
env_port = os.environ.get("APP_PORT")
|
||||
if env_port:
|
||||
try:
|
||||
settings_dict["webui_port"] = int(env_port)
|
||||
except ValueError:
|
||||
pass
|
||||
env_password = os.environ.get("APP_ACCESS_PASSWORD")
|
||||
if env_password:
|
||||
settings_dict["webui_access_password"] = env_password
|
||||
return settings_dict
|
||||
except Exception as e:
|
||||
if "未初始化" not in str(e):
|
||||
print(f"[Settings] 从数据库加载设置失败: {e},使用默认值")
|
||||
return {name: defn.default_value for name, defn in SETTING_DEFINITIONS.items()}
|
||||
|
||||
|
||||
def _save_settings_to_db(**kwargs) -> None:
|
||||
"""保存设置到数据库"""
|
||||
try:
|
||||
from ..database.session import get_db
|
||||
from ..database.crud import set_setting
|
||||
|
||||
with get_db() as db:
|
||||
for attr_name, value in kwargs.items():
|
||||
if attr_name in SETTING_DEFINITIONS:
|
||||
defn = SETTING_DEFINITIONS[attr_name]
|
||||
str_value = _value_to_string(value)
|
||||
set_setting(
|
||||
db,
|
||||
defn.db_key,
|
||||
str_value,
|
||||
category=defn.category.value,
|
||||
description=defn.description
|
||||
)
|
||||
except Exception as e:
|
||||
if "未初始化" not in str(e):
|
||||
print(f"[Settings] 保存设置到数据库失败: {e}")
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
"""
|
||||
应用配置 - 完全基于数据库存储
|
||||
"""
|
||||
|
||||
# 应用信息
|
||||
app_name: str = "OpenAI/Codex CLI 自动注册系统"
|
||||
app_version: str = "2.0.0"
|
||||
debug: bool = False
|
||||
|
||||
# 数据库配置
|
||||
database_url: str = "data/database.db"
|
||||
|
||||
@field_validator('database_url', mode='before')
|
||||
@classmethod
|
||||
def validate_database_url(cls, v):
|
||||
if isinstance(v, str):
|
||||
if v.startswith(("postgres://", "postgresql://")):
|
||||
return _normalize_database_url(v)
|
||||
if v.startswith(("postgresql+psycopg://", "postgresql+psycopg2://")):
|
||||
return v
|
||||
if isinstance(v, str) and v.startswith("sqlite:///"):
|
||||
return v
|
||||
if isinstance(v, str) and not v.startswith(("sqlite:///", "postgresql://", "postgresql+psycopg://", "postgresql+psycopg2://", "mysql://")):
|
||||
# 如果是文件路径,转换为 SQLite URL
|
||||
if os.path.isabs(v) or ":/" not in v:
|
||||
return f"sqlite:///{v}"
|
||||
return v
|
||||
|
||||
# Web UI 配置
|
||||
webui_host: str = "0.0.0.0"
|
||||
webui_port: int = 8000
|
||||
webui_secret_key: SecretStr = SecretStr("your-secret-key-change-in-production")
|
||||
webui_access_password: SecretStr = SecretStr("admin123")
|
||||
|
||||
# 日志配置
|
||||
log_level: str = "INFO"
|
||||
log_file: str = "logs/app.log"
|
||||
log_retention_days: int = 30
|
||||
|
||||
# OpenAI 配置
|
||||
openai_client_id: str = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
openai_auth_url: str = "https://auth.openai.com/oauth/authorize"
|
||||
openai_token_url: str = "https://auth.openai.com/oauth/token"
|
||||
openai_redirect_uri: str = "http://localhost:1455/auth/callback"
|
||||
openai_scope: str = "openid email profile offline_access"
|
||||
|
||||
# 代理配置
|
||||
proxy_enabled: bool = False
|
||||
proxy_type: str = "http"
|
||||
proxy_host: str = "127.0.0.1"
|
||||
proxy_port: int = 7890
|
||||
proxy_username: Optional[str] = None
|
||||
proxy_password: Optional[SecretStr] = None
|
||||
proxy_dynamic_enabled: bool = False
|
||||
proxy_dynamic_api_url: str = ""
|
||||
proxy_dynamic_api_key: Optional[SecretStr] = None
|
||||
proxy_dynamic_api_key_header: str = "X-API-Key"
|
||||
proxy_dynamic_result_field: str = ""
|
||||
|
||||
@property
|
||||
def proxy_url(self) -> Optional[str]:
|
||||
"""获取完整的代理 URL"""
|
||||
if not self.proxy_enabled:
|
||||
return None
|
||||
|
||||
if self.proxy_type == "http":
|
||||
scheme = "http"
|
||||
elif self.proxy_type == "socks5":
|
||||
scheme = "socks5"
|
||||
else:
|
||||
return None
|
||||
|
||||
auth = ""
|
||||
if self.proxy_username and self.proxy_password:
|
||||
auth = f"{self.proxy_username}:{self.proxy_password.get_secret_value()}@"
|
||||
|
||||
return f"{scheme}://{auth}{self.proxy_host}:{self.proxy_port}"
|
||||
|
||||
# 注册配置
|
||||
registration_max_retries: int = 3
|
||||
registration_timeout: int = 120
|
||||
registration_default_password_length: int = 12
|
||||
registration_sleep_min: int = 5
|
||||
registration_sleep_max: int = 30
|
||||
|
||||
# 邮箱服务配置
|
||||
email_service_priority: Dict[str, int] = {"tempmail": 0, "outlook": 1, "moe_mail": 2}
|
||||
|
||||
# Tempmail.lol 配置
|
||||
tempmail_base_url: str = "https://api.tempmail.lol/v2"
|
||||
tempmail_timeout: int = 30
|
||||
tempmail_max_retries: int = 3
|
||||
|
||||
# 自定义域名邮箱配置
|
||||
custom_domain_base_url: str = ""
|
||||
custom_domain_api_key: Optional[SecretStr] = None
|
||||
|
||||
# 安全配置
|
||||
encryption_key: SecretStr = SecretStr("your-encryption-key-change-in-production")
|
||||
|
||||
# Team Manager 配置
|
||||
tm_enabled: bool = False
|
||||
tm_api_url: str = ""
|
||||
tm_api_key: Optional[SecretStr] = None
|
||||
|
||||
# CPA 上传配置
|
||||
cpa_enabled: bool = False
|
||||
cpa_api_url: str = ""
|
||||
cpa_api_token: SecretStr = SecretStr("")
|
||||
|
||||
# 验证码配置
|
||||
email_code_timeout: int = 120
|
||||
email_code_poll_interval: int = 3
|
||||
|
||||
# Outlook 配置
|
||||
outlook_provider_priority: List[str] = ["imap_old", "imap_new", "graph_api"]
|
||||
outlook_health_failure_threshold: int = 5
|
||||
outlook_health_disable_duration: int = 60
|
||||
outlook_default_client_id: str = "24d9a0ed-8787-4584-883c-2fd79308940a"
|
||||
|
||||
|
||||
# 全局配置实例
|
||||
_settings: Optional[Settings] = None
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""
|
||||
获取全局配置实例(单例模式)
|
||||
完全从数据库加载配置
|
||||
"""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
# 先初始化默认设置(如果数据库中没有的话)
|
||||
init_default_settings()
|
||||
# 从数据库加载所有设置
|
||||
settings_dict = _load_settings_from_db()
|
||||
_settings = Settings(**settings_dict)
|
||||
return _settings
|
||||
|
||||
|
||||
def update_settings(**kwargs) -> Settings:
|
||||
"""
|
||||
更新配置并保存到数据库
|
||||
"""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
_settings = get_settings()
|
||||
|
||||
# 创建新的配置实例
|
||||
updated_data = _settings.model_dump()
|
||||
updated_data.update(kwargs)
|
||||
_settings = Settings(**updated_data)
|
||||
|
||||
# 保存到数据库
|
||||
_save_settings_to_db(**kwargs)
|
||||
|
||||
return _settings
|
||||
|
||||
|
||||
def get_database_url() -> str:
|
||||
"""
|
||||
获取数据库 URL(处理相对路径)
|
||||
"""
|
||||
settings = get_settings()
|
||||
url = settings.database_url
|
||||
|
||||
# 如果 URL 是相对路径,转换为绝对路径
|
||||
if url.startswith("sqlite:///"):
|
||||
path = url[10:] # 移除 "sqlite:///"
|
||||
if not os.path.isabs(path):
|
||||
# 转换为相对于项目根目录的路径
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
abs_path = os.path.join(project_root, path)
|
||||
return f"sqlite:///{abs_path}"
|
||||
|
||||
return url
|
||||
|
||||
|
||||
def get_setting_definition(attr_name: str) -> Optional[SettingDefinition]:
|
||||
"""获取设置项的定义信息"""
|
||||
return SETTING_DEFINITIONS.get(attr_name)
|
||||
|
||||
|
||||
def get_all_setting_definitions() -> Dict[str, SettingDefinition]:
|
||||
"""获取所有设置项的定义"""
|
||||
return SETTING_DEFINITIONS.copy()
|
||||
32
src/core/__init__.py
Normal file
32
src/core/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
核心功能模块
|
||||
"""
|
||||
|
||||
from .openai.oauth import OAuthManager, OAuthStart, generate_oauth_url, submit_callback_url
|
||||
from .http_client import (
|
||||
OpenAIHTTPClient,
|
||||
HTTPClient,
|
||||
HTTPClientError,
|
||||
RequestConfig,
|
||||
create_http_client,
|
||||
create_openai_client,
|
||||
)
|
||||
from .register import RegistrationEngine, RegistrationResult
|
||||
from .utils import setup_logging, get_data_dir
|
||||
|
||||
__all__ = [
|
||||
'OAuthManager',
|
||||
'OAuthStart',
|
||||
'generate_oauth_url',
|
||||
'submit_callback_url',
|
||||
'OpenAIHTTPClient',
|
||||
'HTTPClient',
|
||||
'HTTPClientError',
|
||||
'RequestConfig',
|
||||
'create_http_client',
|
||||
'create_openai_client',
|
||||
'RegistrationEngine',
|
||||
'RegistrationResult',
|
||||
'setup_logging',
|
||||
'get_data_dir',
|
||||
]
|
||||
118
src/core/dynamic_proxy.py
Normal file
118
src/core/dynamic_proxy.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
动态代理获取模块
|
||||
支持通过外部 API 获取动态代理 URL
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fetch_dynamic_proxy(api_url: str, api_key: str = "", api_key_header: str = "X-API-Key", result_field: str = "") -> Optional[str]:
|
||||
"""
|
||||
从代理 API 获取代理 URL
|
||||
|
||||
Args:
|
||||
api_url: 代理 API 地址,响应应为代理 URL 字符串或含代理 URL 的 JSON
|
||||
api_key: API 密钥(可选)
|
||||
api_key_header: API 密钥请求头名称
|
||||
result_field: 从 JSON 响应中提取代理 URL 的字段路径,支持点号分隔(如 "data.proxy"),留空则使用响应原文
|
||||
|
||||
Returns:
|
||||
代理 URL 字符串(如 http://user:pass@host:port),失败返回 None
|
||||
"""
|
||||
try:
|
||||
from curl_cffi import requests as cffi_requests
|
||||
|
||||
headers = {}
|
||||
if api_key:
|
||||
headers[api_key_header] = api_key
|
||||
|
||||
response = cffi_requests.get(
|
||||
api_url,
|
||||
headers=headers,
|
||||
timeout=10,
|
||||
impersonate="chrome110"
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"动态代理 API 返回错误状态码: {response.status_code}")
|
||||
return None
|
||||
|
||||
text = response.text.strip()
|
||||
|
||||
# 尝试解析 JSON
|
||||
if result_field or text.startswith("{") or text.startswith("["):
|
||||
try:
|
||||
import json
|
||||
data = json.loads(text)
|
||||
if result_field:
|
||||
# 按点号路径逐层提取
|
||||
for key in result_field.split("."):
|
||||
if isinstance(data, dict):
|
||||
data = data.get(key)
|
||||
elif isinstance(data, list) and key.isdigit():
|
||||
data = data[int(key)]
|
||||
else:
|
||||
data = None
|
||||
if data is None:
|
||||
break
|
||||
proxy_url = str(data).strip() if data is not None else None
|
||||
else:
|
||||
# 无指定字段,尝试常见键名
|
||||
for key in ("proxy", "url", "proxy_url", "data", "ip"):
|
||||
val = data.get(key) if isinstance(data, dict) else None
|
||||
if val:
|
||||
proxy_url = str(val).strip()
|
||||
break
|
||||
else:
|
||||
proxy_url = text
|
||||
except (ValueError, AttributeError):
|
||||
proxy_url = text
|
||||
else:
|
||||
proxy_url = text
|
||||
|
||||
if not proxy_url:
|
||||
logger.warning("动态代理 API 返回空代理 URL")
|
||||
return None
|
||||
|
||||
# 若未包含协议头,默认加 http://
|
||||
if not re.match(r'^(http|socks5)://', proxy_url):
|
||||
proxy_url = "http://" + proxy_url
|
||||
|
||||
logger.info(f"动态代理获取成功: {proxy_url[:40]}..." if len(proxy_url) > 40 else f"动态代理获取成功: {proxy_url}")
|
||||
return proxy_url
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取动态代理失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_proxy_url_for_task() -> Optional[str]:
|
||||
"""
|
||||
为注册任务获取代理 URL。
|
||||
优先使用动态代理(若启用),否则使用静态代理配置。
|
||||
|
||||
Returns:
|
||||
代理 URL 或 None
|
||||
"""
|
||||
from ..config.settings import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
# 优先使用动态代理
|
||||
if settings.proxy_dynamic_enabled and settings.proxy_dynamic_api_url:
|
||||
api_key = settings.proxy_dynamic_api_key.get_secret_value() if settings.proxy_dynamic_api_key else ""
|
||||
proxy_url = fetch_dynamic_proxy(
|
||||
api_url=settings.proxy_dynamic_api_url,
|
||||
api_key=api_key,
|
||||
api_key_header=settings.proxy_dynamic_api_key_header,
|
||||
result_field=settings.proxy_dynamic_result_field,
|
||||
)
|
||||
if proxy_url:
|
||||
return proxy_url
|
||||
logger.warning("动态代理获取失败,回退到静态代理")
|
||||
|
||||
# 使用静态代理
|
||||
return settings.proxy_url
|
||||
429
src/core/http_client.py
Normal file
429
src/core/http_client.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""
|
||||
HTTP 客户端封装
|
||||
基于 curl_cffi 的 HTTP 请求封装,支持代理和错误处理
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
from typing import Optional, Dict, Any, Union, Tuple
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
from curl_cffi import requests as cffi_requests
|
||||
from curl_cffi.requests import Session, Response
|
||||
|
||||
from ..config.constants import ERROR_MESSAGES
|
||||
from ..config.settings import get_settings
|
||||
from .openai.sentinel import SentinelPOWError, build_sentinel_pow_token
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestConfig:
|
||||
"""HTTP 请求配置"""
|
||||
timeout: int = 30
|
||||
max_retries: int = 3
|
||||
retry_delay: float = 1.0
|
||||
impersonate: str = "chrome"
|
||||
verify_ssl: bool = True
|
||||
follow_redirects: bool = True
|
||||
|
||||
|
||||
class HTTPClientError(Exception):
|
||||
"""HTTP 客户端异常"""
|
||||
pass
|
||||
|
||||
|
||||
class HTTPClient:
|
||||
"""
|
||||
HTTP 客户端封装
|
||||
支持代理、重试、错误处理和会话管理
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: Optional[str] = None,
|
||||
config: Optional[RequestConfig] = None,
|
||||
session: Optional[Session] = None
|
||||
):
|
||||
"""
|
||||
初始化 HTTP 客户端
|
||||
|
||||
Args:
|
||||
proxy_url: 代理 URL,如 "http://127.0.0.1:7890"
|
||||
config: 请求配置
|
||||
session: 可重用的会话对象
|
||||
"""
|
||||
self.proxy_url = proxy_url
|
||||
self.config = config or RequestConfig()
|
||||
self._session = session
|
||||
|
||||
@property
|
||||
def proxies(self) -> Optional[Dict[str, str]]:
|
||||
"""获取代理配置"""
|
||||
if not self.proxy_url:
|
||||
return None
|
||||
return {
|
||||
"http": self.proxy_url,
|
||||
"https": self.proxy_url,
|
||||
}
|
||||
|
||||
@property
|
||||
def session(self) -> Session:
|
||||
"""获取会话对象(单例)"""
|
||||
if self._session is None:
|
||||
self._session = Session(
|
||||
proxies=self.proxies,
|
||||
impersonate=self.config.impersonate,
|
||||
verify=self.config.verify_ssl,
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
return self._session
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
**kwargs
|
||||
) -> Response:
|
||||
"""
|
||||
发送 HTTP 请求
|
||||
|
||||
Args:
|
||||
method: HTTP 方法 (GET, POST, PUT, DELETE, etc.)
|
||||
url: 请求 URL
|
||||
**kwargs: 其他请求参数
|
||||
|
||||
Returns:
|
||||
Response 对象
|
||||
|
||||
Raises:
|
||||
HTTPClientError: 请求失败
|
||||
"""
|
||||
# 设置默认参数
|
||||
kwargs.setdefault("timeout", self.config.timeout)
|
||||
kwargs.setdefault("allow_redirects", self.config.follow_redirects)
|
||||
|
||||
# 添加代理配置
|
||||
if self.proxies and "proxies" not in kwargs:
|
||||
kwargs["proxies"] = self.proxies
|
||||
|
||||
last_exception = None
|
||||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
response = self.session.request(method, url, **kwargs)
|
||||
|
||||
# 检查响应状态码
|
||||
if response.status_code >= 400:
|
||||
logger.warning(
|
||||
f"HTTP {response.status_code} for {method} {url}"
|
||||
f" (attempt {attempt + 1}/{self.config.max_retries})"
|
||||
)
|
||||
|
||||
# 如果是服务器错误,重试
|
||||
if response.status_code >= 500 and attempt < self.config.max_retries - 1:
|
||||
time.sleep(self.config.retry_delay * (attempt + 1))
|
||||
continue
|
||||
|
||||
return response
|
||||
|
||||
except (cffi_requests.RequestsError, ConnectionError, TimeoutError) as e:
|
||||
last_exception = e
|
||||
logger.warning(
|
||||
f"请求失败: {method} {url} (attempt {attempt + 1}/{self.config.max_retries}): {e}"
|
||||
)
|
||||
|
||||
if attempt < self.config.max_retries - 1:
|
||||
time.sleep(self.config.retry_delay * (attempt + 1))
|
||||
else:
|
||||
break
|
||||
|
||||
raise HTTPClientError(
|
||||
f"请求失败,最大重试次数已达: {method} {url} - {last_exception}"
|
||||
)
|
||||
|
||||
def get(self, url: str, **kwargs) -> Response:
|
||||
"""发送 GET 请求"""
|
||||
return self.request("GET", url, **kwargs)
|
||||
|
||||
def post(self, url: str, data: Any = None, json: Any = None, **kwargs) -> Response:
|
||||
"""发送 POST 请求"""
|
||||
return self.request("POST", url, data=data, json=json, **kwargs)
|
||||
|
||||
def put(self, url: str, data: Any = None, json: Any = None, **kwargs) -> Response:
|
||||
"""发送 PUT 请求"""
|
||||
return self.request("PUT", url, data=data, json=json, **kwargs)
|
||||
|
||||
def delete(self, url: str, **kwargs) -> Response:
|
||||
"""发送 DELETE 请求"""
|
||||
return self.request("DELETE", url, **kwargs)
|
||||
|
||||
def head(self, url: str, **kwargs) -> Response:
|
||||
"""发送 HEAD 请求"""
|
||||
return self.request("HEAD", url, **kwargs)
|
||||
|
||||
def options(self, url: str, **kwargs) -> Response:
|
||||
"""发送 OPTIONS 请求"""
|
||||
return self.request("OPTIONS", url, **kwargs)
|
||||
|
||||
def patch(self, url: str, data: Any = None, json: Any = None, **kwargs) -> Response:
|
||||
"""发送 PATCH 请求"""
|
||||
return self.request("PATCH", url, data=data, json=json, **kwargs)
|
||||
|
||||
def download_file(self, url: str, filepath: str, chunk_size: int = 8192) -> None:
|
||||
"""
|
||||
下载文件
|
||||
|
||||
Args:
|
||||
url: 文件 URL
|
||||
filepath: 保存路径
|
||||
chunk_size: 块大小
|
||||
|
||||
Raises:
|
||||
HTTPClientError: 下载失败
|
||||
"""
|
||||
try:
|
||||
response = self.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(filepath, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=chunk_size):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPClientError(f"下载文件失败: {url} - {e}")
|
||||
|
||||
def check_proxy(self, test_url: str = "https://httpbin.org/ip") -> bool:
|
||||
"""
|
||||
检查代理是否可用
|
||||
|
||||
Args:
|
||||
test_url: 测试 URL
|
||||
|
||||
Returns:
|
||||
bool: 代理是否可用
|
||||
"""
|
||||
if not self.proxy_url:
|
||||
return False
|
||||
|
||||
try:
|
||||
response = self.get(test_url, timeout=10)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
"""关闭会话"""
|
||||
if self._session:
|
||||
self._session.close()
|
||||
self._session = None
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
|
||||
class OpenAIHTTPClient(HTTPClient):
|
||||
"""
|
||||
OpenAI 专用 HTTP 客户端
|
||||
包含 OpenAI API 特定的请求方法
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: Optional[str] = None,
|
||||
config: Optional[RequestConfig] = None
|
||||
):
|
||||
"""
|
||||
初始化 OpenAI HTTP 客户端
|
||||
|
||||
Args:
|
||||
proxy_url: 代理 URL
|
||||
config: 请求配置
|
||||
"""
|
||||
super().__init__(proxy_url, config)
|
||||
|
||||
# OpenAI 特定的默认配置
|
||||
if config is None:
|
||||
self.config.timeout = 30
|
||||
self.config.max_retries = 3
|
||||
|
||||
# 默认请求头
|
||||
self.default_headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
|
||||
"(KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
|
||||
"Accept": "application/json",
|
||||
"Accept-Language": "en-US,en;q=0.9",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
"Connection": "keep-alive",
|
||||
"Sec-Fetch-Dest": "empty",
|
||||
"Sec-Fetch-Mode": "cors",
|
||||
"Sec-Fetch-Site": "same-site",
|
||||
}
|
||||
|
||||
def check_ip_location(self) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
检查 IP 地理位置
|
||||
|
||||
Returns:
|
||||
Tuple[是否支持, 位置信息]
|
||||
"""
|
||||
try:
|
||||
response = self.get("https://cloudflare.com/cdn-cgi/trace", timeout=10)
|
||||
trace_text = response.text
|
||||
|
||||
# 解析位置信息
|
||||
import re
|
||||
loc_match = re.search(r"loc=([A-Z]+)", trace_text)
|
||||
loc = loc_match.group(1) if loc_match else None
|
||||
|
||||
# 检查是否支持
|
||||
if loc in ["CN", "HK", "MO", "TW"]:
|
||||
return False, loc
|
||||
return True, loc
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查 IP 地理位置失败: {e}")
|
||||
return False, None
|
||||
|
||||
def send_openai_request(
|
||||
self,
|
||||
endpoint: str,
|
||||
method: str = "POST",
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
json_data: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送 OpenAI API 请求
|
||||
|
||||
Args:
|
||||
endpoint: API 端点
|
||||
method: HTTP 方法
|
||||
data: 表单数据
|
||||
json_data: JSON 数据
|
||||
headers: 请求头
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
响应 JSON 数据
|
||||
|
||||
Raises:
|
||||
HTTPClientError: 请求失败
|
||||
"""
|
||||
# 合并请求头
|
||||
request_headers = self.default_headers.copy()
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
# 设置 Content-Type
|
||||
if json_data is not None and "Content-Type" not in request_headers:
|
||||
request_headers["Content-Type"] = "application/json"
|
||||
elif data is not None and "Content-Type" not in request_headers:
|
||||
request_headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
|
||||
try:
|
||||
response = self.request(
|
||||
method,
|
||||
endpoint,
|
||||
data=data,
|
||||
json=json_data,
|
||||
headers=request_headers,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# 检查响应状态码
|
||||
response.raise_for_status()
|
||||
|
||||
# 尝试解析 JSON
|
||||
try:
|
||||
return response.json()
|
||||
except json.JSONDecodeError:
|
||||
return {"raw_response": response.text}
|
||||
|
||||
except cffi_requests.RequestsError as e:
|
||||
raise HTTPClientError(f"OpenAI 请求失败: {endpoint} - {e}")
|
||||
|
||||
def check_sentinel(self, did: str, proxies: Optional[Dict] = None) -> Optional[str]:
|
||||
"""
|
||||
检查 Sentinel 拦截
|
||||
|
||||
Args:
|
||||
did: Device ID
|
||||
proxies: 代理配置
|
||||
|
||||
Returns:
|
||||
Sentinel token 或 None
|
||||
"""
|
||||
from ..config.constants import OPENAI_API_ENDPOINTS
|
||||
|
||||
try:
|
||||
pow_token = build_sentinel_pow_token(self.default_headers.get("User-Agent", ""))
|
||||
sen_req_body = json.dumps({
|
||||
"p": pow_token,
|
||||
"id": did,
|
||||
"flow": "authorize_continue",
|
||||
}, separators=(",", ":"))
|
||||
|
||||
response = self.post(
|
||||
OPENAI_API_ENDPOINTS["sentinel"],
|
||||
headers={
|
||||
"origin": "https://sentinel.openai.com",
|
||||
"referer": "https://sentinel.openai.com/backend-api/sentinel/frame.html?sv=20260219f9f6",
|
||||
"content-type": "text/plain;charset=UTF-8",
|
||||
},
|
||||
data=sen_req_body,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json().get("token")
|
||||
else:
|
||||
logger.warning(f"Sentinel 检查失败: {response.status_code}")
|
||||
return None
|
||||
|
||||
except SentinelPOWError as e:
|
||||
logger.error(f"Sentinel POW 求解失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Sentinel 检查异常: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def create_http_client(
|
||||
proxy_url: Optional[str] = None,
|
||||
config: Optional[RequestConfig] = None
|
||||
) -> HTTPClient:
|
||||
"""
|
||||
创建 HTTP 客户端工厂函数
|
||||
|
||||
Args:
|
||||
proxy_url: 代理 URL
|
||||
config: 请求配置
|
||||
|
||||
Returns:
|
||||
HTTPClient 实例
|
||||
"""
|
||||
return HTTPClient(proxy_url, config)
|
||||
|
||||
|
||||
def create_openai_client(
|
||||
proxy_url: Optional[str] = None,
|
||||
config: Optional[RequestConfig] = None
|
||||
) -> OpenAIHTTPClient:
|
||||
"""
|
||||
创建 OpenAI HTTP 客户端工厂函数
|
||||
|
||||
Args:
|
||||
proxy_url: 代理 URL
|
||||
config: 请求配置
|
||||
|
||||
Returns:
|
||||
OpenAIHTTPClient 实例
|
||||
"""
|
||||
return OpenAIHTTPClient(proxy_url, config)
|
||||
3
src/core/openai/__init__.py
Normal file
3
src/core/openai/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2026/3/18 19:55
|
||||
370
src/core/openai/oauth.py
Normal file
370
src/core/openai/oauth.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
OpenAI OAuth 授权模块
|
||||
从 main.py 中提取的 OAuth 相关函数
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import secrets
|
||||
import time
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from curl_cffi import requests as cffi_requests
|
||||
|
||||
from ...config.constants import (
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_AUTH_URL,
|
||||
OAUTH_TOKEN_URL,
|
||||
OAUTH_REDIRECT_URI,
|
||||
OAUTH_SCOPE,
|
||||
)
|
||||
|
||||
|
||||
def _b64url_no_pad(raw: bytes) -> str:
|
||||
"""Base64 URL 编码(无填充)"""
|
||||
return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=")
|
||||
|
||||
|
||||
def _sha256_b64url_no_pad(s: str) -> str:
|
||||
"""SHA256 哈希后 Base64 URL 编码"""
|
||||
return _b64url_no_pad(hashlib.sha256(s.encode("ascii")).digest())
|
||||
|
||||
|
||||
def _random_state(nbytes: int = 16) -> str:
|
||||
"""生成随机 state"""
|
||||
return secrets.token_urlsafe(nbytes)
|
||||
|
||||
|
||||
def _pkce_verifier() -> str:
|
||||
"""生成 PKCE code_verifier"""
|
||||
return secrets.token_urlsafe(64)
|
||||
|
||||
|
||||
def _parse_callback_url(callback_url: str) -> Dict[str, str]:
|
||||
"""解析回调 URL"""
|
||||
candidate = callback_url.strip()
|
||||
if not candidate:
|
||||
return {"code": "", "state": "", "error": "", "error_description": ""}
|
||||
|
||||
if "://" not in candidate:
|
||||
if candidate.startswith("?"):
|
||||
candidate = f"http://localhost{candidate}"
|
||||
elif any(ch in candidate for ch in "/?#") or ":" in candidate:
|
||||
candidate = f"http://{candidate}"
|
||||
elif "=" in candidate:
|
||||
candidate = f"http://localhost/?{candidate}"
|
||||
|
||||
parsed = urllib.parse.urlparse(candidate)
|
||||
query = urllib.parse.parse_qs(parsed.query, keep_blank_values=True)
|
||||
fragment = urllib.parse.parse_qs(parsed.fragment, keep_blank_values=True)
|
||||
|
||||
for key, values in fragment.items():
|
||||
if key not in query or not query[key] or not (query[key][0] or "").strip():
|
||||
query[key] = values
|
||||
|
||||
def get1(k: str) -> str:
|
||||
v = query.get(k, [""])
|
||||
return (v[0] or "").strip()
|
||||
|
||||
code = get1("code")
|
||||
state = get1("state")
|
||||
error = get1("error")
|
||||
error_description = get1("error_description")
|
||||
|
||||
if code and not state and "#" in code:
|
||||
code, state = code.split("#", 1)
|
||||
|
||||
if not error and error_description:
|
||||
error, error_description = error_description, ""
|
||||
|
||||
return {
|
||||
"code": code,
|
||||
"state": state,
|
||||
"error": error,
|
||||
"error_description": error_description,
|
||||
}
|
||||
|
||||
|
||||
def _jwt_claims_no_verify(id_token: str) -> Dict[str, Any]:
|
||||
"""解析 JWT ID Token(不验证签名)"""
|
||||
if not id_token or id_token.count(".") < 2:
|
||||
return {}
|
||||
payload_b64 = id_token.split(".")[1]
|
||||
pad = "=" * ((4 - (len(payload_b64) % 4)) % 4)
|
||||
try:
|
||||
payload = base64.urlsafe_b64decode((payload_b64 + pad).encode("ascii"))
|
||||
return json.loads(payload.decode("utf-8"))
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _decode_jwt_segment(seg: str) -> Dict[str, Any]:
|
||||
"""解码 JWT 片段"""
|
||||
raw = (seg or "").strip()
|
||||
if not raw:
|
||||
return {}
|
||||
pad = "=" * ((4 - (len(raw) % 4)) % 4)
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode((raw + pad).encode("ascii"))
|
||||
return json.loads(decoded.decode("utf-8"))
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _to_int(v: Any) -> int:
|
||||
"""转换为整数"""
|
||||
try:
|
||||
return int(v)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
|
||||
def _post_form(
|
||||
url: str,
|
||||
data: Dict[str, str],
|
||||
timeout: int = 30,
|
||||
proxy_url: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送 POST 表单请求
|
||||
|
||||
Args:
|
||||
url: 请求 URL
|
||||
data: 表单数据
|
||||
timeout: 超时时间
|
||||
proxy_url: 代理 URL
|
||||
|
||||
Returns:
|
||||
响应 JSON 数据
|
||||
"""
|
||||
# 构建代理配置
|
||||
proxies = None
|
||||
if proxy_url:
|
||||
proxies = {
|
||||
"http": proxy_url,
|
||||
"https": proxy_url,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Accept": "application/json",
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
|
||||
"(KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 curl_cffi 发送请求,支持代理和浏览器指纹
|
||||
response = cffi_requests.post(
|
||||
url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
proxies=proxies,
|
||||
impersonate="chrome"
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"token exchange failed: {response.status_code}: {response.text}"
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
except cffi_requests.RequestsError as e:
|
||||
raise RuntimeError(f"token exchange failed: network error: {e}") from e
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OAuthStart:
|
||||
"""OAuth 开始信息"""
|
||||
auth_url: str
|
||||
state: str
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
def generate_oauth_url(
|
||||
*,
|
||||
redirect_uri: str = OAUTH_REDIRECT_URI,
|
||||
scope: str = OAUTH_SCOPE,
|
||||
client_id: str = OAUTH_CLIENT_ID
|
||||
) -> OAuthStart:
|
||||
"""
|
||||
生成 OAuth 授权 URL
|
||||
|
||||
Args:
|
||||
redirect_uri: 回调地址
|
||||
scope: 权限范围
|
||||
client_id: OpenAI Client ID
|
||||
|
||||
Returns:
|
||||
OAuthStart 对象,包含授权 URL 和必要参数
|
||||
"""
|
||||
state = _random_state()
|
||||
code_verifier = _pkce_verifier()
|
||||
code_challenge = _sha256_b64url_no_pad(code_verifier)
|
||||
|
||||
params = {
|
||||
"client_id": client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": scope,
|
||||
"state": state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"prompt": "login",
|
||||
"id_token_add_organizations": "true",
|
||||
"codex_cli_simplified_flow": "true",
|
||||
}
|
||||
auth_url = f"{OAUTH_AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
return OAuthStart(
|
||||
auth_url=auth_url,
|
||||
state=state,
|
||||
code_verifier=code_verifier,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
|
||||
def submit_callback_url(
|
||||
*,
|
||||
callback_url: str,
|
||||
expected_state: str,
|
||||
code_verifier: str,
|
||||
redirect_uri: str = OAUTH_REDIRECT_URI,
|
||||
client_id: str = OAUTH_CLIENT_ID,
|
||||
token_url: str = OAUTH_TOKEN_URL,
|
||||
proxy_url: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
处理 OAuth 回调 URL,获取访问令牌
|
||||
|
||||
Args:
|
||||
callback_url: 回调 URL
|
||||
expected_state: 预期的 state 值
|
||||
code_verifier: PKCE code_verifier
|
||||
redirect_uri: 回调地址
|
||||
client_id: OpenAI Client ID
|
||||
token_url: Token 交换地址
|
||||
proxy_url: 代理 URL
|
||||
|
||||
Returns:
|
||||
包含访问令牌等信息的 JSON 字符串
|
||||
|
||||
Raises:
|
||||
RuntimeError: OAuth 错误
|
||||
ValueError: 缺少必要参数或 state 不匹配
|
||||
"""
|
||||
cb = _parse_callback_url(callback_url)
|
||||
if cb["error"]:
|
||||
desc = cb["error_description"]
|
||||
raise RuntimeError(f"oauth error: {cb['error']}: {desc}".strip())
|
||||
|
||||
if not cb["code"]:
|
||||
raise ValueError("callback url missing ?code=")
|
||||
if not cb["state"]:
|
||||
raise ValueError("callback url missing ?state=")
|
||||
if cb["state"] != expected_state:
|
||||
raise ValueError("state mismatch")
|
||||
|
||||
token_resp = _post_form(
|
||||
token_url,
|
||||
{
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": client_id,
|
||||
"code": cb["code"],
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_verifier": code_verifier,
|
||||
},
|
||||
proxy_url=proxy_url
|
||||
)
|
||||
|
||||
access_token = (token_resp.get("access_token") or "").strip()
|
||||
refresh_token = (token_resp.get("refresh_token") or "").strip()
|
||||
id_token = (token_resp.get("id_token") or "").strip()
|
||||
expires_in = _to_int(token_resp.get("expires_in"))
|
||||
|
||||
claims = _jwt_claims_no_verify(id_token)
|
||||
email = str(claims.get("email") or "").strip()
|
||||
auth_claims = claims.get("https://api.openai.com/auth") or {}
|
||||
account_id = str(auth_claims.get("chatgpt_account_id") or "").strip()
|
||||
|
||||
now = int(time.time())
|
||||
expired_rfc3339 = time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%SZ", time.gmtime(now + max(expires_in, 0))
|
||||
)
|
||||
now_rfc3339 = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(now))
|
||||
|
||||
config = {
|
||||
"id_token": id_token,
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"account_id": account_id,
|
||||
"last_refresh": now_rfc3339,
|
||||
"email": email,
|
||||
"type": "codex",
|
||||
"expired": expired_rfc3339,
|
||||
}
|
||||
|
||||
return json.dumps(config, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
|
||||
class OAuthManager:
|
||||
"""OAuth 管理器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str = OAUTH_CLIENT_ID,
|
||||
auth_url: str = OAUTH_AUTH_URL,
|
||||
token_url: str = OAUTH_TOKEN_URL,
|
||||
redirect_uri: str = OAUTH_REDIRECT_URI,
|
||||
scope: str = OAUTH_SCOPE,
|
||||
proxy_url: Optional[str] = None
|
||||
):
|
||||
self.client_id = client_id
|
||||
self.auth_url = auth_url
|
||||
self.token_url = token_url
|
||||
self.redirect_uri = redirect_uri
|
||||
self.scope = scope
|
||||
self.proxy_url = proxy_url
|
||||
|
||||
def start_oauth(self) -> OAuthStart:
|
||||
"""开始 OAuth 流程"""
|
||||
return generate_oauth_url(
|
||||
redirect_uri=self.redirect_uri,
|
||||
scope=self.scope,
|
||||
client_id=self.client_id
|
||||
)
|
||||
|
||||
def handle_callback(
|
||||
self,
|
||||
callback_url: str,
|
||||
expected_state: str,
|
||||
code_verifier: str
|
||||
) -> Dict[str, Any]:
|
||||
"""处理 OAuth 回调"""
|
||||
result_json = submit_callback_url(
|
||||
callback_url=callback_url,
|
||||
expected_state=expected_state,
|
||||
code_verifier=code_verifier,
|
||||
redirect_uri=self.redirect_uri,
|
||||
client_id=self.client_id,
|
||||
token_url=self.token_url,
|
||||
proxy_url=self.proxy_url
|
||||
)
|
||||
return json.loads(result_json)
|
||||
|
||||
def extract_account_info(self, id_token: str) -> Dict[str, Any]:
|
||||
"""从 ID Token 中提取账户信息"""
|
||||
claims = _jwt_claims_no_verify(id_token)
|
||||
email = str(claims.get("email") or "").strip()
|
||||
auth_claims = claims.get("https://api.openai.com/auth") or {}
|
||||
account_id = str(auth_claims.get("chatgpt_account_id") or "").strip()
|
||||
|
||||
return {
|
||||
"email": email,
|
||||
"account_id": account_id,
|
||||
"claims": claims
|
||||
}
|
||||
261
src/core/openai/payment.py
Normal file
261
src/core/openai/payment.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
支付核心逻辑 — 生成 Plus/Team 支付链接、无痕打开浏览器、检测订阅状态
|
||||
"""
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from curl_cffi import requests as cffi_requests
|
||||
|
||||
from ...database.models import Account
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PAYMENT_CHECKOUT_URL = "https://chatgpt.com/backend-api/payments/checkout"
|
||||
TEAM_CHECKOUT_BASE_URL = "https://chatgpt.com/checkout/openai_llc/"
|
||||
|
||||
|
||||
def _build_proxies(proxy: Optional[str]) -> Optional[dict]:
|
||||
if proxy:
|
||||
return {"http": proxy, "https": proxy}
|
||||
return None
|
||||
|
||||
|
||||
_COUNTRY_CURRENCY_MAP = {
|
||||
"SG": "SGD",
|
||||
"US": "USD",
|
||||
"TR": "TRY",
|
||||
"JP": "JPY",
|
||||
"HK": "HKD",
|
||||
"GB": "GBP",
|
||||
"EU": "EUR",
|
||||
"AU": "AUD",
|
||||
"CA": "CAD",
|
||||
"IN": "INR",
|
||||
"BR": "BRL",
|
||||
"MX": "MXN",
|
||||
}
|
||||
|
||||
|
||||
def _extract_oai_did(cookies_str: str) -> Optional[str]:
|
||||
"""从 cookie 字符串中提取 oai-device-id"""
|
||||
for part in cookies_str.split(";"):
|
||||
part = part.strip()
|
||||
if part.startswith("oai-did="):
|
||||
return part[len("oai-did="):].strip()
|
||||
return None
|
||||
|
||||
|
||||
def _parse_cookie_str(cookies_str: str, domain: str) -> list:
|
||||
"""将 'key=val; key2=val2' 格式解析为 Playwright cookie 列表"""
|
||||
cookies = []
|
||||
for part in cookies_str.split(";"):
|
||||
part = part.strip()
|
||||
if "=" not in part:
|
||||
continue
|
||||
name, _, value = part.partition("=")
|
||||
cookies.append({
|
||||
"name": name.strip(),
|
||||
"value": value.strip(),
|
||||
"domain": domain,
|
||||
"path": "/",
|
||||
})
|
||||
return cookies
|
||||
|
||||
|
||||
def _open_url_system_browser(url: str) -> bool:
|
||||
"""回退方案:调用系统浏览器以无痕模式打开"""
|
||||
platform = sys.platform
|
||||
try:
|
||||
if platform == "win32":
|
||||
for browser, flag in [("chrome", "--incognito"), ("msedge", "--inprivate")]:
|
||||
try:
|
||||
subprocess.Popen(f'start {browser} {flag} "{url}"', shell=True)
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
elif platform == "darwin":
|
||||
subprocess.Popen(["open", "-a", "Google Chrome", "--args", "--incognito", url])
|
||||
return True
|
||||
else:
|
||||
for binary in ["google-chrome", "chromium-browser", "chromium"]:
|
||||
try:
|
||||
subprocess.Popen([binary, "--incognito", url])
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"系统浏览器无痕打开失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def generate_plus_link(
|
||||
account: Account,
|
||||
proxy: Optional[str] = None,
|
||||
country: str = "SG",
|
||||
) -> str:
|
||||
"""生成 Plus 支付链接(后端携带账号 cookie 发请求)"""
|
||||
if not account.access_token:
|
||||
raise ValueError("账号缺少 access_token")
|
||||
|
||||
currency = _COUNTRY_CURRENCY_MAP.get(country, "USD")
|
||||
headers = {
|
||||
"Authorization": f"Bearer {account.access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"oai-language": "zh-CN",
|
||||
}
|
||||
if account.cookies:
|
||||
headers["cookie"] = account.cookies
|
||||
oai_did = _extract_oai_did(account.cookies)
|
||||
if oai_did:
|
||||
headers["oai-device-id"] = oai_did
|
||||
|
||||
payload = {
|
||||
"plan_name": "chatgptplusplan",
|
||||
"billing_details": {"country": country, "currency": currency},
|
||||
"promo_campaign": {
|
||||
"promo_campaign_id": "plus-1-month-free",
|
||||
"is_coupon_from_query_param": False,
|
||||
},
|
||||
"checkout_ui_mode": "custom",
|
||||
}
|
||||
|
||||
resp = cffi_requests.post(
|
||||
PAYMENT_CHECKOUT_URL,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
proxies=_build_proxies(proxy),
|
||||
timeout=30,
|
||||
impersonate="chrome110",
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
if "checkout_session_id" in data:
|
||||
return TEAM_CHECKOUT_BASE_URL + data["checkout_session_id"]
|
||||
raise ValueError(data.get("detail", "API 未返回 checkout_session_id"))
|
||||
|
||||
|
||||
def generate_team_link(
|
||||
account: Account,
|
||||
workspace_name: str = "MyTeam",
|
||||
price_interval: str = "month",
|
||||
seat_quantity: int = 5,
|
||||
proxy: Optional[str] = None,
|
||||
country: str = "SG",
|
||||
) -> str:
|
||||
"""生成 Team 支付链接(后端携带账号 cookie 发请求)"""
|
||||
if not account.access_token:
|
||||
raise ValueError("账号缺少 access_token")
|
||||
|
||||
currency = _COUNTRY_CURRENCY_MAP.get(country, "USD")
|
||||
headers = {
|
||||
"Authorization": f"Bearer {account.access_token}",
|
||||
"Content-Type": "application/json",
|
||||
"oai-language": "zh-CN",
|
||||
}
|
||||
if account.cookies:
|
||||
headers["cookie"] = account.cookies
|
||||
oai_did = _extract_oai_did(account.cookies)
|
||||
if oai_did:
|
||||
headers["oai-device-id"] = oai_did
|
||||
|
||||
payload = {
|
||||
"plan_name": "chatgptteamplan",
|
||||
"team_plan_data": {
|
||||
"workspace_name": workspace_name,
|
||||
"price_interval": price_interval,
|
||||
"seat_quantity": seat_quantity,
|
||||
},
|
||||
"billing_details": {"country": country, "currency": currency},
|
||||
"promo_campaign": {
|
||||
"promo_campaign_id": "team-1-month-free",
|
||||
"is_coupon_from_query_param": True,
|
||||
},
|
||||
"cancel_url": "https://chatgpt.com/#pricing",
|
||||
"checkout_ui_mode": "custom",
|
||||
}
|
||||
|
||||
resp = cffi_requests.post(
|
||||
PAYMENT_CHECKOUT_URL,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
proxies=_build_proxies(proxy),
|
||||
timeout=30,
|
||||
impersonate="chrome110",
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
if "checkout_session_id" in data:
|
||||
return TEAM_CHECKOUT_BASE_URL + data["checkout_session_id"]
|
||||
raise ValueError(data.get("detail", "API 未返回 checkout_session_id"))
|
||||
|
||||
|
||||
def open_url_incognito(url: str, cookies_str: Optional[str] = None) -> bool:
|
||||
"""用 Playwright 以无痕模式打开 URL,可注入 cookie"""
|
||||
import threading
|
||||
try:
|
||||
from playwright.sync_api import sync_playwright
|
||||
except ImportError:
|
||||
logger.warning("playwright 未安装,回退到系统浏览器")
|
||||
return _open_url_system_browser(url)
|
||||
|
||||
def _launch():
|
||||
try:
|
||||
with sync_playwright() as p:
|
||||
browser = p.chromium.launch(headless=False, args=["--incognito"])
|
||||
ctx = browser.new_context()
|
||||
if cookies_str:
|
||||
ctx.add_cookies(_parse_cookie_str(cookies_str, "chatgpt.com"))
|
||||
page = ctx.new_page()
|
||||
page.goto(url)
|
||||
# 保持窗口打开直到用户关闭
|
||||
page.wait_for_timeout(300_000) # 最多等待 5 分钟
|
||||
except Exception as e:
|
||||
logger.warning(f"Playwright 无痕打开失败: {e}")
|
||||
|
||||
threading.Thread(target=_launch, daemon=True).start()
|
||||
return True
|
||||
|
||||
|
||||
def check_subscription_status(account: Account, proxy: Optional[str] = None) -> str:
|
||||
"""
|
||||
检测账号当前订阅状态。
|
||||
|
||||
Returns:
|
||||
'free' / 'plus' / 'team'
|
||||
"""
|
||||
if not account.access_token:
|
||||
raise ValueError("账号缺少 access_token")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {account.access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
resp = cffi_requests.get(
|
||||
"https://chatgpt.com/backend-api/me",
|
||||
headers=headers,
|
||||
proxies=_build_proxies(proxy),
|
||||
timeout=20,
|
||||
impersonate="chrome110",
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# 解析订阅类型
|
||||
plan = data.get("plan_type") or ""
|
||||
if "team" in plan.lower():
|
||||
return "team"
|
||||
if "plus" in plan.lower():
|
||||
return "plus"
|
||||
|
||||
# 尝试从 orgs 或 workspace 信息判断
|
||||
orgs = data.get("orgs", {}).get("data", [])
|
||||
for org in orgs:
|
||||
settings_ = org.get("settings", {})
|
||||
if settings_.get("workspace_plan_type") in ("team", "enterprise"):
|
||||
return "team"
|
||||
|
||||
return "free"
|
||||
98
src/core/openai/sentinel.py
Normal file
98
src/core/openai/sentinel.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Helpers for OpenAI Sentinel proof-of-work tokens."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
DEFAULT_SENTINEL_DIFF = "0fffff"
|
||||
DEFAULT_MAX_ITERATIONS = 500_000
|
||||
_SCREEN_SIGNATURES = (3000, 3120, 4000, 4160)
|
||||
_LANGUAGE_SIGNATURE = "en-US,es-US,en,es"
|
||||
_NAVIGATOR_KEYS = ("location", "ontransitionend", "onprogress")
|
||||
_WINDOW_KEYS = ("window", "document", "navigator")
|
||||
|
||||
|
||||
class SentinelPOWError(RuntimeError):
|
||||
"""Raised when a Sentinel proof-of-work token cannot be solved."""
|
||||
|
||||
|
||||
def _format_browser_time() -> str:
|
||||
"""Match the browser-style timestamp used by public Sentinel solvers."""
|
||||
browser_now = datetime.now(timezone(timedelta(hours=-5)))
|
||||
return browser_now.strftime("%a %b %d %Y %H:%M:%S") + " GMT-0500 (Eastern Standard Time)"
|
||||
|
||||
|
||||
def build_sentinel_config(user_agent: str) -> list:
|
||||
"""Build a browser-like fingerprint payload for the Sentinel PoW solver."""
|
||||
perf_ms = time.perf_counter() * 1000
|
||||
epoch_ms = (time.time() * 1000) - perf_ms
|
||||
return [
|
||||
random.choice(_SCREEN_SIGNATURES),
|
||||
_format_browser_time(),
|
||||
4294705152,
|
||||
0,
|
||||
user_agent,
|
||||
"",
|
||||
"",
|
||||
"en-US",
|
||||
_LANGUAGE_SIGNATURE,
|
||||
0,
|
||||
random.choice(_NAVIGATOR_KEYS),
|
||||
"location",
|
||||
random.choice(_WINDOW_KEYS),
|
||||
perf_ms,
|
||||
str(uuid.uuid4()),
|
||||
"",
|
||||
8,
|
||||
epoch_ms,
|
||||
]
|
||||
|
||||
|
||||
def _encode_pow_payload(config: Sequence[object], nonce: int) -> bytes:
|
||||
prefix = (json.dumps(config[:3], separators=(",", ":"), ensure_ascii=False)[:-1] + ",").encode("utf-8")
|
||||
middle = (
|
||||
"," + json.dumps(config[4:9], separators=(",", ":"), ensure_ascii=False)[1:-1] + ","
|
||||
).encode("utf-8")
|
||||
suffix = ("," + json.dumps(config[10:], separators=(",", ":"), ensure_ascii=False)[1:]).encode("utf-8")
|
||||
body = prefix + str(nonce).encode("ascii") + middle + str(nonce >> 1).encode("ascii") + suffix
|
||||
return base64.b64encode(body)
|
||||
|
||||
|
||||
def solve_sentinel_pow(
|
||||
seed: str,
|
||||
difficulty: str,
|
||||
config: Sequence[object],
|
||||
max_iterations: int = DEFAULT_MAX_ITERATIONS,
|
||||
) -> str:
|
||||
"""Solve the Sentinel PoW challenge and return the base64 payload."""
|
||||
seed_bytes = seed.encode("utf-8")
|
||||
target = bytes.fromhex(difficulty)
|
||||
prefix_length = len(target)
|
||||
|
||||
for nonce in range(max_iterations):
|
||||
encoded = _encode_pow_payload(config, nonce)
|
||||
digest = hashlib.sha3_512(seed_bytes + encoded).digest()
|
||||
if digest[:prefix_length] <= target:
|
||||
return encoded.decode("ascii")
|
||||
|
||||
raise SentinelPOWError(f"failed to solve sentinel pow after {max_iterations} attempts")
|
||||
|
||||
|
||||
def build_sentinel_pow_token(
|
||||
user_agent: str,
|
||||
difficulty: str = DEFAULT_SENTINEL_DIFF,
|
||||
max_iterations: int = DEFAULT_MAX_ITERATIONS,
|
||||
) -> str:
|
||||
"""Build the `p` token required by the Sentinel request endpoint."""
|
||||
config = build_sentinel_config(user_agent)
|
||||
seed = format(random.random())
|
||||
solution = solve_sentinel_pow(seed, difficulty, config, max_iterations=max_iterations)
|
||||
return f"gAAAAAC{solution}"
|
||||
332
src/core/openai/token_refresh.py
Normal file
332
src/core/openai/token_refresh.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""
|
||||
Token 刷新模块
|
||||
支持 Session Token 和 OAuth Refresh Token 两种刷新方式
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import time
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from curl_cffi import requests as cffi_requests
|
||||
|
||||
from ...config.settings import get_settings
|
||||
from ...database.session import get_db
|
||||
from ...database import crud
|
||||
from ...database.models import Account
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenRefreshResult:
|
||||
"""Token 刷新结果"""
|
||||
success: bool
|
||||
access_token: str = ""
|
||||
refresh_token: str = ""
|
||||
expires_at: Optional[datetime] = None
|
||||
error_message: str = ""
|
||||
|
||||
|
||||
class TokenRefreshManager:
|
||||
"""
|
||||
Token 刷新管理器
|
||||
支持两种刷新方式:
|
||||
1. Session Token 刷新(优先)
|
||||
2. OAuth Refresh Token 刷新
|
||||
"""
|
||||
|
||||
# OpenAI OAuth 端点
|
||||
SESSION_URL = "https://chatgpt.com/api/auth/session"
|
||||
TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
|
||||
def __init__(self, proxy_url: Optional[str] = None):
|
||||
"""
|
||||
初始化 Token 刷新管理器
|
||||
|
||||
Args:
|
||||
proxy_url: 代理 URL
|
||||
"""
|
||||
self.proxy_url = proxy_url
|
||||
self.settings = get_settings()
|
||||
|
||||
def _create_session(self) -> cffi_requests.Session:
|
||||
"""创建 HTTP 会话"""
|
||||
session = cffi_requests.Session(impersonate="chrome120", proxy=self.proxy_url)
|
||||
return session
|
||||
|
||||
def refresh_by_session_token(self, session_token: str) -> TokenRefreshResult:
|
||||
"""
|
||||
使用 Session Token 刷新
|
||||
|
||||
Args:
|
||||
session_token: 会话令牌
|
||||
|
||||
Returns:
|
||||
TokenRefreshResult: 刷新结果
|
||||
"""
|
||||
result = TokenRefreshResult(success=False)
|
||||
|
||||
try:
|
||||
session = self._create_session()
|
||||
|
||||
# 设置会话 Cookie
|
||||
session.cookies.set(
|
||||
"__Secure-next-auth.session-token",
|
||||
session_token,
|
||||
domain=".chatgpt.com",
|
||||
path="/"
|
||||
)
|
||||
|
||||
# 请求会话端点
|
||||
response = session.get(
|
||||
self.SESSION_URL,
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
result.error_message = f"Session token 刷新失败: HTTP {response.status_code}"
|
||||
logger.warning(result.error_message)
|
||||
return result
|
||||
|
||||
data = response.json()
|
||||
|
||||
# 提取 access_token
|
||||
access_token = data.get("accessToken")
|
||||
if not access_token:
|
||||
result.error_message = "Session token 刷新失败: 未找到 accessToken"
|
||||
logger.warning(result.error_message)
|
||||
return result
|
||||
|
||||
# 提取过期时间
|
||||
expires_at = None
|
||||
expires_str = data.get("expires")
|
||||
if expires_str:
|
||||
try:
|
||||
expires_at = datetime.fromisoformat(expires_str.replace("Z", "+00:00"))
|
||||
except:
|
||||
pass
|
||||
|
||||
result.success = True
|
||||
result.access_token = access_token
|
||||
result.expires_at = expires_at
|
||||
|
||||
logger.info(f"Session token 刷新成功,过期时间: {expires_at}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
result.error_message = f"Session token 刷新异常: {str(e)}"
|
||||
logger.error(result.error_message)
|
||||
return result
|
||||
|
||||
def refresh_by_oauth_token(
|
||||
self,
|
||||
refresh_token: str,
|
||||
client_id: Optional[str] = None
|
||||
) -> TokenRefreshResult:
|
||||
"""
|
||||
使用 OAuth Refresh Token 刷新
|
||||
|
||||
Args:
|
||||
refresh_token: OAuth 刷新令牌
|
||||
client_id: OAuth Client ID
|
||||
|
||||
Returns:
|
||||
TokenRefreshResult: 刷新结果
|
||||
"""
|
||||
result = TokenRefreshResult(success=False)
|
||||
|
||||
try:
|
||||
session = self._create_session()
|
||||
|
||||
# 使用配置的 client_id 或默认值
|
||||
client_id = client_id or self.settings.openai_client_id
|
||||
|
||||
# 构建请求体
|
||||
token_data = {
|
||||
"client_id": client_id,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"redirect_uri": self.settings.openai_redirect_uri
|
||||
}
|
||||
|
||||
response = session.post(
|
||||
self.TOKEN_URL,
|
||||
headers={
|
||||
"content-type": "application/x-www-form-urlencoded",
|
||||
"accept": "application/json"
|
||||
},
|
||||
data=token_data,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
result.error_message = f"OAuth token 刷新失败: HTTP {response.status_code}"
|
||||
logger.warning(f"{result.error_message}, 响应: {response.text[:200]}")
|
||||
return result
|
||||
|
||||
data = response.json()
|
||||
|
||||
# 提取令牌
|
||||
access_token = data.get("access_token")
|
||||
new_refresh_token = data.get("refresh_token", refresh_token)
|
||||
expires_in = data.get("expires_in", 3600)
|
||||
|
||||
if not access_token:
|
||||
result.error_message = "OAuth token 刷新失败: 未找到 access_token"
|
||||
logger.warning(result.error_message)
|
||||
return result
|
||||
|
||||
# 计算过期时间
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||
|
||||
result.success = True
|
||||
result.access_token = access_token
|
||||
result.refresh_token = new_refresh_token
|
||||
result.expires_at = expires_at
|
||||
|
||||
logger.info(f"OAuth token 刷新成功,过期时间: {expires_at}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
result.error_message = f"OAuth token 刷新异常: {str(e)}"
|
||||
logger.error(result.error_message)
|
||||
return result
|
||||
|
||||
def refresh_account(self, account: Account) -> TokenRefreshResult:
|
||||
"""
|
||||
刷新账号的 Token
|
||||
|
||||
优先级:
|
||||
1. Session Token 刷新
|
||||
2. OAuth Refresh Token 刷新
|
||||
|
||||
Args:
|
||||
account: 账号对象
|
||||
|
||||
Returns:
|
||||
TokenRefreshResult: 刷新结果
|
||||
"""
|
||||
# 优先尝试 Session Token
|
||||
if account.session_token:
|
||||
logger.info(f"尝试使用 Session Token 刷新账号 {account.email}")
|
||||
result = self.refresh_by_session_token(account.session_token)
|
||||
if result.success:
|
||||
return result
|
||||
logger.warning(f"Session Token 刷新失败,尝试 OAuth 刷新")
|
||||
|
||||
# 尝试 OAuth Refresh Token
|
||||
if account.refresh_token:
|
||||
logger.info(f"尝试使用 OAuth Refresh Token 刷新账号 {account.email}")
|
||||
result = self.refresh_by_oauth_token(
|
||||
refresh_token=account.refresh_token,
|
||||
client_id=account.client_id
|
||||
)
|
||||
return result
|
||||
|
||||
# 无可用刷新方式
|
||||
return TokenRefreshResult(
|
||||
success=False,
|
||||
error_message="账号没有可用的刷新方式(缺少 session_token 和 refresh_token)"
|
||||
)
|
||||
|
||||
def validate_token(self, access_token: str) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
验证 Access Token 是否有效
|
||||
|
||||
Args:
|
||||
access_token: 访问令牌
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (是否有效, 错误信息)
|
||||
"""
|
||||
try:
|
||||
session = self._create_session()
|
||||
|
||||
# 调用 OpenAI API 验证 token
|
||||
response = session.get(
|
||||
"https://chatgpt.com/backend-api/me",
|
||||
headers={
|
||||
"authorization": f"Bearer {access_token}",
|
||||
"accept": "application/json"
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return True, None
|
||||
elif response.status_code == 401:
|
||||
return False, "Token 无效或已过期"
|
||||
elif response.status_code == 403:
|
||||
return False, "账号可能被封禁"
|
||||
else:
|
||||
return False, f"验证失败: HTTP {response.status_code}"
|
||||
|
||||
except Exception as e:
|
||||
return False, f"验证异常: {str(e)}"
|
||||
|
||||
|
||||
def refresh_account_token(account_id: int, proxy_url: Optional[str] = None) -> TokenRefreshResult:
|
||||
"""
|
||||
刷新指定账号的 Token 并更新数据库
|
||||
|
||||
Args:
|
||||
account_id: 账号 ID
|
||||
proxy_url: 代理 URL
|
||||
|
||||
Returns:
|
||||
TokenRefreshResult: 刷新结果
|
||||
"""
|
||||
with get_db() as db:
|
||||
account = crud.get_account_by_id(db, account_id)
|
||||
if not account:
|
||||
return TokenRefreshResult(success=False, error_message="账号不存在")
|
||||
|
||||
manager = TokenRefreshManager(proxy_url=proxy_url)
|
||||
result = manager.refresh_account(account)
|
||||
|
||||
if result.success:
|
||||
# 更新数据库
|
||||
update_data = {
|
||||
"access_token": result.access_token,
|
||||
"last_refresh": datetime.utcnow()
|
||||
}
|
||||
|
||||
if result.refresh_token:
|
||||
update_data["refresh_token"] = result.refresh_token
|
||||
|
||||
if result.expires_at:
|
||||
update_data["expires_at"] = result.expires_at
|
||||
|
||||
crud.update_account(db, account_id, **update_data)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def validate_account_token(account_id: int, proxy_url: Optional[str] = None) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
验证指定账号的 Token 是否有效
|
||||
|
||||
Args:
|
||||
account_id: 账号 ID
|
||||
proxy_url: 代理 URL
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (是否有效, 错误信息)
|
||||
"""
|
||||
with get_db() as db:
|
||||
account = crud.get_account_by_id(db, account_id)
|
||||
if not account:
|
||||
return False, "账号不存在"
|
||||
|
||||
if not account.access_token:
|
||||
return False, "账号没有 access_token"
|
||||
|
||||
manager = TokenRefreshManager(proxy_url=proxy_url)
|
||||
return manager.validate_token(account.access_token)
|
||||
1009
src/core/register.py
Normal file
1009
src/core/register.py
Normal file
File diff suppressed because it is too large
Load Diff
3
src/core/upload/__init__.py
Normal file
3
src/core/upload/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2026/3/18 19:54
|
||||
312
src/core/upload/cpa_upload.py
Normal file
312
src/core/upload/cpa_upload.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
CPA (Codex Protocol API) 上传功能
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from datetime import datetime
|
||||
from urllib.parse import quote
|
||||
|
||||
from curl_cffi import requests as cffi_requests
|
||||
from curl_cffi import CurlMime
|
||||
|
||||
from ...database.session import get_db
|
||||
from ...database.models import Account
|
||||
from ...config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_cpa_auth_files_url(api_url: str) -> str:
|
||||
"""将用户填写的 CPA 地址规范化为 auth-files 接口地址。"""
|
||||
normalized = (api_url or "").strip().rstrip("/")
|
||||
lower_url = normalized.lower()
|
||||
|
||||
if not normalized:
|
||||
return ""
|
||||
|
||||
if lower_url.endswith("/auth-files"):
|
||||
return normalized
|
||||
|
||||
if lower_url.endswith("/v0/management") or lower_url.endswith("/management"):
|
||||
return f"{normalized}/auth-files"
|
||||
|
||||
if lower_url.endswith("/v0"):
|
||||
return f"{normalized}/management/auth-files"
|
||||
|
||||
return f"{normalized}/v0/management/auth-files"
|
||||
|
||||
|
||||
def _build_cpa_headers(api_token: str, content_type: Optional[str] = None) -> dict:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_token}",
|
||||
}
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
return headers
|
||||
|
||||
|
||||
def _extract_cpa_error(response) -> str:
|
||||
error_msg = f"上传失败: HTTP {response.status_code}"
|
||||
try:
|
||||
error_detail = response.json()
|
||||
if isinstance(error_detail, dict):
|
||||
error_msg = error_detail.get("message", error_msg)
|
||||
except Exception:
|
||||
error_msg = f"{error_msg} - {response.text[:200]}"
|
||||
return error_msg
|
||||
|
||||
|
||||
def _post_cpa_auth_file_multipart(upload_url: str, filename: str, file_content: bytes, api_token: str):
|
||||
mime = CurlMime()
|
||||
mime.addpart(
|
||||
name="file",
|
||||
data=file_content,
|
||||
filename=filename,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
return cffi_requests.post(
|
||||
upload_url,
|
||||
multipart=mime,
|
||||
headers=_build_cpa_headers(api_token),
|
||||
proxies=None,
|
||||
timeout=30,
|
||||
impersonate="chrome110",
|
||||
)
|
||||
|
||||
|
||||
def _post_cpa_auth_file_raw_json(upload_url: str, filename: str, file_content: bytes, api_token: str):
|
||||
raw_upload_url = f"{upload_url}?name={quote(filename)}"
|
||||
return cffi_requests.post(
|
||||
raw_upload_url,
|
||||
data=file_content,
|
||||
headers=_build_cpa_headers(api_token, content_type="application/json"),
|
||||
proxies=None,
|
||||
timeout=30,
|
||||
impersonate="chrome110",
|
||||
)
|
||||
|
||||
|
||||
def generate_token_json(account: Account) -> dict:
|
||||
"""
|
||||
生成 CPA 格式的 Token JSON
|
||||
|
||||
Args:
|
||||
account: 账号模型实例
|
||||
|
||||
Returns:
|
||||
CPA 格式的 Token 字典
|
||||
"""
|
||||
return {
|
||||
"type": "codex",
|
||||
"email": account.email,
|
||||
"expired": account.expires_at.strftime("%Y-%m-%dT%H:%M:%S+08:00") if account.expires_at else "",
|
||||
"id_token": account.id_token or "",
|
||||
"account_id": account.account_id or "",
|
||||
"access_token": account.access_token or "",
|
||||
"last_refresh": account.last_refresh.strftime("%Y-%m-%dT%H:%M:%S+08:00") if account.last_refresh else "",
|
||||
"refresh_token": account.refresh_token or "",
|
||||
}
|
||||
|
||||
|
||||
def upload_to_cpa(
|
||||
token_data: dict,
|
||||
proxy: str = None,
|
||||
api_url: str = None,
|
||||
api_token: str = None,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
上传单个账号到 CPA 管理平台(不走代理)
|
||||
|
||||
Args:
|
||||
token_data: Token JSON 数据
|
||||
proxy: 保留参数,不使用(CPA 上传始终直连)
|
||||
api_url: 指定 CPA API URL(优先于全局配置)
|
||||
api_token: 指定 CPA API Token(优先于全局配置)
|
||||
|
||||
Returns:
|
||||
(成功标志, 消息或错误信息)
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
# 优先使用传入的参数,否则退回全局配置
|
||||
effective_url = api_url or settings.cpa_api_url
|
||||
effective_token = api_token or (settings.cpa_api_token.get_secret_value() if settings.cpa_api_token else "")
|
||||
|
||||
# 仅当未指定服务时才检查全局启用开关
|
||||
if not api_url and not settings.cpa_enabled:
|
||||
return False, "CPA 上传未启用"
|
||||
|
||||
if not effective_url:
|
||||
return False, "CPA API URL 未配置"
|
||||
|
||||
if not effective_token:
|
||||
return False, "CPA API Token 未配置"
|
||||
|
||||
upload_url = _normalize_cpa_auth_files_url(effective_url)
|
||||
|
||||
filename = f"{token_data['email']}.json"
|
||||
file_content = json.dumps(token_data, ensure_ascii=False, indent=2).encode("utf-8")
|
||||
|
||||
try:
|
||||
response = _post_cpa_auth_file_multipart(
|
||||
upload_url,
|
||||
filename,
|
||||
file_content,
|
||||
effective_token,
|
||||
)
|
||||
|
||||
if response.status_code in (200, 201):
|
||||
return True, "上传成功"
|
||||
|
||||
if response.status_code in (404, 405, 415):
|
||||
logger.warning("CPA multipart 上传失败,尝试原始 JSON 回退: %s", response.status_code)
|
||||
fallback_response = _post_cpa_auth_file_raw_json(
|
||||
upload_url,
|
||||
filename,
|
||||
file_content,
|
||||
effective_token,
|
||||
)
|
||||
if fallback_response.status_code in (200, 201):
|
||||
return True, "上传成功"
|
||||
response = fallback_response
|
||||
|
||||
return False, _extract_cpa_error(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CPA 上传异常: {e}")
|
||||
return False, f"上传异常: {str(e)}"
|
||||
|
||||
|
||||
def batch_upload_to_cpa(
|
||||
account_ids: List[int],
|
||||
proxy: str = None,
|
||||
api_url: str = None,
|
||||
api_token: str = None,
|
||||
) -> dict:
|
||||
"""
|
||||
批量上传账号到 CPA 管理平台
|
||||
|
||||
Args:
|
||||
account_ids: 账号 ID 列表
|
||||
proxy: 可选的代理 URL
|
||||
api_url: 指定 CPA API URL(优先于全局配置)
|
||||
api_token: 指定 CPA API Token(优先于全局配置)
|
||||
|
||||
Returns:
|
||||
包含成功/失败统计和详情的字典
|
||||
"""
|
||||
results = {
|
||||
"success_count": 0,
|
||||
"failed_count": 0,
|
||||
"skipped_count": 0,
|
||||
"details": []
|
||||
}
|
||||
|
||||
with get_db() as db:
|
||||
for account_id in account_ids:
|
||||
account = db.query(Account).filter(Account.id == account_id).first()
|
||||
|
||||
if not account:
|
||||
results["failed_count"] += 1
|
||||
results["details"].append({
|
||||
"id": account_id,
|
||||
"email": None,
|
||||
"success": False,
|
||||
"error": "账号不存在"
|
||||
})
|
||||
continue
|
||||
|
||||
# 检查是否已有 Token
|
||||
if not account.access_token:
|
||||
results["skipped_count"] += 1
|
||||
results["details"].append({
|
||||
"id": account_id,
|
||||
"email": account.email,
|
||||
"success": False,
|
||||
"error": "缺少 Token"
|
||||
})
|
||||
continue
|
||||
|
||||
# 生成 Token JSON
|
||||
token_data = generate_token_json(account)
|
||||
|
||||
# 上传
|
||||
success, message = upload_to_cpa(token_data, proxy, api_url=api_url, api_token=api_token)
|
||||
|
||||
if success:
|
||||
# 更新数据库状态
|
||||
account.cpa_uploaded = True
|
||||
account.cpa_uploaded_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
results["success_count"] += 1
|
||||
results["details"].append({
|
||||
"id": account_id,
|
||||
"email": account.email,
|
||||
"success": True,
|
||||
"message": message
|
||||
})
|
||||
else:
|
||||
results["failed_count"] += 1
|
||||
results["details"].append({
|
||||
"id": account_id,
|
||||
"email": account.email,
|
||||
"success": False,
|
||||
"error": message
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_cpa_connection(api_url: str, api_token: str, proxy: str = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
测试 CPA 连接(不走代理)
|
||||
|
||||
Args:
|
||||
api_url: CPA API URL
|
||||
api_token: CPA API Token
|
||||
proxy: 保留参数,不使用(CPA 始终直连)
|
||||
|
||||
Returns:
|
||||
(成功标志, 消息)
|
||||
"""
|
||||
if not api_url:
|
||||
return False, "API URL 不能为空"
|
||||
|
||||
if not api_token:
|
||||
return False, "API Token 不能为空"
|
||||
|
||||
test_url = _normalize_cpa_auth_files_url(api_url)
|
||||
headers = _build_cpa_headers(api_token)
|
||||
|
||||
try:
|
||||
response = cffi_requests.get(
|
||||
test_url,
|
||||
headers=headers,
|
||||
proxies=None,
|
||||
timeout=10,
|
||||
impersonate="chrome110",
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return True, "CPA 连接测试成功"
|
||||
if response.status_code == 401:
|
||||
return False, "连接成功,但 API Token 无效"
|
||||
if response.status_code == 403:
|
||||
return False, "连接成功,但服务端未启用远程管理或当前 Token 无权限"
|
||||
if response.status_code == 404:
|
||||
return False, "未找到 CPA auth-files 接口,请检查 API URL 是否填写为根地址、/v0/management 或完整 auth-files 地址"
|
||||
if response.status_code == 503:
|
||||
return False, "连接成功,但服务端认证管理器不可用"
|
||||
|
||||
return False, f"服务器返回异常状态码: {response.status_code}"
|
||||
|
||||
except cffi_requests.exceptions.ConnectionError as e:
|
||||
return False, f"无法连接到服务器: {str(e)}"
|
||||
except cffi_requests.exceptions.Timeout:
|
||||
return False, "连接超时,请检查网络配置"
|
||||
except Exception as e:
|
||||
return False, f"连接测试失败: {str(e)}"
|
||||
224
src/core/upload/sub2api_upload.py
Normal file
224
src/core/upload/sub2api_upload.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Sub2API 账号上传功能
|
||||
将账号以 sub2api-data 格式批量导入到 Sub2API 平台
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from curl_cffi import requests as cffi_requests
|
||||
|
||||
from ...database.session import get_db
|
||||
from ...database.models import Account
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upload_to_sub2api(
|
||||
accounts: List[Account],
|
||||
api_url: str,
|
||||
api_key: str,
|
||||
concurrency: int = 3,
|
||||
priority: int = 50,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
上传账号列表到 Sub2API 平台(不走代理)
|
||||
|
||||
Args:
|
||||
accounts: 账号模型实例列表
|
||||
api_url: Sub2API 地址,如 http://host
|
||||
api_key: Admin API Key(x-api-key header)
|
||||
concurrency: 账号并发数,默认 3
|
||||
priority: 账号优先级,默认 50
|
||||
|
||||
Returns:
|
||||
(成功标志, 消息)
|
||||
"""
|
||||
if not accounts:
|
||||
return False, "无可上传的账号"
|
||||
|
||||
if not api_url:
|
||||
return False, "Sub2API URL 未配置"
|
||||
|
||||
if not api_key:
|
||||
return False, "Sub2API API Key 未配置"
|
||||
|
||||
exported_at = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
account_items = []
|
||||
for acc in accounts:
|
||||
if not acc.access_token:
|
||||
continue
|
||||
expires_at = int(acc.expires_at.timestamp()) if acc.expires_at else 0
|
||||
account_items.append({
|
||||
"name": acc.email,
|
||||
"platform": "openai",
|
||||
"type": "oauth",
|
||||
"credentials": {
|
||||
"access_token": acc.access_token,
|
||||
"chatgpt_account_id": acc.account_id or "",
|
||||
"chatgpt_user_id": "",
|
||||
"client_id": acc.client_id or "",
|
||||
"expires_at": expires_at,
|
||||
"expires_in": 863999,
|
||||
"model_mapping": {
|
||||
"gpt-5.1": "gpt-5.1",
|
||||
"gpt-5.1-codex": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-max": "gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex-mini": "gpt-5.1-codex-mini",
|
||||
"gpt-5.2": "gpt-5.2",
|
||||
"gpt-5.2-codex": "gpt-5.2-codex",
|
||||
"gpt-5.3": "gpt-5.3",
|
||||
"gpt-5.3-codex": "gpt-5.3-codex",
|
||||
"gpt-5.4": "gpt-5.4"
|
||||
},
|
||||
"organization_id": acc.workspace_id or "",
|
||||
"refresh_token": acc.refresh_token or "",
|
||||
},
|
||||
"extra": {},
|
||||
"concurrency": concurrency,
|
||||
"priority": priority,
|
||||
"rate_multiplier": 1,
|
||||
"auto_pause_on_expired": True,
|
||||
})
|
||||
|
||||
if not account_items:
|
||||
return False, "所有账号均缺少 access_token,无法上传"
|
||||
|
||||
payload = {
|
||||
"data": {
|
||||
"type": "sub2api-data",
|
||||
"version": 1,
|
||||
"exported_at": exported_at,
|
||||
"proxies": [],
|
||||
"accounts": account_items,
|
||||
},
|
||||
"skip_default_group_bind": True,
|
||||
}
|
||||
|
||||
url = api_url.rstrip("/") + "/api/v1/admin/accounts/data"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": api_key,
|
||||
"Idempotency-Key": f"import-{exported_at}",
|
||||
}
|
||||
|
||||
try:
|
||||
response = cffi_requests.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
proxies=None,
|
||||
timeout=30,
|
||||
impersonate="chrome110",
|
||||
)
|
||||
|
||||
if response.status_code in (200, 201):
|
||||
return True, f"成功上传 {len(account_items)} 个账号"
|
||||
|
||||
error_msg = f"上传失败: HTTP {response.status_code}"
|
||||
try:
|
||||
detail = response.json()
|
||||
if isinstance(detail, dict):
|
||||
error_msg = detail.get("message", error_msg)
|
||||
except Exception:
|
||||
error_msg = f"{error_msg} - {response.text[:200]}"
|
||||
return False, error_msg
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Sub2API 上传异常: {e}")
|
||||
return False, f"上传异常: {str(e)}"
|
||||
|
||||
|
||||
def batch_upload_to_sub2api(
|
||||
account_ids: List[int],
|
||||
api_url: str,
|
||||
api_key: str,
|
||||
concurrency: int = 3,
|
||||
priority: int = 50,
|
||||
) -> dict:
|
||||
"""
|
||||
批量上传指定 ID 的账号到 Sub2API 平台
|
||||
|
||||
Returns:
|
||||
包含成功/失败/跳过统计和详情的字典
|
||||
"""
|
||||
results = {
|
||||
"success_count": 0,
|
||||
"failed_count": 0,
|
||||
"skipped_count": 0,
|
||||
"details": []
|
||||
}
|
||||
|
||||
with get_db() as db:
|
||||
accounts = []
|
||||
for account_id in account_ids:
|
||||
acc = db.query(Account).filter(Account.id == account_id).first()
|
||||
if not acc:
|
||||
results["failed_count"] += 1
|
||||
results["details"].append({"id": account_id, "email": None, "success": False, "error": "账号不存在"})
|
||||
continue
|
||||
if not acc.access_token:
|
||||
results["skipped_count"] += 1
|
||||
results["details"].append({"id": account_id, "email": acc.email, "success": False, "error": "缺少 access_token"})
|
||||
continue
|
||||
accounts.append(acc)
|
||||
|
||||
if not accounts:
|
||||
return results
|
||||
|
||||
success, message = upload_to_sub2api(accounts, api_url, api_key, concurrency, priority)
|
||||
|
||||
if success:
|
||||
for acc in accounts:
|
||||
results["success_count"] += 1
|
||||
results["details"].append({"id": acc.id, "email": acc.email, "success": True, "message": message})
|
||||
else:
|
||||
for acc in accounts:
|
||||
results["failed_count"] += 1
|
||||
results["details"].append({"id": acc.id, "email": acc.email, "success": False, "error": message})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_sub2api_connection(api_url: str, api_key: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
测试 Sub2API 连接(GET /api/v1/admin/accounts/data 探活)
|
||||
|
||||
Returns:
|
||||
(成功标志, 消息)
|
||||
"""
|
||||
if not api_url:
|
||||
return False, "API URL 不能为空"
|
||||
if not api_key:
|
||||
return False, "API Key 不能为空"
|
||||
|
||||
url = api_url.rstrip("/") + "/api/v1/admin/accounts/data"
|
||||
headers = {"x-api-key": api_key}
|
||||
|
||||
try:
|
||||
response = cffi_requests.get(
|
||||
url,
|
||||
headers=headers,
|
||||
proxies=None,
|
||||
timeout=10,
|
||||
impersonate="chrome110",
|
||||
)
|
||||
|
||||
if response.status_code in (200, 201, 204, 405):
|
||||
return True, "Sub2API 连接测试成功"
|
||||
if response.status_code == 401:
|
||||
return False, "连接成功,但 API Key 无效"
|
||||
if response.status_code == 403:
|
||||
return False, "连接成功,但权限不足"
|
||||
|
||||
return False, f"服务器返回异常状态码: {response.status_code}"
|
||||
|
||||
except cffi_requests.exceptions.ConnectionError as e:
|
||||
return False, f"无法连接到服务器: {str(e)}"
|
||||
except cffi_requests.exceptions.Timeout:
|
||||
return False, "连接超时,请检查网络配置"
|
||||
except Exception as e:
|
||||
return False, f"连接测试失败: {str(e)}"
|
||||
204
src/core/upload/team_manager_upload.py
Normal file
204
src/core/upload/team_manager_upload.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
Team Manager 上传功能
|
||||
参照 CPA 上传模式,直连不走代理
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
from curl_cffi import requests as cffi_requests
|
||||
|
||||
from ...database.models import Account
|
||||
from ...database.session import get_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upload_to_team_manager(
|
||||
account: Account,
|
||||
api_url: str,
|
||||
api_key: str,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
上传单账号到 Team Manager(直连,不走代理)
|
||||
|
||||
Returns:
|
||||
(成功标志, 消息)
|
||||
"""
|
||||
if not api_url:
|
||||
return False, "Team Manager API URL 未配置"
|
||||
if not api_key:
|
||||
return False, "Team Manager API Key 未配置"
|
||||
if not account.access_token:
|
||||
return False, "账号缺少 access_token"
|
||||
|
||||
url = api_url.rstrip("/") + "/admin/teams/import"
|
||||
headers = {
|
||||
"X-API-Key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"import_type": "single",
|
||||
"email": account.email,
|
||||
"access_token": account.access_token or "",
|
||||
"session_token": account.session_token or "",
|
||||
"refresh_token": account.refresh_token or "",
|
||||
"client_id": account.client_id or "",
|
||||
"account_id": account.account_id or "",
|
||||
}
|
||||
|
||||
try:
|
||||
resp = cffi_requests.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
proxies=None,
|
||||
timeout=30
|
||||
)
|
||||
if resp.status_code in (200, 201):
|
||||
return True, "上传成功"
|
||||
error_msg = f"上传失败: HTTP {resp.status_code}"
|
||||
try:
|
||||
detail = resp.json()
|
||||
if isinstance(detail, dict):
|
||||
error_msg = detail.get("message", error_msg)
|
||||
except Exception:
|
||||
error_msg = f"{error_msg} - {resp.text[:200]}"
|
||||
return False, error_msg
|
||||
except Exception as e:
|
||||
logger.error(f"Team Manager 上传异常: {e}")
|
||||
return False, f"上传异常: {str(e)}"
|
||||
|
||||
|
||||
def batch_upload_to_team_manager(
|
||||
account_ids: List[int],
|
||||
api_url: str,
|
||||
api_key: str,
|
||||
) -> dict:
|
||||
"""
|
||||
批量上传账号到 Team Manager(使用 batch 模式,一次请求提交所有账号)
|
||||
|
||||
Returns:
|
||||
包含成功/失败统计和详情的字典
|
||||
"""
|
||||
results = {
|
||||
"success_count": 0,
|
||||
"failed_count": 0,
|
||||
"skipped_count": 0,
|
||||
"details": [],
|
||||
}
|
||||
|
||||
with get_db() as db:
|
||||
lines = []
|
||||
valid_accounts = []
|
||||
for account_id in account_ids:
|
||||
account = db.query(Account).filter(Account.id == account_id).first()
|
||||
if not account:
|
||||
results["failed_count"] += 1
|
||||
results["details"].append(
|
||||
{"id": account_id, "email": None, "success": False, "error": "账号不存在"}
|
||||
)
|
||||
continue
|
||||
if not account.access_token:
|
||||
results["skipped_count"] += 1
|
||||
results["details"].append(
|
||||
{"id": account_id, "email": account.email, "success": False, "error": "缺少 Token"}
|
||||
)
|
||||
continue
|
||||
# 格式:邮箱,AT,RT,ST,ClientID
|
||||
lines.append(",".join([
|
||||
account.email or "",
|
||||
account.access_token or "",
|
||||
account.refresh_token or "",
|
||||
account.session_token or "",
|
||||
account.client_id or "",
|
||||
]))
|
||||
valid_accounts.append(account)
|
||||
|
||||
if not valid_accounts:
|
||||
return results
|
||||
|
||||
url = api_url.rstrip("/") + "/admin/teams/import"
|
||||
headers = {
|
||||
"X-API-Key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"import_type": "batch",
|
||||
"content": "\n".join(lines),
|
||||
}
|
||||
|
||||
try:
|
||||
resp = cffi_requests.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
proxies=None,
|
||||
timeout=60,
|
||||
impersonate="chrome110",
|
||||
)
|
||||
if resp.status_code in (200, 201):
|
||||
for account in valid_accounts:
|
||||
results["success_count"] += 1
|
||||
results["details"].append(
|
||||
{"id": account.id, "email": account.email, "success": True, "message": "批量上传成功"}
|
||||
)
|
||||
else:
|
||||
error_msg = f"批量上传失败: HTTP {resp.status_code}"
|
||||
try:
|
||||
detail = resp.json()
|
||||
if isinstance(detail, dict):
|
||||
error_msg = detail.get("message", error_msg)
|
||||
except Exception:
|
||||
error_msg = f"{error_msg} - {resp.text[:200]}"
|
||||
for account in valid_accounts:
|
||||
results["failed_count"] += 1
|
||||
results["details"].append(
|
||||
{"id": account.id, "email": account.email, "success": False, "error": error_msg}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Team Manager 批量上传异常: {e}")
|
||||
error_msg = f"上传异常: {str(e)}"
|
||||
for account in valid_accounts:
|
||||
results["failed_count"] += 1
|
||||
results["details"].append(
|
||||
{"id": account.id, "email": account.email, "success": False, "error": error_msg}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_team_manager_connection(api_url: str, api_key: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
测试 Team Manager 连接(直连)
|
||||
|
||||
Returns:
|
||||
(成功标志, 消息)
|
||||
"""
|
||||
if not api_url:
|
||||
return False, "API URL 不能为空"
|
||||
if not api_key:
|
||||
return False, "API Key 不能为空"
|
||||
|
||||
url = api_url.rstrip("/") + "/admin/teams/import"
|
||||
headers = {"X-API-Key": api_key}
|
||||
|
||||
try:
|
||||
resp = cffi_requests.options(
|
||||
url,
|
||||
headers=headers,
|
||||
proxies=None,
|
||||
timeout=10,
|
||||
impersonate="chrome110",
|
||||
)
|
||||
if resp.status_code in (200, 204, 401, 403, 405):
|
||||
if resp.status_code == 401:
|
||||
return False, "连接成功,但 API Key 无效"
|
||||
return True, "Team Manager 连接测试成功"
|
||||
return False, f"服务器返回异常状态码: {resp.status_code}"
|
||||
except cffi_requests.exceptions.ConnectionError as e:
|
||||
return False, f"无法连接到服务器: {str(e)}"
|
||||
except cffi_requests.exceptions.Timeout:
|
||||
return False, "连接超时,请检查网络配置"
|
||||
except Exception as e:
|
||||
return False, f"连接测试失败: {str(e)}"
|
||||
570
src/core/utils.py
Normal file
570
src/core/utils.py
Normal file
@@ -0,0 +1,570 @@
|
||||
"""
|
||||
通用工具函数
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import string
|
||||
import secrets
|
||||
import hashlib
|
||||
import logging
|
||||
import base64
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional, Union, Callable
|
||||
from pathlib import Path
|
||||
|
||||
from ..config.constants import PASSWORD_CHARSET, DEFAULT_PASSWORD_LENGTH
|
||||
from ..config.settings import get_settings
|
||||
|
||||
|
||||
def setup_logging(
|
||||
log_level: str = "INFO",
|
||||
log_file: Optional[str] = None,
|
||||
log_format: str = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
配置日志系统
|
||||
|
||||
Args:
|
||||
log_level: 日志级别 (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
log_file: 日志文件路径,如果不指定则只输出到控制台
|
||||
log_format: 日志格式
|
||||
|
||||
Returns:
|
||||
根日志记录器
|
||||
"""
|
||||
# 设置日志级别
|
||||
numeric_level = getattr(logging, log_level.upper(), None)
|
||||
if not isinstance(numeric_level, int):
|
||||
numeric_level = logging.INFO
|
||||
|
||||
# 配置根日志记录器
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(numeric_level)
|
||||
|
||||
# 清除现有的处理器
|
||||
root_logger.handlers.clear()
|
||||
|
||||
# 创建格式化器
|
||||
formatter = logging.Formatter(log_format)
|
||||
|
||||
# 控制台处理器
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(formatter)
|
||||
console_handler.setLevel(numeric_level)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# 文件处理器(如果指定了日志文件)
|
||||
if log_file:
|
||||
# 确保日志目录存在
|
||||
log_dir = os.path.dirname(log_file)
|
||||
if log_dir:
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(numeric_level)
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
return root_logger
|
||||
|
||||
|
||||
def generate_password(length: int = DEFAULT_PASSWORD_LENGTH) -> str:
|
||||
"""
|
||||
生成随机密码
|
||||
|
||||
Args:
|
||||
length: 密码长度
|
||||
|
||||
Returns:
|
||||
随机密码字符串
|
||||
"""
|
||||
if length < 4:
|
||||
length = 4
|
||||
|
||||
# 确保密码包含至少一个大写字母、一个小写字母和一个数字
|
||||
password = [
|
||||
secrets.choice(string.ascii_lowercase),
|
||||
secrets.choice(string.ascii_uppercase),
|
||||
secrets.choice(string.digits),
|
||||
]
|
||||
|
||||
# 添加剩余字符
|
||||
password.extend(secrets.choice(PASSWORD_CHARSET) for _ in range(length - 3))
|
||||
|
||||
# 随机打乱
|
||||
secrets.SystemRandom().shuffle(password)
|
||||
|
||||
return ''.join(password)
|
||||
|
||||
|
||||
def generate_random_string(length: int = 8) -> str:
|
||||
"""
|
||||
生成随机字符串(仅字母)
|
||||
|
||||
Args:
|
||||
length: 字符串长度
|
||||
|
||||
Returns:
|
||||
随机字符串
|
||||
"""
|
||||
chars = string.ascii_letters
|
||||
return ''.join(secrets.choice(chars) for _ in range(length))
|
||||
|
||||
|
||||
def generate_uuid() -> str:
|
||||
"""生成 UUID 字符串"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def get_timestamp() -> int:
|
||||
"""获取当前时间戳(秒)"""
|
||||
return int(time.time())
|
||||
|
||||
|
||||
def format_datetime(dt: Optional[datetime] = None, fmt: str = "%Y-%m-%d %H:%M:%S") -> str:
|
||||
"""
|
||||
格式化日期时间
|
||||
|
||||
Args:
|
||||
dt: 日期时间对象,如果为 None 则使用当前时间
|
||||
fmt: 格式字符串
|
||||
|
||||
Returns:
|
||||
格式化后的字符串
|
||||
"""
|
||||
if dt is None:
|
||||
dt = datetime.now()
|
||||
return dt.strftime(fmt)
|
||||
|
||||
|
||||
def parse_datetime(dt_str: str, fmt: str = "%Y-%m-%d %H:%M:%S") -> Optional[datetime]:
|
||||
"""
|
||||
解析日期时间字符串
|
||||
|
||||
Args:
|
||||
dt_str: 日期时间字符串
|
||||
fmt: 格式字符串
|
||||
|
||||
Returns:
|
||||
日期时间对象,如果解析失败返回 None
|
||||
"""
|
||||
try:
|
||||
return datetime.strptime(dt_str, fmt)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def human_readable_size(size_bytes: int) -> str:
|
||||
"""
|
||||
将字节大小转换为人类可读的格式
|
||||
|
||||
Args:
|
||||
size_bytes: 字节大小
|
||||
|
||||
Returns:
|
||||
人类可读的字符串
|
||||
"""
|
||||
if size_bytes < 0:
|
||||
return "0 B"
|
||||
|
||||
units = ["B", "KB", "MB", "GB", "TB", "PB"]
|
||||
unit_index = 0
|
||||
|
||||
while size_bytes >= 1024 and unit_index < len(units) - 1:
|
||||
size_bytes /= 1024
|
||||
unit_index += 1
|
||||
|
||||
return f"{size_bytes:.2f} {units[unit_index]}"
|
||||
|
||||
|
||||
def retry_with_backoff(
|
||||
func: Callable,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 1.0,
|
||||
max_delay: float = 30.0,
|
||||
backoff_factor: float = 2.0,
|
||||
exceptions: tuple = (Exception,)
|
||||
) -> Any:
|
||||
"""
|
||||
带有指数退避的重试装饰器/函数
|
||||
|
||||
Args:
|
||||
func: 要重试的函数
|
||||
max_retries: 最大重试次数
|
||||
base_delay: 基础延迟(秒)
|
||||
max_delay: 最大延迟(秒)
|
||||
backoff_factor: 退避因子
|
||||
exceptions: 要捕获的异常类型
|
||||
|
||||
Returns:
|
||||
函数的返回值
|
||||
|
||||
Raises:
|
||||
最后一次尝试的异常
|
||||
"""
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return func()
|
||||
except exceptions as e:
|
||||
last_exception = e
|
||||
|
||||
# 如果是最后一次尝试,直接抛出异常
|
||||
if attempt == max_retries:
|
||||
break
|
||||
|
||||
# 计算延迟时间
|
||||
delay = min(base_delay * (backoff_factor ** attempt), max_delay)
|
||||
|
||||
# 添加随机抖动
|
||||
delay *= (0.5 + random.random())
|
||||
|
||||
# 记录日志
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
f"尝试 {func.__name__} 失败 (attempt {attempt + 1}/{max_retries + 1}): {e}. "
|
||||
f"等待 {delay:.2f} 秒后重试..."
|
||||
)
|
||||
|
||||
time.sleep(delay)
|
||||
|
||||
# 所有重试都失败,抛出最后一个异常
|
||||
raise last_exception
|
||||
|
||||
|
||||
class RetryDecorator:
|
||||
"""重试装饰器类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 1.0,
|
||||
max_delay: float = 30.0,
|
||||
backoff_factor: float = 2.0,
|
||||
exceptions: tuple = (Exception,)
|
||||
):
|
||||
self.max_retries = max_retries
|
||||
self.base_delay = base_delay
|
||||
self.max_delay = max_delay
|
||||
self.backoff_factor = backoff_factor
|
||||
self.exceptions = exceptions
|
||||
|
||||
def __call__(self, func: Callable) -> Callable:
|
||||
"""装饰器调用"""
|
||||
def wrapper(*args, **kwargs):
|
||||
def func_to_retry():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return retry_with_backoff(
|
||||
func_to_retry,
|
||||
max_retries=self.max_retries,
|
||||
base_delay=self.base_delay,
|
||||
max_delay=self.max_delay,
|
||||
backoff_factor=self.backoff_factor,
|
||||
exceptions=self.exceptions
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def validate_email(email: str) -> bool:
|
||||
"""
|
||||
验证邮箱地址格式
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
|
||||
Returns:
|
||||
是否有效
|
||||
"""
|
||||
pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
return bool(re.match(pattern, email))
|
||||
|
||||
|
||||
def validate_url(url: str) -> bool:
|
||||
"""
|
||||
验证 URL 格式
|
||||
|
||||
Args:
|
||||
url: URL
|
||||
|
||||
Returns:
|
||||
是否有效
|
||||
"""
|
||||
pattern = r"^https?://[^\s/$.?#].[^\s]*$"
|
||||
return bool(re.match(pattern, url))
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""
|
||||
清理文件名,移除不安全的字符
|
||||
|
||||
Args:
|
||||
filename: 原始文件名
|
||||
|
||||
Returns:
|
||||
清理后的文件名
|
||||
"""
|
||||
# 移除危险字符
|
||||
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
|
||||
# 移除控制字符
|
||||
filename = ''.join(char for char in filename if ord(char) >= 32)
|
||||
# 限制长度
|
||||
if len(filename) > 255:
|
||||
name, ext = os.path.splitext(filename)
|
||||
filename = name[:255 - len(ext)] + ext
|
||||
return filename
|
||||
|
||||
|
||||
def read_json_file(filepath: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
读取 JSON 文件
|
||||
|
||||
Args:
|
||||
filepath: 文件路径
|
||||
|
||||
Returns:
|
||||
JSON 数据,如果读取失败返回 None
|
||||
"""
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError, IOError) as e:
|
||||
logging.getLogger(__name__).warning(f"读取 JSON 文件失败: {filepath} - {e}")
|
||||
return None
|
||||
|
||||
|
||||
def write_json_file(filepath: str, data: Dict[str, Any], indent: int = 2) -> bool:
|
||||
"""
|
||||
写入 JSON 文件
|
||||
|
||||
Args:
|
||||
filepath: 文件路径
|
||||
data: 要写入的数据
|
||||
indent: 缩进空格数
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=indent)
|
||||
|
||||
return True
|
||||
except (IOError, TypeError) as e:
|
||||
logging.getLogger(__name__).error(f"写入 JSON 文件失败: {filepath} - {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_project_root() -> Path:
|
||||
"""
|
||||
获取项目根目录
|
||||
|
||||
Returns:
|
||||
项目根目录 Path 对象
|
||||
"""
|
||||
# 当前文件所在目录
|
||||
current_dir = Path(__file__).parent
|
||||
|
||||
# 向上查找直到找到项目根目录(包含 pyproject.toml 或 setup.py)
|
||||
for parent in [current_dir] + list(current_dir.parents):
|
||||
if (parent / "pyproject.toml").exists() or (parent / "setup.py").exists():
|
||||
return parent
|
||||
|
||||
# 如果找不到,返回当前目录的父目录
|
||||
return current_dir.parent
|
||||
|
||||
|
||||
def get_data_dir() -> Path:
|
||||
"""
|
||||
获取数据目录
|
||||
|
||||
Returns:
|
||||
数据目录 Path 对象
|
||||
"""
|
||||
settings = get_settings()
|
||||
if not settings.database_url.startswith("sqlite"):
|
||||
data_dir = Path(os.environ.get("APP_DATA_DIR", "data"))
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
return data_dir
|
||||
data_dir = Path(settings.database_url).parent
|
||||
|
||||
# 如果 database_url 是 SQLite URL,提取路径
|
||||
if settings.database_url.startswith("sqlite:///"):
|
||||
db_path = settings.database_url[10:] # 移除 "sqlite:///"
|
||||
data_dir = Path(db_path).parent
|
||||
|
||||
# 确保目录存在
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return data_dir
|
||||
|
||||
|
||||
def get_logs_dir() -> Path:
|
||||
"""
|
||||
获取日志目录
|
||||
|
||||
Returns:
|
||||
日志目录 Path 对象
|
||||
"""
|
||||
settings = get_settings()
|
||||
log_file = Path(settings.log_file)
|
||||
log_dir = log_file.parent
|
||||
|
||||
# 确保目录存在
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return log_dir
|
||||
|
||||
|
||||
def format_duration(seconds: int) -> str:
|
||||
"""
|
||||
格式化持续时间
|
||||
|
||||
Args:
|
||||
seconds: 秒数
|
||||
|
||||
Returns:
|
||||
格式化的持续时间字符串
|
||||
"""
|
||||
if seconds < 60:
|
||||
return f"{seconds}秒"
|
||||
|
||||
minutes, seconds = divmod(seconds, 60)
|
||||
if minutes < 60:
|
||||
return f"{minutes}分{seconds}秒"
|
||||
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
if hours < 24:
|
||||
return f"{hours}小时{minutes}分"
|
||||
|
||||
days, hours = divmod(hours, 24)
|
||||
return f"{days}天{hours}小时"
|
||||
|
||||
|
||||
def mask_sensitive_data(data: Union[str, Dict, List], mask_char: str = "*") -> Union[str, Dict, List]:
|
||||
"""
|
||||
掩码敏感数据
|
||||
|
||||
Args:
|
||||
data: 要掩码的数据
|
||||
mask_char: 掩码字符
|
||||
|
||||
Returns:
|
||||
掩码后的数据
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
# 如果是邮箱,掩码中间部分
|
||||
if "@" in data:
|
||||
local, domain = data.split("@", 1)
|
||||
if len(local) > 2:
|
||||
masked_local = local[0] + mask_char * (len(local) - 2) + local[-1]
|
||||
else:
|
||||
masked_local = mask_char * len(local)
|
||||
return f"{masked_local}@{domain}"
|
||||
|
||||
# 如果是 token 或密钥,掩码大部分内容
|
||||
if len(data) > 10:
|
||||
return data[:4] + mask_char * (len(data) - 8) + data[-4:]
|
||||
return mask_char * len(data)
|
||||
|
||||
elif isinstance(data, dict):
|
||||
masked_dict = {}
|
||||
for key, value in data.items():
|
||||
# 敏感字段名
|
||||
sensitive_keys = ["password", "token", "secret", "key", "auth", "credential"]
|
||||
if any(sensitive in key.lower() for sensitive in sensitive_keys):
|
||||
masked_dict[key] = mask_sensitive_data(value, mask_char)
|
||||
else:
|
||||
masked_dict[key] = value
|
||||
return masked_dict
|
||||
|
||||
elif isinstance(data, list):
|
||||
return [mask_sensitive_data(item, mask_char) for item in data]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def calculate_md5(data: Union[str, bytes]) -> str:
|
||||
"""
|
||||
计算 MD5 哈希
|
||||
|
||||
Args:
|
||||
data: 要哈希的数据
|
||||
|
||||
Returns:
|
||||
MD5 哈希字符串
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
data = data.encode('utf-8')
|
||||
|
||||
return hashlib.md5(data).hexdigest()
|
||||
|
||||
|
||||
def calculate_sha256(data: Union[str, bytes]) -> str:
|
||||
"""
|
||||
计算 SHA256 哈希
|
||||
|
||||
Args:
|
||||
data: 要哈希的数据
|
||||
|
||||
Returns:
|
||||
SHA256 哈希字符串
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
data = data.encode('utf-8')
|
||||
|
||||
return hashlib.sha256(data).hexdigest()
|
||||
|
||||
|
||||
def base64_encode(data: Union[str, bytes]) -> str:
|
||||
"""Base64 编码"""
|
||||
if isinstance(data, str):
|
||||
data = data.encode('utf-8')
|
||||
|
||||
return base64.b64encode(data).decode('utf-8')
|
||||
|
||||
|
||||
def base64_decode(data: str) -> str:
|
||||
"""Base64 解码"""
|
||||
try:
|
||||
decoded = base64.b64decode(data)
|
||||
return decoded.decode('utf-8')
|
||||
except (base64.binascii.Error, UnicodeDecodeError):
|
||||
return ""
|
||||
|
||||
|
||||
class Timer:
|
||||
"""计时器上下文管理器"""
|
||||
|
||||
def __init__(self, name: str = "操作"):
|
||||
self.name = name
|
||||
self.start_time = None
|
||||
self.elapsed = None
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.elapsed = time.time() - self.start_time
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.debug(f"{self.name} 耗时: {self.elapsed:.2f} 秒")
|
||||
|
||||
def get_elapsed(self) -> float:
|
||||
"""获取经过的时间(秒)"""
|
||||
if self.elapsed is not None:
|
||||
return self.elapsed
|
||||
if self.start_time is not None:
|
||||
return time.time() - self.start_time
|
||||
return 0.0
|
||||
20
src/database/__init__.py
Normal file
20
src/database/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
数据库模块
|
||||
"""
|
||||
|
||||
from .models import Base, Account, EmailService, RegistrationTask, Setting
|
||||
from .session import get_db, init_database, get_session_manager, DatabaseSessionManager
|
||||
from . import crud
|
||||
|
||||
__all__ = [
|
||||
'Base',
|
||||
'Account',
|
||||
'EmailService',
|
||||
'RegistrationTask',
|
||||
'Setting',
|
||||
'get_db',
|
||||
'init_database',
|
||||
'get_session_manager',
|
||||
'DatabaseSessionManager',
|
||||
'crud',
|
||||
]
|
||||
714
src/database/crud.py
Normal file
714
src/database/crud.py
Normal file
@@ -0,0 +1,714 @@
|
||||
"""
|
||||
数据库 CRUD 操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_, desc, asc, func
|
||||
|
||||
from .models import Account, EmailService, RegistrationTask, Setting, Proxy, CpaService, Sub2ApiService
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 账户 CRUD
|
||||
# ============================================================================
|
||||
|
||||
def create_account(
|
||||
db: Session,
|
||||
email: str,
|
||||
email_service: str,
|
||||
password: Optional[str] = None,
|
||||
client_id: Optional[str] = None,
|
||||
session_token: Optional[str] = None,
|
||||
email_service_id: Optional[str] = None,
|
||||
account_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
refresh_token: Optional[str] = None,
|
||||
id_token: Optional[str] = None,
|
||||
proxy_used: Optional[str] = None,
|
||||
expires_at: Optional['datetime'] = None,
|
||||
extra_data: Optional[Dict[str, Any]] = None,
|
||||
status: Optional[str] = None,
|
||||
source: Optional[str] = None
|
||||
) -> Account:
|
||||
"""创建新账户"""
|
||||
db_account = Account(
|
||||
email=email,
|
||||
password=password,
|
||||
client_id=client_id,
|
||||
session_token=session_token,
|
||||
email_service=email_service,
|
||||
email_service_id=email_service_id,
|
||||
account_id=account_id,
|
||||
workspace_id=workspace_id,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
id_token=id_token,
|
||||
proxy_used=proxy_used,
|
||||
expires_at=expires_at,
|
||||
extra_data=extra_data or {},
|
||||
status=status or 'active',
|
||||
source=source or 'register',
|
||||
registered_at=datetime.utcnow()
|
||||
)
|
||||
db.add(db_account)
|
||||
db.commit()
|
||||
db.refresh(db_account)
|
||||
return db_account
|
||||
|
||||
|
||||
def get_account_by_id(db: Session, account_id: int) -> Optional[Account]:
|
||||
"""根据 ID 获取账户"""
|
||||
return db.query(Account).filter(Account.id == account_id).first()
|
||||
|
||||
|
||||
def get_account_by_email(db: Session, email: str) -> Optional[Account]:
|
||||
"""根据邮箱获取账户"""
|
||||
return db.query(Account).filter(Account.email == email).first()
|
||||
|
||||
|
||||
def get_accounts(
|
||||
db: Session,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
email_service: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
search: Optional[str] = None
|
||||
) -> List[Account]:
|
||||
"""获取账户列表(支持分页、筛选)"""
|
||||
query = db.query(Account)
|
||||
|
||||
if email_service:
|
||||
query = query.filter(Account.email_service == email_service)
|
||||
|
||||
if status:
|
||||
query = query.filter(Account.status == status)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Account.email.ilike(f"%{search}%"),
|
||||
Account.account_id.ilike(f"%{search}%"),
|
||||
Account.workspace_id.ilike(f"%{search}%")
|
||||
)
|
||||
query = query.filter(search_filter)
|
||||
|
||||
query = query.order_by(desc(Account.created_at)).offset(skip).limit(limit)
|
||||
return query.all()
|
||||
|
||||
|
||||
def update_account(
|
||||
db: Session,
|
||||
account_id: int,
|
||||
**kwargs
|
||||
) -> Optional[Account]:
|
||||
"""更新账户信息"""
|
||||
db_account = get_account_by_id(db, account_id)
|
||||
if not db_account:
|
||||
return None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(db_account, key) and value is not None:
|
||||
setattr(db_account, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_account)
|
||||
return db_account
|
||||
|
||||
|
||||
def delete_account(db: Session, account_id: int) -> bool:
|
||||
"""删除账户"""
|
||||
db_account = get_account_by_id(db, account_id)
|
||||
if not db_account:
|
||||
return False
|
||||
|
||||
db.delete(db_account)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
def delete_accounts_batch(db: Session, account_ids: List[int]) -> int:
|
||||
"""批量删除账户"""
|
||||
result = db.query(Account).filter(Account.id.in_(account_ids)).delete(synchronize_session=False)
|
||||
db.commit()
|
||||
return result
|
||||
|
||||
|
||||
def get_accounts_count(
|
||||
db: Session,
|
||||
email_service: Optional[str] = None,
|
||||
status: Optional[str] = None
|
||||
) -> int:
|
||||
"""获取账户数量"""
|
||||
query = db.query(func.count(Account.id))
|
||||
|
||||
if email_service:
|
||||
query = query.filter(Account.email_service == email_service)
|
||||
|
||||
if status:
|
||||
query = query.filter(Account.status == status)
|
||||
|
||||
return query.scalar()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 邮箱服务 CRUD
|
||||
# ============================================================================
|
||||
|
||||
def create_email_service(
|
||||
db: Session,
|
||||
service_type: str,
|
||||
name: str,
|
||||
config: Dict[str, Any],
|
||||
enabled: bool = True,
|
||||
priority: int = 0
|
||||
) -> EmailService:
|
||||
"""创建邮箱服务配置"""
|
||||
db_service = EmailService(
|
||||
service_type=service_type,
|
||||
name=name,
|
||||
config=config,
|
||||
enabled=enabled,
|
||||
priority=priority
|
||||
)
|
||||
db.add(db_service)
|
||||
db.commit()
|
||||
db.refresh(db_service)
|
||||
return db_service
|
||||
|
||||
|
||||
def get_email_service_by_id(db: Session, service_id: int) -> Optional[EmailService]:
|
||||
"""根据 ID 获取邮箱服务"""
|
||||
return db.query(EmailService).filter(EmailService.id == service_id).first()
|
||||
|
||||
|
||||
def get_email_services(
|
||||
db: Session,
|
||||
service_type: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[EmailService]:
|
||||
"""获取邮箱服务列表"""
|
||||
query = db.query(EmailService)
|
||||
|
||||
if service_type:
|
||||
query = query.filter(EmailService.service_type == service_type)
|
||||
|
||||
if enabled is not None:
|
||||
query = query.filter(EmailService.enabled == enabled)
|
||||
|
||||
query = query.order_by(
|
||||
asc(EmailService.priority),
|
||||
desc(EmailService.last_used)
|
||||
).offset(skip).limit(limit)
|
||||
|
||||
return query.all()
|
||||
|
||||
|
||||
def update_email_service(
|
||||
db: Session,
|
||||
service_id: int,
|
||||
**kwargs
|
||||
) -> Optional[EmailService]:
|
||||
"""更新邮箱服务配置"""
|
||||
db_service = get_email_service_by_id(db, service_id)
|
||||
if not db_service:
|
||||
return None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(db_service, key) and value is not None:
|
||||
setattr(db_service, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_service)
|
||||
return db_service
|
||||
|
||||
|
||||
def delete_email_service(db: Session, service_id: int) -> bool:
|
||||
"""删除邮箱服务配置"""
|
||||
db_service = get_email_service_by_id(db, service_id)
|
||||
if not db_service:
|
||||
return False
|
||||
|
||||
db.delete(db_service)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 注册任务 CRUD
|
||||
# ============================================================================
|
||||
|
||||
def create_registration_task(
|
||||
db: Session,
|
||||
task_uuid: str,
|
||||
email_service_id: Optional[int] = None,
|
||||
proxy: Optional[str] = None
|
||||
) -> RegistrationTask:
|
||||
"""创建注册任务"""
|
||||
db_task = RegistrationTask(
|
||||
task_uuid=task_uuid,
|
||||
email_service_id=email_service_id,
|
||||
proxy=proxy,
|
||||
status='pending'
|
||||
)
|
||||
db.add(db_task)
|
||||
db.commit()
|
||||
db.refresh(db_task)
|
||||
return db_task
|
||||
|
||||
|
||||
def get_registration_task_by_uuid(db: Session, task_uuid: str) -> Optional[RegistrationTask]:
|
||||
"""根据 UUID 获取注册任务"""
|
||||
return db.query(RegistrationTask).filter(RegistrationTask.task_uuid == task_uuid).first()
|
||||
|
||||
|
||||
def get_registration_tasks(
|
||||
db: Session,
|
||||
status: Optional[str] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[RegistrationTask]:
|
||||
"""获取注册任务列表"""
|
||||
query = db.query(RegistrationTask)
|
||||
|
||||
if status:
|
||||
query = query.filter(RegistrationTask.status == status)
|
||||
|
||||
query = query.order_by(desc(RegistrationTask.created_at)).offset(skip).limit(limit)
|
||||
return query.all()
|
||||
|
||||
|
||||
def update_registration_task(
|
||||
db: Session,
|
||||
task_uuid: str,
|
||||
**kwargs
|
||||
) -> Optional[RegistrationTask]:
|
||||
"""更新注册任务状态"""
|
||||
db_task = get_registration_task_by_uuid(db, task_uuid)
|
||||
if not db_task:
|
||||
return None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(db_task, key):
|
||||
setattr(db_task, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_task)
|
||||
return db_task
|
||||
|
||||
|
||||
def append_task_log(db: Session, task_uuid: str, log_message: str) -> bool:
|
||||
"""追加任务日志"""
|
||||
db_task = get_registration_task_by_uuid(db, task_uuid)
|
||||
if not db_task:
|
||||
return False
|
||||
|
||||
if db_task.logs:
|
||||
db_task.logs += f"\n{log_message}"
|
||||
else:
|
||||
db_task.logs = log_message
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
def delete_registration_task(db: Session, task_uuid: str) -> bool:
|
||||
"""删除注册任务"""
|
||||
db_task = get_registration_task_by_uuid(db, task_uuid)
|
||||
if not db_task:
|
||||
return False
|
||||
|
||||
db.delete(db_task)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
# 为 API 路由添加别名
|
||||
get_account = get_account_by_id
|
||||
get_registration_task = get_registration_task_by_uuid
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 设置 CRUD
|
||||
# ============================================================================
|
||||
|
||||
def get_setting(db: Session, key: str) -> Optional[Setting]:
|
||||
"""获取设置"""
|
||||
return db.query(Setting).filter(Setting.key == key).first()
|
||||
|
||||
|
||||
def get_settings_by_category(db: Session, category: str) -> List[Setting]:
|
||||
"""根据分类获取设置"""
|
||||
return db.query(Setting).filter(Setting.category == category).all()
|
||||
|
||||
|
||||
def set_setting(
|
||||
db: Session,
|
||||
key: str,
|
||||
value: str,
|
||||
description: Optional[str] = None,
|
||||
category: str = 'general'
|
||||
) -> Setting:
|
||||
"""设置或更新配置项"""
|
||||
db_setting = get_setting(db, key)
|
||||
if db_setting:
|
||||
db_setting.value = value
|
||||
db_setting.description = description or db_setting.description
|
||||
db_setting.category = category
|
||||
db_setting.updated_at = datetime.utcnow()
|
||||
else:
|
||||
db_setting = Setting(
|
||||
key=key,
|
||||
value=value,
|
||||
description=description,
|
||||
category=category
|
||||
)
|
||||
db.add(db_setting)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_setting)
|
||||
return db_setting
|
||||
|
||||
|
||||
def delete_setting(db: Session, key: str) -> bool:
|
||||
"""删除设置"""
|
||||
db_setting = get_setting(db, key)
|
||||
if not db_setting:
|
||||
return False
|
||||
|
||||
db.delete(db_setting)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 代理 CRUD
|
||||
# ============================================================================
|
||||
|
||||
def create_proxy(
|
||||
db: Session,
|
||||
name: str,
|
||||
type: str,
|
||||
host: str,
|
||||
port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
enabled: bool = True,
|
||||
priority: int = 0
|
||||
) -> Proxy:
|
||||
"""创建代理配置"""
|
||||
db_proxy = Proxy(
|
||||
name=name,
|
||||
type=type,
|
||||
host=host,
|
||||
port=port,
|
||||
username=username,
|
||||
password=password,
|
||||
enabled=enabled,
|
||||
priority=priority
|
||||
)
|
||||
db.add(db_proxy)
|
||||
db.commit()
|
||||
db.refresh(db_proxy)
|
||||
return db_proxy
|
||||
|
||||
|
||||
def get_proxy_by_id(db: Session, proxy_id: int) -> Optional[Proxy]:
|
||||
"""根据 ID 获取代理"""
|
||||
return db.query(Proxy).filter(Proxy.id == proxy_id).first()
|
||||
|
||||
|
||||
def get_proxies(
|
||||
db: Session,
|
||||
enabled: Optional[bool] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[Proxy]:
|
||||
"""获取代理列表"""
|
||||
query = db.query(Proxy)
|
||||
|
||||
if enabled is not None:
|
||||
query = query.filter(Proxy.enabled == enabled)
|
||||
|
||||
query = query.order_by(desc(Proxy.created_at)).offset(skip).limit(limit)
|
||||
return query.all()
|
||||
|
||||
|
||||
def get_enabled_proxies(db: Session) -> List[Proxy]:
|
||||
"""获取所有启用的代理"""
|
||||
return db.query(Proxy).filter(Proxy.enabled == True).all()
|
||||
|
||||
|
||||
def update_proxy(
|
||||
db: Session,
|
||||
proxy_id: int,
|
||||
**kwargs
|
||||
) -> Optional[Proxy]:
|
||||
"""更新代理配置"""
|
||||
db_proxy = get_proxy_by_id(db, proxy_id)
|
||||
if not db_proxy:
|
||||
return None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(db_proxy, key):
|
||||
setattr(db_proxy, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_proxy)
|
||||
return db_proxy
|
||||
|
||||
|
||||
def delete_proxy(db: Session, proxy_id: int) -> bool:
|
||||
"""删除代理配置"""
|
||||
db_proxy = get_proxy_by_id(db, proxy_id)
|
||||
if not db_proxy:
|
||||
return False
|
||||
|
||||
db.delete(db_proxy)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
def update_proxy_last_used(db: Session, proxy_id: int) -> bool:
|
||||
"""更新代理最后使用时间"""
|
||||
db_proxy = get_proxy_by_id(db, proxy_id)
|
||||
if not db_proxy:
|
||||
return False
|
||||
|
||||
db_proxy.last_used = datetime.utcnow()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
def get_random_proxy(db: Session) -> Optional[Proxy]:
|
||||
"""随机获取一个启用的代理,优先返回 is_default=True 的代理"""
|
||||
import random
|
||||
# 优先返回默认代理
|
||||
default_proxy = db.query(Proxy).filter(Proxy.enabled == True, Proxy.is_default == True).first()
|
||||
if default_proxy:
|
||||
return default_proxy
|
||||
proxies = get_enabled_proxies(db)
|
||||
if not proxies:
|
||||
return None
|
||||
return random.choice(proxies)
|
||||
|
||||
|
||||
def set_proxy_default(db: Session, proxy_id: int) -> Optional[Proxy]:
|
||||
"""将指定代理设为默认,同时清除其他代理的默认标记"""
|
||||
# 清除所有默认标记
|
||||
db.query(Proxy).filter(Proxy.is_default == True).update({"is_default": False})
|
||||
# 设置新的默认代理
|
||||
proxy = db.query(Proxy).filter(Proxy.id == proxy_id).first()
|
||||
if proxy:
|
||||
proxy.is_default = True
|
||||
db.commit()
|
||||
db.refresh(proxy)
|
||||
return proxy
|
||||
|
||||
|
||||
def get_proxies_count(db: Session, enabled: Optional[bool] = None) -> int:
|
||||
"""获取代理数量"""
|
||||
query = db.query(func.count(Proxy.id))
|
||||
if enabled is not None:
|
||||
query = query.filter(Proxy.enabled == enabled)
|
||||
return query.scalar()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CPA 服务 CRUD
|
||||
# ============================================================================
|
||||
|
||||
def create_cpa_service(
|
||||
db: Session,
|
||||
name: str,
|
||||
api_url: str,
|
||||
api_token: str,
|
||||
enabled: bool = True,
|
||||
priority: int = 0
|
||||
) -> CpaService:
|
||||
"""创建 CPA 服务配置"""
|
||||
db_service = CpaService(
|
||||
name=name,
|
||||
api_url=api_url,
|
||||
api_token=api_token,
|
||||
enabled=enabled,
|
||||
priority=priority
|
||||
)
|
||||
db.add(db_service)
|
||||
db.commit()
|
||||
db.refresh(db_service)
|
||||
return db_service
|
||||
|
||||
|
||||
def get_cpa_service_by_id(db: Session, service_id: int) -> Optional[CpaService]:
|
||||
"""根据 ID 获取 CPA 服务"""
|
||||
return db.query(CpaService).filter(CpaService.id == service_id).first()
|
||||
|
||||
|
||||
def get_cpa_services(
|
||||
db: Session,
|
||||
enabled: Optional[bool] = None
|
||||
) -> List[CpaService]:
|
||||
"""获取 CPA 服务列表"""
|
||||
query = db.query(CpaService)
|
||||
if enabled is not None:
|
||||
query = query.filter(CpaService.enabled == enabled)
|
||||
return query.order_by(asc(CpaService.priority), asc(CpaService.id)).all()
|
||||
|
||||
|
||||
def update_cpa_service(
|
||||
db: Session,
|
||||
service_id: int,
|
||||
**kwargs
|
||||
) -> Optional[CpaService]:
|
||||
"""更新 CPA 服务配置"""
|
||||
db_service = get_cpa_service_by_id(db, service_id)
|
||||
if not db_service:
|
||||
return None
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(db_service, key):
|
||||
setattr(db_service, key, value)
|
||||
db.commit()
|
||||
db.refresh(db_service)
|
||||
return db_service
|
||||
|
||||
|
||||
def delete_cpa_service(db: Session, service_id: int) -> bool:
|
||||
"""删除 CPA 服务配置"""
|
||||
db_service = get_cpa_service_by_id(db, service_id)
|
||||
if not db_service:
|
||||
return False
|
||||
db.delete(db_service)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Sub2API 服务 CRUD
|
||||
# ============================================================================
|
||||
|
||||
def create_sub2api_service(
|
||||
db: Session,
|
||||
name: str,
|
||||
api_url: str,
|
||||
api_key: str,
|
||||
enabled: bool = True,
|
||||
priority: int = 0
|
||||
) -> Sub2ApiService:
|
||||
"""创建 Sub2API 服务配置"""
|
||||
svc = Sub2ApiService(
|
||||
name=name,
|
||||
api_url=api_url,
|
||||
api_key=api_key,
|
||||
enabled=enabled,
|
||||
priority=priority,
|
||||
)
|
||||
db.add(svc)
|
||||
db.commit()
|
||||
db.refresh(svc)
|
||||
return svc
|
||||
|
||||
|
||||
def get_sub2api_service_by_id(db: Session, service_id: int) -> Optional[Sub2ApiService]:
|
||||
"""按 ID 获取 Sub2API 服务"""
|
||||
return db.query(Sub2ApiService).filter(Sub2ApiService.id == service_id).first()
|
||||
|
||||
|
||||
def get_sub2api_services(
|
||||
db: Session,
|
||||
enabled: Optional[bool] = None
|
||||
) -> List[Sub2ApiService]:
|
||||
"""获取 Sub2API 服务列表"""
|
||||
query = db.query(Sub2ApiService)
|
||||
if enabled is not None:
|
||||
query = query.filter(Sub2ApiService.enabled == enabled)
|
||||
return query.order_by(asc(Sub2ApiService.priority), asc(Sub2ApiService.id)).all()
|
||||
|
||||
|
||||
def update_sub2api_service(db: Session, service_id: int, **kwargs) -> Optional[Sub2ApiService]:
|
||||
"""更新 Sub2API 服务配置"""
|
||||
svc = get_sub2api_service_by_id(db, service_id)
|
||||
if not svc:
|
||||
return None
|
||||
for key, value in kwargs.items():
|
||||
setattr(svc, key, value)
|
||||
db.commit()
|
||||
db.refresh(svc)
|
||||
return svc
|
||||
|
||||
|
||||
def delete_sub2api_service(db: Session, service_id: int) -> bool:
|
||||
"""删除 Sub2API 服务配置"""
|
||||
svc = get_sub2api_service_by_id(db, service_id)
|
||||
if not svc:
|
||||
return False
|
||||
db.delete(svc)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Team Manager 服务 CRUD
|
||||
# ============================================================================
|
||||
|
||||
def create_tm_service(
|
||||
db: Session,
|
||||
name: str,
|
||||
api_url: str,
|
||||
api_key: str,
|
||||
enabled: bool = True,
|
||||
priority: int = 0,
|
||||
):
|
||||
"""创建 Team Manager 服务配置"""
|
||||
from .models import TeamManagerService
|
||||
svc = TeamManagerService(
|
||||
name=name,
|
||||
api_url=api_url,
|
||||
api_key=api_key,
|
||||
enabled=enabled,
|
||||
priority=priority,
|
||||
)
|
||||
db.add(svc)
|
||||
db.commit()
|
||||
db.refresh(svc)
|
||||
return svc
|
||||
|
||||
|
||||
def get_tm_service_by_id(db: Session, service_id: int):
|
||||
"""按 ID 获取 Team Manager 服务"""
|
||||
from .models import TeamManagerService
|
||||
return db.query(TeamManagerService).filter(TeamManagerService.id == service_id).first()
|
||||
|
||||
|
||||
def get_tm_services(db: Session, enabled=None):
|
||||
"""获取 Team Manager 服务列表"""
|
||||
from .models import TeamManagerService
|
||||
q = db.query(TeamManagerService)
|
||||
if enabled is not None:
|
||||
q = q.filter(TeamManagerService.enabled == enabled)
|
||||
return q.order_by(TeamManagerService.priority.asc(), TeamManagerService.id.asc()).all()
|
||||
|
||||
|
||||
def update_tm_service(db: Session, service_id: int, **kwargs):
|
||||
"""更新 Team Manager 服务配置"""
|
||||
svc = get_tm_service_by_id(db, service_id)
|
||||
if not svc:
|
||||
return None
|
||||
for k, v in kwargs.items():
|
||||
setattr(svc, k, v)
|
||||
db.commit()
|
||||
db.refresh(svc)
|
||||
return svc
|
||||
|
||||
|
||||
def delete_tm_service(db: Session, service_id: int) -> bool:
|
||||
"""删除 Team Manager 服务配置"""
|
||||
svc = get_tm_service_by_id(db, service_id)
|
||||
if not svc:
|
||||
return False
|
||||
db.delete(svc)
|
||||
db.commit()
|
||||
return True
|
||||
86
src/database/init_db.py
Normal file
86
src/database/init_db.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
数据库初始化和初始化数据
|
||||
"""
|
||||
|
||||
from .session import init_database
|
||||
from .models import Base
|
||||
|
||||
|
||||
def initialize_database(database_url: str = None):
|
||||
"""
|
||||
初始化数据库
|
||||
创建所有表并设置默认配置
|
||||
"""
|
||||
# 初始化数据库连接和表
|
||||
db_manager = init_database(database_url)
|
||||
|
||||
# 创建表
|
||||
db_manager.create_tables()
|
||||
|
||||
# 初始化默认设置(从 settings 模块导入以避免循环导入)
|
||||
from ..config.settings import init_default_settings
|
||||
init_default_settings()
|
||||
|
||||
return db_manager
|
||||
|
||||
|
||||
def reset_database(database_url: str = None):
|
||||
"""
|
||||
重置数据库(删除所有表并重新创建)
|
||||
警告:会丢失所有数据!
|
||||
"""
|
||||
db_manager = init_database(database_url)
|
||||
|
||||
# 删除所有表
|
||||
db_manager.drop_tables()
|
||||
print("已删除所有表")
|
||||
|
||||
# 重新创建所有表
|
||||
db_manager.create_tables()
|
||||
print("已重新创建所有表")
|
||||
|
||||
# 初始化默认设置
|
||||
from ..config.settings import init_default_settings
|
||||
init_default_settings()
|
||||
|
||||
print("数据库重置完成")
|
||||
return db_manager
|
||||
|
||||
|
||||
def check_database_connection(database_url: str = None) -> bool:
|
||||
"""
|
||||
检查数据库连接是否正常
|
||||
"""
|
||||
try:
|
||||
db_manager = init_database(database_url)
|
||||
with db_manager.get_db() as db:
|
||||
# 尝试执行一个简单的查询
|
||||
db.execute("SELECT 1")
|
||||
print("数据库连接正常")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"数据库连接失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 当直接运行此脚本时,初始化数据库
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="数据库初始化脚本")
|
||||
parser.add_argument("--reset", action="store_true", help="重置数据库(删除所有数据)")
|
||||
parser.add_argument("--check", action="store_true", help="检查数据库连接")
|
||||
parser.add_argument("--url", help="数据库连接字符串")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.check:
|
||||
check_database_connection(args.url)
|
||||
elif args.reset:
|
||||
confirm = input("警告:这将删除所有数据!确认重置?(y/N): ")
|
||||
if confirm.lower() == 'y':
|
||||
reset_database(args.url)
|
||||
else:
|
||||
print("操作已取消")
|
||||
else:
|
||||
initialize_database(args.url)
|
||||
229
src/database/models.py
Normal file
229
src/database/models.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
SQLAlchemy ORM 模型定义
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
import json
|
||||
from sqlalchemy import Column, Integer, String, Text, Boolean, DateTime, ForeignKey
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class JSONEncodedDict(TypeDecorator):
|
||||
"""JSON 编码字典类型"""
|
||||
impl = Text
|
||||
|
||||
def process_bind_param(self, value: Optional[Dict[str, Any]], dialect):
|
||||
if value is None:
|
||||
return None
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
def process_result_value(self, value: Optional[str], dialect):
|
||||
if value is None:
|
||||
return None
|
||||
return json.loads(value)
|
||||
|
||||
|
||||
class Account(Base):
|
||||
"""已注册账号表"""
|
||||
__tablename__ = 'accounts'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
email = Column(String(255), nullable=False, unique=True, index=True)
|
||||
password = Column(String(255)) # 注册密码(明文存储)
|
||||
access_token = Column(Text)
|
||||
refresh_token = Column(Text)
|
||||
id_token = Column(Text)
|
||||
session_token = Column(Text) # 会话令牌(优先刷新方式)
|
||||
client_id = Column(String(255)) # OAuth Client ID
|
||||
account_id = Column(String(255))
|
||||
workspace_id = Column(String(255))
|
||||
email_service = Column(String(50), nullable=False) # 'tempmail', 'outlook', 'moe_mail'
|
||||
email_service_id = Column(String(255)) # 邮箱服务中的ID
|
||||
proxy_used = Column(String(255))
|
||||
registered_at = Column(DateTime, default=datetime.utcnow)
|
||||
last_refresh = Column(DateTime) # 最后刷新时间
|
||||
expires_at = Column(DateTime) # Token 过期时间
|
||||
status = Column(String(20), default='active') # 'active', 'expired', 'banned', 'failed'
|
||||
extra_data = Column(JSONEncodedDict) # 额外信息存储
|
||||
cpa_uploaded = Column(Boolean, default=False) # 是否已上传到 CPA
|
||||
cpa_uploaded_at = Column(DateTime) # 上传时间
|
||||
source = Column(String(20), default='register') # 'register' 或 'login',区分账号来源
|
||||
subscription_type = Column(String(20)) # None / 'plus' / 'team'
|
||||
subscription_at = Column(DateTime) # 订阅开通时间
|
||||
cookies = Column(Text) # 完整 cookie 字符串,用于支付请求
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
'id': self.id,
|
||||
'email': self.email,
|
||||
'password': self.password,
|
||||
'client_id': self.client_id,
|
||||
'email_service': self.email_service,
|
||||
'account_id': self.account_id,
|
||||
'workspace_id': self.workspace_id,
|
||||
'registered_at': self.registered_at.isoformat() if self.registered_at else None,
|
||||
'last_refresh': self.last_refresh.isoformat() if self.last_refresh else None,
|
||||
'expires_at': self.expires_at.isoformat() if self.expires_at else None,
|
||||
'status': self.status,
|
||||
'proxy_used': self.proxy_used,
|
||||
'cpa_uploaded': self.cpa_uploaded,
|
||||
'cpa_uploaded_at': self.cpa_uploaded_at.isoformat() if self.cpa_uploaded_at else None,
|
||||
'source': self.source,
|
||||
'subscription_type': self.subscription_type,
|
||||
'subscription_at': self.subscription_at.isoformat() if self.subscription_at else None,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'updated_at': self.updated_at.isoformat() if self.updated_at else None
|
||||
}
|
||||
|
||||
|
||||
class EmailService(Base):
|
||||
"""邮箱服务配置表"""
|
||||
__tablename__ = 'email_services'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
service_type = Column(String(50), nullable=False) # 'outlook', 'moe_mail'
|
||||
name = Column(String(100), nullable=False)
|
||||
config = Column(JSONEncodedDict, nullable=False) # 服务配置(加密存储)
|
||||
enabled = Column(Boolean, default=True)
|
||||
priority = Column(Integer, default=0) # 使用优先级
|
||||
last_used = Column(DateTime)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class RegistrationTask(Base):
|
||||
"""注册任务表"""
|
||||
__tablename__ = 'registration_tasks'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
task_uuid = Column(String(36), unique=True, nullable=False, index=True) # 任务唯一标识
|
||||
status = Column(String(20), default='pending') # 'pending', 'running', 'completed', 'failed', 'cancelled'
|
||||
email_service_id = Column(Integer, ForeignKey('email_services.id'), index=True) # 使用的邮箱服务
|
||||
proxy = Column(String(255)) # 使用的代理
|
||||
logs = Column(Text) # 注册过程日志
|
||||
result = Column(JSONEncodedDict) # 注册结果
|
||||
error_message = Column(Text)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
started_at = Column(DateTime)
|
||||
completed_at = Column(DateTime)
|
||||
|
||||
# 关系
|
||||
email_service = relationship('EmailService')
|
||||
|
||||
|
||||
class Setting(Base):
|
||||
"""系统设置表"""
|
||||
__tablename__ = 'settings'
|
||||
|
||||
key = Column(String(100), primary_key=True)
|
||||
value = Column(Text)
|
||||
description = Column(Text)
|
||||
category = Column(String(50), default='general') # 'general', 'email', 'proxy', 'openai'
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class CpaService(Base):
|
||||
"""CPA 服务配置表"""
|
||||
__tablename__ = 'cpa_services'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(String(100), nullable=False) # 服务名称
|
||||
api_url = Column(String(500), nullable=False) # API URL
|
||||
api_token = Column(Text, nullable=False) # API Token
|
||||
enabled = Column(Boolean, default=True)
|
||||
priority = Column(Integer, default=0) # 优先级
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class Sub2ApiService(Base):
|
||||
"""Sub2API 服务配置表"""
|
||||
__tablename__ = 'sub2api_services'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(String(100), nullable=False) # 服务名称
|
||||
api_url = Column(String(500), nullable=False) # API URL (host)
|
||||
api_key = Column(Text, nullable=False) # x-api-key
|
||||
enabled = Column(Boolean, default=True)
|
||||
priority = Column(Integer, default=0) # 优先级
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class TeamManagerService(Base):
|
||||
"""Team Manager 服务配置表"""
|
||||
__tablename__ = 'tm_services'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(String(100), nullable=False) # 服务名称
|
||||
api_url = Column(String(500), nullable=False) # API URL
|
||||
api_key = Column(Text, nullable=False) # X-API-Key
|
||||
enabled = Column(Boolean, default=True)
|
||||
priority = Column(Integer, default=0) # 优先级
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class Proxy(Base):
|
||||
"""代理列表表"""
|
||||
__tablename__ = 'proxies'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(String(100), nullable=False) # 代理名称
|
||||
type = Column(String(20), nullable=False, default='http') # http, socks5
|
||||
host = Column(String(255), nullable=False)
|
||||
port = Column(Integer, nullable=False)
|
||||
username = Column(String(100))
|
||||
password = Column(String(255))
|
||||
enabled = Column(Boolean, default=True)
|
||||
is_default = Column(Boolean, default=False) # 是否为默认代理
|
||||
priority = Column(Integer, default=0) # 优先级(保留字段)
|
||||
last_used = Column(DateTime) # 最后使用时间
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
def to_dict(self, include_password: bool = False) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
result = {
|
||||
'id': self.id,
|
||||
'name': self.name,
|
||||
'type': self.type,
|
||||
'host': self.host,
|
||||
'port': self.port,
|
||||
'username': self.username,
|
||||
'enabled': self.enabled,
|
||||
'is_default': self.is_default or False,
|
||||
'priority': self.priority,
|
||||
'last_used': self.last_used.isoformat() if self.last_used else None,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
if include_password:
|
||||
result['password'] = self.password
|
||||
else:
|
||||
result['has_password'] = bool(self.password)
|
||||
return result
|
||||
|
||||
@property
|
||||
def proxy_url(self) -> str:
|
||||
"""获取完整的代理 URL"""
|
||||
if self.type == "http":
|
||||
scheme = "http"
|
||||
elif self.type == "socks5":
|
||||
scheme = "socks5"
|
||||
else:
|
||||
scheme = self.type
|
||||
|
||||
auth = ""
|
||||
if self.username and self.password:
|
||||
auth = f"{self.username}:{self.password}@"
|
||||
|
||||
return f"{scheme}://{auth}{self.host}:{self.port}"
|
||||
182
src/database/session.py
Normal file
182
src/database/session.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
数据库会话管理
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
import os
|
||||
import logging
|
||||
|
||||
from .models import Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_sqlalchemy_url(database_url: str) -> str:
|
||||
if database_url.startswith("postgresql://"):
|
||||
return "postgresql+psycopg://" + database_url[len("postgresql://"):]
|
||||
if database_url.startswith("postgres://"):
|
||||
return "postgresql+psycopg://" + database_url[len("postgres://"):]
|
||||
return database_url
|
||||
|
||||
|
||||
class DatabaseSessionManager:
|
||||
"""数据库会话管理器"""
|
||||
|
||||
def __init__(self, database_url: str = None):
|
||||
if database_url is None:
|
||||
env_url = os.environ.get("APP_DATABASE_URL") or os.environ.get("DATABASE_URL")
|
||||
if env_url:
|
||||
database_url = env_url
|
||||
else:
|
||||
# 优先使用 APP_DATA_DIR 环境变量(PyInstaller 打包后由 webui.py 设置)
|
||||
data_dir = os.environ.get('APP_DATA_DIR') or os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
'data'
|
||||
)
|
||||
db_path = os.path.join(data_dir, 'database.db')
|
||||
# 确保目录存在
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
database_url = f"sqlite:///{db_path}"
|
||||
|
||||
self.database_url = _build_sqlalchemy_url(database_url)
|
||||
self.engine = create_engine(
|
||||
self.database_url,
|
||||
connect_args={"check_same_thread": False} if self.database_url.startswith("sqlite") else {},
|
||||
echo=False, # 设置为 True 可以查看所有 SQL 语句
|
||||
pool_pre_ping=True # 连接池预检查
|
||||
)
|
||||
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
|
||||
|
||||
def get_db(self) -> Generator[Session, None, None]:
|
||||
"""
|
||||
获取数据库会话的上下文管理器
|
||||
使用示例:
|
||||
with get_db() as db:
|
||||
# 使用 db 进行数据库操作
|
||||
pass
|
||||
"""
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@contextmanager
|
||||
def session_scope(self) -> Generator[Session, None, None]:
|
||||
"""
|
||||
事务作用域上下文管理器
|
||||
使用示例:
|
||||
with session_scope() as session:
|
||||
# 数据库操作
|
||||
pass
|
||||
"""
|
||||
session = self.SessionLocal()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise e
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def create_tables(self):
|
||||
"""创建所有表"""
|
||||
Base.metadata.create_all(bind=self.engine)
|
||||
|
||||
def drop_tables(self):
|
||||
"""删除所有表(谨慎使用)"""
|
||||
Base.metadata.drop_all(bind=self.engine)
|
||||
|
||||
def migrate_tables(self):
|
||||
"""
|
||||
数据库迁移 - 添加缺失的列
|
||||
用于在不删除数据的情况下更新表结构
|
||||
"""
|
||||
if not self.database_url.startswith("sqlite"):
|
||||
logger.info("非 SQLite 数据库,跳过自动迁移")
|
||||
return
|
||||
|
||||
# 需要检查和添加的新列
|
||||
migrations = [
|
||||
# (表名, 列名, 列类型)
|
||||
("accounts", "cpa_uploaded", "BOOLEAN DEFAULT 0"),
|
||||
("accounts", "cpa_uploaded_at", "DATETIME"),
|
||||
("accounts", "source", "VARCHAR(20) DEFAULT 'register'"),
|
||||
("accounts", "subscription_type", "VARCHAR(20)"),
|
||||
("accounts", "subscription_at", "DATETIME"),
|
||||
("accounts", "cookies", "TEXT"),
|
||||
("proxies", "is_default", "BOOLEAN DEFAULT 0"),
|
||||
]
|
||||
|
||||
# 确保新表存在(create_tables 已处理,此处兜底)
|
||||
Base.metadata.create_all(bind=self.engine)
|
||||
|
||||
with self.engine.connect() as conn:
|
||||
# 数据迁移:将旧的 custom_domain 记录统一为 moe_mail
|
||||
try:
|
||||
conn.execute(text("UPDATE email_services SET service_type='moe_mail' WHERE service_type='custom_domain'"))
|
||||
conn.execute(text("UPDATE accounts SET email_service='moe_mail' WHERE email_service='custom_domain'"))
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"迁移 custom_domain -> moe_mail 时出错: {e}")
|
||||
|
||||
for table_name, column_name, column_type in migrations:
|
||||
try:
|
||||
# 检查列是否存在
|
||||
result = conn.execute(text(
|
||||
f"SELECT * FROM pragma_table_info('{table_name}') WHERE name='{column_name}'"
|
||||
))
|
||||
if result.fetchone() is None:
|
||||
# 列不存在,添加它
|
||||
logger.info(f"添加列 {table_name}.{column_name}")
|
||||
conn.execute(text(
|
||||
f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}"
|
||||
))
|
||||
conn.commit()
|
||||
logger.info(f"成功添加列 {table_name}.{column_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"迁移列 {table_name}.{column_name} 时出错: {e}")
|
||||
|
||||
|
||||
# 全局数据库会话管理器实例
|
||||
_db_manager: DatabaseSessionManager = None
|
||||
|
||||
|
||||
def init_database(database_url: str = None) -> DatabaseSessionManager:
|
||||
"""
|
||||
初始化数据库会话管理器
|
||||
"""
|
||||
global _db_manager
|
||||
if _db_manager is None:
|
||||
_db_manager = DatabaseSessionManager(database_url)
|
||||
_db_manager.create_tables()
|
||||
# 执行数据库迁移
|
||||
_db_manager.migrate_tables()
|
||||
return _db_manager
|
||||
|
||||
|
||||
def get_session_manager() -> DatabaseSessionManager:
|
||||
"""
|
||||
获取数据库会话管理器
|
||||
"""
|
||||
if _db_manager is None:
|
||||
raise RuntimeError("数据库未初始化,请先调用 init_database()")
|
||||
return _db_manager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""
|
||||
获取数据库会话的快捷函数
|
||||
"""
|
||||
manager = get_session_manager()
|
||||
db = manager.SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
73
src/services/__init__.py
Normal file
73
src/services/__init__.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
邮箱服务模块
|
||||
"""
|
||||
|
||||
from .base import (
|
||||
BaseEmailService,
|
||||
EmailServiceError,
|
||||
EmailServiceStatus,
|
||||
EmailServiceFactory,
|
||||
create_email_service,
|
||||
EmailServiceType
|
||||
)
|
||||
from .tempmail import TempmailService
|
||||
from .outlook import OutlookService
|
||||
from .moe_mail import MeoMailEmailService
|
||||
from .temp_mail import TempMailService
|
||||
from .duck_mail import DuckMailService
|
||||
from .freemail import FreemailService
|
||||
from .imap_mail import ImapMailService
|
||||
|
||||
# 注册服务
|
||||
EmailServiceFactory.register(EmailServiceType.TEMPMAIL, TempmailService)
|
||||
EmailServiceFactory.register(EmailServiceType.OUTLOOK, OutlookService)
|
||||
EmailServiceFactory.register(EmailServiceType.MOE_MAIL, MeoMailEmailService)
|
||||
EmailServiceFactory.register(EmailServiceType.TEMP_MAIL, TempMailService)
|
||||
EmailServiceFactory.register(EmailServiceType.DUCK_MAIL, DuckMailService)
|
||||
EmailServiceFactory.register(EmailServiceType.FREEMAIL, FreemailService)
|
||||
EmailServiceFactory.register(EmailServiceType.IMAP_MAIL, ImapMailService)
|
||||
|
||||
# 导出 Outlook 模块的额外内容
|
||||
from .outlook.base import (
|
||||
ProviderType,
|
||||
EmailMessage,
|
||||
TokenInfo,
|
||||
ProviderHealth,
|
||||
ProviderStatus,
|
||||
)
|
||||
from .outlook.account import OutlookAccount
|
||||
from .outlook.providers import (
|
||||
OutlookProvider,
|
||||
IMAPOldProvider,
|
||||
IMAPNewProvider,
|
||||
GraphAPIProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 基类
|
||||
'BaseEmailService',
|
||||
'EmailServiceError',
|
||||
'EmailServiceStatus',
|
||||
'EmailServiceFactory',
|
||||
'create_email_service',
|
||||
'EmailServiceType',
|
||||
# 服务类
|
||||
'TempmailService',
|
||||
'OutlookService',
|
||||
'MeoMailEmailService',
|
||||
'TempMailService',
|
||||
'DuckMailService',
|
||||
'FreemailService',
|
||||
'ImapMailService',
|
||||
# Outlook 模块
|
||||
'ProviderType',
|
||||
'EmailMessage',
|
||||
'TokenInfo',
|
||||
'ProviderHealth',
|
||||
'ProviderStatus',
|
||||
'OutlookAccount',
|
||||
'OutlookProvider',
|
||||
'IMAPOldProvider',
|
||||
'IMAPNewProvider',
|
||||
'GraphAPIProvider',
|
||||
]
|
||||
386
src/services/base.py
Normal file
386
src/services/base.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
邮箱服务抽象基类
|
||||
所有邮箱服务实现的基类
|
||||
"""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from enum import Enum
|
||||
|
||||
from ..config.constants import EmailServiceType
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmailServiceError(Exception):
|
||||
"""邮箱服务异常"""
|
||||
pass
|
||||
|
||||
|
||||
class EmailServiceStatus(Enum):
|
||||
"""邮箱服务状态"""
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded"
|
||||
UNAVAILABLE = "unavailable"
|
||||
|
||||
|
||||
class BaseEmailService(abc.ABC):
|
||||
"""
|
||||
邮箱服务抽象基类
|
||||
|
||||
所有邮箱服务必须实现此接口
|
||||
"""
|
||||
|
||||
def __init__(self, service_type: EmailServiceType, name: str = None):
|
||||
"""
|
||||
初始化邮箱服务
|
||||
|
||||
Args:
|
||||
service_type: 服务类型
|
||||
name: 服务名称
|
||||
"""
|
||||
self.service_type = service_type
|
||||
self.name = name or f"{service_type.value}_service"
|
||||
self._status = EmailServiceStatus.HEALTHY
|
||||
self._last_error = None
|
||||
|
||||
@property
|
||||
def status(self) -> EmailServiceStatus:
|
||||
"""获取服务状态"""
|
||||
return self._status
|
||||
|
||||
@property
|
||||
def last_error(self) -> Optional[str]:
|
||||
"""获取最后一次错误信息"""
|
||||
return self._last_error
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
创建新邮箱地址
|
||||
|
||||
Args:
|
||||
config: 配置参数,如邮箱前缀、域名等
|
||||
|
||||
Returns:
|
||||
包含邮箱信息的字典,至少包含:
|
||||
- email: 邮箱地址
|
||||
- service_id: 邮箱服务中的 ID
|
||||
- token/credentials: 访问凭证(如果需要)
|
||||
|
||||
Raises:
|
||||
EmailServiceError: 创建失败
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_verification_code(
|
||||
self,
|
||||
email: str,
|
||||
email_id: str = None,
|
||||
timeout: int = 120,
|
||||
pattern: str = r"(?<!\d)(\d{6})(?!\d)",
|
||||
otp_sent_at: Optional[float] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
获取验证码
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
email_id: 邮箱服务中的 ID(如果需要)
|
||||
timeout: 超时时间(秒)
|
||||
pattern: 验证码正则表达式
|
||||
otp_sent_at: OTP 发送时间戳,用于过滤旧邮件
|
||||
|
||||
Returns:
|
||||
验证码字符串,如果超时或未找到返回 None
|
||||
|
||||
Raises:
|
||||
EmailServiceError: 服务错误
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def list_emails(self, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列出所有邮箱(如果服务支持)
|
||||
|
||||
Args:
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
邮箱列表
|
||||
|
||||
Raises:
|
||||
EmailServiceError: 服务错误
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_email(self, email_id: str) -> bool:
|
||||
"""
|
||||
删除邮箱
|
||||
|
||||
Args:
|
||||
email_id: 邮箱服务中的 ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
|
||||
Raises:
|
||||
EmailServiceError: 服务错误
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def check_health(self) -> bool:
|
||||
"""
|
||||
检查服务健康状态
|
||||
|
||||
Returns:
|
||||
服务是否健康
|
||||
|
||||
Note:
|
||||
此方法不应抛出异常,应捕获异常并返回 False
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_email_info(self, email_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取邮箱信息(可选实现)
|
||||
|
||||
Args:
|
||||
email_id: 邮箱服务中的 ID
|
||||
|
||||
Returns:
|
||||
邮箱信息字典,如果不存在返回 None
|
||||
"""
|
||||
# 默认实现:遍历列表查找
|
||||
for email_info in self.list_emails():
|
||||
if email_info.get("id") == email_id:
|
||||
return email_info
|
||||
return None
|
||||
|
||||
def wait_for_email(
|
||||
self,
|
||||
email: str,
|
||||
email_id: str = None,
|
||||
timeout: int = 120,
|
||||
check_interval: int = 3,
|
||||
expected_sender: str = None,
|
||||
expected_subject: str = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
等待并获取邮件(可选实现)
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
email_id: 邮箱服务中的 ID
|
||||
timeout: 超时时间(秒)
|
||||
check_interval: 检查间隔(秒)
|
||||
expected_sender: 期望的发件人(包含检查)
|
||||
expected_subject: 期望的主题(包含检查)
|
||||
|
||||
Returns:
|
||||
邮件信息字典,如果超时返回 None
|
||||
"""
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
start_time = time.time()
|
||||
last_email_id = None
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
emails = self.list_emails()
|
||||
for email_info in emails:
|
||||
email_data = email_info.get("email", {})
|
||||
current_email_id = email_info.get("id")
|
||||
|
||||
# 检查是否是新的邮件
|
||||
if last_email_id and current_email_id == last_email_id:
|
||||
continue
|
||||
|
||||
# 检查邮箱地址
|
||||
if email_data.get("address") != email:
|
||||
continue
|
||||
|
||||
# 获取邮件列表
|
||||
messages = self.get_email_messages(email_id or current_email_id)
|
||||
for message in messages:
|
||||
# 检查发件人
|
||||
if expected_sender and expected_sender not in message.get("from", ""):
|
||||
continue
|
||||
|
||||
# 检查主题
|
||||
if expected_subject and expected_subject not in message.get("subject", ""):
|
||||
continue
|
||||
|
||||
# 返回邮件信息
|
||||
return {
|
||||
"id": message.get("id"),
|
||||
"from": message.get("from"),
|
||||
"subject": message.get("subject"),
|
||||
"content": message.get("content"),
|
||||
"received_at": message.get("received_at"),
|
||||
"email_info": email_info
|
||||
}
|
||||
|
||||
# 更新最后检查的邮件 ID
|
||||
if messages:
|
||||
last_email_id = current_email_id
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"等待邮件时出错: {e}")
|
||||
|
||||
time.sleep(check_interval)
|
||||
|
||||
return None
|
||||
|
||||
def get_email_messages(self, email_id: str, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取邮箱中的邮件列表(可选实现)
|
||||
|
||||
Args:
|
||||
email_id: 邮箱服务中的 ID
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
邮件列表
|
||||
|
||||
Note:
|
||||
这是可选方法,某些服务可能不支持
|
||||
"""
|
||||
raise NotImplementedError("此邮箱服务不支持获取邮件列表")
|
||||
|
||||
def get_message_content(self, email_id: str, message_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取邮件内容(可选实现)
|
||||
|
||||
Args:
|
||||
email_id: 邮箱服务中的 ID
|
||||
message_id: 邮件 ID
|
||||
|
||||
Returns:
|
||||
邮件内容字典
|
||||
|
||||
Note:
|
||||
这是可选方法,某些服务可能不支持
|
||||
"""
|
||||
raise NotImplementedError("此邮箱服务不支持获取邮件内容")
|
||||
|
||||
def update_status(self, success: bool, error: Exception = None):
|
||||
"""
|
||||
更新服务状态
|
||||
|
||||
Args:
|
||||
success: 操作是否成功
|
||||
error: 错误信息
|
||||
"""
|
||||
if success:
|
||||
self._status = EmailServiceStatus.HEALTHY
|
||||
self._last_error = None
|
||||
else:
|
||||
self._status = EmailServiceStatus.DEGRADED
|
||||
if error:
|
||||
self._last_error = str(error)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""字符串表示"""
|
||||
return f"{self.name} ({self.service_type.value})"
|
||||
|
||||
|
||||
class EmailServiceFactory:
|
||||
"""邮箱服务工厂"""
|
||||
|
||||
_registry: Dict[EmailServiceType, type] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, service_type: EmailServiceType, service_class: type):
|
||||
"""
|
||||
注册邮箱服务类
|
||||
|
||||
Args:
|
||||
service_type: 服务类型
|
||||
service_class: 服务类
|
||||
"""
|
||||
if not issubclass(service_class, BaseEmailService):
|
||||
raise TypeError(f"{service_class} 必须是 BaseEmailService 的子类")
|
||||
cls._registry[service_type] = service_class
|
||||
logger.info(f"注册邮箱服务: {service_type.value} -> {service_class.__name__}")
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
service_type: EmailServiceType,
|
||||
config: Dict[str, Any],
|
||||
name: str = None
|
||||
) -> BaseEmailService:
|
||||
"""
|
||||
创建邮箱服务实例
|
||||
|
||||
Args:
|
||||
service_type: 服务类型
|
||||
config: 服务配置
|
||||
name: 服务名称
|
||||
|
||||
Returns:
|
||||
邮箱服务实例
|
||||
|
||||
Raises:
|
||||
ValueError: 服务类型未注册或配置无效
|
||||
"""
|
||||
if service_type not in cls._registry:
|
||||
raise ValueError(f"未注册的服务类型: {service_type.value}")
|
||||
|
||||
service_class = cls._registry[service_type]
|
||||
try:
|
||||
instance = service_class(config, name)
|
||||
return instance
|
||||
except Exception as e:
|
||||
raise ValueError(f"创建邮箱服务失败: {e}")
|
||||
|
||||
@classmethod
|
||||
def get_available_services(cls) -> List[EmailServiceType]:
|
||||
"""
|
||||
获取所有已注册的服务类型
|
||||
|
||||
Returns:
|
||||
已注册的服务类型列表
|
||||
"""
|
||||
return list(cls._registry.keys())
|
||||
|
||||
@classmethod
|
||||
def get_service_class(cls, service_type: EmailServiceType) -> Optional[type]:
|
||||
"""
|
||||
获取服务类
|
||||
|
||||
Args:
|
||||
service_type: 服务类型
|
||||
|
||||
Returns:
|
||||
服务类,如果未注册返回 None
|
||||
"""
|
||||
return cls._registry.get(service_type)
|
||||
|
||||
|
||||
# 简化的工厂函数
|
||||
def create_email_service(
|
||||
service_type: EmailServiceType,
|
||||
config: Dict[str, Any],
|
||||
name: str = None
|
||||
) -> BaseEmailService:
|
||||
"""
|
||||
创建邮箱服务(简化工厂函数)
|
||||
|
||||
Args:
|
||||
service_type: 服务类型
|
||||
config: 服务配置
|
||||
name: 服务名称
|
||||
|
||||
Returns:
|
||||
邮箱服务实例
|
||||
"""
|
||||
return EmailServiceFactory.create(service_type, config, name)
|
||||
366
src/services/duck_mail.py
Normal file
366
src/services/duck_mail.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
DuckMail 邮箱服务实现
|
||||
兼容 DuckMail 的 accounts/token/messages 接口模型
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from html import unescape
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import BaseEmailService, EmailServiceError, EmailServiceType
|
||||
from ..config.constants import OTP_CODE_PATTERN
|
||||
from ..core.http_client import HTTPClient, RequestConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DuckMailService(BaseEmailService):
|
||||
"""DuckMail 邮箱服务"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None, name: str = None):
|
||||
super().__init__(EmailServiceType.DUCK_MAIL, name)
|
||||
|
||||
required_keys = ["base_url", "default_domain"]
|
||||
missing_keys = [key for key in required_keys if not (config or {}).get(key)]
|
||||
if missing_keys:
|
||||
raise ValueError(f"缺少必需配置: {missing_keys}")
|
||||
|
||||
default_config = {
|
||||
"api_key": "",
|
||||
"password_length": 12,
|
||||
"expires_in": None,
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
"proxy_url": None,
|
||||
}
|
||||
self.config = {**default_config, **(config or {})}
|
||||
self.config["base_url"] = str(self.config["base_url"]).rstrip("/")
|
||||
self.config["default_domain"] = str(self.config["default_domain"]).strip().lstrip("@")
|
||||
|
||||
http_config = RequestConfig(
|
||||
timeout=self.config["timeout"],
|
||||
max_retries=self.config["max_retries"],
|
||||
)
|
||||
self.http_client = HTTPClient(
|
||||
proxy_url=self.config.get("proxy_url"),
|
||||
config=http_config,
|
||||
)
|
||||
|
||||
self._accounts_by_id: Dict[str, Dict[str, Any]] = {}
|
||||
self._accounts_by_email: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def _build_headers(
|
||||
self,
|
||||
token: Optional[str] = None,
|
||||
use_api_key: bool = False,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, str]:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
auth_token = token
|
||||
if not auth_token and use_api_key and self.config.get("api_key"):
|
||||
auth_token = self.config["api_key"]
|
||||
|
||||
if auth_token:
|
||||
headers["Authorization"] = f"Bearer {auth_token}"
|
||||
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
return headers
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
token: Optional[str] = None,
|
||||
use_api_key: bool = False,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{self.config['base_url']}{path}"
|
||||
kwargs["headers"] = self._build_headers(
|
||||
token=token,
|
||||
use_api_key=use_api_key,
|
||||
extra_headers=kwargs.get("headers"),
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.http_client.request(method, url, **kwargs)
|
||||
if response.status_code >= 400:
|
||||
error_message = f"API 请求失败: {response.status_code}"
|
||||
try:
|
||||
error_payload = response.json()
|
||||
error_message = f"{error_message} - {error_payload}"
|
||||
except Exception:
|
||||
error_message = f"{error_message} - {response.text[:200]}"
|
||||
raise EmailServiceError(error_message)
|
||||
|
||||
try:
|
||||
return response.json()
|
||||
except Exception:
|
||||
return {"raw_response": response.text}
|
||||
except Exception as e:
|
||||
self.update_status(False, e)
|
||||
if isinstance(e, EmailServiceError):
|
||||
raise
|
||||
raise EmailServiceError(f"请求失败: {method} {path} - {e}")
|
||||
|
||||
def _generate_local_part(self) -> str:
|
||||
first = random.choice(string.ascii_lowercase)
|
||||
rest = "".join(random.choices(string.ascii_lowercase + string.digits, k=7))
|
||||
return f"{first}{rest}"
|
||||
|
||||
def _generate_password(self) -> str:
|
||||
length = max(6, int(self.config.get("password_length") or 12))
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return "".join(random.choices(alphabet, k=length))
|
||||
|
||||
def _cache_account(self, account_info: Dict[str, Any]) -> None:
|
||||
account_id = str(account_info.get("account_id") or account_info.get("service_id") or "").strip()
|
||||
email = str(account_info.get("email") or "").strip().lower()
|
||||
|
||||
if account_id:
|
||||
self._accounts_by_id[account_id] = account_info
|
||||
if email:
|
||||
self._accounts_by_email[email] = account_info
|
||||
|
||||
def _get_account_info(self, email: Optional[str] = None, email_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
||||
if email_id:
|
||||
cached = self._accounts_by_id.get(str(email_id))
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
if email:
|
||||
cached = self._accounts_by_email.get(str(email).strip().lower())
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
return None
|
||||
|
||||
def _strip_html(self, html_content: Any) -> str:
|
||||
if isinstance(html_content, list):
|
||||
html_content = "\n".join(str(item) for item in html_content if item)
|
||||
text = str(html_content or "")
|
||||
return unescape(re.sub(r"<[^>]+>", " ", text))
|
||||
|
||||
def _parse_message_time(self, value: Optional[str]) -> Optional[float]:
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
normalized = value.replace("Z", "+00:00")
|
||||
return datetime.fromisoformat(normalized).astimezone(timezone.utc).timestamp()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _message_search_text(self, summary: Dict[str, Any], detail: Dict[str, Any]) -> str:
|
||||
sender = summary.get("from") or detail.get("from") or {}
|
||||
if isinstance(sender, dict):
|
||||
sender_text = " ".join(
|
||||
str(sender.get(key) or "") for key in ("name", "address")
|
||||
).strip()
|
||||
else:
|
||||
sender_text = str(sender)
|
||||
|
||||
subject = str(summary.get("subject") or detail.get("subject") or "")
|
||||
text_body = str(detail.get("text") or "")
|
||||
html_body = self._strip_html(detail.get("html"))
|
||||
return "\n".join(part for part in [sender_text, subject, text_body, html_body] if part).strip()
|
||||
|
||||
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
request_config = config or {}
|
||||
local_part = str(request_config.get("name") or self._generate_local_part()).strip()
|
||||
domain = str(request_config.get("default_domain") or request_config.get("domain") or self.config["default_domain"]).strip().lstrip("@")
|
||||
address = f"{local_part}@{domain}"
|
||||
password = self._generate_password()
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"address": address,
|
||||
"password": password,
|
||||
}
|
||||
|
||||
expires_in = request_config.get("expiresIn", request_config.get("expires_in", self.config.get("expires_in")))
|
||||
if expires_in is not None:
|
||||
payload["expiresIn"] = expires_in
|
||||
|
||||
account_response = self._make_request(
|
||||
"POST",
|
||||
"/accounts",
|
||||
json=payload,
|
||||
use_api_key=bool(self.config.get("api_key")),
|
||||
)
|
||||
token_response = self._make_request(
|
||||
"POST",
|
||||
"/token",
|
||||
json={
|
||||
"address": account_response.get("address", address),
|
||||
"password": password,
|
||||
},
|
||||
)
|
||||
|
||||
account_id = str(account_response.get("id") or token_response.get("id") or "").strip()
|
||||
resolved_address = str(account_response.get("address") or address).strip()
|
||||
token = str(token_response.get("token") or "").strip()
|
||||
|
||||
if not account_id or not resolved_address or not token:
|
||||
raise EmailServiceError("DuckMail 返回数据不完整")
|
||||
|
||||
email_info = {
|
||||
"email": resolved_address,
|
||||
"service_id": account_id,
|
||||
"id": account_id,
|
||||
"account_id": account_id,
|
||||
"token": token,
|
||||
"password": password,
|
||||
"created_at": time.time(),
|
||||
"raw_account": account_response,
|
||||
}
|
||||
|
||||
self._cache_account(email_info)
|
||||
self.update_status(True)
|
||||
return email_info
|
||||
|
||||
def get_verification_code(
|
||||
self,
|
||||
email: str,
|
||||
email_id: str = None,
|
||||
timeout: int = 120,
|
||||
pattern: str = OTP_CODE_PATTERN,
|
||||
otp_sent_at: Optional[float] = None,
|
||||
) -> Optional[str]:
|
||||
account_info = self._get_account_info(email=email, email_id=email_id)
|
||||
if not account_info:
|
||||
logger.warning(f"DuckMail 未找到邮箱缓存: {email}, {email_id}")
|
||||
return None
|
||||
|
||||
token = account_info.get("token")
|
||||
if not token:
|
||||
logger.warning(f"DuckMail 邮箱缺少访问 token: {email}")
|
||||
return None
|
||||
|
||||
start_time = time.time()
|
||||
seen_message_ids = set()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = self._make_request(
|
||||
"GET",
|
||||
"/messages",
|
||||
token=token,
|
||||
params={"page": 1},
|
||||
)
|
||||
messages = response.get("hydra:member", [])
|
||||
|
||||
for message in messages:
|
||||
message_id = str(message.get("id") or "").strip()
|
||||
if not message_id or message_id in seen_message_ids:
|
||||
continue
|
||||
|
||||
created_at = self._parse_message_time(message.get("createdAt"))
|
||||
if otp_sent_at and created_at and created_at + 1 < otp_sent_at:
|
||||
continue
|
||||
|
||||
seen_message_ids.add(message_id)
|
||||
detail = self._make_request(
|
||||
"GET",
|
||||
f"/messages/{message_id}",
|
||||
token=token,
|
||||
)
|
||||
|
||||
content = self._message_search_text(message, detail)
|
||||
if "openai" not in content.lower():
|
||||
continue
|
||||
|
||||
match = re.search(pattern, content)
|
||||
if match:
|
||||
self.update_status(True)
|
||||
return match.group(1)
|
||||
except Exception as e:
|
||||
logger.debug(f"DuckMail 轮询验证码失败: {e}")
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
return None
|
||||
|
||||
def list_emails(self, **kwargs) -> List[Dict[str, Any]]:
|
||||
return list(self._accounts_by_email.values())
|
||||
|
||||
def delete_email(self, email_id: str) -> bool:
|
||||
account_info = self._get_account_info(email_id=email_id) or self._get_account_info(email=email_id)
|
||||
if not account_info:
|
||||
return False
|
||||
|
||||
token = account_info.get("token")
|
||||
account_id = account_info.get("account_id") or account_info.get("service_id")
|
||||
if not token or not account_id:
|
||||
return False
|
||||
|
||||
try:
|
||||
self._make_request(
|
||||
"DELETE",
|
||||
f"/accounts/{account_id}",
|
||||
token=token,
|
||||
)
|
||||
self._accounts_by_id.pop(str(account_id), None)
|
||||
self._accounts_by_email.pop(str(account_info.get("email") or "").lower(), None)
|
||||
self.update_status(True)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"DuckMail 删除邮箱失败: {e}")
|
||||
self.update_status(False, e)
|
||||
return False
|
||||
|
||||
def check_health(self) -> bool:
|
||||
try:
|
||||
self._make_request(
|
||||
"GET",
|
||||
"/domains",
|
||||
params={"page": 1},
|
||||
use_api_key=bool(self.config.get("api_key")),
|
||||
)
|
||||
self.update_status(True)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"DuckMail 健康检查失败: {e}")
|
||||
self.update_status(False, e)
|
||||
return False
|
||||
|
||||
def get_email_messages(self, email_id: str, **kwargs) -> List[Dict[str, Any]]:
|
||||
account_info = self._get_account_info(email_id=email_id) or self._get_account_info(email=email_id)
|
||||
if not account_info or not account_info.get("token"):
|
||||
return []
|
||||
response = self._make_request(
|
||||
"GET",
|
||||
"/messages",
|
||||
token=account_info["token"],
|
||||
params={"page": kwargs.get("page", 1)},
|
||||
)
|
||||
return response.get("hydra:member", [])
|
||||
|
||||
def get_message_detail(self, email_id: str, message_id: str) -> Optional[Dict[str, Any]]:
|
||||
account_info = self._get_account_info(email_id=email_id) or self._get_account_info(email=email_id)
|
||||
if not account_info or not account_info.get("token"):
|
||||
return None
|
||||
return self._make_request(
|
||||
"GET",
|
||||
f"/messages/{message_id}",
|
||||
token=account_info["token"],
|
||||
)
|
||||
|
||||
def get_service_info(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"service_type": self.service_type.value,
|
||||
"name": self.name,
|
||||
"base_url": self.config["base_url"],
|
||||
"default_domain": self.config["default_domain"],
|
||||
"cached_accounts": len(self._accounts_by_email),
|
||||
"status": self.status.value,
|
||||
}
|
||||
324
src/services/freemail.py
Normal file
324
src/services/freemail.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""
|
||||
Freemail 邮箱服务实现
|
||||
基于自部署 Cloudflare Worker 临时邮箱服务 (https://github.com/idinging/freemail)
|
||||
"""
|
||||
|
||||
import re
|
||||
import time
|
||||
import logging
|
||||
import random
|
||||
import string
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from .base import BaseEmailService, EmailServiceError, EmailServiceType
|
||||
from ..core.http_client import HTTPClient, RequestConfig
|
||||
from ..config.constants import OTP_CODE_PATTERN
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FreemailService(BaseEmailService):
|
||||
"""
|
||||
Freemail 邮箱服务
|
||||
基于自部署 Cloudflare Worker 的临时邮箱
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None, name: str = None):
|
||||
"""
|
||||
初始化 Freemail 服务
|
||||
|
||||
Args:
|
||||
config: 配置字典,支持以下键:
|
||||
- base_url: Worker 域名地址 (必需)
|
||||
- admin_token: Admin Token,对应 JWT_TOKEN (必需)
|
||||
- domain: 邮箱域名,如 example.com
|
||||
- timeout: 请求超时时间,默认 30
|
||||
- max_retries: 最大重试次数,默认 3
|
||||
name: 服务名称
|
||||
"""
|
||||
super().__init__(EmailServiceType.FREEMAIL, name)
|
||||
|
||||
required_keys = ["base_url", "admin_token"]
|
||||
missing_keys = [key for key in required_keys if not (config or {}).get(key)]
|
||||
if missing_keys:
|
||||
raise ValueError(f"缺少必需配置: {missing_keys}")
|
||||
|
||||
default_config = {
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
}
|
||||
self.config = {**default_config, **(config or {})}
|
||||
self.config["base_url"] = self.config["base_url"].rstrip("/")
|
||||
|
||||
http_config = RequestConfig(
|
||||
timeout=self.config["timeout"],
|
||||
max_retries=self.config["max_retries"],
|
||||
)
|
||||
self.http_client = HTTPClient(proxy_url=None, config=http_config)
|
||||
|
||||
# 缓存 domain 列表
|
||||
self._domains = []
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""构造 admin 请求头"""
|
||||
return {
|
||||
"Authorization": f"Bearer {self.config['admin_token']}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
def _make_request(self, method: str, path: str, **kwargs) -> Any:
|
||||
"""
|
||||
发送请求并返回 JSON 数据
|
||||
|
||||
Args:
|
||||
method: HTTP 方法
|
||||
path: 请求路径(以 / 开头)
|
||||
**kwargs: 传递给 http_client.request 的额外参数
|
||||
|
||||
Returns:
|
||||
响应 JSON 数据
|
||||
|
||||
Raises:
|
||||
EmailServiceError: 请求失败
|
||||
"""
|
||||
url = f"{self.config['base_url']}{path}"
|
||||
kwargs.setdefault("headers", {})
|
||||
kwargs["headers"].update(self._get_headers())
|
||||
|
||||
try:
|
||||
response = self.http_client.request(method, url, **kwargs)
|
||||
|
||||
if response.status_code >= 400:
|
||||
error_msg = f"请求失败: {response.status_code}"
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = f"{error_msg} - {error_data}"
|
||||
except Exception:
|
||||
error_msg = f"{error_msg} - {response.text[:200]}"
|
||||
self.update_status(False, EmailServiceError(error_msg))
|
||||
raise EmailServiceError(error_msg)
|
||||
|
||||
try:
|
||||
return response.json()
|
||||
except Exception:
|
||||
return {"raw_response": response.text}
|
||||
|
||||
except Exception as e:
|
||||
self.update_status(False, e)
|
||||
if isinstance(e, EmailServiceError):
|
||||
raise
|
||||
raise EmailServiceError(f"请求失败: {method} {path} - {e}")
|
||||
|
||||
def _ensure_domains(self):
|
||||
"""获取并缓存可用域名列表"""
|
||||
if not self._domains:
|
||||
try:
|
||||
domains = self._make_request("GET", "/api/domains")
|
||||
if isinstance(domains, list):
|
||||
self._domains = domains
|
||||
except Exception as e:
|
||||
logger.warning(f"获取 Freemail 域名列表失败: {e}")
|
||||
|
||||
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
通过 API 创建临时邮箱
|
||||
|
||||
Returns:
|
||||
包含邮箱信息的字典:
|
||||
- email: 邮箱地址
|
||||
- service_id: 同 email(用作标识)
|
||||
"""
|
||||
self._ensure_domains()
|
||||
|
||||
req_config = config or {}
|
||||
domain_index = 0
|
||||
target_domain = req_config.get("domain") or self.config.get("domain")
|
||||
|
||||
if target_domain and self._domains:
|
||||
for i, d in enumerate(self._domains):
|
||||
if d == target_domain:
|
||||
domain_index = i
|
||||
break
|
||||
|
||||
prefix = req_config.get("name")
|
||||
try:
|
||||
if prefix:
|
||||
body = {
|
||||
"local": prefix,
|
||||
"domainIndex": domain_index
|
||||
}
|
||||
resp = self._make_request("POST", "/api/create", json=body)
|
||||
else:
|
||||
params = {"domainIndex": domain_index}
|
||||
length = req_config.get("length")
|
||||
if length:
|
||||
params["length"] = length
|
||||
resp = self._make_request("GET", "/api/generate", params=params)
|
||||
|
||||
email = resp.get("email")
|
||||
if not email:
|
||||
raise EmailServiceError(f"创建邮箱失败,未返回邮箱地址: {resp}")
|
||||
|
||||
email_info = {
|
||||
"email": email,
|
||||
"service_id": email,
|
||||
"id": email,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
|
||||
logger.info(f"成功创建 Freemail 邮箱: {email}")
|
||||
self.update_status(True)
|
||||
return email_info
|
||||
|
||||
except Exception as e:
|
||||
self.update_status(False, e)
|
||||
if isinstance(e, EmailServiceError):
|
||||
raise
|
||||
raise EmailServiceError(f"创建邮箱失败: {e}")
|
||||
|
||||
def get_verification_code(
|
||||
self,
|
||||
email: str,
|
||||
email_id: str = None,
|
||||
timeout: int = 120,
|
||||
pattern: str = OTP_CODE_PATTERN,
|
||||
otp_sent_at: Optional[float] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
从 Freemail 邮箱获取验证码
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
email_id: 未使用,保留接口兼容
|
||||
timeout: 超时时间(秒)
|
||||
pattern: 验证码正则
|
||||
otp_sent_at: OTP 发送时间戳(暂未使用)
|
||||
|
||||
Returns:
|
||||
验证码字符串,超时返回 None
|
||||
"""
|
||||
logger.info(f"正在从 Freemail 邮箱 {email} 获取验证码...")
|
||||
|
||||
start_time = time.time()
|
||||
seen_mail_ids: set = set()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
mails = self._make_request("GET", "/api/emails", params={"mailbox": email, "limit": 20})
|
||||
if not isinstance(mails, list):
|
||||
time.sleep(3)
|
||||
continue
|
||||
|
||||
for mail in mails:
|
||||
mail_id = mail.get("id")
|
||||
if not mail_id or mail_id in seen_mail_ids:
|
||||
continue
|
||||
|
||||
seen_mail_ids.add(mail_id)
|
||||
|
||||
sender = str(mail.get("sender", "")).lower()
|
||||
subject = str(mail.get("subject", ""))
|
||||
preview = str(mail.get("preview", ""))
|
||||
|
||||
content = f"{sender}\n{subject}\n{preview}"
|
||||
|
||||
if "openai" not in content.lower():
|
||||
continue
|
||||
|
||||
# 尝试直接使用 Freemail 提取的验证码
|
||||
v_code = mail.get("verification_code")
|
||||
if v_code:
|
||||
logger.info(f"从 Freemail 邮箱 {email} 找到验证码: {v_code}")
|
||||
self.update_status(True)
|
||||
return v_code
|
||||
|
||||
# 如果没有直接提供,通过正则匹配 preview
|
||||
match = re.search(pattern, content)
|
||||
if match:
|
||||
code = match.group(1)
|
||||
logger.info(f"从 Freemail 邮箱 {email} 找到验证码: {code}")
|
||||
self.update_status(True)
|
||||
return code
|
||||
|
||||
# 如果依然未找到,获取邮件详情进行匹配
|
||||
try:
|
||||
detail = self._make_request("GET", f"/api/email/{mail_id}")
|
||||
full_content = str(detail.get("content", "")) + "\n" + str(detail.get("html_content", ""))
|
||||
match = re.search(pattern, full_content)
|
||||
if match:
|
||||
code = match.group(1)
|
||||
logger.info(f"从 Freemail 邮箱 {email} 找到验证码: {code}")
|
||||
self.update_status(True)
|
||||
return code
|
||||
except Exception as e:
|
||||
logger.debug(f"获取 Freemail 邮件详情失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"检查 Freemail 邮件时出错: {e}")
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
logger.warning(f"等待 Freemail 验证码超时: {email}")
|
||||
return None
|
||||
|
||||
def list_emails(self, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列出邮箱
|
||||
|
||||
Args:
|
||||
**kwargs: 额外查询参数
|
||||
|
||||
Returns:
|
||||
邮箱列表
|
||||
"""
|
||||
try:
|
||||
params = {
|
||||
"limit": kwargs.get("limit", 100),
|
||||
"offset": kwargs.get("offset", 0)
|
||||
}
|
||||
resp = self._make_request("GET", "/api/mailboxes", params=params)
|
||||
|
||||
emails = []
|
||||
if isinstance(resp, list):
|
||||
for mail in resp:
|
||||
address = mail.get("address")
|
||||
if address:
|
||||
emails.append({
|
||||
"id": address,
|
||||
"service_id": address,
|
||||
"email": address,
|
||||
"created_at": mail.get("created_at"),
|
||||
"raw_data": mail
|
||||
})
|
||||
self.update_status(True)
|
||||
return emails
|
||||
except Exception as e:
|
||||
logger.warning(f"列出 Freemail 邮箱失败: {e}")
|
||||
self.update_status(False, e)
|
||||
return []
|
||||
|
||||
def delete_email(self, email_id: str) -> bool:
|
||||
"""
|
||||
删除邮箱
|
||||
"""
|
||||
try:
|
||||
self._make_request("DELETE", "/api/mailboxes", params={"address": email_id})
|
||||
logger.info(f"已删除 Freemail 邮箱: {email_id}")
|
||||
self.update_status(True)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"删除 Freemail 邮箱失败: {e}")
|
||||
self.update_status(False, e)
|
||||
return False
|
||||
|
||||
def check_health(self) -> bool:
|
||||
"""检查服务健康状态"""
|
||||
try:
|
||||
self._make_request("GET", "/api/domains")
|
||||
self.update_status(True)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Freemail 健康检查失败: {e}")
|
||||
self.update_status(False, e)
|
||||
return False
|
||||
217
src/services/imap_mail.py
Normal file
217
src/services/imap_mail.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
IMAP 邮箱服务
|
||||
支持 Gmail / QQ / 163 / Yahoo / Outlook 等标准 IMAP 协议邮件服务商。
|
||||
仅用于接收验证码,强制直连(imaplib 不支持代理)。
|
||||
"""
|
||||
|
||||
import imaplib
|
||||
import email
|
||||
import re
|
||||
import time
|
||||
import logging
|
||||
from email.header import decode_header
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .base import BaseEmailService, EmailServiceError
|
||||
from ..config.constants import (
|
||||
EmailServiceType,
|
||||
OPENAI_EMAIL_SENDERS,
|
||||
OTP_CODE_SEMANTIC_PATTERN,
|
||||
OTP_CODE_PATTERN,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImapMailService(BaseEmailService):
|
||||
"""标准 IMAP 邮箱服务(仅接收验证码,强制直连)"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None, name: str = None):
|
||||
super().__init__(EmailServiceType.IMAP_MAIL, name)
|
||||
|
||||
cfg = config or {}
|
||||
required_keys = ["host", "email", "password"]
|
||||
missing_keys = [k for k in required_keys if not cfg.get(k)]
|
||||
if missing_keys:
|
||||
raise ValueError(f"缺少必需配置: {missing_keys}")
|
||||
|
||||
self.host: str = str(cfg["host"]).strip()
|
||||
self.port: int = int(cfg.get("port", 993))
|
||||
self.use_ssl: bool = bool(cfg.get("use_ssl", True))
|
||||
self.email_addr: str = str(cfg["email"]).strip()
|
||||
self.password: str = str(cfg["password"])
|
||||
self.timeout: int = int(cfg.get("timeout", 30))
|
||||
self.max_retries: int = int(cfg.get("max_retries", 3))
|
||||
|
||||
def _connect(self) -> imaplib.IMAP4:
|
||||
"""建立 IMAP 连接并登录,返回 mail 对象"""
|
||||
if self.use_ssl:
|
||||
mail = imaplib.IMAP4_SSL(self.host, self.port)
|
||||
else:
|
||||
mail = imaplib.IMAP4(self.host, self.port)
|
||||
mail.starttls()
|
||||
mail.login(self.email_addr, self.password)
|
||||
return mail
|
||||
|
||||
def _decode_str(self, value) -> str:
|
||||
"""解码邮件头部字段"""
|
||||
if value is None:
|
||||
return ""
|
||||
parts = decode_header(value)
|
||||
decoded = []
|
||||
for part, charset in parts:
|
||||
if isinstance(part, bytes):
|
||||
decoded.append(part.decode(charset or "utf-8", errors="replace"))
|
||||
else:
|
||||
decoded.append(str(part))
|
||||
return " ".join(decoded)
|
||||
|
||||
def _get_text_body(self, msg) -> str:
|
||||
"""提取邮件纯文本内容"""
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain":
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
payload = part.get_payload(decode=True)
|
||||
if payload:
|
||||
body += payload.decode(charset, errors="replace")
|
||||
else:
|
||||
charset = msg.get_content_charset() or "utf-8"
|
||||
payload = msg.get_payload(decode=True)
|
||||
if payload:
|
||||
body = payload.decode(charset, errors="replace")
|
||||
return body
|
||||
|
||||
def _is_openai_sender(self, from_addr: str) -> bool:
|
||||
"""判断发件人是否为 OpenAI"""
|
||||
from_lower = from_addr.lower()
|
||||
for sender in OPENAI_EMAIL_SENDERS:
|
||||
if sender.startswith("@") or sender.startswith("."):
|
||||
if sender in from_lower:
|
||||
return True
|
||||
else:
|
||||
if sender in from_lower:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_otp(self, text: str) -> Optional[str]:
|
||||
"""从文本中提取 6 位验证码,优先语义匹配,回退简单匹配"""
|
||||
match = re.search(OTP_CODE_SEMANTIC_PATTERN, text, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
match = re.search(OTP_CODE_PATTERN, text)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""IMAP 模式不创建新邮箱,直接返回配置中的固定地址"""
|
||||
self.update_status(True)
|
||||
return {
|
||||
"email": self.email_addr,
|
||||
"service_id": self.email_addr,
|
||||
"id": self.email_addr,
|
||||
}
|
||||
|
||||
def get_verification_code(
|
||||
self,
|
||||
email: str,
|
||||
email_id: str = None,
|
||||
timeout: int = 60,
|
||||
pattern: str = None,
|
||||
otp_sent_at: Optional[float] = None,
|
||||
) -> Optional[str]:
|
||||
"""轮询 IMAP 收件箱,获取 OpenAI 验证码"""
|
||||
start_time = time.time()
|
||||
seen_ids: set = set()
|
||||
mail = None
|
||||
|
||||
try:
|
||||
mail = self._connect()
|
||||
mail.select("INBOX")
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
# 搜索所有未读邮件
|
||||
status, data = mail.search(None, "UNSEEN")
|
||||
if status != "OK" or not data or not data[0]:
|
||||
time.sleep(3)
|
||||
continue
|
||||
|
||||
msg_ids = data[0].split()
|
||||
for msg_id in reversed(msg_ids): # 最新的优先
|
||||
id_str = msg_id.decode()
|
||||
if id_str in seen_ids:
|
||||
continue
|
||||
seen_ids.add(id_str)
|
||||
|
||||
# 获取邮件
|
||||
status, msg_data = mail.fetch(msg_id, "(RFC822)")
|
||||
if status != "OK" or not msg_data:
|
||||
continue
|
||||
|
||||
raw = msg_data[0][1]
|
||||
msg = email.message_from_bytes(raw)
|
||||
|
||||
# 检查发件人
|
||||
from_addr = self._decode_str(msg.get("From", ""))
|
||||
if not self._is_openai_sender(from_addr):
|
||||
continue
|
||||
|
||||
# 提取验证码
|
||||
body = self._get_text_body(msg)
|
||||
code = self._extract_otp(body)
|
||||
if code:
|
||||
# 标记已读
|
||||
mail.store(msg_id, "+FLAGS", "\\Seen")
|
||||
self.update_status(True)
|
||||
logger.info(f"IMAP 获取验证码成功: {code}")
|
||||
return code
|
||||
|
||||
except imaplib.IMAP4.error as e:
|
||||
logger.debug(f"IMAP 搜索邮件失败: {e}")
|
||||
# 尝试重新连接
|
||||
try:
|
||||
mail.select("INBOX")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"IMAP 连接/轮询失败: {e}")
|
||||
self.update_status(False, str(e))
|
||||
finally:
|
||||
if mail:
|
||||
try:
|
||||
mail.logout()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def check_health(self) -> bool:
|
||||
"""尝试 IMAP 登录并选择收件箱"""
|
||||
mail = None
|
||||
try:
|
||||
mail = self._connect()
|
||||
status, _ = mail.select("INBOX")
|
||||
return status == "OK"
|
||||
except Exception as e:
|
||||
logger.warning(f"IMAP 健康检查失败: {e}")
|
||||
return False
|
||||
finally:
|
||||
if mail:
|
||||
try:
|
||||
mail.logout()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def list_emails(self, **kwargs) -> list:
|
||||
"""IMAP 单账号模式,返回固定地址"""
|
||||
return [{"email": self.email_addr, "id": self.email_addr}]
|
||||
|
||||
def delete_email(self, email_id: str) -> bool:
|
||||
"""IMAP 模式无需删除逻辑"""
|
||||
return True
|
||||
556
src/services/moe_mail.py
Normal file
556
src/services/moe_mail.py
Normal file
@@ -0,0 +1,556 @@
|
||||
"""
|
||||
自定义域名邮箱服务实现
|
||||
基于 email.md 中的 REST API 接口
|
||||
"""
|
||||
|
||||
import re
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from .base import BaseEmailService, EmailServiceError, EmailServiceType
|
||||
from ..core.http_client import HTTPClient, RequestConfig
|
||||
from ..config.constants import OTP_CODE_PATTERN
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MeoMailEmailService(BaseEmailService):
|
||||
"""
|
||||
自定义域名邮箱服务
|
||||
基于 REST API 接口
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None, name: str = None):
|
||||
"""
|
||||
初始化自定义域名邮箱服务
|
||||
|
||||
Args:
|
||||
config: 配置字典,支持以下键:
|
||||
- base_url: API 基础地址 (必需)
|
||||
- api_key: API 密钥 (必需)
|
||||
- api_key_header: API 密钥请求头名称 (默认: X-API-Key)
|
||||
- timeout: 请求超时时间 (默认: 30)
|
||||
- max_retries: 最大重试次数 (默认: 3)
|
||||
- proxy_url: 代理 URL
|
||||
- default_domain: 默认域名
|
||||
- default_expiry: 默认过期时间(毫秒)
|
||||
name: 服务名称
|
||||
"""
|
||||
super().__init__(EmailServiceType.MOE_MAIL, name)
|
||||
|
||||
# 必需配置检查
|
||||
required_keys = ["base_url", "api_key"]
|
||||
missing_keys = [key for key in required_keys if key not in (config or {})]
|
||||
|
||||
if missing_keys:
|
||||
raise ValueError(f"缺少必需配置: {missing_keys}")
|
||||
|
||||
# 默认配置
|
||||
default_config = {
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
"api_key_header": "X-API-Key",
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
"proxy_url": None,
|
||||
"default_domain": None,
|
||||
"default_expiry": 3600000, # 1小时
|
||||
}
|
||||
|
||||
self.config = {**default_config, **(config or {})}
|
||||
|
||||
# 创建 HTTP 客户端
|
||||
http_config = RequestConfig(
|
||||
timeout=self.config["timeout"],
|
||||
max_retries=self.config["max_retries"],
|
||||
)
|
||||
self.http_client = HTTPClient(
|
||||
proxy_url=self.config.get("proxy_url"),
|
||||
config=http_config
|
||||
)
|
||||
|
||||
# 状态变量
|
||||
self._emails_cache: Dict[str, Dict[str, Any]] = {}
|
||||
self._last_config_check: float = 0
|
||||
self._cached_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""获取 API 请求头"""
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# 添加 API 密钥
|
||||
api_key_header = self.config.get("api_key_header", "X-API-Key")
|
||||
headers[api_key_header] = self.config["api_key"]
|
||||
|
||||
return headers
|
||||
|
||||
def _make_request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
发送 API 请求
|
||||
|
||||
Args:
|
||||
method: HTTP 方法
|
||||
endpoint: API 端点
|
||||
**kwargs: 请求参数
|
||||
|
||||
Returns:
|
||||
响应 JSON 数据
|
||||
|
||||
Raises:
|
||||
EmailServiceError: 请求失败
|
||||
"""
|
||||
url = urljoin(self.config["base_url"], endpoint)
|
||||
|
||||
# 添加默认请求头
|
||||
kwargs.setdefault("headers", {})
|
||||
kwargs["headers"].update(self._get_headers())
|
||||
|
||||
try:
|
||||
# POST 请求禁用自动重定向,手动处理以保持 POST 方法(避免 HTTP→HTTPS 重定向时被转为 GET)
|
||||
if method.upper() == "POST":
|
||||
kwargs["allow_redirects"] = False
|
||||
response = self.http_client.request(method, url, **kwargs)
|
||||
# 处理重定向
|
||||
max_redirects = 5
|
||||
redirect_count = 0
|
||||
while response.status_code in (301, 302, 303, 307, 308) and redirect_count < max_redirects:
|
||||
location = response.headers.get("Location", "")
|
||||
if not location:
|
||||
break
|
||||
import urllib.parse as _urlparse
|
||||
redirect_url = _urlparse.urljoin(url, location)
|
||||
# 307/308 保持 POST,其余(301/302/303)转为 GET
|
||||
if response.status_code in (307, 308):
|
||||
redirect_method = method
|
||||
redirect_kwargs = kwargs
|
||||
else:
|
||||
redirect_method = "GET"
|
||||
# GET 不传 body
|
||||
redirect_kwargs = {k: v for k, v in kwargs.items() if k not in ("json", "data")}
|
||||
response = self.http_client.request(redirect_method, redirect_url, **redirect_kwargs)
|
||||
url = redirect_url
|
||||
redirect_count += 1
|
||||
else:
|
||||
response = self.http_client.request(method, url, **kwargs)
|
||||
|
||||
if response.status_code >= 400:
|
||||
error_msg = f"API 请求失败: {response.status_code}"
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = f"{error_msg} - {error_data}"
|
||||
except:
|
||||
error_msg = f"{error_msg} - {response.text[:200]}"
|
||||
|
||||
self.update_status(False, EmailServiceError(error_msg))
|
||||
raise EmailServiceError(error_msg)
|
||||
|
||||
# 解析响应
|
||||
try:
|
||||
return response.json()
|
||||
except json.JSONDecodeError:
|
||||
return {"raw_response": response.text}
|
||||
|
||||
except Exception as e:
|
||||
self.update_status(False, e)
|
||||
if isinstance(e, EmailServiceError):
|
||||
raise
|
||||
raise EmailServiceError(f"API 请求失败: {method} {endpoint} - {e}")
|
||||
|
||||
def get_config(self, force_refresh: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
获取系统配置
|
||||
|
||||
Args:
|
||||
force_refresh: 是否强制刷新缓存
|
||||
|
||||
Returns:
|
||||
配置信息
|
||||
"""
|
||||
# 检查缓存
|
||||
if not force_refresh and self._cached_config and time.time() - self._last_config_check < 300:
|
||||
return self._cached_config
|
||||
|
||||
try:
|
||||
response = self._make_request("GET", "/api/config")
|
||||
self._cached_config = response
|
||||
self._last_config_check = time.time()
|
||||
self.update_status(True)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.warning(f"获取配置失败: {e}")
|
||||
return {}
|
||||
|
||||
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
创建临时邮箱
|
||||
|
||||
Args:
|
||||
config: 配置参数:
|
||||
- name: 邮箱前缀(可选)
|
||||
- expiryTime: 有效期(毫秒)(可选)
|
||||
- domain: 邮箱域名(可选)
|
||||
|
||||
Returns:
|
||||
包含邮箱信息的字典:
|
||||
- email: 邮箱地址
|
||||
- service_id: 邮箱 ID
|
||||
- id: 邮箱 ID(同 service_id)
|
||||
- expiry: 过期时间信息
|
||||
"""
|
||||
# 获取默认配置
|
||||
sys_config = self.get_config()
|
||||
default_domain = self.config.get("default_domain")
|
||||
if not default_domain and sys_config.get("emailDomains"):
|
||||
# 使用系统配置的第一个域名
|
||||
domains = sys_config["emailDomains"].split(",")
|
||||
default_domain = domains[0].strip() if domains else None
|
||||
|
||||
# 构建请求参数
|
||||
request_config = config or {}
|
||||
create_data = {
|
||||
"name": request_config.get("name", ""),
|
||||
"expiryTime": request_config.get("expiryTime", self.config.get("default_expiry", 3600000)),
|
||||
"domain": request_config.get("domain", default_domain),
|
||||
}
|
||||
|
||||
# 移除空值
|
||||
create_data = {k: v for k, v in create_data.items() if v is not None and v != ""}
|
||||
|
||||
try:
|
||||
response = self._make_request("POST", "/api/emails/generate", json=create_data)
|
||||
|
||||
email = response.get("email", "").strip()
|
||||
email_id = response.get("id", "").strip()
|
||||
|
||||
if not email or not email_id:
|
||||
raise EmailServiceError("API 返回数据不完整")
|
||||
|
||||
email_info = {
|
||||
"email": email,
|
||||
"service_id": email_id,
|
||||
"id": email_id,
|
||||
"created_at": time.time(),
|
||||
"expiry": create_data.get("expiryTime"),
|
||||
"domain": create_data.get("domain"),
|
||||
"raw_response": response,
|
||||
}
|
||||
|
||||
# 缓存邮箱信息
|
||||
self._emails_cache[email_id] = email_info
|
||||
|
||||
logger.info(f"成功创建自定义域名邮箱: {email} (ID: {email_id})")
|
||||
self.update_status(True)
|
||||
return email_info
|
||||
|
||||
except Exception as e:
|
||||
self.update_status(False, e)
|
||||
if isinstance(e, EmailServiceError):
|
||||
raise
|
||||
raise EmailServiceError(f"创建邮箱失败: {e}")
|
||||
|
||||
def get_verification_code(
|
||||
self,
|
||||
email: str,
|
||||
email_id: str = None,
|
||||
timeout: int = 120,
|
||||
pattern: str = OTP_CODE_PATTERN,
|
||||
otp_sent_at: Optional[float] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
从自定义域名邮箱获取验证码
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
email_id: 邮箱 ID(如果不提供,从缓存中查找)
|
||||
timeout: 超时时间(秒)
|
||||
pattern: 验证码正则表达式
|
||||
otp_sent_at: OTP 发送时间戳(自定义域名服务暂不使用此参数)
|
||||
|
||||
Returns:
|
||||
验证码字符串,如果超时或未找到返回 None
|
||||
"""
|
||||
# 查找邮箱 ID
|
||||
target_email_id = email_id
|
||||
if not target_email_id:
|
||||
# 从缓存中查找
|
||||
for eid, info in self._emails_cache.items():
|
||||
if info.get("email") == email:
|
||||
target_email_id = eid
|
||||
break
|
||||
|
||||
if not target_email_id:
|
||||
logger.warning(f"未找到邮箱 {email} 的 ID,无法获取验证码")
|
||||
return None
|
||||
|
||||
logger.info(f"正在从自定义域名邮箱 {email} 获取验证码...")
|
||||
|
||||
start_time = time.time()
|
||||
seen_message_ids = set()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
# 获取邮件列表
|
||||
response = self._make_request("GET", f"/api/emails/{target_email_id}")
|
||||
|
||||
messages = response.get("messages", [])
|
||||
if not isinstance(messages, list):
|
||||
time.sleep(3)
|
||||
continue
|
||||
|
||||
for message in messages:
|
||||
message_id = message.get("id")
|
||||
if not message_id or message_id in seen_message_ids:
|
||||
continue
|
||||
|
||||
seen_message_ids.add(message_id)
|
||||
|
||||
# 检查是否是目标邮件
|
||||
sender = str(message.get("from_address", "")).lower()
|
||||
subject = str(message.get("subject", ""))
|
||||
|
||||
# 获取邮件内容
|
||||
message_content = self._get_message_content(target_email_id, message_id)
|
||||
if not message_content:
|
||||
continue
|
||||
|
||||
content = f"{sender} {subject} {message_content}"
|
||||
|
||||
# 检查是否是 OpenAI 邮件
|
||||
if "openai" not in sender and "openai" not in content.lower():
|
||||
continue
|
||||
|
||||
# 提取验证码 过滤掉邮箱
|
||||
email_pattern = r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"
|
||||
match = re.search(pattern, re.sub(email_pattern, "", content))
|
||||
if match:
|
||||
code = match.group(1)
|
||||
logger.info(f"从自定义域名邮箱 {email} 找到验证码: {code}")
|
||||
self.update_status(True)
|
||||
return code
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"检查邮件时出错: {e}")
|
||||
|
||||
# 等待一段时间再检查
|
||||
time.sleep(3)
|
||||
|
||||
logger.warning(f"等待验证码超时: {email}")
|
||||
return None
|
||||
|
||||
def _get_message_content(self, email_id: str, message_id: str) -> Optional[str]:
|
||||
"""获取邮件内容"""
|
||||
try:
|
||||
response = self._make_request("GET", f"/api/emails/{email_id}/{message_id}")
|
||||
message = response.get("message", {})
|
||||
|
||||
# 优先使用纯文本内容,其次使用 HTML 内容
|
||||
content = message.get("content", "")
|
||||
if not content:
|
||||
html = message.get("html", "")
|
||||
if html:
|
||||
# 简单去除 HTML 标签
|
||||
content = re.sub(r"<[^>]+>", " ", html)
|
||||
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.debug(f"获取邮件内容失败: {e}")
|
||||
return None
|
||||
|
||||
def list_emails(self, cursor: str = None, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列出所有邮箱
|
||||
|
||||
Args:
|
||||
cursor: 分页游标
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
邮箱列表
|
||||
"""
|
||||
params = {}
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
try:
|
||||
response = self._make_request("GET", "/api/emails", params=params)
|
||||
emails = response.get("emails", [])
|
||||
|
||||
# 更新缓存
|
||||
for email_info in emails:
|
||||
email_id = email_info.get("id")
|
||||
if email_id:
|
||||
self._emails_cache[email_id] = email_info
|
||||
|
||||
self.update_status(True)
|
||||
return emails
|
||||
except Exception as e:
|
||||
logger.warning(f"列出邮箱失败: {e}")
|
||||
self.update_status(False, e)
|
||||
return []
|
||||
|
||||
def delete_email(self, email_id: str) -> bool:
|
||||
"""
|
||||
删除邮箱
|
||||
|
||||
Args:
|
||||
email_id: 邮箱 ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
response = self._make_request("DELETE", f"/api/emails/{email_id}")
|
||||
success = response.get("success", False)
|
||||
|
||||
if success:
|
||||
# 从缓存中移除
|
||||
self._emails_cache.pop(email_id, None)
|
||||
logger.info(f"成功删除邮箱: {email_id}")
|
||||
else:
|
||||
logger.warning(f"删除邮箱失败: {email_id}")
|
||||
|
||||
self.update_status(success)
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除邮箱失败: {email_id} - {e}")
|
||||
self.update_status(False, e)
|
||||
return False
|
||||
|
||||
def check_health(self) -> bool:
|
||||
"""检查自定义域名邮箱服务是否可用"""
|
||||
try:
|
||||
# 尝试获取配置
|
||||
config = self.get_config(force_refresh=True)
|
||||
if config:
|
||||
logger.debug(f"自定义域名邮箱服务健康检查通过,配置: {config.get('defaultRole', 'N/A')}")
|
||||
self.update_status(True)
|
||||
return True
|
||||
else:
|
||||
logger.warning("自定义域名邮箱服务健康检查失败:获取配置为空")
|
||||
self.update_status(False, EmailServiceError("获取配置为空"))
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"自定义域名邮箱服务健康检查失败: {e}")
|
||||
self.update_status(False, e)
|
||||
return False
|
||||
|
||||
def get_email_messages(self, email_id: str, cursor: str = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取邮箱中的邮件列表
|
||||
|
||||
Args:
|
||||
email_id: 邮箱 ID
|
||||
cursor: 分页游标
|
||||
|
||||
Returns:
|
||||
邮件列表
|
||||
"""
|
||||
params = {}
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
try:
|
||||
response = self._make_request("GET", f"/api/emails/{email_id}", params=params)
|
||||
messages = response.get("messages", [])
|
||||
self.update_status(True)
|
||||
return messages
|
||||
except Exception as e:
|
||||
logger.error(f"获取邮件列表失败: {email_id} - {e}")
|
||||
self.update_status(False, e)
|
||||
return []
|
||||
|
||||
def get_message_detail(self, email_id: str, message_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取邮件详情
|
||||
|
||||
Args:
|
||||
email_id: 邮箱 ID
|
||||
message_id: 邮件 ID
|
||||
|
||||
Returns:
|
||||
邮件详情
|
||||
"""
|
||||
try:
|
||||
response = self._make_request("GET", f"/api/emails/{email_id}/{message_id}")
|
||||
message = response.get("message")
|
||||
self.update_status(True)
|
||||
return message
|
||||
except Exception as e:
|
||||
logger.error(f"获取邮件详情失败: {email_id}/{message_id} - {e}")
|
||||
self.update_status(False, e)
|
||||
return None
|
||||
|
||||
def create_email_share(self, email_id: str, expires_in: int = 86400000) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
创建邮箱分享链接
|
||||
|
||||
Args:
|
||||
email_id: 邮箱 ID
|
||||
expires_in: 有效期(毫秒)
|
||||
|
||||
Returns:
|
||||
分享信息
|
||||
"""
|
||||
try:
|
||||
response = self._make_request(
|
||||
"POST",
|
||||
f"/api/emails/{email_id}/share",
|
||||
json={"expiresIn": expires_in}
|
||||
)
|
||||
self.update_status(True)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"创建邮箱分享链接失败: {email_id} - {e}")
|
||||
self.update_status(False, e)
|
||||
return None
|
||||
|
||||
def create_message_share(
|
||||
self,
|
||||
email_id: str,
|
||||
message_id: str,
|
||||
expires_in: int = 86400000
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
创建邮件分享链接
|
||||
|
||||
Args:
|
||||
email_id: 邮箱 ID
|
||||
message_id: 邮件 ID
|
||||
expires_in: 有效期(毫秒)
|
||||
|
||||
Returns:
|
||||
分享信息
|
||||
"""
|
||||
try:
|
||||
response = self._make_request(
|
||||
"POST",
|
||||
f"/api/emails/{email_id}/messages/{message_id}/share",
|
||||
json={"expiresIn": expires_in}
|
||||
)
|
||||
self.update_status(True)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"创建邮件分享链接失败: {email_id}/{message_id} - {e}")
|
||||
self.update_status(False, e)
|
||||
return None
|
||||
|
||||
def get_service_info(self) -> Dict[str, Any]:
|
||||
"""获取服务信息"""
|
||||
config = self.get_config()
|
||||
return {
|
||||
"service_type": self.service_type.value,
|
||||
"name": self.name,
|
||||
"base_url": self.config["base_url"],
|
||||
"default_domain": self.config.get("default_domain"),
|
||||
"system_config": config,
|
||||
"cached_emails_count": len(self._emails_cache),
|
||||
"status": self.status.value,
|
||||
}
|
||||
8
src/services/outlook/__init__.py
Normal file
8
src/services/outlook/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Outlook 邮箱服务模块
|
||||
支持多种 IMAP/API 连接方式,自动故障切换
|
||||
"""
|
||||
|
||||
from .service import OutlookService
|
||||
|
||||
__all__ = ['OutlookService']
|
||||
51
src/services/outlook/account.py
Normal file
51
src/services/outlook/account.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Outlook 账户数据类
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutlookAccount:
|
||||
"""Outlook 账户信息"""
|
||||
email: str
|
||||
password: str = ""
|
||||
client_id: str = ""
|
||||
refresh_token: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "OutlookAccount":
|
||||
"""从配置创建账户"""
|
||||
return cls(
|
||||
email=config.get("email", ""),
|
||||
password=config.get("password", ""),
|
||||
client_id=config.get("client_id", ""),
|
||||
refresh_token=config.get("refresh_token", "")
|
||||
)
|
||||
|
||||
def has_oauth(self) -> bool:
|
||||
"""是否支持 OAuth2"""
|
||||
return bool(self.client_id and self.refresh_token)
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""验证账户信息是否有效"""
|
||||
return bool(self.email and self.password) or self.has_oauth()
|
||||
|
||||
def to_dict(self, include_sensitive: bool = False) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
result = {
|
||||
"email": self.email,
|
||||
"has_oauth": self.has_oauth(),
|
||||
}
|
||||
if include_sensitive:
|
||||
result.update({
|
||||
"password": self.password,
|
||||
"client_id": self.client_id,
|
||||
"refresh_token": self.refresh_token[:20] + "..." if self.refresh_token else "",
|
||||
})
|
||||
return result
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""字符串表示"""
|
||||
return f"OutlookAccount({self.email})"
|
||||
153
src/services/outlook/base.py
Normal file
153
src/services/outlook/base.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
Outlook 服务基础定义
|
||||
包含枚举类型和数据类
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
|
||||
class ProviderType(str, Enum):
|
||||
"""Outlook 提供者类型"""
|
||||
IMAP_OLD = "imap_old" # 旧版 IMAP (outlook.office365.com)
|
||||
IMAP_NEW = "imap_new" # 新版 IMAP (outlook.live.com)
|
||||
GRAPH_API = "graph_api" # Microsoft Graph API
|
||||
|
||||
|
||||
class TokenEndpoint(str, Enum):
|
||||
"""Token 端点"""
|
||||
LIVE = "https://login.live.com/oauth20_token.srf"
|
||||
CONSUMERS = "https://login.microsoftonline.com/consumers/oauth2/v2.0/token"
|
||||
COMMON = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||||
|
||||
|
||||
class IMAPServer(str, Enum):
|
||||
"""IMAP 服务器"""
|
||||
OLD = "outlook.office365.com"
|
||||
NEW = "outlook.live.com"
|
||||
|
||||
|
||||
class ProviderStatus(str, Enum):
|
||||
"""提供者状态"""
|
||||
HEALTHY = "healthy" # 健康
|
||||
DEGRADED = "degraded" # 降级
|
||||
DISABLED = "disabled" # 禁用
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmailMessage:
|
||||
"""邮件消息数据类"""
|
||||
id: str # 消息 ID
|
||||
subject: str # 主题
|
||||
sender: str # 发件人
|
||||
recipients: List[str] = field(default_factory=list) # 收件人列表
|
||||
body: str = "" # 正文内容
|
||||
body_preview: str = "" # 正文预览
|
||||
received_at: Optional[datetime] = None # 接收时间
|
||||
received_timestamp: int = 0 # 接收时间戳
|
||||
is_read: bool = False # 是否已读
|
||||
has_attachments: bool = False # 是否有附件
|
||||
raw_data: Optional[bytes] = None # 原始数据(用于调试)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"subject": self.subject,
|
||||
"sender": self.sender,
|
||||
"recipients": self.recipients,
|
||||
"body": self.body,
|
||||
"body_preview": self.body_preview,
|
||||
"received_at": self.received_at.isoformat() if self.received_at else None,
|
||||
"received_timestamp": self.received_timestamp,
|
||||
"is_read": self.is_read,
|
||||
"has_attachments": self.has_attachments,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenInfo:
|
||||
"""Token 信息数据类"""
|
||||
access_token: str
|
||||
expires_at: float # 过期时间戳
|
||||
token_type: str = "Bearer"
|
||||
scope: str = ""
|
||||
refresh_token: Optional[str] = None
|
||||
|
||||
def is_expired(self, buffer_seconds: int = 120) -> bool:
|
||||
"""检查 Token 是否已过期"""
|
||||
import time
|
||||
return time.time() >= (self.expires_at - buffer_seconds)
|
||||
|
||||
@classmethod
|
||||
def from_response(cls, data: Dict[str, Any], scope: str = "") -> "TokenInfo":
|
||||
"""从 API 响应创建"""
|
||||
import time
|
||||
return cls(
|
||||
access_token=data.get("access_token", ""),
|
||||
expires_at=time.time() + data.get("expires_in", 3600),
|
||||
token_type=data.get("token_type", "Bearer"),
|
||||
scope=scope or data.get("scope", ""),
|
||||
refresh_token=data.get("refresh_token"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderHealth:
|
||||
"""提供者健康状态"""
|
||||
provider_type: ProviderType
|
||||
status: ProviderStatus = ProviderStatus.HEALTHY
|
||||
failure_count: int = 0 # 连续失败次数
|
||||
last_success: Optional[datetime] = None # 最后成功时间
|
||||
last_failure: Optional[datetime] = None # 最后失败时间
|
||||
last_error: str = "" # 最后错误信息
|
||||
disabled_until: Optional[datetime] = None # 禁用截止时间
|
||||
|
||||
def record_success(self):
|
||||
"""记录成功"""
|
||||
self.status = ProviderStatus.HEALTHY
|
||||
self.failure_count = 0
|
||||
self.last_success = datetime.now()
|
||||
self.disabled_until = None
|
||||
|
||||
def record_failure(self, error: str):
|
||||
"""记录失败"""
|
||||
self.failure_count += 1
|
||||
self.last_failure = datetime.now()
|
||||
self.last_error = error
|
||||
|
||||
def should_disable(self, threshold: int = 3) -> bool:
|
||||
"""判断是否应该禁用"""
|
||||
return self.failure_count >= threshold
|
||||
|
||||
def is_disabled(self) -> bool:
|
||||
"""检查是否被禁用"""
|
||||
if self.disabled_until and datetime.now() < self.disabled_until:
|
||||
return True
|
||||
return False
|
||||
|
||||
def disable(self, duration_seconds: int = 300):
|
||||
"""禁用提供者"""
|
||||
from datetime import timedelta
|
||||
self.status = ProviderStatus.DISABLED
|
||||
self.disabled_until = datetime.now() + timedelta(seconds=duration_seconds)
|
||||
|
||||
def enable(self):
|
||||
"""启用提供者"""
|
||||
self.status = ProviderStatus.HEALTHY
|
||||
self.disabled_until = None
|
||||
self.failure_count = 0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"provider_type": self.provider_type.value,
|
||||
"status": self.status.value,
|
||||
"failure_count": self.failure_count,
|
||||
"last_success": self.last_success.isoformat() if self.last_success else None,
|
||||
"last_failure": self.last_failure.isoformat() if self.last_failure else None,
|
||||
"last_error": self.last_error,
|
||||
"disabled_until": self.disabled_until.isoformat() if self.disabled_until else None,
|
||||
}
|
||||
228
src/services/outlook/email_parser.py
Normal file
228
src/services/outlook/email_parser.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
邮件解析和验证码提取
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from ...config.constants import (
|
||||
OTP_CODE_SIMPLE_PATTERN,
|
||||
OTP_CODE_SEMANTIC_PATTERN,
|
||||
OPENAI_EMAIL_SENDERS,
|
||||
OPENAI_VERIFICATION_KEYWORDS,
|
||||
)
|
||||
from .base import EmailMessage
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmailParser:
|
||||
"""
|
||||
邮件解析器
|
||||
用于识别 OpenAI 验证邮件并提取验证码
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 编译正则表达式
|
||||
self._simple_pattern = re.compile(OTP_CODE_SIMPLE_PATTERN)
|
||||
self._semantic_pattern = re.compile(OTP_CODE_SEMANTIC_PATTERN, re.IGNORECASE)
|
||||
|
||||
def is_openai_verification_email(
|
||||
self,
|
||||
email: EmailMessage,
|
||||
target_email: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
判断是否为 OpenAI 验证邮件
|
||||
|
||||
Args:
|
||||
email: 邮件对象
|
||||
target_email: 目标邮箱地址(用于验证收件人)
|
||||
|
||||
Returns:
|
||||
是否为 OpenAI 验证邮件
|
||||
"""
|
||||
sender = email.sender.lower()
|
||||
|
||||
# 1. 发件人必须是 OpenAI
|
||||
if not any(s in sender for s in OPENAI_EMAIL_SENDERS):
|
||||
logger.debug(f"邮件发件人非 OpenAI: {sender}")
|
||||
return False
|
||||
|
||||
# 2. 主题或正文包含验证关键词
|
||||
subject = email.subject.lower()
|
||||
body = email.body.lower()
|
||||
combined = f"{subject} {body}"
|
||||
|
||||
if not any(kw in combined for kw in OPENAI_VERIFICATION_KEYWORDS):
|
||||
logger.debug(f"邮件未包含验证关键词: {subject[:50]}")
|
||||
return False
|
||||
|
||||
# 3. 收件人检查已移除:别名邮件的 IMAP 头中收件人可能不匹配,只靠发件人+关键词判断
|
||||
logger.debug(f"识别为 OpenAI 验证邮件: {subject[:50]}")
|
||||
return True
|
||||
|
||||
def extract_verification_code(
|
||||
self,
|
||||
email: EmailMessage,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
从邮件中提取验证码
|
||||
|
||||
优先级:
|
||||
1. 从主题提取(6位数字)
|
||||
2. 从正文用语义正则提取(如 "code is 123456")
|
||||
3. 兜底:任意 6 位数字
|
||||
|
||||
Args:
|
||||
email: 邮件对象
|
||||
|
||||
Returns:
|
||||
验证码字符串,如果未找到返回 None
|
||||
"""
|
||||
# 1. 主题优先
|
||||
code = self._extract_from_subject(email.subject)
|
||||
if code:
|
||||
logger.debug(f"从主题提取验证码: {code}")
|
||||
return code
|
||||
|
||||
# 2. 正文语义匹配
|
||||
code = self._extract_semantic(email.body)
|
||||
if code:
|
||||
logger.debug(f"从正文语义提取验证码: {code}")
|
||||
return code
|
||||
|
||||
# 3. 兜底:正文任意 6 位数字
|
||||
code = self._extract_simple(email.body)
|
||||
if code:
|
||||
logger.debug(f"从正文兜底提取验证码: {code}")
|
||||
return code
|
||||
|
||||
return None
|
||||
|
||||
def _extract_from_subject(self, subject: str) -> Optional[str]:
|
||||
"""从主题提取验证码"""
|
||||
match = self._simple_pattern.search(subject)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
def _extract_semantic(self, body: str) -> Optional[str]:
|
||||
"""语义匹配提取验证码"""
|
||||
match = self._semantic_pattern.search(body)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
def _extract_simple(self, body: str) -> Optional[str]:
|
||||
"""简单匹配提取验证码"""
|
||||
match = self._simple_pattern.search(body)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
def find_verification_code_in_emails(
|
||||
self,
|
||||
emails: List[EmailMessage],
|
||||
target_email: Optional[str] = None,
|
||||
min_timestamp: int = 0,
|
||||
used_codes: Optional[set] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
从邮件列表中查找验证码
|
||||
|
||||
Args:
|
||||
emails: 邮件列表
|
||||
target_email: 目标邮箱地址
|
||||
min_timestamp: 最小时间戳(用于过滤旧邮件)
|
||||
used_codes: 已使用的验证码集合(用于去重)
|
||||
|
||||
Returns:
|
||||
验证码字符串,如果未找到返回 None
|
||||
"""
|
||||
used_codes = used_codes or set()
|
||||
|
||||
for email in emails:
|
||||
# 时间戳过滤
|
||||
if min_timestamp > 0 and email.received_timestamp > 0:
|
||||
if email.received_timestamp < min_timestamp:
|
||||
logger.debug(f"跳过旧邮件: {email.subject[:50]}")
|
||||
continue
|
||||
|
||||
# 检查是否是 OpenAI 验证邮件
|
||||
if not self.is_openai_verification_email(email, target_email):
|
||||
continue
|
||||
|
||||
# 提取验证码
|
||||
code = self.extract_verification_code(email)
|
||||
if code:
|
||||
# 去重检查
|
||||
if code in used_codes:
|
||||
logger.debug(f"跳过已使用的验证码: {code}")
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"[{target_email or 'unknown'}] 找到验证码: {code}, "
|
||||
f"邮件主题: {email.subject[:30]}"
|
||||
)
|
||||
return code
|
||||
|
||||
return None
|
||||
|
||||
def filter_emails_by_sender(
|
||||
self,
|
||||
emails: List[EmailMessage],
|
||||
sender_patterns: List[str],
|
||||
) -> List[EmailMessage]:
|
||||
"""
|
||||
按发件人过滤邮件
|
||||
|
||||
Args:
|
||||
emails: 邮件列表
|
||||
sender_patterns: 发件人匹配模式列表
|
||||
|
||||
Returns:
|
||||
过滤后的邮件列表
|
||||
"""
|
||||
filtered = []
|
||||
for email in emails:
|
||||
sender = email.sender.lower()
|
||||
if any(pattern.lower() in sender for pattern in sender_patterns):
|
||||
filtered.append(email)
|
||||
return filtered
|
||||
|
||||
def filter_emails_by_subject(
|
||||
self,
|
||||
emails: List[EmailMessage],
|
||||
keywords: List[str],
|
||||
) -> List[EmailMessage]:
|
||||
"""
|
||||
按主题关键词过滤邮件
|
||||
|
||||
Args:
|
||||
emails: 邮件列表
|
||||
keywords: 关键词列表
|
||||
|
||||
Returns:
|
||||
过滤后的邮件列表
|
||||
"""
|
||||
filtered = []
|
||||
for email in emails:
|
||||
subject = email.subject.lower()
|
||||
if any(kw.lower() in subject for kw in keywords):
|
||||
filtered.append(email)
|
||||
return filtered
|
||||
|
||||
|
||||
# 全局解析器实例
|
||||
_parser: Optional[EmailParser] = None
|
||||
|
||||
|
||||
def get_email_parser() -> EmailParser:
|
||||
"""获取全局邮件解析器实例"""
|
||||
global _parser
|
||||
if _parser is None:
|
||||
_parser = EmailParser()
|
||||
return _parser
|
||||
312
src/services/outlook/health_checker.py
Normal file
312
src/services/outlook/health_checker.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
健康检查和故障切换管理
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from .base import ProviderType, ProviderHealth, ProviderStatus
|
||||
from .providers.base import OutlookProvider
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HealthChecker:
|
||||
"""
|
||||
健康检查管理器
|
||||
跟踪各提供者的健康状态,管理故障切换
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
failure_threshold: int = 3,
|
||||
disable_duration: int = 300,
|
||||
recovery_check_interval: int = 60,
|
||||
):
|
||||
"""
|
||||
初始化健康检查器
|
||||
|
||||
Args:
|
||||
failure_threshold: 连续失败次数阈值,超过后禁用
|
||||
disable_duration: 禁用时长(秒)
|
||||
recovery_check_interval: 恢复检查间隔(秒)
|
||||
"""
|
||||
self.failure_threshold = failure_threshold
|
||||
self.disable_duration = disable_duration
|
||||
self.recovery_check_interval = recovery_check_interval
|
||||
|
||||
# 提供者健康状态: ProviderType -> ProviderHealth
|
||||
self._health_status: Dict[ProviderType, ProviderHealth] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# 初始化所有提供者的健康状态
|
||||
for provider_type in ProviderType:
|
||||
self._health_status[provider_type] = ProviderHealth(
|
||||
provider_type=provider_type
|
||||
)
|
||||
|
||||
def get_health(self, provider_type: ProviderType) -> ProviderHealth:
|
||||
"""获取提供者的健康状态"""
|
||||
with self._lock:
|
||||
return self._health_status.get(provider_type, ProviderHealth(provider_type=provider_type))
|
||||
|
||||
def record_success(self, provider_type: ProviderType):
|
||||
"""记录成功操作"""
|
||||
with self._lock:
|
||||
health = self._health_status.get(provider_type)
|
||||
if health:
|
||||
health.record_success()
|
||||
logger.debug(f"{provider_type.value} 记录成功")
|
||||
|
||||
def record_failure(self, provider_type: ProviderType, error: str):
|
||||
"""记录失败操作"""
|
||||
with self._lock:
|
||||
health = self._health_status.get(provider_type)
|
||||
if health:
|
||||
health.record_failure(error)
|
||||
|
||||
# 检查是否需要禁用
|
||||
if health.should_disable(self.failure_threshold):
|
||||
health.disable(self.disable_duration)
|
||||
logger.warning(
|
||||
f"{provider_type.value} 已禁用 {self.disable_duration} 秒,"
|
||||
f"原因: {error}"
|
||||
)
|
||||
|
||||
def is_available(self, provider_type: ProviderType) -> bool:
|
||||
"""
|
||||
检查提供者是否可用
|
||||
|
||||
Args:
|
||||
provider_type: 提供者类型
|
||||
|
||||
Returns:
|
||||
是否可用
|
||||
"""
|
||||
health = self.get_health(provider_type)
|
||||
|
||||
# 检查是否被禁用
|
||||
if health.is_disabled():
|
||||
remaining = (health.disabled_until - datetime.now()).total_seconds()
|
||||
logger.debug(
|
||||
f"{provider_type.value} 已被禁用,剩余 {int(remaining)} 秒"
|
||||
)
|
||||
return False
|
||||
|
||||
return health.status != ProviderStatus.DISABLED
|
||||
|
||||
def get_available_providers(
|
||||
self,
|
||||
priority_order: Optional[List[ProviderType]] = None,
|
||||
) -> List[ProviderType]:
|
||||
"""
|
||||
获取可用的提供者列表
|
||||
|
||||
Args:
|
||||
priority_order: 优先级顺序,默认为 [IMAP_NEW, IMAP_OLD, GRAPH_API]
|
||||
|
||||
Returns:
|
||||
可用的提供者列表
|
||||
"""
|
||||
if priority_order is None:
|
||||
priority_order = [
|
||||
ProviderType.IMAP_NEW,
|
||||
ProviderType.IMAP_OLD,
|
||||
ProviderType.GRAPH_API,
|
||||
]
|
||||
|
||||
available = []
|
||||
for provider_type in priority_order:
|
||||
if self.is_available(provider_type):
|
||||
available.append(provider_type)
|
||||
|
||||
return available
|
||||
|
||||
def get_next_available_provider(
|
||||
self,
|
||||
priority_order: Optional[List[ProviderType]] = None,
|
||||
) -> Optional[ProviderType]:
|
||||
"""
|
||||
获取下一个可用的提供者
|
||||
|
||||
Args:
|
||||
priority_order: 优先级顺序
|
||||
|
||||
Returns:
|
||||
可用的提供者类型,如果没有返回 None
|
||||
"""
|
||||
available = self.get_available_providers(priority_order)
|
||||
return available[0] if available else None
|
||||
|
||||
def force_disable(self, provider_type: ProviderType, duration: Optional[int] = None):
|
||||
"""
|
||||
强制禁用提供者
|
||||
|
||||
Args:
|
||||
provider_type: 提供者类型
|
||||
duration: 禁用时长(秒),默认使用配置值
|
||||
"""
|
||||
with self._lock:
|
||||
health = self._health_status.get(provider_type)
|
||||
if health:
|
||||
health.disable(duration or self.disable_duration)
|
||||
logger.warning(f"{provider_type.value} 已强制禁用")
|
||||
|
||||
def force_enable(self, provider_type: ProviderType):
|
||||
"""
|
||||
强制启用提供者
|
||||
|
||||
Args:
|
||||
provider_type: 提供者类型
|
||||
"""
|
||||
with self._lock:
|
||||
health = self._health_status.get(provider_type)
|
||||
if health:
|
||||
health.enable()
|
||||
logger.info(f"{provider_type.value} 已启用")
|
||||
|
||||
def get_all_health_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取所有提供者的健康状态
|
||||
|
||||
Returns:
|
||||
健康状态字典
|
||||
"""
|
||||
with self._lock:
|
||||
return {
|
||||
provider_type.value: health.to_dict()
|
||||
for provider_type, health in self._health_status.items()
|
||||
}
|
||||
|
||||
def check_and_recover(self):
|
||||
"""
|
||||
检查并恢复被禁用的提供者
|
||||
|
||||
如果禁用时间已过,自动恢复提供者
|
||||
"""
|
||||
with self._lock:
|
||||
for provider_type, health in self._health_status.items():
|
||||
if health.is_disabled():
|
||||
# 检查是否可以恢复
|
||||
if health.disabled_until and datetime.now() >= health.disabled_until:
|
||||
health.enable()
|
||||
logger.info(f"{provider_type.value} 已自动恢复")
|
||||
|
||||
def reset_all(self):
|
||||
"""重置所有提供者的健康状态"""
|
||||
with self._lock:
|
||||
for provider_type in ProviderType:
|
||||
self._health_status[provider_type] = ProviderHealth(
|
||||
provider_type=provider_type
|
||||
)
|
||||
logger.info("已重置所有提供者的健康状态")
|
||||
|
||||
|
||||
class FailoverManager:
|
||||
"""
|
||||
故障切换管理器
|
||||
管理提供者之间的自动切换
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
health_checker: HealthChecker,
|
||||
priority_order: Optional[List[ProviderType]] = None,
|
||||
):
|
||||
"""
|
||||
初始化故障切换管理器
|
||||
|
||||
Args:
|
||||
health_checker: 健康检查器
|
||||
priority_order: 提供者优先级顺序
|
||||
"""
|
||||
self.health_checker = health_checker
|
||||
self.priority_order = priority_order or [
|
||||
ProviderType.IMAP_NEW,
|
||||
ProviderType.IMAP_OLD,
|
||||
ProviderType.GRAPH_API,
|
||||
]
|
||||
|
||||
# 当前使用的提供者索引
|
||||
self._current_index = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_current_provider(self) -> Optional[ProviderType]:
|
||||
"""
|
||||
获取当前提供者
|
||||
|
||||
Returns:
|
||||
当前提供者类型,如果没有可用的返回 None
|
||||
"""
|
||||
available = self.health_checker.get_available_providers(self.priority_order)
|
||||
if not available:
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
# 尝试使用当前索引
|
||||
if self._current_index < len(available):
|
||||
return available[self._current_index]
|
||||
return available[0]
|
||||
|
||||
def switch_to_next(self) -> Optional[ProviderType]:
|
||||
"""
|
||||
切换到下一个提供者
|
||||
|
||||
Returns:
|
||||
下一个提供者类型,如果没有可用的返回 None
|
||||
"""
|
||||
available = self.health_checker.get_available_providers(self.priority_order)
|
||||
if not available:
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
self._current_index = (self._current_index + 1) % len(available)
|
||||
next_provider = available[self._current_index]
|
||||
logger.info(f"切换到提供者: {next_provider.value}")
|
||||
return next_provider
|
||||
|
||||
def on_provider_success(self, provider_type: ProviderType):
|
||||
"""
|
||||
提供者成功时调用
|
||||
|
||||
Args:
|
||||
provider_type: 提供者类型
|
||||
"""
|
||||
self.health_checker.record_success(provider_type)
|
||||
|
||||
# 重置索引到成功的提供者
|
||||
with self._lock:
|
||||
available = self.health_checker.get_available_providers(self.priority_order)
|
||||
if provider_type in available:
|
||||
self._current_index = available.index(provider_type)
|
||||
|
||||
def on_provider_failure(self, provider_type: ProviderType, error: str):
|
||||
"""
|
||||
提供者失败时调用
|
||||
|
||||
Args:
|
||||
provider_type: 提供者类型
|
||||
error: 错误信息
|
||||
"""
|
||||
self.health_checker.record_failure(provider_type, error)
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取故障切换状态
|
||||
|
||||
Returns:
|
||||
状态字典
|
||||
"""
|
||||
current = self.get_current_provider()
|
||||
return {
|
||||
"current_provider": current.value if current else None,
|
||||
"priority_order": [p.value for p in self.priority_order],
|
||||
"available_providers": [
|
||||
p.value for p in self.health_checker.get_available_providers(self.priority_order)
|
||||
],
|
||||
"health_status": self.health_checker.get_all_health_status(),
|
||||
}
|
||||
29
src/services/outlook/providers/__init__.py
Normal file
29
src/services/outlook/providers/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Outlook 提供者模块
|
||||
"""
|
||||
|
||||
from .base import OutlookProvider, ProviderConfig
|
||||
from .imap_old import IMAPOldProvider
|
||||
from .imap_new import IMAPNewProvider
|
||||
from .graph_api import GraphAPIProvider
|
||||
|
||||
__all__ = [
|
||||
'OutlookProvider',
|
||||
'ProviderConfig',
|
||||
'IMAPOldProvider',
|
||||
'IMAPNewProvider',
|
||||
'GraphAPIProvider',
|
||||
]
|
||||
|
||||
|
||||
# 提供者注册表
|
||||
PROVIDER_REGISTRY = {
|
||||
'imap_old': IMAPOldProvider,
|
||||
'imap_new': IMAPNewProvider,
|
||||
'graph_api': GraphAPIProvider,
|
||||
}
|
||||
|
||||
|
||||
def get_provider_class(provider_type: str):
|
||||
"""获取提供者类"""
|
||||
return PROVIDER_REGISTRY.get(provider_type)
|
||||
180
src/services/outlook/providers/base.py
Normal file
180
src/services/outlook/providers/base.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Outlook 提供者抽象基类
|
||||
"""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from ..base import ProviderType, EmailMessage, ProviderHealth, ProviderStatus
|
||||
from ..account import OutlookAccount
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderConfig:
|
||||
"""提供者配置"""
|
||||
timeout: int = 30
|
||||
max_retries: int = 3
|
||||
proxy_url: Optional[str] = None
|
||||
|
||||
# 健康检查配置
|
||||
health_failure_threshold: int = 3
|
||||
health_disable_duration: int = 300 # 秒
|
||||
|
||||
|
||||
class OutlookProvider(abc.ABC):
|
||||
"""
|
||||
Outlook 提供者抽象基类
|
||||
定义所有提供者必须实现的接口
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
account: OutlookAccount,
|
||||
config: Optional[ProviderConfig] = None,
|
||||
):
|
||||
"""
|
||||
初始化提供者
|
||||
|
||||
Args:
|
||||
account: Outlook 账户
|
||||
config: 提供者配置
|
||||
"""
|
||||
self.account = account
|
||||
self.config = config or ProviderConfig()
|
||||
|
||||
# 健康状态
|
||||
self._health = ProviderHealth(provider_type=self.provider_type)
|
||||
|
||||
# 连接状态
|
||||
self._connected = False
|
||||
self._last_error: Optional[str] = None
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def provider_type(self) -> ProviderType:
|
||||
"""获取提供者类型"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def health(self) -> ProviderHealth:
|
||||
"""获取健康状态"""
|
||||
return self._health
|
||||
|
||||
@property
|
||||
def is_healthy(self) -> bool:
|
||||
"""检查是否健康"""
|
||||
return (
|
||||
self._health.status == ProviderStatus.HEALTHY
|
||||
and not self._health.is_disabled()
|
||||
)
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""检查是否已连接"""
|
||||
return self._connected
|
||||
|
||||
@abc.abstractmethod
|
||||
def connect(self) -> bool:
|
||||
"""
|
||||
连接到服务
|
||||
|
||||
Returns:
|
||||
是否连接成功
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self):
|
||||
"""断开连接"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_recent_emails(
|
||||
self,
|
||||
count: int = 20,
|
||||
only_unseen: bool = True,
|
||||
) -> List[EmailMessage]:
|
||||
"""
|
||||
获取最近的邮件
|
||||
|
||||
Args:
|
||||
count: 获取数量
|
||||
only_unseen: 是否只获取未读
|
||||
|
||||
Returns:
|
||||
邮件列表
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def test_connection(self) -> bool:
|
||||
"""
|
||||
测试连接是否正常
|
||||
|
||||
Returns:
|
||||
连接是否正常
|
||||
"""
|
||||
pass
|
||||
|
||||
def record_success(self):
|
||||
"""记录成功操作"""
|
||||
self._health.record_success()
|
||||
self._last_error = None
|
||||
logger.debug(f"[{self.account.email}] {self.provider_type.value} 操作成功")
|
||||
|
||||
def record_failure(self, error: str):
|
||||
"""记录失败操作"""
|
||||
self._health.record_failure(error)
|
||||
self._last_error = error
|
||||
|
||||
# 检查是否需要禁用
|
||||
if self._health.should_disable(self.config.health_failure_threshold):
|
||||
self._health.disable(self.config.health_disable_duration)
|
||||
logger.warning(
|
||||
f"[{self.account.email}] {self.provider_type.value} 已禁用 "
|
||||
f"{self.config.health_disable_duration} 秒,原因: {error}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[{self.account.email}] {self.provider_type.value} 操作失败 "
|
||||
f"({self._health.failure_count}/{self.config.health_failure_threshold}): {error}"
|
||||
)
|
||||
|
||||
def check_health(self) -> bool:
|
||||
"""
|
||||
检查健康状态
|
||||
|
||||
Returns:
|
||||
是否健康可用
|
||||
"""
|
||||
# 检查是否被禁用
|
||||
if self._health.is_disabled():
|
||||
logger.debug(
|
||||
f"[{self.account.email}] {self.provider_type.value} 已被禁用,"
|
||||
f"将在 {self._health.disabled_until} 后恢复"
|
||||
)
|
||||
return False
|
||||
|
||||
return self._health.status in (ProviderStatus.HEALTHY, ProviderStatus.DEGRADED)
|
||||
|
||||
def __enter__(self):
|
||||
"""上下文管理器入口"""
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""上下文管理器出口"""
|
||||
self.disconnect()
|
||||
return False
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""字符串表示"""
|
||||
return f"{self.__class__.__name__}({self.account.email})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
250
src/services/outlook/providers/graph_api.py
Normal file
250
src/services/outlook/providers/graph_api.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
Graph API 提供者
|
||||
使用 Microsoft Graph REST API
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from curl_cffi import requests as _requests
|
||||
|
||||
from ..base import ProviderType, EmailMessage
|
||||
from ..account import OutlookAccount
|
||||
from ..token_manager import TokenManager
|
||||
from .base import OutlookProvider, ProviderConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphAPIProvider(OutlookProvider):
|
||||
"""
|
||||
Graph API 提供者
|
||||
使用 Microsoft Graph REST API 获取邮件
|
||||
需要 graph.microsoft.com/.default scope
|
||||
"""
|
||||
|
||||
# Graph API 端点
|
||||
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
|
||||
MESSAGES_ENDPOINT = "/me/mailFolders/inbox/messages"
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ProviderType:
|
||||
return ProviderType.GRAPH_API
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
account: OutlookAccount,
|
||||
config: Optional[ProviderConfig] = None,
|
||||
):
|
||||
super().__init__(account, config)
|
||||
|
||||
# Token 管理器
|
||||
self._token_manager: Optional[TokenManager] = None
|
||||
|
||||
# 注意:Graph API 必须使用 OAuth2
|
||||
if not account.has_oauth():
|
||||
logger.warning(
|
||||
f"[{self.account.email}] Graph API 提供者需要 OAuth2 配置 "
|
||||
f"(client_id + refresh_token)"
|
||||
)
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""
|
||||
验证连接(获取 Token)
|
||||
|
||||
Returns:
|
||||
是否连接成功
|
||||
"""
|
||||
if not self.account.has_oauth():
|
||||
error = "Graph API 需要 OAuth2 配置"
|
||||
self.record_failure(error)
|
||||
logger.error(f"[{self.account.email}] {error}")
|
||||
return False
|
||||
|
||||
if not self._token_manager:
|
||||
self._token_manager = TokenManager(
|
||||
self.account,
|
||||
ProviderType.GRAPH_API,
|
||||
self.config.proxy_url,
|
||||
self.config.timeout,
|
||||
)
|
||||
|
||||
# 尝试获取 Token
|
||||
token = self._token_manager.get_access_token()
|
||||
if token:
|
||||
self._connected = True
|
||||
self.record_success()
|
||||
logger.info(f"[{self.account.email}] Graph API 连接成功")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def disconnect(self):
|
||||
"""断开连接(清除状态)"""
|
||||
self._connected = False
|
||||
|
||||
def get_recent_emails(
|
||||
self,
|
||||
count: int = 20,
|
||||
only_unseen: bool = True,
|
||||
) -> List[EmailMessage]:
|
||||
"""
|
||||
获取最近的邮件
|
||||
|
||||
Args:
|
||||
count: 获取数量
|
||||
only_unseen: 是否只获取未读
|
||||
|
||||
Returns:
|
||||
邮件列表
|
||||
"""
|
||||
if not self._connected:
|
||||
if not self.connect():
|
||||
return []
|
||||
|
||||
try:
|
||||
# 获取 Access Token
|
||||
token = self._token_manager.get_access_token()
|
||||
if not token:
|
||||
self.record_failure("无法获取 Access Token")
|
||||
return []
|
||||
|
||||
# 构建 API 请求
|
||||
url = f"{self.GRAPH_API_BASE}{self.MESSAGES_ENDPOINT}"
|
||||
|
||||
params = {
|
||||
"$top": count,
|
||||
"$select": "id,subject,from,toRecipients,receivedDateTime,isRead,hasAttachments,bodyPreview,body",
|
||||
"$orderby": "receivedDateTime desc",
|
||||
}
|
||||
|
||||
# 只获取未读邮件
|
||||
if only_unseen:
|
||||
params["$filter"] = "isRead eq false"
|
||||
|
||||
# 构建代理配置
|
||||
proxies = None
|
||||
if self.config.proxy_url:
|
||||
proxies = {"http": self.config.proxy_url, "https": self.config.proxy_url}
|
||||
|
||||
# 发送请求(curl_cffi 自动对 params 进行 URL 编码)
|
||||
resp = _requests.get(
|
||||
url,
|
||||
params=params,
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Accept": "application/json",
|
||||
"Prefer": "outlook.body-content-type='text'",
|
||||
},
|
||||
proxies=proxies,
|
||||
timeout=self.config.timeout,
|
||||
impersonate="chrome110",
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
# Token 无 Graph 权限(client_id 未授权),清除缓存但不记录健康失败
|
||||
# 避免因权限不足导致健康检查器禁用该提供者,影响其他账户
|
||||
if self._token_manager:
|
||||
self._token_manager.clear_cache()
|
||||
self._connected = False
|
||||
logger.warning(f"[{self.account.email}] Graph API 返回 401,client_id 可能无 Graph 权限,跳过")
|
||||
return []
|
||||
|
||||
if resp.status_code != 200:
|
||||
error_body = resp.text[:200]
|
||||
self.record_failure(f"HTTP {resp.status_code}: {error_body}")
|
||||
logger.error(f"[{self.account.email}] Graph API 请求失败: HTTP {resp.status_code}")
|
||||
return []
|
||||
|
||||
data = resp.json()
|
||||
|
||||
# 解析邮件
|
||||
messages = data.get("value", [])
|
||||
emails = []
|
||||
|
||||
for msg in messages:
|
||||
try:
|
||||
email_msg = self._parse_graph_message(msg)
|
||||
if email_msg:
|
||||
emails.append(email_msg)
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.account.email}] 解析 Graph API 邮件失败: {e}")
|
||||
|
||||
self.record_success()
|
||||
return emails
|
||||
|
||||
except Exception as e:
|
||||
self.record_failure(str(e))
|
||||
logger.error(f"[{self.account.email}] Graph API 获取邮件失败: {e}")
|
||||
return []
|
||||
|
||||
def _parse_graph_message(self, msg: dict) -> Optional[EmailMessage]:
|
||||
"""
|
||||
解析 Graph API 消息
|
||||
|
||||
Args:
|
||||
msg: Graph API 消息对象
|
||||
|
||||
Returns:
|
||||
EmailMessage 对象
|
||||
"""
|
||||
# 解析发件人
|
||||
from_info = msg.get("from", {})
|
||||
sender_info = from_info.get("emailAddress", {})
|
||||
sender = sender_info.get("address", "")
|
||||
|
||||
# 解析收件人
|
||||
recipients = []
|
||||
for recipient in msg.get("toRecipients", []):
|
||||
addr_info = recipient.get("emailAddress", {})
|
||||
addr = addr_info.get("address", "")
|
||||
if addr:
|
||||
recipients.append(addr)
|
||||
|
||||
# 解析日期
|
||||
received_at = None
|
||||
received_timestamp = 0
|
||||
try:
|
||||
date_str = msg.get("receivedDateTime", "")
|
||||
if date_str:
|
||||
# ISO 8601 格式
|
||||
received_at = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||
received_timestamp = int(received_at.timestamp())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 获取正文
|
||||
body_info = msg.get("body", {})
|
||||
body = body_info.get("content", "")
|
||||
body_preview = msg.get("bodyPreview", "")
|
||||
|
||||
return EmailMessage(
|
||||
id=msg.get("id", ""),
|
||||
subject=msg.get("subject", ""),
|
||||
sender=sender,
|
||||
recipients=recipients,
|
||||
body=body,
|
||||
body_preview=body_preview,
|
||||
received_at=received_at,
|
||||
received_timestamp=received_timestamp,
|
||||
is_read=msg.get("isRead", False),
|
||||
has_attachments=msg.get("hasAttachments", False),
|
||||
)
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""
|
||||
测试 Graph API 连接
|
||||
|
||||
Returns:
|
||||
连接是否正常
|
||||
"""
|
||||
try:
|
||||
# 尝试获取一封邮件来测试连接
|
||||
emails = self.get_recent_emails(count=1, only_unseen=False)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.account.email}] Graph API 连接测试失败: {e}")
|
||||
return False
|
||||
231
src/services/outlook/providers/imap_new.py
Normal file
231
src/services/outlook/providers/imap_new.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
新版 IMAP 提供者
|
||||
使用 outlook.live.com 服务器和 login.microsoftonline.com/consumers Token 端点
|
||||
"""
|
||||
|
||||
import email
|
||||
import imaplib
|
||||
import logging
|
||||
from email.header import decode_header
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from ..base import ProviderType, EmailMessage
|
||||
from ..account import OutlookAccount
|
||||
from ..token_manager import TokenManager
|
||||
from .base import OutlookProvider, ProviderConfig
|
||||
from .imap_old import IMAPOldProvider
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IMAPNewProvider(OutlookProvider):
|
||||
"""
|
||||
新版 IMAP 提供者
|
||||
使用 outlook.live.com:993 和 login.microsoftonline.com/consumers Token 端点
|
||||
需要 IMAP.AccessAsUser.All scope
|
||||
"""
|
||||
|
||||
# IMAP 服务器配置
|
||||
IMAP_HOST = "outlook.live.com"
|
||||
IMAP_PORT = 993
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ProviderType:
|
||||
return ProviderType.IMAP_NEW
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
account: OutlookAccount,
|
||||
config: Optional[ProviderConfig] = None,
|
||||
):
|
||||
super().__init__(account, config)
|
||||
|
||||
# IMAP 连接
|
||||
self._conn: Optional[imaplib.IMAP4_SSL] = None
|
||||
|
||||
# Token 管理器
|
||||
self._token_manager: Optional[TokenManager] = None
|
||||
|
||||
# 注意:新版 IMAP 必须使用 OAuth2
|
||||
if not account.has_oauth():
|
||||
logger.warning(
|
||||
f"[{self.account.email}] 新版 IMAP 提供者需要 OAuth2 配置 "
|
||||
f"(client_id + refresh_token)"
|
||||
)
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""
|
||||
连接到 IMAP 服务器
|
||||
|
||||
Returns:
|
||||
是否连接成功
|
||||
"""
|
||||
if self._connected and self._conn:
|
||||
try:
|
||||
self._conn.noop()
|
||||
return True
|
||||
except Exception:
|
||||
self.disconnect()
|
||||
|
||||
# 新版 IMAP 必须使用 OAuth2,无 OAuth 时静默跳过,不记录健康失败
|
||||
if not self.account.has_oauth():
|
||||
logger.debug(f"[{self.account.email}] 跳过 IMAP_NEW(无 OAuth)")
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.debug(f"[{self.account.email}] 正在连接 IMAP ({self.IMAP_HOST})...")
|
||||
|
||||
# 创建连接
|
||||
self._conn = imaplib.IMAP4_SSL(
|
||||
self.IMAP_HOST,
|
||||
self.IMAP_PORT,
|
||||
timeout=self.config.timeout,
|
||||
)
|
||||
|
||||
# XOAUTH2 认证
|
||||
if self._authenticate_xoauth2():
|
||||
self._connected = True
|
||||
self.record_success()
|
||||
logger.info(f"[{self.account.email}] 新版 IMAP 连接成功 (XOAUTH2)")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.disconnect()
|
||||
self.record_failure(str(e))
|
||||
logger.error(f"[{self.account.email}] 新版 IMAP 连接失败: {e}")
|
||||
return False
|
||||
|
||||
def _authenticate_xoauth2(self) -> bool:
|
||||
"""
|
||||
使用 XOAUTH2 认证
|
||||
|
||||
Returns:
|
||||
是否认证成功
|
||||
"""
|
||||
if not self._token_manager:
|
||||
self._token_manager = TokenManager(
|
||||
self.account,
|
||||
ProviderType.IMAP_NEW,
|
||||
self.config.proxy_url,
|
||||
self.config.timeout,
|
||||
)
|
||||
|
||||
# 获取 Access Token
|
||||
token = self._token_manager.get_access_token()
|
||||
if not token:
|
||||
logger.error(f"[{self.account.email}] 获取 IMAP Token 失败")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 构建 XOAUTH2 认证字符串
|
||||
auth_string = f"user={self.account.email}\x01auth=Bearer {token}\x01\x01"
|
||||
self._conn.authenticate("XOAUTH2", lambda _: auth_string.encode("utf-8"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.account.email}] XOAUTH2 认证异常: {e}")
|
||||
# 清除缓存的 Token
|
||||
self._token_manager.clear_cache()
|
||||
return False
|
||||
|
||||
def disconnect(self):
|
||||
"""断开 IMAP 连接"""
|
||||
if self._conn:
|
||||
try:
|
||||
self._conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
self._conn = None
|
||||
|
||||
self._connected = False
|
||||
|
||||
def get_recent_emails(
|
||||
self,
|
||||
count: int = 20,
|
||||
only_unseen: bool = True,
|
||||
) -> List[EmailMessage]:
|
||||
"""
|
||||
获取最近的邮件
|
||||
|
||||
Args:
|
||||
count: 获取数量
|
||||
only_unseen: 是否只获取未读
|
||||
|
||||
Returns:
|
||||
邮件列表
|
||||
"""
|
||||
if not self._connected:
|
||||
if not self.connect():
|
||||
return []
|
||||
|
||||
try:
|
||||
# 选择收件箱
|
||||
self._conn.select("INBOX", readonly=True)
|
||||
|
||||
# 搜索邮件
|
||||
flag = "UNSEEN" if only_unseen else "ALL"
|
||||
status, data = self._conn.search(None, flag)
|
||||
|
||||
if status != "OK" or not data or not data[0]:
|
||||
return []
|
||||
|
||||
# 获取最新的邮件 ID
|
||||
ids = data[0].split()
|
||||
recent_ids = ids[-count:][::-1]
|
||||
|
||||
emails = []
|
||||
for msg_id in recent_ids:
|
||||
try:
|
||||
email_msg = self._fetch_email(msg_id)
|
||||
if email_msg:
|
||||
emails.append(email_msg)
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.account.email}] 解析邮件失败 (ID: {msg_id}): {e}")
|
||||
|
||||
return emails
|
||||
|
||||
except Exception as e:
|
||||
self.record_failure(str(e))
|
||||
logger.error(f"[{self.account.email}] 获取邮件失败: {e}")
|
||||
return []
|
||||
|
||||
def _fetch_email(self, msg_id: bytes) -> Optional[EmailMessage]:
|
||||
"""获取并解析单封邮件"""
|
||||
status, data = self._conn.fetch(msg_id, "(RFC822)")
|
||||
if status != "OK" or not data or not data[0]:
|
||||
return None
|
||||
|
||||
raw = b""
|
||||
for part in data:
|
||||
if isinstance(part, tuple) and len(part) > 1:
|
||||
raw = part[1]
|
||||
break
|
||||
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
return self._parse_email(raw)
|
||||
|
||||
@staticmethod
|
||||
def _parse_email(raw: bytes) -> EmailMessage:
|
||||
"""解析原始邮件"""
|
||||
# 使用旧版提供者的解析方法
|
||||
return IMAPOldProvider._parse_email(raw)
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""测试 IMAP 连接"""
|
||||
try:
|
||||
with self:
|
||||
self._conn.select("INBOX", readonly=True)
|
||||
self._conn.search(None, "ALL")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.account.email}] 新版 IMAP 连接测试失败: {e}")
|
||||
return False
|
||||
345
src/services/outlook/providers/imap_old.py
Normal file
345
src/services/outlook/providers/imap_old.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
旧版 IMAP 提供者
|
||||
使用 outlook.office365.com 服务器和 login.live.com Token 端点
|
||||
"""
|
||||
|
||||
import email
|
||||
import imaplib
|
||||
import logging
|
||||
from email.header import decode_header
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from ..base import ProviderType, EmailMessage
|
||||
from ..account import OutlookAccount
|
||||
from ..token_manager import TokenManager
|
||||
from .base import OutlookProvider, ProviderConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IMAPOldProvider(OutlookProvider):
|
||||
"""
|
||||
旧版 IMAP 提供者
|
||||
使用 outlook.office365.com:993 和 login.live.com Token 端点
|
||||
"""
|
||||
|
||||
# IMAP 服务器配置
|
||||
IMAP_HOST = "outlook.office365.com"
|
||||
IMAP_PORT = 993
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ProviderType:
|
||||
return ProviderType.IMAP_OLD
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
account: OutlookAccount,
|
||||
config: Optional[ProviderConfig] = None,
|
||||
):
|
||||
super().__init__(account, config)
|
||||
|
||||
# IMAP 连接
|
||||
self._conn: Optional[imaplib.IMAP4_SSL] = None
|
||||
|
||||
# Token 管理器
|
||||
self._token_manager: Optional[TokenManager] = None
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""
|
||||
连接到 IMAP 服务器
|
||||
|
||||
Returns:
|
||||
是否连接成功
|
||||
"""
|
||||
if self._connected and self._conn:
|
||||
# 检查现有连接
|
||||
try:
|
||||
self._conn.noop()
|
||||
return True
|
||||
except Exception:
|
||||
self.disconnect()
|
||||
|
||||
try:
|
||||
logger.debug(f"[{self.account.email}] 正在连接 IMAP ({self.IMAP_HOST})...")
|
||||
|
||||
# 创建连接
|
||||
self._conn = imaplib.IMAP4_SSL(
|
||||
self.IMAP_HOST,
|
||||
self.IMAP_PORT,
|
||||
timeout=self.config.timeout,
|
||||
)
|
||||
|
||||
# 尝试 XOAUTH2 认证
|
||||
if self.account.has_oauth():
|
||||
if self._authenticate_xoauth2():
|
||||
self._connected = True
|
||||
self.record_success()
|
||||
logger.info(f"[{self.account.email}] IMAP 连接成功 (XOAUTH2)")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"[{self.account.email}] XOAUTH2 认证失败,尝试密码认证")
|
||||
|
||||
# 密码认证
|
||||
if self.account.password:
|
||||
self._conn.login(self.account.email, self.account.password)
|
||||
self._connected = True
|
||||
self.record_success()
|
||||
logger.info(f"[{self.account.email}] IMAP 连接成功 (密码认证)")
|
||||
return True
|
||||
|
||||
raise ValueError("没有可用的认证方式")
|
||||
|
||||
except Exception as e:
|
||||
self.disconnect()
|
||||
self.record_failure(str(e))
|
||||
logger.error(f"[{self.account.email}] IMAP 连接失败: {e}")
|
||||
return False
|
||||
|
||||
def _authenticate_xoauth2(self) -> bool:
|
||||
"""
|
||||
使用 XOAUTH2 认证
|
||||
|
||||
Returns:
|
||||
是否认证成功
|
||||
"""
|
||||
if not self._token_manager:
|
||||
self._token_manager = TokenManager(
|
||||
self.account,
|
||||
ProviderType.IMAP_OLD,
|
||||
self.config.proxy_url,
|
||||
self.config.timeout,
|
||||
)
|
||||
|
||||
# 获取 Access Token
|
||||
token = self._token_manager.get_access_token()
|
||||
if not token:
|
||||
return False
|
||||
|
||||
try:
|
||||
# 构建 XOAUTH2 认证字符串
|
||||
auth_string = f"user={self.account.email}\x01auth=Bearer {token}\x01\x01"
|
||||
self._conn.authenticate("XOAUTH2", lambda _: auth_string.encode("utf-8"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"[{self.account.email}] XOAUTH2 认证异常: {e}")
|
||||
# 清除缓存的 Token
|
||||
self._token_manager.clear_cache()
|
||||
return False
|
||||
|
||||
def disconnect(self):
|
||||
"""断开 IMAP 连接"""
|
||||
if self._conn:
|
||||
try:
|
||||
self._conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
self._conn = None
|
||||
|
||||
self._connected = False
|
||||
|
||||
def get_recent_emails(
|
||||
self,
|
||||
count: int = 20,
|
||||
only_unseen: bool = True,
|
||||
) -> List[EmailMessage]:
|
||||
"""
|
||||
获取最近的邮件
|
||||
|
||||
Args:
|
||||
count: 获取数量
|
||||
only_unseen: 是否只获取未读
|
||||
|
||||
Returns:
|
||||
邮件列表
|
||||
"""
|
||||
if not self._connected:
|
||||
if not self.connect():
|
||||
return []
|
||||
|
||||
try:
|
||||
# 选择收件箱
|
||||
self._conn.select("INBOX", readonly=True)
|
||||
|
||||
# 搜索邮件
|
||||
flag = "UNSEEN" if only_unseen else "ALL"
|
||||
status, data = self._conn.search(None, flag)
|
||||
|
||||
if status != "OK" or not data or not data[0]:
|
||||
return []
|
||||
|
||||
# 获取最新的邮件 ID
|
||||
ids = data[0].split()
|
||||
recent_ids = ids[-count:][::-1] # 倒序,最新的在前
|
||||
|
||||
emails = []
|
||||
for msg_id in recent_ids:
|
||||
try:
|
||||
email_msg = self._fetch_email(msg_id)
|
||||
if email_msg:
|
||||
emails.append(email_msg)
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.account.email}] 解析邮件失败 (ID: {msg_id}): {e}")
|
||||
|
||||
return emails
|
||||
|
||||
except Exception as e:
|
||||
self.record_failure(str(e))
|
||||
logger.error(f"[{self.account.email}] 获取邮件失败: {e}")
|
||||
return []
|
||||
|
||||
def _fetch_email(self, msg_id: bytes) -> Optional[EmailMessage]:
|
||||
"""
|
||||
获取并解析单封邮件
|
||||
|
||||
Args:
|
||||
msg_id: 邮件 ID
|
||||
|
||||
Returns:
|
||||
EmailMessage 对象,失败返回 None
|
||||
"""
|
||||
status, data = self._conn.fetch(msg_id, "(RFC822)")
|
||||
if status != "OK" or not data or not data[0]:
|
||||
return None
|
||||
|
||||
# 获取原始邮件内容
|
||||
raw = b""
|
||||
for part in data:
|
||||
if isinstance(part, tuple) and len(part) > 1:
|
||||
raw = part[1]
|
||||
break
|
||||
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
return self._parse_email(raw)
|
||||
|
||||
@staticmethod
|
||||
def _parse_email(raw: bytes) -> EmailMessage:
|
||||
"""
|
||||
解析原始邮件
|
||||
|
||||
Args:
|
||||
raw: 原始邮件数据
|
||||
|
||||
Returns:
|
||||
EmailMessage 对象
|
||||
"""
|
||||
# 移除 BOM
|
||||
if raw.startswith(b"\xef\xbb\xbf"):
|
||||
raw = raw[3:]
|
||||
|
||||
msg = email.message_from_bytes(raw)
|
||||
|
||||
# 解析邮件头
|
||||
subject = IMAPOldProvider._decode_header(msg.get("Subject", ""))
|
||||
sender = IMAPOldProvider._decode_header(msg.get("From", ""))
|
||||
to = IMAPOldProvider._decode_header(msg.get("To", ""))
|
||||
delivered_to = IMAPOldProvider._decode_header(msg.get("Delivered-To", ""))
|
||||
x_original_to = IMAPOldProvider._decode_header(msg.get("X-Original-To", ""))
|
||||
date_str = IMAPOldProvider._decode_header(msg.get("Date", ""))
|
||||
|
||||
# 提取正文
|
||||
body = IMAPOldProvider._extract_body(msg)
|
||||
|
||||
# 解析日期
|
||||
received_timestamp = 0
|
||||
received_at = None
|
||||
try:
|
||||
if date_str:
|
||||
received_at = parsedate_to_datetime(date_str)
|
||||
received_timestamp = int(received_at.timestamp())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 构建收件人列表
|
||||
recipients = [r for r in [to, delivered_to, x_original_to] if r]
|
||||
|
||||
return EmailMessage(
|
||||
id=msg.get("Message-ID", ""),
|
||||
subject=subject,
|
||||
sender=sender,
|
||||
recipients=recipients,
|
||||
body=body,
|
||||
received_at=received_at,
|
||||
received_timestamp=received_timestamp,
|
||||
is_read=False, # 搜索的是未读邮件
|
||||
raw_data=raw[:500] if len(raw) > 500 else raw,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _decode_header(header: str) -> str:
|
||||
"""解码邮件头"""
|
||||
if not header:
|
||||
return ""
|
||||
|
||||
parts = []
|
||||
for chunk, encoding in decode_header(header):
|
||||
if isinstance(chunk, bytes):
|
||||
try:
|
||||
decoded = chunk.decode(encoding or "utf-8", errors="replace")
|
||||
parts.append(decoded)
|
||||
except Exception:
|
||||
parts.append(chunk.decode("utf-8", errors="replace"))
|
||||
else:
|
||||
parts.append(str(chunk))
|
||||
|
||||
return "".join(parts).strip()
|
||||
|
||||
@staticmethod
|
||||
def _extract_body(msg) -> str:
|
||||
"""提取邮件正文"""
|
||||
import html as html_module
|
||||
import re
|
||||
|
||||
texts = []
|
||||
parts = msg.walk() if msg.is_multipart() else [msg]
|
||||
|
||||
for part in parts:
|
||||
content_type = part.get_content_type()
|
||||
if content_type not in ("text/plain", "text/html"):
|
||||
continue
|
||||
|
||||
payload = part.get_payload(decode=True)
|
||||
if not payload:
|
||||
continue
|
||||
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
try:
|
||||
text = payload.decode(charset, errors="replace")
|
||||
except LookupError:
|
||||
text = payload.decode("utf-8", errors="replace")
|
||||
|
||||
# 如果是 HTML,移除标签
|
||||
if "<html" in text.lower():
|
||||
text = re.sub(r"<[^>]+>", " ", text)
|
||||
|
||||
texts.append(text)
|
||||
|
||||
# 合并并清理文本
|
||||
combined = " ".join(texts)
|
||||
combined = html_module.unescape(combined)
|
||||
combined = re.sub(r"\s+", " ", combined).strip()
|
||||
|
||||
return combined
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""
|
||||
测试 IMAP 连接
|
||||
|
||||
Returns:
|
||||
连接是否正常
|
||||
"""
|
||||
try:
|
||||
with self:
|
||||
self._conn.select("INBOX", readonly=True)
|
||||
self._conn.search(None, "ALL")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.account.email}] IMAP 连接测试失败: {e}")
|
||||
return False
|
||||
487
src/services/outlook/service.py
Normal file
487
src/services/outlook/service.py
Normal file
@@ -0,0 +1,487 @@
|
||||
"""
|
||||
Outlook 邮箱服务主类
|
||||
支持多种 IMAP/API 连接方式,自动故障切换
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from ..base import BaseEmailService, EmailServiceError, EmailServiceStatus, EmailServiceType
|
||||
from ...config.constants import EmailServiceType as ServiceType
|
||||
from ...config.settings import get_settings
|
||||
from .account import OutlookAccount
|
||||
from .base import ProviderType, EmailMessage
|
||||
from .email_parser import EmailParser, get_email_parser
|
||||
from .health_checker import HealthChecker, FailoverManager
|
||||
from .providers.base import OutlookProvider, ProviderConfig
|
||||
from .providers.imap_old import IMAPOldProvider
|
||||
from .providers.imap_new import IMAPNewProvider
|
||||
from .providers.graph_api import GraphAPIProvider
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 默认提供者优先级
|
||||
# IMAP_OLD 最兼容(只需 login.live.com token),IMAP_NEW 次之,Graph API 最后
|
||||
# 原因:部分 client_id 没有 Graph API 权限,但有 IMAP 权限
|
||||
DEFAULT_PROVIDER_PRIORITY = [
|
||||
ProviderType.IMAP_OLD,
|
||||
ProviderType.IMAP_NEW,
|
||||
ProviderType.GRAPH_API,
|
||||
]
|
||||
|
||||
|
||||
def get_email_code_settings() -> dict:
|
||||
"""获取验证码等待配置"""
|
||||
settings = get_settings()
|
||||
return {
|
||||
"timeout": settings.email_code_timeout,
|
||||
"poll_interval": settings.email_code_poll_interval,
|
||||
}
|
||||
|
||||
|
||||
class OutlookService(BaseEmailService):
|
||||
"""
|
||||
Outlook 邮箱服务
|
||||
支持多种 IMAP/API 连接方式,自动故障切换
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None, name: str = None):
|
||||
"""
|
||||
初始化 Outlook 服务
|
||||
|
||||
Args:
|
||||
config: 配置字典,支持以下键:
|
||||
- accounts: Outlook 账户列表
|
||||
- provider_priority: 提供者优先级列表
|
||||
- health_failure_threshold: 连续失败次数阈值
|
||||
- health_disable_duration: 禁用时长(秒)
|
||||
- timeout: 请求超时时间
|
||||
- proxy_url: 代理 URL
|
||||
name: 服务名称
|
||||
"""
|
||||
super().__init__(ServiceType.OUTLOOK, name)
|
||||
|
||||
# 默认配置
|
||||
default_config = {
|
||||
"accounts": [],
|
||||
"provider_priority": [p.value for p in DEFAULT_PROVIDER_PRIORITY],
|
||||
"health_failure_threshold": 5,
|
||||
"health_disable_duration": 60,
|
||||
"timeout": 30,
|
||||
"proxy_url": None,
|
||||
}
|
||||
|
||||
self.config = {**default_config, **(config or {})}
|
||||
|
||||
# 解析提供者优先级
|
||||
self.provider_priority = [
|
||||
ProviderType(p) for p in self.config.get("provider_priority", [])
|
||||
]
|
||||
if not self.provider_priority:
|
||||
self.provider_priority = DEFAULT_PROVIDER_PRIORITY
|
||||
|
||||
# 提供者配置
|
||||
self.provider_config = ProviderConfig(
|
||||
timeout=self.config.get("timeout", 30),
|
||||
proxy_url=self.config.get("proxy_url"),
|
||||
health_failure_threshold=self.config.get("health_failure_threshold", 3),
|
||||
health_disable_duration=self.config.get("health_disable_duration", 300),
|
||||
)
|
||||
|
||||
# 获取默认 client_id(供无 client_id 的账户使用)
|
||||
try:
|
||||
_default_client_id = get_settings().outlook_default_client_id
|
||||
except Exception:
|
||||
_default_client_id = "24d9a0ed-8787-4584-883c-2fd79308940a"
|
||||
|
||||
# 解析账户
|
||||
self.accounts: List[OutlookAccount] = []
|
||||
self._current_account_index = 0
|
||||
self._account_lock = threading.Lock()
|
||||
|
||||
# 支持两种配置格式
|
||||
if "email" in self.config and "password" in self.config:
|
||||
account = OutlookAccount.from_config(self.config)
|
||||
if not account.client_id and _default_client_id:
|
||||
account.client_id = _default_client_id
|
||||
if account.validate():
|
||||
self.accounts.append(account)
|
||||
else:
|
||||
for account_config in self.config.get("accounts", []):
|
||||
account = OutlookAccount.from_config(account_config)
|
||||
if not account.client_id and _default_client_id:
|
||||
account.client_id = _default_client_id
|
||||
if account.validate():
|
||||
self.accounts.append(account)
|
||||
|
||||
if not self.accounts:
|
||||
logger.warning("未配置有效的 Outlook 账户")
|
||||
|
||||
# 健康检查器和故障切换管理器
|
||||
self.health_checker = HealthChecker(
|
||||
failure_threshold=self.provider_config.health_failure_threshold,
|
||||
disable_duration=self.provider_config.health_disable_duration,
|
||||
)
|
||||
self.failover_manager = FailoverManager(
|
||||
health_checker=self.health_checker,
|
||||
priority_order=self.provider_priority,
|
||||
)
|
||||
|
||||
# 邮件解析器
|
||||
self.email_parser = get_email_parser()
|
||||
|
||||
# 提供者实例缓存: (email, provider_type) -> OutlookProvider
|
||||
self._providers: Dict[tuple, OutlookProvider] = {}
|
||||
self._provider_lock = threading.Lock()
|
||||
|
||||
# IMAP 连接限制(防止限流)
|
||||
self._imap_semaphore = threading.Semaphore(5)
|
||||
|
||||
# 验证码去重机制
|
||||
self._used_codes: Dict[str, set] = {}
|
||||
|
||||
def _get_provider(
|
||||
self,
|
||||
account: OutlookAccount,
|
||||
provider_type: ProviderType,
|
||||
) -> OutlookProvider:
|
||||
"""
|
||||
获取或创建提供者实例
|
||||
|
||||
Args:
|
||||
account: Outlook 账户
|
||||
provider_type: 提供者类型
|
||||
|
||||
Returns:
|
||||
提供者实例
|
||||
"""
|
||||
cache_key = (account.email.lower(), provider_type)
|
||||
|
||||
with self._provider_lock:
|
||||
if cache_key not in self._providers:
|
||||
provider = self._create_provider(account, provider_type)
|
||||
self._providers[cache_key] = provider
|
||||
|
||||
return self._providers[cache_key]
|
||||
|
||||
def _create_provider(
|
||||
self,
|
||||
account: OutlookAccount,
|
||||
provider_type: ProviderType,
|
||||
) -> OutlookProvider:
|
||||
"""
|
||||
创建提供者实例
|
||||
|
||||
Args:
|
||||
account: Outlook 账户
|
||||
provider_type: 提供者类型
|
||||
|
||||
Returns:
|
||||
提供者实例
|
||||
"""
|
||||
if provider_type == ProviderType.IMAP_OLD:
|
||||
return IMAPOldProvider(account, self.provider_config)
|
||||
elif provider_type == ProviderType.IMAP_NEW:
|
||||
return IMAPNewProvider(account, self.provider_config)
|
||||
elif provider_type == ProviderType.GRAPH_API:
|
||||
return GraphAPIProvider(account, self.provider_config)
|
||||
else:
|
||||
raise ValueError(f"未知的提供者类型: {provider_type}")
|
||||
|
||||
def _get_provider_priority_for_account(self, account: OutlookAccount) -> List[ProviderType]:
|
||||
"""根据账户是否有 OAuth,返回适合的提供者优先级列表"""
|
||||
if account.has_oauth():
|
||||
return self.provider_priority
|
||||
else:
|
||||
# 无 OAuth,直接走旧版 IMAP(密码认证),跳过需要 OAuth 的提供者
|
||||
return [ProviderType.IMAP_OLD]
|
||||
|
||||
def _try_providers_for_emails(
|
||||
self,
|
||||
account: OutlookAccount,
|
||||
count: int = 20,
|
||||
only_unseen: bool = True,
|
||||
) -> List[EmailMessage]:
|
||||
"""
|
||||
尝试多个提供者获取邮件
|
||||
|
||||
Args:
|
||||
account: Outlook 账户
|
||||
count: 获取数量
|
||||
only_unseen: 是否只获取未读
|
||||
|
||||
Returns:
|
||||
邮件列表
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# 根据账户类型选择合适的提供者优先级
|
||||
priority = self._get_provider_priority_for_account(account)
|
||||
|
||||
# 按优先级尝试各提供者
|
||||
for provider_type in priority:
|
||||
# 检查提供者是否可用
|
||||
if not self.health_checker.is_available(provider_type):
|
||||
logger.debug(
|
||||
f"[{account.email}] {provider_type.value} 不可用,跳过"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
provider = self._get_provider(account, provider_type)
|
||||
|
||||
with self._imap_semaphore:
|
||||
with provider:
|
||||
emails = provider.get_recent_emails(count, only_unseen)
|
||||
|
||||
if emails:
|
||||
# 成功获取邮件
|
||||
self.health_checker.record_success(provider_type)
|
||||
logger.debug(
|
||||
f"[{account.email}] {provider_type.value} 获取到 {len(emails)} 封邮件"
|
||||
)
|
||||
return emails
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
errors.append(f"{provider_type.value}: {error_msg}")
|
||||
self.health_checker.record_failure(provider_type, error_msg)
|
||||
logger.warning(
|
||||
f"[{account.email}] {provider_type.value} 获取邮件失败: {e}"
|
||||
)
|
||||
|
||||
logger.error(
|
||||
f"[{account.email}] 所有提供者都失败: {'; '.join(errors)}"
|
||||
)
|
||||
return []
|
||||
|
||||
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
选择可用的 Outlook 账户
|
||||
|
||||
Args:
|
||||
config: 配置参数(未使用)
|
||||
|
||||
Returns:
|
||||
包含邮箱信息的字典
|
||||
"""
|
||||
if not self.accounts:
|
||||
self.update_status(False, EmailServiceError("没有可用的 Outlook 账户"))
|
||||
raise EmailServiceError("没有可用的 Outlook 账户")
|
||||
|
||||
# 轮询选择账户
|
||||
with self._account_lock:
|
||||
account = self.accounts[self._current_account_index]
|
||||
self._current_account_index = (self._current_account_index + 1) % len(self.accounts)
|
||||
|
||||
email_info = {
|
||||
"email": account.email,
|
||||
"service_id": account.email,
|
||||
"account": {
|
||||
"email": account.email,
|
||||
"has_oauth": account.has_oauth()
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(f"选择 Outlook 账户: {account.email}")
|
||||
self.update_status(True)
|
||||
return email_info
|
||||
|
||||
def get_verification_code(
|
||||
self,
|
||||
email: str,
|
||||
email_id: str = None,
|
||||
timeout: int = None,
|
||||
pattern: str = None,
|
||||
otp_sent_at: Optional[float] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
从 Outlook 邮箱获取验证码
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
email_id: 未使用
|
||||
timeout: 超时时间(秒)
|
||||
pattern: 验证码正则表达式(未使用)
|
||||
otp_sent_at: OTP 发送时间戳
|
||||
|
||||
Returns:
|
||||
验证码字符串
|
||||
"""
|
||||
# 查找对应的账户
|
||||
account = None
|
||||
for acc in self.accounts:
|
||||
if acc.email.lower() == email.lower():
|
||||
account = acc
|
||||
break
|
||||
|
||||
if not account:
|
||||
self.update_status(False, EmailServiceError(f"未找到邮箱对应的账户: {email}"))
|
||||
return None
|
||||
|
||||
# 获取验证码等待配置
|
||||
code_settings = get_email_code_settings()
|
||||
actual_timeout = timeout or code_settings["timeout"]
|
||||
poll_interval = code_settings["poll_interval"]
|
||||
|
||||
logger.info(
|
||||
f"[{email}] 开始获取验证码,超时 {actual_timeout}s,"
|
||||
f"提供者优先级: {[p.value for p in self.provider_priority]}"
|
||||
)
|
||||
|
||||
# 初始化验证码去重集合
|
||||
if email not in self._used_codes:
|
||||
self._used_codes[email] = set()
|
||||
used_codes = self._used_codes[email]
|
||||
|
||||
# 计算最小时间戳(留出 60 秒时钟偏差)
|
||||
min_timestamp = (otp_sent_at - 60) if otp_sent_at else 0
|
||||
|
||||
start_time = time.time()
|
||||
poll_count = 0
|
||||
|
||||
while time.time() - start_time < actual_timeout:
|
||||
poll_count += 1
|
||||
|
||||
# 渐进式邮件检查:前 3 次只检查未读
|
||||
only_unseen = poll_count <= 3
|
||||
|
||||
try:
|
||||
# 尝试多个提供者获取邮件
|
||||
emails = self._try_providers_for_emails(
|
||||
account,
|
||||
count=15,
|
||||
only_unseen=only_unseen,
|
||||
)
|
||||
|
||||
if emails:
|
||||
logger.debug(
|
||||
f"[{email}] 第 {poll_count} 次轮询获取到 {len(emails)} 封邮件"
|
||||
)
|
||||
|
||||
# 从邮件中查找验证码
|
||||
code = self.email_parser.find_verification_code_in_emails(
|
||||
emails,
|
||||
target_email=email,
|
||||
min_timestamp=min_timestamp,
|
||||
used_codes=used_codes,
|
||||
)
|
||||
|
||||
if code:
|
||||
used_codes.add(code)
|
||||
elapsed = int(time.time() - start_time)
|
||||
logger.info(
|
||||
f"[{email}] 找到验证码: {code},"
|
||||
f"总耗时 {elapsed}s,轮询 {poll_count} 次"
|
||||
)
|
||||
self.update_status(True)
|
||||
return code
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[{email}] 检查出错: {e}")
|
||||
|
||||
# 等待下次轮询
|
||||
time.sleep(poll_interval)
|
||||
|
||||
elapsed = int(time.time() - start_time)
|
||||
logger.warning(f"[{email}] 验证码超时 ({actual_timeout}s),共轮询 {poll_count} 次")
|
||||
return None
|
||||
|
||||
def list_emails(self, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""列出所有可用的 Outlook 账户"""
|
||||
return [
|
||||
{
|
||||
"email": account.email,
|
||||
"id": account.email,
|
||||
"has_oauth": account.has_oauth(),
|
||||
"type": "outlook"
|
||||
}
|
||||
for account in self.accounts
|
||||
]
|
||||
|
||||
def delete_email(self, email_id: str) -> bool:
|
||||
"""删除邮箱(Outlook 不支持删除账户)"""
|
||||
logger.warning(f"Outlook 服务不支持删除账户: {email_id}")
|
||||
return False
|
||||
|
||||
def check_health(self) -> bool:
|
||||
"""检查 Outlook 服务是否可用"""
|
||||
if not self.accounts:
|
||||
self.update_status(False, EmailServiceError("没有配置的账户"))
|
||||
return False
|
||||
|
||||
# 测试第一个账户的连接
|
||||
test_account = self.accounts[0]
|
||||
|
||||
# 尝试任一提供者连接
|
||||
for provider_type in self.provider_priority:
|
||||
try:
|
||||
provider = self._get_provider(test_account, provider_type)
|
||||
if provider.test_connection():
|
||||
self.update_status(True)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Outlook 健康检查失败 ({test_account.email}, {provider_type.value}): {e}"
|
||||
)
|
||||
|
||||
self.update_status(False, EmailServiceError("健康检查失败"))
|
||||
return False
|
||||
|
||||
def get_provider_status(self) -> Dict[str, Any]:
|
||||
"""获取提供者状态"""
|
||||
return self.failover_manager.get_status()
|
||||
|
||||
def get_account_stats(self) -> Dict[str, Any]:
|
||||
"""获取账户统计信息"""
|
||||
total = len(self.accounts)
|
||||
oauth_count = sum(1 for acc in self.accounts if acc.has_oauth())
|
||||
|
||||
return {
|
||||
"total_accounts": total,
|
||||
"oauth_accounts": oauth_count,
|
||||
"password_accounts": total - oauth_count,
|
||||
"accounts": [acc.to_dict() for acc in self.accounts],
|
||||
"provider_status": self.get_provider_status(),
|
||||
}
|
||||
|
||||
def add_account(self, account_config: Dict[str, Any]) -> bool:
|
||||
"""添加新的 Outlook 账户"""
|
||||
try:
|
||||
account = OutlookAccount.from_config(account_config)
|
||||
if not account.validate():
|
||||
return False
|
||||
|
||||
self.accounts.append(account)
|
||||
logger.info(f"添加 Outlook 账户: {account.email}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"添加 Outlook 账户失败: {e}")
|
||||
return False
|
||||
|
||||
def remove_account(self, email: str) -> bool:
|
||||
"""移除 Outlook 账户"""
|
||||
for i, acc in enumerate(self.accounts):
|
||||
if acc.email.lower() == email.lower():
|
||||
self.accounts.pop(i)
|
||||
logger.info(f"移除 Outlook 账户: {email}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def reset_provider_health(self):
|
||||
"""重置所有提供者的健康状态"""
|
||||
self.health_checker.reset_all()
|
||||
logger.info("已重置所有提供者的健康状态")
|
||||
|
||||
def force_provider(self, provider_type: ProviderType):
|
||||
"""强制使用指定的提供者"""
|
||||
self.health_checker.force_enable(provider_type)
|
||||
# 禁用其他提供者
|
||||
for pt in ProviderType:
|
||||
if pt != provider_type:
|
||||
self.health_checker.force_disable(pt, 60)
|
||||
logger.info(f"已强制使用提供者: {provider_type.value}")
|
||||
239
src/services/outlook/token_manager.py
Normal file
239
src/services/outlook/token_manager.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
Token 管理器
|
||||
支持多个 Microsoft Token 端点,自动选择合适的端点
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
from curl_cffi import requests as _requests
|
||||
|
||||
from .base import ProviderType, TokenEndpoint, TokenInfo
|
||||
from .account import OutlookAccount
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 各提供者的 Scope 配置
|
||||
PROVIDER_SCOPES = {
|
||||
ProviderType.IMAP_OLD: "", # 旧版 IMAP 不需要特定 scope
|
||||
ProviderType.IMAP_NEW: "https://outlook.office.com/IMAP.AccessAsUser.All offline_access",
|
||||
ProviderType.GRAPH_API: "https://graph.microsoft.com/.default",
|
||||
}
|
||||
|
||||
# 各提供者的 Token 端点
|
||||
PROVIDER_TOKEN_URLS = {
|
||||
ProviderType.IMAP_OLD: TokenEndpoint.LIVE.value,
|
||||
ProviderType.IMAP_NEW: TokenEndpoint.CONSUMERS.value,
|
||||
ProviderType.GRAPH_API: TokenEndpoint.COMMON.value,
|
||||
}
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""
|
||||
Token 管理器
|
||||
支持多端点 Token 获取和缓存
|
||||
"""
|
||||
|
||||
# Token 缓存: key = (email, provider_type) -> TokenInfo
|
||||
_token_cache: Dict[tuple, TokenInfo] = {}
|
||||
_cache_lock = threading.Lock()
|
||||
|
||||
# 默认超时时间
|
||||
DEFAULT_TIMEOUT = 30
|
||||
# Token 刷新提前时间(秒)
|
||||
REFRESH_BUFFER = 120
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
account: OutlookAccount,
|
||||
provider_type: ProviderType,
|
||||
proxy_url: Optional[str] = None,
|
||||
timeout: int = DEFAULT_TIMEOUT,
|
||||
):
|
||||
"""
|
||||
初始化 Token 管理器
|
||||
|
||||
Args:
|
||||
account: Outlook 账户
|
||||
provider_type: 提供者类型
|
||||
proxy_url: 代理 URL(可选)
|
||||
timeout: 请求超时时间
|
||||
"""
|
||||
self.account = account
|
||||
self.provider_type = provider_type
|
||||
self.proxy_url = proxy_url
|
||||
self.timeout = timeout
|
||||
|
||||
# 获取端点和 Scope
|
||||
self.token_url = PROVIDER_TOKEN_URLS.get(provider_type, TokenEndpoint.LIVE.value)
|
||||
self.scope = PROVIDER_SCOPES.get(provider_type, "")
|
||||
|
||||
def get_cached_token(self) -> Optional[TokenInfo]:
|
||||
"""获取缓存的 Token"""
|
||||
cache_key = (self.account.email.lower(), self.provider_type)
|
||||
with self._cache_lock:
|
||||
token = self._token_cache.get(cache_key)
|
||||
if token and not token.is_expired(self.REFRESH_BUFFER):
|
||||
return token
|
||||
return None
|
||||
|
||||
def set_cached_token(self, token: TokenInfo):
|
||||
"""缓存 Token"""
|
||||
cache_key = (self.account.email.lower(), self.provider_type)
|
||||
with self._cache_lock:
|
||||
self._token_cache[cache_key] = token
|
||||
|
||||
def clear_cache(self):
|
||||
"""清除缓存"""
|
||||
cache_key = (self.account.email.lower(), self.provider_type)
|
||||
with self._cache_lock:
|
||||
self._token_cache.pop(cache_key, None)
|
||||
|
||||
def get_access_token(self, force_refresh: bool = False) -> Optional[str]:
|
||||
"""
|
||||
获取 Access Token
|
||||
|
||||
Args:
|
||||
force_refresh: 是否强制刷新
|
||||
|
||||
Returns:
|
||||
Access Token 字符串,失败返回 None
|
||||
"""
|
||||
# 检查缓存
|
||||
if not force_refresh:
|
||||
cached = self.get_cached_token()
|
||||
if cached:
|
||||
logger.debug(f"[{self.account.email}] 使用缓存的 Token ({self.provider_type.value})")
|
||||
return cached.access_token
|
||||
|
||||
# 刷新 Token
|
||||
try:
|
||||
token = self._refresh_token()
|
||||
if token:
|
||||
self.set_cached_token(token)
|
||||
return token.access_token
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.account.email}] 获取 Token 失败 ({self.provider_type.value}): {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _refresh_token(self) -> Optional[TokenInfo]:
|
||||
"""
|
||||
刷新 Token
|
||||
|
||||
Returns:
|
||||
TokenInfo 对象,失败返回 None
|
||||
"""
|
||||
if not self.account.client_id or not self.account.refresh_token:
|
||||
raise ValueError("缺少 client_id 或 refresh_token")
|
||||
|
||||
logger.debug(f"[{self.account.email}] 正在刷新 Token ({self.provider_type.value})...")
|
||||
logger.debug(f"[{self.account.email}] Token URL: {self.token_url}")
|
||||
|
||||
# 构建请求体
|
||||
data = {
|
||||
"client_id": self.account.client_id,
|
||||
"refresh_token": self.account.refresh_token,
|
||||
"grant_type": "refresh_token",
|
||||
}
|
||||
|
||||
# 添加 Scope(如果需要)
|
||||
if self.scope:
|
||||
data["scope"] = self.scope
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
proxies = None
|
||||
if self.proxy_url:
|
||||
proxies = {"http": self.proxy_url, "https": self.proxy_url}
|
||||
|
||||
try:
|
||||
resp = _requests.post(
|
||||
self.token_url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
proxies=proxies,
|
||||
timeout=self.timeout,
|
||||
impersonate="chrome110",
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
error_body = resp.text
|
||||
logger.error(f"[{self.account.email}] Token 刷新失败: HTTP {resp.status_code}")
|
||||
logger.debug(f"[{self.account.email}] 错误响应: {error_body[:500]}")
|
||||
|
||||
if "service abuse" in error_body.lower():
|
||||
logger.warning(f"[{self.account.email}] 账号可能被封禁")
|
||||
elif "invalid_grant" in error_body.lower():
|
||||
logger.warning(f"[{self.account.email}] Refresh Token 已失效")
|
||||
|
||||
return None
|
||||
|
||||
response_data = resp.json()
|
||||
|
||||
# 解析响应
|
||||
token = TokenInfo.from_response(response_data, self.scope)
|
||||
logger.info(
|
||||
f"[{self.account.email}] Token 刷新成功 ({self.provider_type.value}), "
|
||||
f"有效期 {int(token.expires_at - time.time())} 秒"
|
||||
)
|
||||
return token
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"[{self.account.email}] JSON 解析错误: {e}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.account.email}] 未知错误: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def clear_all_cache(cls):
|
||||
"""清除所有 Token 缓存"""
|
||||
with cls._cache_lock:
|
||||
cls._token_cache.clear()
|
||||
logger.info("已清除所有 Token 缓存")
|
||||
|
||||
@classmethod
|
||||
def get_cache_stats(cls) -> Dict[str, Any]:
|
||||
"""获取缓存统计"""
|
||||
with cls._cache_lock:
|
||||
return {
|
||||
"cache_size": len(cls._token_cache),
|
||||
"entries": [
|
||||
{
|
||||
"email": key[0],
|
||||
"provider": key[1].value,
|
||||
}
|
||||
for key in cls._token_cache.keys()
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def create_token_manager(
|
||||
account: OutlookAccount,
|
||||
provider_type: ProviderType,
|
||||
proxy_url: Optional[str] = None,
|
||||
timeout: int = TokenManager.DEFAULT_TIMEOUT,
|
||||
) -> TokenManager:
|
||||
"""
|
||||
创建 Token 管理器的工厂函数
|
||||
|
||||
Args:
|
||||
account: Outlook 账户
|
||||
provider_type: 提供者类型
|
||||
proxy_url: 代理 URL
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
TokenManager 实例
|
||||
"""
|
||||
return TokenManager(account, provider_type, proxy_url, timeout)
|
||||
763
src/services/outlook_legacy_mail.py
Normal file
763
src/services/outlook_legacy_mail.py
Normal file
@@ -0,0 +1,763 @@
|
||||
"""
|
||||
Outlook 邮箱服务实现
|
||||
支持 IMAP 协议,XOAUTH2 和密码认证
|
||||
"""
|
||||
|
||||
import imaplib
|
||||
import email
|
||||
import re
|
||||
import time
|
||||
import threading
|
||||
import json
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from email.header import decode_header
|
||||
from email.utils import parsedate_to_datetime
|
||||
from urllib.error import HTTPError
|
||||
|
||||
from .base import BaseEmailService, EmailServiceError, EmailServiceType
|
||||
from ..config.constants import (
|
||||
OTP_CODE_PATTERN,
|
||||
OTP_CODE_SIMPLE_PATTERN,
|
||||
OTP_CODE_SEMANTIC_PATTERN,
|
||||
OPENAI_EMAIL_SENDERS,
|
||||
OPENAI_VERIFICATION_KEYWORDS,
|
||||
)
|
||||
from ..config.settings import get_settings
|
||||
|
||||
|
||||
def get_email_code_settings() -> dict:
|
||||
"""
|
||||
获取验证码等待配置
|
||||
|
||||
Returns:
|
||||
dict: 包含 timeout 和 poll_interval 的字典
|
||||
"""
|
||||
settings = get_settings()
|
||||
return {
|
||||
"timeout": settings.email_code_timeout,
|
||||
"poll_interval": settings.email_code_poll_interval,
|
||||
}
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutlookAccount:
|
||||
"""Outlook 账户信息"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
email: str,
|
||||
password: str,
|
||||
client_id: str = "",
|
||||
refresh_token: str = ""
|
||||
):
|
||||
self.email = email
|
||||
self.password = password
|
||||
self.client_id = client_id
|
||||
self.refresh_token = refresh_token
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "OutlookAccount":
|
||||
"""从配置创建账户"""
|
||||
return cls(
|
||||
email=config.get("email", ""),
|
||||
password=config.get("password", ""),
|
||||
client_id=config.get("client_id", ""),
|
||||
refresh_token=config.get("refresh_token", "")
|
||||
)
|
||||
|
||||
def has_oauth(self) -> bool:
|
||||
"""是否支持 OAuth2"""
|
||||
return bool(self.client_id and self.refresh_token)
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""验证账户信息是否有效"""
|
||||
return bool(self.email and self.password) or self.has_oauth()
|
||||
|
||||
|
||||
class OutlookIMAPClient:
|
||||
"""
|
||||
Outlook IMAP 客户端
|
||||
支持 XOAUTH2 和密码认证
|
||||
"""
|
||||
|
||||
# Microsoft OAuth2 Token 缓存
|
||||
_token_cache: Dict[str, tuple] = {}
|
||||
_cache_lock = threading.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
account: OutlookAccount,
|
||||
host: str = "outlook.office365.com",
|
||||
port: int = 993,
|
||||
timeout: int = 20
|
||||
):
|
||||
self.account = account
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout = timeout
|
||||
self._conn: Optional[imaplib.IMAP4_SSL] = None
|
||||
|
||||
@staticmethod
|
||||
def refresh_ms_token(account: OutlookAccount, timeout: int = 15) -> str:
|
||||
"""刷新 Microsoft access token"""
|
||||
if not account.client_id or not account.refresh_token:
|
||||
raise RuntimeError("缺少 client_id 或 refresh_token")
|
||||
|
||||
key = account.email.lower()
|
||||
with OutlookIMAPClient._cache_lock:
|
||||
cached = OutlookIMAPClient._token_cache.get(key)
|
||||
if cached and time.time() < cached[1]:
|
||||
return cached[0]
|
||||
|
||||
body = urllib.parse.urlencode({
|
||||
"client_id": account.client_id,
|
||||
"refresh_token": account.refresh_token,
|
||||
"grant_type": "refresh_token",
|
||||
"redirect_uri": "https://login.live.com/oauth20_desktop.srf",
|
||||
}).encode()
|
||||
|
||||
req = urllib.request.Request(
|
||||
"https://login.live.com/oauth20_token.srf",
|
||||
data=body,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"}
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
data = json.loads(resp.read())
|
||||
except HTTPError as e:
|
||||
raise RuntimeError(f"MS OAuth 刷新失败: {e.code}") from e
|
||||
|
||||
token = data.get("access_token")
|
||||
if not token:
|
||||
raise RuntimeError("MS OAuth 响应无 access_token")
|
||||
|
||||
ttl = int(data.get("expires_in", 3600))
|
||||
with OutlookIMAPClient._cache_lock:
|
||||
OutlookIMAPClient._token_cache[key] = (token, time.time() + ttl - 120)
|
||||
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def _build_xoauth2(email_addr: str, token: str) -> bytes:
|
||||
"""构建 XOAUTH2 认证字符串"""
|
||||
return f"user={email_addr}\x01auth=Bearer {token}\x01\x01".encode()
|
||||
|
||||
def connect(self):
|
||||
"""连接到 IMAP 服务器"""
|
||||
self._conn = imaplib.IMAP4_SSL(self.host, self.port, timeout=self.timeout)
|
||||
|
||||
# 优先使用 XOAUTH2 认证
|
||||
if self.account.has_oauth():
|
||||
try:
|
||||
token = self.refresh_ms_token(self.account)
|
||||
self._conn.authenticate(
|
||||
"XOAUTH2",
|
||||
lambda _: self._build_xoauth2(self.account.email, token)
|
||||
)
|
||||
logger.debug(f"使用 XOAUTH2 认证连接: {self.account.email}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"XOAUTH2 认证失败,回退密码认证: {e}")
|
||||
|
||||
# 回退到密码认证
|
||||
self._conn.login(self.account.email, self.account.password)
|
||||
logger.debug(f"使用密码认证连接: {self.account.email}")
|
||||
|
||||
def _ensure_connection(self):
|
||||
"""确保连接有效"""
|
||||
if self._conn:
|
||||
try:
|
||||
self._conn.noop()
|
||||
return
|
||||
except Exception:
|
||||
self.close()
|
||||
|
||||
self.connect()
|
||||
|
||||
def get_recent_emails(
|
||||
self,
|
||||
count: int = 20,
|
||||
only_unseen: bool = True,
|
||||
timeout: int = 30
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取最近的邮件
|
||||
|
||||
Args:
|
||||
count: 获取的邮件数量
|
||||
only_unseen: 是否只获取未读邮件
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
邮件列表
|
||||
"""
|
||||
self._ensure_connection()
|
||||
|
||||
flag = "UNSEEN" if only_unseen else "ALL"
|
||||
self._conn.select("INBOX", readonly=True)
|
||||
|
||||
_, data = self._conn.search(None, flag)
|
||||
if not data or not data[0]:
|
||||
return []
|
||||
|
||||
# 获取最新的邮件
|
||||
ids = data[0].split()[-count:]
|
||||
result = []
|
||||
|
||||
for mid in reversed(ids):
|
||||
try:
|
||||
_, payload = self._conn.fetch(mid, "(RFC822)")
|
||||
if not payload:
|
||||
continue
|
||||
|
||||
raw = b""
|
||||
for part in payload:
|
||||
if isinstance(part, tuple) and len(part) > 1:
|
||||
raw = part[1]
|
||||
break
|
||||
|
||||
if raw:
|
||||
result.append(self._parse_email(raw))
|
||||
except Exception as e:
|
||||
logger.warning(f"解析邮件失败 (ID: {mid}): {e}")
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _parse_email(raw: bytes) -> Dict[str, Any]:
|
||||
"""解析邮件内容"""
|
||||
# 移除可能的 BOM
|
||||
if raw.startswith(b"\xef\xbb\xbf"):
|
||||
raw = raw[3:]
|
||||
|
||||
msg = email.message_from_bytes(raw)
|
||||
|
||||
# 解析邮件头
|
||||
subject = OutlookIMAPClient._decode_header(msg.get("Subject", ""))
|
||||
sender = OutlookIMAPClient._decode_header(msg.get("From", ""))
|
||||
date_str = OutlookIMAPClient._decode_header(msg.get("Date", ""))
|
||||
to = OutlookIMAPClient._decode_header(msg.get("To", ""))
|
||||
delivered_to = OutlookIMAPClient._decode_header(msg.get("Delivered-To", ""))
|
||||
x_original_to = OutlookIMAPClient._decode_header(msg.get("X-Original-To", ""))
|
||||
|
||||
# 提取邮件正文
|
||||
body = OutlookIMAPClient._extract_body(msg)
|
||||
|
||||
# 解析日期
|
||||
date_timestamp = 0
|
||||
try:
|
||||
if date_str:
|
||||
dt = parsedate_to_datetime(date_str)
|
||||
date_timestamp = int(dt.timestamp())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"subject": subject,
|
||||
"from": sender,
|
||||
"date": date_str,
|
||||
"date_timestamp": date_timestamp,
|
||||
"to": to,
|
||||
"delivered_to": delivered_to,
|
||||
"x_original_to": x_original_to,
|
||||
"body": body,
|
||||
"raw": raw.hex()[:100] # 存储原始数据的部分哈希用于调试
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _decode_header(header: str) -> str:
|
||||
"""解码邮件头"""
|
||||
if not header:
|
||||
return ""
|
||||
|
||||
parts = []
|
||||
for chunk, encoding in decode_header(header):
|
||||
if isinstance(chunk, bytes):
|
||||
try:
|
||||
decoded = chunk.decode(encoding or "utf-8", errors="replace")
|
||||
parts.append(decoded)
|
||||
except Exception:
|
||||
parts.append(chunk.decode("utf-8", errors="replace"))
|
||||
else:
|
||||
parts.append(chunk)
|
||||
|
||||
return "".join(parts).strip()
|
||||
|
||||
@staticmethod
|
||||
def _extract_body(msg) -> str:
|
||||
"""提取邮件正文"""
|
||||
import html as html_module
|
||||
|
||||
texts = []
|
||||
parts = msg.walk() if msg.is_multipart() else [msg]
|
||||
|
||||
for part in parts:
|
||||
content_type = part.get_content_type()
|
||||
if content_type not in ("text/plain", "text/html"):
|
||||
continue
|
||||
|
||||
payload = part.get_payload(decode=True)
|
||||
if not payload:
|
||||
continue
|
||||
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
try:
|
||||
text = payload.decode(charset, errors="replace")
|
||||
except LookupError:
|
||||
text = payload.decode("utf-8", errors="replace")
|
||||
|
||||
# 如果是 HTML,移除标签
|
||||
if "<html" in text.lower():
|
||||
text = re.sub(r"<[^>]+>", " ", text)
|
||||
|
||||
texts.append(text)
|
||||
|
||||
# 合并并清理文本
|
||||
combined = " ".join(texts)
|
||||
combined = html_module.unescape(combined)
|
||||
combined = re.sub(r"\s+", " ", combined).strip()
|
||||
|
||||
return combined
|
||||
|
||||
def close(self):
|
||||
"""关闭连接"""
|
||||
if self._conn:
|
||||
try:
|
||||
self._conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
self._conn = None
|
||||
|
||||
def __enter__(self):
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
|
||||
class OutlookService(BaseEmailService):
|
||||
"""
|
||||
Outlook 邮箱服务
|
||||
支持多个 Outlook 账户的轮询和验证码获取
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None, name: str = None):
|
||||
"""
|
||||
初始化 Outlook 服务
|
||||
|
||||
Args:
|
||||
config: 配置字典,支持以下键:
|
||||
- accounts: Outlook 账户列表,每个账户包含:
|
||||
- email: 邮箱地址
|
||||
- password: 密码
|
||||
- client_id: OAuth2 client_id (可选)
|
||||
- refresh_token: OAuth2 refresh_token (可选)
|
||||
- imap_host: IMAP 服务器 (默认: outlook.office365.com)
|
||||
- imap_port: IMAP 端口 (默认: 993)
|
||||
- timeout: 超时时间 (默认: 30)
|
||||
- max_retries: 最大重试次数 (默认: 3)
|
||||
name: 服务名称
|
||||
"""
|
||||
super().__init__(EmailServiceType.OUTLOOK, name)
|
||||
|
||||
# 默认配置
|
||||
default_config = {
|
||||
"accounts": [],
|
||||
"imap_host": "outlook.office365.com",
|
||||
"imap_port": 993,
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
"proxy_url": None,
|
||||
}
|
||||
|
||||
self.config = {**default_config, **(config or {})}
|
||||
|
||||
# 解析账户
|
||||
self.accounts: List[OutlookAccount] = []
|
||||
self._current_account_index = 0
|
||||
self._account_locks: Dict[str, threading.Lock] = {}
|
||||
|
||||
# 支持两种配置格式:
|
||||
# 1. 单个账户格式:{"email": "xxx", "password": "xxx"}
|
||||
# 2. 多账户格式:{"accounts": [{"email": "xxx", "password": "xxx"}]}
|
||||
if "email" in self.config and "password" in self.config:
|
||||
# 单个账户格式
|
||||
account = OutlookAccount.from_config(self.config)
|
||||
if account.validate():
|
||||
self.accounts.append(account)
|
||||
self._account_locks[account.email] = threading.Lock()
|
||||
else:
|
||||
logger.warning(f"无效的 Outlook 账户配置: {self.config}")
|
||||
else:
|
||||
# 多账户格式
|
||||
for account_config in self.config.get("accounts", []):
|
||||
account = OutlookAccount.from_config(account_config)
|
||||
if account.validate():
|
||||
self.accounts.append(account)
|
||||
self._account_locks[account.email] = threading.Lock()
|
||||
else:
|
||||
logger.warning(f"无效的 Outlook 账户配置: {account_config}")
|
||||
|
||||
if not self.accounts:
|
||||
logger.warning("未配置有效的 Outlook 账户")
|
||||
|
||||
# IMAP 连接限制(防止限流)
|
||||
self._imap_semaphore = threading.Semaphore(5)
|
||||
|
||||
# 验证码去重机制:email -> set of used codes
|
||||
self._used_codes: Dict[str, set] = {}
|
||||
|
||||
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
选择可用的 Outlook 账户
|
||||
|
||||
Args:
|
||||
config: 配置参数(目前未使用)
|
||||
|
||||
Returns:
|
||||
包含邮箱信息的字典:
|
||||
- email: 邮箱地址
|
||||
- service_id: 账户邮箱(同 email)
|
||||
- account: 账户信息
|
||||
"""
|
||||
if not self.accounts:
|
||||
self.update_status(False, EmailServiceError("没有可用的 Outlook 账户"))
|
||||
raise EmailServiceError("没有可用的 Outlook 账户")
|
||||
|
||||
# 轮询选择账户
|
||||
with threading.Lock():
|
||||
account = self.accounts[self._current_account_index]
|
||||
self._current_account_index = (self._current_account_index + 1) % len(self.accounts)
|
||||
|
||||
email_info = {
|
||||
"email": account.email,
|
||||
"service_id": account.email, # 对于 Outlook,service_id 就是邮箱地址
|
||||
"account": {
|
||||
"email": account.email,
|
||||
"has_oauth": account.has_oauth()
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(f"选择 Outlook 账户: {account.email}")
|
||||
self.update_status(True)
|
||||
return email_info
|
||||
|
||||
def get_verification_code(
|
||||
self,
|
||||
email: str,
|
||||
email_id: str = None,
|
||||
timeout: int = None,
|
||||
pattern: str = OTP_CODE_PATTERN,
|
||||
otp_sent_at: Optional[float] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
从 Outlook 邮箱获取验证码
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
email_id: 未使用(对于 Outlook,email 就是标识)
|
||||
timeout: 超时时间(秒),默认使用配置值
|
||||
pattern: 验证码正则表达式
|
||||
otp_sent_at: OTP 发送时间戳,用于过滤旧邮件
|
||||
|
||||
Returns:
|
||||
验证码字符串,如果超时或未找到返回 None
|
||||
"""
|
||||
# 查找对应的账户
|
||||
account = None
|
||||
for acc in self.accounts:
|
||||
if acc.email.lower() == email.lower():
|
||||
account = acc
|
||||
break
|
||||
|
||||
if not account:
|
||||
self.update_status(False, EmailServiceError(f"未找到邮箱对应的账户: {email}"))
|
||||
return None
|
||||
|
||||
# 从数据库获取验证码等待配置
|
||||
code_settings = get_email_code_settings()
|
||||
actual_timeout = timeout or code_settings["timeout"]
|
||||
poll_interval = code_settings["poll_interval"]
|
||||
|
||||
logger.info(f"[{email}] 开始获取验证码,超时 {actual_timeout}s,OTP发送时间: {otp_sent_at}")
|
||||
|
||||
# 初始化验证码去重集合
|
||||
if email not in self._used_codes:
|
||||
self._used_codes[email] = set()
|
||||
used_codes = self._used_codes[email]
|
||||
|
||||
# 计算最小时间戳(留出 60 秒时钟偏差)
|
||||
min_timestamp = (otp_sent_at - 60) if otp_sent_at else 0
|
||||
|
||||
start_time = time.time()
|
||||
poll_count = 0
|
||||
|
||||
while time.time() - start_time < actual_timeout:
|
||||
poll_count += 1
|
||||
loop_start = time.time()
|
||||
|
||||
# 渐进式邮件检查:前 3 次只检查未读,之后检查全部
|
||||
only_unseen = poll_count <= 3
|
||||
|
||||
try:
|
||||
connect_start = time.time()
|
||||
with self._imap_semaphore:
|
||||
with OutlookIMAPClient(
|
||||
account,
|
||||
host=self.config["imap_host"],
|
||||
port=self.config["imap_port"],
|
||||
timeout=10
|
||||
) as client:
|
||||
connect_elapsed = time.time() - connect_start
|
||||
logger.debug(f"[{email}] IMAP 连接耗时 {connect_elapsed:.2f}s")
|
||||
|
||||
# 搜索邮件
|
||||
search_start = time.time()
|
||||
emails = client.get_recent_emails(count=15, only_unseen=only_unseen)
|
||||
search_elapsed = time.time() - search_start
|
||||
logger.debug(f"[{email}] 搜索到 {len(emails)} 封邮件(未读={only_unseen}),耗时 {search_elapsed:.2f}s")
|
||||
|
||||
for mail in emails:
|
||||
# 时间戳过滤
|
||||
mail_ts = mail.get("date_timestamp", 0)
|
||||
if min_timestamp > 0 and mail_ts > 0 and mail_ts < min_timestamp:
|
||||
logger.debug(f"[{email}] 跳过旧邮件: {mail.get('subject', '')[:50]}")
|
||||
continue
|
||||
|
||||
# 检查是否是 OpenAI 验证邮件
|
||||
if not self._is_openai_verification_mail(mail, email):
|
||||
continue
|
||||
|
||||
# 提取验证码
|
||||
code = self._extract_code_from_mail(mail, pattern)
|
||||
if code:
|
||||
# 去重检查
|
||||
if code in used_codes:
|
||||
logger.debug(f"[{email}] 跳过已使用的验证码: {code}")
|
||||
continue
|
||||
|
||||
used_codes.add(code)
|
||||
elapsed = int(time.time() - start_time)
|
||||
logger.info(f"[{email}] 找到验证码: {code},总耗时 {elapsed}s,轮询 {poll_count} 次")
|
||||
self.update_status(True)
|
||||
return code
|
||||
|
||||
except Exception as e:
|
||||
loop_elapsed = time.time() - loop_start
|
||||
logger.warning(f"[{email}] 检查出错: {e},循环耗时 {loop_elapsed:.2f}s")
|
||||
|
||||
# 等待下次轮询
|
||||
time.sleep(poll_interval)
|
||||
|
||||
elapsed = int(time.time() - start_time)
|
||||
logger.warning(f"[{email}] 验证码超时 ({actual_timeout}s),共轮询 {poll_count} 次")
|
||||
return None
|
||||
|
||||
def list_emails(self, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列出所有可用的 Outlook 账户
|
||||
|
||||
Returns:
|
||||
账户列表
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"email": account.email,
|
||||
"id": account.email,
|
||||
"has_oauth": account.has_oauth(),
|
||||
"type": "outlook"
|
||||
}
|
||||
for account in self.accounts
|
||||
]
|
||||
|
||||
def delete_email(self, email_id: str) -> bool:
|
||||
"""
|
||||
删除邮箱(对于 Outlook,不支持删除账户)
|
||||
|
||||
Args:
|
||||
email_id: 邮箱地址
|
||||
|
||||
Returns:
|
||||
False(Outlook 不支持删除账户)
|
||||
"""
|
||||
logger.warning(f"Outlook 服务不支持删除账户: {email_id}")
|
||||
return False
|
||||
|
||||
def check_health(self) -> bool:
|
||||
"""检查 Outlook 服务是否可用"""
|
||||
if not self.accounts:
|
||||
self.update_status(False, EmailServiceError("没有配置的账户"))
|
||||
return False
|
||||
|
||||
# 测试第一个账户的连接
|
||||
test_account = self.accounts[0]
|
||||
try:
|
||||
with self._imap_semaphore:
|
||||
with OutlookIMAPClient(
|
||||
test_account,
|
||||
host=self.config["imap_host"],
|
||||
port=self.config["imap_port"],
|
||||
timeout=10
|
||||
) as client:
|
||||
# 尝试列出邮箱(快速测试)
|
||||
client._conn.select("INBOX", readonly=True)
|
||||
self.update_status(True)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Outlook 健康检查失败 ({test_account.email}): {e}")
|
||||
self.update_status(False, e)
|
||||
return False
|
||||
|
||||
def _is_oai_mail(self, mail: Dict[str, Any]) -> bool:
|
||||
"""判断是否为 OpenAI 相关邮件(旧方法,保留兼容)"""
|
||||
combined = f"{mail.get('from', '')} {mail.get('subject', '')} {mail.get('body', '')}".lower()
|
||||
keywords = ["openai", "chatgpt", "verification", "验证码", "code"]
|
||||
return any(keyword in combined for keyword in keywords)
|
||||
|
||||
def _is_openai_verification_mail(
|
||||
self,
|
||||
mail: Dict[str, Any],
|
||||
target_email: str = None
|
||||
) -> bool:
|
||||
"""
|
||||
严格判断是否为 OpenAI 验证邮件
|
||||
|
||||
Args:
|
||||
mail: 邮件信息字典
|
||||
target_email: 目标邮箱地址(用于验证收件人)
|
||||
|
||||
Returns:
|
||||
是否为 OpenAI 验证邮件
|
||||
"""
|
||||
sender = mail.get("from", "").lower()
|
||||
|
||||
# 1. 发件人必须是 OpenAI
|
||||
valid_senders = OPENAI_EMAIL_SENDERS
|
||||
if not any(s in sender for s in valid_senders):
|
||||
logger.debug(f"邮件发件人非 OpenAI: {sender}")
|
||||
return False
|
||||
|
||||
# 2. 主题或正文包含验证关键词
|
||||
subject = mail.get("subject", "").lower()
|
||||
body = mail.get("body", "").lower()
|
||||
verification_keywords = OPENAI_VERIFICATION_KEYWORDS
|
||||
combined = f"{subject} {body}"
|
||||
if not any(kw in combined for kw in verification_keywords):
|
||||
logger.debug(f"邮件未包含验证关键词: {subject[:50]}")
|
||||
return False
|
||||
|
||||
# 3. 验证收件人(可选)
|
||||
if target_email:
|
||||
recipients = f"{mail.get('to', '')} {mail.get('delivered_to', '')} {mail.get('x_original_to', '')}".lower()
|
||||
if target_email.lower() not in recipients:
|
||||
logger.debug(f"邮件收件人不匹配: {recipients[:50]}")
|
||||
return False
|
||||
|
||||
logger.debug(f"识别为 OpenAI 验证邮件: {subject[:50]}")
|
||||
return True
|
||||
|
||||
def _extract_code_from_mail(
|
||||
self,
|
||||
mail: Dict[str, Any],
|
||||
fallback_pattern: str = OTP_CODE_PATTERN
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
从邮件中提取验证码
|
||||
|
||||
优先级:
|
||||
1. 从主题提取(6位数字)
|
||||
2. 从正文用语义正则提取(如 "code is 123456")
|
||||
3. 兜底:任意 6 位数字
|
||||
|
||||
Args:
|
||||
mail: 邮件信息字典
|
||||
fallback_pattern: 兜底正则表达式
|
||||
|
||||
Returns:
|
||||
验证码字符串,如果未找到返回 None
|
||||
"""
|
||||
# 编译正则
|
||||
re_simple = re.compile(OTP_CODE_SIMPLE_PATTERN)
|
||||
re_semantic = re.compile(OTP_CODE_SEMANTIC_PATTERN, re.IGNORECASE)
|
||||
|
||||
# 1. 主题优先
|
||||
subject = mail.get("subject", "")
|
||||
match = re_simple.search(subject)
|
||||
if match:
|
||||
code = match.group(1)
|
||||
logger.debug(f"从主题提取验证码: {code}")
|
||||
return code
|
||||
|
||||
# 2. 正文语义匹配
|
||||
body = mail.get("body", "")
|
||||
match = re_semantic.search(body)
|
||||
if match:
|
||||
code = match.group(1)
|
||||
logger.debug(f"从正文语义提取验证码: {code}")
|
||||
return code
|
||||
|
||||
# 3. 兜底:任意 6 位数字
|
||||
match = re_simple.search(body)
|
||||
if match:
|
||||
code = match.group(1)
|
||||
logger.debug(f"从正文兜底提取验证码: {code}")
|
||||
return code
|
||||
|
||||
return None
|
||||
|
||||
def get_account_stats(self) -> Dict[str, Any]:
|
||||
"""获取账户统计信息"""
|
||||
total = len(self.accounts)
|
||||
oauth_count = sum(1 for acc in self.accounts if acc.has_oauth())
|
||||
|
||||
return {
|
||||
"total_accounts": total,
|
||||
"oauth_accounts": oauth_count,
|
||||
"password_accounts": total - oauth_count,
|
||||
"accounts": [
|
||||
{
|
||||
"email": acc.email,
|
||||
"has_oauth": acc.has_oauth()
|
||||
}
|
||||
for acc in self.accounts
|
||||
]
|
||||
}
|
||||
|
||||
def add_account(self, account_config: Dict[str, Any]) -> bool:
|
||||
"""添加新的 Outlook 账户"""
|
||||
try:
|
||||
account = OutlookAccount.from_config(account_config)
|
||||
if not account.validate():
|
||||
return False
|
||||
|
||||
self.accounts.append(account)
|
||||
self._account_locks[account.email] = threading.Lock()
|
||||
logger.info(f"添加 Outlook 账户: {account.email}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"添加 Outlook 账户失败: {e}")
|
||||
return False
|
||||
|
||||
def remove_account(self, email: str) -> bool:
|
||||
"""移除 Outlook 账户"""
|
||||
for i, acc in enumerate(self.accounts):
|
||||
if acc.email.lower() == email.lower():
|
||||
self.accounts.pop(i)
|
||||
self._account_locks.pop(email, None)
|
||||
logger.info(f"移除 Outlook 账户: {email}")
|
||||
return True
|
||||
return False
|
||||
455
src/services/temp_mail.py
Normal file
455
src/services/temp_mail.py
Normal file
@@ -0,0 +1,455 @@
|
||||
"""
|
||||
Temp-Mail 邮箱服务实现
|
||||
基于自部署 Cloudflare Worker 临时邮箱服务
|
||||
接口文档参见 plan/temp-mail.md
|
||||
"""
|
||||
|
||||
import re
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
from email import message_from_string
|
||||
from email.header import decode_header, make_header
|
||||
from email.message import Message
|
||||
from email.policy import default as email_policy
|
||||
from html import unescape
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from .base import BaseEmailService, EmailServiceError, EmailServiceType
|
||||
from ..core.http_client import HTTPClient, RequestConfig
|
||||
from ..config.constants import OTP_CODE_PATTERN
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TempMailService(BaseEmailService):
|
||||
"""
|
||||
Temp-Mail 邮箱服务
|
||||
基于自部署 Cloudflare Worker 的临时邮箱,admin 模式管理邮箱
|
||||
不走代理,不使用 requests 库
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None, name: str = None):
|
||||
"""
|
||||
初始化 TempMail 服务
|
||||
|
||||
Args:
|
||||
config: 配置字典,支持以下键:
|
||||
- base_url: Worker 域名地址,如 https://mail.example.com (必需)
|
||||
- admin_password: Admin 密码,对应 x-admin-auth header (必需)
|
||||
- domain: 邮箱域名,如 example.com (必需)
|
||||
- enable_prefix: 是否启用前缀,默认 True
|
||||
- timeout: 请求超时时间,默认 30
|
||||
- max_retries: 最大重试次数,默认 3
|
||||
name: 服务名称
|
||||
"""
|
||||
super().__init__(EmailServiceType.TEMP_MAIL, name)
|
||||
|
||||
required_keys = ["base_url", "admin_password", "domain"]
|
||||
missing_keys = [key for key in required_keys if not (config or {}).get(key)]
|
||||
if missing_keys:
|
||||
raise ValueError(f"缺少必需配置: {missing_keys}")
|
||||
|
||||
default_config = {
|
||||
"enable_prefix": True,
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
}
|
||||
self.config = {**default_config, **(config or {})}
|
||||
|
||||
# 不走代理,proxy_url=None
|
||||
http_config = RequestConfig(
|
||||
timeout=self.config["timeout"],
|
||||
max_retries=self.config["max_retries"],
|
||||
)
|
||||
self.http_client = HTTPClient(proxy_url=None, config=http_config)
|
||||
|
||||
# 邮箱缓存:email -> {jwt, address}
|
||||
self._email_cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def _decode_mime_header(self, value: str) -> str:
|
||||
"""解码 MIME 头,兼容 RFC 2047 编码主题。"""
|
||||
if not value:
|
||||
return ""
|
||||
try:
|
||||
return str(make_header(decode_header(value)))
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
def _extract_body_from_message(self, message: Message) -> str:
|
||||
"""从 MIME 邮件对象中提取可读正文。"""
|
||||
parts: List[str] = []
|
||||
|
||||
if message.is_multipart():
|
||||
for part in message.walk():
|
||||
if part.get_content_maintype() == "multipart":
|
||||
continue
|
||||
|
||||
content_type = (part.get_content_type() or "").lower()
|
||||
if content_type not in ("text/plain", "text/html"):
|
||||
continue
|
||||
|
||||
try:
|
||||
payload = part.get_payload(decode=True)
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
text = payload.decode(charset, errors="replace") if payload else ""
|
||||
except Exception:
|
||||
try:
|
||||
text = part.get_content()
|
||||
except Exception:
|
||||
text = ""
|
||||
|
||||
if content_type == "text/html":
|
||||
text = re.sub(r"<[^>]+>", " ", text)
|
||||
parts.append(text)
|
||||
else:
|
||||
try:
|
||||
payload = message.get_payload(decode=True)
|
||||
charset = message.get_content_charset() or "utf-8"
|
||||
body = payload.decode(charset, errors="replace") if payload else ""
|
||||
except Exception:
|
||||
try:
|
||||
body = message.get_content()
|
||||
except Exception:
|
||||
body = str(message.get_payload() or "")
|
||||
|
||||
if "html" in (message.get_content_type() or "").lower():
|
||||
body = re.sub(r"<[^>]+>", " ", body)
|
||||
parts.append(body)
|
||||
|
||||
return unescape("\n".join(part for part in parts if part).strip())
|
||||
|
||||
def _extract_mail_fields(self, mail: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""统一提取邮件字段,兼容 raw MIME 和不同 Worker 返回格式。"""
|
||||
sender = str(
|
||||
mail.get("source")
|
||||
or mail.get("from")
|
||||
or mail.get("from_address")
|
||||
or mail.get("fromAddress")
|
||||
or ""
|
||||
).strip()
|
||||
subject = str(mail.get("subject") or mail.get("title") or "").strip()
|
||||
body_text = str(
|
||||
mail.get("text")
|
||||
or mail.get("body")
|
||||
or mail.get("content")
|
||||
or mail.get("html")
|
||||
or ""
|
||||
).strip()
|
||||
raw = str(mail.get("raw") or "").strip()
|
||||
|
||||
if raw:
|
||||
try:
|
||||
message = message_from_string(raw, policy=email_policy)
|
||||
sender = sender or self._decode_mime_header(message.get("From", ""))
|
||||
subject = subject or self._decode_mime_header(message.get("Subject", ""))
|
||||
parsed_body = self._extract_body_from_message(message)
|
||||
if parsed_body:
|
||||
body_text = f"{body_text}\n{parsed_body}".strip() if body_text else parsed_body
|
||||
except Exception as e:
|
||||
logger.debug(f"解析 TempMail raw 邮件失败: {e}")
|
||||
body_text = f"{body_text}\n{raw}".strip() if body_text else raw
|
||||
|
||||
body_text = unescape(re.sub(r"<[^>]+>", " ", body_text))
|
||||
return {
|
||||
"sender": sender,
|
||||
"subject": subject,
|
||||
"body": body_text,
|
||||
"raw": raw,
|
||||
}
|
||||
|
||||
def _admin_headers(self) -> Dict[str, str]:
|
||||
"""构造 admin 请求头"""
|
||||
return {
|
||||
"x-admin-auth": self.config["admin_password"],
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
def _make_request(self, method: str, path: str, **kwargs) -> Any:
|
||||
"""
|
||||
发送请求并返回 JSON 数据
|
||||
|
||||
Args:
|
||||
method: HTTP 方法
|
||||
path: 请求路径(以 / 开头)
|
||||
**kwargs: 传递给 http_client.request 的额外参数
|
||||
|
||||
Returns:
|
||||
响应 JSON 数据
|
||||
|
||||
Raises:
|
||||
EmailServiceError: 请求失败
|
||||
"""
|
||||
base_url = self.config["base_url"].rstrip("/")
|
||||
url = f"{base_url}{path}"
|
||||
|
||||
# 合并默认 admin headers
|
||||
kwargs.setdefault("headers", {})
|
||||
for k, v in self._admin_headers().items():
|
||||
kwargs["headers"].setdefault(k, v)
|
||||
|
||||
try:
|
||||
response = self.http_client.request(method, url, **kwargs)
|
||||
|
||||
if response.status_code >= 400:
|
||||
error_msg = f"请求失败: {response.status_code}"
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = f"{error_msg} - {error_data}"
|
||||
except Exception:
|
||||
error_msg = f"{error_msg} - {response.text[:200]}"
|
||||
self.update_status(False, EmailServiceError(error_msg))
|
||||
raise EmailServiceError(error_msg)
|
||||
|
||||
try:
|
||||
return response.json()
|
||||
except json.JSONDecodeError:
|
||||
return {"raw_response": response.text}
|
||||
|
||||
except Exception as e:
|
||||
self.update_status(False, e)
|
||||
if isinstance(e, EmailServiceError):
|
||||
raise
|
||||
raise EmailServiceError(f"请求失败: {method} {path} - {e}")
|
||||
|
||||
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
通过 admin API 创建临时邮箱
|
||||
|
||||
Returns:
|
||||
包含邮箱信息的字典:
|
||||
- email: 邮箱地址
|
||||
- jwt: 用户级 JWT token
|
||||
- service_id: 同 email(用作标识)
|
||||
"""
|
||||
import random
|
||||
import string
|
||||
|
||||
# 生成随机邮箱名
|
||||
letters = ''.join(random.choices(string.ascii_lowercase, k=5))
|
||||
digits = ''.join(random.choices(string.digits, k=random.randint(1, 3)))
|
||||
suffix = ''.join(random.choices(string.ascii_lowercase, k=random.randint(1, 3)))
|
||||
name = letters + digits + suffix
|
||||
|
||||
domain = self.config["domain"]
|
||||
enable_prefix = self.config.get("enable_prefix", True)
|
||||
|
||||
body = {
|
||||
"enablePrefix": enable_prefix,
|
||||
"name": name,
|
||||
"domain": domain,
|
||||
}
|
||||
|
||||
try:
|
||||
response = self._make_request("POST", "/admin/new_address", json=body)
|
||||
|
||||
address = response.get("address", "").strip()
|
||||
jwt = response.get("jwt", "").strip()
|
||||
|
||||
if not address:
|
||||
raise EmailServiceError(f"API 返回数据不完整: {response}")
|
||||
|
||||
email_info = {
|
||||
"email": address,
|
||||
"jwt": jwt,
|
||||
"service_id": address,
|
||||
"id": address,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
|
||||
# 缓存 jwt,供获取验证码时使用
|
||||
self._email_cache[address] = email_info
|
||||
|
||||
logger.info(f"成功创建 TempMail 邮箱: {address}")
|
||||
self.update_status(True)
|
||||
return email_info
|
||||
|
||||
except Exception as e:
|
||||
self.update_status(False, e)
|
||||
if isinstance(e, EmailServiceError):
|
||||
raise
|
||||
raise EmailServiceError(f"创建邮箱失败: {e}")
|
||||
|
||||
def get_verification_code(
|
||||
self,
|
||||
email: str,
|
||||
email_id: str = None,
|
||||
timeout: int = 120,
|
||||
pattern: str = OTP_CODE_PATTERN,
|
||||
otp_sent_at: Optional[float] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
从 TempMail 邮箱获取验证码
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
email_id: 未使用,保留接口兼容
|
||||
timeout: 超时时间(秒)
|
||||
pattern: 验证码正则
|
||||
otp_sent_at: OTP 发送时间戳(暂未使用)
|
||||
|
||||
Returns:
|
||||
验证码字符串,超时返回 None
|
||||
"""
|
||||
logger.info(f"正在从 TempMail 邮箱 {email} 获取验证码...")
|
||||
|
||||
start_time = time.time()
|
||||
seen_mail_ids: set = set()
|
||||
|
||||
# 优先使用用户级 JWT,回退到 admin API
|
||||
cached = self._email_cache.get(email, {})
|
||||
jwt = cached.get("jwt")
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
if jwt:
|
||||
response = self._make_request(
|
||||
"GET",
|
||||
"/user_api/mails",
|
||||
params={"limit": 20, "offset": 0},
|
||||
headers={"x-user-token": jwt, "Content-Type": "application/json", "Accept": "application/json"},
|
||||
)
|
||||
else:
|
||||
response = self._make_request(
|
||||
"GET",
|
||||
"/admin/mails",
|
||||
params={"limit": 20, "offset": 0, "address": email},
|
||||
)
|
||||
|
||||
# /user_api/mails 和 /admin/mails 返回格式相同: {"results": [...], "total": N}
|
||||
mails = response.get("results", [])
|
||||
if not isinstance(mails, list):
|
||||
time.sleep(3)
|
||||
continue
|
||||
|
||||
for mail in mails:
|
||||
mail_id = mail.get("id")
|
||||
if not mail_id or mail_id in seen_mail_ids:
|
||||
continue
|
||||
|
||||
seen_mail_ids.add(mail_id)
|
||||
|
||||
parsed = self._extract_mail_fields(mail)
|
||||
sender = parsed["sender"].lower()
|
||||
subject = parsed["subject"]
|
||||
body_text = parsed["body"]
|
||||
raw_text = parsed["raw"]
|
||||
content = f"{sender}\n{subject}\n{body_text}\n{raw_text}".strip()
|
||||
|
||||
# 只处理 OpenAI 邮件
|
||||
if "openai" not in sender and "openai" not in content.lower():
|
||||
continue
|
||||
|
||||
match = re.search(pattern, content)
|
||||
if match:
|
||||
code = match.group(1)
|
||||
logger.info(f"从 TempMail 邮箱 {email} 找到验证码: {code}")
|
||||
self.update_status(True)
|
||||
return code
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"检查 TempMail 邮件时出错: {e}")
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
logger.warning(f"等待 TempMail 验证码超时: {email}")
|
||||
return None
|
||||
|
||||
def list_emails(self, limit: int = 100, offset: int = 0, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列出邮箱
|
||||
|
||||
Args:
|
||||
limit: 返回数量上限
|
||||
offset: 分页偏移
|
||||
**kwargs: 额外查询参数,透传给 admin API
|
||||
|
||||
Returns:
|
||||
邮箱列表
|
||||
"""
|
||||
params = {
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
params.update({k: v for k, v in kwargs.items() if v is not None})
|
||||
|
||||
try:
|
||||
response = self._make_request("GET", "/admin/mails", params=params)
|
||||
mails = response.get("results", [])
|
||||
if not isinstance(mails, list):
|
||||
raise EmailServiceError(f"API 返回数据格式错误: {response}")
|
||||
|
||||
emails: List[Dict[str, Any]] = []
|
||||
for mail in mails:
|
||||
address = (mail.get("address") or "").strip()
|
||||
mail_id = mail.get("id") or address
|
||||
email_info = {
|
||||
"id": mail_id,
|
||||
"service_id": mail_id,
|
||||
"email": address,
|
||||
"subject": mail.get("subject"),
|
||||
"from": mail.get("source"),
|
||||
"created_at": mail.get("createdAt") or mail.get("created_at"),
|
||||
"raw_data": mail,
|
||||
}
|
||||
emails.append(email_info)
|
||||
|
||||
if address:
|
||||
cached = self._email_cache.get(address, {})
|
||||
self._email_cache[address] = {**cached, **email_info}
|
||||
|
||||
self.update_status(True)
|
||||
return emails
|
||||
except Exception as e:
|
||||
logger.warning(f"列出 TempMail 邮箱失败: {e}")
|
||||
self.update_status(False, e)
|
||||
return list(self._email_cache.values())
|
||||
|
||||
def delete_email(self, email_id: str) -> bool:
|
||||
"""
|
||||
删除邮箱
|
||||
|
||||
Note:
|
||||
当前 TempMail admin API 文档未见删除地址接口,这里先从本地缓存移除,
|
||||
以满足统一接口并避免服务实例化失败。
|
||||
"""
|
||||
removed = False
|
||||
emails_to_delete = []
|
||||
|
||||
for address, info in self._email_cache.items():
|
||||
candidate_ids = {
|
||||
address,
|
||||
info.get("id"),
|
||||
info.get("service_id"),
|
||||
}
|
||||
if email_id in candidate_ids:
|
||||
emails_to_delete.append(address)
|
||||
|
||||
for address in emails_to_delete:
|
||||
self._email_cache.pop(address, None)
|
||||
removed = True
|
||||
|
||||
if removed:
|
||||
logger.info(f"已从 TempMail 缓存移除邮箱: {email_id}")
|
||||
self.update_status(True)
|
||||
else:
|
||||
logger.info(f"TempMail 缓存中未找到邮箱: {email_id}")
|
||||
|
||||
return removed
|
||||
|
||||
def check_health(self) -> bool:
|
||||
"""检查服务健康状态"""
|
||||
try:
|
||||
self._make_request(
|
||||
"GET",
|
||||
"/admin/mails",
|
||||
params={"limit": 1, "offset": 0},
|
||||
)
|
||||
self.update_status(True)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"TempMail 健康检查失败: {e}")
|
||||
self.update_status(False, e)
|
||||
return False
|
||||
400
src/services/tempmail.py
Normal file
400
src/services/tempmail.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""
|
||||
Tempmail.lol 邮箱服务实现
|
||||
"""
|
||||
|
||||
import re
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
import json
|
||||
|
||||
from curl_cffi import requests as cffi_requests
|
||||
|
||||
from .base import BaseEmailService, EmailServiceError, EmailServiceType
|
||||
from ..core.http_client import HTTPClient, RequestConfig
|
||||
from ..config.constants import OTP_CODE_PATTERN
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TempmailService(BaseEmailService):
|
||||
"""
|
||||
Tempmail.lol 邮箱服务
|
||||
基于 Tempmail.lol API v2
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any] = None, name: str = None):
|
||||
"""
|
||||
初始化 Tempmail 服务
|
||||
|
||||
Args:
|
||||
config: 配置字典,支持以下键:
|
||||
- base_url: API 基础地址 (默认: https://api.tempmail.lol/v2)
|
||||
- timeout: 请求超时时间 (默认: 30)
|
||||
- max_retries: 最大重试次数 (默认: 3)
|
||||
- proxy_url: 代理 URL
|
||||
name: 服务名称
|
||||
"""
|
||||
super().__init__(EmailServiceType.TEMPMAIL, name)
|
||||
|
||||
# 默认配置
|
||||
default_config = {
|
||||
"base_url": "https://api.tempmail.lol/v2",
|
||||
"timeout": 30,
|
||||
"max_retries": 3,
|
||||
"proxy_url": None,
|
||||
}
|
||||
|
||||
self.config = {**default_config, **(config or {})}
|
||||
|
||||
# 创建 HTTP 客户端
|
||||
http_config = RequestConfig(
|
||||
timeout=self.config["timeout"],
|
||||
max_retries=self.config["max_retries"],
|
||||
)
|
||||
self.http_client = HTTPClient(
|
||||
proxy_url=self.config.get("proxy_url"),
|
||||
config=http_config
|
||||
)
|
||||
|
||||
# 状态变量
|
||||
self._email_cache: Dict[str, Dict[str, Any]] = {}
|
||||
self._last_check_time: float = 0
|
||||
|
||||
def create_email(self, config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
创建新的临时邮箱
|
||||
|
||||
Args:
|
||||
config: 配置参数(Tempmail.lol 目前不支持自定义配置)
|
||||
|
||||
Returns:
|
||||
包含邮箱信息的字典:
|
||||
- email: 邮箱地址
|
||||
- service_id: 邮箱 token
|
||||
- token: 邮箱 token(同 service_id)
|
||||
- created_at: 创建时间戳
|
||||
"""
|
||||
try:
|
||||
# 发送创建请求
|
||||
response = self.http_client.post(
|
||||
f"{self.config['base_url']}/inbox/create",
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={}
|
||||
)
|
||||
|
||||
if response.status_code not in (200, 201):
|
||||
self.update_status(False, EmailServiceError(f"请求失败,状态码: {response.status_code}"))
|
||||
raise EmailServiceError(f"Tempmail.lol 请求失败,状态码: {response.status_code}")
|
||||
|
||||
data = response.json()
|
||||
email = str(data.get("address", "")).strip()
|
||||
token = str(data.get("token", "")).strip()
|
||||
|
||||
if not email or not token:
|
||||
self.update_status(False, EmailServiceError("返回数据不完整"))
|
||||
raise EmailServiceError("Tempmail.lol 返回数据不完整")
|
||||
|
||||
# 缓存邮箱信息
|
||||
email_info = {
|
||||
"email": email,
|
||||
"service_id": token,
|
||||
"token": token,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
self._email_cache[email] = email_info
|
||||
|
||||
logger.info(f"Tempmail.lol 邮箱创建成功,新鲜热乎: {email}")
|
||||
self.update_status(True)
|
||||
return email_info
|
||||
|
||||
except Exception as e:
|
||||
self.update_status(False, e)
|
||||
if isinstance(e, EmailServiceError):
|
||||
raise
|
||||
raise EmailServiceError(f"创建 Tempmail.lol 邮箱失败: {e}")
|
||||
|
||||
def get_verification_code(
|
||||
self,
|
||||
email: str,
|
||||
email_id: str = None,
|
||||
timeout: int = 120,
|
||||
pattern: str = OTP_CODE_PATTERN,
|
||||
otp_sent_at: Optional[float] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
从 Tempmail.lol 获取验证码
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
email_id: 邮箱 token(如果不提供,从缓存中查找)
|
||||
timeout: 超时时间(秒)
|
||||
pattern: 验证码正则表达式
|
||||
otp_sent_at: OTP 发送时间戳(Tempmail 服务暂不使用此参数)
|
||||
|
||||
Returns:
|
||||
验证码字符串,如果超时或未找到返回 None
|
||||
"""
|
||||
token = email_id
|
||||
if not token:
|
||||
# 从缓存中查找 token
|
||||
if email in self._email_cache:
|
||||
token = self._email_cache[email].get("token")
|
||||
else:
|
||||
logger.warning(f"未找到邮箱 {email} 的 token,无法获取验证码")
|
||||
return None
|
||||
|
||||
if not token:
|
||||
logger.warning(f"邮箱 {email} 没有 token,无法获取验证码")
|
||||
return None
|
||||
|
||||
logger.info(f"正在等邮箱 {email} 的验证码,邮差应该在路上了...")
|
||||
|
||||
start_time = time.time()
|
||||
seen_ids = set()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
# 获取邮件列表
|
||||
response = self.http_client.get(
|
||||
f"{self.config['base_url']}/inbox",
|
||||
params={"token": token},
|
||||
headers={"Accept": "application/json"}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
time.sleep(3)
|
||||
continue
|
||||
|
||||
data = response.json()
|
||||
|
||||
# 检查 inbox 是否过期
|
||||
if data is None or (isinstance(data, dict) and not data):
|
||||
logger.warning(f"邮箱 {email} 已过期")
|
||||
return None
|
||||
|
||||
email_list = data.get("emails", []) if isinstance(data, dict) else []
|
||||
|
||||
if not isinstance(email_list, list):
|
||||
time.sleep(3)
|
||||
continue
|
||||
|
||||
for msg in email_list:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
|
||||
# 使用 date 作为唯一标识
|
||||
msg_date = msg.get("date", 0)
|
||||
if not msg_date or msg_date in seen_ids:
|
||||
continue
|
||||
seen_ids.add(msg_date)
|
||||
|
||||
sender = str(msg.get("from", "")).lower()
|
||||
subject = str(msg.get("subject", ""))
|
||||
body = str(msg.get("body", ""))
|
||||
html = str(msg.get("html") or "")
|
||||
|
||||
content = "\n".join([sender, subject, body, html])
|
||||
|
||||
# 检查是否是 OpenAI 邮件
|
||||
if "openai" not in sender and "openai" not in content.lower():
|
||||
continue
|
||||
|
||||
# 提取验证码
|
||||
match = re.search(pattern, content)
|
||||
if match:
|
||||
code = match.group(1)
|
||||
logger.info(f"找到验证码了,六位嘉宾登场: {code}")
|
||||
self.update_status(True)
|
||||
return code
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"检查邮件时出错: {e}")
|
||||
|
||||
# 等待一段时间再检查
|
||||
time.sleep(3)
|
||||
|
||||
logger.warning(f"等验证码等到超时了: {email}")
|
||||
return None
|
||||
|
||||
def list_emails(self, **kwargs) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列出所有缓存的邮箱
|
||||
|
||||
Note:
|
||||
Tempmail.lol API 不支持列出所有邮箱,这里返回缓存的邮箱
|
||||
"""
|
||||
return list(self._email_cache.values())
|
||||
|
||||
def delete_email(self, email_id: str) -> bool:
|
||||
"""
|
||||
删除邮箱
|
||||
|
||||
Note:
|
||||
Tempmail.lol API 不支持删除邮箱,这里从缓存中移除
|
||||
"""
|
||||
# 从缓存中查找并移除
|
||||
emails_to_delete = []
|
||||
for email, info in self._email_cache.items():
|
||||
if info.get("token") == email_id:
|
||||
emails_to_delete.append(email)
|
||||
|
||||
for email in emails_to_delete:
|
||||
del self._email_cache[email]
|
||||
logger.info(f"从缓存中移除邮箱: {email}")
|
||||
|
||||
return len(emails_to_delete) > 0
|
||||
|
||||
def check_health(self) -> bool:
|
||||
"""检查 Tempmail.lol 服务是否可用"""
|
||||
try:
|
||||
response = self.http_client.get(
|
||||
f"{self.config['base_url']}/inbox/create",
|
||||
timeout=10
|
||||
)
|
||||
# 即使返回错误状态码也认为服务可用(只要可以连接)
|
||||
self.update_status(True)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Tempmail.lol 健康检查失败: {e}")
|
||||
self.update_status(False, e)
|
||||
return False
|
||||
|
||||
def get_inbox(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取邮箱收件箱内容
|
||||
|
||||
Args:
|
||||
token: 邮箱 token
|
||||
|
||||
Returns:
|
||||
收件箱数据
|
||||
"""
|
||||
try:
|
||||
response = self.http_client.get(
|
||||
f"{self.config['base_url']}/inbox",
|
||||
params={"token": token},
|
||||
headers={"Accept": "application/json"}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
logger.error(f"获取收件箱失败: {e}")
|
||||
return None
|
||||
|
||||
def wait_for_verification_code_with_callback(
|
||||
self,
|
||||
email: str,
|
||||
token: str,
|
||||
callback: callable = None,
|
||||
timeout: int = 120
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
等待验证码并支持回调函数
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
token: 邮箱 token
|
||||
callback: 回调函数,接收当前状态信息
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
验证码或 None
|
||||
"""
|
||||
start_time = time.time()
|
||||
seen_ids = set()
|
||||
check_count = 0
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
check_count += 1
|
||||
|
||||
if callback:
|
||||
callback({
|
||||
"status": "checking",
|
||||
"email": email,
|
||||
"check_count": check_count,
|
||||
"elapsed_time": time.time() - start_time,
|
||||
})
|
||||
|
||||
try:
|
||||
data = self.get_inbox(token)
|
||||
if not data:
|
||||
time.sleep(3)
|
||||
continue
|
||||
|
||||
# 检查 inbox 是否过期
|
||||
if data is None or (isinstance(data, dict) and not data):
|
||||
if callback:
|
||||
callback({
|
||||
"status": "expired",
|
||||
"email": email,
|
||||
"message": "邮箱已过期"
|
||||
})
|
||||
return None
|
||||
|
||||
email_list = data.get("emails", []) if isinstance(data, dict) else []
|
||||
|
||||
for msg in email_list:
|
||||
msg_date = msg.get("date", 0)
|
||||
if not msg_date or msg_date in seen_ids:
|
||||
continue
|
||||
seen_ids.add(msg_date)
|
||||
|
||||
sender = str(msg.get("from", "")).lower()
|
||||
subject = str(msg.get("subject", ""))
|
||||
body = str(msg.get("body", ""))
|
||||
html = str(msg.get("html") or "")
|
||||
|
||||
content = "\n".join([sender, subject, body, html])
|
||||
|
||||
# 检查是否是 OpenAI 邮件
|
||||
if "openai" not in sender and "openai" not in content.lower():
|
||||
continue
|
||||
|
||||
# 提取验证码
|
||||
match = re.search(OTP_CODE_PATTERN, content)
|
||||
if match:
|
||||
code = match.group(1)
|
||||
if callback:
|
||||
callback({
|
||||
"status": "found",
|
||||
"email": email,
|
||||
"code": code,
|
||||
"message": "找到验证码"
|
||||
})
|
||||
return code
|
||||
|
||||
if callback and check_count % 5 == 0:
|
||||
callback({
|
||||
"status": "waiting",
|
||||
"email": email,
|
||||
"check_count": check_count,
|
||||
"message": f"已检查 {len(seen_ids)} 封邮件,等待验证码..."
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"检查邮件时出错: {e}")
|
||||
if callback:
|
||||
callback({
|
||||
"status": "error",
|
||||
"email": email,
|
||||
"error": str(e),
|
||||
"message": "检查邮件时出错"
|
||||
})
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
if callback:
|
||||
callback({
|
||||
"status": "timeout",
|
||||
"email": email,
|
||||
"message": "等待验证码超时"
|
||||
})
|
||||
return None
|
||||
7
src/web/__init__.py
Normal file
7
src/web/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Web UI 应用模块
|
||||
"""
|
||||
|
||||
from .app import app, create_app
|
||||
|
||||
__all__ = ['app', 'create_app']
|
||||
201
src/web/app.py
Normal file
201
src/web/app.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
FastAPI 应用主文件
|
||||
轻量级 Web UI,支持注册、账号管理、设置
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import secrets
|
||||
import hmac
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, Request, Form
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
|
||||
from ..config.settings import get_settings
|
||||
from .routes import api_router
|
||||
from .routes.websocket import router as ws_router
|
||||
from .task_manager import task_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 获取项目根目录
|
||||
# PyInstaller 打包后静态资源在 sys._MEIPASS,开发时在源码根目录
|
||||
if getattr(sys, 'frozen', False):
|
||||
_RESOURCE_ROOT = Path(sys._MEIPASS)
|
||||
else:
|
||||
_RESOURCE_ROOT = Path(__file__).parent.parent.parent
|
||||
|
||||
# 静态文件和模板目录
|
||||
STATIC_DIR = _RESOURCE_ROOT / "static"
|
||||
TEMPLATES_DIR = _RESOURCE_ROOT / "templates"
|
||||
|
||||
|
||||
def _build_static_asset_version(static_dir: Path) -> str:
|
||||
"""基于静态文件最后修改时间生成版本号,避免部署后浏览器继续使用旧缓存。"""
|
||||
latest_mtime = 0
|
||||
if static_dir.exists():
|
||||
for path in static_dir.rglob("*"):
|
||||
if path.is_file():
|
||||
latest_mtime = max(latest_mtime, int(path.stat().st_mtime))
|
||||
return str(latest_mtime or 1)
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""创建 FastAPI 应用实例"""
|
||||
settings = get_settings()
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
version=settings.app_version,
|
||||
description="OpenAI/Codex CLI 自动注册系统 Web UI",
|
||||
docs_url="/api/docs" if settings.debug else None,
|
||||
redoc_url="/api/redoc" if settings.debug else None,
|
||||
)
|
||||
|
||||
# CORS 中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 挂载静态文件
|
||||
if STATIC_DIR.exists():
|
||||
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
||||
logger.info(f"静态文件目录: {STATIC_DIR}")
|
||||
else:
|
||||
# 创建静态目录
|
||||
STATIC_DIR.mkdir(parents=True, exist_ok=True)
|
||||
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
||||
logger.info(f"创建静态文件目录: {STATIC_DIR}")
|
||||
|
||||
# 创建模板目录
|
||||
if not TEMPLATES_DIR.exists():
|
||||
TEMPLATES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"创建模板目录: {TEMPLATES_DIR}")
|
||||
|
||||
# 注册 API 路由
|
||||
app.include_router(api_router, prefix="/api")
|
||||
|
||||
# 注册 WebSocket 路由
|
||||
app.include_router(ws_router, prefix="/api")
|
||||
|
||||
# 模板引擎
|
||||
templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
|
||||
templates.env.globals["static_version"] = _build_static_asset_version(STATIC_DIR)
|
||||
|
||||
def _auth_token(password: str) -> str:
|
||||
secret = get_settings().webui_secret_key.get_secret_value().encode("utf-8")
|
||||
return hmac.new(secret, password.encode("utf-8"), hashlib.sha256).hexdigest()
|
||||
|
||||
def _is_authenticated(request: Request) -> bool:
|
||||
cookie = request.cookies.get("webui_auth")
|
||||
expected = _auth_token(get_settings().webui_access_password.get_secret_value())
|
||||
return bool(cookie) and secrets.compare_digest(cookie, expected)
|
||||
|
||||
def _redirect_to_login(request: Request) -> RedirectResponse:
|
||||
return RedirectResponse(url=f"/login?next={request.url.path}", status_code=302)
|
||||
|
||||
@app.get("/login", response_class=HTMLResponse)
|
||||
async def login_page(request: Request, next: Optional[str] = "/"):
|
||||
"""登录页面"""
|
||||
return templates.TemplateResponse(
|
||||
"login.html",
|
||||
{"request": request, "error": "", "next": next or "/"}
|
||||
)
|
||||
|
||||
@app.post("/login")
|
||||
async def login_submit(request: Request, password: str = Form(...), next: Optional[str] = "/"):
|
||||
"""处理登录提交"""
|
||||
expected = get_settings().webui_access_password.get_secret_value()
|
||||
if not secrets.compare_digest(password, expected):
|
||||
return templates.TemplateResponse(
|
||||
"login.html",
|
||||
{"request": request, "error": "密码错误", "next": next or "/"},
|
||||
status_code=401
|
||||
)
|
||||
|
||||
response = RedirectResponse(url=next or "/", status_code=302)
|
||||
response.set_cookie("webui_auth", _auth_token(expected), httponly=True, samesite="lax")
|
||||
return response
|
||||
|
||||
@app.get("/logout")
|
||||
async def logout(request: Request, next: Optional[str] = "/login"):
|
||||
"""退出登录"""
|
||||
response = RedirectResponse(url=next or "/login", status_code=302)
|
||||
response.delete_cookie("webui_auth")
|
||||
return response
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def index(request: Request):
|
||||
"""首页 - 注册页面"""
|
||||
if not _is_authenticated(request):
|
||||
return _redirect_to_login(request)
|
||||
return templates.TemplateResponse("index.html", {"request": request})
|
||||
|
||||
@app.get("/accounts", response_class=HTMLResponse)
|
||||
async def accounts_page(request: Request):
|
||||
"""账号管理页面"""
|
||||
if not _is_authenticated(request):
|
||||
return _redirect_to_login(request)
|
||||
return templates.TemplateResponse("accounts.html", {"request": request})
|
||||
|
||||
@app.get("/email-services", response_class=HTMLResponse)
|
||||
async def email_services_page(request: Request):
|
||||
"""邮箱服务管理页面"""
|
||||
if not _is_authenticated(request):
|
||||
return _redirect_to_login(request)
|
||||
return templates.TemplateResponse("email_services.html", {"request": request})
|
||||
|
||||
@app.get("/settings", response_class=HTMLResponse)
|
||||
async def settings_page(request: Request):
|
||||
"""设置页面"""
|
||||
if not _is_authenticated(request):
|
||||
return _redirect_to_login(request)
|
||||
return templates.TemplateResponse("settings.html", {"request": request})
|
||||
|
||||
@app.get("/payment", response_class=HTMLResponse)
|
||||
async def payment_page(request: Request):
|
||||
"""支付页面"""
|
||||
return templates.TemplateResponse("payment.html", {"request": request})
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""应用启动事件"""
|
||||
import asyncio
|
||||
from ..database.init_db import initialize_database
|
||||
|
||||
# 确保数据库已初始化(reload 模式下子进程也需要初始化)
|
||||
try:
|
||||
initialize_database()
|
||||
except Exception as e:
|
||||
logger.warning(f"数据库初始化: {e}")
|
||||
|
||||
# 设置 TaskManager 的事件循环
|
||||
loop = asyncio.get_event_loop()
|
||||
task_manager.set_loop(loop)
|
||||
|
||||
logger.info("=" * 50)
|
||||
logger.info(f"{settings.app_name} v{settings.app_version} 启动中,程序正在伸懒腰...")
|
||||
logger.info(f"调试模式: {settings.debug}")
|
||||
logger.info(f"数据库连接已接好线: {settings.database_url}")
|
||||
logger.info("=" * 50)
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""应用关闭事件"""
|
||||
logger.info("应用关闭,今天先收摊啦")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# 创建全局应用实例
|
||||
app = create_app()
|
||||
26
src/web/routes/__init__.py
Normal file
26
src/web/routes/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
API 路由模块
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .accounts import router as accounts_router
|
||||
from .registration import router as registration_router
|
||||
from .settings import router as settings_router
|
||||
from .email import router as email_services_router
|
||||
from .payment import router as payment_router
|
||||
from .upload.cpa_services import router as cpa_services_router
|
||||
from .upload.sub2api_services import router as sub2api_services_router
|
||||
from .upload.tm_services import router as tm_services_router
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
# 注册各模块路由
|
||||
api_router.include_router(accounts_router, prefix="/accounts", tags=["accounts"])
|
||||
api_router.include_router(registration_router, prefix="/registration", tags=["registration"])
|
||||
api_router.include_router(settings_router, prefix="/settings", tags=["settings"])
|
||||
api_router.include_router(email_services_router, prefix="/email-services", tags=["email-services"])
|
||||
api_router.include_router(payment_router, prefix="/payment", tags=["payment"])
|
||||
api_router.include_router(cpa_services_router, prefix="/cpa-services", tags=["cpa-services"])
|
||||
api_router.include_router(sub2api_services_router, prefix="/sub2api-services", tags=["sub2api-services"])
|
||||
api_router.include_router(tm_services_router, prefix="/tm-services", tags=["tm-services"])
|
||||
1063
src/web/routes/accounts.py
Normal file
1063
src/web/routes/accounts.py
Normal file
File diff suppressed because it is too large
Load Diff
610
src/web/routes/email.py
Normal file
610
src/web/routes/email.py
Normal file
@@ -0,0 +1,610 @@
|
||||
"""
|
||||
邮箱服务配置 API 路由
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...database import crud
|
||||
from ...database.session import get_db
|
||||
from ...database.models import EmailService as EmailServiceModel
|
||||
from ...services import EmailServiceFactory, EmailServiceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============== Pydantic Models ==============
|
||||
|
||||
class EmailServiceCreate(BaseModel):
|
||||
"""创建邮箱服务请求"""
|
||||
service_type: str
|
||||
name: str
|
||||
config: Dict[str, Any]
|
||||
enabled: bool = True
|
||||
priority: int = 0
|
||||
|
||||
|
||||
class EmailServiceUpdate(BaseModel):
|
||||
"""更新邮箱服务请求"""
|
||||
name: Optional[str] = None
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
enabled: Optional[bool] = None
|
||||
priority: Optional[int] = None
|
||||
|
||||
|
||||
class EmailServiceResponse(BaseModel):
|
||||
"""邮箱服务响应"""
|
||||
id: int
|
||||
service_type: str
|
||||
name: str
|
||||
enabled: bool
|
||||
priority: int
|
||||
config: Optional[Dict[str, Any]] = None # 过滤敏感信息后的配置
|
||||
last_used: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class EmailServiceListResponse(BaseModel):
|
||||
"""邮箱服务列表响应"""
|
||||
total: int
|
||||
services: List[EmailServiceResponse]
|
||||
|
||||
|
||||
class ServiceTestResult(BaseModel):
|
||||
"""服务测试结果"""
|
||||
success: bool
|
||||
message: str
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class OutlookBatchImportRequest(BaseModel):
|
||||
"""Outlook 批量导入请求"""
|
||||
data: str # 多行数据,每行格式: 邮箱----密码 或 邮箱----密码----client_id----refresh_token
|
||||
enabled: bool = True
|
||||
priority: int = 0
|
||||
|
||||
|
||||
class OutlookBatchImportResponse(BaseModel):
|
||||
"""Outlook 批量导入响应"""
|
||||
total: int
|
||||
success: int
|
||||
failed: int
|
||||
accounts: List[Dict[str, Any]]
|
||||
errors: List[str]
|
||||
|
||||
|
||||
# ============== Helper Functions ==============
|
||||
|
||||
# 敏感字段列表,返回响应时需要过滤
|
||||
SENSITIVE_FIELDS = {'password', 'api_key', 'refresh_token', 'access_token', 'admin_token'}
|
||||
|
||||
def filter_sensitive_config(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""过滤敏感配置信息"""
|
||||
if not config:
|
||||
return {}
|
||||
|
||||
filtered = {}
|
||||
for key, value in config.items():
|
||||
if key in SENSITIVE_FIELDS:
|
||||
# 敏感字段不返回,但标记是否存在
|
||||
filtered[f"has_{key}"] = bool(value)
|
||||
else:
|
||||
filtered[key] = value
|
||||
|
||||
# 为 Outlook 计算是否有 OAuth
|
||||
if config.get('client_id') and config.get('refresh_token'):
|
||||
filtered['has_oauth'] = True
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def service_to_response(service: EmailServiceModel) -> EmailServiceResponse:
|
||||
"""转换服务模型为响应"""
|
||||
return EmailServiceResponse(
|
||||
id=service.id,
|
||||
service_type=service.service_type,
|
||||
name=service.name,
|
||||
enabled=service.enabled,
|
||||
priority=service.priority,
|
||||
config=filter_sensitive_config(service.config),
|
||||
last_used=service.last_used.isoformat() if service.last_used else None,
|
||||
created_at=service.created_at.isoformat() if service.created_at else None,
|
||||
updated_at=service.updated_at.isoformat() if service.updated_at else None,
|
||||
)
|
||||
|
||||
|
||||
# ============== API Endpoints ==============
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_email_services_stats():
|
||||
"""获取邮箱服务统计信息"""
|
||||
with get_db() as db:
|
||||
from sqlalchemy import func
|
||||
|
||||
# 按类型统计
|
||||
type_stats = db.query(
|
||||
EmailServiceModel.service_type,
|
||||
func.count(EmailServiceModel.id)
|
||||
).group_by(EmailServiceModel.service_type).all()
|
||||
|
||||
# 启用数量
|
||||
enabled_count = db.query(func.count(EmailServiceModel.id)).filter(
|
||||
EmailServiceModel.enabled == True
|
||||
).scalar()
|
||||
|
||||
stats = {
|
||||
'outlook_count': 0,
|
||||
'custom_count': 0,
|
||||
'temp_mail_count': 0,
|
||||
'duck_mail_count': 0,
|
||||
'freemail_count': 0,
|
||||
'imap_mail_count': 0,
|
||||
'tempmail_available': True, # 临时邮箱始终可用
|
||||
'enabled_count': enabled_count
|
||||
}
|
||||
|
||||
for service_type, count in type_stats:
|
||||
if service_type == 'outlook':
|
||||
stats['outlook_count'] = count
|
||||
elif service_type == 'moe_mail':
|
||||
stats['custom_count'] = count
|
||||
elif service_type == 'temp_mail':
|
||||
stats['temp_mail_count'] = count
|
||||
elif service_type == 'duck_mail':
|
||||
stats['duck_mail_count'] = count
|
||||
elif service_type == 'freemail':
|
||||
stats['freemail_count'] = count
|
||||
elif service_type == 'imap_mail':
|
||||
stats['imap_mail_count'] = count
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
@router.get("/types")
|
||||
async def get_service_types():
|
||||
"""获取支持的邮箱服务类型"""
|
||||
return {
|
||||
"types": [
|
||||
{
|
||||
"value": "tempmail",
|
||||
"label": "Tempmail.lol",
|
||||
"description": "临时邮箱服务,无需配置",
|
||||
"config_fields": [
|
||||
{"name": "base_url", "label": "API 地址", "default": "https://api.tempmail.lol/v2", "required": False},
|
||||
{"name": "timeout", "label": "超时时间", "default": 30, "required": False},
|
||||
]
|
||||
},
|
||||
{
|
||||
"value": "outlook",
|
||||
"label": "Outlook",
|
||||
"description": "Outlook 邮箱,需要配置账户信息",
|
||||
"config_fields": [
|
||||
{"name": "email", "label": "邮箱地址", "required": True},
|
||||
{"name": "password", "label": "密码", "required": True},
|
||||
{"name": "client_id", "label": "OAuth Client ID", "required": False},
|
||||
{"name": "refresh_token", "label": "OAuth Refresh Token", "required": False},
|
||||
]
|
||||
},
|
||||
{
|
||||
"value": "moe_mail",
|
||||
"label": "MoeMail",
|
||||
"description": "自定义域名邮箱服务",
|
||||
"config_fields": [
|
||||
{"name": "base_url", "label": "API 地址", "required": True},
|
||||
{"name": "api_key", "label": "API Key", "required": True},
|
||||
{"name": "default_domain", "label": "默认域名", "required": False},
|
||||
]
|
||||
},
|
||||
{
|
||||
"value": "temp_mail",
|
||||
"label": "Temp-Mail(自部署)",
|
||||
"description": "自部署 Cloudflare Worker 临时邮箱,admin 模式管理",
|
||||
"config_fields": [
|
||||
{"name": "base_url", "label": "Worker 地址", "required": True, "placeholder": "https://mail.example.com"},
|
||||
{"name": "admin_password", "label": "Admin 密码", "required": True, "secret": True},
|
||||
{"name": "domain", "label": "邮箱域名", "required": True, "placeholder": "example.com"},
|
||||
{"name": "enable_prefix", "label": "启用前缀", "required": False, "default": True},
|
||||
]
|
||||
},
|
||||
{
|
||||
"value": "duck_mail",
|
||||
"label": "DuckMail",
|
||||
"description": "DuckMail 接口邮箱服务,支持 API Key 私有域名访问",
|
||||
"config_fields": [
|
||||
{"name": "base_url", "label": "API 地址", "required": True, "placeholder": "https://api.duckmail.sbs"},
|
||||
{"name": "default_domain", "label": "默认域名", "required": True, "placeholder": "duckmail.sbs"},
|
||||
{"name": "api_key", "label": "API Key", "required": False, "secret": True},
|
||||
{"name": "password_length", "label": "随机密码长度", "required": False, "default": 12},
|
||||
]
|
||||
},
|
||||
{
|
||||
"value": "freemail",
|
||||
"label": "Freemail",
|
||||
"description": "Freemail 自部署 Cloudflare Worker 临时邮箱服务",
|
||||
"config_fields": [
|
||||
{"name": "base_url", "label": "API 地址", "required": True, "placeholder": "https://freemail.example.com"},
|
||||
{"name": "admin_token", "label": "Admin Token", "required": True, "secret": True},
|
||||
{"name": "domain", "label": "邮箱域名", "required": False, "placeholder": "example.com"},
|
||||
]
|
||||
},
|
||||
{
|
||||
"value": "imap_mail",
|
||||
"label": "IMAP 邮箱",
|
||||
"description": "标准 IMAP 协议邮箱(Gmail/QQ/163等),仅用于接收验证码,强制直连",
|
||||
"config_fields": [
|
||||
{"name": "host", "label": "IMAP 服务器", "required": True, "placeholder": "imap.gmail.com"},
|
||||
{"name": "port", "label": "端口", "required": False, "default": 993},
|
||||
{"name": "use_ssl", "label": "使用 SSL", "required": False, "default": True},
|
||||
{"name": "email", "label": "邮箱地址", "required": True},
|
||||
{"name": "password", "label": "密码/授权码", "required": True, "secret": True},
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("", response_model=EmailServiceListResponse)
|
||||
async def list_email_services(
|
||||
service_type: Optional[str] = Query(None, description="服务类型筛选"),
|
||||
enabled_only: bool = Query(False, description="只显示启用的服务"),
|
||||
):
|
||||
"""获取邮箱服务列表"""
|
||||
with get_db() as db:
|
||||
query = db.query(EmailServiceModel)
|
||||
|
||||
if service_type:
|
||||
query = query.filter(EmailServiceModel.service_type == service_type)
|
||||
|
||||
if enabled_only:
|
||||
query = query.filter(EmailServiceModel.enabled == True)
|
||||
|
||||
services = query.order_by(EmailServiceModel.priority.asc(), EmailServiceModel.id.asc()).all()
|
||||
|
||||
return EmailServiceListResponse(
|
||||
total=len(services),
|
||||
services=[service_to_response(s) for s in services]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{service_id}", response_model=EmailServiceResponse)
|
||||
async def get_email_service(service_id: int):
|
||||
"""获取单个邮箱服务详情"""
|
||||
with get_db() as db:
|
||||
service = db.query(EmailServiceModel).filter(EmailServiceModel.id == service_id).first()
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="服务不存在")
|
||||
return service_to_response(service)
|
||||
|
||||
|
||||
@router.get("/{service_id}/full")
|
||||
async def get_email_service_full(service_id: int):
|
||||
"""获取单个邮箱服务完整详情(包含敏感字段,用于编辑)"""
|
||||
with get_db() as db:
|
||||
service = db.query(EmailServiceModel).filter(EmailServiceModel.id == service_id).first()
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="服务不存在")
|
||||
|
||||
return {
|
||||
"id": service.id,
|
||||
"service_type": service.service_type,
|
||||
"name": service.name,
|
||||
"enabled": service.enabled,
|
||||
"priority": service.priority,
|
||||
"config": service.config or {}, # 返回完整配置
|
||||
"last_used": service.last_used.isoformat() if service.last_used else None,
|
||||
"created_at": service.created_at.isoformat() if service.created_at else None,
|
||||
"updated_at": service.updated_at.isoformat() if service.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.post("", response_model=EmailServiceResponse)
|
||||
async def create_email_service(request: EmailServiceCreate):
|
||||
"""创建邮箱服务配置"""
|
||||
# 验证服务类型
|
||||
try:
|
||||
EmailServiceType(request.service_type)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"无效的服务类型: {request.service_type}")
|
||||
|
||||
with get_db() as db:
|
||||
# 检查名称是否重复
|
||||
existing = db.query(EmailServiceModel).filter(EmailServiceModel.name == request.name).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="服务名称已存在")
|
||||
|
||||
service = EmailServiceModel(
|
||||
service_type=request.service_type,
|
||||
name=request.name,
|
||||
config=request.config,
|
||||
enabled=request.enabled,
|
||||
priority=request.priority
|
||||
)
|
||||
db.add(service)
|
||||
db.commit()
|
||||
db.refresh(service)
|
||||
|
||||
return service_to_response(service)
|
||||
|
||||
|
||||
@router.patch("/{service_id}", response_model=EmailServiceResponse)
|
||||
async def update_email_service(service_id: int, request: EmailServiceUpdate):
|
||||
"""更新邮箱服务配置"""
|
||||
with get_db() as db:
|
||||
service = db.query(EmailServiceModel).filter(EmailServiceModel.id == service_id).first()
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="服务不存在")
|
||||
|
||||
update_data = {}
|
||||
if request.name is not None:
|
||||
update_data["name"] = request.name
|
||||
if request.config is not None:
|
||||
# 合并配置而不是替换
|
||||
current_config = service.config or {}
|
||||
merged_config = {**current_config, **request.config}
|
||||
# 移除空值
|
||||
merged_config = {k: v for k, v in merged_config.items() if v}
|
||||
update_data["config"] = merged_config
|
||||
if request.enabled is not None:
|
||||
update_data["enabled"] = request.enabled
|
||||
if request.priority is not None:
|
||||
update_data["priority"] = request.priority
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(service, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(service)
|
||||
|
||||
return service_to_response(service)
|
||||
|
||||
|
||||
@router.delete("/{service_id}")
|
||||
async def delete_email_service(service_id: int):
|
||||
"""删除邮箱服务配置"""
|
||||
with get_db() as db:
|
||||
service = db.query(EmailServiceModel).filter(EmailServiceModel.id == service_id).first()
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="服务不存在")
|
||||
|
||||
db.delete(service)
|
||||
db.commit()
|
||||
|
||||
return {"success": True, "message": f"服务 {service.name} 已删除"}
|
||||
|
||||
|
||||
@router.post("/{service_id}/test", response_model=ServiceTestResult)
|
||||
async def test_email_service(service_id: int):
|
||||
"""测试邮箱服务是否可用"""
|
||||
with get_db() as db:
|
||||
service = db.query(EmailServiceModel).filter(EmailServiceModel.id == service_id).first()
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="服务不存在")
|
||||
|
||||
try:
|
||||
service_type = EmailServiceType(service.service_type)
|
||||
email_service = EmailServiceFactory.create(service_type, service.config, name=service.name)
|
||||
|
||||
health = email_service.check_health()
|
||||
|
||||
if health:
|
||||
return ServiceTestResult(
|
||||
success=True,
|
||||
message="服务连接正常",
|
||||
details=email_service.get_service_info() if hasattr(email_service, 'get_service_info') else None
|
||||
)
|
||||
else:
|
||||
return ServiceTestResult(
|
||||
success=False,
|
||||
message="服务连接失败"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"测试邮箱服务失败: {e}")
|
||||
return ServiceTestResult(
|
||||
success=False,
|
||||
message=f"测试失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{service_id}/enable")
|
||||
async def enable_email_service(service_id: int):
|
||||
"""启用邮箱服务"""
|
||||
with get_db() as db:
|
||||
service = db.query(EmailServiceModel).filter(EmailServiceModel.id == service_id).first()
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="服务不存在")
|
||||
|
||||
service.enabled = True
|
||||
db.commit()
|
||||
|
||||
return {"success": True, "message": f"服务 {service.name} 已启用"}
|
||||
|
||||
|
||||
@router.post("/{service_id}/disable")
|
||||
async def disable_email_service(service_id: int):
|
||||
"""禁用邮箱服务"""
|
||||
with get_db() as db:
|
||||
service = db.query(EmailServiceModel).filter(EmailServiceModel.id == service_id).first()
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="服务不存在")
|
||||
|
||||
service.enabled = False
|
||||
db.commit()
|
||||
|
||||
return {"success": True, "message": f"服务 {service.name} 已禁用"}
|
||||
|
||||
|
||||
@router.post("/reorder")
|
||||
async def reorder_services(service_ids: List[int]):
|
||||
"""重新排序邮箱服务优先级"""
|
||||
with get_db() as db:
|
||||
for index, service_id in enumerate(service_ids):
|
||||
service = db.query(EmailServiceModel).filter(EmailServiceModel.id == service_id).first()
|
||||
if service:
|
||||
service.priority = index
|
||||
|
||||
db.commit()
|
||||
|
||||
return {"success": True, "message": "优先级已更新"}
|
||||
|
||||
|
||||
@router.post("/outlook/batch-import", response_model=OutlookBatchImportResponse)
|
||||
async def batch_import_outlook(request: OutlookBatchImportRequest):
|
||||
"""
|
||||
批量导入 Outlook 邮箱账户
|
||||
|
||||
支持两种格式:
|
||||
- 格式一(密码认证):邮箱----密码
|
||||
- 格式二(XOAUTH2 认证):邮箱----密码----client_id----refresh_token
|
||||
|
||||
每行一个账户,使用四个连字符(----)分隔字段
|
||||
"""
|
||||
lines = request.data.strip().split("\n")
|
||||
total = len(lines)
|
||||
success = 0
|
||||
failed = 0
|
||||
accounts = []
|
||||
errors = []
|
||||
|
||||
with get_db() as db:
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
|
||||
# 跳过空行和注释
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
parts = line.split("----")
|
||||
|
||||
# 验证格式
|
||||
if len(parts) < 2:
|
||||
failed += 1
|
||||
errors.append(f"行 {i+1}: 格式错误,至少需要邮箱和密码")
|
||||
continue
|
||||
|
||||
email = parts[0].strip()
|
||||
password = parts[1].strip()
|
||||
|
||||
# 验证邮箱格式
|
||||
if "@" not in email:
|
||||
failed += 1
|
||||
errors.append(f"行 {i+1}: 无效的邮箱地址: {email}")
|
||||
continue
|
||||
|
||||
# 检查是否已存在
|
||||
existing = db.query(EmailServiceModel).filter(
|
||||
EmailServiceModel.service_type == "outlook",
|
||||
EmailServiceModel.name == email
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
failed += 1
|
||||
errors.append(f"行 {i+1}: 邮箱已存在: {email}")
|
||||
continue
|
||||
|
||||
# 构建配置
|
||||
config = {
|
||||
"email": email,
|
||||
"password": password
|
||||
}
|
||||
|
||||
# 检查是否有 OAuth 信息(格式二)
|
||||
if len(parts) >= 4:
|
||||
client_id = parts[2].strip()
|
||||
refresh_token = parts[3].strip()
|
||||
if client_id and refresh_token:
|
||||
config["client_id"] = client_id
|
||||
config["refresh_token"] = refresh_token
|
||||
|
||||
# 创建服务记录
|
||||
try:
|
||||
service = EmailServiceModel(
|
||||
service_type="outlook",
|
||||
name=email,
|
||||
config=config,
|
||||
enabled=request.enabled,
|
||||
priority=request.priority
|
||||
)
|
||||
db.add(service)
|
||||
db.commit()
|
||||
db.refresh(service)
|
||||
|
||||
accounts.append({
|
||||
"id": service.id,
|
||||
"email": email,
|
||||
"has_oauth": bool(config.get("client_id")),
|
||||
"name": email
|
||||
})
|
||||
success += 1
|
||||
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
errors.append(f"行 {i+1}: 创建失败: {str(e)}")
|
||||
db.rollback()
|
||||
|
||||
return OutlookBatchImportResponse(
|
||||
total=total,
|
||||
success=success,
|
||||
failed=failed,
|
||||
accounts=accounts,
|
||||
errors=errors
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/outlook/batch")
|
||||
async def batch_delete_outlook(service_ids: List[int]):
|
||||
"""批量删除 Outlook 邮箱服务"""
|
||||
deleted = 0
|
||||
with get_db() as db:
|
||||
for service_id in service_ids:
|
||||
service = db.query(EmailServiceModel).filter(
|
||||
EmailServiceModel.id == service_id,
|
||||
EmailServiceModel.service_type == "outlook"
|
||||
).first()
|
||||
if service:
|
||||
db.delete(service)
|
||||
deleted += 1
|
||||
db.commit()
|
||||
|
||||
return {"success": True, "deleted": deleted, "message": f"已删除 {deleted} 个服务"}
|
||||
|
||||
|
||||
# ============== 临时邮箱测试 ==============
|
||||
|
||||
class TempmailTestRequest(BaseModel):
|
||||
"""临时邮箱测试请求"""
|
||||
api_url: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/test-tempmail")
|
||||
async def test_tempmail_service(request: TempmailTestRequest):
|
||||
"""测试临时邮箱服务是否可用"""
|
||||
try:
|
||||
from ...services import EmailServiceFactory, EmailServiceType
|
||||
from ...config.settings import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
base_url = request.api_url or settings.tempmail_base_url
|
||||
|
||||
config = {"base_url": base_url}
|
||||
tempmail = EmailServiceFactory.create(EmailServiceType.TEMPMAIL, config)
|
||||
|
||||
# 检查服务健康状态
|
||||
health = tempmail.check_health()
|
||||
|
||||
if health:
|
||||
return {"success": True, "message": "临时邮箱连接正常"}
|
||||
else:
|
||||
return {"success": False, "message": "临时邮箱连接失败"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"测试临时邮箱失败: {e}")
|
||||
return {"success": False, "message": f"测试失败: {str(e)}"}
|
||||
182
src/web/routes/payment.py
Normal file
182
src/web/routes/payment.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
支付相关 API 路由
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...database.session import get_db
|
||||
from ...database.models import Account
|
||||
from ...database import crud
|
||||
from ...config.settings import get_settings
|
||||
from .accounts import resolve_account_ids
|
||||
from ...core.openai.payment import (
|
||||
generate_plus_link,
|
||||
generate_team_link,
|
||||
open_url_incognito,
|
||||
check_subscription_status,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============== Pydantic Models ==============
|
||||
|
||||
class GenerateLinkRequest(BaseModel):
|
||||
account_id: int
|
||||
plan_type: str # 'plus' or 'team'
|
||||
workspace_name: str = "MyTeam"
|
||||
price_interval: str = "month"
|
||||
seat_quantity: int = 5
|
||||
proxy: Optional[str] = None
|
||||
auto_open: bool = False # 生成后是否自动无痕打开
|
||||
country: str = "SG" # 计费国家,决定货币 # 生成后是否自动无痕打开
|
||||
|
||||
|
||||
class OpenIncognitoRequest(BaseModel):
|
||||
url: str
|
||||
account_id: Optional[int] = None # 可选,用于注入账号 cookie
|
||||
|
||||
|
||||
class MarkSubscriptionRequest(BaseModel):
|
||||
subscription_type: str # 'free' / 'plus' / 'team'
|
||||
|
||||
|
||||
class BatchCheckSubscriptionRequest(BaseModel):
|
||||
ids: List[int] = []
|
||||
proxy: Optional[str] = None
|
||||
select_all: bool = False
|
||||
status_filter: Optional[str] = None
|
||||
email_service_filter: Optional[str] = None
|
||||
search_filter: Optional[str] = None
|
||||
|
||||
|
||||
# ============== 支付链接生成 ==============
|
||||
|
||||
@router.post("/generate-link")
|
||||
def generate_payment_link(request: GenerateLinkRequest):
|
||||
"""生成 Plus 或 Team 支付链接,可选自动无痕打开"""
|
||||
with get_db() as db:
|
||||
account = db.query(Account).filter(Account.id == request.account_id).first()
|
||||
if not account:
|
||||
raise HTTPException(status_code=404, detail="账号不存在")
|
||||
|
||||
proxy = request.proxy or get_settings().proxy_url
|
||||
|
||||
try:
|
||||
if request.plan_type == "plus":
|
||||
link = generate_plus_link(account, proxy, country=request.country)
|
||||
elif request.plan_type == "team":
|
||||
link = generate_team_link(
|
||||
account,
|
||||
workspace_name=request.workspace_name,
|
||||
price_interval=request.price_interval,
|
||||
seat_quantity=request.seat_quantity,
|
||||
proxy=proxy,
|
||||
country=request.country,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="plan_type 必须为 plus 或 team")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"生成支付链接失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"生成链接失败: {str(e)}")
|
||||
|
||||
opened = False
|
||||
if request.auto_open and link:
|
||||
cookies_str = account.cookies if account else None
|
||||
opened = open_url_incognito(link, cookies_str)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"link": link,
|
||||
"plan_type": request.plan_type,
|
||||
"auto_opened": opened,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/open-incognito")
|
||||
def open_browser_incognito(request: OpenIncognitoRequest):
|
||||
"""后端以无痕模式打开指定 URL,可注入账号 cookie"""
|
||||
if not request.url:
|
||||
raise HTTPException(status_code=400, detail="URL 不能为空")
|
||||
|
||||
cookies_str = None
|
||||
if request.account_id:
|
||||
with get_db() as db:
|
||||
account = db.query(Account).filter(Account.id == request.account_id).first()
|
||||
if account:
|
||||
cookies_str = account.cookies
|
||||
|
||||
success = open_url_incognito(request.url, cookies_str)
|
||||
if success:
|
||||
return {"success": True, "message": "已在无痕模式打开浏览器"}
|
||||
return {"success": False, "message": "未找到可用的浏览器,请手动复制链接"}
|
||||
|
||||
|
||||
# ============== 订阅状态 ==============
|
||||
|
||||
@router.post("/accounts/batch-check-subscription")
|
||||
def batch_check_subscription(request: BatchCheckSubscriptionRequest):
|
||||
"""批量检测账号订阅状态"""
|
||||
proxy = request.proxy or get_settings().proxy_url
|
||||
|
||||
results = {"success_count": 0, "failed_count": 0, "details": []}
|
||||
|
||||
with get_db() as db:
|
||||
ids = resolve_account_ids(
|
||||
db, request.ids, request.select_all,
|
||||
request.status_filter, request.email_service_filter, request.search_filter
|
||||
)
|
||||
for account_id in ids:
|
||||
account = db.query(Account).filter(Account.id == account_id).first()
|
||||
if not account:
|
||||
results["failed_count"] += 1
|
||||
results["details"].append(
|
||||
{"id": account_id, "email": None, "success": False, "error": "账号不存在"}
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
status = check_subscription_status(account, proxy)
|
||||
account.subscription_type = None if status == "free" else status
|
||||
account.subscription_at = datetime.utcnow() if status != "free" else account.subscription_at
|
||||
db.commit()
|
||||
results["success_count"] += 1
|
||||
results["details"].append(
|
||||
{"id": account_id, "email": account.email, "success": True, "subscription_type": status}
|
||||
)
|
||||
except Exception as e:
|
||||
results["failed_count"] += 1
|
||||
results["details"].append(
|
||||
{"id": account_id, "email": account.email, "success": False, "error": str(e)}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@router.post("/accounts/{account_id}/mark-subscription")
|
||||
def mark_subscription(account_id: int, request: MarkSubscriptionRequest):
|
||||
"""手动标记账号订阅类型"""
|
||||
allowed = ("free", "plus", "team")
|
||||
if request.subscription_type not in allowed:
|
||||
raise HTTPException(status_code=400, detail=f"subscription_type 必须为 {allowed}")
|
||||
|
||||
with get_db() as db:
|
||||
account = db.query(Account).filter(Account.id == account_id).first()
|
||||
if not account:
|
||||
raise HTTPException(status_code=404, detail="账号不存在")
|
||||
|
||||
account.subscription_type = None if request.subscription_type == "free" else request.subscription_type
|
||||
account.subscription_at = datetime.utcnow() if request.subscription_type != "free" else None
|
||||
db.commit()
|
||||
|
||||
return {"success": True, "subscription_type": request.subscription_type}
|
||||
|
||||
|
||||
1527
src/web/routes/registration.py
Normal file
1527
src/web/routes/registration.py
Normal file
File diff suppressed because it is too large
Load Diff
781
src/web/routes/settings.py
Normal file
781
src/web/routes/settings.py
Normal file
@@ -0,0 +1,781 @@
|
||||
"""
|
||||
设置 API 路由
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...config.settings import get_settings, update_settings
|
||||
from ...database import crud
|
||||
from ...database.session import get_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============== Pydantic Models ==============
|
||||
|
||||
class SettingItem(BaseModel):
|
||||
"""设置项"""
|
||||
key: str
|
||||
value: str
|
||||
description: Optional[str] = None
|
||||
category: str = "general"
|
||||
|
||||
|
||||
class SettingUpdateRequest(BaseModel):
|
||||
"""设置更新请求"""
|
||||
value: str
|
||||
|
||||
|
||||
class ProxySettings(BaseModel):
|
||||
"""代理设置"""
|
||||
enabled: bool = False
|
||||
type: str = "http" # http, socks5
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 7890
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
|
||||
|
||||
class RegistrationSettings(BaseModel):
|
||||
"""注册设置"""
|
||||
max_retries: int = 3
|
||||
timeout: int = 120
|
||||
default_password_length: int = 12
|
||||
sleep_min: int = 5
|
||||
sleep_max: int = 30
|
||||
|
||||
|
||||
class WebUISettings(BaseModel):
|
||||
"""Web UI 设置"""
|
||||
host: Optional[str] = None
|
||||
port: Optional[int] = None
|
||||
debug: Optional[bool] = None
|
||||
access_password: Optional[str] = None
|
||||
|
||||
|
||||
class AllSettings(BaseModel):
|
||||
"""所有设置"""
|
||||
proxy: ProxySettings
|
||||
registration: RegistrationSettings
|
||||
webui: WebUISettings
|
||||
|
||||
|
||||
# ============== API Endpoints ==============
|
||||
|
||||
@router.get("")
|
||||
async def get_all_settings():
|
||||
"""获取所有设置"""
|
||||
settings = get_settings()
|
||||
|
||||
return {
|
||||
"proxy": {
|
||||
"enabled": settings.proxy_enabled,
|
||||
"type": settings.proxy_type,
|
||||
"host": settings.proxy_host,
|
||||
"port": settings.proxy_port,
|
||||
"username": settings.proxy_username,
|
||||
"has_password": bool(settings.proxy_password),
|
||||
"dynamic_enabled": settings.proxy_dynamic_enabled,
|
||||
"dynamic_api_url": settings.proxy_dynamic_api_url,
|
||||
"dynamic_api_key_header": settings.proxy_dynamic_api_key_header,
|
||||
"dynamic_result_field": settings.proxy_dynamic_result_field,
|
||||
"has_dynamic_api_key": bool(settings.proxy_dynamic_api_key and settings.proxy_dynamic_api_key.get_secret_value()),
|
||||
},
|
||||
"registration": {
|
||||
"max_retries": settings.registration_max_retries,
|
||||
"timeout": settings.registration_timeout,
|
||||
"default_password_length": settings.registration_default_password_length,
|
||||
"sleep_min": settings.registration_sleep_min,
|
||||
"sleep_max": settings.registration_sleep_max,
|
||||
},
|
||||
"webui": {
|
||||
"host": settings.webui_host,
|
||||
"port": settings.webui_port,
|
||||
"debug": settings.debug,
|
||||
"has_access_password": bool(settings.webui_access_password and settings.webui_access_password.get_secret_value()),
|
||||
},
|
||||
"tempmail": {
|
||||
"base_url": settings.tempmail_base_url,
|
||||
"timeout": settings.tempmail_timeout,
|
||||
"max_retries": settings.tempmail_max_retries,
|
||||
},
|
||||
"email_code": {
|
||||
"timeout": settings.email_code_timeout,
|
||||
"poll_interval": settings.email_code_poll_interval,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/proxy/dynamic")
|
||||
async def get_dynamic_proxy_settings():
|
||||
"""获取动态代理设置"""
|
||||
settings = get_settings()
|
||||
return {
|
||||
"enabled": settings.proxy_dynamic_enabled,
|
||||
"api_url": settings.proxy_dynamic_api_url,
|
||||
"api_key_header": settings.proxy_dynamic_api_key_header,
|
||||
"result_field": settings.proxy_dynamic_result_field,
|
||||
"has_api_key": bool(settings.proxy_dynamic_api_key and settings.proxy_dynamic_api_key.get_secret_value()),
|
||||
}
|
||||
|
||||
|
||||
class DynamicProxySettings(BaseModel):
|
||||
"""动态代理设置"""
|
||||
enabled: bool = False
|
||||
api_url: str = ""
|
||||
api_key: Optional[str] = None
|
||||
api_key_header: str = "X-API-Key"
|
||||
result_field: str = ""
|
||||
|
||||
|
||||
@router.post("/proxy/dynamic")
|
||||
async def update_dynamic_proxy_settings(request: DynamicProxySettings):
|
||||
"""更新动态代理设置"""
|
||||
update_dict = {
|
||||
"proxy_dynamic_enabled": request.enabled,
|
||||
"proxy_dynamic_api_url": request.api_url,
|
||||
"proxy_dynamic_api_key_header": request.api_key_header,
|
||||
"proxy_dynamic_result_field": request.result_field,
|
||||
}
|
||||
if request.api_key is not None:
|
||||
update_dict["proxy_dynamic_api_key"] = request.api_key
|
||||
|
||||
update_settings(**update_dict)
|
||||
return {"success": True, "message": "动态代理设置已更新"}
|
||||
|
||||
|
||||
@router.post("/proxy/dynamic/test")
|
||||
async def test_dynamic_proxy(request: DynamicProxySettings):
|
||||
"""测试动态代理 API"""
|
||||
from ...core.dynamic_proxy import fetch_dynamic_proxy
|
||||
|
||||
if not request.api_url:
|
||||
raise HTTPException(status_code=400, detail="请填写动态代理 API 地址")
|
||||
|
||||
# 若未传入 api_key,使用已保存的
|
||||
api_key = request.api_key or ""
|
||||
if not api_key:
|
||||
settings = get_settings()
|
||||
if settings.proxy_dynamic_api_key:
|
||||
api_key = settings.proxy_dynamic_api_key.get_secret_value()
|
||||
|
||||
proxy_url = fetch_dynamic_proxy(
|
||||
api_url=request.api_url,
|
||||
api_key=api_key,
|
||||
api_key_header=request.api_key_header,
|
||||
result_field=request.result_field,
|
||||
)
|
||||
|
||||
if not proxy_url:
|
||||
return {"success": False, "message": "动态代理 API 返回为空或请求失败"}
|
||||
|
||||
# 用获取到的代理测试连通性
|
||||
import time
|
||||
from curl_cffi import requests as cffi_requests
|
||||
try:
|
||||
proxies = {"http": proxy_url, "https": proxy_url}
|
||||
start = time.time()
|
||||
resp = cffi_requests.get(
|
||||
"https://api.ipify.org?format=json",
|
||||
proxies=proxies,
|
||||
timeout=10,
|
||||
impersonate="chrome110"
|
||||
)
|
||||
elapsed = round((time.time() - start) * 1000)
|
||||
if resp.status_code == 200:
|
||||
ip = resp.json().get("ip", "")
|
||||
return {"success": True, "proxy_url": proxy_url, "ip": ip, "response_time": elapsed,
|
||||
"message": f"动态代理可用,出口 IP: {ip},响应时间: {elapsed}ms"}
|
||||
return {"success": False, "proxy_url": proxy_url, "message": f"代理连接失败: HTTP {resp.status_code}"}
|
||||
except Exception as e:
|
||||
return {"success": False, "proxy_url": proxy_url, "message": f"代理连接失败: {e}"}
|
||||
|
||||
|
||||
@router.get("/registration")
|
||||
async def get_registration_settings():
|
||||
"""获取注册设置"""
|
||||
settings = get_settings()
|
||||
|
||||
return {
|
||||
"max_retries": settings.registration_max_retries,
|
||||
"timeout": settings.registration_timeout,
|
||||
"default_password_length": settings.registration_default_password_length,
|
||||
"sleep_min": settings.registration_sleep_min,
|
||||
"sleep_max": settings.registration_sleep_max,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/registration")
|
||||
async def update_registration_settings(request: RegistrationSettings):
|
||||
"""更新注册设置"""
|
||||
update_settings(
|
||||
registration_max_retries=request.max_retries,
|
||||
registration_timeout=request.timeout,
|
||||
registration_default_password_length=request.default_password_length,
|
||||
registration_sleep_min=request.sleep_min,
|
||||
registration_sleep_max=request.sleep_max,
|
||||
)
|
||||
|
||||
return {"success": True, "message": "注册设置已更新"}
|
||||
|
||||
|
||||
@router.post("/webui")
|
||||
async def update_webui_settings(request: WebUISettings):
|
||||
"""更新 Web UI 设置"""
|
||||
update_dict = {}
|
||||
if request.host is not None:
|
||||
update_dict["webui_host"] = request.host
|
||||
if request.port is not None:
|
||||
update_dict["webui_port"] = request.port
|
||||
if request.debug is not None:
|
||||
update_dict["debug"] = request.debug
|
||||
if request.access_password:
|
||||
update_dict["webui_access_password"] = request.access_password
|
||||
|
||||
update_settings(**update_dict)
|
||||
return {"success": True, "message": "Web UI 设置已更新"}
|
||||
|
||||
|
||||
@router.get("/database")
|
||||
async def get_database_info():
|
||||
"""获取数据库信息"""
|
||||
settings = get_settings()
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
db_path = settings.database_url
|
||||
if db_path.startswith("sqlite:///"):
|
||||
db_path = db_path[10:]
|
||||
|
||||
db_file = Path(db_path) if os.path.isabs(db_path) else Path(db_path)
|
||||
db_size = db_file.stat().st_size if db_file.exists() else 0
|
||||
|
||||
with get_db() as db:
|
||||
from ...database.models import Account, EmailService, RegistrationTask
|
||||
|
||||
account_count = db.query(Account).count()
|
||||
service_count = db.query(EmailService).count()
|
||||
task_count = db.query(RegistrationTask).count()
|
||||
|
||||
return {
|
||||
"database_url": settings.database_url,
|
||||
"database_size_bytes": db_size,
|
||||
"database_size_mb": round(db_size / (1024 * 1024), 2),
|
||||
"accounts_count": account_count,
|
||||
"email_services_count": service_count,
|
||||
"tasks_count": task_count,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/database/backup")
|
||||
async def backup_database():
|
||||
"""备份数据库"""
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
db_path = settings.database_url
|
||||
if db_path.startswith("sqlite:///"):
|
||||
db_path = db_path[10:]
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
raise HTTPException(status_code=404, detail="数据库文件不存在")
|
||||
|
||||
# 创建备份目录
|
||||
from pathlib import Path as FilePath
|
||||
backup_dir = FilePath(db_path).parent / "backups"
|
||||
backup_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 生成备份文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_path = backup_dir / f"database_backup_{timestamp}.db"
|
||||
|
||||
# 复制数据库文件
|
||||
shutil.copy2(db_path, backup_path)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "数据库备份成功",
|
||||
"backup_path": str(backup_path)
|
||||
}
|
||||
|
||||
|
||||
@router.post("/database/cleanup")
|
||||
async def cleanup_database(
|
||||
days: int = 30,
|
||||
keep_failed: bool = True
|
||||
):
|
||||
"""清理过期数据"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||
|
||||
with get_db() as db:
|
||||
from ...database.models import RegistrationTask
|
||||
from sqlalchemy import delete
|
||||
|
||||
# 删除旧任务
|
||||
conditions = [RegistrationTask.created_at < cutoff_date]
|
||||
if not keep_failed:
|
||||
conditions.append(RegistrationTask.status != "failed")
|
||||
else:
|
||||
conditions.append(RegistrationTask.status.in_(["completed", "cancelled"]))
|
||||
|
||||
result = db.execute(
|
||||
delete(RegistrationTask).where(*conditions)
|
||||
)
|
||||
db.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已清理 {deleted_count} 条过期任务记录",
|
||||
"deleted_count": deleted_count
|
||||
}
|
||||
|
||||
|
||||
@router.get("/logs")
|
||||
async def get_recent_logs(
|
||||
lines: int = 100,
|
||||
level: str = "INFO"
|
||||
):
|
||||
"""获取最近日志"""
|
||||
settings = get_settings()
|
||||
|
||||
log_file = settings.log_file
|
||||
if not log_file:
|
||||
return {"logs": [], "message": "日志文件未配置"}
|
||||
|
||||
from pathlib import Path
|
||||
log_path = Path(log_file)
|
||||
|
||||
if not log_path.exists():
|
||||
return {"logs": [], "message": "日志文件不存在"}
|
||||
|
||||
try:
|
||||
with open(log_path, "r", encoding="utf-8") as f:
|
||||
all_lines = f.readlines()
|
||||
recent_lines = all_lines[-lines:]
|
||||
|
||||
return {
|
||||
"logs": [line.strip() for line in recent_lines],
|
||||
"total_lines": len(all_lines)
|
||||
}
|
||||
except Exception as e:
|
||||
return {"logs": [], "error": str(e)}
|
||||
|
||||
|
||||
# ============== 临时邮箱设置 ==============
|
||||
|
||||
class TempmailSettings(BaseModel):
|
||||
"""临时邮箱设置"""
|
||||
api_url: Optional[str] = None
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class EmailCodeSettings(BaseModel):
|
||||
"""验证码等待设置"""
|
||||
timeout: int = 120 # 验证码等待超时(秒)
|
||||
poll_interval: int = 3 # 验证码轮询间隔(秒)
|
||||
|
||||
|
||||
@router.get("/tempmail")
|
||||
async def get_tempmail_settings():
|
||||
"""获取临时邮箱设置"""
|
||||
settings = get_settings()
|
||||
|
||||
return {
|
||||
"api_url": settings.tempmail_base_url,
|
||||
"timeout": settings.tempmail_timeout,
|
||||
"max_retries": settings.tempmail_max_retries,
|
||||
"enabled": True # 临时邮箱默认可用
|
||||
}
|
||||
|
||||
|
||||
@router.post("/tempmail")
|
||||
async def update_tempmail_settings(request: TempmailSettings):
|
||||
"""更新临时邮箱设置"""
|
||||
update_dict = {}
|
||||
|
||||
if request.api_url:
|
||||
update_dict["tempmail_base_url"] = request.api_url
|
||||
|
||||
update_settings(**update_dict)
|
||||
|
||||
return {"success": True, "message": "临时邮箱设置已更新"}
|
||||
|
||||
|
||||
# ============== 验证码等待设置 ==============
|
||||
|
||||
@router.get("/email-code")
|
||||
async def get_email_code_settings():
|
||||
"""获取验证码等待设置"""
|
||||
settings = get_settings()
|
||||
return {
|
||||
"timeout": settings.email_code_timeout,
|
||||
"poll_interval": settings.email_code_poll_interval,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/email-code")
|
||||
async def update_email_code_settings(request: EmailCodeSettings):
|
||||
"""更新验证码等待设置"""
|
||||
# 验证参数范围
|
||||
if request.timeout < 30 or request.timeout > 600:
|
||||
raise HTTPException(status_code=400, detail="超时时间必须在 30-600 秒之间")
|
||||
if request.poll_interval < 1 or request.poll_interval > 30:
|
||||
raise HTTPException(status_code=400, detail="轮询间隔必须在 1-30 秒之间")
|
||||
|
||||
update_settings(
|
||||
email_code_timeout=request.timeout,
|
||||
email_code_poll_interval=request.poll_interval,
|
||||
)
|
||||
|
||||
return {"success": True, "message": "验证码等待设置已更新"}
|
||||
|
||||
|
||||
# ============== 代理列表 CRUD ==============
|
||||
|
||||
class ProxyCreateRequest(BaseModel):
|
||||
"""创建代理请求"""
|
||||
name: str
|
||||
type: str = "http" # http, socks5
|
||||
host: str
|
||||
port: int
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
enabled: bool = True
|
||||
priority: int = 0
|
||||
|
||||
|
||||
class ProxyUpdateRequest(BaseModel):
|
||||
"""更新代理请求"""
|
||||
name: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
host: Optional[str] = None
|
||||
port: Optional[int] = None
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
enabled: Optional[bool] = None
|
||||
priority: Optional[int] = None
|
||||
|
||||
|
||||
@router.get("/proxies")
|
||||
async def get_proxies_list(enabled: Optional[bool] = None):
|
||||
"""获取代理列表"""
|
||||
with get_db() as db:
|
||||
proxies = crud.get_proxies(db, enabled=enabled)
|
||||
return {
|
||||
"proxies": [p.to_dict() for p in proxies],
|
||||
"total": len(proxies)
|
||||
}
|
||||
|
||||
|
||||
@router.post("/proxies")
|
||||
async def create_proxy_item(request: ProxyCreateRequest):
|
||||
"""创建代理"""
|
||||
with get_db() as db:
|
||||
proxy = crud.create_proxy(
|
||||
db,
|
||||
name=request.name,
|
||||
type=request.type,
|
||||
host=request.host,
|
||||
port=request.port,
|
||||
username=request.username,
|
||||
password=request.password,
|
||||
enabled=request.enabled,
|
||||
priority=request.priority
|
||||
)
|
||||
return {"success": True, "proxy": proxy.to_dict()}
|
||||
|
||||
|
||||
@router.get("/proxies/{proxy_id}")
|
||||
async def get_proxy_item(proxy_id: int):
|
||||
"""获取单个代理"""
|
||||
with get_db() as db:
|
||||
proxy = crud.get_proxy_by_id(db, proxy_id)
|
||||
if not proxy:
|
||||
raise HTTPException(status_code=404, detail="代理不存在")
|
||||
return proxy.to_dict(include_password=True)
|
||||
|
||||
|
||||
@router.patch("/proxies/{proxy_id}")
|
||||
async def update_proxy_item(proxy_id: int, request: ProxyUpdateRequest):
|
||||
"""更新代理"""
|
||||
with get_db() as db:
|
||||
update_data = {}
|
||||
if request.name is not None:
|
||||
update_data["name"] = request.name
|
||||
if request.type is not None:
|
||||
update_data["type"] = request.type
|
||||
if request.host is not None:
|
||||
update_data["host"] = request.host
|
||||
if request.port is not None:
|
||||
update_data["port"] = request.port
|
||||
if request.username is not None:
|
||||
update_data["username"] = request.username
|
||||
if request.password is not None:
|
||||
update_data["password"] = request.password
|
||||
if request.enabled is not None:
|
||||
update_data["enabled"] = request.enabled
|
||||
if request.priority is not None:
|
||||
update_data["priority"] = request.priority
|
||||
|
||||
proxy = crud.update_proxy(db, proxy_id, **update_data)
|
||||
if not proxy:
|
||||
raise HTTPException(status_code=404, detail="代理不存在")
|
||||
return {"success": True, "proxy": proxy.to_dict()}
|
||||
|
||||
|
||||
@router.delete("/proxies/{proxy_id}")
|
||||
async def delete_proxy_item(proxy_id: int):
|
||||
"""删除代理"""
|
||||
with get_db() as db:
|
||||
success = crud.delete_proxy(db, proxy_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="代理不存在")
|
||||
return {"success": True, "message": "代理已删除"}
|
||||
|
||||
|
||||
@router.post("/proxies/{proxy_id}/set-default")
|
||||
async def set_proxy_default(proxy_id: int):
|
||||
"""将指定代理设为默认"""
|
||||
with get_db() as db:
|
||||
proxy = crud.set_proxy_default(db, proxy_id)
|
||||
if not proxy:
|
||||
raise HTTPException(status_code=404, detail="代理不存在")
|
||||
return {"success": True, "proxy": proxy.to_dict()}
|
||||
|
||||
|
||||
@router.post("/proxies/{proxy_id}/test")
|
||||
async def test_proxy_item(proxy_id: int):
|
||||
"""测试单个代理"""
|
||||
import time
|
||||
from curl_cffi import requests as cffi_requests
|
||||
|
||||
with get_db() as db:
|
||||
proxy = crud.get_proxy_by_id(db, proxy_id)
|
||||
if not proxy:
|
||||
raise HTTPException(status_code=404, detail="代理不存在")
|
||||
|
||||
proxy_url = proxy.proxy_url
|
||||
test_url = "https://api.ipify.org?format=json"
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
proxies = {
|
||||
"http": proxy_url,
|
||||
"https": proxy_url
|
||||
}
|
||||
|
||||
response = cffi_requests.get(
|
||||
test_url,
|
||||
proxies=proxies,
|
||||
timeout=3,
|
||||
impersonate="chrome110"
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
if response.status_code == 200:
|
||||
ip_info = response.json()
|
||||
return {
|
||||
"success": True,
|
||||
"ip": ip_info.get("ip", ""),
|
||||
"response_time": round(elapsed_time * 1000),
|
||||
"message": f"代理连接成功,出口 IP: {ip_info.get('ip', 'unknown')}"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"代理返回错误状态码: {response.status_code}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"代理连接失败: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/proxies/test-all")
|
||||
async def test_all_proxies():
|
||||
"""测试所有启用的代理"""
|
||||
import time
|
||||
from curl_cffi import requests as cffi_requests
|
||||
|
||||
with get_db() as db:
|
||||
proxies = crud.get_enabled_proxies(db)
|
||||
|
||||
results = []
|
||||
for proxy in proxies:
|
||||
proxy_url = proxy.proxy_url
|
||||
test_url = "https://api.ipify.org?format=json"
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
proxies_dict = {
|
||||
"http": proxy_url,
|
||||
"https": proxy_url
|
||||
}
|
||||
|
||||
response = cffi_requests.get(
|
||||
test_url,
|
||||
proxies=proxies_dict,
|
||||
timeout=3,
|
||||
impersonate="chrome110"
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
if response.status_code == 200:
|
||||
ip_info = response.json()
|
||||
results.append({
|
||||
"id": proxy.id,
|
||||
"name": proxy.name,
|
||||
"success": True,
|
||||
"ip": ip_info.get("ip", ""),
|
||||
"response_time": round(elapsed_time * 1000)
|
||||
})
|
||||
else:
|
||||
results.append({
|
||||
"id": proxy.id,
|
||||
"name": proxy.name,
|
||||
"success": False,
|
||||
"message": f"状态码: {response.status_code}"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
results.append({
|
||||
"id": proxy.id,
|
||||
"name": proxy.name,
|
||||
"success": False,
|
||||
"message": str(e)
|
||||
})
|
||||
|
||||
success_count = sum(1 for r in results if r["success"])
|
||||
return {
|
||||
"total": len(proxies),
|
||||
"success": success_count,
|
||||
"failed": len(proxies) - success_count,
|
||||
"results": results
|
||||
}
|
||||
|
||||
|
||||
@router.post("/proxies/{proxy_id}/enable")
|
||||
async def enable_proxy(proxy_id: int):
|
||||
"""启用代理"""
|
||||
with get_db() as db:
|
||||
proxy = crud.update_proxy(db, proxy_id, enabled=True)
|
||||
if not proxy:
|
||||
raise HTTPException(status_code=404, detail="代理不存在")
|
||||
return {"success": True, "message": "代理已启用"}
|
||||
|
||||
|
||||
@router.post("/proxies/{proxy_id}/disable")
|
||||
async def disable_proxy(proxy_id: int):
|
||||
"""禁用代理"""
|
||||
with get_db() as db:
|
||||
proxy = crud.update_proxy(db, proxy_id, enabled=False)
|
||||
if not proxy:
|
||||
raise HTTPException(status_code=404, detail="代理不存在")
|
||||
return {"success": True, "message": "代理已禁用"}
|
||||
|
||||
|
||||
# ============== Outlook 设置 ==============
|
||||
|
||||
class OutlookSettings(BaseModel):
|
||||
"""Outlook 设置"""
|
||||
default_client_id: Optional[str] = None
|
||||
|
||||
|
||||
@router.get("/outlook")
|
||||
async def get_outlook_settings():
|
||||
"""获取 Outlook 设置"""
|
||||
settings = get_settings()
|
||||
|
||||
return {
|
||||
"default_client_id": settings.outlook_default_client_id,
|
||||
"provider_priority": settings.outlook_provider_priority,
|
||||
"health_failure_threshold": settings.outlook_health_failure_threshold,
|
||||
"health_disable_duration": settings.outlook_health_disable_duration,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/outlook")
|
||||
async def update_outlook_settings(request: OutlookSettings):
|
||||
"""更新 Outlook 设置"""
|
||||
update_dict = {}
|
||||
|
||||
if request.default_client_id is not None:
|
||||
update_dict["outlook_default_client_id"] = request.default_client_id
|
||||
|
||||
if update_dict:
|
||||
update_settings(**update_dict)
|
||||
|
||||
return {"success": True, "message": "Outlook 设置已更新"}
|
||||
|
||||
|
||||
# ============== Team Manager 设置 ==============
|
||||
|
||||
class TeamManagerSettings(BaseModel):
|
||||
"""Team Manager 设置"""
|
||||
enabled: bool = False
|
||||
api_url: str = ""
|
||||
api_key: str = ""
|
||||
|
||||
|
||||
class TeamManagerTestRequest(BaseModel):
|
||||
"""Team Manager 测试请求"""
|
||||
api_url: str
|
||||
api_key: str
|
||||
|
||||
|
||||
@router.get("/team-manager")
|
||||
async def get_team_manager_settings():
|
||||
"""获取 Team Manager 设置"""
|
||||
settings = get_settings()
|
||||
return {
|
||||
"enabled": settings.tm_enabled,
|
||||
"api_url": settings.tm_api_url,
|
||||
"has_api_key": bool(settings.tm_api_key and settings.tm_api_key.get_secret_value()),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/team-manager")
|
||||
async def update_team_manager_settings(request: TeamManagerSettings):
|
||||
"""更新 Team Manager 设置"""
|
||||
update_dict = {
|
||||
"tm_enabled": request.enabled,
|
||||
"tm_api_url": request.api_url,
|
||||
}
|
||||
if request.api_key:
|
||||
update_dict["tm_api_key"] = request.api_key
|
||||
update_settings(**update_dict)
|
||||
return {"success": True, "message": "Team Manager 设置已更新"}
|
||||
|
||||
|
||||
@router.post("/team-manager/test")
|
||||
async def test_team_manager_connection(request: TeamManagerTestRequest):
|
||||
"""测试 Team Manager 连接"""
|
||||
from ...core.upload.team_manager_upload import test_team_manager_connection as do_test
|
||||
|
||||
settings = get_settings()
|
||||
api_key = request.api_key
|
||||
if api_key == 'use_saved_key' or not api_key:
|
||||
if settings.tm_api_key:
|
||||
api_key = settings.tm_api_key.get_secret_value()
|
||||
else:
|
||||
return {"success": False, "message": "未配置 API Key"}
|
||||
|
||||
success, message = do_test(request.api_url, api_key)
|
||||
return {"success": success, "message": message}
|
||||
2
src/web/routes/upload/__init__.py
Normal file
2
src/web/routes/upload/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
171
src/web/routes/upload/cpa_services.py
Normal file
171
src/web/routes/upload/cpa_services.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
CPA 服务管理 API 路由
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ....database import crud
|
||||
from ....database.session import get_db
|
||||
from ....core.upload.cpa_upload import test_cpa_connection
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============== Pydantic Models ==============
|
||||
|
||||
class CpaServiceCreate(BaseModel):
|
||||
name: str
|
||||
api_url: str
|
||||
api_token: str
|
||||
enabled: bool = True
|
||||
priority: int = 0
|
||||
|
||||
|
||||
class CpaServiceUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
api_url: Optional[str] = None
|
||||
api_token: Optional[str] = None
|
||||
enabled: Optional[bool] = None
|
||||
priority: Optional[int] = None
|
||||
|
||||
|
||||
class CpaServiceResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
api_url: str
|
||||
has_token: bool
|
||||
enabled: bool
|
||||
priority: int
|
||||
created_at: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CpaServiceTestRequest(BaseModel):
|
||||
api_url: Optional[str] = None
|
||||
api_token: Optional[str] = None
|
||||
|
||||
|
||||
def _to_response(svc) -> CpaServiceResponse:
|
||||
return CpaServiceResponse(
|
||||
id=svc.id,
|
||||
name=svc.name,
|
||||
api_url=svc.api_url,
|
||||
has_token=bool(svc.api_token),
|
||||
enabled=svc.enabled,
|
||||
priority=svc.priority,
|
||||
created_at=svc.created_at.isoformat() if svc.created_at else None,
|
||||
updated_at=svc.updated_at.isoformat() if svc.updated_at else None,
|
||||
)
|
||||
|
||||
|
||||
# ============== API Endpoints ==============
|
||||
|
||||
@router.get("", response_model=List[CpaServiceResponse])
|
||||
async def list_cpa_services(enabled: Optional[bool] = None):
|
||||
"""获取 CPA 服务列表"""
|
||||
with get_db() as db:
|
||||
services = crud.get_cpa_services(db, enabled=enabled)
|
||||
return [_to_response(s) for s in services]
|
||||
|
||||
|
||||
@router.post("", response_model=CpaServiceResponse)
|
||||
async def create_cpa_service(request: CpaServiceCreate):
|
||||
"""新增 CPA 服务"""
|
||||
with get_db() as db:
|
||||
service = crud.create_cpa_service(
|
||||
db,
|
||||
name=request.name,
|
||||
api_url=request.api_url,
|
||||
api_token=request.api_token,
|
||||
enabled=request.enabled,
|
||||
priority=request.priority,
|
||||
)
|
||||
return _to_response(service)
|
||||
|
||||
|
||||
@router.get("/{service_id}", response_model=CpaServiceResponse)
|
||||
async def get_cpa_service(service_id: int):
|
||||
"""获取单个 CPA 服务详情"""
|
||||
with get_db() as db:
|
||||
service = crud.get_cpa_service_by_id(db, service_id)
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="CPA 服务不存在")
|
||||
return _to_response(service)
|
||||
|
||||
|
||||
@router.get("/{service_id}/full")
|
||||
async def get_cpa_service_full(service_id: int):
|
||||
"""获取 CPA 服务完整配置(含 token)"""
|
||||
with get_db() as db:
|
||||
service = crud.get_cpa_service_by_id(db, service_id)
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="CPA 服务不存在")
|
||||
return {
|
||||
"id": service.id,
|
||||
"name": service.name,
|
||||
"api_url": service.api_url,
|
||||
"api_token": service.api_token,
|
||||
"enabled": service.enabled,
|
||||
"priority": service.priority,
|
||||
}
|
||||
|
||||
|
||||
@router.patch("/{service_id}", response_model=CpaServiceResponse)
|
||||
async def update_cpa_service(service_id: int, request: CpaServiceUpdate):
|
||||
"""更新 CPA 服务配置"""
|
||||
with get_db() as db:
|
||||
service = crud.get_cpa_service_by_id(db, service_id)
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="CPA 服务不存在")
|
||||
|
||||
update_data = {}
|
||||
if request.name is not None:
|
||||
update_data["name"] = request.name
|
||||
if request.api_url is not None:
|
||||
update_data["api_url"] = request.api_url
|
||||
# api_token 留空则保持原值
|
||||
if request.api_token:
|
||||
update_data["api_token"] = request.api_token
|
||||
if request.enabled is not None:
|
||||
update_data["enabled"] = request.enabled
|
||||
if request.priority is not None:
|
||||
update_data["priority"] = request.priority
|
||||
|
||||
service = crud.update_cpa_service(db, service_id, **update_data)
|
||||
return _to_response(service)
|
||||
|
||||
|
||||
@router.delete("/{service_id}")
|
||||
async def delete_cpa_service(service_id: int):
|
||||
"""删除 CPA 服务"""
|
||||
with get_db() as db:
|
||||
service = crud.get_cpa_service_by_id(db, service_id)
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="CPA 服务不存在")
|
||||
crud.delete_cpa_service(db, service_id)
|
||||
return {"success": True, "message": f"CPA 服务 {service.name} 已删除"}
|
||||
|
||||
|
||||
@router.post("/{service_id}/test")
|
||||
async def test_cpa_service(service_id: int):
|
||||
"""测试 CPA 服务连接"""
|
||||
with get_db() as db:
|
||||
service = crud.get_cpa_service_by_id(db, service_id)
|
||||
if not service:
|
||||
raise HTTPException(status_code=404, detail="CPA 服务不存在")
|
||||
success, message = test_cpa_connection(service.api_url, service.api_token)
|
||||
return {"success": success, "message": message}
|
||||
|
||||
|
||||
@router.post("/test-connection")
|
||||
async def test_cpa_connection_direct(request: CpaServiceTestRequest):
|
||||
"""直接测试 CPA 连接(用于添加前验证)"""
|
||||
if not request.api_url or not request.api_token:
|
||||
raise HTTPException(status_code=400, detail="api_url 和 api_token 不能为空")
|
||||
success, message = test_cpa_connection(request.api_url, request.api_token)
|
||||
return {"success": success, "message": message}
|
||||
207
src/web/routes/upload/sub2api_services.py
Normal file
207
src/web/routes/upload/sub2api_services.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
Sub2API 服务管理 API 路由
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ....database import crud
|
||||
from ....database.session import get_db
|
||||
from ....core.upload.sub2api_upload import test_sub2api_connection, batch_upload_to_sub2api
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============== Pydantic Models ==============
|
||||
|
||||
class Sub2ApiServiceCreate(BaseModel):
|
||||
name: str
|
||||
api_url: str
|
||||
api_key: str
|
||||
enabled: bool = True
|
||||
priority: int = 0
|
||||
|
||||
|
||||
class Sub2ApiServiceUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
api_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
enabled: Optional[bool] = None
|
||||
priority: Optional[int] = None
|
||||
|
||||
|
||||
class Sub2ApiServiceResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
api_url: str
|
||||
has_key: bool
|
||||
enabled: bool
|
||||
priority: int
|
||||
created_at: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class Sub2ApiTestRequest(BaseModel):
|
||||
api_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
|
||||
|
||||
class Sub2ApiUploadRequest(BaseModel):
|
||||
account_ids: List[int]
|
||||
service_id: Optional[int] = None
|
||||
concurrency: int = 3
|
||||
priority: int = 50
|
||||
|
||||
|
||||
def _to_response(svc) -> Sub2ApiServiceResponse:
|
||||
return Sub2ApiServiceResponse(
|
||||
id=svc.id,
|
||||
name=svc.name,
|
||||
api_url=svc.api_url,
|
||||
has_key=bool(svc.api_key),
|
||||
enabled=svc.enabled,
|
||||
priority=svc.priority,
|
||||
created_at=svc.created_at.isoformat() if svc.created_at else None,
|
||||
updated_at=svc.updated_at.isoformat() if svc.updated_at else None,
|
||||
)
|
||||
|
||||
|
||||
# ============== API Endpoints ==============
|
||||
|
||||
@router.get("", response_model=List[Sub2ApiServiceResponse])
|
||||
async def list_sub2api_services(enabled: Optional[bool] = None):
|
||||
"""获取 Sub2API 服务列表"""
|
||||
with get_db() as db:
|
||||
services = crud.get_sub2api_services(db, enabled=enabled)
|
||||
return [_to_response(s) for s in services]
|
||||
|
||||
|
||||
@router.post("", response_model=Sub2ApiServiceResponse)
|
||||
async def create_sub2api_service(request: Sub2ApiServiceCreate):
|
||||
"""新增 Sub2API 服务"""
|
||||
with get_db() as db:
|
||||
svc = crud.create_sub2api_service(
|
||||
db,
|
||||
name=request.name,
|
||||
api_url=request.api_url,
|
||||
api_key=request.api_key,
|
||||
enabled=request.enabled,
|
||||
priority=request.priority,
|
||||
)
|
||||
return _to_response(svc)
|
||||
|
||||
|
||||
@router.get("/{service_id}", response_model=Sub2ApiServiceResponse)
|
||||
async def get_sub2api_service(service_id: int):
|
||||
"""获取单个 Sub2API 服务详情"""
|
||||
with get_db() as db:
|
||||
svc = crud.get_sub2api_service_by_id(db, service_id)
|
||||
if not svc:
|
||||
raise HTTPException(status_code=404, detail="Sub2API 服务不存在")
|
||||
return _to_response(svc)
|
||||
|
||||
|
||||
@router.get("/{service_id}/full")
|
||||
async def get_sub2api_service_full(service_id: int):
|
||||
"""获取 Sub2API 服务完整配置(含 API Key)"""
|
||||
with get_db() as db:
|
||||
svc = crud.get_sub2api_service_by_id(db, service_id)
|
||||
if not svc:
|
||||
raise HTTPException(status_code=404, detail="Sub2API 服务不存在")
|
||||
return {
|
||||
"id": svc.id,
|
||||
"name": svc.name,
|
||||
"api_url": svc.api_url,
|
||||
"api_key": svc.api_key,
|
||||
"enabled": svc.enabled,
|
||||
"priority": svc.priority,
|
||||
}
|
||||
|
||||
|
||||
@router.patch("/{service_id}", response_model=Sub2ApiServiceResponse)
|
||||
async def update_sub2api_service(service_id: int, request: Sub2ApiServiceUpdate):
|
||||
"""更新 Sub2API 服务配置"""
|
||||
with get_db() as db:
|
||||
svc = crud.get_sub2api_service_by_id(db, service_id)
|
||||
if not svc:
|
||||
raise HTTPException(status_code=404, detail="Sub2API 服务不存在")
|
||||
|
||||
update_data = {}
|
||||
if request.name is not None:
|
||||
update_data["name"] = request.name
|
||||
if request.api_url is not None:
|
||||
update_data["api_url"] = request.api_url
|
||||
# api_key 留空则保持原值
|
||||
if request.api_key:
|
||||
update_data["api_key"] = request.api_key
|
||||
if request.enabled is not None:
|
||||
update_data["enabled"] = request.enabled
|
||||
if request.priority is not None:
|
||||
update_data["priority"] = request.priority
|
||||
|
||||
svc = crud.update_sub2api_service(db, service_id, **update_data)
|
||||
return _to_response(svc)
|
||||
|
||||
|
||||
@router.delete("/{service_id}")
|
||||
async def delete_sub2api_service(service_id: int):
|
||||
"""删除 Sub2API 服务"""
|
||||
with get_db() as db:
|
||||
svc = crud.get_sub2api_service_by_id(db, service_id)
|
||||
if not svc:
|
||||
raise HTTPException(status_code=404, detail="Sub2API 服务不存在")
|
||||
crud.delete_sub2api_service(db, service_id)
|
||||
return {"success": True, "message": f"Sub2API 服务 {svc.name} 已删除"}
|
||||
|
||||
|
||||
@router.post("/{service_id}/test")
|
||||
async def test_sub2api_service(service_id: int):
|
||||
"""测试 Sub2API 服务连接"""
|
||||
with get_db() as db:
|
||||
svc = crud.get_sub2api_service_by_id(db, service_id)
|
||||
if not svc:
|
||||
raise HTTPException(status_code=404, detail="Sub2API 服务不存在")
|
||||
success, message = test_sub2api_connection(svc.api_url, svc.api_key)
|
||||
return {"success": success, "message": message}
|
||||
|
||||
|
||||
@router.post("/test-connection")
|
||||
async def test_sub2api_connection_direct(request: Sub2ApiTestRequest):
|
||||
"""直接测试 Sub2API 连接(用于添加前验证)"""
|
||||
if not request.api_url or not request.api_key:
|
||||
raise HTTPException(status_code=400, detail="api_url 和 api_key 不能为空")
|
||||
success, message = test_sub2api_connection(request.api_url, request.api_key)
|
||||
return {"success": success, "message": message}
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_accounts_to_sub2api(request: Sub2ApiUploadRequest):
|
||||
"""批量上传账号到 Sub2API 平台"""
|
||||
if not request.account_ids:
|
||||
raise HTTPException(status_code=400, detail="账号 ID 列表不能为空")
|
||||
|
||||
with get_db() as db:
|
||||
if request.service_id:
|
||||
svc = crud.get_sub2api_service_by_id(db, request.service_id)
|
||||
else:
|
||||
svcs = crud.get_sub2api_services(db, enabled=True)
|
||||
svc = svcs[0] if svcs else None
|
||||
|
||||
if not svc:
|
||||
raise HTTPException(status_code=400, detail="未找到可用的 Sub2API 服务")
|
||||
|
||||
api_url = svc.api_url
|
||||
api_key = svc.api_key
|
||||
|
||||
results = batch_upload_to_sub2api(
|
||||
request.account_ids,
|
||||
api_url,
|
||||
api_key,
|
||||
concurrency=request.concurrency,
|
||||
priority=request.priority,
|
||||
)
|
||||
return results
|
||||
153
src/web/routes/upload/tm_services.py
Normal file
153
src/web/routes/upload/tm_services.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
Team Manager 服务管理 API 路由
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ....database import crud
|
||||
from ....database.session import get_db
|
||||
from ....core.upload.team_manager_upload import test_team_manager_connection
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============== Pydantic Models ==============
|
||||
|
||||
class TmServiceCreate(BaseModel):
|
||||
name: str
|
||||
api_url: str
|
||||
api_key: str
|
||||
enabled: bool = True
|
||||
priority: int = 0
|
||||
|
||||
|
||||
class TmServiceUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
api_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
enabled: Optional[bool] = None
|
||||
priority: Optional[int] = None
|
||||
|
||||
|
||||
class TmServiceResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
api_url: str
|
||||
has_key: bool
|
||||
enabled: bool
|
||||
priority: int
|
||||
created_at: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TmTestRequest(BaseModel):
|
||||
api_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
|
||||
|
||||
def _to_response(svc) -> TmServiceResponse:
|
||||
return TmServiceResponse(
|
||||
id=svc.id,
|
||||
name=svc.name,
|
||||
api_url=svc.api_url,
|
||||
has_key=bool(svc.api_key),
|
||||
enabled=svc.enabled,
|
||||
priority=svc.priority,
|
||||
created_at=svc.created_at.isoformat() if svc.created_at else None,
|
||||
updated_at=svc.updated_at.isoformat() if svc.updated_at else None,
|
||||
)
|
||||
|
||||
|
||||
# ============== API Endpoints ==============
|
||||
|
||||
@router.get("", response_model=List[TmServiceResponse])
|
||||
async def list_tm_services(enabled: Optional[bool] = None):
|
||||
"""获取 Team Manager 服务列表"""
|
||||
with get_db() as db:
|
||||
services = crud.get_tm_services(db, enabled=enabled)
|
||||
return [_to_response(s) for s in services]
|
||||
|
||||
|
||||
@router.post("", response_model=TmServiceResponse)
|
||||
async def create_tm_service(request: TmServiceCreate):
|
||||
"""新增 Team Manager 服务"""
|
||||
with get_db() as db:
|
||||
svc = crud.create_tm_service(
|
||||
db,
|
||||
name=request.name,
|
||||
api_url=request.api_url,
|
||||
api_key=request.api_key,
|
||||
enabled=request.enabled,
|
||||
priority=request.priority,
|
||||
)
|
||||
return _to_response(svc)
|
||||
|
||||
|
||||
@router.get("/{service_id}", response_model=TmServiceResponse)
|
||||
async def get_tm_service(service_id: int):
|
||||
"""获取单个 Team Manager 服务详情"""
|
||||
with get_db() as db:
|
||||
svc = crud.get_tm_service_by_id(db, service_id)
|
||||
if not svc:
|
||||
raise HTTPException(status_code=404, detail="Team Manager 服务不存在")
|
||||
return _to_response(svc)
|
||||
|
||||
|
||||
@router.patch("/{service_id}", response_model=TmServiceResponse)
|
||||
async def update_tm_service(service_id: int, request: TmServiceUpdate):
|
||||
"""更新 Team Manager 服务配置"""
|
||||
with get_db() as db:
|
||||
svc = crud.get_tm_service_by_id(db, service_id)
|
||||
if not svc:
|
||||
raise HTTPException(status_code=404, detail="Team Manager 服务不存在")
|
||||
|
||||
update_data = {}
|
||||
if request.name is not None:
|
||||
update_data["name"] = request.name
|
||||
if request.api_url is not None:
|
||||
update_data["api_url"] = request.api_url
|
||||
if request.api_key:
|
||||
update_data["api_key"] = request.api_key
|
||||
if request.enabled is not None:
|
||||
update_data["enabled"] = request.enabled
|
||||
if request.priority is not None:
|
||||
update_data["priority"] = request.priority
|
||||
|
||||
svc = crud.update_tm_service(db, service_id, **update_data)
|
||||
return _to_response(svc)
|
||||
|
||||
|
||||
@router.delete("/{service_id}")
|
||||
async def delete_tm_service(service_id: int):
|
||||
"""删除 Team Manager 服务"""
|
||||
with get_db() as db:
|
||||
svc = crud.get_tm_service_by_id(db, service_id)
|
||||
if not svc:
|
||||
raise HTTPException(status_code=404, detail="Team Manager 服务不存在")
|
||||
crud.delete_tm_service(db, service_id)
|
||||
return {"success": True, "message": f"Team Manager 服务 {svc.name} 已删除"}
|
||||
|
||||
|
||||
@router.post("/{service_id}/test")
|
||||
async def test_tm_service(service_id: int):
|
||||
"""测试 Team Manager 服务连接"""
|
||||
with get_db() as db:
|
||||
svc = crud.get_tm_service_by_id(db, service_id)
|
||||
if not svc:
|
||||
raise HTTPException(status_code=404, detail="Team Manager 服务不存在")
|
||||
success, message = test_team_manager_connection(svc.api_url, svc.api_key)
|
||||
return {"success": success, "message": message}
|
||||
|
||||
|
||||
@router.post("/test-connection")
|
||||
async def test_tm_connection_direct(request: TmTestRequest):
|
||||
"""直接测试 Team Manager 连接(用于添加前验证)"""
|
||||
if not request.api_url or not request.api_key:
|
||||
raise HTTPException(status_code=400, detail="api_url 和 api_key 不能为空")
|
||||
success, message = test_team_manager_connection(request.api_url, request.api_key)
|
||||
return {"success": success, "message": message}
|
||||
170
src/web/routes/websocket.py
Normal file
170
src/web/routes/websocket.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
WebSocket 路由
|
||||
提供实时日志推送和任务状态更新
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
|
||||
from ..task_manager import task_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.websocket("/ws/task/{task_uuid}")
|
||||
async def task_websocket(websocket: WebSocket, task_uuid: str):
|
||||
"""
|
||||
任务日志 WebSocket
|
||||
|
||||
消息格式:
|
||||
- 服务端发送: {"type": "log", "task_uuid": "xxx", "message": "...", "timestamp": "..."}
|
||||
- 服务端发送: {"type": "status", "task_uuid": "xxx", "status": "running|completed|failed|cancelled", ...}
|
||||
- 客户端发送: {"type": "ping"} - 心跳
|
||||
- 客户端发送: {"type": "cancel"} - 取消任务
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
# 注册连接(会记录当前日志数量,避免重复发送历史日志)
|
||||
task_manager.register_websocket(task_uuid, websocket)
|
||||
logger.info(f"WebSocket 连接已建立,日志频道正式开麦: {task_uuid}")
|
||||
|
||||
try:
|
||||
# 发送当前状态
|
||||
status = task_manager.get_status(task_uuid)
|
||||
if status:
|
||||
await websocket.send_json({
|
||||
"type": "status",
|
||||
"task_uuid": task_uuid,
|
||||
**status
|
||||
})
|
||||
|
||||
# 发送历史日志(只发送注册时已存在的日志,避免与实时推送重复)
|
||||
history_logs = task_manager.get_unsent_logs(task_uuid, websocket)
|
||||
for log in history_logs:
|
||||
await websocket.send_json({
|
||||
"type": "log",
|
||||
"task_uuid": task_uuid,
|
||||
"message": log
|
||||
})
|
||||
|
||||
# 保持连接,等待客户端消息
|
||||
while True:
|
||||
try:
|
||||
# 使用 wait_for 实现超时,但不是断开连接
|
||||
# 而是发送心跳检测
|
||||
data = await asyncio.wait_for(
|
||||
websocket.receive_json(),
|
||||
timeout=30.0 # 30秒超时
|
||||
)
|
||||
|
||||
# 处理心跳
|
||||
if data.get("type") == "ping":
|
||||
await websocket.send_json({"type": "pong"})
|
||||
|
||||
# 处理取消请求
|
||||
elif data.get("type") == "cancel":
|
||||
task_manager.cancel_task(task_uuid)
|
||||
await websocket.send_json({
|
||||
"type": "status",
|
||||
"task_uuid": task_uuid,
|
||||
"status": "cancelling",
|
||||
"message": "取消请求已提交,正在踩刹车,别慌"
|
||||
})
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时,发送心跳检测
|
||||
try:
|
||||
await websocket.send_json({"type": "ping"})
|
||||
except Exception:
|
||||
# 发送失败,可能是连接断开
|
||||
logger.info(f"WebSocket 心跳检测失败: {task_uuid}")
|
||||
break
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开: {task_uuid}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket 错误: {e}")
|
||||
|
||||
finally:
|
||||
task_manager.unregister_websocket(task_uuid, websocket)
|
||||
|
||||
|
||||
@router.websocket("/ws/batch/{batch_id}")
|
||||
async def batch_websocket(websocket: WebSocket, batch_id: str):
|
||||
"""
|
||||
批量任务 WebSocket
|
||||
|
||||
用于批量注册任务的实时状态更新
|
||||
|
||||
消息格式:
|
||||
- 服务端发送: {"type": "log", "batch_id": "xxx", "message": "...", "timestamp": "..."}
|
||||
- 服务端发送: {"type": "status", "batch_id": "xxx", "status": "running|completed|cancelled", ...}
|
||||
- 客户端发送: {"type": "ping"} - 心跳
|
||||
- 客户端发送: {"type": "cancel"} - 取消批量任务
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
# 注册连接(会记录当前日志数量,避免重复发送历史日志)
|
||||
task_manager.register_batch_websocket(batch_id, websocket)
|
||||
logger.info(f"批量任务 WebSocket 连接已建立,群聊频道正式开麦: {batch_id}")
|
||||
|
||||
try:
|
||||
# 发送当前状态
|
||||
status = task_manager.get_batch_status(batch_id)
|
||||
if status:
|
||||
await websocket.send_json({
|
||||
"type": "status",
|
||||
"batch_id": batch_id,
|
||||
**status
|
||||
})
|
||||
|
||||
# 发送历史日志(只发送注册时已存在的日志,避免与实时推送重复)
|
||||
history_logs = task_manager.get_unsent_batch_logs(batch_id, websocket)
|
||||
for log in history_logs:
|
||||
await websocket.send_json({
|
||||
"type": "log",
|
||||
"batch_id": batch_id,
|
||||
"message": log
|
||||
})
|
||||
|
||||
# 保持连接,等待客户端消息
|
||||
while True:
|
||||
try:
|
||||
data = await asyncio.wait_for(
|
||||
websocket.receive_json(),
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
# 处理心跳
|
||||
if data.get("type") == "ping":
|
||||
await websocket.send_json({"type": "pong"})
|
||||
|
||||
# 处理取消请求
|
||||
elif data.get("type") == "cancel":
|
||||
task_manager.cancel_batch(batch_id)
|
||||
await websocket.send_json({
|
||||
"type": "status",
|
||||
"batch_id": batch_id,
|
||||
"status": "cancelling",
|
||||
"message": "取消请求已提交,正在让整队缓缓靠边停车"
|
||||
})
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时,发送心跳检测
|
||||
try:
|
||||
await websocket.send_json({"type": "ping"})
|
||||
except Exception:
|
||||
logger.info(f"批量任务 WebSocket 心跳检测失败: {batch_id}")
|
||||
break
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"批量任务 WebSocket 断开: {batch_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量任务 WebSocket 错误: {e}")
|
||||
|
||||
finally:
|
||||
task_manager.unregister_batch_websocket(batch_id, websocket)
|
||||
386
src/web/task_manager.py
Normal file
386
src/web/task_manager.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
任务管理器
|
||||
负责管理后台任务、日志队列和 WebSocket 推送
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, Optional, List, Callable, Any
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全局线程池(支持最多 50 个并发注册任务)
|
||||
_executor = ThreadPoolExecutor(max_workers=50, thread_name_prefix="reg_worker")
|
||||
|
||||
# 全局元锁:保护所有 defaultdict 的首次 key 创建(避免多线程竞态)
|
||||
_meta_lock = threading.Lock()
|
||||
|
||||
# 任务日志队列 (task_uuid -> list of logs)
|
||||
_log_queues: Dict[str, List[str]] = defaultdict(list)
|
||||
_log_locks: Dict[str, threading.Lock] = {}
|
||||
|
||||
# WebSocket 连接管理 (task_uuid -> list of websockets)
|
||||
_ws_connections: Dict[str, List] = defaultdict(list)
|
||||
_ws_lock = threading.Lock()
|
||||
|
||||
# WebSocket 已发送日志索引 (task_uuid -> {websocket: sent_count})
|
||||
_ws_sent_index: Dict[str, Dict] = defaultdict(dict)
|
||||
|
||||
# 任务状态
|
||||
_task_status: Dict[str, dict] = {}
|
||||
|
||||
# 任务取消标志
|
||||
_task_cancelled: Dict[str, bool] = {}
|
||||
|
||||
# 批量任务状态 (batch_id -> dict)
|
||||
_batch_status: Dict[str, dict] = {}
|
||||
_batch_logs: Dict[str, List[str]] = defaultdict(list)
|
||||
_batch_locks: Dict[str, threading.Lock] = {}
|
||||
|
||||
|
||||
def _get_log_lock(task_uuid: str) -> threading.Lock:
|
||||
"""线程安全地获取或创建任务日志锁"""
|
||||
if task_uuid not in _log_locks:
|
||||
with _meta_lock:
|
||||
if task_uuid not in _log_locks:
|
||||
_log_locks[task_uuid] = threading.Lock()
|
||||
return _log_locks[task_uuid]
|
||||
|
||||
|
||||
def _get_batch_lock(batch_id: str) -> threading.Lock:
|
||||
"""线程安全地获取或创建批量任务日志锁"""
|
||||
if batch_id not in _batch_locks:
|
||||
with _meta_lock:
|
||||
if batch_id not in _batch_locks:
|
||||
_batch_locks[batch_id] = threading.Lock()
|
||||
return _batch_locks[batch_id]
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""任务管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.executor = _executor
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
|
||||
def set_loop(self, loop: asyncio.AbstractEventLoop):
|
||||
"""设置事件循环(在 FastAPI 启动时调用)"""
|
||||
self._loop = loop
|
||||
|
||||
def get_loop(self) -> Optional[asyncio.AbstractEventLoop]:
|
||||
"""获取事件循环"""
|
||||
return self._loop
|
||||
|
||||
def is_cancelled(self, task_uuid: str) -> bool:
|
||||
"""检查任务是否已取消"""
|
||||
return _task_cancelled.get(task_uuid, False)
|
||||
|
||||
def cancel_task(self, task_uuid: str):
|
||||
"""取消任务"""
|
||||
_task_cancelled[task_uuid] = True
|
||||
logger.info(f"任务 {task_uuid} 已标记为取消")
|
||||
|
||||
def add_log(self, task_uuid: str, log_message: str):
|
||||
"""添加日志并推送到 WebSocket(线程安全)"""
|
||||
# 先广播到 WebSocket,确保实时推送
|
||||
# 然后再添加到队列,这样 get_unsent_logs 不会获取到这条日志
|
||||
if self._loop and self._loop.is_running():
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._broadcast_log(task_uuid, log_message),
|
||||
self._loop
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送日志到 WebSocket 失败: {e}")
|
||||
|
||||
# 广播后再添加到队列
|
||||
with _get_log_lock(task_uuid):
|
||||
_log_queues[task_uuid].append(log_message)
|
||||
|
||||
async def _broadcast_log(self, task_uuid: str, log_message: str):
|
||||
"""广播日志到所有 WebSocket 连接"""
|
||||
with _ws_lock:
|
||||
connections = _ws_connections.get(task_uuid, []).copy()
|
||||
# 注意:不在这里更新 sent_index,因为日志已经通过 add_log 添加到队列
|
||||
# sent_index 应该只在 get_unsent_logs 或发送历史日志时更新
|
||||
# 这样可以避免竞态条件
|
||||
|
||||
for ws in connections:
|
||||
try:
|
||||
await ws.send_json({
|
||||
"type": "log",
|
||||
"task_uuid": task_uuid,
|
||||
"message": log_message,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
# 发送成功后更新 sent_index
|
||||
with _ws_lock:
|
||||
ws_id = id(ws)
|
||||
if task_uuid in _ws_sent_index and ws_id in _ws_sent_index[task_uuid]:
|
||||
_ws_sent_index[task_uuid][ws_id] += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket 发送失败: {e}")
|
||||
|
||||
async def broadcast_status(self, task_uuid: str, status: str, **kwargs):
|
||||
"""广播任务状态更新"""
|
||||
with _ws_lock:
|
||||
connections = _ws_connections.get(task_uuid, []).copy()
|
||||
|
||||
message = {
|
||||
"type": "status",
|
||||
"task_uuid": task_uuid,
|
||||
"status": status,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
**kwargs
|
||||
}
|
||||
|
||||
for ws in connections:
|
||||
try:
|
||||
await ws.send_json(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket 发送状态失败: {e}")
|
||||
|
||||
def register_websocket(self, task_uuid: str, websocket):
|
||||
"""注册 WebSocket 连接"""
|
||||
with _ws_lock:
|
||||
if task_uuid not in _ws_connections:
|
||||
_ws_connections[task_uuid] = []
|
||||
# 避免重复注册同一个连接
|
||||
if websocket not in _ws_connections[task_uuid]:
|
||||
_ws_connections[task_uuid].append(websocket)
|
||||
# 记录已发送的日志数量,用于发送历史日志时避免重复
|
||||
with _get_log_lock(task_uuid):
|
||||
_ws_sent_index[task_uuid][id(websocket)] = len(_log_queues.get(task_uuid, []))
|
||||
logger.info(f"WebSocket 连接已注册,日志小喇叭准备开播: {task_uuid}")
|
||||
else:
|
||||
logger.warning(f"WebSocket 连接已存在,跳过重复注册: {task_uuid}")
|
||||
|
||||
def get_unsent_logs(self, task_uuid: str, websocket) -> List[str]:
|
||||
"""获取未发送给该 WebSocket 的日志"""
|
||||
with _ws_lock:
|
||||
ws_id = id(websocket)
|
||||
sent_count = _ws_sent_index.get(task_uuid, {}).get(ws_id, 0)
|
||||
|
||||
with _get_log_lock(task_uuid):
|
||||
all_logs = _log_queues.get(task_uuid, [])
|
||||
unsent_logs = all_logs[sent_count:]
|
||||
# 更新已发送索引
|
||||
_ws_sent_index[task_uuid][ws_id] = len(all_logs)
|
||||
return unsent_logs
|
||||
|
||||
def unregister_websocket(self, task_uuid: str, websocket):
|
||||
"""注销 WebSocket 连接"""
|
||||
with _ws_lock:
|
||||
if task_uuid in _ws_connections:
|
||||
try:
|
||||
_ws_connections[task_uuid].remove(websocket)
|
||||
except ValueError:
|
||||
pass
|
||||
# 清理已发送索引
|
||||
if task_uuid in _ws_sent_index:
|
||||
_ws_sent_index[task_uuid].pop(id(websocket), None)
|
||||
logger.info(f"WebSocket 连接已注销: {task_uuid}")
|
||||
|
||||
def get_logs(self, task_uuid: str) -> List[str]:
|
||||
"""获取任务的所有日志"""
|
||||
with _get_log_lock(task_uuid):
|
||||
return _log_queues.get(task_uuid, []).copy()
|
||||
|
||||
def update_status(self, task_uuid: str, status: str, **kwargs):
|
||||
"""更新任务状态"""
|
||||
if task_uuid not in _task_status:
|
||||
_task_status[task_uuid] = {}
|
||||
|
||||
_task_status[task_uuid]["status"] = status
|
||||
_task_status[task_uuid].update(kwargs)
|
||||
|
||||
def get_status(self, task_uuid: str) -> Optional[dict]:
|
||||
"""获取任务状态"""
|
||||
return _task_status.get(task_uuid)
|
||||
|
||||
def cleanup_task(self, task_uuid: str):
|
||||
"""清理任务数据"""
|
||||
# 保留日志队列一段时间,以便后续查询
|
||||
# 只清理取消标志
|
||||
if task_uuid in _task_cancelled:
|
||||
del _task_cancelled[task_uuid]
|
||||
|
||||
# ============== 批量任务管理 ==============
|
||||
|
||||
def init_batch(self, batch_id: str, total: int):
|
||||
"""初始化批量任务"""
|
||||
_batch_status[batch_id] = {
|
||||
"status": "running",
|
||||
"total": total,
|
||||
"completed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
"current_index": 0,
|
||||
"finished": False
|
||||
}
|
||||
logger.info(f"批量任务 {batch_id} 已初始化,总数: {total}")
|
||||
|
||||
def add_batch_log(self, batch_id: str, log_message: str):
|
||||
"""添加批量任务日志并推送"""
|
||||
# 先广播到 WebSocket,确保实时推送
|
||||
if self._loop and self._loop.is_running():
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._broadcast_batch_log(batch_id, log_message),
|
||||
self._loop
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送批量日志到 WebSocket 失败: {e}")
|
||||
|
||||
# 广播后再添加到队列
|
||||
with _get_batch_lock(batch_id):
|
||||
_batch_logs[batch_id].append(log_message)
|
||||
|
||||
async def _broadcast_batch_log(self, batch_id: str, log_message: str):
|
||||
"""广播批量任务日志"""
|
||||
key = f"batch_{batch_id}"
|
||||
with _ws_lock:
|
||||
connections = _ws_connections.get(key, []).copy()
|
||||
# 注意:不在这里更新 sent_index,避免竞态条件
|
||||
|
||||
for ws in connections:
|
||||
try:
|
||||
await ws.send_json({
|
||||
"type": "log",
|
||||
"batch_id": batch_id,
|
||||
"message": log_message,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
# 发送成功后更新 sent_index
|
||||
with _ws_lock:
|
||||
ws_id = id(ws)
|
||||
if key in _ws_sent_index and ws_id in _ws_sent_index[key]:
|
||||
_ws_sent_index[key][ws_id] += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket 发送批量日志失败: {e}")
|
||||
|
||||
def update_batch_status(self, batch_id: str, **kwargs):
|
||||
"""更新批量任务状态"""
|
||||
if batch_id not in _batch_status:
|
||||
logger.warning(f"批量任务 {batch_id} 不存在")
|
||||
return
|
||||
|
||||
_batch_status[batch_id].update(kwargs)
|
||||
|
||||
# 异步广播状态更新
|
||||
if self._loop and self._loop.is_running():
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._broadcast_batch_status(batch_id),
|
||||
self._loop
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"广播批量状态失败: {e}")
|
||||
|
||||
async def _broadcast_batch_status(self, batch_id: str):
|
||||
"""广播批量任务状态"""
|
||||
with _ws_lock:
|
||||
connections = _ws_connections.get(f"batch_{batch_id}", []).copy()
|
||||
|
||||
status = _batch_status.get(batch_id, {})
|
||||
|
||||
for ws in connections:
|
||||
try:
|
||||
await ws.send_json({
|
||||
"type": "status",
|
||||
"batch_id": batch_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
**status
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket 发送批量状态失败: {e}")
|
||||
|
||||
def get_batch_status(self, batch_id: str) -> Optional[dict]:
|
||||
"""获取批量任务状态"""
|
||||
return _batch_status.get(batch_id)
|
||||
|
||||
def get_batch_logs(self, batch_id: str) -> List[str]:
|
||||
"""获取批量任务日志"""
|
||||
with _get_batch_lock(batch_id):
|
||||
return _batch_logs.get(batch_id, []).copy()
|
||||
|
||||
def is_batch_cancelled(self, batch_id: str) -> bool:
|
||||
"""检查批量任务是否已取消"""
|
||||
status = _batch_status.get(batch_id, {})
|
||||
return status.get("cancelled", False)
|
||||
|
||||
def cancel_batch(self, batch_id: str):
|
||||
"""取消批量任务"""
|
||||
if batch_id in _batch_status:
|
||||
_batch_status[batch_id]["cancelled"] = True
|
||||
_batch_status[batch_id]["status"] = "cancelling"
|
||||
logger.info(f"批量任务 {batch_id} 已标记为取消")
|
||||
|
||||
def register_batch_websocket(self, batch_id: str, websocket):
|
||||
"""注册批量任务 WebSocket 连接"""
|
||||
key = f"batch_{batch_id}"
|
||||
with _ws_lock:
|
||||
if key not in _ws_connections:
|
||||
_ws_connections[key] = []
|
||||
# 避免重复注册同一个连接
|
||||
if websocket not in _ws_connections[key]:
|
||||
_ws_connections[key].append(websocket)
|
||||
# 记录已发送的日志数量,用于发送历史日志时避免重复
|
||||
with _get_batch_lock(batch_id):
|
||||
_ws_sent_index[key][id(websocket)] = len(_batch_logs.get(batch_id, []))
|
||||
logger.info(f"批量任务 WebSocket 连接已注册,批量频道开始集合: {batch_id}")
|
||||
else:
|
||||
logger.warning(f"批量任务 WebSocket 连接已存在,跳过重复注册: {batch_id}")
|
||||
|
||||
def get_unsent_batch_logs(self, batch_id: str, websocket) -> List[str]:
|
||||
"""获取未发送给该 WebSocket 的批量任务日志"""
|
||||
key = f"batch_{batch_id}"
|
||||
with _ws_lock:
|
||||
ws_id = id(websocket)
|
||||
sent_count = _ws_sent_index.get(key, {}).get(ws_id, 0)
|
||||
|
||||
with _get_batch_lock(batch_id):
|
||||
all_logs = _batch_logs.get(batch_id, [])
|
||||
unsent_logs = all_logs[sent_count:]
|
||||
# 更新已发送索引
|
||||
_ws_sent_index[key][ws_id] = len(all_logs)
|
||||
return unsent_logs
|
||||
|
||||
def unregister_batch_websocket(self, batch_id: str, websocket):
|
||||
"""注销批量任务 WebSocket 连接"""
|
||||
key = f"batch_{batch_id}"
|
||||
with _ws_lock:
|
||||
if key in _ws_connections:
|
||||
try:
|
||||
_ws_connections[key].remove(websocket)
|
||||
except ValueError:
|
||||
pass
|
||||
# 清理已发送索引
|
||||
if key in _ws_sent_index:
|
||||
_ws_sent_index[key].pop(id(websocket), None)
|
||||
logger.info(f"批量任务 WebSocket 连接已注销: {batch_id}")
|
||||
|
||||
def create_log_callback(self, task_uuid: str, prefix: str = "", batch_id: str = "") -> Callable[[str], None]:
|
||||
"""创建日志回调函数,可附加任务编号前缀,并同时推送到批量任务频道"""
|
||||
def callback(msg: str):
|
||||
full_msg = f"{prefix} {msg}" if prefix else msg
|
||||
self.add_log(task_uuid, full_msg)
|
||||
# 如果属于批量任务,同步推送到 batch 频道,前端可在混合日志中看到详细步骤
|
||||
if batch_id:
|
||||
self.add_batch_log(batch_id, full_msg)
|
||||
return callback
|
||||
|
||||
def create_check_cancelled_callback(self, task_uuid: str) -> Callable[[], bool]:
|
||||
"""创建检查取消的回调函数"""
|
||||
def callback() -> bool:
|
||||
return self.is_cancelled(task_uuid)
|
||||
return callback
|
||||
|
||||
|
||||
# 全局实例
|
||||
task_manager = TaskManager()
|
||||
Reference in New Issue
Block a user