143 lines
5.3 KiB
Python
143 lines
5.3 KiB
Python
import asyncio
|
||
import aiomysql
|
||
from typing import List, Tuple, Dict, Any
|
||
|
||
|
||
class AsyncMySQL:
|
||
def __init__(self, config_data: Dict):
|
||
|
||
self.config = {
|
||
'host': config_data['host'],
|
||
'port': config_data['port'],
|
||
'user': config_data['user'],
|
||
'password': config_data['password'],
|
||
'db': config_data['db'],
|
||
'autocommit': True,
|
||
'minsize': 1,
|
||
'maxsize': config_data['max_overflow'],
|
||
}
|
||
self.pool = None
|
||
|
||
async def initialize(self):
|
||
"""初始化连接池"""
|
||
self.pool = await aiomysql.create_pool(**self.config)
|
||
return self
|
||
|
||
async def close(self):
|
||
"""关闭连接池"""
|
||
if self.pool:
|
||
self.pool.close()
|
||
await self.pool.wait_closed()
|
||
|
||
async def execute(self, query: str, params=None):
|
||
"""执行单条SQL语句"""
|
||
async with self.pool.acquire() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute(query, params)
|
||
return cursor.rowcount
|
||
|
||
async def executemany(self, query: str, params_list: List[Tuple]):
|
||
"""批量执行SQL语句"""
|
||
async with self.pool.acquire() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.executemany(query, params_list)
|
||
return cursor.rowcount
|
||
|
||
async def insert_many_tuple(self, table: str, columns: List[str], data: List[Tuple]):
|
||
"""批量插入数据到指定表"""
|
||
placeholders = ', '.join(['%s'] * len(columns))
|
||
columns_str = ', '.join(columns)
|
||
query = f"INSERT INTO {table} ({columns_str}) VALUES ({placeholders})"
|
||
return await self.executemany(query, data)
|
||
|
||
async def insert_many(self, table: str, data: List[Dict[str, Any]]):
|
||
columns = list(data[0].keys())
|
||
# 从字典数据中提取值,转换为元组列表
|
||
params_list = [tuple(record.get(col) for col in columns) for record in data]
|
||
placeholders = ', '.join(['%s'] * len(columns))
|
||
columns_str = ', '.join(columns)
|
||
query = f"INSERT INTO {table} ({columns_str}) VALUES ({placeholders})"
|
||
return await self.executemany(query, params_list)
|
||
|
||
async def fetch_all(self, query: str, params=None) -> List[Dict[str, Any]]:
|
||
"""查询多条记录"""
|
||
async with self.pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
await cursor.execute(query, params)
|
||
return await cursor.fetchall()
|
||
|
||
async def delete(self, table: str, where_conditions: str = None, params: Tuple = None) -> int:
|
||
|
||
if where_conditions:
|
||
query = f"DELETE FROM {table} WHERE {where_conditions}"
|
||
else:
|
||
query = f"DELETE FROM {table}"
|
||
|
||
return await self.execute(query, params)
|
||
|
||
async def delete_many(self, table: str, conditions_list: List[Tuple[str, Tuple]]) -> int:
|
||
|
||
total_affected = 0
|
||
for where_conditions, params in conditions_list:
|
||
affected_rows = await self.delete(table, where_conditions, params)
|
||
total_affected += affected_rows
|
||
return total_affected
|
||
|
||
async def update(self, table: str, set_columns: Dict[str, Any], where_conditions: str = None,
|
||
params: Tuple = None) -> int:
|
||
"""
|
||
更新单条记录
|
||
:param table: 表名
|
||
:param set_columns: 要更新的字段和值字典
|
||
:param where_conditions: WHERE条件
|
||
:param params: 参数列表
|
||
:return: 受影响的行数
|
||
"""
|
||
set_clause = ', '.join([f"{k} = %s" for k in set_columns])
|
||
query = f"UPDATE {table} SET {set_clause}"
|
||
|
||
if where_conditions:
|
||
query += f" WHERE {where_conditions}"
|
||
|
||
# 构造参数列表
|
||
update_params = list(set_columns.values())
|
||
if where_conditions:
|
||
if params:
|
||
update_params.extend(params)
|
||
else:
|
||
raise ValueError("当使用WHERE条件时,必须提供参数")
|
||
|
||
return await self.execute(query, update_params)
|
||
|
||
async def update_many(self, table: str, set_columns: Dict[str, Any], conditions_list: List[Tuple[str, Tuple]]) -> int:
|
||
"""
|
||
批量更新多条记录
|
||
:param table: 表名
|
||
:param set_columns: 要更新的字段和值字典
|
||
:param conditions_list: 条件列表,每个元素是(where_conditions, params)
|
||
:return: 受影响的总行数
|
||
"""
|
||
set_clause = ', '.join([f"{k} = %s" for k in set_columns])
|
||
total_affected = 0
|
||
|
||
for where_conditions, params in conditions_list:
|
||
query = f"UPDATE {table} SET {set_clause} WHERE {where_conditions}"
|
||
update_params = list(set_columns.values())
|
||
if params:
|
||
update_params.extend(params)
|
||
|
||
affected_rows = await self.execute(query, update_params)
|
||
total_affected += affected_rows
|
||
|
||
return total_affected
|
||
|
||
|
||
if __name__ == '__main__':
|
||
from public_function.public_func import read_config
|
||
|
||
config = read_config(r"C:\workfile\crawler_task_management\public_function\config.yaml")
|
||
obj = AsyncMySQL(config['advert_policy'])
|
||
sql_str = f"""select account_id,password from crawler_account_record_info
|
||
where status=1 and app_name='xiapi' limit 1"""
|
||
obj.fetch_all(sql_str)
|