CmdForge/tests/test_embeddings.py

206 lines
6.2 KiB
Python

"""Unit tests for semantic search embeddings module."""
import json
import sqlite3
import pytest
from cmdforge.registry.embeddings import (
build_embed_text,
cosine_similarity,
pack_embedding,
store_embedding,
unpack_embedding,
)
# ---------------------------------------------------------------------------
# pack / unpack
# ---------------------------------------------------------------------------
def test_pack_unpack_roundtrip():
"""Pack a vector, unpack it, verify identical."""
original = [0.1, 0.2, 0.3, -0.5, 1.0, 0.0]
blob = pack_embedding(original)
result = unpack_embedding(blob)
assert result is not None
assert len(result) == len(original)
for a, b in zip(original, result):
assert abs(a - b) < 1e-6
def test_unpack_empty_blob():
"""Empty blob returns None."""
assert unpack_embedding(b"") is None
def test_unpack_malformed_blob():
"""Blob not divisible by 4 returns None."""
assert unpack_embedding(b"\x00\x01\x02") is None
# ---------------------------------------------------------------------------
# cosine_similarity
# ---------------------------------------------------------------------------
def test_cosine_similarity_identical():
"""Same vector should give similarity of 1.0."""
vec = [1.0, 2.0, 3.0]
assert abs(cosine_similarity(vec, vec) - 1.0) < 1e-6
def test_cosine_similarity_orthogonal():
"""Orthogonal vectors should give similarity of 0.0."""
a = [1.0, 0.0, 0.0]
b = [0.0, 1.0, 0.0]
assert abs(cosine_similarity(a, b)) < 1e-6
def test_cosine_similarity_opposite():
"""Opposite vectors should give similarity of -1.0."""
a = [1.0, 0.0]
b = [-1.0, 0.0]
assert abs(cosine_similarity(a, b) - (-1.0)) < 1e-6
def test_cosine_similarity_dimension_mismatch():
"""Different lengths should return 0.0."""
a = [1.0, 2.0]
b = [1.0, 2.0, 3.0]
assert cosine_similarity(a, b) == 0.0
def test_cosine_similarity_zero_vector():
"""Zero vector should return 0.0."""
a = [0.0, 0.0, 0.0]
b = [1.0, 2.0, 3.0]
assert cosine_similarity(a, b) == 0.0
# ---------------------------------------------------------------------------
# build_embed_text
# ---------------------------------------------------------------------------
def test_build_embed_text():
"""Verify text construction from name/desc/tags."""
text = build_embed_text("my-tool", "Analyzes CSV files", ["csv", "data"])
assert text == "my-tool | Analyzes CSV files | csv, data"
def test_build_embed_text_no_tags():
"""Handles None tags."""
text = build_embed_text("my-tool", "A description", None)
assert text == "my-tool | A description"
def test_build_embed_text_empty_tags():
"""Handles empty tags list."""
text = build_embed_text("my-tool", "A description", [])
assert text == "my-tool | A description"
def test_build_embed_text_all_empty():
"""All empty/None fields returns None."""
assert build_embed_text(None, None, None) is None
assert build_embed_text("", "", []) is None
def test_build_embed_text_only_name():
"""Name only."""
text = build_embed_text("my-tool", None, None)
assert text == "my-tool"
# ---------------------------------------------------------------------------
# store and retrieve (in-memory SQLite)
# ---------------------------------------------------------------------------
def _make_db():
"""Create an in-memory SQLite DB with the embeddings table."""
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA foreign_keys=OFF") # No FK to tools table in test
conn.executescript("""
CREATE TABLE tools (
id INTEGER PRIMARY KEY,
owner TEXT, name TEXT, version TEXT,
description TEXT, category TEXT, tags TEXT,
downloads INTEGER DEFAULT 0,
visibility TEXT DEFAULT 'public',
moderation_status TEXT DEFAULT 'approved'
);
CREATE TABLE tool_embeddings (
tool_id INTEGER PRIMARY KEY,
embedding BLOB NOT NULL,
dimensions INTEGER NOT NULL,
model TEXT NOT NULL,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
""")
return conn
def test_store_and_retrieve():
"""Store embedding, verify retrieval."""
conn = _make_db()
conn.execute(
"INSERT INTO tools (id, owner, name, version) VALUES (1, 'test', 'my-tool', '1.0.0')"
)
conn.commit()
vec = [0.1, 0.2, 0.3, 0.4, 0.5]
store_embedding(conn, 1, vec, "nomic-embed-text")
row = conn.execute("SELECT * FROM tool_embeddings WHERE tool_id = 1").fetchone()
assert row is not None
assert row["dimensions"] == 5
assert row["model"] == "nomic-embed-text"
recovered = unpack_embedding(row["embedding"])
assert recovered is not None
assert len(recovered) == 5
for a, b in zip(vec, recovered):
assert abs(a - b) < 1e-6
def test_store_upsert():
"""Store should update on conflict."""
conn = _make_db()
conn.execute(
"INSERT INTO tools (id, owner, name, version) VALUES (1, 'test', 'my-tool', '1.0.0')"
)
conn.commit()
store_embedding(conn, 1, [1.0, 2.0], "model-a")
store_embedding(conn, 1, [3.0, 4.0, 5.0], "model-b")
row = conn.execute("SELECT * FROM tool_embeddings WHERE tool_id = 1").fetchone()
assert row["dimensions"] == 3
assert row["model"] == "model-b"
def test_model_mismatch_filtered():
"""Embeddings with different model should be skipped in search-like filtering."""
conn = _make_db()
conn.execute(
"INSERT INTO tools (id, owner, name, version, description, tags, visibility, moderation_status) "
"VALUES (1, 'test', 'my-tool', '1.0.0', 'test tool', '[]', 'public', 'approved')"
)
conn.commit()
store_embedding(conn, 1, [1.0, 0.0, 0.0], "model-a")
# Query with model-b should skip this embedding
rows = conn.execute(
"SELECT * FROM tool_embeddings WHERE model = ?",
["model-b"],
).fetchall()
assert len(rows) == 0
# Query with model-a should find it
rows = conn.execute(
"SELECT * FROM tool_embeddings WHERE model = ?",
["model-a"],
).fetchall()
assert len(rows) == 1