❝"在数据驱动的时代,让AI理解你的数据库就像教会外星人说人话一样困难。但Vanna做到了,而且做得相当优雅。"
想象一下这样的场景:你兴冲冲地打开ChatGPT,输入"帮我查询一下德国有多少客户",期待着AI能够生成一条完美的SQL语句。结果呢?AI给你返回了一个看起来很专业的查询:
SELECTCOUNT(*)FROMcustomersWHEREcountry ='Germany';
看起来不错对吧?但当你兴奋地复制到数据库中执行时,系统无情地抛出了错误:Table 'customers' doesn't exist。
这就是当前AI生成SQL面临的核心困境:LLM虽然掌握了SQL语法的精髓,却对你的具体数据库结构一无所知。就像一个语言天才试图在不了解当地文化的情况下进行深度交流一样,注定会闹出笑话。
但是,如果我告诉你有一个开源项目能够将SQL生成的准确率从令人绝望的3%提升到令人惊艳的80%,你会相信吗?这就是我们今天要深入探讨的主角——Vanna。
Vanna不是简单的"ChatGPT + SQL"的组合,而是一个基于RAG(Retrieval-Augmented Generation)架构的智能SQL生成框架。它的核心理念可以用一句话概括:让AI不仅懂SQL语法,更要懂你的数据。
从技术架构上看,Vanna采用了经典的RAG模式:
graph TD
A[用户问题] --> B[向量化检索]
B --> C[相关上下文]
C --> D[LLM生成SQL]
D --> E[执行验证]
E --> F[结果反馈]
F --> G[自动训练]
G --> B
这个架构的精妙之处在于,它不是简单地把问题扔给LLM,而是先从知识库中检索出最相关的上下文信息,然后再让LLM基于这些信息生成SQL。这就像给一个外国朋友不仅提供了字典,还提供了当地的文化背景和使用习惯。
让我们深入Vanna的技术内核。通过分析其源码结构,我们可以发现Vanna采用了高度模块化的设计:
# Vanna的核心抽象基类
classVannaBase(ABC):
def__init__(self, config=None):
self.config = config
self.run_sql_is_set =False
self.static_documentation =""
self.dialect = self.config.get("dialect","SQL")
self.language = self.config.get("language",None)
self.max_tokens = self.config.get("max_tokens",14000)
@abstractmethod
defgenerate_embedding(self, data: str, **kwargs)-> List[float]:
"""生成文本嵌入向量"""
pass
@abstractmethod
defget_similar_question_sql(self, question: str, **kwargs)-> list:
"""检索相似的问题-SQL对"""
pass
@abstractmethod
defsubmit_prompt(self, prompt, **kwargs)-> str:
"""提交提示词到LLM"""
pass
这种设计的巧妙之处在于,它将复杂的SQL生成过程分解为三个可插拔的组件:
Vanna的另一个令人印象深刻的特点是其广泛的生态系统支持。从项目结构可以看出,它支持:
LLM提供商(9+):
向量数据库(10+):
关系数据库(10+):
这种"大一统"的设计哲学让Vanna能够适应几乎任何技术栈,这在企业级应用中尤为重要。
在深入Vanna的解决方案之前,我们先来理解传统方法的局限性。Vanna团队进行了一项令人印象深刻的实验,使用Cybersyn SEC数据集测试了不同方法的SQL生成准确率:
实验设置:
结果令人震惊:
这个结果揭示了一个重要的洞察:上下文比模型更重要。即使是最强大的GPT-4,在没有合适上下文的情况下,准确率也只有可怜的10%。
Vanna的成功秘诀在于其精心设计的三层上下文策略:
defadd_ddl(self, ddl: str, **kwargs)-> str:
"""添加数据定义语言到训练数据"""
id = deterministic_uuid(ddl) +"-ddl"
self.ddl_collection.add(
documents=ddl,
embeddings=self.generate_embedding(ddl),
ids=id,
)
returnid
这一层提供了数据库的结构信息,包括表名、字段名、数据类型等。但仅有这些还不够,因为它无法告诉AI如何正确地使用这些表。
defadd_documentation(self, documentation: str, **kwargs)-> str:
"""添加业务文档到训练数据"""
id = deterministic_uuid(documentation) +"-doc"
self.documentation_collection.add(
documents=documentation,
embeddings=self.generate_embedding(documentation),
ids=id,
)
returnid
这一层包含了业务规则、字段含义、计算逻辑等信息。比如"revenue"字段的具体定义,或者某个表中数据的业务含义。
defadd_question_sql(self, question: str, sql: str, **kwargs)-> str:
"""添加问题-SQL对到训练数据"""
question_sql_json = json.dumps({
"question": question,
"sql": sql,
}, ensure_ascii=False)
id = deterministic_uuid(question_sql_json) +"-sql"
self.sql_collection.add(
documents=question_sql_json,
embeddings=self.generate_embedding(question_sql_json),
ids=id,
)
returnid
这是最关键的一层,它提供了具体的问题-SQL对应关系,让AI能够学习到如何将自然语言问题转换为正确的SQL查询。
Vanna的核心创新在于其智能检索机制。当用户提出问题时,系统会:
defgenerate_sql(self, question: str, allow_llm_to_see_data=False, **kwargs)-> str:
# 检索相似的问题-SQL对
question_sql_list = self.get_similar_question_sql(question, **kwargs)
# 检索相关的DDL信息
ddl_list = self.get_related_ddl(question, **kwargs)
# 检索相关的文档
doc_list = self.get_related_documentation(question, **kwargs)
# 构建提示词
prompt = self.get_sql_prompt(
initial_prompt=initial_prompt,
question=question,
question_sql_list=question_sql_list,
ddl_list=ddl_list,
doc_list=doc_list,
**kwargs,
)
# 提交给LLM生成SQL
llm_response = self.submit_prompt(prompt, **kwargs)
returnself.extract_sql(llm_response)
这种方法的精妙之处在于,它不是简单地把所有信息都塞给LLM,而是智能地选择最相关的信息。这样既保证了上下文的质量,又避免了超出LLM的上下文窗口限制。
让我们通过一个具体的例子来看看Vanna是如何工作的:
fromvanna.openai.openai_chatimportOpenAI_Chat
fromvanna.chromadb.chromadb_vectorimportChromaDB_VectorStore
# 创建自定义的Vanna类
classMyVanna(ChromaDB_VectorStore, OpenAI_Chat):
def__init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
OpenAI_Chat.__init__(self, config=config)
# 初始化
vn = MyVanna(config={
'api_key':'your-openai-key',
'model':'gpt-4'
})
# 连接数据库
vn.connect_to_postgres(
host="localhost",
dbname="ecommerce",
user="admin",
password="password"
)
训练Vanna就像教一个新员工熟悉公司的数据库:
# 1. 添加表结构信息
vn.train(ddl="""
CREATE TABLE customers (
id SERIAL PRIMARY KEY,
name VARCHAR(100),
email VARCHAR(100),
country VARCHAR(50),
created_at TIMESTAMP
);
""")
# 2. 添加业务文档
vn.train(documentation="""
customers表存储了所有注册用户的基本信息。
country字段使用ISO 3166-1标准的国家代码。
created_at表示用户注册时间。
""")
# 3. 添加示例查询
vn.train(
question="德国有多少客户?",
sql="SELECT COUNT(*) FROM customers WHERE country = 'DE';"
)
vn.train(
question="最近一个月新注册的用户数量?",
sql="""
SELECT COUNT(*)
FROM customers
WHERE created_at >= CURRENT_DATE - INTERVAL '1 month';
"""
)
训练完成后,你就可以开始享受AI助手的服务了:
# 提问
sql, df, fig = vn.ask("显示每个国家的客户数量,按数量降序排列")
# Vanna会自动:
# 1. 理解问题意图
# 2. 检索相关上下文
# 3. 生成SQL查询
# 4. 执行查询
# 5. 返回结果和可视化图表
生成的SQL可能是这样的:
SELECT
country,
COUNT(*)ascustomer_count
FROMcustomers
GROUPBYcountry
ORDERBYcustomer_countDESC;
defask(self, question: str, auto_train: bool = True, **kwargs):
# ... 生成和执行SQL ...
# 如果查询成功且auto_train=True,自动添加到训练数据
iflen(df) >0andauto_train:
self.add_question_sql(question=question, sql=sql)
这个特性让Vanna能够从每次成功的查询中学习,不断改进自己的性能。
# 当需要探索数据时,Vanna可以生成中间查询
if'intermediate_sql'inllm_response:
intermediate_sql = self.extract_sql(llm_response)
df = self.run_sql(intermediate_sql)
# 基于中间结果生成最终SQL
prompt = self.get_sql_prompt(
# ... 包含中间结果的上下文 ...
doc_list=doc_list + [f"中间查询结果: \n{df.to_markdown()}"],
)
这个特性让AI能够像人类分析师一样,先探索数据再生成最终查询。
def_response_language(self)-> str:
ifself.languageisNone:
return""
returnf"Respond in the{self.language}language."
Vanna支持多种语言的问答,这对国际化企业尤为重要。
通过Vanna团队的详细实验,我们可以清晰地看到不同策略对准确率的影响:
实验数据深度分析:
Schema-only方法的失败原因:
静态示例的局限性:
上下文相关方法的优势:
有趣的是,实验结果显示了不同LLM在不同上下文策略下的表现差异:
GPT-4.1:
Google Bison:
GPT-4.1-mini:
基于实验结果,我们可以总结出几个关键的性能优化策略:
defget_sql_prompt(self, question: str, **kwargs):
# 动态调整检索数量
n_results = min(10, max(3, len(question.split()) //2))
question_sql_list = self.get_similar_question_sql(
question, n_results=n_results
)
# ...
# 高质量的核心示例
core_examples = [
{"question":"...","sql":"...","priority":"high"},
# ...
]
# 自动生成的示例
auto_examples = [
{"question":"...","sql":"...","priority":"medium"},
# ...
]
defcontinuous_learning(self):
# 定期分析查询日志
successful_queries = self.get_successful_queries()
# 自动提取新的训练样本
forqueryinsuccessful_queries:
ifself.is_novel_pattern(query):
self.add_question_sql(query.question, query.sql)
在企业环境中部署Vanna需要考虑更多的因素:
# 企业级配置示例
classEnterpriseVanna:
def__init__(self):
self.config = {
# 多模型支持
'primary_llm':'gpt-4',
'fallback_llm':'gpt-3.5-turbo',
# 向量数据库集群
'vector_store': {
'type':'qdrant',
'cluster_urls': ['http://qdrant-1:6333','http://qdrant-2:6333'],
'collection_name':'enterprise_sql_kb'
},
# 安全配置
'security': {
'enable_query_validation':True,
'allowed_operations': ['SELECT'],
'max_result_rows':10000,
'query_timeout':30
},
# 监控配置
'monitoring': {
'enable_logging':True,
'log_level':'INFO',
'metrics_endpoint':'/metrics'
}
}
企业级部署必须考虑安全性:
classSecureVanna(VannaBase):
defvalidate_sql(self, sql: str)-> bool:
"""SQL安全验证"""
# 检查危险操作
dangerous_keywords = ['DROP','DELETE','UPDATE','INSERT','TRUNCATE']
sql_upper = sql.upper()
forkeywordindangerous_keywords:
ifkeywordinsql_upper:
raiseSecurityError(f"Dangerous operation detected:{keyword}")
# 检查表访问权限
tables = self.extract_table_names(sql)
fortableintables:
ifnotself.user_has_access(table):
raisePermissionError(f"Access denied to table:{table}")
returnTrue
defrun_sql(self, sql: str, **kwargs)-> pd.DataFrame:
# 验证SQL安全性
self.validate_sql(sql)
# 添加行数限制
if'LIMIT'notinsql.upper():
sql +=f" LIMIT{self.config['max_result_rows']}"
returnsuper().run_sql(sql, **kwargs)
importlogging
fromprometheus_clientimportCounter, Histogram, Gauge
classMonitoredVanna(VannaBase):
def__init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Prometheus指标
self.query_counter = Counter('vanna_queries_total','Total queries')
self.query_duration = Histogram('vanna_query_duration_seconds','Query duration')
self.accuracy_gauge = Gauge('vanna_accuracy_rate','Current accuracy rate')
# 日志配置
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
defask(self, question: str, **kwargs):
start_time = time.time()
self.query_counter.inc()
try:
result = super().ask(question, **kwargs)
# 记录成功查询
self.logger.info(f"Successful query:{question}")
duration = time.time() - start_time
self.query_duration.observe(duration)
returnresult
exceptExceptionase:
# 记录失败查询
self.logger.error(f"Failed query:{question}, Error:{str(e)}")
raise
classMultiTenantVanna(VannaBase):
def__init__(self, tenant_id: str, **kwargs):
self.tenant_id = tenant_id
# 租户隔离的配置
config = kwargs.get('config', {})
config['collection_name'] =f"vanna_{tenant_id}"
super().__init__(config=config)
defadd_question_sql(self, question: str, sql: str, **kwargs):
# 添加租户标识
metadata = kwargs.get('metadata', {})
metadata['tenant_id'] = self.tenant_id
returnsuper().add_question_sql(
question, sql, metadata=metadata, **kwargs
)
背景:某大型电商公司有复杂的数据仓库,包含用户、订单、商品、物流等多个业务域的数据。业务分析师经常需要进行复杂的数据查询。
挑战:
Vanna解决方案:
# 训练数据示例
training_examples = [
{
"question":"最近30天每日GMV趋势",
"sql":"""
SELECT
DATE(order_time) as date,
SUM(total_amount) as gmv
FROM orders
WHERE order_time >= CURRENT_DATE - INTERVAL '30 days'
AND order_status = 'completed'
GROUP BY DATE(order_time)
ORDER BY date;
"""
},
{
"question":"各品类的复购率",
"sql":"""
WITH user_category_orders AS (
SELECT
u.user_id,
p.category_id,
COUNT(DISTINCT o.order_id) as order_count
FROM users u
JOIN orders o ON u.user_id = o.user_id
JOIN order_items oi ON o.order_id = oi.order_id
JOIN products p ON oi.product_id = p.product_id
WHERE o.order_status = 'completed'
GROUP BY u.user_id, p.category_id
)
SELECT
c.category_name,
COUNT(CASE WHEN uco.order_count > 1 THEN 1 END) * 100.0 / COUNT(*) as repurchase_rate
FROM user_category_orders uco
JOIN categories c ON uco.category_id = c.category_id
GROUP BY c.category_name;
"""
}
]
效果:
背景:金融公司需要实时监控各种风险指标,业务人员需要快速获取风控数据。
特殊要求:
Vanna定制方案:
classFinanceVanna(VannaBase):
def__init__(self, user_role: str, **kwargs):
super().__init__(**kwargs)
self.user_role = user_role
self.audit_logger = AuditLogger()
defask(self, question: str, **kwargs):
# 审计日志
self.audit_logger.log_query_request(
user_role=self.user_role,
question=question,
timestamp=datetime.now()
)
# 基于角色的查询限制
ifself.user_role =='analyst':
# 分析师只能查询汇总数据
kwargs['aggregation_only'] =True
elifself.user_role =='manager':
# 经理可以查询详细数据但有行数限制
kwargs['max_rows'] =1000
result = super().ask(question, **kwargs)
# 记录查询结果
self.audit_logger.log_query_result(
user_role=self.user_role,
question=question,
sql=result[0]ifresultelseNone,
row_count=len(result[1])ifresultandresult[1]isnotNoneelse0
)
returnresult
背景:制造企业有大量IoT设备数据,需要进行设备状态监控和预测性维护分析。
技术挑战:
解决方案:
# 专门的时序数据训练
time_series_examples = [
{
"question":"设备A最近24小时的温度异常点",
"sql":"""
WITH temp_stats AS (
SELECT
AVG(temperature) as avg_temp,
STDDEV(temperature) as std_temp
FROM sensor_data
WHERE device_id = 'A'
AND timestamp >= NOW() - INTERVAL '24 hours'
)
SELECT
timestamp,
temperature,
ABS(temperature - ts.avg_temp) / ts.std_temp as z_score
FROM sensor_data sd
CROSS JOIN temp_stats ts
WHERE sd.device_id = 'A'
AND sd.timestamp >= NOW() - INTERVAL '24 hours'
AND ABS(sd.temperature - ts.avg_temp) / ts.std_temp > 2
ORDER BY timestamp;
"""
}
]
未来的Vanna可能支持:
# 未来可能的功能
classSmartVanna(VannaBase):
defauto_discover_schema(self):
"""自动发现和理解数据库结构"""
pass
defsuggest_data_quality_checks(self):
"""基于查询模式建议数据质量检查"""
pass
defauto_generate_documentation(self):
"""自动生成数据字典和业务文档"""
pass
# 跨数据源查询
vn.ask("比较我们在MySQL中的销售数据和Snowflake中的财务数据")
# 自动生成跨数据源的查询计划
Vanna这样的工具正在推动"数据民主化"的实现:
# 安装Vanna
pip install vanna
# 安装可选依赖
pip install vanna[openai,chromadb,postgres]
importos
fromvanna.openai.openai_chatimportOpenAI_Chat
fromvanna.chromadb.chromadb_vectorimportChromaDB_VectorStore
classEcommerceVanna(ChromaDB_VectorStore, OpenAI_Chat):
def__init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
OpenAI_Chat.__init__(self, config=config)
# 初始化
vn = EcommerceVanna(config={
'api_key': os.getenv('OPENAI_API_KEY'),
'model':'gpt-4',
'path':'./vanna_db'# ChromaDB存储路径
})
# 连接数据库
vn.connect_to_postgres(
host="localhost",
dbname="ecommerce",
user="postgres",
password="password"
)
# 训练数据
deftrain_ecommerce_model():
# 添加表结构
ddl_statements = [
"""
CREATE TABLE users (
user_id SERIAL PRIMARY KEY,
username VARCHAR(50) UNIQUE NOT NULL,
email VARCHAR(100) UNIQUE NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
country VARCHAR(2)
);
""",
"""
CREATE TABLE products (
product_id SERIAL PRIMARY KEY,
product_name VARCHAR(200) NOT NULL,
category_id INTEGER,
price DECIMAL(10,2),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
""",
"""
CREATE TABLE orders (
order_id SERIAL PRIMARY KEY,
user_id INTEGER REFERENCES users(user_id),
order_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
total_amount DECIMAL(10,2),
status VARCHAR(20) DEFAULT 'pending'
);
"""
]
forddlinddl_statements:
vn.train(ddl=ddl)
# 添加业务文档
documentation = [
"用户表(users)存储所有注册用户信息,country字段使用ISO 3166-1 alpha-2标准",
"订单表(orders)记录所有订单,status字段可能值:pending, completed, cancelled",
"产品表(products)存储商品信息,price字段为美元价格"
]
fordocindocumentation:
vn.train(documentation=doc)
# 添加示例查询
examples = [
{
"question":"今天有多少新用户注册?",
"sql":"SELECT COUNT(*) FROM users WHERE DATE(created_at) = CURRENT_DATE;"
},
{
"question":"最近7天的日均订单金额是多少?",
"sql":"""
SELECT AVG(daily_total) as avg_daily_amount
FROM (
SELECT DATE(order_date) as date, SUM(total_amount) as daily_total
FROM orders
WHERE order_date >= CURRENT_DATE - INTERVAL '7 days'
AND status = 'completed'
GROUP BY DATE(order_date)
) daily_totals;
"""
},
{
"question":"哪个国家的用户最多?",
"sql":"""
SELECT country, COUNT(*) as user_count
FROM users
WHERE country IS NOT NULL
GROUP BY country
ORDER BY user_count DESC
LIMIT 1;
"""
}
]
forexampleinexamples:
vn.train(question=example["question"], sql=example["sql"])
# 执行训练
train_ecommerce_model()
# 开始使用
if__name__ =="__main__":
whileTrue:
question = input("请输入你的问题(输入'quit'退出): ")
ifquestion.lower() =='quit':
break
try:
sql, df, fig = vn.ask(question)
print(f"\n生成的SQL:\n{sql}")
print(f"\n查询结果:\n{df}")
iffig:
fig.show() # 显示图表
exceptExceptionase:
print(f"查询失败:{str(e)}")
classAdvancedEcommerceVanna(EcommerceVanna):
def__init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.business_metrics = {
'GMV':'Gross Merchandise Value - 总商品交易额',
'AOV':'Average Order Value - 平均订单价值',
'LTV':'Lifetime Value - 客户生命周期价值',
'CAC':'Customer Acquisition Cost - 客户获取成本'
}
defpreprocess_question(self, question: str)-> str:
"""预处理问题,替换业务术语"""
forabbr, full_nameinself.business_metrics.items():
ifabbr.lower()inquestion.lower():
question = question.replace(abbr, full_name)
returnquestion
defask(self, question: str, **kwargs):
# 预处理问题
processed_question = self.preprocess_question(question)
# 添加业务上下文
ifany(metricinprocessed_question.lower()formetricinself.business_metrics.values()):
kwargs['include_business_context'] =True
returnsuper().ask(processed_question, **kwargs)
defgenerate_business_report(self, period: str ="last_30_days"):
"""生成业务报告"""
questions = [
f"What was the GMV for the{period}?",
f"What was the AOV for the{period}?",
f"How many new customers did we acquire in the{period}?",
f"What was the top-selling product category in the{period}?"
]
report = {}
forquestioninquestions:
try:
sql, df, _ = self.ask(question, print_results=False)
report[question] = {
'sql': sql,
'result': df.to_dict('records')ifdfisnotNoneelseNone
}
exceptExceptionase:
report[question] = {'error': str(e)}
returnreport
通过深入分析Vanna项目,我们看到了RAG技术在SQL生成领域的巨大潜力。从3%到80%的准确率提升不仅仅是一个数字的变化,更代表着一种全新的数据分析范式的诞生。
Vanna只是AI+SQL领域的一个开始。随着技术的不断发展,我们可以期待:
对于技术从业者:
对于企业决策者:
读到这里,相信你对Vanna和RAG技术有了深入的了解。但学习的旅程永远不会结束,我特别想听听你的想法:
我为大家准备了一个小挑战:
挑战题目:基于本文介绍的Vanna架构,设计一个针对你所在行业的AI SQL助手。请考虑:
| 欢迎光临 链载Ai (https://www.lianzai.com/) | Powered by Discuz! X3.5 |