cronpy/core/mysqldb.py

186 lines
6.3 KiB
Python
Raw Normal View History

2020-10-03 21:17:53 +00:00
from datetime import datetime, date, timedelta
import re
from MySQLdb.constants import FIELD_TYPE
import _mysql
import setting
def _datetime_or_none(value):
space_split = value.split(' ')
if len(space_split[0]) == 10:
dash_split = space_split[0].split('-')
if len(dash_split) == 3 and all([x.isnumeric() for x in dash_split]):
year = int(dash_split[0])
month = int(dash_split[1])
day = int(dash_split[2])
if len(space_split) > 1:
colon_split = space_split[1].split(':')
if len(colon_split) == 3 and all([x.isnumeric() for x in colon_split[:1]]):
hour = int(colon_split[0])
minute = int(colon_split[1])
if colon_split[2].isnumeric():
second = int(colon_split[2])
microsecond = 0
else:
point_split = colon_split[2].split('.')
if len(point_split) == 2 and all([x.isnumeric() for x in point_split]):
second = int(point_split[0])
microsecond = int(point_split[1]) * 10 ** (6 - len(point_split[1]))
else:
return None
return datetime(year, month, day, hour, minute, second, microsecond)
else:
return None
else:
return date(year, month, day)
else:
return None
else:
return None
def _timedelta_or_none(value):
colon_split = value.split(':')
if len(colon_split) == 3 and all([x.isnumeric() for x in colon_split]):
hours = int(colon_split[0])
minutes = int(colon_split[1])
seconds = int(colon_split[2])
return timedelta(hours=hours, minutes=minutes, seconds=seconds)
else:
return None
def _year_to_datetime(value):
if value.isnumeric():
return datetime(year=int(value), month=1, day=1)
else:
return None
def _bytes_to_str(value):
return value.decode(encoding='utf8', errors='replace')
def _none(_):
return None
class MysqlDB:
CONVERTER = {
FIELD_TYPE.BIT: int,
FIELD_TYPE.BLOB: _bytes_to_str,
FIELD_TYPE.CHAR: _bytes_to_str,
FIELD_TYPE.DATE: _datetime_or_none,
FIELD_TYPE.DATETIME: _datetime_or_none,
FIELD_TYPE.DECIMAL: float,
FIELD_TYPE.DOUBLE: float,
FIELD_TYPE.ENUM: _bytes_to_str,
FIELD_TYPE.FLOAT: float,
FIELD_TYPE.GEOMETRY: _none,
FIELD_TYPE.INT24: int,
FIELD_TYPE.INTERVAL: _none,
FIELD_TYPE.LONG: int,
FIELD_TYPE.LONG_BLOB: _bytes_to_str,
FIELD_TYPE.LONGLONG: int,
FIELD_TYPE.MEDIUM_BLOB: _bytes_to_str,
FIELD_TYPE.NEWDATE: _datetime_or_none,
FIELD_TYPE.NEWDECIMAL: float,
FIELD_TYPE.NULL: _none,
FIELD_TYPE.SET: _bytes_to_str,
FIELD_TYPE.SHORT: int,
FIELD_TYPE.STRING: _bytes_to_str,
FIELD_TYPE.TIME: _timedelta_or_none,
FIELD_TYPE.TIMESTAMP: _datetime_or_none,
FIELD_TYPE.TINY: int,
FIELD_TYPE.TINY_BLOB: _bytes_to_str,
FIELD_TYPE.VAR_STRING: _bytes_to_str,
FIELD_TYPE.VARCHAR: _bytes_to_str,
FIELD_TYPE.YEAR: _year_to_datetime
}
CHARSET = 'utf8'
def __init__(self, host=setting.MYSQL_HOST, port=setting.MYSQL_PORT, user=setting.MYSQL_USER,
pass_=setting.MYSQL_PASS, base=setting.MYSQL_BASE, charset=CHARSET, autocommit=False):
self._session = None
self._host = host
self._port = port
self._user = user
self._pass = pass_
self._base = base
self._charset = charset
self._autocommit = autocommit
def connect(self):
if self._session is None:
self._session = _mysql.connect(
host=self._host, port=self._port, user=self._user, passwd=self._pass, db=self._base, conv=self.CONVERTER
)
self._session.set_character_set(self._charset)
self._session.autocommit(self._autocommit)
@staticmethod
def _build_stmt(stmt, args):
if args is not None:
for key in args:
if isinstance(args[key], str):
args[key] = "'{}'".format(args[key].replace("'", "''").replace('\\', '\\\\'))
elif isinstance(args[key], datetime):
args[key] = "'{}'".format(args[key].strftime('%Y-%m-%d %H:%M:%S'))
elif isinstance(args[key], date):
args[key] = "'{}'".format(args[key].strftime('%Y-%m-%d'))
elif isinstance(args[key], float):
args[key] = str(args[key])
elif isinstance(args[key], int):
args[key] = str(args[key])
elif args[key] is None:
args[key] = 'NULL'
else:
raise NameError('Argument type not allowed here: {} - {}'.format(type(args[key]), args[key]))
stmt = re.sub(r':([a-zA-Z0-9_]+)', r'%(\1)s', stmt) % args
return stmt
def query(self, stmt, args=None):
self.connect()
self._session.query(self._build_stmt(stmt, args))
result = self._session.use_result()
fields = result.describe()
rows = list()
while True:
row = result.fetch_row()
if row:
rows.append({fields[x][0]: row[0][x] for x in range(len(row[0]))})
else:
break
return rows
def exec(self, stmt, args=None):
self.connect()
self._session.query(self._build_stmt(stmt, args))
return self._session.insert_id()
def rollback(self):
if self._session is not None:
self._session.rollback()
def commit(self):
if self._session is not None and setting.MYSQL_SAVE:
self._session.commit()
def close(self):
if self._session is not None:
self._session.close()
def get_random_ua(self):
for row in self.query(""" SELECT useragents FROM user_agents ORDER BY RAND() LIMIT 1 """):
return row['useragents']
def get_tables(self, table_name=None):
for row in self.query(""" SHOW TABLES """):
if table_name is None or row['Tables_in_bob'] == table_name:
yield row['Tables_in_bob']