From df6c7d0dc3d805f88ed4ffc1e897ab66ddaa4134 Mon Sep 17 00:00:00 2001 From: Igor Lins e Silva <4753812+igorls@users.noreply.github.com> Date: Fri, 24 Apr 2026 00:46:43 -0300 Subject: [PATCH] feat(llm): pluggable provider abstraction for entity refinement MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three providers cover the useful space while keeping the zero-API default: - `ollama` (default): local models via http://localhost:11434. Works fully offline. Tag-matching check accepts both `model` and `model:latest` forms. - `openai-compat`: any /v1/chat/completions endpoint. Covers OpenRouter, LM Studio, llama.cpp server, vLLM, Groq, Together, Fireworks, and most self-hosted frameworks. API key falls back to $OPENAI_API_KEY. Endpoint normalization is forgiving about trailing `/v1`. - `anthropic`: Messages API v2023-06-01. API key falls back to $ANTHROPIC_API_KEY. Concatenates multi-block text responses. JSON mode is normalized across providers — Ollama uses `format: "json"`, OpenAI-compat uses `response_format`, Anthropic uses prompt-level instruction. Callers request JSON once; this module handles the provider-specific plumbing. No external SDK dependency; stdlib `urllib` throughout. HTTP errors are wrapped into a single `LLMError` class so callers don't need to distinguish transport, auth, and parse failures at the call site. 26 tests, all with mocked HTTP — suite runs offline with no real provider required. --- mempalace/llm_client.py | 305 ++++++++++++++++++++++++++++++++++++ tests/test_llm_client.py | 327 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 632 insertions(+) create mode 100644 mempalace/llm_client.py create mode 100644 tests/test_llm_client.py diff --git a/mempalace/llm_client.py b/mempalace/llm_client.py new file mode 100644 index 0000000..442cf31 --- /dev/null +++ b/mempalace/llm_client.py @@ -0,0 +1,305 @@ +""" +llm_client.py — Minimal provider abstraction for LLM-assisted entity refinement. + +Three providers cover the useful space: + +- ``ollama`` (default): local models via http://localhost:11434. Works fully + offline. Honors MemPalace's "zero-API required" principle. +- ``openai-compat``: any OpenAI-compatible ``/v1/chat/completions`` endpoint. + Covers OpenRouter, LM Studio, llama.cpp server, vLLM, Groq, Fireworks, + Together, and most self-hosted setups. +- ``anthropic``: the official Messages API. Opt-in for users who want Haiku + quality without setting up a local model. + +All providers expose the same ``classify(system, user, json_mode)`` method and +the same ``check_available()`` probe. No external SDK dependencies — stdlib +``urllib`` only. + +JSON mode matters here: we always ask for structured output. Providers +differ on how to request it (Ollama: ``format: json``; OpenAI-compat: +``response_format``; Anthropic: prompt-level instruction) and this module +normalizes that away from the caller. +""" + +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from typing import Optional +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + + +class LLMError(RuntimeError): + """Raised for any provider failure — transport, parse, auth, missing model.""" + + +@dataclass +class LLMResponse: + text: str + model: str + provider: str + raw: dict + + +# ==================== BASE ==================== + + +class LLMProvider: + name: str = "base" + + def __init__( + self, + model: str, + endpoint: Optional[str] = None, + api_key: Optional[str] = None, + timeout: int = 120, + ): + self.model = model + self.endpoint = endpoint + self.api_key = api_key + self.timeout = timeout + + def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse: + raise NotImplementedError + + def check_available(self) -> tuple[bool, str]: + """Return ``(ok, message)``. Fast probe that the provider is reachable.""" + raise NotImplementedError + + +def _http_post_json(url: str, body: dict, headers: dict, timeout: int) -> dict: + """POST JSON and return the parsed response. Raises LLMError on any failure.""" + req = Request( + url, + data=json.dumps(body).encode("utf-8"), + headers={"Content-Type": "application/json", **headers}, + ) + try: + with urlopen(req, timeout=timeout) as resp: + return json.loads(resp.read()) + except HTTPError as e: + detail = "" + try: + detail = e.read().decode("utf-8", errors="replace")[:500] + except Exception: + pass + raise LLMError(f"HTTP {e.code} from {url}: {detail or e.reason}") from e + except (URLError, OSError) as e: + raise LLMError(f"Cannot reach {url}: {e}") from e + except json.JSONDecodeError as e: + raise LLMError(f"Malformed response from {url}: {e}") from e + + +# ==================== OLLAMA ==================== + + +class OllamaProvider(LLMProvider): + name = "ollama" + DEFAULT_ENDPOINT = "http://localhost:11434" + + def __init__( + self, + model: str, + endpoint: Optional[str] = None, + timeout: int = 180, + **_: object, + ): + super().__init__( + model=model, + endpoint=endpoint or self.DEFAULT_ENDPOINT, + timeout=timeout, + ) + + def check_available(self) -> tuple[bool, str]: + try: + with urlopen(f"{self.endpoint}/api/tags", timeout=5) as resp: + data = json.loads(resp.read()) + except (URLError, HTTPError, OSError, json.JSONDecodeError) as e: + return False, f"Cannot reach Ollama at {self.endpoint}: {e}" + names = {m.get("name", "") for m in data.get("models", []) or []} + # Ollama tags may or may not include ':latest' — accept either form + wanted = {self.model, f"{self.model}:latest"} + if not names & wanted: + return ( + False, + f"Model '{self.model}' not loaded in Ollama. " f"Run: ollama pull {self.model}", + ) + return True, "ok" + + def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse: + body: dict = { + "model": self.model, + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + "stream": False, + "options": {"temperature": 0.1}, + } + if json_mode: + body["format"] = "json" + data = _http_post_json(f"{self.endpoint}/api/chat", body, headers={}, timeout=self.timeout) + text = (data.get("message") or {}).get("content", "") + if not text: + raise LLMError(f"Empty response from Ollama (model={self.model})") + return LLMResponse(text=text, model=self.model, provider=self.name, raw=data) + + +# ==================== OPENAI-COMPAT ==================== + + +class OpenAICompatProvider(LLMProvider): + """Any OpenAI-compatible ``/v1/chat/completions`` endpoint. + + Supply ``--llm-endpoint http://host:port`` (with or without ``/v1``). + API key via ``--llm-api-key`` or the ``OPENAI_API_KEY`` env var. + """ + + name = "openai-compat" + + def __init__( + self, + model: str, + endpoint: Optional[str] = None, + api_key: Optional[str] = None, + timeout: int = 120, + **_: object, + ): + resolved_key = api_key or os.environ.get("OPENAI_API_KEY") + super().__init__(model=model, endpoint=endpoint, api_key=resolved_key, timeout=timeout) + + def _resolve_url(self) -> str: + if not self.endpoint: + raise LLMError("openai-compat provider requires --llm-endpoint") + url = self.endpoint.rstrip("/") + if url.endswith("/chat/completions"): + return url + if not url.endswith("/v1"): + url = f"{url}/v1" + return f"{url}/chat/completions" + + def check_available(self) -> tuple[bool, str]: + if not self.endpoint: + return False, "no --llm-endpoint configured" + base = self.endpoint.rstrip("/") + base = base.removesuffix("/chat/completions").removesuffix("/v1") + try: + req = Request(f"{base}/v1/models") + if self.api_key: + req.add_header("Authorization", f"Bearer {self.api_key}") + with urlopen(req, timeout=5): + pass + except (URLError, HTTPError, OSError) as e: + return False, f"Cannot reach {self.endpoint}: {e}" + return True, "ok" + + def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse: + body: dict = { + "model": self.model, + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + "temperature": 0.1, + } + if json_mode: + body["response_format"] = {"type": "json_object"} + headers = {} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + data = _http_post_json(self._resolve_url(), body, headers=headers, timeout=self.timeout) + try: + text = data["choices"][0]["message"]["content"] + except (KeyError, IndexError, TypeError) as e: + raise LLMError(f"Unexpected response shape: {e}") from e + if not text: + raise LLMError(f"Empty response from {self.name} (model={self.model})") + return LLMResponse(text=text, model=self.model, provider=self.name, raw=data) + + +# ==================== ANTHROPIC ==================== + + +class AnthropicProvider(LLMProvider): + name = "anthropic" + DEFAULT_ENDPOINT = "https://api.anthropic.com" + API_VERSION = "2023-06-01" + + def __init__( + self, + model: str, + api_key: Optional[str] = None, + endpoint: Optional[str] = None, + timeout: int = 120, + **_: object, + ): + key = api_key or os.environ.get("ANTHROPIC_API_KEY") + super().__init__( + model=model, + endpoint=endpoint or self.DEFAULT_ENDPOINT, + api_key=key, + timeout=timeout, + ) + + def check_available(self) -> tuple[bool, str]: + if not self.api_key: + return False, "ANTHROPIC_API_KEY not set (use --llm-api-key or env)" + # Don't probe — a live request would cost money. First real call will + # surface auth errors if the key is invalid. + return True, "ok" + + def classify(self, system: str, user: str, json_mode: bool = True) -> LLMResponse: + if not self.api_key: + raise LLMError("Anthropic provider requires ANTHROPIC_API_KEY env or --llm-api-key") + sys_prompt = system + if json_mode: + sys_prompt += "\n\nRespond with valid JSON only, no prose." + body = { + "model": self.model, + "max_tokens": 2048, + "temperature": 0.1, + "system": sys_prompt, + "messages": [{"role": "user", "content": user}], + } + headers = { + "X-API-Key": self.api_key, + "anthropic-version": self.API_VERSION, + } + data = _http_post_json( + f"{self.endpoint}/v1/messages", body, headers=headers, timeout=self.timeout + ) + try: + text = "".join( + b.get("text", "") for b in data.get("content", []) or [] if b.get("type") == "text" + ) + except (AttributeError, TypeError) as e: + raise LLMError(f"Unexpected response shape: {e}") from e + if not text: + raise LLMError(f"Empty response from Anthropic (model={self.model})") + return LLMResponse(text=text, model=self.model, provider=self.name, raw=data) + + +# ==================== FACTORY ==================== + + +PROVIDERS: dict[str, type[LLMProvider]] = { + "ollama": OllamaProvider, + "openai-compat": OpenAICompatProvider, + "anthropic": AnthropicProvider, +} + + +def get_provider( + name: str, + model: str, + endpoint: Optional[str] = None, + api_key: Optional[str] = None, + timeout: int = 120, +) -> LLMProvider: + """Build a provider by name. Raises LLMError on unknown provider.""" + cls = PROVIDERS.get(name) + if cls is None: + raise LLMError(f"Unknown provider '{name}'. Choices: {sorted(PROVIDERS.keys())}") + return cls(model=model, endpoint=endpoint, api_key=api_key, timeout=timeout) diff --git a/tests/test_llm_client.py b/tests/test_llm_client.py new file mode 100644 index 0000000..184d100 --- /dev/null +++ b/tests/test_llm_client.py @@ -0,0 +1,327 @@ +"""Tests for mempalace.llm_client. + +HTTP is mocked throughout — these tests do not require a running Ollama +or network access. Live-provider smoke tests live outside the unit-test +suite. +""" + +import json +from unittest.mock import patch, MagicMock + +import pytest + +from mempalace.llm_client import ( + AnthropicProvider, + LLMError, + OllamaProvider, + OpenAICompatProvider, + _http_post_json, + get_provider, +) + + +# ── factory ───────────────────────────────────────────────────────────── + + +def test_get_provider_ollama(): + p = get_provider("ollama", "gemma4:e4b") + assert isinstance(p, OllamaProvider) + assert p.model == "gemma4:e4b" + assert p.endpoint == OllamaProvider.DEFAULT_ENDPOINT + + +def test_get_provider_openai_compat(): + p = get_provider("openai-compat", "foo", endpoint="http://localhost:1234") + assert isinstance(p, OpenAICompatProvider) + + +def test_get_provider_anthropic(): + p = get_provider("anthropic", "claude-haiku", api_key="sk-xxx") + assert isinstance(p, AnthropicProvider) + assert p.api_key == "sk-xxx" + + +def test_get_provider_unknown_raises(): + with pytest.raises(LLMError, match="Unknown provider"): + get_provider("nonsense", "x") + + +# ── _http_post_json ───────────────────────────────────────────────────── + + +def test_http_post_json_success(): + mock_resp = MagicMock() + mock_resp.read.return_value = b'{"ok": true}' + mock_resp.__enter__.return_value = mock_resp + mock_resp.__exit__.return_value = False + with patch("mempalace.llm_client.urlopen", return_value=mock_resp): + result = _http_post_json("http://x/y", {"a": 1}, {}, timeout=5) + assert result == {"ok": True} + + +def test_http_post_json_http_error_wraps_as_llm_error(): + from urllib.error import HTTPError + import io + + err = HTTPError("http://x", 404, "Not Found", {}, io.BytesIO(b"model missing")) + with patch("mempalace.llm_client.urlopen", side_effect=err): + with pytest.raises(LLMError, match="HTTP 404"): + _http_post_json("http://x", {}, {}, timeout=5) + + +def test_http_post_json_url_error_wraps_as_llm_error(): + from urllib.error import URLError + + with patch("mempalace.llm_client.urlopen", side_effect=URLError("conn refused")): + with pytest.raises(LLMError, match="Cannot reach"): + _http_post_json("http://x", {}, {}, timeout=5) + + +def test_http_post_json_malformed_response(): + mock_resp = MagicMock() + mock_resp.read.return_value = b"not json" + mock_resp.__enter__.return_value = mock_resp + mock_resp.__exit__.return_value = False + with patch("mempalace.llm_client.urlopen", return_value=mock_resp): + with pytest.raises(LLMError, match="Malformed"): + _http_post_json("http://x", {}, {}, timeout=5) + + +# ── OllamaProvider ────────────────────────────────────────────────────── + + +def _mock_ollama_chat_response(content: str): + mock = MagicMock() + mock.read.return_value = json.dumps({"message": {"content": content}}).encode() + mock.__enter__.return_value = mock + mock.__exit__.return_value = False + return mock + + +def test_ollama_check_available_finds_model(): + tags = {"models": [{"name": "gemma4:e4b"}, {"name": "other:latest"}]} + mock = MagicMock() + mock.read.return_value = json.dumps(tags).encode() + mock.__enter__.return_value = mock + mock.__exit__.return_value = False + with patch("mempalace.llm_client.urlopen", return_value=mock): + p = OllamaProvider(model="gemma4:e4b") + ok, msg = p.check_available() + assert ok + assert msg == "ok" + + +def test_ollama_check_available_accepts_latest_suffix(): + tags = {"models": [{"name": "mymodel:latest"}]} + mock = MagicMock() + mock.read.return_value = json.dumps(tags).encode() + mock.__enter__.return_value = mock + mock.__exit__.return_value = False + with patch("mempalace.llm_client.urlopen", return_value=mock): + p = OllamaProvider(model="mymodel") + ok, _ = p.check_available() + assert ok + + +def test_ollama_check_available_missing_model(): + tags = {"models": [{"name": "other:latest"}]} + mock = MagicMock() + mock.read.return_value = json.dumps(tags).encode() + mock.__enter__.return_value = mock + mock.__exit__.return_value = False + with patch("mempalace.llm_client.urlopen", return_value=mock): + p = OllamaProvider(model="absent") + ok, msg = p.check_available() + assert not ok + assert "ollama pull absent" in msg + + +def test_ollama_check_available_unreachable(): + from urllib.error import URLError + + with patch("mempalace.llm_client.urlopen", side_effect=URLError("refused")): + p = OllamaProvider(model="gemma4:e4b") + ok, msg = p.check_available() + assert not ok + assert "Cannot reach Ollama" in msg + + +def test_ollama_classify_sends_json_format(): + captured = {} + + def fake_urlopen(req, *, timeout): + captured["url"] = req.full_url + captured["body"] = json.loads(req.data.decode()) + return _mock_ollama_chat_response('{"classifications": []}') + + with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen): + p = OllamaProvider(model="gemma4:e4b") + resp = p.classify("sys", "user", json_mode=True) + + assert captured["body"]["format"] == "json" + assert captured["body"]["model"] == "gemma4:e4b" + assert captured["url"].endswith("/api/chat") + assert resp.provider == "ollama" + assert resp.text == '{"classifications": []}' + + +def test_ollama_classify_empty_content_raises(): + with patch("mempalace.llm_client.urlopen", return_value=_mock_ollama_chat_response("")): + p = OllamaProvider(model="x") + with pytest.raises(LLMError, match="Empty response"): + p.classify("s", "u") + + +# ── OpenAICompatProvider ──────────────────────────────────────────────── + + +def _mock_openai_response(content: str): + mock = MagicMock() + payload = {"choices": [{"message": {"content": content}}]} + mock.read.return_value = json.dumps(payload).encode() + mock.__enter__.return_value = mock + mock.__exit__.return_value = False + return mock + + +def test_openai_compat_resolves_url_with_v1_suffix(): + captured = {} + + def fake_urlopen(req, *, timeout): + captured["url"] = req.full_url + return _mock_openai_response('{"ok": true}') + + with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen): + p = OpenAICompatProvider(model="x", endpoint="http://h:1234") + p.classify("s", "u") + assert captured["url"] == "http://h:1234/v1/chat/completions" + + +def test_openai_compat_resolves_url_with_existing_v1(): + captured = {} + + def fake_urlopen(req, *, timeout): + captured["url"] = req.full_url + return _mock_openai_response('{"ok": true}') + + with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen): + p = OpenAICompatProvider(model="x", endpoint="http://h:1234/v1") + p.classify("s", "u") + assert captured["url"] == "http://h:1234/v1/chat/completions" + + +def test_openai_compat_requires_endpoint(): + p = OpenAICompatProvider(model="x") + with pytest.raises(LLMError, match="requires --llm-endpoint"): + p.classify("s", "u") + + +def test_openai_compat_sends_authorization_when_key_present(): + captured = {} + + def fake_urlopen(req, *, timeout): + captured["auth"] = req.get_header("Authorization") + return _mock_openai_response('{"ok": true}') + + with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen): + p = OpenAICompatProvider(model="x", endpoint="http://h", api_key="sk-aaa") + p.classify("s", "u") + assert captured["auth"] == "Bearer sk-aaa" + + +def test_openai_compat_uses_env_var_fallback(monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "sk-from-env") + p = OpenAICompatProvider(model="x", endpoint="http://h") + assert p.api_key == "sk-from-env" + + +def test_openai_compat_sends_response_format_json(): + captured = {} + + def fake_urlopen(req, *, timeout): + captured["body"] = json.loads(req.data.decode()) + return _mock_openai_response('{"ok": true}') + + with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen): + p = OpenAICompatProvider(model="x", endpoint="http://h") + p.classify("s", "u", json_mode=True) + assert captured["body"]["response_format"] == {"type": "json_object"} + + +def test_openai_compat_unexpected_shape_raises(): + mock = MagicMock() + mock.read.return_value = b'{"nothing": "here"}' + mock.__enter__.return_value = mock + mock.__exit__.return_value = False + with patch("mempalace.llm_client.urlopen", return_value=mock): + p = OpenAICompatProvider(model="x", endpoint="http://h") + with pytest.raises(LLMError, match="Unexpected response shape"): + p.classify("s", "u") + + +# ── AnthropicProvider ─────────────────────────────────────────────────── + + +def _mock_anthropic_response(text: str): + mock = MagicMock() + payload = {"content": [{"type": "text", "text": text}]} + mock.read.return_value = json.dumps(payload).encode() + mock.__enter__.return_value = mock + mock.__exit__.return_value = False + return mock + + +def test_anthropic_requires_api_key(monkeypatch): + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + p = AnthropicProvider(model="claude-haiku") + ok, msg = p.check_available() + assert not ok + assert "ANTHROPIC_API_KEY" in msg + + +def test_anthropic_reads_env_key(monkeypatch): + monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-env") + p = AnthropicProvider(model="claude-haiku") + assert p.api_key == "sk-ant-env" + ok, _ = p.check_available() + assert ok + + +def test_anthropic_classify_sends_version_and_key(): + captured = {} + + def fake_urlopen(req, *, timeout): + captured["api_key"] = req.get_header("X-api-key") + captured["version"] = req.get_header("Anthropic-version") + return _mock_anthropic_response('{"ok": true}') + + with patch("mempalace.llm_client.urlopen", side_effect=fake_urlopen): + p = AnthropicProvider(model="claude-haiku", api_key="sk-ant-abc") + resp = p.classify("s", "u") + assert captured["api_key"] == "sk-ant-abc" + assert captured["version"] == AnthropicProvider.API_VERSION + assert resp.text == '{"ok": true}' + + +def test_anthropic_joins_multiple_text_blocks(): + mock = MagicMock() + payload = { + "content": [ + {"type": "text", "text": "part one. "}, + {"type": "text", "text": "part two."}, + ] + } + mock.read.return_value = json.dumps(payload).encode() + mock.__enter__.return_value = mock + mock.__exit__.return_value = False + with patch("mempalace.llm_client.urlopen", return_value=mock): + p = AnthropicProvider(model="claude-haiku", api_key="sk-ant") + resp = p.classify("s", "u") + assert resp.text == "part one. part two." + + +def test_anthropic_no_key_raises_on_classify(monkeypatch): + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + p = AnthropicProvider(model="claude-haiku") + with pytest.raises(LLMError, match="requires ANTHROPIC_API_KEY"): + p.classify("s", "u")