Commit d972c690 by 范立洲

feat: add yunheBot.server

parent 229bbccf
FROM python:3.9-slim AS compile-image FROM python:3.9-slim AS compile-image
RUN apt-get upgrade & \ WORKDIR /yunheBot
apt-get install -y --no-install-recommends build-essential gcc & \ COPY . .
python -m venv /opt/venv RUN apt-get update && \
ENV PATH="/opt/venv/bin:$PATH" apt-get install -y --no-install-recommends git && \
COPY requirements.txt . apt-get purge -y --auto-remove && \
RUN pip install -r requirements.txt rm -rf /var/lib/apt/lists/* && \
pip install -r requirements.txt
RUN tox -e build
FROM python:3.9-slim AS build-image FROM python:3.9-slim
COPY --from=compile-image /opt/venv /opt/venv WORKDIR /yunheBot
ENV PATH="/opt/venv/bin:$PATH" FLASK_ENV="PRD" COPY --from=compile-image /yunheBot/dist/*.whl /yunheBot/dist/
WORKDIR /yuheBot RUN pip install /yunheBot/dist/*.whl
COPY ./src/ . CMD ["gunicorn","-w","4","-b","0.0.0.0:9000","--log-level=info","--access-logfile","./log/access.log","--error-logfile","./log/error.log","yunheBot.server:app"]
CMD ["gunicorn","-w","4","-b","0.0.0.0:9000","--log-level=info","--access-logfile","./log/access.log","--error-logfile","./log/error.log","wsgi:app"]
...@@ -56,9 +56,15 @@ class CompletionResource(Resource): ...@@ -56,9 +56,15 @@ class CompletionResource(Resource):
@validator.validator(POST_SCHEMA) @validator.validator(POST_SCHEMA)
def post(self): def post(self):
data = request.get_json(force=True) data = request.get_json(force=True)
# req: Dict = json.loads(data)
req: Dict = data req: Dict = data
# if not req["stream"]:
# return self._completion(req)
# else:
# pass
# 1. intent check # 1. intent check
intent = self.intent_bot.complete(req["prompt"]) intent = self.intent_bot.complete(req["prompt"])
...@@ -78,3 +84,30 @@ class CompletionResource(Resource): ...@@ -78,3 +84,30 @@ class CompletionResource(Resource):
return response.OK("", resp.to_dict()) return response.OK("", resp.to_dict())
except Exception as e: except Exception as e:
return response.InternalServerError(str(e)) return response.InternalServerError(str(e))
def _completion(self, req: Dict):
# 1. intent check
intent = self.intent_bot.complete(req["prompt"])
# 1.1 intent ==> qa
if intent is not None and intent.answer == const.INTENT_QA:
try:
answer = self.qa_bot.complete(**req)
if answer is None:
return response.NotFound("No Answer Found")
return response.OK("OK", answer.to_dict())
except Exception as e:
return response.InternalServerError(str(e))
# 1.2 intent ==> chatgpt
try:
resp = self.chatgpt_bot.complete(req["prompt"])
return response.OK("", resp.to_dict())
except Exception as e:
return response.InternalServerError(str(e))
def _stream_completion(self, req):
# # 1. intent check
# intent = self.intent_bot.complete(req["prompt"])
pass
# coding=utf-8 # coding=utf-8
from .datasource import QDDataSource
class BaseQDBot:
def __init__(self, data_source: QDDataSource):
self.data_source = data_source
def create_collection(self):
"""创建数据集
Args:
size (int): 向量大小
distance (QdrantClient.Distance): 损失函数
"""
try:
self.data_source.create_collection(collection=self.collection)
except Exception as e:
raise e
def upload_csv(self, filepath: str):
"""导入csv数据集
Args:
filepath: 文件路径
"""
self.data_source.upload_csv(
self.collection, self.embdder, filepath, self.formatter
)
class BotPayload: class BotPayload:
@classmethod @classmethod
......
...@@ -4,11 +4,10 @@ from collections import Counter ...@@ -4,11 +4,10 @@ from collections import Counter
from typing import List, Dict from typing import List, Dict
from .datasource import QDDataSource from .datasource import QDDataSource
from .embedding import Embedding from .embedding import Embedding
from .const import INTENT_QA from .common import BotPayload, BaseQDBot
from .common import BotPayload
class IntentBot: class IntentBot(BaseQDBot):
def __init__( def __init__(
self, self,
data_source: QDDataSource, data_source: QDDataSource,
...@@ -17,38 +16,16 @@ class IntentBot: ...@@ -17,38 +16,16 @@ class IntentBot:
score: float, score: float,
limit: int, limit: int,
): ):
self.data_source = data_source super(IntentBot, self).__init__(data_source)
# self.data_source = data_source
self.embdder = embedder self.embdder = embedder
self.collection = collection self.collection = collection
self.score = float(score) self.score = float(score)
self.limit = int(limit) self.limit = int(limit)
def create_collection(self):
"""创建数据集
Args:
size (int): 向量大小
distance (QdrantClient.Distance): 损失函数
"""
try:
self.data_source.create_collection(collection=self.collection)
except Exception as e:
raise e
def formatter(self, line: List) -> Dict: def formatter(self, line: List) -> Dict:
# line format: 问题,答案(意图) # line format: 问题,答案, type, 意图
# return {"question": line[0], "answer": INTENT_QA} return BotPayload(question=line[0], answer=int(line[3])).to_dict()
return BotPayload(question=line[0], answer=INTENT_QA).to_dict()
def upload_csv(self, filepath: str):
"""导入csv数据集
Args:
filepath: 文件路径
"""
self.data_source.upload_csv(
self.collection, self.embdder, filepath, self.formatter
)
def complete(self, prompt: str) -> BotPayload: def complete(self, prompt: str) -> BotPayload:
points = self.data_source.search( points = self.data_source.search(
......
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
from typing import List, Dict from typing import List, Dict
from .datasource import QDDataSource from .datasource import QDDataSource
from .embedding import Embedding from .embedding import Embedding
from .common import BotPayload from .common import BotPayload, BaseQDBot
class QABot: class QABot(BaseQDBot):
def __init__( def __init__(
self, self,
data_source: QDDataSource, data_source: QDDataSource,
...@@ -15,39 +15,15 @@ class QABot: ...@@ -15,39 +15,15 @@ class QABot:
score: int, score: int,
limit: int, limit: int,
): ):
self.data_source = data_source super(QABot, self).__init__(data_source)
self.embdder = embedder self.embdder = embedder
self.collection = collection self.collection = collection
self.score = float(score) self.score = float(score)
self.limit = int(limit) self.limit = int(limit)
def create_collection(self):
"""创建数据集
Args:
size (int): 向量大小
distance (QdrantClient.Distance): 损失函数
"""
try:
self.data_source.create_collection(collection=self.collection)
except Exception as e:
raise e
def formatter(self, line: List) -> Dict: def formatter(self, line: List) -> Dict:
# line format: 问题,答案,跳转类型 # line format: 问题,答案, type, 意图
# return {"question": line[0], "answer": line[1], "type": line[2]} return BotPayload(question=line[0], answer=line[1], type=line[2]).to_dict()
# return {"question": line[0], "answer": line[1], "type": "demo"}
return BotPayload(question=line[0], answer=line[1], type="demo").to_dict()
def upload_csv(self, filepath: str):
"""导入csv数据集
Args:
filepath: 文件路径
"""
self.data_source.upload_csv(
self.collection, self.embdder, filepath, self.formatter
)
def complete(self, prompt: str, score: int = 0, limit: int = 0) -> BotPayload: def complete(self, prompt: str, score: int = 0, limit: int = 0) -> BotPayload:
_limit = self.limit if limit == 0 else limit _limit = self.limit if limit == 0 else limit
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论