Commit d972c690 by 范立洲

feat: add yunheBot.server

parent 229bbccf
FROM python:3.9-slim AS compile-image
RUN apt-get upgrade & \
apt-get install -y --no-install-recommends build-essential gcc & \
python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
COPY requirements.txt .
RUN pip install -r requirements.txt
WORKDIR /yunheBot
COPY . .
RUN apt-get update && \
apt-get install -y --no-install-recommends git && \
apt-get purge -y --auto-remove && \
rm -rf /var/lib/apt/lists/* && \
pip install -r requirements.txt
RUN tox -e build
FROM python:3.9-slim AS build-image
COPY --from=compile-image /opt/venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH" FLASK_ENV="PRD"
WORKDIR /yuheBot
COPY ./src/ .
CMD ["gunicorn","-w","4","-b","0.0.0.0:9000","--log-level=info","--access-logfile","./log/access.log","--error-logfile","./log/error.log","wsgi:app"]
FROM python:3.9-slim
WORKDIR /yunheBot
COPY --from=compile-image /yunheBot/dist/*.whl /yunheBot/dist/
RUN pip install /yunheBot/dist/*.whl
CMD ["gunicorn","-w","4","-b","0.0.0.0:9000","--log-level=info","--access-logfile","./log/access.log","--error-logfile","./log/error.log","yunheBot.server:app"]
......@@ -56,9 +56,15 @@ class CompletionResource(Resource):
@validator.validator(POST_SCHEMA)
def post(self):
data = request.get_json(force=True)
# req: Dict = json.loads(data)
req: Dict = data
# if not req["stream"]:
# return self._completion(req)
# else:
# pass
# 1. intent check
intent = self.intent_bot.complete(req["prompt"])
......@@ -78,3 +84,30 @@ class CompletionResource(Resource):
return response.OK("", resp.to_dict())
except Exception as e:
return response.InternalServerError(str(e))
def _completion(self, req: Dict):
# 1. intent check
intent = self.intent_bot.complete(req["prompt"])
# 1.1 intent ==> qa
if intent is not None and intent.answer == const.INTENT_QA:
try:
answer = self.qa_bot.complete(**req)
if answer is None:
return response.NotFound("No Answer Found")
return response.OK("OK", answer.to_dict())
except Exception as e:
return response.InternalServerError(str(e))
# 1.2 intent ==> chatgpt
try:
resp = self.chatgpt_bot.complete(req["prompt"])
return response.OK("", resp.to_dict())
except Exception as e:
return response.InternalServerError(str(e))
def _stream_completion(self, req):
# # 1. intent check
# intent = self.intent_bot.complete(req["prompt"])
pass
# coding=utf-8
from .datasource import QDDataSource
class BaseQDBot:
def __init__(self, data_source: QDDataSource):
self.data_source = data_source
def create_collection(self):
"""创建数据集
Args:
size (int): 向量大小
distance (QdrantClient.Distance): 损失函数
"""
try:
self.data_source.create_collection(collection=self.collection)
except Exception as e:
raise e
def upload_csv(self, filepath: str):
"""导入csv数据集
Args:
filepath: 文件路径
"""
self.data_source.upload_csv(
self.collection, self.embdder, filepath, self.formatter
)
class BotPayload:
@classmethod
......
......@@ -4,11 +4,10 @@ from collections import Counter
from typing import List, Dict
from .datasource import QDDataSource
from .embedding import Embedding
from .const import INTENT_QA
from .common import BotPayload
from .common import BotPayload, BaseQDBot
class IntentBot:
class IntentBot(BaseQDBot):
def __init__(
self,
data_source: QDDataSource,
......@@ -17,38 +16,16 @@ class IntentBot:
score: float,
limit: int,
):
self.data_source = data_source
super(IntentBot, self).__init__(data_source)
# self.data_source = data_source
self.embdder = embedder
self.collection = collection
self.score = float(score)
self.limit = int(limit)
def create_collection(self):
"""创建数据集
Args:
size (int): 向量大小
distance (QdrantClient.Distance): 损失函数
"""
try:
self.data_source.create_collection(collection=self.collection)
except Exception as e:
raise e
def formatter(self, line: List) -> Dict:
# line format: 问题,答案(意图)
# return {"question": line[0], "answer": INTENT_QA}
return BotPayload(question=line[0], answer=INTENT_QA).to_dict()
def upload_csv(self, filepath: str):
"""导入csv数据集
Args:
filepath: 文件路径
"""
self.data_source.upload_csv(
self.collection, self.embdder, filepath, self.formatter
)
# line format: 问题,答案, type, 意图
return BotPayload(question=line[0], answer=int(line[3])).to_dict()
def complete(self, prompt: str) -> BotPayload:
points = self.data_source.search(
......
......@@ -3,10 +3,10 @@
from typing import List, Dict
from .datasource import QDDataSource
from .embedding import Embedding
from .common import BotPayload
from .common import BotPayload, BaseQDBot
class QABot:
class QABot(BaseQDBot):
def __init__(
self,
data_source: QDDataSource,
......@@ -15,39 +15,15 @@ class QABot:
score: int,
limit: int,
):
self.data_source = data_source
super(QABot, self).__init__(data_source)
self.embdder = embedder
self.collection = collection
self.score = float(score)
self.limit = int(limit)
def create_collection(self):
"""创建数据集
Args:
size (int): 向量大小
distance (QdrantClient.Distance): 损失函数
"""
try:
self.data_source.create_collection(collection=self.collection)
except Exception as e:
raise e
def formatter(self, line: List) -> Dict:
# line format: 问题,答案,跳转类型
# return {"question": line[0], "answer": line[1], "type": line[2]}
# return {"question": line[0], "answer": line[1], "type": "demo"}
return BotPayload(question=line[0], answer=line[1], type="demo").to_dict()
def upload_csv(self, filepath: str):
"""导入csv数据集
Args:
filepath: 文件路径
"""
self.data_source.upload_csv(
self.collection, self.embdder, filepath, self.formatter
)
# line format: 问题,答案, type, 意图
return BotPayload(question=line[0], answer=line[1], type=line[2]).to_dict()
def complete(self, prompt: str, score: int = 0, limit: int = 0) -> BotPayload:
_limit = self.limit if limit == 0 else limit
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论