from datetime import datetime, date, timedelta import re from MySQLdb.constants import FIELD_TYPE import _mysql import setting def _datetime_or_none(value): space_split = value.split(' ') if len(space_split[0]) == 10: dash_split = space_split[0].split('-') if len(dash_split) == 3 and all([x.isnumeric() for x in dash_split]): year = int(dash_split[0]) month = int(dash_split[1]) day = int(dash_split[2]) if len(space_split) > 1: colon_split = space_split[1].split(':') if len(colon_split) == 3 and all([x.isnumeric() for x in colon_split[:1]]): hour = int(colon_split[0]) minute = int(colon_split[1]) if colon_split[2].isnumeric(): second = int(colon_split[2]) microsecond = 0 else: point_split = colon_split[2].split('.') if len(point_split) == 2 and all([x.isnumeric() for x in point_split]): second = int(point_split[0]) microsecond = int(point_split[1]) * 10 ** (6 - len(point_split[1])) else: return None return datetime(year, month, day, hour, minute, second, microsecond) else: return None else: return date(year, month, day) else: return None else: return None def _timedelta_or_none(value): colon_split = value.split(':') if len(colon_split) == 3 and all([x.isnumeric() for x in colon_split]): hours = int(colon_split[0]) minutes = int(colon_split[1]) seconds = int(colon_split[2]) return timedelta(hours=hours, minutes=minutes, seconds=seconds) else: return None def _year_to_datetime(value): if value.isnumeric(): return datetime(year=int(value), month=1, day=1) else: return None def _bytes_to_str(value): return value.decode(encoding='utf8', errors='replace') def _none(_): return None class MysqlDB: CONVERTER = { FIELD_TYPE.BIT: int, FIELD_TYPE.BLOB: _bytes_to_str, FIELD_TYPE.CHAR: _bytes_to_str, FIELD_TYPE.DATE: _datetime_or_none, FIELD_TYPE.DATETIME: _datetime_or_none, FIELD_TYPE.DECIMAL: float, FIELD_TYPE.DOUBLE: float, FIELD_TYPE.ENUM: _bytes_to_str, FIELD_TYPE.FLOAT: float, FIELD_TYPE.GEOMETRY: _none, FIELD_TYPE.INT24: int, FIELD_TYPE.INTERVAL: _none, FIELD_TYPE.LONG: int, FIELD_TYPE.LONG_BLOB: _bytes_to_str, FIELD_TYPE.LONGLONG: int, FIELD_TYPE.MEDIUM_BLOB: _bytes_to_str, FIELD_TYPE.NEWDATE: _datetime_or_none, FIELD_TYPE.NEWDECIMAL: float, FIELD_TYPE.NULL: _none, FIELD_TYPE.SET: _bytes_to_str, FIELD_TYPE.SHORT: int, FIELD_TYPE.STRING: _bytes_to_str, FIELD_TYPE.TIME: _timedelta_or_none, FIELD_TYPE.TIMESTAMP: _datetime_or_none, FIELD_TYPE.TINY: int, FIELD_TYPE.TINY_BLOB: _bytes_to_str, FIELD_TYPE.VAR_STRING: _bytes_to_str, FIELD_TYPE.VARCHAR: _bytes_to_str, FIELD_TYPE.YEAR: _year_to_datetime } CHARSET = 'utf8' def __init__(self, host=setting.MYSQL_HOST, port=setting.MYSQL_PORT, user=setting.MYSQL_USER, pass_=setting.MYSQL_PASS, base=setting.MYSQL_BASE, charset=CHARSET, autocommit=False): self._session = None self._host = host self._port = port self._user = user self._pass = pass_ self._base = base self._charset = charset self._autocommit = autocommit def connect(self): if self._session is None: self._session = _mysql.connect( host=self._host, port=self._port, user=self._user, passwd=self._pass, db=self._base, conv=self.CONVERTER ) self._session.set_character_set(self._charset) self._session.autocommit(self._autocommit) @staticmethod def _build_stmt(stmt, args): if args is not None: for key in args: if isinstance(args[key], str): args[key] = "'{}'".format(args[key].replace("'", "''").replace('\\', '\\\\')) elif isinstance(args[key], datetime): args[key] = "'{}'".format(args[key].strftime('%Y-%m-%d %H:%M:%S')) elif isinstance(args[key], date): args[key] = "'{}'".format(args[key].strftime('%Y-%m-%d')) elif isinstance(args[key], float): args[key] = str(args[key]) elif isinstance(args[key], int): args[key] = str(args[key]) elif args[key] is None: args[key] = 'NULL' else: raise NameError('Argument type not allowed here: {} - {}'.format(type(args[key]), args[key])) stmt = re.sub(r':([a-zA-Z0-9_]+)', r'%(\1)s', stmt) % args return stmt def query(self, stmt, args=None): self.connect() self._session.query(self._build_stmt(stmt, args)) result = self._session.use_result() fields = result.describe() rows = list() while True: row = result.fetch_row() if row: rows.append({fields[x][0]: row[0][x] for x in range(len(row[0]))}) else: break return rows def exec(self, stmt, args=None): self.connect() self._session.query(self._build_stmt(stmt, args)) return self._session.insert_id() def rollback(self): if self._session is not None: self._session.rollback() def commit(self): if self._session is not None and setting.MYSQL_SAVE: self._session.commit() def close(self): if self._session is not None: self._session.close() def get_random_ua(self): for row in self.query(""" SELECT useragents FROM user_agents ORDER BY RAND() LIMIT 1 """): return row['useragents'] def get_tables(self, table_name=None): for row in self.query(""" SHOW TABLES """): if table_name is None or row['Tables_in_bob'] == table_name: yield row['Tables_in_bob']