297 lines
10 KiB
Python
297 lines
10 KiB
Python
import base64
|
|
import json
|
|
|
|
from src.config.constants import EmailServiceType, OPENAI_API_ENDPOINTS, OPENAI_PAGE_TYPES
|
|
from src.core.http_client import OpenAIHTTPClient
|
|
from src.core.openai.oauth import OAuthStart
|
|
from src.core.register import RegistrationEngine
|
|
from src.services.base import BaseEmailService
|
|
|
|
|
|
class DummyResponse:
|
|
def __init__(self, status_code=200, payload=None, text="", headers=None, on_return=None):
|
|
self.status_code = status_code
|
|
self._payload = payload
|
|
self.text = text
|
|
self.headers = headers or {}
|
|
self.on_return = on_return
|
|
|
|
def json(self):
|
|
if self._payload is None:
|
|
raise ValueError("no json payload")
|
|
return self._payload
|
|
|
|
|
|
class QueueSession:
|
|
def __init__(self, steps):
|
|
self.steps = list(steps)
|
|
self.calls = []
|
|
self.cookies = {}
|
|
|
|
def get(self, url, **kwargs):
|
|
return self._request("GET", url, **kwargs)
|
|
|
|
def post(self, url, **kwargs):
|
|
return self._request("POST", url, **kwargs)
|
|
|
|
def request(self, method, url, **kwargs):
|
|
return self._request(method.upper(), url, **kwargs)
|
|
|
|
def close(self):
|
|
return None
|
|
|
|
def _request(self, method, url, **kwargs):
|
|
self.calls.append({
|
|
"method": method,
|
|
"url": url,
|
|
"kwargs": kwargs,
|
|
})
|
|
if not self.steps:
|
|
raise AssertionError(f"unexpected request: {method} {url}")
|
|
expected_method, expected_url, response = self.steps.pop(0)
|
|
assert method == expected_method
|
|
assert url == expected_url
|
|
if callable(response):
|
|
response = response(self)
|
|
if response.on_return:
|
|
response.on_return(self)
|
|
return response
|
|
|
|
|
|
class FakeEmailService(BaseEmailService):
|
|
def __init__(self, codes):
|
|
super().__init__(EmailServiceType.TEMPMAIL)
|
|
self.codes = list(codes)
|
|
self.otp_requests = []
|
|
|
|
def create_email(self, config=None):
|
|
return {
|
|
"email": "tester@example.com",
|
|
"service_id": "mailbox-1",
|
|
}
|
|
|
|
def get_verification_code(self, email, email_id=None, timeout=120, pattern=r"(?<!\d)(\d{6})(?!\d)", otp_sent_at=None):
|
|
self.otp_requests.append({
|
|
"email": email,
|
|
"email_id": email_id,
|
|
"otp_sent_at": otp_sent_at,
|
|
})
|
|
if not self.codes:
|
|
raise AssertionError("no verification code queued")
|
|
return self.codes.pop(0)
|
|
|
|
def list_emails(self, **kwargs):
|
|
return []
|
|
|
|
def delete_email(self, email_id):
|
|
return True
|
|
|
|
def check_health(self):
|
|
return True
|
|
|
|
|
|
class FakeOAuthManager:
|
|
def __init__(self):
|
|
self.start_calls = 0
|
|
self.callback_calls = []
|
|
|
|
def start_oauth(self):
|
|
self.start_calls += 1
|
|
return OAuthStart(
|
|
auth_url=f"https://auth.example.test/flow/{self.start_calls}",
|
|
state=f"state-{self.start_calls}",
|
|
code_verifier=f"verifier-{self.start_calls}",
|
|
redirect_uri="http://localhost:1455/auth/callback",
|
|
)
|
|
|
|
def handle_callback(self, callback_url, expected_state, code_verifier):
|
|
self.callback_calls.append({
|
|
"callback_url": callback_url,
|
|
"expected_state": expected_state,
|
|
"code_verifier": code_verifier,
|
|
})
|
|
return {
|
|
"account_id": "acct-1",
|
|
"access_token": "access-1",
|
|
"refresh_token": "refresh-1",
|
|
"id_token": "id-1",
|
|
}
|
|
|
|
|
|
class FakeOpenAIClient:
|
|
def __init__(self, sessions, sentinel_tokens):
|
|
self._sessions = list(sessions)
|
|
self._session_index = 0
|
|
self._session = self._sessions[0]
|
|
self._sentinel_tokens = list(sentinel_tokens)
|
|
|
|
@property
|
|
def session(self):
|
|
return self._session
|
|
|
|
def check_ip_location(self):
|
|
return True, "US"
|
|
|
|
def check_sentinel(self, did):
|
|
if not self._sentinel_tokens:
|
|
raise AssertionError("no sentinel token queued")
|
|
return self._sentinel_tokens.pop(0)
|
|
|
|
def close(self):
|
|
if self._session_index + 1 < len(self._sessions):
|
|
self._session_index += 1
|
|
self._session = self._sessions[self._session_index]
|
|
|
|
|
|
def _workspace_cookie(workspace_id):
|
|
payload = base64.urlsafe_b64encode(
|
|
json.dumps({"workspaces": [{"id": workspace_id}]}).encode("utf-8")
|
|
).decode("ascii").rstrip("=")
|
|
return f"{payload}.sig"
|
|
|
|
|
|
def _response_with_did(did):
|
|
return DummyResponse(
|
|
status_code=200,
|
|
text="ok",
|
|
on_return=lambda session: session.cookies.__setitem__("oai-did", did),
|
|
)
|
|
|
|
|
|
def _response_with_login_cookies(workspace_id="ws-1", session_token="session-1"):
|
|
def setter(session):
|
|
session.cookies["oai-client-auth-session"] = _workspace_cookie(workspace_id)
|
|
session.cookies["__Secure-next-auth.session-token"] = session_token
|
|
|
|
return DummyResponse(status_code=200, payload={}, on_return=setter)
|
|
|
|
|
|
def test_check_sentinel_sends_non_empty_pow(monkeypatch):
|
|
session = QueueSession([
|
|
("POST", OPENAI_API_ENDPOINTS["sentinel"], DummyResponse(payload={"token": "sentinel-token"})),
|
|
])
|
|
client = OpenAIHTTPClient()
|
|
client._session = session
|
|
|
|
monkeypatch.setattr(
|
|
"src.core.http_client.build_sentinel_pow_token",
|
|
lambda user_agent: "gAAAAACpow-token",
|
|
)
|
|
|
|
token = client.check_sentinel("device-1")
|
|
|
|
assert token == "sentinel-token"
|
|
body = json.loads(session.calls[0]["kwargs"]["data"])
|
|
assert body["id"] == "device-1"
|
|
assert body["flow"] == "authorize_continue"
|
|
assert body["p"] == "gAAAAACpow-token"
|
|
|
|
|
|
def test_run_registers_then_relogs_to_fetch_token():
|
|
session_one = QueueSession([
|
|
("GET", "https://auth.example.test/flow/1", _response_with_did("did-1")),
|
|
(
|
|
"POST",
|
|
OPENAI_API_ENDPOINTS["signup"],
|
|
DummyResponse(payload={"page": {"type": OPENAI_PAGE_TYPES["PASSWORD_REGISTRATION"]}}),
|
|
),
|
|
("POST", OPENAI_API_ENDPOINTS["register"], DummyResponse(payload={})),
|
|
("GET", OPENAI_API_ENDPOINTS["send_otp"], DummyResponse(payload={})),
|
|
("POST", OPENAI_API_ENDPOINTS["validate_otp"], DummyResponse(payload={})),
|
|
("POST", OPENAI_API_ENDPOINTS["create_account"], DummyResponse(payload={})),
|
|
])
|
|
session_two = QueueSession([
|
|
("GET", "https://auth.example.test/flow/2", _response_with_did("did-2")),
|
|
(
|
|
"POST",
|
|
OPENAI_API_ENDPOINTS["signup"],
|
|
DummyResponse(payload={"page": {"type": OPENAI_PAGE_TYPES["LOGIN_PASSWORD"]}}),
|
|
),
|
|
(
|
|
"POST",
|
|
OPENAI_API_ENDPOINTS["password_verify"],
|
|
DummyResponse(payload={"page": {"type": OPENAI_PAGE_TYPES["EMAIL_OTP_VERIFICATION"]}}),
|
|
),
|
|
("POST", OPENAI_API_ENDPOINTS["validate_otp"], _response_with_login_cookies()),
|
|
(
|
|
"POST",
|
|
OPENAI_API_ENDPOINTS["select_workspace"],
|
|
DummyResponse(payload={"continue_url": "https://auth.example.test/continue"}),
|
|
),
|
|
(
|
|
"GET",
|
|
"https://auth.example.test/continue",
|
|
DummyResponse(
|
|
status_code=302,
|
|
headers={"Location": "http://localhost:1455/auth/callback?code=code-2&state=state-2"},
|
|
),
|
|
),
|
|
])
|
|
|
|
email_service = FakeEmailService(["123456", "654321"])
|
|
engine = RegistrationEngine(email_service)
|
|
fake_oauth = FakeOAuthManager()
|
|
engine.http_client = FakeOpenAIClient([session_one, session_two], ["sentinel-1", "sentinel-2"])
|
|
engine.oauth_manager = fake_oauth
|
|
|
|
result = engine.run()
|
|
|
|
assert result.success is True
|
|
assert result.source == "register"
|
|
assert result.workspace_id == "ws-1"
|
|
assert result.session_token == "session-1"
|
|
assert fake_oauth.start_calls == 2
|
|
assert len(email_service.otp_requests) == 2
|
|
assert all(item["otp_sent_at"] is not None for item in email_service.otp_requests)
|
|
assert sum(1 for call in session_one.calls if call["url"] == OPENAI_API_ENDPOINTS["send_otp"]) == 1
|
|
assert sum(1 for call in session_two.calls if call["url"] == OPENAI_API_ENDPOINTS["send_otp"]) == 0
|
|
assert sum(1 for call in session_one.calls if call["url"] == OPENAI_API_ENDPOINTS["select_workspace"]) == 0
|
|
assert sum(1 for call in session_two.calls if call["url"] == OPENAI_API_ENDPOINTS["select_workspace"]) == 1
|
|
relogin_start_body = json.loads(session_two.calls[1]["kwargs"]["data"])
|
|
assert relogin_start_body["screen_hint"] == "login"
|
|
assert relogin_start_body["username"]["value"] == "tester@example.com"
|
|
password_verify_body = json.loads(session_two.calls[2]["kwargs"]["data"])
|
|
assert password_verify_body == {"password": result.password}
|
|
assert result.metadata["token_acquired_via_relogin"] is True
|
|
|
|
|
|
def test_existing_account_login_uses_auto_sent_otp_without_manual_send():
|
|
session = QueueSession([
|
|
("GET", "https://auth.example.test/flow/1", _response_with_did("did-1")),
|
|
(
|
|
"POST",
|
|
OPENAI_API_ENDPOINTS["signup"],
|
|
DummyResponse(payload={"page": {"type": OPENAI_PAGE_TYPES["EMAIL_OTP_VERIFICATION"]}}),
|
|
),
|
|
("POST", OPENAI_API_ENDPOINTS["validate_otp"], _response_with_login_cookies("ws-existing", "session-existing")),
|
|
(
|
|
"POST",
|
|
OPENAI_API_ENDPOINTS["select_workspace"],
|
|
DummyResponse(payload={"continue_url": "https://auth.example.test/continue-existing"}),
|
|
),
|
|
(
|
|
"GET",
|
|
"https://auth.example.test/continue-existing",
|
|
DummyResponse(
|
|
status_code=302,
|
|
headers={"Location": "http://localhost:1455/auth/callback?code=code-1&state=state-1"},
|
|
),
|
|
),
|
|
])
|
|
|
|
email_service = FakeEmailService(["246810"])
|
|
engine = RegistrationEngine(email_service)
|
|
fake_oauth = FakeOAuthManager()
|
|
engine.http_client = FakeOpenAIClient([session], ["sentinel-1"])
|
|
engine.oauth_manager = fake_oauth
|
|
|
|
result = engine.run()
|
|
|
|
assert result.success is True
|
|
assert result.source == "login"
|
|
assert fake_oauth.start_calls == 1
|
|
assert sum(1 for call in session.calls if call["url"] == OPENAI_API_ENDPOINTS["send_otp"]) == 0
|
|
assert len(email_service.otp_requests) == 1
|
|
assert email_service.otp_requests[0]["otp_sent_at"] is not None
|
|
assert result.metadata["token_acquired_via_relogin"] is False
|