This project implements an execution-grounded NL→SQL assistant for the Spider dataset using:
- a multi-agent pipeline
- RAG over database schemas (BM25 + tuned dense retrieval)
- SQLite execution + automatic repair loops
- An E5 embedding model fine-tuned on Spider for improved schema retrieval
Given an input example {question, db_id} from Spider:
- retrieve the most relevant schema pieces.
- generate SQL with an LLM.
- validate SQL shape (audit / optional reprompt).
- execute in SQLite.
- auto-repair on errors (deterministic + LLM repair).
- output final SQL + detailed metadata.
In a Kaggle Notebook (GPU recommended for LLM generation):
!pip -q install transformers accelerate bitsandbytes sentencepiece sqlparse sqlglot
!pip -q install sentence-transformers faiss-cpu rank_bm25Set paths as in the notebook:
-
Spider dataset:
- spider_root =
/kaggle/input/yale-universitys-spider-10-nlp-dataset/spider - contains
tables.jsonanddatabase/
- spider_root =
-
Dev split:
- dev1_path =
/kaggle/input/dev-json-splitted-v1/.../dev_1.json
- dev1_path =
-
Tuned embedder:
- EMBED_MODEL =
/kaggle/input/e5model-spider-tuning/transformers/default/1
- EMBED_MODEL =
-
Generator LLM:
- MODEL_NAME =
XGenerationLab/XiYanSQL-QwenCoder-7B-2502 (loaded in 4-bit)
- MODEL_NAME =
-
Embedder:
- tuned E5 (SentenceTransformers format)
data_agent = DataAgent(dev1_path, tables_path)
dev1, schema_by_db, used_db_ids = data_agent.load()
llm = LLMGeneratorAgent(MODEL_NAME)
llm.load()
schema_agent = SchemaStoreAgent(schema_by_db, db_root)
for db_id in used_db_ids:
schema_agent.ensure(db_id)
schema_agent.build_db_lexicon(db_id)
retrieval_agent = RetrievalIndexAgent(schema_agent, embed_model_name=EMBED_MODEL)
retrieval_agent.build_all(used_db_ids)
fk_agent = FKBridgeAgent(schema_agent)
join_agent = JoinPlannerAgent()
router_agent = RetrievalRouterAgent(schema_agent, retrieval_agent, fk_agent)
shape_agent = ShapeCheckerAgent()
exec_agent = ExecutionAgent()
det_repair = DeterministicRepairAgent(schema_by_db)
validator = ValidatorRepairAgent(llm, schema_agent, exec_agent, det_repair)
prompt_builder = PromptBuilderAgent(
llm=llm,
router=router_agent,
schema_agent=schema_agent,
value_linker=None, # not used
join_hint_agent=join_agent,
)
runner = PipelineRunner(
llm=llm,
prompt_builder=prompt_builder,
shape_agent=shape_agent,
exec_agent=exec_agent,
det_repair=det_repair,
validator=validator
)
pred_path = "/kaggle/working/pred_rag_baseline_agents.sql"
preds, meta_df = runner.run_config(
dev1, name="rag_baseline_agents",
use_value_linker=False, gated=False,
out_sql_path=pred_path
)gold_path = "/kaggle/working/dev1_gold.sql"
with open(gold_path, "w", encoding="utf-8") as f:
for ex in dev1:
f.write(ex["query"].strip() + "\t" + ex["db_id"].strip() + "\n")
!git clone -q https://github.com/taoyds/spider
!python spider/evaluation.py \
--gold "/kaggle/working/dev1_gold.sql" \
--pred "/kaggle/working/pred_rag_baseline_agents.sql" \
--db "/kaggle/input/yale-universitys-spider-10-nlp-dataset/spider/database" \
--table "/kaggle/input/yale-universitys-spider-10-nlp-dataset/spider/tables.json" \
--etype allInput: (question, db_id)
Output: final executable SQL (+ metadata)
Pipeline stages:
- DataAgent → loads dev set and schema metadata (tables.json), builds schema_by_db.
- SchemaStoreAgent → converts schemas into retrievable documents + builds FK join graph.
- RetrievalIndexAgent → builds:
- BM25 index (lexical retrieval)
- FAISS index using tuned E5 (dense retrieval)
- RetrievalRouterAgent → chooses retrieval mode and builds schema context:
- BM25-based, dense-based, or fallback to full DDL
- FKBridgeAgent → expands retrieved docs via FK shortest paths (adds bridge tables/edges).
- JoinPlannerAgent → optionally formats a compact list of FK join edges as hints (disabled by default)
- PromptBuilderAgent → composes final NL→SQL prompt using retrieved schema context.
- LLMGeneratorAgent → loads XiYanSQL-QwenCoder (4-bit) and generates SQL, then cleans output.
- ShapeCheckerAgent → parses SQL with sqlglot and audits expected query structure (agg/group/order/limit/distinct)
- ExecutionAgent → runs SQL in SQLite read-only mode
- DeterministicRepairAgent → fixes common runtime errors:
- unknown table / unknown column via fuzzy matching + alias mapping
- ValidatorRepairAgent → LLM repair loop using (DDL + bad SQL + SQLite error), then re-executes
- PipelineRunner → runs on the full dev set, saves predictions and meta logs
Retrieval is done over schema documents, not natural language docs.
Document types
- TABLE docs: table name + columns (+ types) + PK + FK lines relevant to that table
- EDGE docs: one doc per FK relation like A.col = B.col
Retrieval methods:
- BM25: strong when question contains exact schema tokens
- Dense (tuned E5): strong for paraphrases / semantic match
Routing
Router selects the retrieval strategy based on lexical hits + score gaps, otherwise falls back to full DDL.
FK expansion
Even if retrieval returns correct tables, it might miss connecting tables. FKBridgeAgent adds minimal join paths so the LLM sees a joinable subgraph.
The fine-tuned model is: intfloat/e5-small-v2 on Spider train examples to improve schema retrieval.
Training idea:
- Build “COLUMN docs” for each database (table.col + hints).
- Parse gold SQL to extract used columns.
- Train contrastively:
query: <question><->passage: <column_doc> - Loss:
MultipleNegativesRankingLoss(in-batch negatives)
Result:
An embedding model that better aligns Spider questions with the correct schema elements.
pred_*.sql-> one predicted SQL query per linepred_*_meta.csv-> per-example metadata:- router mode, retrieval stats, shape issues.
- execution status, repair traces, validator steps.
dev1_gold.sql-> gold SQL + db_id file for Spider evaluation
- Keep
USE_VALIDATOR_REPAIR=Truefor best executable-SQL rate. - Use
USE_SHAPE_AUDIT=Trueto analyze failures; enable reprompt only when debugging. - Start with
USE_JOIN_HINTS=Falseand enable it only as an ablation.
Spider evaluation reports:
- structural matching (exact / partial), and
- execution accuracy.
Execution accuracy can have occasional false positives (same result by coincidence), so meta logs + partial metrics are important for diagnosing errors.