206 lines
6.2 KiB
Python
206 lines
6.2 KiB
Python
"""Unit tests for semantic search embeddings module."""
|
|
|
|
import json
|
|
import sqlite3
|
|
|
|
import pytest
|
|
|
|
from cmdforge.registry.embeddings import (
|
|
build_embed_text,
|
|
cosine_similarity,
|
|
pack_embedding,
|
|
store_embedding,
|
|
unpack_embedding,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# pack / unpack
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_pack_unpack_roundtrip():
|
|
"""Pack a vector, unpack it, verify identical."""
|
|
original = [0.1, 0.2, 0.3, -0.5, 1.0, 0.0]
|
|
blob = pack_embedding(original)
|
|
result = unpack_embedding(blob)
|
|
assert result is not None
|
|
assert len(result) == len(original)
|
|
for a, b in zip(original, result):
|
|
assert abs(a - b) < 1e-6
|
|
|
|
|
|
def test_unpack_empty_blob():
|
|
"""Empty blob returns None."""
|
|
assert unpack_embedding(b"") is None
|
|
|
|
|
|
def test_unpack_malformed_blob():
|
|
"""Blob not divisible by 4 returns None."""
|
|
assert unpack_embedding(b"\x00\x01\x02") is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# cosine_similarity
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_cosine_similarity_identical():
|
|
"""Same vector should give similarity of 1.0."""
|
|
vec = [1.0, 2.0, 3.0]
|
|
assert abs(cosine_similarity(vec, vec) - 1.0) < 1e-6
|
|
|
|
|
|
def test_cosine_similarity_orthogonal():
|
|
"""Orthogonal vectors should give similarity of 0.0."""
|
|
a = [1.0, 0.0, 0.0]
|
|
b = [0.0, 1.0, 0.0]
|
|
assert abs(cosine_similarity(a, b)) < 1e-6
|
|
|
|
|
|
def test_cosine_similarity_opposite():
|
|
"""Opposite vectors should give similarity of -1.0."""
|
|
a = [1.0, 0.0]
|
|
b = [-1.0, 0.0]
|
|
assert abs(cosine_similarity(a, b) - (-1.0)) < 1e-6
|
|
|
|
|
|
def test_cosine_similarity_dimension_mismatch():
|
|
"""Different lengths should return 0.0."""
|
|
a = [1.0, 2.0]
|
|
b = [1.0, 2.0, 3.0]
|
|
assert cosine_similarity(a, b) == 0.0
|
|
|
|
|
|
def test_cosine_similarity_zero_vector():
|
|
"""Zero vector should return 0.0."""
|
|
a = [0.0, 0.0, 0.0]
|
|
b = [1.0, 2.0, 3.0]
|
|
assert cosine_similarity(a, b) == 0.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# build_embed_text
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_build_embed_text():
|
|
"""Verify text construction from name/desc/tags."""
|
|
text = build_embed_text("my-tool", "Analyzes CSV files", ["csv", "data"])
|
|
assert text == "my-tool | Analyzes CSV files | csv, data"
|
|
|
|
|
|
def test_build_embed_text_no_tags():
|
|
"""Handles None tags."""
|
|
text = build_embed_text("my-tool", "A description", None)
|
|
assert text == "my-tool | A description"
|
|
|
|
|
|
def test_build_embed_text_empty_tags():
|
|
"""Handles empty tags list."""
|
|
text = build_embed_text("my-tool", "A description", [])
|
|
assert text == "my-tool | A description"
|
|
|
|
|
|
def test_build_embed_text_all_empty():
|
|
"""All empty/None fields returns None."""
|
|
assert build_embed_text(None, None, None) is None
|
|
assert build_embed_text("", "", []) is None
|
|
|
|
|
|
def test_build_embed_text_only_name():
|
|
"""Name only."""
|
|
text = build_embed_text("my-tool", None, None)
|
|
assert text == "my-tool"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# store and retrieve (in-memory SQLite)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_db():
|
|
"""Create an in-memory SQLite DB with the embeddings table."""
|
|
conn = sqlite3.connect(":memory:")
|
|
conn.row_factory = sqlite3.Row
|
|
conn.execute("PRAGMA foreign_keys=OFF") # No FK to tools table in test
|
|
conn.executescript("""
|
|
CREATE TABLE tools (
|
|
id INTEGER PRIMARY KEY,
|
|
owner TEXT, name TEXT, version TEXT,
|
|
description TEXT, category TEXT, tags TEXT,
|
|
downloads INTEGER DEFAULT 0,
|
|
visibility TEXT DEFAULT 'public',
|
|
moderation_status TEXT DEFAULT 'approved'
|
|
);
|
|
CREATE TABLE tool_embeddings (
|
|
tool_id INTEGER PRIMARY KEY,
|
|
embedding BLOB NOT NULL,
|
|
dimensions INTEGER NOT NULL,
|
|
model TEXT NOT NULL,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
""")
|
|
return conn
|
|
|
|
|
|
def test_store_and_retrieve():
|
|
"""Store embedding, verify retrieval."""
|
|
conn = _make_db()
|
|
conn.execute(
|
|
"INSERT INTO tools (id, owner, name, version) VALUES (1, 'test', 'my-tool', '1.0.0')"
|
|
)
|
|
conn.commit()
|
|
|
|
vec = [0.1, 0.2, 0.3, 0.4, 0.5]
|
|
store_embedding(conn, 1, vec, "nomic-embed-text")
|
|
|
|
row = conn.execute("SELECT * FROM tool_embeddings WHERE tool_id = 1").fetchone()
|
|
assert row is not None
|
|
assert row["dimensions"] == 5
|
|
assert row["model"] == "nomic-embed-text"
|
|
|
|
recovered = unpack_embedding(row["embedding"])
|
|
assert recovered is not None
|
|
assert len(recovered) == 5
|
|
for a, b in zip(vec, recovered):
|
|
assert abs(a - b) < 1e-6
|
|
|
|
|
|
def test_store_upsert():
|
|
"""Store should update on conflict."""
|
|
conn = _make_db()
|
|
conn.execute(
|
|
"INSERT INTO tools (id, owner, name, version) VALUES (1, 'test', 'my-tool', '1.0.0')"
|
|
)
|
|
conn.commit()
|
|
|
|
store_embedding(conn, 1, [1.0, 2.0], "model-a")
|
|
store_embedding(conn, 1, [3.0, 4.0, 5.0], "model-b")
|
|
|
|
row = conn.execute("SELECT * FROM tool_embeddings WHERE tool_id = 1").fetchone()
|
|
assert row["dimensions"] == 3
|
|
assert row["model"] == "model-b"
|
|
|
|
|
|
def test_model_mismatch_filtered():
|
|
"""Embeddings with different model should be skipped in search-like filtering."""
|
|
conn = _make_db()
|
|
conn.execute(
|
|
"INSERT INTO tools (id, owner, name, version, description, tags, visibility, moderation_status) "
|
|
"VALUES (1, 'test', 'my-tool', '1.0.0', 'test tool', '[]', 'public', 'approved')"
|
|
)
|
|
conn.commit()
|
|
|
|
store_embedding(conn, 1, [1.0, 0.0, 0.0], "model-a")
|
|
|
|
# Query with model-b should skip this embedding
|
|
rows = conn.execute(
|
|
"SELECT * FROM tool_embeddings WHERE model = ?",
|
|
["model-b"],
|
|
).fetchall()
|
|
assert len(rows) == 0
|
|
|
|
# Query with model-a should find it
|
|
rows = conn.execute(
|
|
"SELECT * FROM tool_embeddings WHERE model = ?",
|
|
["model-a"],
|
|
).fetchall()
|
|
assert len(rows) == 1
|