diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 4936a75..d05075d 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -4,12 +4,9 @@ import mempalace.embedding as embedding @pytest.fixture(autouse=True) -def clear_embedding_state(): - embedding._EF_CACHE.clear() - embedding._WARNED.clear() - yield - embedding._EF_CACHE.clear() - embedding._WARNED.clear() +def isolate_embedding_state(monkeypatch): + monkeypatch.setattr(embedding, "_EF_CACHE", {}) + monkeypatch.setattr(embedding, "_WARNED", set()) def test_auto_picks_cuda(monkeypatch):