186 lines
6.3 KiB
Python
186 lines
6.3 KiB
Python
|
from datetime import datetime, date, timedelta
|
||
|
import re
|
||
|
|
||
|
from MySQLdb.constants import FIELD_TYPE
|
||
|
import _mysql
|
||
|
|
||
|
import setting
|
||
|
|
||
|
|
||
|
def _datetime_or_none(value):
|
||
|
space_split = value.split(' ')
|
||
|
if len(space_split[0]) == 10:
|
||
|
dash_split = space_split[0].split('-')
|
||
|
if len(dash_split) == 3 and all([x.isnumeric() for x in dash_split]):
|
||
|
year = int(dash_split[0])
|
||
|
month = int(dash_split[1])
|
||
|
day = int(dash_split[2])
|
||
|
if len(space_split) > 1:
|
||
|
colon_split = space_split[1].split(':')
|
||
|
if len(colon_split) == 3 and all([x.isnumeric() for x in colon_split[:1]]):
|
||
|
hour = int(colon_split[0])
|
||
|
minute = int(colon_split[1])
|
||
|
if colon_split[2].isnumeric():
|
||
|
second = int(colon_split[2])
|
||
|
microsecond = 0
|
||
|
else:
|
||
|
point_split = colon_split[2].split('.')
|
||
|
if len(point_split) == 2 and all([x.isnumeric() for x in point_split]):
|
||
|
second = int(point_split[0])
|
||
|
microsecond = int(point_split[1]) * 10 ** (6 - len(point_split[1]))
|
||
|
else:
|
||
|
return None
|
||
|
return datetime(year, month, day, hour, minute, second, microsecond)
|
||
|
else:
|
||
|
return None
|
||
|
else:
|
||
|
return date(year, month, day)
|
||
|
else:
|
||
|
return None
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
|
||
|
def _timedelta_or_none(value):
|
||
|
colon_split = value.split(':')
|
||
|
if len(colon_split) == 3 and all([x.isnumeric() for x in colon_split]):
|
||
|
hours = int(colon_split[0])
|
||
|
minutes = int(colon_split[1])
|
||
|
seconds = int(colon_split[2])
|
||
|
return timedelta(hours=hours, minutes=minutes, seconds=seconds)
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
|
||
|
def _year_to_datetime(value):
|
||
|
if value.isnumeric():
|
||
|
return datetime(year=int(value), month=1, day=1)
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
|
||
|
def _bytes_to_str(value):
|
||
|
return value.decode(encoding='utf8', errors='replace')
|
||
|
|
||
|
|
||
|
def _none(_):
|
||
|
return None
|
||
|
|
||
|
|
||
|
class MysqlDB:
|
||
|
|
||
|
CONVERTER = {
|
||
|
FIELD_TYPE.BIT: int,
|
||
|
FIELD_TYPE.BLOB: _bytes_to_str,
|
||
|
FIELD_TYPE.CHAR: _bytes_to_str,
|
||
|
FIELD_TYPE.DATE: _datetime_or_none,
|
||
|
FIELD_TYPE.DATETIME: _datetime_or_none,
|
||
|
FIELD_TYPE.DECIMAL: float,
|
||
|
FIELD_TYPE.DOUBLE: float,
|
||
|
FIELD_TYPE.ENUM: _bytes_to_str,
|
||
|
FIELD_TYPE.FLOAT: float,
|
||
|
FIELD_TYPE.GEOMETRY: _none,
|
||
|
FIELD_TYPE.INT24: int,
|
||
|
FIELD_TYPE.INTERVAL: _none,
|
||
|
FIELD_TYPE.LONG: int,
|
||
|
FIELD_TYPE.LONG_BLOB: _bytes_to_str,
|
||
|
FIELD_TYPE.LONGLONG: int,
|
||
|
FIELD_TYPE.MEDIUM_BLOB: _bytes_to_str,
|
||
|
FIELD_TYPE.NEWDATE: _datetime_or_none,
|
||
|
FIELD_TYPE.NEWDECIMAL: float,
|
||
|
FIELD_TYPE.NULL: _none,
|
||
|
FIELD_TYPE.SET: _bytes_to_str,
|
||
|
FIELD_TYPE.SHORT: int,
|
||
|
FIELD_TYPE.STRING: _bytes_to_str,
|
||
|
FIELD_TYPE.TIME: _timedelta_or_none,
|
||
|
FIELD_TYPE.TIMESTAMP: _datetime_or_none,
|
||
|
FIELD_TYPE.TINY: int,
|
||
|
FIELD_TYPE.TINY_BLOB: _bytes_to_str,
|
||
|
FIELD_TYPE.VAR_STRING: _bytes_to_str,
|
||
|
FIELD_TYPE.VARCHAR: _bytes_to_str,
|
||
|
FIELD_TYPE.YEAR: _year_to_datetime
|
||
|
}
|
||
|
|
||
|
CHARSET = 'utf8'
|
||
|
|
||
|
def __init__(self, host=setting.MYSQL_HOST, port=setting.MYSQL_PORT, user=setting.MYSQL_USER,
|
||
|
pass_=setting.MYSQL_PASS, base=setting.MYSQL_BASE, charset=CHARSET, autocommit=False):
|
||
|
self._session = None
|
||
|
self._host = host
|
||
|
self._port = port
|
||
|
self._user = user
|
||
|
self._pass = pass_
|
||
|
self._base = base
|
||
|
self._charset = charset
|
||
|
self._autocommit = autocommit
|
||
|
|
||
|
def connect(self):
|
||
|
if self._session is None:
|
||
|
self._session = _mysql.connect(
|
||
|
host=self._host, port=self._port, user=self._user, passwd=self._pass, db=self._base, conv=self.CONVERTER
|
||
|
)
|
||
|
self._session.set_character_set(self._charset)
|
||
|
self._session.autocommit(self._autocommit)
|
||
|
|
||
|
@staticmethod
|
||
|
def _build_stmt(stmt, args):
|
||
|
if args is not None:
|
||
|
for key in args:
|
||
|
if isinstance(args[key], str):
|
||
|
args[key] = "'{}'".format(args[key].replace("'", "''").replace('\\', '\\\\'))
|
||
|
elif isinstance(args[key], datetime):
|
||
|
args[key] = "'{}'".format(args[key].strftime('%Y-%m-%d %H:%M:%S'))
|
||
|
elif isinstance(args[key], date):
|
||
|
args[key] = "'{}'".format(args[key].strftime('%Y-%m-%d'))
|
||
|
elif isinstance(args[key], float):
|
||
|
args[key] = str(args[key])
|
||
|
elif isinstance(args[key], int):
|
||
|
args[key] = str(args[key])
|
||
|
elif args[key] is None:
|
||
|
args[key] = 'NULL'
|
||
|
else:
|
||
|
raise NameError('Argument type not allowed here: {} - {}'.format(type(args[key]), args[key]))
|
||
|
stmt = re.sub(r':([a-zA-Z0-9_]+)', r'%(\1)s', stmt) % args
|
||
|
return stmt
|
||
|
|
||
|
def query(self, stmt, args=None):
|
||
|
self.connect()
|
||
|
self._session.query(self._build_stmt(stmt, args))
|
||
|
|
||
|
result = self._session.use_result()
|
||
|
fields = result.describe()
|
||
|
rows = list()
|
||
|
while True:
|
||
|
row = result.fetch_row()
|
||
|
if row:
|
||
|
rows.append({fields[x][0]: row[0][x] for x in range(len(row[0]))})
|
||
|
else:
|
||
|
break
|
||
|
return rows
|
||
|
|
||
|
def exec(self, stmt, args=None):
|
||
|
self.connect()
|
||
|
self._session.query(self._build_stmt(stmt, args))
|
||
|
return self._session.insert_id()
|
||
|
|
||
|
def rollback(self):
|
||
|
if self._session is not None:
|
||
|
self._session.rollback()
|
||
|
|
||
|
def commit(self):
|
||
|
if self._session is not None and setting.MYSQL_SAVE:
|
||
|
self._session.commit()
|
||
|
|
||
|
def close(self):
|
||
|
if self._session is not None:
|
||
|
self._session.close()
|
||
|
|
||
|
def get_random_ua(self):
|
||
|
for row in self.query(""" SELECT useragents FROM user_agents ORDER BY RAND() LIMIT 1 """):
|
||
|
return row['useragents']
|
||
|
|
||
|
def get_tables(self, table_name=None):
|
||
|
for row in self.query(""" SHOW TABLES """):
|
||
|
if table_name is None or row['Tables_in_bob'] == table_name:
|
||
|
yield row['Tables_in_bob']
|