crawler_task_management/public_function/asyn_mysql.py

143 lines
5.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import aiomysql
from typing import List, Tuple, Dict, Any
class AsyncMySQL:
def __init__(self, config_data: Dict):
self.config = {
'host': config_data['host'],
'port': config_data['port'],
'user': config_data['user'],
'password': config_data['password'],
'db': config_data['db'],
'autocommit': True,
'minsize': 1,
'maxsize': config_data['max_overflow'],
}
self.pool = None
async def initialize(self):
"""初始化连接池"""
self.pool = await aiomysql.create_pool(**self.config)
return self
async def close(self):
"""关闭连接池"""
if self.pool:
self.pool.close()
await self.pool.wait_closed()
async def execute(self, query: str, params=None):
"""执行单条SQL语句"""
async with self.pool.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, params)
return cursor.rowcount
async def executemany(self, query: str, params_list: List[Tuple]):
"""批量执行SQL语句"""
async with self.pool.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.executemany(query, params_list)
return cursor.rowcount
async def insert_many_tuple(self, table: str, columns: List[str], data: List[Tuple]):
"""批量插入数据到指定表"""
placeholders = ', '.join(['%s'] * len(columns))
columns_str = ', '.join(columns)
query = f"INSERT INTO {table} ({columns_str}) VALUES ({placeholders})"
return await self.executemany(query, data)
async def insert_many(self, table: str, data: List[Dict[str, Any]]):
columns = list(data[0].keys())
# 从字典数据中提取值,转换为元组列表
params_list = [tuple(record.get(col) for col in columns) for record in data]
placeholders = ', '.join(['%s'] * len(columns))
columns_str = ', '.join(columns)
query = f"INSERT INTO {table} ({columns_str}) VALUES ({placeholders})"
return await self.executemany(query, params_list)
async def fetch_all(self, query: str, params=None) -> List[Dict[str, Any]]:
"""查询多条记录"""
async with self.pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(query, params)
return await cursor.fetchall()
async def delete(self, table: str, where_conditions: str = None, params: Tuple = None) -> int:
if where_conditions:
query = f"DELETE FROM {table} WHERE {where_conditions}"
else:
query = f"DELETE FROM {table}"
return await self.execute(query, params)
async def delete_many(self, table: str, conditions_list: List[Tuple[str, Tuple]]) -> int:
total_affected = 0
for where_conditions, params in conditions_list:
affected_rows = await self.delete(table, where_conditions, params)
total_affected += affected_rows
return total_affected
async def update(self, table: str, set_columns: Dict[str, Any], where_conditions: str = None,
params: Tuple = None) -> int:
"""
更新单条记录
:param table: 表名
:param set_columns: 要更新的字段和值字典
:param where_conditions: WHERE条件
:param params: 参数列表
:return: 受影响的行数
"""
set_clause = ', '.join([f"{k} = %s" for k in set_columns])
query = f"UPDATE {table} SET {set_clause}"
if where_conditions:
query += f" WHERE {where_conditions}"
# 构造参数列表
update_params = list(set_columns.values())
if where_conditions:
if params:
update_params.extend(params)
else:
raise ValueError("当使用WHERE条件时必须提供参数")
return await self.execute(query, update_params)
async def update_many(self, table: str, set_columns: Dict[str, Any], conditions_list: List[Tuple[str, Tuple]]) -> int:
"""
批量更新多条记录
:param table: 表名
:param set_columns: 要更新的字段和值字典
:param conditions_list: 条件列表,每个元素是(where_conditions, params)
:return: 受影响的总行数
"""
set_clause = ', '.join([f"{k} = %s" for k in set_columns])
total_affected = 0
for where_conditions, params in conditions_list:
query = f"UPDATE {table} SET {set_clause} WHERE {where_conditions}"
update_params = list(set_columns.values())
if params:
update_params.extend(params)
affected_rows = await self.execute(query, update_params)
total_affected += affected_rows
return total_affected
if __name__ == '__main__':
from public_function.public_func import read_config
config = read_config(r"C:\workfile\crawler_task_management\public_function\config.yaml")
obj = AsyncMySQL(config['advert_policy'])
sql_str = f"""select account_id,password from crawler_account_record_info
where status=1 and app_name='xiapi' limit 1"""
obj.fetch_all(sql_str)