Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,10 @@ JWT_ALGORITHM=HS256
JWT_EXPIRE_MINUTES=1440
FRONTEND_URL=http://localhost:5173
DATABASE_URL=sqlite:///./workmate.db

# AI / LLM
GEMINI_API_KEY=your-gemini-api-key
VOYAGE_API_KEY=your-voyageai-api-key

# Notion (direct API access — optional, JSON import works without this)
NOTION_TOKEN=your-notion-api-token
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,10 @@ chroma_db/

# Antigravity SKILLS
.agent/
CLAUDE.md
CLAUDE.md
.claude/

# Debugging
Tasks_for_Claude.txt
debug_rag.py
generate_token.py
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@ dependencies = [
"python-multipart>=0.0.22",
"pypdf2>=3.0.1",
"sse-starlette>=3.3.2",
"voyageai>=0.3.7",
"bm25s>=0.2.12",
]
3 changes: 1 addition & 2 deletions src/backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from src.backend.config import settings
from src.backend.database import Base, engine
from src.backend.routers import admin, auth, chat, conversations, upload
from src.backend.routers import admin, auth, conversations, upload


def create_app() -> FastAPI:
Expand All @@ -21,7 +21,6 @@ def create_app() -> FastAPI:

app.include_router(auth.router)
app.include_router(admin.router)
app.include_router(chat.router)
app.include_router(conversations.router)
app.include_router(upload.router)

Expand Down
1 change: 1 addition & 0 deletions src/backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class Settings(BaseSettings):
FRONTEND_URL: str = "http://localhost:5173"
DATABASE_URL: str = "sqlite:///./workmate.db"
GEMINI_API_KEY: str = ""
VOYAGE_API_KEY: str = ""
NOTION_TOKEN: str = ""

model_config = {"env_file": ".env", "extra": "ignore"}
Expand Down
36 changes: 36 additions & 0 deletions src/backend/dependencies/services.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import logging
import os

from fastapi import HTTPException

from src.backend.load.bm25_manager import BM25Manager, BM25_INDEX_PATH
from src.backend.load.chroma_manager import ChromaManager
from src.backend.load.hybrid_retriever import HybridRetriever
from src.backend.llm.gemini_client import GeminiClient
from src.backend.llm.voyage_reranker import VoyageReranker

logger = logging.getLogger(__name__)

_chroma_manager = None
_gemini_client = None
_voyage_reranker = None
_bm25_manager = None
_hybrid_retriever = None


def get_chroma_manager() -> ChromaManager:
Expand All @@ -33,3 +40,32 @@ def get_gemini_client() -> GeminiClient:
logger.error(f"Failed to initialize GeminiClient: {e}")
raise HTTPException(status_code=500, detail="LLM configuration missing.")
return _gemini_client


def get_voyage_reranker() -> VoyageReranker:
global _voyage_reranker
if _voyage_reranker is None:
try:
_voyage_reranker = VoyageReranker()
except Exception as e:
logger.error(f"Failed to initialize VoyageReranker: {e}")
raise HTTPException(status_code=500, detail="Reranker configuration missing.")
return _voyage_reranker


def get_bm25_manager() -> BM25Manager:
global _bm25_manager
if _bm25_manager is None:
_bm25_manager = BM25Manager()
if os.path.exists(BM25_INDEX_PATH):
_bm25_manager.load(BM25_INDEX_PATH)
else:
logger.warning(f"BM25 index not found at {BM25_INDEX_PATH}. Run NotionIngestor to build it.")
return _bm25_manager


def get_hybrid_retriever() -> HybridRetriever:
global _hybrid_retriever
if _hybrid_retriever is None:
_hybrid_retriever = HybridRetriever(get_chroma_manager(), get_bm25_manager())
return _hybrid_retriever
66 changes: 37 additions & 29 deletions src/backend/llm/gemini_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,25 @@ def filter_chunks(
) -> List[Dict[str, Any]]:
"""
LLM Re-ranking: Asks Gemini to filter out irrelevant chunks before generation.
Uses sequential chunk_N IDs (chunk_1, chunk_2, ...) for the filter prompt so
the LLM and the matching logic agree on the same ID format.
"""
if not chunks:
return []

prompt = prompts.get_filter_prompt(chunks, user_question)
# Remap chunks to sequential IDs for the filter prompt so the LLM
# returns predictable IDs like "chunk_1, chunk_3" instead of UUIDs.
sequential_id_map: Dict[str, Dict[str, Any]] = {}
remapped_chunks = []
for i, chunk in enumerate(chunks, start=1):
seq_id = f"chunk_{i}"
sequential_id_map[seq_id] = chunk
remapped_chunks.append({**chunk, "chunk_id": seq_id})

prompt = prompts.get_filter_prompt(remapped_chunks, user_question)
cfg = types.GenerateContentConfig(
system_instruction=prompts.FILTER_SYSTEM_INSTRUCTION,
temperature=0.0, # Zero precision for extraction
temperature=0.0,
max_output_tokens=100,
)

Expand All @@ -51,40 +62,37 @@ def filter_chunks(
config=cfg,
)
output = getattr(response, "text", "") or ""
print(f"🧠 Re-ranker Output: {output}")
print(f"Re-ranker Output: {output}")

if "NONE" in output.upper():
print("🧠 Re-ranker kept 0 chunks.")
print("Re-ranker kept 0 chunks.")
return []

# Parse the output safely: sometimes Gemini returns a clean comma-separated list of UUIDs
# But sometimes it's lazy and returns just the first part e.g. "30, 4f" instead of "302f24..."
filtered_chunks = []

# Match returned IDs (e.g. "chunk_1, chunk_3") back to original chunks.
output_parts = [p.strip() for p in output.split(",") if p.strip()]

for chunk in chunks:
chunk_id = str(chunk["chunk_id"])

# Check 1: Is the full chunk_id anywhere in the raw output string?
if chunk_id in output:
filtered_chunks.append(chunk)
filtered_chunks = []
seen = set()
for part in output_parts:
# Exact match: "chunk_3"
if part in sequential_id_map and part not in seen:
filtered_chunks.append(sequential_id_map[part])
seen.add(part)
continue

# Check 2: Did the LLM abbreviate the IDs? Check each comma-separated part.
for part in output_parts:
if len(part) >= 2 and chunk_id.startswith(part):
filtered_chunks.append(chunk)
break

print(f"🧠 Re-ranker kept {len(filtered_chunks)}/{len(chunks)} chunks.")

# Fallback to all chunks if zero were matched but it didn't explicitly say "NONE"
# Plain number match: LLM returned "3" instead of "chunk_3"
candidate = f"chunk_{part}"
if candidate in sequential_id_map and candidate not in seen:
filtered_chunks.append(sequential_id_map[candidate])
seen.add(candidate)

print(f"Re-ranker kept {len(filtered_chunks)}/{len(chunks)} chunks.")

# Fallback: if nothing matched but LLM didn't say NONE, return all chunks.
if not filtered_chunks:
return chunks
return chunks
return filtered_chunks

except Exception as e:
logger.warning(f"⚠️ Re-ranking failed (falling back to all chunks): {e}")
logger.warning(f"Re-ranking failed (falling back to all chunks): {e}")
return chunks

def ask_workmate(
Expand Down
106 changes: 106 additions & 0 deletions src/backend/llm/voyage_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from __future__ import annotations

import logging
import os
from typing import Any, Dict, List, Tuple

logger = logging.getLogger(__name__)

DEFAULT_RERANK_MODEL = "rerank-2"
RELEVANCE_THRESHOLD = 0.3


class VoyageReranker:
"""
Reranks retrieved chunks using the VoyageAI rerank API.
Replaces the LLM-based filter_chunks step in the RAG pipeline.
"""

def __init__(
self,
model: str = DEFAULT_RERANK_MODEL,
threshold: float = RELEVANCE_THRESHOLD,
):
api_key = os.getenv("VOYAGE_API_KEY")
if not api_key:
logger.warning(
"[VoyageReranker] VOYAGE_API_KEY not set — reranking disabled. "
"Set VOYAGE_API_KEY in .env to enable VoyageAI reranking."
)
self.client = None
else:
import voyageai
self.client = voyageai.Client(api_key=api_key)

self.model = model
self.threshold = threshold

def rerank(
self,
chunks: List[Dict[str, Any]],
query: str,
top_k: int = 5,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Rerank chunks by relevance to the query.

Returns:
final_chunks — top_k chunks above threshold, score field stripped (clean for generation)
scored_chunks — all input chunks sorted by score descending, with rerank_score added
"""
if not chunks:
return [], []

if self.client is None:
logger.warning("[VoyageReranker] Reranking skipped (no API key). Returning top_k unranked.")
return chunks[:top_k], []

documents = [
f"Page: {c['page_title']}\nSection: {c['section']}\n{c['text']}"
if c.get("section")
else f"Page: {c['page_title']}\n{c['text']}"
for c in chunks
]

try:
result = self.client.rerank(
query=query,
documents=documents,
model=self.model,
top_k=len(chunks), # fetch all scores; we apply threshold + top_k ourselves
)

scored_chunks: List[Dict[str, Any]] = []
for item in result.results:
chunk = chunks[item.index]
scored_chunks.append({**chunk, "rerank_score": round(item.relevance_score, 4)})

scored_chunks.sort(key=lambda x: x["rerank_score"], reverse=True)

logger.info(
f"[VoyageReranker] scores: "
f"{[(c['page_title'], c['rerank_score']) for c in scored_chunks]}"
)

above_threshold = [c for c in scored_chunks if c["rerank_score"] >= self.threshold]
final_scored = above_threshold[:top_k]

if not final_scored:
logger.warning(
f"[VoyageReranker] All {len(chunks)} chunks below threshold "
f"({self.threshold}). Top score: "
f"{scored_chunks[0]['rerank_score'] if scored_chunks else 'N/A'}."
)

# Strip rerank_score before passing to generation prompt
final_chunks = [
{k: v for k, v in c.items() if k != "rerank_score"} for c in final_scored
]

return final_chunks, scored_chunks

except Exception as e:
logger.warning(
f"[VoyageReranker] Reranking failed, falling back to top {top_k} unranked: {e}"
)
return chunks[:top_k], []
77 changes: 77 additions & 0 deletions src/backend/load/bm25_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import logging
import os
import pickle

import bm25s

logger = logging.getLogger(__name__)

PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
BM25_INDEX_PATH = os.path.join(PROJECT_ROOT, "workmate_db", "bm25_index.pkl")


class BM25Manager:
def __init__(self):
self.index = None
self.chunks: list[str] = []
self.metadatas: list[dict] = []
self.ids: list[str] = []

def build_index(self, chunks: list[str], metadatas: list[dict], ids: list[str]):
self.chunks = chunks
self.metadatas = metadatas
self.ids = ids

corpus_indices = list(range(len(chunks)))
indexed_texts = [
f"{m.get('title', '')} {c}".lower()
for c, m in zip(chunks, metadatas)
]
tokenized_corpus = bm25s.tokenize(indexed_texts)
self.index = bm25s.BM25(corpus=corpus_indices)
self.index.index(tokenized_corpus)
logger.info(f"BM25 index built with {len(chunks)} documents")

def search(self, query: str, top_k: int = 10) -> list[dict]:
if self.index is None:
logger.warning("BM25 index not built, returning empty results")
return []

k = min(top_k, len(self.chunks))
query_tokens = bm25s.tokenize([query.lower()])
results, _ = self.index.retrieve(query_tokens, k=k)

output = []
for idx in results[0]:
meta = self.metadatas[idx]
output.append({
"chunk_id": self.ids[idx],
"text": self.chunks[idx],
"page_title": meta.get("title", "Unknown Source"),
"section": meta.get("parent_title", ""),
**meta,
})
return output

def save(self, path: str = BM25_INDEX_PATH):
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f:
pickle.dump(
{
"index": self.index,
"chunks": self.chunks,
"metadatas": self.metadatas,
"ids": self.ids,
},
f,
)
logger.info(f"BM25 index saved to {path}")

def load(self, path: str = BM25_INDEX_PATH):
with open(path, "rb") as f:
data = pickle.load(f)
self.index = data["index"]
self.chunks = data["chunks"]
self.metadatas = data["metadatas"]
self.ids = data["ids"]
logger.info(f"BM25 index loaded from {path} ({len(self.chunks)} documents)")
Loading