diff --git a/mempalace/knowledge_graph.py b/mempalace/knowledge_graph.py index 4a4399a..235dc4c 100644 --- a/mempalace/knowledge_graph.py +++ b/mempalace/knowledge_graph.py @@ -99,9 +99,10 @@ class KnowledgeGraph: def close(self): """Close the database connection.""" - if self._connection is not None: - self._connection.close() - self._connection = None + with self._lock: + if self._connection is not None: + self._connection.close() + self._connection = None def _entity_id(self, name: str) -> str: return name.lower().replace(" ", "_").replace("'", "") diff --git a/tests/test_kg_thread_safety.py b/tests/test_kg_thread_safety.py new file mode 100644 index 0000000..2ffb087 --- /dev/null +++ b/tests/test_kg_thread_safety.py @@ -0,0 +1,13 @@ +"""TDD: KnowledgeGraph.close() must hold self._lock.""" + +import inspect +from mempalace.knowledge_graph import KnowledgeGraph + + +class TestKGCloseLock: + def test_close_holds_lock(self): + src = inspect.getsource(KnowledgeGraph.close) + assert "self._lock" in src, ( + "close() does not acquire self._lock. " + "Closing while a read/write is in progress can corrupt data." + )