crawler_task_management/public_function/asyn_mysql.py

143 lines
5.3 KiB
Python
Raw Normal View History

2025-11-26 17:40:11 +08:00
import asyncio
import aiomysql
from typing import List, Tuple, Dict, Any
class AsyncMySQL:
def __init__(self, config_data: Dict):
self.config = {
'host': config_data['host'],
'port': config_data['port'],
'user': config_data['user'],
'password': config_data['password'],
'db': config_data['db'],
'autocommit': True,
'minsize': 1,
'maxsize': config_data['max_overflow'],
}
self.pool = None
async def initialize(self):
"""初始化连接池"""
self.pool = await aiomysql.create_pool(**self.config)
return self
async def close(self):
"""关闭连接池"""
if self.pool:
self.pool.close()
await self.pool.wait_closed()
async def execute(self, query: str, params=None):
"""执行单条SQL语句"""
async with self.pool.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, params)
return cursor.rowcount
async def executemany(self, query: str, params_list: List[Tuple]):
"""批量执行SQL语句"""
async with self.pool.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.executemany(query, params_list)
return cursor.rowcount
async def insert_many_tuple(self, table: str, columns: List[str], data: List[Tuple]):
"""批量插入数据到指定表"""
placeholders = ', '.join(['%s'] * len(columns))
columns_str = ', '.join(columns)
query = f"INSERT INTO {table} ({columns_str}) VALUES ({placeholders})"
return await self.executemany(query, data)
async def insert_many(self, table: str, data: List[Dict[str, Any]]):
columns = list(data[0].keys())
# 从字典数据中提取值,转换为元组列表
params_list = [tuple(record.get(col) for col in columns) for record in data]
placeholders = ', '.join(['%s'] * len(columns))
columns_str = ', '.join(columns)
query = f"INSERT INTO {table} ({columns_str}) VALUES ({placeholders})"
return await self.executemany(query, params_list)
async def fetch_all(self, query: str, params=None) -> List[Dict[str, Any]]:
"""查询多条记录"""
async with self.pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(query, params)
return await cursor.fetchall()
async def delete(self, table: str, where_conditions: str = None, params: Tuple = None) -> int:
if where_conditions:
query = f"DELETE FROM {table} WHERE {where_conditions}"
else:
query = f"DELETE FROM {table}"
return await self.execute(query, params)
async def delete_many(self, table: str, conditions_list: List[Tuple[str, Tuple]]) -> int:
total_affected = 0
for where_conditions, params in conditions_list:
affected_rows = await self.delete(table, where_conditions, params)
total_affected += affected_rows
return total_affected
async def update(self, table: str, set_columns: Dict[str, Any], where_conditions: str = None,
params: Tuple = None) -> int:
"""
更新单条记录
:param table: 表名
:param set_columns: 要更新的字段和值字典
:param where_conditions: WHERE条件
:param params: 参数列表
:return: 受影响的行数
"""
set_clause = ', '.join([f"{k} = %s" for k in set_columns])
query = f"UPDATE {table} SET {set_clause}"
if where_conditions:
query += f" WHERE {where_conditions}"
# 构造参数列表
update_params = list(set_columns.values())
if where_conditions:
if params:
update_params.extend(params)
else:
raise ValueError("当使用WHERE条件时必须提供参数")
return await self.execute(query, update_params)
async def update_many(self, table: str, set_columns: Dict[str, Any], conditions_list: List[Tuple[str, Tuple]]) -> int:
"""
批量更新多条记录
:param table: 表名
:param set_columns: 要更新的字段和值字典
:param conditions_list: 条件列表每个元素是(where_conditions, params)
:return: 受影响的总行数
"""
set_clause = ', '.join([f"{k} = %s" for k in set_columns])
total_affected = 0
for where_conditions, params in conditions_list:
query = f"UPDATE {table} SET {set_clause} WHERE {where_conditions}"
update_params = list(set_columns.values())
if params:
update_params.extend(params)
affected_rows = await self.execute(query, update_params)
total_affected += affected_rows
return total_affected
if __name__ == '__main__':
from public_function.public_func import read_config
config = read_config(r"C:\workfile\crawler_task_management\public_function\config.yaml")
obj = AsyncMySQL(config['advert_policy'])
sql_str = f"""select account_id,password from crawler_account_record_info
where status=1 and app_name='xiapi' limit 1"""
obj.fetch_all(sql_str)