Files
mempalace/tests/test_backends.py
T

81 lines
2.3 KiB
Python
Raw Normal View History

2026-04-11 19:16:49 -04:00
import chromadb
import pytest
from mempalace.backends.chroma import ChromaBackend, ChromaCollection
class _FakeCollection:
def __init__(self):
self.calls = []
def add(self, **kwargs):
self.calls.append(("add", kwargs))
def upsert(self, **kwargs):
self.calls.append(("upsert", kwargs))
def query(self, **kwargs):
self.calls.append(("query", kwargs))
return {"kind": "query"}
def get(self, **kwargs):
self.calls.append(("get", kwargs))
return {"kind": "get"}
def delete(self, **kwargs):
self.calls.append(("delete", kwargs))
def count(self):
self.calls.append(("count", {}))
return 7
def test_chroma_collection_delegates_methods():
fake = _FakeCollection()
collection = ChromaCollection(fake)
collection.add(documents=["d"], ids=["1"], metadatas=[{"wing": "w"}])
collection.upsert(documents=["u"], ids=["2"], metadatas=[{"room": "r"}])
assert collection.query(query_texts=["q"]) == {"kind": "query"}
assert collection.get(where={"wing": "w"}) == {"kind": "get"}
collection.delete(ids=["1"])
assert collection.count() == 7
assert fake.calls == [
("add", {"documents": ["d"], "ids": ["1"], "metadatas": [{"wing": "w"}]}),
("upsert", {"documents": ["u"], "ids": ["2"], "metadatas": [{"room": "r"}]}),
("query", {"query_texts": ["q"]}),
("get", {"where": {"wing": "w"}}),
("delete", {"ids": ["1"]}),
("count", {}),
]
def test_chroma_backend_create_false_raises_without_creating_directory(tmp_path):
palace_path = tmp_path / "missing-palace"
with pytest.raises(FileNotFoundError):
ChromaBackend().get_collection(
str(palace_path),
collection_name="mempalace_drawers",
create=False,
)
assert not palace_path.exists()
def test_chroma_backend_create_true_creates_directory_and_collection(tmp_path):
palace_path = tmp_path / "palace"
collection = ChromaBackend().get_collection(
str(palace_path),
collection_name="mempalace_drawers",
create=True,
)
assert palace_path.is_dir()
assert isinstance(collection, ChromaCollection)
client = chromadb.PersistentClient(path=str(palace_path))
client.get_collection("mempalace_drawers")