feat(llm): pluggable provider abstraction for entity refinement
Three providers cover the useful space while keeping the zero-API default: - `ollama` (default): local models via http://localhost:11434. Works fully offline. Tag-matching check accepts both `model` and `model:latest` forms. - `openai-compat`: any /v1/chat/completions endpoint. Covers OpenRouter, LM Studio, llama.cpp server, vLLM, Groq, Together, Fireworks, and most self-hosted frameworks. API key falls back to $OPENAI_API_KEY. Endpoint normalization is forgiving about trailing `/v1`. - `anthropic`: Messages API v2023-06-01. API key falls back to $ANTHROPIC_API_KEY. Concatenates multi-block text responses. JSON mode is normalized across providers — Ollama uses `format: "json"`, OpenAI-compat uses `response_format`, Anthropic uses prompt-level instruction. Callers request JSON once; this module handles the provider-specific plumbing. No external SDK dependency; stdlib `urllib` throughout. HTTP errors are wrapped into a single `LLMError` class so callers don't need to distinguish transport, auth, and parse failures at the call site. 26 tests, all with mocked HTTP — suite runs offline with no real provider required.
This commit is contained in:
@@ -0,0 +1,305 @@
|
|||||||
|
"""
|
||||||
|
llm_client.py — Minimal provider abstraction for LLM-assisted entity refinement.
|
||||||
|
|
||||||
|
Three providers cover the useful space:
|
||||||
|
|
||||||
|
- ``ollama`` (default): local models via http://localhost:11434. Works fully
|
||||||
|
offline. Honors MemPalace's "zero-API required" principle.
|
||||||
|
- ``openai-compat``: any OpenAI-compatible ``/v1/chat/completions`` endpoint.
|
||||||
|
Covers OpenRouter, LM Studio, llama.cpp server, vLLM, Groq, Fireworks,
|
||||||
|
Together, and most self-hosted setups.
|
||||||
|
- ``anthropic``: the official Messages API. Opt-in for users who want Haiku
|
||||||
|
quality without setting up a local model.
|
||||||
|
|
||||||
|
All providers expose the same ``classify(system, user, json_mode)`` method and
|
||||||
|
the same ``check_available()`` probe. No external SDK dependencies — stdlib
|
||||||
|
``urllib`` only.
|
||||||
|
|
||||||
|
JSON mode matters here: we always ask for structured output. Providers
|
||||||
|
differ on how to request it (Ollama: ``format: json``; OpenAI-compat:
|
||||||
|
``response_format``; Anthropic: prompt-level instruction) and this module
|
||||||
|
normalizes that away from the caller.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.error import HTTPError, URLError
|
||||||
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
|
|
||||||
|
class LLMError(RuntimeError):
|
||||||
|
"""Raised for any provider failure — transport, parse, auth, missing model."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMResponse:
|
||||||
|
text: str
|
||||||
|
model: str
|
||||||
|
provider: str
|
||||||
|
raw: dict
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== BASE ====================
|
||||||
|
|
||||||
|
|
||||||
|
class LLMProvider:
|
||||||
|
name: str = "base"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
endpoint: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
timeout: int = 120,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.api_key = api_key
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def check_available(self) -> tuple[bool, str]:
|
||||||
|
"""Return ``(ok, message)``. Fast probe that the provider is reachable."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def _http_post_json(url: str, body: dict, headers: dict, timeout: int) -> dict:
|
||||||
|
"""POST JSON and return the parsed response. Raises LLMError on any failure."""
|
||||||
|
req = Request(
|
||||||
|
url,
|
||||||
|
data=json.dumps(body).encode("utf-8"),
|
||||||
|
headers={"Content-Type": "application/json", **headers},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with urlopen(req, timeout=timeout) as resp:
|
||||||
|
return json.loads(resp.read())
|
||||||
|
except HTTPError as e:
|
||||||
|
detail = ""
|
||||||
|
try:
|
||||||
|
detail = e.read().decode("utf-8", errors="replace")[:500]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise LLMError(f"HTTP {e.code} from {url}: {detail or e.reason}") from e
|
||||||
|
except (URLError, OSError) as e:
|
||||||
|
raise LLMError(f"Cannot reach {url}: {e}") from e
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise LLMError(f"Malformed response from {url}: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== OLLAMA ====================
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaProvider(LLMProvider):
|
||||||
|
name = "ollama"
|
||||||
|
DEFAULT_ENDPOINT = "http://localhost:11434"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
endpoint: Optional[str] = None,
|
||||||
|
timeout: int = 180,
|
||||||
|
**_: object,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
model=model,
|
||||||
|
endpoint=endpoint or self.DEFAULT_ENDPOINT,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_available(self) -> tuple[bool, str]:
|
||||||
|
try:
|
||||||
|
with urlopen(f"{self.endpoint}/api/tags", timeout=5) as resp:
|
||||||
|
data = json.loads(resp.read())
|
||||||
|
except (URLError, HTTPError, OSError, json.JSONDecodeError) as e:
|
||||||
|
return False, f"Cannot reach Ollama at {self.endpoint}: {e}"
|
||||||
|
names = {m.get("name", "") for m in data.get("models", []) or []}
|
||||||
|
# Ollama tags may or may not include ':latest' — accept either form
|
||||||
|
wanted = {self.model, f"{self.model}:latest"}
|
||||||
|
if not names & wanted:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"Model '{self.model}' not loaded in Ollama. " f"Run: ollama pull {self.model}",
|
||||||
|
)
|
||||||
|
return True, "ok"
|
||||||
|
|
||||||
|
def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse:
|
||||||
|
body: dict = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": system},
|
||||||
|
{"role": "user", "content": user},
|
||||||
|
],
|
||||||
|
"stream": False,
|
||||||
|
"options": {"temperature": 0.1},
|
||||||
|
}
|
||||||
|
if json_mode:
|
||||||
|
body["format"] = "json"
|
||||||
|
data = _http_post_json(f"{self.endpoint}/api/chat", body, headers={}, timeout=self.timeout)
|
||||||
|
text = (data.get("message") or {}).get("content", "")
|
||||||
|
if not text:
|
||||||
|
raise LLMError(f"Empty response from Ollama (model={self.model})")
|
||||||
|
return LLMResponse(text=text, model=self.model, provider=self.name, raw=data)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== OPENAI-COMPAT ====================
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatProvider(LLMProvider):
|
||||||
|
"""Any OpenAI-compatible ``/v1/chat/completions`` endpoint.
|
||||||
|
|
||||||
|
Supply ``--llm-endpoint http://host:port`` (with or without ``/v1``).
|
||||||
|
API key via ``--llm-api-key`` or the ``OPENAI_API_KEY`` env var.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "openai-compat"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
endpoint: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
timeout: int = 120,
|
||||||
|
**_: object,
|
||||||
|
):
|
||||||
|
resolved_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||||
|
super().__init__(model=model, endpoint=endpoint, api_key=resolved_key, timeout=timeout)
|
||||||
|
|
||||||
|
def _resolve_url(self) -> str:
|
||||||
|
if not self.endpoint:
|
||||||
|
raise LLMError("openai-compat provider requires --llm-endpoint")
|
||||||
|
url = self.endpoint.rstrip("/")
|
||||||
|
if url.endswith("/chat/completions"):
|
||||||
|
return url
|
||||||
|
if not url.endswith("/v1"):
|
||||||
|
url = f"{url}/v1"
|
||||||
|
return f"{url}/chat/completions"
|
||||||
|
|
||||||
|
def check_available(self) -> tuple[bool, str]:
|
||||||
|
if not self.endpoint:
|
||||||
|
return False, "no --llm-endpoint configured"
|
||||||
|
base = self.endpoint.rstrip("/")
|
||||||
|
base = base.removesuffix("/chat/completions").removesuffix("/v1")
|
||||||
|
try:
|
||||||
|
req = Request(f"{base}/v1/models")
|
||||||
|
if self.api_key:
|
||||||
|
req.add_header("Authorization", f"Bearer {self.api_key}")
|
||||||
|
with urlopen(req, timeout=5):
|
||||||
|
pass
|
||||||
|
except (URLError, HTTPError, OSError) as e:
|
||||||
|
return False, f"Cannot reach {self.endpoint}: {e}"
|
||||||
|
return True, "ok"
|
||||||
|
|
||||||
|
def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse:
|
||||||
|
body: dict = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": system},
|
||||||
|
{"role": "user", "content": user},
|
||||||
|
],
|
||||||
|
"temperature": 0.1,
|
||||||
|
}
|
||||||
|
if json_mode:
|
||||||
|
body["response_format"] = {"type": "json_object"}
|
||||||
|
headers = {}
|
||||||
|
if self.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
data = _http_post_json(self._resolve_url(), body, headers=headers, timeout=self.timeout)
|
||||||
|
try:
|
||||||
|
text = data["choices"][0]["message"]["content"]
|
||||||
|
except (KeyError, IndexError, TypeError) as e:
|
||||||
|
raise LLMError(f"Unexpected response shape: {e}") from e
|
||||||
|
if not text:
|
||||||
|
raise LLMError(f"Empty response from {self.name} (model={self.model})")
|
||||||
|
return LLMResponse(text=text, model=self.model, provider=self.name, raw=data)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== ANTHROPIC ====================
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicProvider(LLMProvider):
|
||||||
|
name = "anthropic"
|
||||||
|
DEFAULT_ENDPOINT = "https://api.anthropic.com"
|
||||||
|
API_VERSION = "2023-06-01"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
endpoint: Optional[str] = None,
|
||||||
|
timeout: int = 120,
|
||||||
|
**_: object,
|
||||||
|
):
|
||||||
|
key = api_key or os.environ.get("ANTHROPIC_API_KEY")
|
||||||
|
super().__init__(
|
||||||
|
model=model,
|
||||||
|
endpoint=endpoint or self.DEFAULT_ENDPOINT,
|
||||||
|
api_key=key,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_available(self) -> tuple[bool, str]:
|
||||||
|
if not self.api_key:
|
||||||
|
return False, "ANTHROPIC_API_KEY not set (use --llm-api-key or env)"
|
||||||
|
# Don't probe — a live request would cost money. First real call will
|
||||||
|
# surface auth errors if the key is invalid.
|
||||||
|
return True, "ok"
|
||||||
|
|
||||||
|
def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse:
|
||||||
|
if not self.api_key:
|
||||||
|
raise LLMError("Anthropic provider requires ANTHROPIC_API_KEY env or --llm-api-key")
|
||||||
|
sys_prompt = system
|
||||||
|
if json_mode:
|
||||||
|
sys_prompt += "\n\nRespond with valid JSON only, no prose."
|
||||||
|
body = {
|
||||||
|
"model": self.model,
|
||||||
|
"max_tokens": 2048,
|
||||||
|
"temperature": 0.1,
|
||||||
|
"system": sys_prompt,
|
||||||
|
"messages": [{"role": "user", "content": user}],
|
||||||
|
}
|
||||||
|
headers = {
|
||||||
|
"X-API-Key": self.api_key,
|
||||||
|
"anthropic-version": self.API_VERSION,
|
||||||
|
}
|
||||||
|
data = _http_post_json(
|
||||||
|
f"{self.endpoint}/v1/messages", body, headers=headers, timeout=self.timeout
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
text = "".join(
|
||||||
|
b.get("text", "") for b in data.get("content", []) or [] if b.get("type") == "text"
|
||||||
|
)
|
||||||
|
except (AttributeError, TypeError) as e:
|
||||||
|
raise LLMError(f"Unexpected response shape: {e}") from e
|
||||||
|
if not text:
|
||||||
|
raise LLMError(f"Empty response from Anthropic (model={self.model})")
|
||||||
|
return LLMResponse(text=text, model=self.model, provider=self.name, raw=data)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== FACTORY ====================
|
||||||
|
|
||||||
|
|
||||||
|
PROVIDERS: dict[str, type[LLMProvider]] = {
|
||||||
|
"ollama": OllamaProvider,
|
||||||
|
"openai-compat": OpenAICompatProvider,
|
||||||
|
"anthropic": AnthropicProvider,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider(
|
||||||
|
name: str,
|
||||||
|
model: str,
|
||||||
|
endpoint: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
timeout: int = 120,
|
||||||
|
) -> LLMProvider:
|
||||||
|
"""Build a provider by name. Raises LLMError on unknown provider."""
|
||||||
|
cls = PROVIDERS.get(name)
|
||||||
|
if cls is None:
|
||||||
|
raise LLMError(f"Unknown provider '{name}'. Choices: {sorted(PROVIDERS.keys())}")
|
||||||
|
return cls(model=model, endpoint=endpoint, api_key=api_key, timeout=timeout)
|
||||||
@@ -0,0 +1,327 @@
|
|||||||
|
"""Tests for mempalace.llm_client.
|
||||||
|
|
||||||
|
HTTP is mocked throughout — these tests do not require a running Ollama
|
||||||
|
or network access. Live-provider smoke tests live outside the unit-test
|
||||||
|
suite.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mempalace.llm_client import (
|
||||||
|
AnthropicProvider,
|
||||||
|
LLMError,
|
||||||
|
OllamaProvider,
|
||||||
|
OpenAICompatProvider,
|
||||||
|
_http_post_json,
|
||||||
|
get_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── factory ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_provider_ollama():
|
||||||
|
p = get_provider("ollama", "gemma4:e4b")
|
||||||
|
assert isinstance(p, OllamaProvider)
|
||||||
|
assert p.model == "gemma4:e4b"
|
||||||
|
assert p.endpoint == OllamaProvider.DEFAULT_ENDPOINT
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_provider_openai_compat():
|
||||||
|
p = get_provider("openai-compat", "foo", endpoint="http://localhost:1234")
|
||||||
|
assert isinstance(p, OpenAICompatProvider)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_provider_anthropic():
|
||||||
|
p = get_provider("anthropic", "claude-haiku", api_key="sk-xxx")
|
||||||
|
assert isinstance(p, AnthropicProvider)
|
||||||
|
assert p.api_key == "sk-xxx"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_provider_unknown_raises():
|
||||||
|
with pytest.raises(LLMError, match="Unknown provider"):
|
||||||
|
get_provider("nonsense", "x")
|
||||||
|
|
||||||
|
|
||||||
|
# ── _http_post_json ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_http_post_json_success():
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.read.return_value = b'{"ok": true}'
|
||||||
|
mock_resp.__enter__.return_value = mock_resp
|
||||||
|
mock_resp.__exit__.return_value = False
|
||||||
|
with patch("mempalace.llm_client.urlopen", return_value=mock_resp):
|
||||||
|
result = _http_post_json("http://x/y", {"a": 1}, {}, timeout=5)
|
||||||
|
assert result == {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
def test_http_post_json_http_error_wraps_as_llm_error():
|
||||||
|
from urllib.error import HTTPError
|
||||||
|
import io
|
||||||
|
|
||||||
|
err = HTTPError("http://x", 404, "Not Found", {}, io.BytesIO(b"model missing"))
|
||||||
|
with patch("mempalace.llm_client.urlopen", side_effect=err):
|
||||||
|
with pytest.raises(LLMError, match="HTTP 404"):
|
||||||
|
_http_post_json("http://x", {}, {}, timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_http_post_json_url_error_wraps_as_llm_error():
|
||||||
|
from urllib.error import URLError
|
||||||
|
|
||||||
|
with patch("mempalace.llm_client.urlopen", side_effect=URLError("conn refused")):
|
||||||
|
with pytest.raises(LLMError, match="Cannot reach"):
|
||||||
|
_http_post_json("http://x", {}, {}, timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_http_post_json_malformed_response():
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.read.return_value = b"not json"
|
||||||
|
mock_resp.__enter__.return_value = mock_resp
|
||||||
|
mock_resp.__exit__.return_value = False
|
||||||
|
with patch("mempalace.llm_client.urlopen", return_value=mock_resp):
|
||||||
|
with pytest.raises(LLMError, match="Malformed"):
|
||||||
|
_http_post_json("http://x", {}, {}, timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
# ── OllamaProvider ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_ollama_chat_response(content: str):
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.read.return_value = json.dumps({"message": {"content": content}}).encode()
|
||||||
|
mock.__enter__.return_value = mock
|
||||||
|
mock.__exit__.return_value = False
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_check_available_finds_model():
|
||||||
|
tags = {"models": [{"name": "gemma4:e4b"}, {"name": "other:latest"}]}
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.read.return_value = json.dumps(tags).encode()
|
||||||
|
mock.__enter__.return_value = mock
|
||||||
|
mock.__exit__.return_value = False
|
||||||
|
with patch("mempalace.llm_client.urlopen", return_value=mock):
|
||||||
|
p = OllamaProvider(model="gemma4:e4b")
|
||||||
|
ok, msg = p.check_available()
|
||||||
|
assert ok
|
||||||
|
assert msg == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_check_available_accepts_latest_suffix():
|
||||||
|
tags = {"models": [{"name": "mymodel:latest"}]}
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.read.return_value = json.dumps(tags).encode()
|
||||||
|
mock.__enter__.return_value = mock
|
||||||
|
mock.__exit__.return_value = False
|
||||||
|
with patch("mempalace.llm_client.urlopen", return_value=mock):
|
||||||
|
p = OllamaProvider(model="mymodel")
|
||||||
|
ok, _ = p.check_available()
|
||||||
|
assert ok
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_check_available_missing_model():
|
||||||
|
tags = {"models": [{"name": "other:latest"}]}
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.read.return_value = json.dumps(tags).encode()
|
||||||
|
mock.__enter__.return_value = mock
|
||||||
|
mock.__exit__.return_value = False
|
||||||
|
with patch("mempalace.llm_client.urlopen", return_value=mock):
|
||||||
|
p = OllamaProvider(model="absent")
|
||||||
|
ok, msg = p.check_available()
|
||||||
|
assert not ok
|
||||||
|
assert "ollama pull absent" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_check_available_unreachable():
|
||||||
|
from urllib.error import URLError
|
||||||
|
|
||||||
|
with patch("mempalace.llm_client.urlopen", side_effect=URLError("refused")):
|
||||||
|
p = OllamaProvider(model="gemma4:e4b")
|
||||||
|
ok, msg = p.check_available()
|
||||||
|
assert not ok
|
||||||
|
assert "Cannot reach Ollama" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_classify_sends_json_format():
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_urlopen(req, *, timeout):
|
||||||
|
captured["url"] = req.full_url
|
||||||
|
captured["body"] = json.loads(req.data.decode())
|
||||||
|
return _mock_ollama_chat_response('{"classifications": []}')
|
||||||
|
|
||||||
|
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
|
||||||
|
p = OllamaProvider(model="gemma4:e4b")
|
||||||
|
resp = p.classify("sys", "user", json_mode=True)
|
||||||
|
|
||||||
|
assert captured["body"]["format"] == "json"
|
||||||
|
assert captured["body"]["model"] == "gemma4:e4b"
|
||||||
|
assert captured["url"].endswith("/api/chat")
|
||||||
|
assert resp.provider == "ollama"
|
||||||
|
assert resp.text == '{"classifications": []}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_classify_empty_content_raises():
|
||||||
|
with patch("mempalace.llm_client.urlopen", return_value=_mock_ollama_chat_response("")):
|
||||||
|
p = OllamaProvider(model="x")
|
||||||
|
with pytest.raises(LLMError, match="Empty response"):
|
||||||
|
p.classify("s", "u")
|
||||||
|
|
||||||
|
|
||||||
|
# ── OpenAICompatProvider ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_openai_response(content: str):
|
||||||
|
mock = MagicMock()
|
||||||
|
payload = {"choices": [{"message": {"content": content}}]}
|
||||||
|
mock.read.return_value = json.dumps(payload).encode()
|
||||||
|
mock.__enter__.return_value = mock
|
||||||
|
mock.__exit__.return_value = False
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_resolves_url_with_v1_suffix():
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_urlopen(req, *, timeout):
|
||||||
|
captured["url"] = req.full_url
|
||||||
|
return _mock_openai_response('{"ok": true}')
|
||||||
|
|
||||||
|
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
|
||||||
|
p = OpenAICompatProvider(model="x", endpoint="http://h:1234")
|
||||||
|
p.classify("s", "u")
|
||||||
|
assert captured["url"] == "http://h:1234/v1/chat/completions"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_resolves_url_with_existing_v1():
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_urlopen(req, *, timeout):
|
||||||
|
captured["url"] = req.full_url
|
||||||
|
return _mock_openai_response('{"ok": true}')
|
||||||
|
|
||||||
|
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
|
||||||
|
p = OpenAICompatProvider(model="x", endpoint="http://h:1234/v1")
|
||||||
|
p.classify("s", "u")
|
||||||
|
assert captured["url"] == "http://h:1234/v1/chat/completions"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_requires_endpoint():
|
||||||
|
p = OpenAICompatProvider(model="x")
|
||||||
|
with pytest.raises(LLMError, match="requires --llm-endpoint"):
|
||||||
|
p.classify("s", "u")
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_sends_authorization_when_key_present():
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_urlopen(req, *, timeout):
|
||||||
|
captured["auth"] = req.get_header("Authorization")
|
||||||
|
return _mock_openai_response('{"ok": true}')
|
||||||
|
|
||||||
|
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
|
||||||
|
p = OpenAICompatProvider(model="x", endpoint="http://h", api_key="sk-aaa")
|
||||||
|
p.classify("s", "u")
|
||||||
|
assert captured["auth"] == "Bearer sk-aaa"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_uses_env_var_fallback(monkeypatch):
|
||||||
|
monkeypatch.setenv("OPENAI_API_KEY", "sk-from-env")
|
||||||
|
p = OpenAICompatProvider(model="x", endpoint="http://h")
|
||||||
|
assert p.api_key == "sk-from-env"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_sends_response_format_json():
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_urlopen(req, *, timeout):
|
||||||
|
captured["body"] = json.loads(req.data.decode())
|
||||||
|
return _mock_openai_response('{"ok": true}')
|
||||||
|
|
||||||
|
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
|
||||||
|
p = OpenAICompatProvider(model="x", endpoint="http://h")
|
||||||
|
p.classify("s", "u", json_mode=True)
|
||||||
|
assert captured["body"]["response_format"] == {"type": "json_object"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_unexpected_shape_raises():
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.read.return_value = b'{"nothing": "here"}'
|
||||||
|
mock.__enter__.return_value = mock
|
||||||
|
mock.__exit__.return_value = False
|
||||||
|
with patch("mempalace.llm_client.urlopen", return_value=mock):
|
||||||
|
p = OpenAICompatProvider(model="x", endpoint="http://h")
|
||||||
|
with pytest.raises(LLMError, match="Unexpected response shape"):
|
||||||
|
p.classify("s", "u")
|
||||||
|
|
||||||
|
|
||||||
|
# ── AnthropicProvider ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_anthropic_response(text: str):
|
||||||
|
mock = MagicMock()
|
||||||
|
payload = {"content": [{"type": "text", "text": text}]}
|
||||||
|
mock.read.return_value = json.dumps(payload).encode()
|
||||||
|
mock.__enter__.return_value = mock
|
||||||
|
mock.__exit__.return_value = False
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_requires_api_key(monkeypatch):
|
||||||
|
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||||
|
p = AnthropicProvider(model="claude-haiku")
|
||||||
|
ok, msg = p.check_available()
|
||||||
|
assert not ok
|
||||||
|
assert "ANTHROPIC_API_KEY" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_reads_env_key(monkeypatch):
|
||||||
|
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-env")
|
||||||
|
p = AnthropicProvider(model="claude-haiku")
|
||||||
|
assert p.api_key == "sk-ant-env"
|
||||||
|
ok, _ = p.check_available()
|
||||||
|
assert ok
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_classify_sends_version_and_key():
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_urlopen(req, *, timeout):
|
||||||
|
captured["api_key"] = req.get_header("X-api-key")
|
||||||
|
captured["version"] = req.get_header("Anthropic-version")
|
||||||
|
return _mock_anthropic_response('{"ok": true}')
|
||||||
|
|
||||||
|
with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen):
|
||||||
|
p = AnthropicProvider(model="claude-haiku", api_key="sk-ant-abc")
|
||||||
|
resp = p.classify("s", "u")
|
||||||
|
assert captured["api_key"] == "sk-ant-abc"
|
||||||
|
assert captured["version"] == AnthropicProvider.API_VERSION
|
||||||
|
assert resp.text == '{"ok": true}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_joins_multiple_text_blocks():
|
||||||
|
mock = MagicMock()
|
||||||
|
payload = {
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "part one. "},
|
||||||
|
{"type": "text", "text": "part two."},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock.read.return_value = json.dumps(payload).encode()
|
||||||
|
mock.__enter__.return_value = mock
|
||||||
|
mock.__exit__.return_value = False
|
||||||
|
with patch("mempalace.llm_client.urlopen", return_value=mock):
|
||||||
|
p = AnthropicProvider(model="claude-haiku", api_key="sk-ant")
|
||||||
|
resp = p.classify("s", "u")
|
||||||
|
assert resp.text == "part one. part two."
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_no_key_raises_on_classify(monkeypatch):
|
||||||
|
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||||
|
p = AnthropicProvider(model="claude-haiku")
|
||||||
|
with pytest.raises(LLMError, match="requires ANTHROPIC_API_KEY"):
|
||||||
|
p.classify("s", "u")
|
||||||
Reference in New Issue
Block a user