官网:GitHub - defog-ai/sqlcoder: SoTA LLM for converting natural language questions to SQL queries
参考:大模型应用实践:「智能问数应用」SQL Coder 构建大模型数据分析助手 - 百度智能云千帆社区
这次的测试环境不是我的老朋友了,是公司的一台机器( Windows 10 22H2 19045 + i5-6600K CPU + 16GB 内存 + RTX 3060 12GB )
这里测试的模型是:
项目地址:GitHub - defog-ai/sqlcoder: SoTA LLM for converting natural language questions to SQL queries
代码是“2024-01-02”在Github下载的,这里将代码放到了这里:
# 创建虚拟环境
python -m venv venv
# 激活虚拟环境
.\venv\scripts\activate
# 退出虚拟环境
deactivate
# 安装依赖(添加到这里,使用的时候方便一些)
pip install -r requirements.txt
# 检查 torch
python
import torch
print("torch版本:"+torch.__version__+" ;cuda是否可用:"+str(torch.cuda.is_available()))
# 安装 GPU版本 torch 下面两个都可以
pip install torch==2.1.2+cu121 -f https://download.pytorch.org/whl/torch_stable.html
pip install torch --extra-index-url https://download.pytorch.org/whl/cu121
# 为了测试更简单,这里将模型权重代码写死了
def get_tokenizer_model(model_name):
model_path = "D:\\0-llm\\defog\\sqlcoder-7b"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto",
use_cache=True,
)
return tokenizer, model
数据库(例如:MySQL)各个表的结构以及表之间的关系,使用
CREATE TABLE `task` (
`id` bigint(20) NOT NULL, -- Unique ID for each Task
`title` varchar(500) COLLATE utf8mb4_unicode_ci DEFAULT NULL, -- Task Title
`createdate` datetime DEFAULT NULL, -- Date of the Task Created
`createuser` bigint(20) DEFAULT NULL, -- Creator of the Task
`remark` varchar(1024) COLLATE utf8mb4_unicode_ci DEFAULT NULL, -- Remark of the Task
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
说明:如果这个文件不修改,打印的结果中没有回答的SQL语句,参照开头百度的教程修改之后就没问题了,文件内容如下:
### Task
Generate a SQL query to answer the following question:
`{user_question}`
### Database Schema
This query will run on a MySQL database whose schema is represented in this string:
{table_metadata_string}
### SQL
Given the database schema, here is the SQL query that answers `{user_question}`:
```sql
按照文章中的顺序到这里,直接提问,会提示缺少依赖,请安装:
python inference.py -q "Count the number of Task."
下面是未修改
下面是修改
python inference.py -q "Count the number of task created each month and sort the results in ascending order by month."
# 回答如下:
SELECT date_trunc('month', created_at) AS MONTH, COUNT(*) AS task_count FROM task GROUP BY MONTH ORDER BY MONTH ASC;
# 需要稍微修改一下:
SELECT date_format(createdate, '%Y-%m') AS MONTH, COUNT(*) AS task_count FROM task GROUP BY MONTH ORDER BY MONTH ASC;
我去,这次生成的语句有两处问题:mysql不支持 date_trunc 函数(提示词中已经指明是 mysql);表结构中没有“created_at”字段,是“createdate”。还有个问题就是之前研究他的时候,用的库表是项目中正式的,当时测试的时候“createdate”是没有问题的,这里为了写这篇笔记改了一下表名、修改字段名、删了几个字段等……这有点不稳定啊
在官方示例的基础上测试中文支持情况,在项目根目录下添加
这里在上面修改的基础上又修改了下面的代码:
def generate_prompt(question, prompt_file="zh/prompt.md", metadata_file="zh/metadata.sql"):
with open(prompt_file, "r",encoding="utf-8") as f:
prompt = f.read()
with open(metadata_file, "r",encoding="utf-8") as f:
table_metadata_string = f.read()
prompt = prompt.format(
user_question=question, table_metadata_string=table_metadata_string
)
return prompt
# 结构和上面的相同,只是修改成中文注释
CREATE TABLE `task` (
`id` bigint(20) NOT NULL, -- 每个任务的唯一ID
`title` varchar(500) COLLATE utf8mb4_unicode_ci DEFAULT NULL, -- 任务标题
`createdate` datetime DEFAULT NULL, -- 任务的创建日期
`createuser` bigint(20) DEFAULT NULL, -- 任务的创建人
`remark` varchar(1024) COLLATE utf8mb4_unicode_ci DEFAULT NULL, -- 任务的备注
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
### 任务
生成一个SQL查询以回答以下问题:
`{user_question}`
### 数据库模式
此查询将在一个MySQL数据库上运行,该数据库的模式在以下字符串中表示:
{table_metadata_string}
### SQL
根据数据库模式,以下是回答 `{user_question}` 的SQL查询:
```sql
还是上面的几个问题
python zh/inference.py -q "统计一下任务的数量"
python zh/inference.py -q "统计每个月创建的任务的数量,并且对结果按月份升序排序"
SELECT date_format(createdate, '%Y-%m') AS MONTH, COUNT(*) AS task_count FROM task GROUP BY MONTH ORDER BY MONTH ASC;
这简直就是惊喜啊,目前测试结果来看,中文prompt要比英文prompt效果好