Compare commits
10 Commits
890d7cf853
...
55ad41265f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
55ad41265f | ||
|
|
6d0d995b1b | ||
|
|
f444e94ff7 | ||
|
|
5f9af317c4 | ||
|
|
4b3bc89d06 | ||
|
|
3323b9d909 | ||
|
|
96a7abcda4 | ||
|
|
534a8344bd | ||
|
|
7b710116a4 | ||
|
|
fc9545c36a |
11
Dockerfile.oauth
Normal file
11
Dockerfile.oauth
Normal file
@@ -0,0 +1,11 @@
|
||||
FROM birdxs/nanobot:latest
|
||||
|
||||
# Copy full project (pyproject.toml + source)
|
||||
COPY pyproject.toml README.md LICENSE /app/
|
||||
COPY nanobot/ /app/nanobot/
|
||||
|
||||
# Install with all dependencies
|
||||
RUN uv pip install --system --no-cache --reinstall /app
|
||||
|
||||
ENTRYPOINT ["nanobot"]
|
||||
CMD ["status"]
|
||||
@@ -270,18 +270,18 @@ This file stores important information that should persist across sessions.
|
||||
|
||||
|
||||
def _make_provider(config):
|
||||
"""Create LiteLLMProvider from config. Exits if no API key found."""
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
"""Create LLM provider from config. Uses OAuth for subscription tokens."""
|
||||
from nanobot.providers import create_provider
|
||||
p = config.get_provider()
|
||||
model = config.agents.defaults.model
|
||||
if not (p and p.api_key) and not model.startswith("bedrock/"):
|
||||
console.print("[red]Error: No API key configured.[/red]")
|
||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
||||
raise typer.Exit(1)
|
||||
return LiteLLMProvider(
|
||||
return create_provider(
|
||||
api_key=p.api_key if p else None,
|
||||
model=model,
|
||||
api_base=config.get_api_base(),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
provider_name=config.get_provider_name(),
|
||||
)
|
||||
@@ -507,6 +507,9 @@ def agent(
|
||||
channels_app = typer.Typer(help="Manage channels")
|
||||
app.add_typer(channels_app, name="channels")
|
||||
|
||||
from nanobot.cli.oauth import oauth_app
|
||||
app.add_typer(oauth_app, name="oauth")
|
||||
|
||||
|
||||
@channels_app.command("status")
|
||||
def channels_status():
|
||||
|
||||
93
nanobot/cli/oauth.py
Normal file
93
nanobot/cli/oauth.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""OAuth CLI commands for subscription authentication."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
|
||||
oauth_app = typer.Typer(help="Manage OAuth authentication for subscription-based providers")
|
||||
console = Console()
|
||||
|
||||
|
||||
@oauth_app.command("login")
|
||||
def login(
|
||||
provider: str = typer.Argument("anthropic", help="Provider name"),
|
||||
token: Optional[str] = typer.Option(None, "--token", "-t", help="OAuth token (from claude setup-token)"),
|
||||
):
|
||||
"""Login to a provider using OAuth.
|
||||
|
||||
For Anthropic Claude Max/Pro, run 'claude setup-token' and paste the token here.
|
||||
|
||||
Example:
|
||||
nanobot oauth login anthropic --token sk-ant-oat01-xxx
|
||||
"""
|
||||
from nanobot.config.oauth_store import OAuthStore
|
||||
from nanobot.config.schema import OAuthCredentials
|
||||
|
||||
if provider != "anthropic":
|
||||
console.print(f"[red]OAuth login for {provider} not yet supported[/red]")
|
||||
return
|
||||
|
||||
if not token:
|
||||
console.print("Please provide your OAuth token:")
|
||||
console.print(" 1. Run: claude setup-token")
|
||||
console.print(" 2. Copy the sk-ant-oat01-... token")
|
||||
console.print(" 3. Run: nanobot oauth login anthropic --token <your-token>")
|
||||
console.print()
|
||||
token = typer.prompt("Token", hide_input=True)
|
||||
|
||||
if not token or "sk-ant-oat" not in token:
|
||||
console.print("[red]Invalid token. Must contain sk-ant-oat[/red]")
|
||||
return
|
||||
|
||||
store = OAuthStore(Path.home() / ".nanobot")
|
||||
creds = OAuthCredentials(
|
||||
access_token=token,
|
||||
token_type="token" # setup-token doesn't expire
|
||||
)
|
||||
store.save(provider, creds)
|
||||
|
||||
console.print(f"[green]Successfully saved {provider} OAuth credentials![/green]")
|
||||
|
||||
|
||||
@oauth_app.command("status")
|
||||
def status():
|
||||
"""Show OAuth credential status."""
|
||||
from nanobot.config.oauth_store import OAuthStore
|
||||
|
||||
store = OAuthStore(Path.home() / ".nanobot")
|
||||
|
||||
providers = ["anthropic"]
|
||||
found_any = False
|
||||
|
||||
for provider in providers:
|
||||
creds = store.load(provider)
|
||||
if creds:
|
||||
found_any = True
|
||||
st = "valid"
|
||||
if creds.is_expired:
|
||||
st = "EXPIRED"
|
||||
elif creds.expires_soon:
|
||||
st = "expires soon"
|
||||
|
||||
token_preview = creds.access_token[:20] + "..."
|
||||
console.print(f" {provider}: {token_preview} ({st})")
|
||||
|
||||
if not found_any:
|
||||
console.print("No OAuth credentials configured.")
|
||||
console.print("Run: nanobot oauth login anthropic --token <token>")
|
||||
|
||||
|
||||
@oauth_app.command("logout")
|
||||
def logout(
|
||||
provider: str = typer.Argument("anthropic", help="Provider name"),
|
||||
):
|
||||
"""Remove OAuth credentials for a provider."""
|
||||
from nanobot.config.oauth_store import OAuthStore
|
||||
|
||||
store = OAuthStore(Path.home() / ".nanobot")
|
||||
if store.delete(provider):
|
||||
console.print(f"[green]Removed {provider} OAuth credentials[/green]")
|
||||
else:
|
||||
console.print(f"No credentials found for {provider}")
|
||||
@@ -12,35 +12,53 @@ def get_config_path() -> Path:
|
||||
return Path.home() / ".nanobot" / "config.json"
|
||||
|
||||
|
||||
def _get_oauth_store_dir() -> Path:
|
||||
"""Get the OAuth store directory."""
|
||||
return Path.home() / ".nanobot"
|
||||
|
||||
|
||||
def get_data_dir() -> Path:
|
||||
"""Get the nanobot data directory."""
|
||||
from nanobot.utils.helpers import get_data_path
|
||||
return get_data_path()
|
||||
|
||||
|
||||
def _inject_oauth_credentials(config: Config) -> Config:
|
||||
"""Inject OAuth credentials from store into config if available."""
|
||||
from nanobot.config.oauth_store import OAuthStore
|
||||
|
||||
store = OAuthStore(_get_oauth_store_dir())
|
||||
creds = store.load("anthropic")
|
||||
if creds and creds.access_token and not creds.is_expired:
|
||||
config.providers.anthropic.api_key = creds.access_token
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def load_config(config_path: Path | None = None) -> Config:
|
||||
"""
|
||||
Load configuration from file or create default.
|
||||
|
||||
|
||||
Args:
|
||||
config_path: Optional path to config file. Uses default if not provided.
|
||||
|
||||
|
||||
Returns:
|
||||
Loaded configuration object.
|
||||
"""
|
||||
path = config_path or get_config_path()
|
||||
|
||||
|
||||
if path.exists():
|
||||
try:
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
data = _migrate_config(data)
|
||||
return Config.model_validate(convert_keys(data))
|
||||
config = Config.model_validate(convert_keys(data))
|
||||
return _inject_oauth_credentials(config)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
print(f"Warning: Failed to load config from {path}: {e}")
|
||||
print("Using default configuration.")
|
||||
|
||||
return Config()
|
||||
|
||||
return _inject_oauth_credentials(Config())
|
||||
|
||||
|
||||
def save_config(config: Config, config_path: Path | None = None) -> None:
|
||||
|
||||
59
nanobot/config/oauth_store.py
Normal file
59
nanobot/config/oauth_store.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""OAuth credential storage."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.config.schema import OAuthCredentials
|
||||
|
||||
|
||||
class OAuthStore:
|
||||
"""Stores OAuth credentials in a JSON file."""
|
||||
|
||||
FILENAME = "oauth-credentials.json"
|
||||
|
||||
def __init__(self, config_dir: Path):
|
||||
self.config_dir = config_dir
|
||||
self.file_path = config_dir / self.FILENAME
|
||||
|
||||
def _load_all(self) -> dict[str, Any]:
|
||||
"""Load all credentials from file."""
|
||||
if not self.file_path.exists():
|
||||
return {}
|
||||
|
||||
with open(self.file_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
def _save_all(self, data: dict[str, Any]) -> None:
|
||||
"""Save all credentials to file."""
|
||||
self.config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(self.file_path, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
# Secure permissions
|
||||
self.file_path.chmod(0o600)
|
||||
|
||||
def save(self, provider: str, credentials: OAuthCredentials) -> None:
|
||||
"""Save credentials for a provider."""
|
||||
data = self._load_all()
|
||||
data[provider] = credentials.model_dump()
|
||||
self._save_all(data)
|
||||
|
||||
def load(self, provider: str) -> OAuthCredentials | None:
|
||||
"""Load credentials for a provider."""
|
||||
data = self._load_all()
|
||||
if provider not in data:
|
||||
return None
|
||||
|
||||
return OAuthCredentials(**data[provider])
|
||||
|
||||
def delete(self, provider: str) -> bool:
|
||||
"""Delete credentials for a provider."""
|
||||
data = self._load_all()
|
||||
if provider not in data:
|
||||
return False
|
||||
|
||||
del data[provider]
|
||||
self._save_all(data)
|
||||
return True
|
||||
@@ -169,11 +169,41 @@ class AgentsConfig(BaseModel):
|
||||
defaults: AgentDefaults = Field(default_factory=AgentDefaults)
|
||||
|
||||
|
||||
class OAuthCredentials(BaseModel):
|
||||
"""OAuth token credentials for subscription-based auth."""
|
||||
access_token: str = ""
|
||||
refresh_token: str = ""
|
||||
expires_at: int = 0 # Unix timestamp
|
||||
token_type: str = "oauth" # "oauth" or "token" (setup-token)
|
||||
|
||||
@property
|
||||
def is_oauth_token(self) -> bool:
|
||||
"""Check if this is an OAuth token (vs regular API key)."""
|
||||
return "sk-ant-oat" in self.access_token
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if token has expired."""
|
||||
import time
|
||||
if self.expires_at == 0:
|
||||
return False # No expiry set (setup-token)
|
||||
return time.time() > self.expires_at
|
||||
|
||||
@property
|
||||
def expires_soon(self) -> bool:
|
||||
"""Check if token expires within 10 minutes."""
|
||||
import time
|
||||
if self.expires_at == 0:
|
||||
return False
|
||||
return time.time() > (self.expires_at - 600)
|
||||
|
||||
|
||||
class ProviderConfig(BaseModel):
|
||||
"""LLM provider configuration."""
|
||||
api_key: str = ""
|
||||
api_base: str | None = None
|
||||
extra_headers: dict[str, str] | None = None # Custom headers (e.g. APP-Code for AiHubMix)
|
||||
oauth_credentials: OAuthCredentials | None = None
|
||||
|
||||
|
||||
class ProvidersConfig(BaseModel):
|
||||
|
||||
@@ -1,6 +1,43 @@
|
||||
"""LLM provider abstraction module."""
|
||||
"""Provider module exports."""
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from nanobot.providers.anthropic_oauth import AnthropicOAuthProvider
|
||||
from nanobot.providers.registry import should_use_oauth_provider
|
||||
|
||||
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider"]
|
||||
__all__ = [
|
||||
"LLMProvider",
|
||||
"LLMResponse",
|
||||
"ToolCallRequest",
|
||||
"LiteLLMProvider",
|
||||
"AnthropicOAuthProvider",
|
||||
"create_provider",
|
||||
]
|
||||
|
||||
|
||||
def create_provider(
|
||||
api_key: str,
|
||||
model: str,
|
||||
api_base: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
provider_name: str | None = None,
|
||||
) -> LLMProvider:
|
||||
"""Factory function to create appropriate provider.
|
||||
|
||||
Automatically selects AnthropicOAuthProvider for OAuth tokens,
|
||||
LiteLLMProvider for everything else.
|
||||
"""
|
||||
if should_use_oauth_provider(api_key, model):
|
||||
return AnthropicOAuthProvider(
|
||||
oauth_token=api_key,
|
||||
default_model=model,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
return LiteLLMProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
default_model=model,
|
||||
extra_headers=extra_headers,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
|
||||
224
nanobot/providers/anthropic_oauth.py
Normal file
224
nanobot/providers/anthropic_oauth.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""Anthropic OAuth provider - direct API calls with Bearer auth.
|
||||
|
||||
This provider bypasses litellm to properly handle OAuth tokens
|
||||
which require Authorization: Bearer header instead of x-api-key.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.oauth_utils import get_auth_headers, get_claude_code_system_prefix
|
||||
|
||||
|
||||
class AnthropicOAuthProvider(LLMProvider):
|
||||
"""
|
||||
Anthropic provider using OAuth token authentication.
|
||||
|
||||
Unlike the LiteLLM provider, this calls the Anthropic API directly
|
||||
with proper Bearer token authentication for Claude Max/Pro subscriptions.
|
||||
"""
|
||||
|
||||
ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
oauth_token: str,
|
||||
default_model: str = "claude-opus-4-5",
|
||||
api_base: str | None = None,
|
||||
):
|
||||
super().__init__(api_key=None, api_base=api_base)
|
||||
self.oauth_token = oauth_token
|
||||
self.default_model = default_model
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
def _get_headers(self) -> dict[str, str]:
|
||||
"""Get request headers with Bearer auth."""
|
||||
return get_auth_headers(self.oauth_token, is_oauth=True)
|
||||
|
||||
def _get_api_url(self) -> str:
|
||||
"""Get API endpoint URL."""
|
||||
if self.api_base:
|
||||
return f"{self.api_base.rstrip('/')}/v1/messages"
|
||||
return self.ANTHROPIC_API_URL
|
||||
|
||||
# Short aliases that need dated suffixes for the API
|
||||
MODEL_ALIASES: dict[str, str] = {
|
||||
"claude-sonnet-4": "claude-sonnet-4-20250514",
|
||||
"claude-opus-4": "claude-opus-4-20250514",
|
||||
"claude-haiku-3-5": "claude-haiku-4-5-20241022",
|
||||
"claude-sonnet-4-5": "claude-sonnet-4-5-20250929",
|
||||
"claude-opus-4-5": "claude-opus-4-5-20250929",
|
||||
"claude-opus-4-6": "claude-opus-4-6",
|
||||
}
|
||||
|
||||
def _resolve_model_alias(self, model: str) -> str:
|
||||
"""Resolve short model aliases to full dated IDs."""
|
||||
return self.MODEL_ALIASES.get(model, model)
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create async HTTP client."""
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=300.0)
|
||||
return self._client
|
||||
|
||||
def _prepare_messages(
|
||||
self,
|
||||
messages: list[dict[str, Any]]
|
||||
) -> tuple[str | None, list[dict[str, Any]]]:
|
||||
"""Prepare messages, extracting system prompt and adding Claude Code identity.
|
||||
|
||||
Returns (system_prompt, messages_without_system)
|
||||
"""
|
||||
system_parts = [get_claude_code_system_prefix()]
|
||||
filtered_messages = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
system_parts.append(msg.get("content", ""))
|
||||
else:
|
||||
filtered_messages.append(msg)
|
||||
|
||||
system_prompt = "\n\n".join(system_parts)
|
||||
return system_prompt, filtered_messages
|
||||
|
||||
def _convert_tools_to_anthropic(
|
||||
self,
|
||||
tools: list[dict[str, Any]] | None
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Convert OpenAI-format tools to Anthropic format."""
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
anthropic_tools = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool["function"]
|
||||
anthropic_tools.append({
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"input_schema": func.get("parameters", {"type": "object", "properties": {}})
|
||||
})
|
||||
|
||||
return anthropic_tools if anthropic_tools else None
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str | None = None,
|
||||
model: str = "claude-opus-4-5",
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Make request to Anthropic API."""
|
||||
client = await self._get_client()
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
|
||||
response = await client.post(
|
||||
self._get_api_url(),
|
||||
headers=self._get_headers(),
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
raise Exception(f"Anthropic API error {response.status_code}: {error_text}")
|
||||
|
||||
return response.json()
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
"""Send chat completion request to Anthropic API."""
|
||||
model = model or self.default_model
|
||||
|
||||
# Strip provider prefix if present (e.g. "anthropic/claude-opus-4-5" -> "claude-opus-4-5")
|
||||
if "/" in model:
|
||||
model = model.split("/")[-1]
|
||||
|
||||
# Resolve short aliases to dated model IDs (API requires dated suffixes)
|
||||
model = self._resolve_model_alias(model)
|
||||
|
||||
system, prepared_messages = self._prepare_messages(messages)
|
||||
anthropic_tools = self._convert_tools_to_anthropic(tools)
|
||||
|
||||
try:
|
||||
response = await self._make_request(
|
||||
messages=prepared_messages,
|
||||
system=system,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
tools=anthropic_tools,
|
||||
)
|
||||
return self._parse_response(response)
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
|
||||
"""Parse Anthropic API response."""
|
||||
content_blocks = response.get("content", [])
|
||||
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
|
||||
for block in content_blocks:
|
||||
if block.get("type") == "text":
|
||||
text_content += block.get("text", "")
|
||||
elif block.get("type") == "tool_use":
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=block.get("id", ""),
|
||||
name=block.get("name", ""),
|
||||
arguments=block.get("input", {}),
|
||||
))
|
||||
|
||||
usage = {}
|
||||
if "usage" in response:
|
||||
usage = {
|
||||
"prompt_tokens": response["usage"].get("input_tokens", 0),
|
||||
"completion_tokens": response["usage"].get("output_tokens", 0),
|
||||
"total_tokens": (
|
||||
response["usage"].get("input_tokens", 0) +
|
||||
response["usage"].get("output_tokens", 0)
|
||||
),
|
||||
}
|
||||
|
||||
return LLMResponse(
|
||||
content=text_content or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=response.get("stop_reason", "end_turn"),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model."""
|
||||
return self.default_model
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
46
nanobot/providers/oauth_utils.py
Normal file
46
nanobot/providers/oauth_utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""OAuth utility functions for Anthropic subscription auth."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def is_oauth_token(token: str | None) -> bool:
|
||||
"""Check if token is an OAuth token (vs regular API key).
|
||||
|
||||
OAuth tokens from Claude Max/Pro contain 'sk-ant-oat' prefix.
|
||||
Regular API keys use 'sk-ant-api03' or similar.
|
||||
"""
|
||||
if not token:
|
||||
return False
|
||||
return "sk-ant-oat" in token
|
||||
|
||||
|
||||
def get_auth_headers(token: str, is_oauth: bool = False) -> dict[str, str]:
|
||||
"""Get authentication headers for Anthropic API.
|
||||
|
||||
OAuth tokens require Authorization: Bearer header.
|
||||
Regular API keys use x-api-key header.
|
||||
"""
|
||||
headers: dict[str, str] = {
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
if is_oauth:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
# Required headers to mimic Claude Code client
|
||||
headers["anthropic-beta"] = "claude-code-20250219,oauth-2025-04-20"
|
||||
headers["anthropic-dangerous-direct-browser-access"] = "true"
|
||||
headers["user-agent"] = "claude-cli/2.1.2 (external, cli)"
|
||||
headers["x-app"] = "cli"
|
||||
else:
|
||||
headers["x-api-key"] = token
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def get_claude_code_system_prefix() -> str:
|
||||
"""Get the required system prompt prefix for OAuth tokens.
|
||||
|
||||
Anthropic requires this identity declaration for OAuth auth.
|
||||
"""
|
||||
return "You are Claude Code, Anthropic's official CLI for Claude."
|
||||
@@ -357,3 +357,24 @@ def find_by_name(name: str) -> ProviderSpec | None:
|
||||
if spec.name == name:
|
||||
return spec
|
||||
return None
|
||||
|
||||
|
||||
def should_use_oauth_provider(api_key: str | None, model: str) -> bool:
|
||||
"""Determine if OAuth provider should be used.
|
||||
|
||||
OAuth provider is used when:
|
||||
1. API key is an OAuth token (contains 'sk-ant-oat')
|
||||
2. Model is an Anthropic model (contains 'claude' or 'anthropic')
|
||||
"""
|
||||
if not api_key:
|
||||
return False
|
||||
|
||||
if "sk-ant-oat" not in api_key:
|
||||
return False
|
||||
|
||||
model_lower = model.lower()
|
||||
anthropic_spec = find_by_name("anthropic")
|
||||
if anthropic_spec:
|
||||
return any(kw in model_lower for kw in anthropic_spec.keywords)
|
||||
|
||||
return False
|
||||
|
||||
88
tests/test_anthropic_oauth.py
Normal file
88
tests/test_anthropic_oauth.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Test Anthropic OAuth provider."""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from nanobot.providers.anthropic_oauth import AnthropicOAuthProvider
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider():
|
||||
"""Create provider with test OAuth token."""
|
||||
return AnthropicOAuthProvider(
|
||||
oauth_token="sk-ant-oat01-test-token",
|
||||
default_model="claude-opus-4-5"
|
||||
)
|
||||
|
||||
|
||||
def test_provider_init(provider):
|
||||
"""Provider should initialize with OAuth token."""
|
||||
assert provider.oauth_token == "sk-ant-oat01-test-token"
|
||||
assert provider.default_model == "claude-opus-4-5"
|
||||
|
||||
|
||||
def test_provider_uses_bearer_auth(provider):
|
||||
"""Provider should use Bearer auth, not x-api-key."""
|
||||
headers = provider._get_headers()
|
||||
assert "Authorization" in headers
|
||||
assert headers["Authorization"].startswith("Bearer ")
|
||||
assert "x-api-key" not in headers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_prepends_system_prompt(provider):
|
||||
"""Chat should prepend Claude Code identity to system prompt."""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
with patch.object(provider, "_make_request", new_callable=AsyncMock) as mock:
|
||||
mock.return_value = {"content": [{"type": "text", "text": "Hi"}], "stop_reason": "end_turn"}
|
||||
await provider.chat(messages)
|
||||
|
||||
call_args = mock.call_args
|
||||
system = call_args[1]["system"]
|
||||
assert "Claude Code" in system
|
||||
|
||||
|
||||
def test_parse_response_text(provider):
|
||||
"""Should parse text response correctly."""
|
||||
response = {
|
||||
"content": [{"type": "text", "text": "Hello world"}],
|
||||
"stop_reason": "end_turn",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5},
|
||||
}
|
||||
result = provider._parse_response(response)
|
||||
assert result.content == "Hello world"
|
||||
assert result.finish_reason == "end_turn"
|
||||
assert result.usage["prompt_tokens"] == 10
|
||||
|
||||
|
||||
def test_parse_response_tool_calls(provider):
|
||||
"""Should parse tool call response correctly."""
|
||||
response = {
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "call_1", "name": "read_file", "input": {"path": "/tmp/test"}}
|
||||
],
|
||||
"stop_reason": "tool_use",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5},
|
||||
}
|
||||
result = provider._parse_response(response)
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "read_file"
|
||||
assert result.tool_calls[0].arguments == {"path": "/tmp/test"}
|
||||
|
||||
|
||||
def test_convert_tools_to_anthropic(provider):
|
||||
"""Should convert OpenAI-format tools to Anthropic format."""
|
||||
openai_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"description": "Read a file",
|
||||
"parameters": {"type": "object", "properties": {"path": {"type": "string"}}}
|
||||
}
|
||||
}
|
||||
]
|
||||
anthropic_tools = provider._convert_tools_to_anthropic(openai_tools)
|
||||
assert len(anthropic_tools) == 1
|
||||
assert anthropic_tools[0]["name"] == "read_file"
|
||||
assert "input_schema" in anthropic_tools[0]
|
||||
55
tests/test_cli_oauth.py
Normal file
55
tests/test_cli_oauth.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Test OAuth CLI commands."""
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typer.testing import CliRunner
|
||||
from nanobot.cli.oauth import oauth_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return CliRunner()
|
||||
|
||||
|
||||
def test_oauth_login_help(runner):
|
||||
"""Login command should have help text."""
|
||||
result = runner.invoke(oauth_app, ["login", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "token" in result.output.lower()
|
||||
|
||||
|
||||
def test_oauth_status_no_credentials(runner, tmp_path, monkeypatch):
|
||||
"""Status should show no credentials when none exist."""
|
||||
monkeypatch.setenv("HOME", str(tmp_path))
|
||||
result = runner.invoke(oauth_app, ["status"])
|
||||
assert result.exit_code == 0
|
||||
assert "No OAuth credentials" in result.output
|
||||
|
||||
|
||||
def test_oauth_login_and_status(runner, tmp_path, monkeypatch):
|
||||
"""Login should save credentials, status should show them."""
|
||||
monkeypatch.setenv("HOME", str(tmp_path))
|
||||
result = runner.invoke(oauth_app, ["login", "--token", "sk-ant-oat01-test-xxx"])
|
||||
assert result.exit_code == 0
|
||||
assert "Successfully saved" in result.output
|
||||
|
||||
result = runner.invoke(oauth_app, ["status"])
|
||||
assert result.exit_code == 0
|
||||
assert "sk-ant-oat01-test-x" in result.output
|
||||
|
||||
|
||||
def test_oauth_logout(runner, tmp_path, monkeypatch):
|
||||
"""Logout should remove credentials."""
|
||||
monkeypatch.setenv("HOME", str(tmp_path))
|
||||
runner.invoke(oauth_app, ["login", "--token", "sk-ant-oat01-test-xxx"])
|
||||
result = runner.invoke(oauth_app, ["logout"])
|
||||
assert result.exit_code == 0
|
||||
assert "Removed" in result.output
|
||||
|
||||
|
||||
def test_oauth_login_invalid_token(runner, tmp_path, monkeypatch):
|
||||
"""Login should reject non-OAuth tokens."""
|
||||
monkeypatch.setenv("HOME", str(tmp_path))
|
||||
result = runner.invoke(oauth_app, ["login", "--token", "sk-ant-api03-regular"])
|
||||
assert result.exit_code == 0
|
||||
assert "Invalid token" in result.output
|
||||
65
tests/test_config_oauth_integration.py
Normal file
65
tests/test_config_oauth_integration.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Test OAuth store integration with config loading."""
|
||||
import json
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from nanobot.config.loader import load_config
|
||||
from nanobot.config.oauth_store import OAuthStore
|
||||
from nanobot.config.schema import OAuthCredentials
|
||||
|
||||
|
||||
def test_oauth_token_injected_into_config(tmp_path, monkeypatch):
|
||||
"""OAuth token from store should be injected into provider api_key."""
|
||||
# Create a minimal config file (no api key set)
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(json.dumps({
|
||||
"agents": {"defaults": {"model": "anthropic/claude-opus-4-5"}},
|
||||
"providers": {"anthropic": {"apiKey": ""}}
|
||||
}))
|
||||
|
||||
# Save OAuth credentials
|
||||
store = OAuthStore(tmp_path)
|
||||
creds = OAuthCredentials(access_token="sk-ant-oat01-test-inject")
|
||||
store.save("anthropic", creds)
|
||||
|
||||
# Monkeypatch get_config_path to use our tmp dir
|
||||
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
|
||||
# Monkeypatch the OAuth store path
|
||||
monkeypatch.setattr("nanobot.config.loader._get_oauth_store_dir", lambda: tmp_path)
|
||||
|
||||
config = load_config(config_path)
|
||||
|
||||
assert config.providers.anthropic.api_key == "sk-ant-oat01-test-inject"
|
||||
|
||||
|
||||
def test_config_without_oauth_unchanged(tmp_path, monkeypatch):
|
||||
"""Config without OAuth store should load normally."""
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(json.dumps({
|
||||
"providers": {"anthropic": {"apiKey": "sk-ant-api03-regular"}}
|
||||
}))
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader._get_oauth_store_dir", lambda: tmp_path / "nonexistent")
|
||||
|
||||
config = load_config(config_path)
|
||||
|
||||
assert config.providers.anthropic.api_key == "sk-ant-api03-regular"
|
||||
|
||||
|
||||
def test_oauth_does_not_overwrite_existing_key(tmp_path, monkeypatch):
|
||||
"""If user already has an API key, OAuth should still override (OAuth takes priority)."""
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(json.dumps({
|
||||
"providers": {"anthropic": {"apiKey": "sk-ant-api03-existing"}}
|
||||
}))
|
||||
|
||||
store = OAuthStore(tmp_path)
|
||||
creds = OAuthCredentials(access_token="sk-ant-oat01-oauth-wins")
|
||||
store.save("anthropic", creds)
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader._get_oauth_store_dir", lambda: tmp_path)
|
||||
|
||||
config = load_config(config_path)
|
||||
|
||||
# OAuth token takes priority over existing API key
|
||||
assert config.providers.anthropic.api_key == "sk-ant-oat01-oauth-wins"
|
||||
48
tests/test_oauth_config.py
Normal file
48
tests/test_oauth_config.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Test OAuth configuration schema."""
|
||||
import pytest
|
||||
from nanobot.config.schema import ProviderConfig, OAuthCredentials
|
||||
|
||||
|
||||
def test_provider_config_has_oauth_fields():
|
||||
"""ProviderConfig should have oauth_credentials field."""
|
||||
config = ProviderConfig(api_key="test")
|
||||
assert hasattr(config, "oauth_credentials")
|
||||
assert config.oauth_credentials is None
|
||||
|
||||
|
||||
def test_oauth_credentials_model():
|
||||
"""OAuthCredentials should store token, refresh, expiry."""
|
||||
creds = OAuthCredentials(
|
||||
access_token="sk-ant-oat01-xxx",
|
||||
refresh_token="rt_xxx",
|
||||
expires_at=1234567890,
|
||||
token_type="oauth"
|
||||
)
|
||||
assert creds.access_token.startswith("sk-ant-oat")
|
||||
assert creds.is_oauth_token is True
|
||||
|
||||
|
||||
def test_oauth_credentials_expiry_check():
|
||||
"""OAuthCredentials should detect expired tokens."""
|
||||
import time
|
||||
expired = OAuthCredentials(
|
||||
access_token="sk-ant-oat01-xxx",
|
||||
expires_at=int(time.time()) - 3600 # 1 hour ago
|
||||
)
|
||||
assert expired.is_expired is True
|
||||
|
||||
valid = OAuthCredentials(
|
||||
access_token="sk-ant-oat01-xxx",
|
||||
expires_at=int(time.time()) + 3600 # 1 hour from now
|
||||
)
|
||||
assert valid.is_expired is False
|
||||
|
||||
|
||||
def test_oauth_credentials_no_expiry():
|
||||
"""Setup tokens with expires_at=0 should never be expired."""
|
||||
creds = OAuthCredentials(
|
||||
access_token="sk-ant-oat01-xxx",
|
||||
expires_at=0 # No expiry (setup-token)
|
||||
)
|
||||
assert creds.is_expired is False
|
||||
assert creds.expires_soon is False
|
||||
55
tests/test_oauth_store.py
Normal file
55
tests/test_oauth_store.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Test OAuth credential storage."""
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from nanobot.config.oauth_store import OAuthStore
|
||||
from nanobot.config.schema import OAuthCredentials
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_store():
|
||||
"""Create store with temp directory."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield OAuthStore(Path(tmpdir) / ".nanobot")
|
||||
|
||||
|
||||
def test_save_and_load_credentials(temp_store):
|
||||
"""Should save and load OAuth credentials."""
|
||||
creds = OAuthCredentials(
|
||||
access_token="sk-ant-oat01-xxx",
|
||||
refresh_token="rt_xxx",
|
||||
expires_at=1234567890
|
||||
)
|
||||
|
||||
temp_store.save("anthropic", creds)
|
||||
loaded = temp_store.load("anthropic")
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.access_token == creds.access_token
|
||||
assert loaded.refresh_token == creds.refresh_token
|
||||
|
||||
|
||||
def test_load_nonexistent_returns_none(temp_store):
|
||||
"""Should return None for missing credentials."""
|
||||
assert temp_store.load("nonexistent") is None
|
||||
|
||||
|
||||
def test_delete_credentials(temp_store):
|
||||
"""Should delete saved credentials."""
|
||||
creds = OAuthCredentials(access_token="sk-ant-oat01-xxx")
|
||||
temp_store.save("anthropic", creds)
|
||||
assert temp_store.delete("anthropic") is True
|
||||
assert temp_store.load("anthropic") is None
|
||||
|
||||
|
||||
def test_delete_nonexistent_returns_false(temp_store):
|
||||
"""Should return False when deleting missing credentials."""
|
||||
assert temp_store.delete("nonexistent") is False
|
||||
|
||||
|
||||
def test_file_permissions(temp_store):
|
||||
"""Credentials file should have restricted permissions."""
|
||||
creds = OAuthCredentials(access_token="sk-ant-oat01-xxx")
|
||||
temp_store.save("anthropic", creds)
|
||||
perms = oct(temp_store.file_path.stat().st_mode)[-3:]
|
||||
assert perms == "600"
|
||||
28
tests/test_oauth_utils.py
Normal file
28
tests/test_oauth_utils.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Test OAuth utility functions."""
|
||||
import pytest
|
||||
from nanobot.providers.oauth_utils import is_oauth_token, get_auth_headers
|
||||
|
||||
|
||||
def test_is_oauth_token_detects_oat():
|
||||
"""Should detect sk-ant-oat tokens as OAuth."""
|
||||
assert is_oauth_token("sk-ant-oat01-buSdhCH2XEkebW7ZQZTvGqH5EwAFh4u52LrdJhAP") is True
|
||||
assert is_oauth_token("sk-ant-api03-regularkey") is False
|
||||
assert is_oauth_token("") is False
|
||||
assert is_oauth_token(None) is False
|
||||
|
||||
|
||||
def test_get_auth_headers_oauth():
|
||||
"""OAuth tokens should use Authorization: Bearer."""
|
||||
headers = get_auth_headers("sk-ant-oat01-xxx", is_oauth=True)
|
||||
assert "Authorization" in headers
|
||||
assert headers["Authorization"] == "Bearer sk-ant-oat01-xxx"
|
||||
assert "x-api-key" not in headers
|
||||
assert headers["anthropic-beta"] == "claude-code-20250219,oauth-2025-04-20"
|
||||
|
||||
|
||||
def test_get_auth_headers_api_key():
|
||||
"""Regular API keys should use x-api-key."""
|
||||
headers = get_auth_headers("sk-ant-api03-xxx", is_oauth=False)
|
||||
assert "x-api-key" in headers
|
||||
assert headers["x-api-key"] == "sk-ant-api03-xxx"
|
||||
assert "Authorization" not in headers
|
||||
32
tests/test_provider_factory.py
Normal file
32
tests/test_provider_factory.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Test provider factory with OAuth support."""
|
||||
import pytest
|
||||
from nanobot.providers import create_provider
|
||||
from nanobot.providers.anthropic_oauth import AnthropicOAuthProvider
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
|
||||
|
||||
def test_create_provider_oauth_token():
|
||||
"""OAuth tokens should create AnthropicOAuthProvider."""
|
||||
provider = create_provider(
|
||||
api_key="sk-ant-oat01-test-token",
|
||||
model="anthropic/claude-opus-4-5"
|
||||
)
|
||||
assert isinstance(provider, AnthropicOAuthProvider)
|
||||
|
||||
|
||||
def test_create_provider_regular_key():
|
||||
"""Regular API keys should create LiteLLMProvider."""
|
||||
provider = create_provider(
|
||||
api_key="sk-ant-api03-regular-key",
|
||||
model="anthropic/claude-opus-4-5"
|
||||
)
|
||||
assert isinstance(provider, LiteLLMProvider)
|
||||
|
||||
|
||||
def test_create_provider_openrouter():
|
||||
"""OpenRouter keys should create LiteLLMProvider."""
|
||||
provider = create_provider(
|
||||
api_key="sk-or-v1-xxx",
|
||||
model="anthropic/claude-opus-4-5"
|
||||
)
|
||||
assert isinstance(provider, LiteLLMProvider)
|
||||
21
tests/test_registry_oauth.py
Normal file
21
tests/test_registry_oauth.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Test OAuth detection in provider registry."""
|
||||
import pytest
|
||||
from nanobot.providers.registry import should_use_oauth_provider
|
||||
|
||||
|
||||
def test_should_use_oauth_for_oat_token():
|
||||
"""OAuth provider should be used for sk-ant-oat tokens."""
|
||||
assert should_use_oauth_provider("sk-ant-oat01-xxx", "anthropic/claude-opus-4-5") is True
|
||||
assert should_use_oauth_provider("sk-ant-oat01-xxx", "claude-sonnet-4") is True
|
||||
|
||||
|
||||
def test_should_not_use_oauth_for_regular_key():
|
||||
"""Regular API keys should not use OAuth provider."""
|
||||
assert should_use_oauth_provider("sk-ant-api03-xxx", "claude-opus-4-5") is False
|
||||
assert should_use_oauth_provider("sk-or-v1-xxx", "anthropic/claude-opus-4-5") is False
|
||||
|
||||
|
||||
def test_should_not_use_oauth_for_non_anthropic():
|
||||
"""Non-Anthropic models should not use OAuth provider."""
|
||||
assert should_use_oauth_provider("sk-ant-oat01-xxx", "gpt-4") is False
|
||||
assert should_use_oauth_provider("sk-ant-oat01-xxx", "deepseek/deepseek-chat") is False
|
||||
Reference in New Issue
Block a user