diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index fdf1868..3b9adb0 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -169,11 +169,41 @@ class AgentsConfig(BaseModel): defaults: AgentDefaults = Field(default_factory=AgentDefaults) +class OAuthCredentials(BaseModel): + """OAuth token credentials for subscription-based auth.""" + access_token: str = "" + refresh_token: str = "" + expires_at: int = 0 # Unix timestamp + token_type: str = "oauth" # "oauth" or "token" (setup-token) + + @property + def is_oauth_token(self) -> bool: + """Check if this is an OAuth token (vs regular API key).""" + return "sk-ant-oat" in self.access_token + + @property + def is_expired(self) -> bool: + """Check if token has expired.""" + import time + if self.expires_at == 0: + return False # No expiry set (setup-token) + return time.time() > self.expires_at + + @property + def expires_soon(self) -> bool: + """Check if token expires within 10 minutes.""" + import time + if self.expires_at == 0: + return False + return time.time() > (self.expires_at - 600) + + class ProviderConfig(BaseModel): """LLM provider configuration.""" api_key: str = "" api_base: str | None = None extra_headers: dict[str, str] | None = None # Custom headers (e.g. APP-Code for AiHubMix) + oauth_credentials: OAuthCredentials | None = None class ProvidersConfig(BaseModel): diff --git a/tests/test_oauth_config.py b/tests/test_oauth_config.py new file mode 100644 index 0000000..67a4fd0 --- /dev/null +++ b/tests/test_oauth_config.py @@ -0,0 +1,48 @@ +"""Test OAuth configuration schema.""" +import pytest +from nanobot.config.schema import ProviderConfig, OAuthCredentials + + +def test_provider_config_has_oauth_fields(): + """ProviderConfig should have oauth_credentials field.""" + config = ProviderConfig(api_key="test") + assert hasattr(config, "oauth_credentials") + assert config.oauth_credentials is None + + +def test_oauth_credentials_model(): + """OAuthCredentials should store token, refresh, expiry.""" + creds = OAuthCredentials( + access_token="sk-ant-oat01-xxx", + refresh_token="rt_xxx", + expires_at=1234567890, + token_type="oauth" + ) + assert creds.access_token.startswith("sk-ant-oat") + assert creds.is_oauth_token is True + + +def test_oauth_credentials_expiry_check(): + """OAuthCredentials should detect expired tokens.""" + import time + expired = OAuthCredentials( + access_token="sk-ant-oat01-xxx", + expires_at=int(time.time()) - 3600 # 1 hour ago + ) + assert expired.is_expired is True + + valid = OAuthCredentials( + access_token="sk-ant-oat01-xxx", + expires_at=int(time.time()) + 3600 # 1 hour from now + ) + assert valid.is_expired is False + + +def test_oauth_credentials_no_expiry(): + """Setup tokens with expires_at=0 should never be expired.""" + creds = OAuthCredentials( + access_token="sk-ant-oat01-xxx", + expires_at=0 # No expiry (setup-token) + ) + assert creds.is_expired is False + assert creds.expires_soon is False