import asyncio import aiomysql from typing import List, Tuple, Dict, Any class AsyncMySQL: def __init__(self, config_data: Dict): self.config = { 'host': config_data['host'], 'port': config_data['port'], 'user': config_data['user'], 'password': config_data['password'], 'db': config_data['db'], 'autocommit': True, 'minsize': 1, 'maxsize': config_data['max_overflow'], } self.pool = None async def initialize(self): """初始化连接池""" self.pool = await aiomysql.create_pool(**self.config) return self async def close(self): """关闭连接池""" if self.pool: self.pool.close() await self.pool.wait_closed() async def execute(self, query: str, params=None): """执行单条SQL语句""" async with self.pool.acquire() as conn: async with conn.cursor() as cursor: await cursor.execute(query, params) return cursor.rowcount async def executemany(self, query: str, params_list: List[Tuple]): """批量执行SQL语句""" async with self.pool.acquire() as conn: async with conn.cursor() as cursor: await cursor.executemany(query, params_list) return cursor.rowcount async def insert_many_tuple(self, table: str, columns: List[str], data: List[Tuple]): """批量插入数据到指定表""" placeholders = ', '.join(['%s'] * len(columns)) columns_str = ', '.join(columns) query = f"INSERT INTO {table} ({columns_str}) VALUES ({placeholders})" return await self.executemany(query, data) async def insert_many(self, table: str, data: List[Dict[str, Any]]): columns = list(data[0].keys()) # 从字典数据中提取值,转换为元组列表 params_list = [tuple(record.get(col) for col in columns) for record in data] placeholders = ', '.join(['%s'] * len(columns)) columns_str = ', '.join(columns) query = f"INSERT INTO {table} ({columns_str}) VALUES ({placeholders})" return await self.executemany(query, params_list) async def fetch_all(self, query: str, params=None) -> List[Dict[str, Any]]: """查询多条记录""" async with self.pool.acquire() as conn: async with conn.cursor(aiomysql.DictCursor) as cursor: await cursor.execute(query, params) return await cursor.fetchall() async def delete(self, table: str, where_conditions: str = None, params: Tuple = None) -> int: if where_conditions: query = f"DELETE FROM {table} WHERE {where_conditions}" else: query = f"DELETE FROM {table}" return await self.execute(query, params) async def delete_many(self, table: str, conditions_list: List[Tuple[str, Tuple]]) -> int: total_affected = 0 for where_conditions, params in conditions_list: affected_rows = await self.delete(table, where_conditions, params) total_affected += affected_rows return total_affected async def update(self, table: str, set_columns: Dict[str, Any], where_conditions: str = None, params: Tuple = None) -> int: """ 更新单条记录 :param table: 表名 :param set_columns: 要更新的字段和值字典 :param where_conditions: WHERE条件 :param params: 参数列表 :return: 受影响的行数 """ set_clause = ', '.join([f"{k} = %s" for k in set_columns]) query = f"UPDATE {table} SET {set_clause}" if where_conditions: query += f" WHERE {where_conditions}" # 构造参数列表 update_params = list(set_columns.values()) if where_conditions: if params: update_params.extend(params) else: raise ValueError("当使用WHERE条件时,必须提供参数") return await self.execute(query, update_params) async def update_many(self, table: str, set_columns: Dict[str, Any], conditions_list: List[Tuple[str, Tuple]]) -> int: """ 批量更新多条记录 :param table: 表名 :param set_columns: 要更新的字段和值字典 :param conditions_list: 条件列表,每个元素是(where_conditions, params) :return: 受影响的总行数 """ set_clause = ', '.join([f"{k} = %s" for k in set_columns]) total_affected = 0 for where_conditions, params in conditions_list: query = f"UPDATE {table} SET {set_clause} WHERE {where_conditions}" update_params = list(set_columns.values()) if params: update_params.extend(params) affected_rows = await self.execute(query, update_params) total_affected += affected_rows return total_affected if __name__ == '__main__': from public_function.public_func import read_config config = read_config(r"C:\workfile\crawler_task_management\public_function\config.yaml") obj = AsyncMySQL(config['advert_policy']) sql_str = f"""select account_id,password from crawler_account_record_info where status=1 and app_name='xiapi' limit 1""" obj.fetch_all(sql_str)