123 lines
3.2 KiB
Python
123 lines
3.2 KiB
Python
import os
|
||
import threading
|
||
from sqlalchemy import create_engine
|
||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||
from sqlalchemy.pool import QueuePool
|
||
|
||
from app.db.models import Base
|
||
from app.utils import ExceptionUtils, PathUtils
|
||
from config import Config
|
||
|
||
lock = threading.Lock()
|
||
_Engine = create_engine(
|
||
f"sqlite:///{os.path.join(Config().get_config_path(), 'user.db')}?check_same_thread=False",
|
||
echo=False,
|
||
poolclass=QueuePool,
|
||
pool_pre_ping=True,
|
||
pool_size=50,
|
||
pool_recycle=60 * 10,
|
||
max_overflow=0
|
||
)
|
||
_Session = scoped_session(sessionmaker(bind=_Engine,
|
||
autoflush=True,
|
||
autocommit=False,
|
||
expire_on_commit=False))
|
||
|
||
|
||
class MainDb:
|
||
|
||
@property
|
||
def session(self):
|
||
return _Session()
|
||
|
||
@staticmethod
|
||
def init_db():
|
||
with lock:
|
||
Base.metadata.create_all(_Engine)
|
||
|
||
def init_data(self):
|
||
"""
|
||
读取config目录下的sql文件,并初始化到数据库,只处理一次
|
||
"""
|
||
config = Config().get_config()
|
||
init_files = Config().get_config("app").get("init_files") or []
|
||
config_dir = Config().get_script_path()
|
||
sql_files = PathUtils.get_dir_level1_files(in_path=config_dir, exts=".sql")
|
||
config_flag = False
|
||
for sql_file in sql_files:
|
||
if os.path.basename(sql_file) not in init_files:
|
||
config_flag = True
|
||
with open(sql_file, "r", encoding="utf-8") as f:
|
||
sql_list = f.read().split(';\n')
|
||
for sql in sql_list:
|
||
try:
|
||
self.excute(sql)
|
||
self.commit()
|
||
except Exception as err:
|
||
print(str(err))
|
||
init_files.append(os.path.basename(sql_file))
|
||
if config_flag:
|
||
config['app']['init_files'] = init_files
|
||
Config().save_config(config)
|
||
|
||
def insert(self, data):
|
||
"""
|
||
插入数据
|
||
"""
|
||
if isinstance(data, list):
|
||
self.session.add_all(data)
|
||
else:
|
||
self.session.add(data)
|
||
|
||
def query(self, *obj):
|
||
"""
|
||
查询对象
|
||
"""
|
||
return self.session.query(*obj)
|
||
|
||
def excute(self, sql):
|
||
"""
|
||
执行SQL语句
|
||
"""
|
||
self.session.execute(sql)
|
||
|
||
def flush(self):
|
||
"""
|
||
刷写
|
||
"""
|
||
self.session.flush()
|
||
|
||
def commit(self):
|
||
"""
|
||
提交事务
|
||
"""
|
||
self.session.commit()
|
||
|
||
def rollback(self):
|
||
"""
|
||
回滚事务
|
||
"""
|
||
self.session.rollback()
|
||
|
||
|
||
class DbPersist(object):
|
||
"""
|
||
数据库持久化装饰器
|
||
"""
|
||
|
||
def __init__(self, db):
|
||
self.db = db
|
||
|
||
def __call__(self, f):
|
||
def persist(*args, **kwargs):
|
||
try:
|
||
ret = f(*args, **kwargs)
|
||
self.db.commit()
|
||
return True if ret is None else ret
|
||
except Exception as e:
|
||
ExceptionUtils.exception_traceback(e)
|
||
self.db.rollback()
|
||
return False
|
||
|
||
return persist
|