Commit fac2c742 by 范立洲

feat: add intent and chatgpt

parent 7d4baa0f
yunheBot-Jie2ooj2
FROM python:3.9-slim AS compile-image
RUN apt-get upgrade & \
apt-get install -y --no-install-recommends build-essential gcc & \
python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
COPY requirements.txt .
RUN pip install -r requirements.txt
FROM python:3.9-slim AS build-image
COPY --from=compile-image /opt/venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH" FLASK_ENV="PRD"
WORKDIR /yuheBot
COPY ./src/ .
CMD ["gunicorn","-w","4","-b","0.0.0.0:9000","--log-level=info","--access-logfile","./log/access.log","--error-logfile","./log/error.log","wsgi:app"]
[base]
secret_key=asdfzxcv
debug=true
debug=false
host=0.0.0.0
port=9000
[openai]
api_key=qwertyuiopasdfghjkl
temperature=0.7
[qdrant]
qdrant_host=127.0.0.1
qdrant_port=6333
qdrant_collection=test
[openai]
api_key=qwertyuiopasdfghjkl
[qa]
score=0.8
limit=3
collection=qa
[yhbot]
park_score=0.8
[intent]
score=0.8
limit=3
collection=intent
aiohttp==3.8.4
aiosignal==1.3.1
aniso8601==9.0.1
anyio==3.6.2
async-timeout==4.0.2
attrs==23.1.0
cachetools==5.3.0
certifi==2022.12.7
chardet==5.1.0
charset-normalizer==3.1.0
click==8.1.3
colorama==0.4.6
distlib==0.3.6
filelock==3.11.0
Flask==2.2.3
Flask-RESTful==0.3.9
frozenlist==1.3.3
grpcio==1.54.0
grpcio-tools==1.54.0
gunicorn==20.1.0
h11==0.14.0
h2==4.1.0
hpack==4.0.0
httpcore==0.17.0
httpx==0.24.0
hyperframe==6.0.1
idna==3.4
itsdangerous==2.1.2
Jinja2==3.1.2
jsonschema==4.17.3
MarkupSafe==2.1.2
multidict==6.0.4
numpy==1.24.2
openai==0.27.4
packaging==23.1
platformdirs==3.2.0
pluggy==1.0.0
protobuf==4.22.3
pydantic==1.10.7
pyproject_api==1.5.1
pyrsistent==0.19.3
pytz==2023.3
qdrant-client==1.1.4
requests==2.28.2
setuptools-scm==7.1.0
six==1.16.0
sniffio==1.3.0
tox==4.4.12
tqdm==4.65.0
typing_extensions==4.5.0
urllib3==1.26.15
virtualenv==20.21.0
Werkzeug==2.2.3
yarl==1.8.2
......@@ -49,12 +49,13 @@ package_dir =
# For more information, check out https://semver.org/.
install_requires =
importlib-metadata
Flask>=2.2.3
qdrant-client>=1.1.4
openai>=0.27.4
flask-restful>=0.3.9
Flask==2.2.3
qdrant-client==1.1.4
openai==0.27.4
flask-restful==0.3.9
click==8.1.3
jsonschema==4.17.3
gunicorn==20.1.0
[options.packages.find]
......
# coding=utf-8
#!/usr/bin/env python3
import os
import logging
from yunheBot.create_app import create_app
config_path = os.getenv("CONFIG")
if config_path is None or len(config_path) == 0:
config_path = "/etc/yunheBot/yhbot.ini"
app = create_app(config_path=config_path)
gunicorn_logger = logging.getLogger("gunicorn.error")
app.logger.handlers = gunicorn_logger.handlers
app.logger.setLevel(gunicorn_logger.level)
......@@ -2,10 +2,10 @@
from flask import Blueprint
from flask_restful import Api
from .v1.industrial_park import IndustrialParkResource
from .v1.completion import CompletionResource
rest_api_bp = Blueprint("api", __name__)
rest_api = Api(rest_api_bp)
rest_api.add_resource(IndustrialParkResource, "/v1/park", endpoint="park")
rest_api.add_resource(CompletionResource, "/v1/completion", endpoint="completion")
# coding=utf-8
# import json
from flask_restful import Resource
from flask import current_app, request
from typing import Dict
from ...core.embedding import OpenAIEmbedding
from ...core.datasource import QDDataSource, OpenAIDataSource
from ...common import response, validator
from ...core.qa import QABot
from ...core.intent import IntentBot
from ...core.chatgpt import ChatGPTBot
from ...core import const
POST_SCHEMA = {
"type": "object",
"properties": {
"prompt": {"type": "string"},
"limit": {"type": "integer"},
"score": {"type": "integer"},
},
"required": ["prompt"],
}
class CompletionResource(Resource):
def __init__(self):
embedder = OpenAIEmbedding(current_app.config["OPENAI_API_KEY"])
qd_ds = QDDataSource(
current_app.config["QDRANT_HOST"],
current_app.config["QDRANT_PORT"],
)
openai_ds = OpenAIDataSource(
current_app.config["OPENAI_API_KEY"],
float(current_app.config["OPENAI_TEMPERATURE"]),
)
qa_score = float(current_app.config["QA_SCORE"])
qa_limit = int(current_app.config["QA_LIMIT"])
qa_collection = current_app.config["QA_COLLECTION"]
self.qa_bot = QABot(qd_ds, embedder, qa_collection, qa_score, qa_limit)
intent_score = float(current_app.config["INTENT_SCORE"])
intent_limit = int(current_app.config["INTENT_LIMIT"])
intent_collection = current_app.config["INTENT_COLLECTION"]
self.intent_bot = IntentBot(
qd_ds, embedder, intent_collection, intent_score, intent_limit
)
self.chatgpt_bot = ChatGPTBot(openai_ds)
@response.wrap_response
@validator.validator(POST_SCHEMA)
def post(self):
data = request.get_json(force=True)
# req: Dict = json.loads(data)
req: Dict = data
# 1. intent check
intent = self.intent_bot.complete(req["prompt"])
# 1.1 intent ==> qa
if intent is not None and intent.answer == const.INTENT_QA:
try:
answer = self.qa_bot.complete(**req)
if answer is None:
return response.NotFound("No Answer Found")
return response.OK("OK", answer.to_dict())
except Exception as e:
return response.InternalServerError(str(e))
# 1.2 intent ==> chatgpt
try:
resp = self.chatgpt_bot.complete(req["prompt"])
return response.OK("", resp.to_dict())
except Exception as e:
return response.InternalServerError(str(e))
# coding=utf-8
from flask_restful import Resource
from flask import current_app, request
from ...core.embedding import OpenAIEmbedding
from ...core.datasource import YHQdrantClient
from ...common import response, validator
POST_SCHEMA = {"type": "object", "properties": {"prompt": {"type": "string"}}}
class IndustrialParkResource(Resource):
def __init__(self):
self.embedder = OpenAIEmbedding(current_app.config["OPENAI_API_KEY"])
self.qd_client = YHQdrantClient(
current_app.config["QDRANT_HOST"],
current_app.config["QDRANT_PORT"],
self.embedder,
)
self.collection = current_app.config["QDRANT_COLLECTION"]
self.park_score = float(current_app.config["PARK_SCORE"])
@response.wrap_response
@validator.validator(POST_SCHEMA)
def post(self):
data = request.get_json(force=True)
prompt = data.get("prompt")
try:
raw_results = self.qd_client.search(
prompt=prompt,
collection_name=self.collection,
limit=3,
)
except Exception as e:
return response.InternalServerError(str(e))
results = [item for item in raw_results if item.score >= self.park_score]
if len(results) == 0:
return response.NotFound("No Answer Found")
answer = results[0].payload
return response.OK("", answer)
......@@ -3,9 +3,11 @@
import click
import os
from yunheBot.create_app import create_app
from yunheBot.core.embedding import OpenAIEmbedding
from yunheBot.core.datasource import YHQdrantClient
from .create_app import create_app
from .core.embedding import OpenAIEmbedding
from .core.datasource import QDDataSource
from .core.qa import QABot
from .core.intent import IntentBot
@click.group()
......@@ -23,22 +25,33 @@ def server(conf):
@click.command()
@click.option("--bot_type", type=str)
@click.option("--open_ai_key", type=str)
@click.option("--db", type=str)
@click.option("--file", type=str)
@click.option("--collection", type=str)
def upload(open_ai_key: str, db: str, filepath: str, collection: str):
embdder = OpenAIEmbedding(open_ai_key)
def upload(bot_type: str, open_ai_key: str, db: str, file: str, collection: str):
embedder = OpenAIEmbedding(open_ai_key)
qdrant_host, qdrant_port = db.split(":")
client = YHQdrantClient(qdrant_host, qdrant_port, embdder)
qd_ds = QDDataSource(qdrant_host, qdrant_port)
if os.path.splitext(filepath)[1].lower() == ".csv":
client.create_collection(collection)
client.upload_csv(filepath, collection)
bot = None
if bot_type == "qa":
bot = QABot(qd_ds, embedder, collection, 0, 0)
elif bot_type == "intent":
bot = IntentBot(qd_ds, embedder, collection, 0, 0)
if bot is None:
print("--bot_type must be specificed as 'qa' or 'intent'")
if os.path.splitext(file)[1].lower() == ".csv":
bot.create_collection()
bot.upload_csv(file)
else:
print(f"{filepath} is not a CSV file.")
print(f"{file} is not a CSV file.")
return
cli.add_command(server)
cli.add_command(upload)
# cli.add_command(intent)
......@@ -19,79 +19,24 @@ def setup_config(config_path):
new_config = {
# base
"SECRET_KEY": conf.get("base", "secret_key"), # session加密
# "PERMANENT_SESSION_LIFETIME": timedelta(days=30), # 设置session过期时间
"DEBUG": conf.getboolean("base", "debug"),
"HOST": conf.get("base", "host"),
"PORT": conf.getint("base", "port"),
# qdrant
"QDRANT_HOST": conf.get("qdrant", "qdrant_host"),
"QDRANT_PORT": conf.get("qdrant", "qdrant_port"),
"QDRANT_COLLECTION": conf.get("qdrant", "qdrant_collection"),
# "QDRANT_COLLECTION": conf.get("qdrant", "qdrant_collection"),
# openai_api_key
"OPENAI_API_KEY": conf.get("openai", "api_key"),
# yhbot
"PARK_SCORE": conf.get("yhbot", "park_score")
"OPENAI_TEMPERATURE": conf.get("openai", "temperature"),
# qa
"QA_COLLECTION": conf.get("qa", "collection"),
"QA_SCORE": conf.get("qa", "score"),
"QA_LIMIT": conf.get("qa", "limit"),
# intent
"INTENT_COLLECTION": conf.get("intent", "collection"),
"INTENT_SCORE": conf.get("intent", "score"),
"INTENT_LIMIT": conf.get("intent", "limit"),
}
base_config.update(new_config)
# class BaseConfig:
# """配置基类"""
#
# # SECRET_KEY = os.urandom(24)
# SECRET_KEY = "asdfzxcv" # session加密
# PERMANENT_SESSION_LIFETIME = timedelta(days=30) # 设置session过期时间
# DEBUG = True
# # SERVER_NAME = 'example.com'
# HOST = "0.0.0.0"
# PORT = 9000
#
# @staticmethod
# def init_app(app):
# pass
# class NewConfig(BaseConfig):
# """区分配置文件"""
#
# def __init__(self, config_path):
# self._conf = get_config(config_path)
#
#
# # base
# SECRET_KEY = self._conf.get("base", "secret_key") # session加密
# PERMANENT_SESSION_LIFETIME = timedelta(days=30) # 设置session过期时间
# DEBUG = self._conf.getboolean("base", "debug")
# HOST = self._conf.get("base", "host")
# PORT = self._conf.getint("base", "port")
#
# # qdrant
# QDRANT_HOST = self._conf.get("qdrant", "qdrant_host")
# QDRANT_PORT = self._conf.get("qdrant", "qdrant_port")
#
# # openai_api_key
# OPENAI_KEY = self._conf.get("openai", "api_key")
#
# # # mysql
# # MYSQL_USERNAME = conf.get("mysql", "USERNAME")
# # MYSQL_PASSWORD = conf.get("mysql", "PASSWORD")
# # MYSQL_HOSTNAME = conf.get("mysql", "HOSTNAME")
# # MYSQL_PORT = conf.getint("mysql", "PORT")
# # MYSQL_DATABASE = conf.get("mysql", "DATABASE")
# # DB_URI = "mysql+pymysql://{}:{}@{}:{}/{}?charset=utf8".format(
# # MYSQL_USERNAME, MYSQL_PASSWORD, MYSQL_HOSTNAME, MYSQL_PORT, MYSQL_DATABASE
# # )
# # SQLALCHEMY_DATABASE_URI = DB_URI
# # SQLALCHEMY_TRACK_MODIFICATIONS = True
#
# # # redis
# # redis_obj = {
# # "host": conf.get("redis", "REDIS_HOST"),
# # "port": conf.get("redis", "REDIS_PORT"),
# # "password": conf.get("redis", "REDIS_PWD"),
# # "decode_responses": conf.getboolean("redis", "DECODE_RESPONSES"),
# # "db": conf.getint("redis", "REDIS_DB"),
# # }
# # POOL = redis.ConnectionPool(**redis_obj)
# # R = redis.Redis(connection_pool=POOL)
base_config.update(new_config)
# coding=utf-8
from .datasource import DataSource
from .common import BotPayload
class ChatGPTBot:
def __init__(self, data_source: DataSource):
self.data_source = data_source
def complete(self, prompt: str) -> BotPayload:
answer = self.data_source.search(prompt)
resp = {"question": prompt, "answer": answer}
return BotPayload.factory(**resp)
# coding=utf-8
class BotPayload:
@classmethod
def factory(cls, **kwargs):
return cls(
**{
key: value
for key, value in kwargs.items()
if not key.startswith("_") and not key.startswith("__")
}
)
def __init__(self, question: str, answer: str, **kwargs):
self.question = question
self.answer = answer
for k, v in kwargs.items():
setattr(self, k, v)
def to_dict(self):
return {
key: value
for key, value in self.__dict__.items()
if not key.startswith("_")
}
# coding=utf-8
INTENT_QA = 1
# coding=utf-8
import csv
import openai
from abc import abstractmethod, ABCMeta
from typing import List, Callable, Dict, Union
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from qdrant_client.http.models import PointStruct
from .embedding import Embedding
from ..util import util
# from ..util import util
class DataSource(metaclass=ABCMeta):
@abstractmethod
def search(
self, collection: str, prompt: str, limit: int, **kwargs
) -> Union[List, str]:
pass
@util.singleton
class YHQdrantClient:
def __init__(self, host: str, port: int, embedder: Embedding):
# @util.singleton
class QDDataSource:
def __init__(self, host: str, port: int):
self.client = QdrantClient(host=host, port=port)
self.embedder = embedder
def create_collection(
self, collection: str, size: int = 1536, distance: Distance = Distance.COSINE
......@@ -33,7 +45,13 @@ class YHQdrantClient:
except Exception as e:
raise e
def upload_csv(self, filepath: str, collection: str):
def upload_csv(
self,
collection: str,
embedder: Embedding,
filepath: str,
formatter: Callable[[List], Dict],
):
"""导入csv数据集
Args:
......@@ -43,28 +61,55 @@ class YHQdrantClient:
csv_reader = csv.reader(f)
count = 0
for line in csv_reader:
print(line)
item = self.embedder.to_embedding(line)
print(item)
vec = embedder.get_vector(line[0])
print(line, vec)
self.client.upsert(
collection_name=collection,
wait=True,
points=[
PointStruct(
id=count,
vector=item[2],
payload={"title": item[0], "text": item[1]},
vector=vec,
payload=formatter(line),
)
],
)
count += 1
def search(self, prompt, collection_name, limit):
vec = self.embedder.get_vector(prompt)
def search(
self,
prompt: str,
limit: int,
embedder: Embedding = None,
collection: str = "",
) -> List:
if embedder is None:
raise ValueError("Embedding Object is None")
if collection is None or len(collection) == 0:
raise ValueError("Collection is None")
vec = embedder.get_vector(prompt)
results = self.client.search(
collection_name=collection_name,
collection_name=collection,
query_vector=vec,
limit=limit,
search_params={"exact": False, "hnsw_ef": 128},
)
return results
class OpenAIDataSource:
def __init__(self, api_key: str, temperature: float):
self.api_key = api_key
self.temperature = temperature
openai.api_key = api_key
def search(self, prompt: str, limit: int = 0, model="gpt-3.5-turbo") -> str:
completion = openai.ChatCompletion.create(
temperature=self.temperature,
model=model,
messages=[{"role": "user", "content": prompt}],
)
return completion.choices[0].message.content
......@@ -2,14 +2,14 @@
import openai
from abc import abstractmethod, ABCMeta
from typing import List, Tuple
# from typing import List, Tuple
from ..util import util
class Embedding(metaclass=ABCMeta):
@abstractmethod
def generate_sample(self, items: List[str]) -> Tuple[str, str, float]:
pass
# @abstractmethod
# def generate_sample(self, items: List[str]) -> Tuple[str, str, float]:
# pass
@abstractmethod
def get_vector(self, prompt: str) -> float:
......@@ -18,9 +18,9 @@ class Embedding(metaclass=ABCMeta):
@util.singleton
class OpenAIEmbedding:
def __init__(self, open_ai_key):
self.open_ai_key = open_ai_key
openai.api_key = open_ai_key
def __init__(self, api_key):
self.open_ai_key = api_key
openai.api_key = api_key
def _create_embedding(self, text):
sentence_embeddings = openai.Embedding.create(
......@@ -28,10 +28,6 @@ class OpenAIEmbedding:
)
return sentence_embeddings["data"][0]["embedding"]
def generate_sample(self, items):
vec = self._create_embedding(items[1])
return [items[0], items[1], vec]
def get_vector(self, prompt):
vec = self._create_embedding(prompt)
return vec
# coding=utf-8
from collections import Counter
from typing import List, Dict
from .datasource import QDDataSource
from .embedding import Embedding
from .const import INTENT_QA
from .common import BotPayload
class IntentBot:
def __init__(
self,
data_source: QDDataSource,
embedder: Embedding,
collection: str,
score: float,
limit: int,
):
self.data_source = data_source
self.embdder = embedder
self.collection = collection
self.score = float(score)
self.limit = int(limit)
def create_collection(self):
"""创建数据集
Args:
size (int): 向量大小
distance (QdrantClient.Distance): 损失函数
"""
try:
self.data_source.create_collection(collection=self.collection)
except Exception as e:
raise e
def formatter(self, line: List) -> Dict:
# line format: 问题,答案(意图)
# return {"question": line[0], "answer": INTENT_QA}
return BotPayload(question=line[0], answer=INTENT_QA).to_dict()
def upload_csv(self, filepath: str):
"""导入csv数据集
Args:
filepath: 文件路径
"""
self.data_source.upload_csv(
self.collection, self.embdder, filepath, self.formatter
)
def complete(self, prompt: str) -> BotPayload:
points = self.data_source.search(
prompt,
self.limit,
embedder=self.embdder,
collection=self.collection,
)
print(points)
if points is None or len(points) == 0:
return None
payloads = [item.payload for item in points if item.score >= self.score]
if len(payloads) == 0:
return None
answers = [p["answer"] for p in payloads]
most_common_answer = Counter(answers).most_common(1)[0][0]
result = list(filter(lambda p: p["answer"] == most_common_answer, payloads))
return BotPayload.factory(**result[0])
# coding=utf-8
from typing import List, Dict
from .datasource import QDDataSource
from .embedding import Embedding
from .common import BotPayload
class QABot:
def __init__(
self,
data_source: QDDataSource,
embedder: Embedding,
collection: str,
score: int,
limit: int,
):
self.data_source = data_source
self.embdder = embedder
self.collection = collection
self.score = float(score)
self.limit = int(limit)
def create_collection(self):
"""创建数据集
Args:
size (int): 向量大小
distance (QdrantClient.Distance): 损失函数
"""
try:
self.data_source.create_collection(collection=self.collection)
except Exception as e:
raise e
def formatter(self, line: List) -> Dict:
# line format: 问题,答案,跳转类型
# return {"question": line[0], "answer": line[1], "type": line[2]}
# return {"question": line[0], "answer": line[1], "type": "demo"}
return BotPayload(question=line[0], answer=line[1], type="demo").to_dict()
def upload_csv(self, filepath: str):
"""导入csv数据集
Args:
filepath: 文件路径
"""
self.data_source.upload_csv(
self.collection, self.embdder, filepath, self.formatter
)
def complete(self, prompt: str, score: int = 0, limit: int = 0) -> BotPayload:
_limit = self.limit if limit == 0 else limit
_score = self.score if score == 0 else score
points = self.data_source.search(
prompt,
_limit,
embedder=self.embdder,
collection=self.collection,
)
if points is None or len(points) == 0:
return None
payloads = [item.payload for item in points if item.score >= _score]
if len(payloads) == 0:
return None
return BotPayload.factory(**payloads[0])
......@@ -6,7 +6,6 @@ from flask import Flask
from .register.bp_register import register_blueprint
from .register.conf_register import register_config
def create_app(*args, **kwargs):
app = Flask(__name__)
register_blueprint(app, *args, **kwargs)
......
......@@ -11,4 +11,3 @@ def register_config(app: Flask, *args, **kwargs):
setup_config(config_path)
app.config.from_mapping(base_config)
# print(app.config)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论