Commit 7d4baa0f by cinuor

first commit

parent a810275c
......@@ -52,3 +52,6 @@ MANIFEST
.venv*/
.conda*/
.python-version
# ConfigFile
dev.ini
yunheBot-Jie2ooj2
# yunheBot
[base]
secret_key=asdfzxcv
debug=true
host=0.0.0.0
port=9000
[qdrant]
qdrant_host=127.0.0.1
qdrant_port=6333
qdrant_collection=test
[openai]
api_key=qwertyuiopasdfghjkl
[yhbot]
park_score=0.8
This source diff could not be displayed because it is too large. You can view the blob instead.
[tool.poetry]
name = "yunhebot"
version = "0.1.0"
description = "yunhe bot"
authors = ["fanlizhou <fanlizhou@yunqilaohe.com>"]
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.11,<3.12"
Flask = "^2.2.3"
qdrant-client = "^1.1.4"
openai = "^0.27.4"
flask-restful = "^0.3.9"
jsonschema = "^4.17.3"
[tool.poetry.group.dev.dependencies]
setuptools = "^67.6.1"
setuptools-scm = {extras = ["toml"], version = "^7.1.0"}
tox = "^4.4.12"
#
# [build-system]
# requires = ["poetry-core"]
# build-backend = "poetry.core.masonry.api"
#
[build-system]
# AVOID CHANGING REQUIRES: IT WILL BE UPDATED BY PYSCAFFOLD!
requires = ["setuptools>=46.1.0", "setuptools_scm[toml]>=5"]
requires = ["setuptools>=46.1.0", "setuptools_scm[toml]>=5", "wheel"]
build-backend = "setuptools.build_meta"
[tool.setuptools_scm]
......
......@@ -48,7 +48,13 @@ package_dir =
# new major versions. This works if the required packages follow Semantic Versioning.
# For more information, check out https://semver.org/.
install_requires =
importlib-metadata; python_version<"3.8"
importlib-metadata
Flask>=2.2.3
qdrant-client>=1.1.4
openai>=0.27.4
flask-restful>=0.3.9
click==8.1.3
jsonschema==4.17.3
[options.packages.find]
......@@ -69,8 +75,8 @@ testing =
[options.entry_points]
# Add here console scripts like:
# console_scripts =
# script_name = yunhebot.module:function
console_scripts =
yhbot = yunheBot.client:cli
# For example:
# console_scripts =
# fibonacci = yunhebot.skeleton:run
......
......@@ -10,7 +10,12 @@ from setuptools import setup
if __name__ == "__main__":
try:
setup(use_scm_version={"version_scheme": "no-guess-dev"})
setup(
# name="yunhebot",
# version="0.1.1",
# packages=find_packages(include=["yunheBot", "yunheBot.*"])
use_scm_version={"version_scheme": "no-guess-dev"}
)
except: # noqa
print(
"\n\nAn error occurred while building the project, "
......
# coding=utf-8
#!/usr/bin/env python3
from yunheBot.client import cli
if __name__ == "__main__":
cli()
# coding=utf-8
from flask import Blueprint
from flask_restful import Api
from .v1.industrial_park import IndustrialParkResource
rest_api_bp = Blueprint("api", __name__)
rest_api = Api(rest_api_bp)
rest_api.add_resource(IndustrialParkResource, "/v1/park", endpoint="park")
# coding=utf-8
from flask_restful import Resource
from flask import current_app, request
from ...core.embedding import OpenAIEmbedding
from ...core.datasource import YHQdrantClient
from ...common import response, validator
POST_SCHEMA = {"type": "object", "properties": {"prompt": {"type": "string"}}}
class IndustrialParkResource(Resource):
def __init__(self):
self.embedder = OpenAIEmbedding(current_app.config["OPENAI_API_KEY"])
self.qd_client = YHQdrantClient(
current_app.config["QDRANT_HOST"],
current_app.config["QDRANT_PORT"],
self.embedder,
)
self.collection = current_app.config["QDRANT_COLLECTION"]
self.park_score = float(current_app.config["PARK_SCORE"])
@response.wrap_response
@validator.validator(POST_SCHEMA)
def post(self):
data = request.get_json(force=True)
prompt = data.get("prompt")
try:
raw_results = self.qd_client.search(
prompt=prompt,
collection_name=self.collection,
limit=3,
)
except Exception as e:
return response.InternalServerError(str(e))
results = [item for item in raw_results if item.score >= self.park_score]
if len(results) == 0:
return response.NotFound("No Answer Found")
answer = results[0].payload
return response.OK("", answer)
# coding=utf-8
#!/usr/bin/env python3
import click
import os
from yunheBot.create_app import create_app
from yunheBot.core.embedding import OpenAIEmbedding
from yunheBot.core.datasource import YHQdrantClient
@click.group()
def cli():
pass
@click.command()
@click.option("--conf", type=str)
def server(conf):
app = create_app(config_path=conf)
host = app.config["HOST"]
port = app.config["PORT"]
app.run(host, port)
@click.command()
@click.option("--open_ai_key", type=str)
@click.option("--db", type=str)
@click.option("--file", type=str)
@click.option("--collection", type=str)
def upload(open_ai_key: str, db: str, filepath: str, collection: str):
embdder = OpenAIEmbedding(open_ai_key)
qdrant_host, qdrant_port = db.split(":")
client = YHQdrantClient(qdrant_host, qdrant_port, embdder)
if os.path.splitext(filepath)[1].lower() == ".csv":
client.create_collection(collection)
client.upload_csv(filepath, collection)
else:
print(f"{filepath} is not a CSV file.")
return
cli.add_command(server)
cli.add_command(upload)
# coding=utf-8
import json
from flask import make_response
def wrap_response(f):
def format_response(*args, **kwargs):
resp = f(*args, **kwargs)
flask_resp = make_response(resp.to_dict(), resp.code)
flask_resp.headers["content-type"] = "application/json"
return flask_resp
return format_response
class Response(object):
def __init__(self, code, message, data):
self.code = code
self.message = message
self.data = data
def to_dict(self):
return {"code": self.code, "message": self.message, "data": self.data}
class OK(Response):
def __init__(self, message, data):
super(OK, self).__init__(200, message, data)
class BadRequest(Response):
def __init__(self, message):
super(BadRequest, self).__init__(400, message, None)
class InternalServerError(Response):
def __init__(self, message):
super(InternalServerError, self).__init__(500, message, None)
class NotFound(Response):
def __init__(self, message):
super(NotFound, self).__init__(404, message, None)
class Forbidden(Response):
def __init__(self, message):
super(Forbidden, self).__init__(403, message, None)
class Raw:
def __init__(self, code, raw_data):
if type(raw_data) == dict or type(raw_data) == list:
self.data = raw_data
self.code = code
else:
raise TypeError("Raw Response Only Support Dict or List Object")
def to_dict(self):
return json.dumps(self.data)
# coding=utf-8
import functools
from flask import request
from jsonschema import validate
from jsonschema.exceptions import SchemaError, ValidationError
from .response import InternalServerError, BadRequest
def validator(schema):
def check_json(f):
@functools.wraps(f)
def wrap_handler(*args, **kwargs):
data = request.get_json(force=True) # type: ignore
try:
validate(instance=data, schema=schema)
except SchemaError as e:
return InternalServerError(f"Schema Error: {e}")
except ValidationError as e:
return BadRequest(f"Bad Request: {e}")
return f(*args, **kwargs)
return wrap_handler
return check_json
import configparser
from datetime import timedelta
base_config = {
"SECRET_KEY": "asdfzxcv", # session加密
"PERMANENT_SESSION_LIFETIME": timedelta(days=30), # 设置session过期时间
"DEBUG": True,
# SERVER_NAME : 'example.com'
"HOST": "0.0.0.0",
"PORT": 9000,
}
def setup_config(config_path):
"""获取配置文件"""
conf = configparser.ConfigParser()
conf.read(config_path)
new_config = {
# base
"SECRET_KEY": conf.get("base", "secret_key"), # session加密
# "PERMANENT_SESSION_LIFETIME": timedelta(days=30), # 设置session过期时间
"DEBUG": conf.getboolean("base", "debug"),
"HOST": conf.get("base", "host"),
"PORT": conf.getint("base", "port"),
# qdrant
"QDRANT_HOST": conf.get("qdrant", "qdrant_host"),
"QDRANT_PORT": conf.get("qdrant", "qdrant_port"),
"QDRANT_COLLECTION": conf.get("qdrant", "qdrant_collection"),
# openai_api_key
"OPENAI_API_KEY": conf.get("openai", "api_key"),
# yhbot
"PARK_SCORE": conf.get("yhbot", "park_score")
}
base_config.update(new_config)
# class BaseConfig:
# """配置基类"""
#
# # SECRET_KEY = os.urandom(24)
# SECRET_KEY = "asdfzxcv" # session加密
# PERMANENT_SESSION_LIFETIME = timedelta(days=30) # 设置session过期时间
# DEBUG = True
# # SERVER_NAME = 'example.com'
# HOST = "0.0.0.0"
# PORT = 9000
#
# @staticmethod
# def init_app(app):
# pass
# class NewConfig(BaseConfig):
# """区分配置文件"""
#
# def __init__(self, config_path):
# self._conf = get_config(config_path)
#
#
# # base
# SECRET_KEY = self._conf.get("base", "secret_key") # session加密
# PERMANENT_SESSION_LIFETIME = timedelta(days=30) # 设置session过期时间
# DEBUG = self._conf.getboolean("base", "debug")
# HOST = self._conf.get("base", "host")
# PORT = self._conf.getint("base", "port")
#
# # qdrant
# QDRANT_HOST = self._conf.get("qdrant", "qdrant_host")
# QDRANT_PORT = self._conf.get("qdrant", "qdrant_port")
#
# # openai_api_key
# OPENAI_KEY = self._conf.get("openai", "api_key")
#
# # # mysql
# # MYSQL_USERNAME = conf.get("mysql", "USERNAME")
# # MYSQL_PASSWORD = conf.get("mysql", "PASSWORD")
# # MYSQL_HOSTNAME = conf.get("mysql", "HOSTNAME")
# # MYSQL_PORT = conf.getint("mysql", "PORT")
# # MYSQL_DATABASE = conf.get("mysql", "DATABASE")
# # DB_URI = "mysql+pymysql://{}:{}@{}:{}/{}?charset=utf8".format(
# # MYSQL_USERNAME, MYSQL_PASSWORD, MYSQL_HOSTNAME, MYSQL_PORT, MYSQL_DATABASE
# # )
# # SQLALCHEMY_DATABASE_URI = DB_URI
# # SQLALCHEMY_TRACK_MODIFICATIONS = True
#
# # # redis
# # redis_obj = {
# # "host": conf.get("redis", "REDIS_HOST"),
# # "port": conf.get("redis", "REDIS_PORT"),
# # "password": conf.get("redis", "REDIS_PWD"),
# # "decode_responses": conf.getboolean("redis", "DECODE_RESPONSES"),
# # "db": conf.getint("redis", "REDIS_DB"),
# # }
# # POOL = redis.ConnectionPool(**redis_obj)
# # R = redis.Redis(connection_pool=POOL)
# coding=utf-8
import csv
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from qdrant_client.http.models import PointStruct
from .embedding import Embedding
from ..util import util
@util.singleton
class YHQdrantClient:
def __init__(self, host: str, port: int, embedder: Embedding):
self.client = QdrantClient(host=host, port=port)
self.embedder = embedder
def create_collection(
self, collection: str, size: int = 1536, distance: Distance = Distance.COSINE
):
"""创建数据集
Args:
collection (str): 集合名称
size (int): 向量大小
distance (QdrantClient.Distance): 损失函数
"""
try:
self.client.recreate_collection(
collection_name=collection,
vectors_config=VectorParams(size=size, distance=distance),
)
except Exception as e:
raise e
def upload_csv(self, filepath: str, collection: str):
"""导入csv数据集
Args:
filepath: 文件路径
"""
with open(filepath, "r", encoding="utf-8") as f:
csv_reader = csv.reader(f)
count = 0
for line in csv_reader:
print(line)
item = self.embedder.to_embedding(line)
print(item)
self.client.upsert(
collection_name=collection,
wait=True,
points=[
PointStruct(
id=count,
vector=item[2],
payload={"title": item[0], "text": item[1]},
)
],
)
count += 1
def search(self, prompt, collection_name, limit):
vec = self.embedder.get_vector(prompt)
results = self.client.search(
collection_name=collection_name,
query_vector=vec,
limit=limit,
search_params={"exact": False, "hnsw_ef": 128},
)
return results
# coding=utf-8
import openai
from abc import abstractmethod, ABCMeta
from typing import List, Tuple
from ..util import util
class Embedding(metaclass=ABCMeta):
@abstractmethod
def generate_sample(self, items: List[str]) -> Tuple[str, str, float]:
pass
@abstractmethod
def get_vector(self, prompt: str) -> float:
pass
@util.singleton
class OpenAIEmbedding:
def __init__(self, open_ai_key):
self.open_ai_key = open_ai_key
openai.api_key = open_ai_key
def _create_embedding(self, text):
sentence_embeddings = openai.Embedding.create(
model="text-embedding-ada-002", input=text
)
return sentence_embeddings["data"][0]["embedding"]
def generate_sample(self, items):
vec = self._create_embedding(items[1])
return [items[0], items[1], vec]
def get_vector(self, prompt):
vec = self._create_embedding(prompt)
return vec
# coding=utf-8
#!/usr/bin/env python3
from flask import Flask
from .register.bp_register import register_blueprint
from .register.conf_register import register_config
def create_app(*args, **kwargs):
app = Flask(__name__)
register_blueprint(app, *args, **kwargs)
register_config(app, *args, **kwargs)
return app
# coding=utf-8
from flask import Flask
from ..api import rest_api_bp
def register_blueprint(app: Flask, *args, **kwargs):
app.register_blueprint(rest_api_bp, url_prefix="/api")
from flask import Flask
from ..config.config import base_config, setup_config
def register_config(app: Flask, *args, **kwargs):
"""配置文件"""
config_path = kwargs.get("config_path")
if config_path is None:
raise FileNotFoundError("Configuration Not Found")
setup_config(config_path)
app.config.from_mapping(base_config)
# print(app.config)
"""
This is a skeleton file that can serve as a starting point for a Python
console script. To run this script uncomment the following lines in the
``[options.entry_points]`` section in ``setup.cfg``::
console_scripts =
fibonacci = yunhebot.skeleton:run
Then run ``pip install .`` (or ``pip install -e .`` for editable mode)
which will install the command ``fibonacci`` inside your current environment.
Besides console scripts, the header (i.e. until ``_logger``...) of this file can
also be used as template for Python modules.
Note:
This file can be renamed depending on your needs or safely removed if not needed.
References:
- https://setuptools.pypa.io/en/latest/userguide/entry_point.html
- https://pip.pypa.io/en/stable/reference/pip_install
"""
import argparse
import logging
import sys
from yunhebot import __version__
__author__ = "cinuor"
__copyright__ = "cinuor"
__license__ = "MIT"
_logger = logging.getLogger(__name__)
# ---- Python API ----
# The functions defined in this section can be imported by users in their
# Python scripts/interactive interpreter, e.g. via
# `from yunhebot.skeleton import fib`,
# when using this Python module as a library.
def fib(n):
"""Fibonacci example function
Args:
n (int): integer
Returns:
int: n-th Fibonacci number
"""
assert n > 0
a, b = 1, 1
for _i in range(n - 1):
a, b = b, a + b
return a
# ---- CLI ----
# The functions defined in this section are wrappers around the main Python
# API allowing them to be called directly from the terminal as a CLI
# executable/script.
def parse_args(args):
"""Parse command line parameters
Args:
args (List[str]): command line parameters as list of strings
(for example ``["--help"]``).
Returns:
:obj:`argparse.Namespace`: command line parameters namespace
"""
parser = argparse.ArgumentParser(description="Just a Fibonacci demonstration")
parser.add_argument(
"--version",
action="version",
version=f"yunheBot {__version__}",
)
parser.add_argument(dest="n", help="n-th Fibonacci number", type=int, metavar="INT")
parser.add_argument(
"-v",
"--verbose",
dest="loglevel",
help="set loglevel to INFO",
action="store_const",
const=logging.INFO,
)
parser.add_argument(
"-vv",
"--very-verbose",
dest="loglevel",
help="set loglevel to DEBUG",
action="store_const",
const=logging.DEBUG,
)
return parser.parse_args(args)
def setup_logging(loglevel):
"""Setup basic logging
Args:
loglevel (int): minimum loglevel for emitting messages
"""
logformat = "[%(asctime)s] %(levelname)s:%(name)s:%(message)s"
logging.basicConfig(
level=loglevel, stream=sys.stdout, format=logformat, datefmt="%Y-%m-%d %H:%M:%S"
)
def main(args):
"""Wrapper allowing :func:`fib` to be called with string arguments in a CLI fashion
Instead of returning the value from :func:`fib`, it prints the result to the
``stdout`` in a nicely formatted message.
Args:
args (List[str]): command line parameters as list of strings
(for example ``["--verbose", "42"]``).
"""
args = parse_args(args)
setup_logging(args.loglevel)
_logger.debug("Starting crazy calculations...")
print(f"The {args.n}-th Fibonacci number is {fib(args.n)}")
_logger.info("Script ends here")
def run():
"""Calls :func:`main` passing the CLI arguments extracted from :obj:`sys.argv`
This function can be used as entry point to create console scripts with setuptools.
"""
main(sys.argv[1:])
if __name__ == "__main__":
# ^ This is a guard statement that will prevent the following code from
# being executed in the case someone imports this file instead of
# executing it as a script.
# https://docs.python.org/3/library/__main__.html
# After installing your project with pip, users can also run your Python
# modules as scripts via the ``-m`` flag, as defined in PEP 338::
#
# python -m yunhebot.skeleton 42
#
run()
def singleton(cls):
instances = {}
def getinstance(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return getinstance
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论