Fix: set cosine distance metadata on all collection creation sites
ChromaDB defaults HNSW index to L2 (Euclidean) distance, but
MemPalace scoring uses 1-distance which requires cosine (range 0-2).
Add metadata={"hnsw:space": "cosine"} to the 4 production and 3 test
call sites that were missing it.
Closes #218
This commit is contained in:
+20
-7
@@ -27,7 +27,8 @@ def test_project_mining():
|
||||
os.makedirs(project_root / "backend")
|
||||
|
||||
write_file(
|
||||
project_root / "backend" / "app.py", "def main():\n print('hello world')\n" * 20
|
||||
project_root / "backend" / "app.py",
|
||||
"def main():\n print('hello world')\n" * 20,
|
||||
)
|
||||
with open(project_root / "mempalace.yaml", "w") as f:
|
||||
yaml.dump(
|
||||
@@ -59,7 +60,9 @@ def test_scan_project_respects_gitignore():
|
||||
write_file(project_root / ".gitignore", "ignored.py\ngenerated/\n")
|
||||
write_file(project_root / "src" / "app.py", "print('hello')\n" * 20)
|
||||
write_file(project_root / "ignored.py", "print('ignore me')\n" * 20)
|
||||
write_file(project_root / "generated" / "artifact.py", "print('artifact')\n" * 20)
|
||||
write_file(
|
||||
project_root / "generated" / "artifact.py", "print('artifact')\n" * 20
|
||||
)
|
||||
|
||||
assert scanned_files(project_root) == ["src/app.py"]
|
||||
finally:
|
||||
@@ -74,7 +77,9 @@ def test_scan_project_respects_nested_gitignore():
|
||||
write_file(project_root / ".gitignore", "*.log\n")
|
||||
write_file(project_root / "subrepo" / ".gitignore", "tasks/\n")
|
||||
write_file(project_root / "subrepo" / "src" / "main.py", "print('main')\n" * 20)
|
||||
write_file(project_root / "subrepo" / "tasks" / "task.py", "print('task')\n" * 20)
|
||||
write_file(
|
||||
project_root / "subrepo" / "tasks" / "task.py", "print('task')\n" * 20
|
||||
)
|
||||
write_file(project_root / "subrepo" / "debug.log", "debug\n" * 20)
|
||||
|
||||
assert scanned_files(project_root) == ["subrepo/src/main.py"]
|
||||
@@ -133,7 +138,9 @@ def test_scan_project_can_disable_gitignore():
|
||||
write_file(project_root / ".gitignore", "data/\n")
|
||||
write_file(project_root / "data" / "stuff.csv", "a,b,c\n" * 20)
|
||||
|
||||
assert scanned_files(project_root, respect_gitignore=False) == ["data/stuff.csv"]
|
||||
assert scanned_files(project_root, respect_gitignore=False) == [
|
||||
"data/stuff.csv"
|
||||
]
|
||||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
@@ -146,7 +153,9 @@ def test_scan_project_can_include_ignored_directory():
|
||||
write_file(project_root / ".gitignore", "docs/\n")
|
||||
write_file(project_root / "docs" / "guide.md", "# Guide\n" * 20)
|
||||
|
||||
assert scanned_files(project_root, include_ignored=["docs"]) == ["docs/guide.md"]
|
||||
assert scanned_files(project_root, include_ignored=["docs"]) == [
|
||||
"docs/guide.md"
|
||||
]
|
||||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
@@ -215,7 +224,9 @@ def test_file_already_mined_check_mtime():
|
||||
palace_path = os.path.join(tmpdir, "palace")
|
||||
os.makedirs(palace_path)
|
||||
client = chromadb.PersistentClient(path=palace_path)
|
||||
col = client.get_or_create_collection("mempalace_drawers")
|
||||
col = client.get_or_create_collection(
|
||||
"mempalace_drawers", metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
test_file = os.path.join(tmpdir, "test.txt")
|
||||
with open(test_file, "w") as f:
|
||||
@@ -269,7 +280,9 @@ def test_mine_dry_run_with_tiny_file_no_crash():
|
||||
project_root = Path(tmpdir).resolve()
|
||||
|
||||
# One normal file and one that falls below MIN_CHUNK_SIZE
|
||||
write_file(project_root / "good.py", "def main():\n print('hello world')\n" * 20)
|
||||
write_file(
|
||||
project_root / "good.py", "def main():\n print('hello world')\n" * 20
|
||||
)
|
||||
write_file(project_root / "tiny.txt", "x")
|
||||
|
||||
with open(project_root / "mempalace.yaml", "w") as f:
|
||||
|
||||
Reference in New Issue
Block a user