From 6975ef004118efc9a14d8d5899c679dccd79a405 Mon Sep 17 00:00:00 2001
From: zayac <stupidzayac@gmail.com>
Date: Sun, 2 Jun 2024 16:58:22 +0800
Subject: [PATCH] =?UTF-8?q?=E5=8A=A0=E5=85=A5=E4=BA=86=E9=99=90=E6=B5=81?=
 =?UTF-8?q?=E5=99=A8,=E7=8E=B0=E5=9C=A8=E4=B8=8D=E4=BC=9A=E5=87=BA?=
 =?UTF-8?q?=E7=8E=B0=E6=9F=A5=E8=AF=A2=E5=A4=B1=E8=B4=A5=E7=9A=84=E6=83=85?=
 =?UTF-8?q?=E5=86=B5=E4=BA=86?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 src/config.ini                |  7 +++--
 src/entity/banner_info.py     | 21 ++++++++++++---
 src/entity/database.py        |  3 ++-
 src/entity/finance.py         | 20 +++++++++++---
 src/entity/pay_record.py      | 51 +++++++++++++++++++++++++++++------
 src/entity/visual_list.py     |  8 +++---
 src/ui/thread_pool_manager.py | 42 ++++++++++++++++++++++++-----
 7 files changed, 124 insertions(+), 28 deletions(-)

diff --git a/src/config.ini b/src/config.ini
index becca02..4697356 100644
--- a/src/config.ini
+++ b/src/config.ini
@@ -1,6 +1,9 @@
+;[Credentials]
+;username = zayac
+;password = 123456
 [Credentials]
-username = zayac
-password = 123456
+username = luffy
+password = luffy230505
 
 [Minimum]
 minimum = True
\ No newline at end of file
diff --git a/src/entity/banner_info.py b/src/entity/banner_info.py
index 2b5b910..e439457 100644
--- a/src/entity/banner_info.py
+++ b/src/entity/banner_info.py
@@ -1,5 +1,5 @@
 import time
-from concurrent.futures import ThreadPoolExecutor
+from concurrent.futures import ThreadPoolExecutor, as_completed
 from dataclasses import dataclass
 from typing import List
 
@@ -12,7 +12,7 @@ from src.entity.account import Account
 from src.entity.member import get_today_new_member_list
 from src.entity.pay_record import get_latest_deposit_user
 from src.entity.user import User
-from src.ui.thread_pool_manager import global_thread_pool
+from src.ui.thread_pool_manager import global_thread_pool, rate_limiter
 
 
 @dataclass
@@ -84,5 +84,18 @@ def query_banner_info(account: Account):
 
 
 def get_banner_info_by_user(user: User) -> List[BannerInfo]:
-    futures = [global_thread_pool.submit(get_banner_info, account) for account in user.accounts]
-    return [future.result() for future in futures]
+    futures = []
+    for account in user.accounts:
+        rate_limiter.acquire()
+        futures.append(global_thread_pool.submit(get_banner_info, account))
+
+    banner_info_list = []
+    for future in as_completed(futures):
+        try:
+            banner_info = future.result()
+            banner_info_list.append(banner_info)
+            logger.info(f"Task completed for account: {banner_info.agentCode}")
+        except Exception as e:
+            logger.error(f"Error in future result: {e}")
+
+    return banner_info_list
diff --git a/src/entity/database.py b/src/entity/database.py
index 755bf02..8fd64b8 100644
--- a/src/entity/database.py
+++ b/src/entity/database.py
@@ -59,4 +59,5 @@ class Database:
         session.commit()
 
 
-db = Database('mysql+mysqlconnector://ky_tools:HJQY35seXen8patn@1panel.stupidpz.com:3306/ky_tools')
+# db = Database('mysql+mysqlconnector://ky_tools:HJQY35seXen8patn@1panel.stupidpz.com:3306/ky_tools')
+db = Database('mysql+mysqlconnector://www_luffy_tool:GrCDtSynbK5MHb48@110.40.20.148:3306/www_luffy_tool')
diff --git a/src/entity/finance.py b/src/entity/finance.py
index eb4ddd6..5cdc6ca 100644
--- a/src/entity/finance.py
+++ b/src/entity/finance.py
@@ -1,4 +1,4 @@
-from concurrent.futures import ThreadPoolExecutor
+from concurrent.futures import ThreadPoolExecutor, as_completed
 from dataclasses import dataclass
 from decimal import Decimal
 from typing import List
@@ -11,7 +11,7 @@ from src.core.constant import FINANCE_URL
 from src.core.util import get_curr_day, get_first_day_by_str
 from src.entity.account import Account
 from src.entity.user import User, get_user_by_telegram_id
-from src.ui.thread_pool_manager import global_thread_pool
+from src.ui.thread_pool_manager import global_thread_pool, rate_limiter
 
 '''
 财务报表
@@ -58,8 +58,20 @@ def get_finances_by_user(user: User, date) -> List[Finance]:
     accounts = user.accounts
     start_date = util.get_first_day_by_str(date)
 
-    futures = [global_thread_pool.submit(get_finance, account, start_date, date) for account in accounts]
-    return [future.result() for future in futures]
+    futures = []
+    for account in accounts:
+        rate_limiter.acquire()
+        futures.append(global_thread_pool.submit(get_finance, account, start_date, date))
+
+    finance_list = []
+    for future in as_completed(futures):
+        try:
+            finance = future.result()
+            finance_list.append(finance)
+        except Exception as e:
+            logger.error(f"Error in future result: {e}")
+
+    return finance_list
 
 
 def get_net_win_by_user(user: User, date: str) -> str:
diff --git a/src/entity/pay_record.py b/src/entity/pay_record.py
index aab5d4f..ebe5810 100644
--- a/src/entity/pay_record.py
+++ b/src/entity/pay_record.py
@@ -10,7 +10,7 @@ from src.entity.account import Account
 from src.entity.member import (MemberList, async_get_member_detail_by_name,
                                get_member_by_name, get_member_list)
 from src.entity.user import User, get_user_by_telegram_id, get_user_by_username_and_password
-from src.ui.thread_pool_manager import global_thread_pool
+from src.ui.thread_pool_manager import global_thread_pool, rate_limiter
 
 
 @dataclass
@@ -50,7 +50,7 @@ def get_pay_record(account: Account):
     return [PayRecord(**item) for item in api_response.data['list']]
 
 
-from concurrent.futures import ThreadPoolExecutor, as_completed
+from concurrent.futures import as_completed
 from datetime import datetime
 
 
@@ -68,7 +68,11 @@ def get_latest_deposit_user(account: Account, count: int):
 
     # 开启多线程根据用户名查询所有数据
     results = []
-    futures = [global_thread_pool.submit(get_member_by_name, account, name) for name in unique_names_within_time]
+    futures = []
+    for name in unique_names_within_time:
+        rate_limiter.acquire()
+        futures.append(global_thread_pool.submit(get_member_by_name, account, name))
+
     for future in as_completed(futures):
         try:
             result = future.result()
@@ -179,7 +183,10 @@ def get_pay_record_list(account: Account, date: str) -> Dict[str, List[str]]:
     }
     member_list = get_member_list(account, params)
     if member_list is not None and len(member_list) > 0:
-        futures = [global_thread_pool.submit(get_pay_record_detail, account, member, date) for member in member_list]
+        futures = []
+        for member in member_list:
+            rate_limiter.acquire()  # 确保每个任务的速率限制
+            futures.append(global_thread_pool.submit(get_pay_record_detail, account, member, date))
         for future in futures:
             result = future.result()
             if result:
@@ -206,7 +213,10 @@ def get_pay_record_detail(account: Account, member: MemberList, date: str) -> Op
 def get_pay_failed_by_user(user: User, date: str) -> Optional[str]:
     logger.info(f'Getting pay failed by user: {user.username}')
 
-    futures = [global_thread_pool.submit(get_pay_record_list, account, date) for account in user.accounts]
+    futures = []
+    for account in user.accounts:
+        rate_limiter.acquire()  # 确保每个任务的速率限制
+        futures.append(global_thread_pool.submit(get_pay_record_list, account, date))
 
     # 使用列表推导式构建结果字符串
     text_lines = [
@@ -214,7 +224,7 @@ def get_pay_failed_by_user(user: User, date: str) -> Optional[str]:
         for future in futures if (res := future.result())['names']
     ]
 
-    text = '\n'.join(text_lines)
+    text = '\n\n'.join(text_lines)
 
     if not text:
         logger.info('无存款失败用户')
@@ -224,9 +234,34 @@ def get_pay_failed_by_user(user: User, date: str) -> Optional[str]:
     return text
 
 
+# def get_pay_failed_by_user(user: User, date: str) -> Optional[str]:
+#     logger.info(f'Getting pay failed by user: {user.username}')
+#     rate_limiter.acquire()
+#     futures = [global_thread_pool.submit(get_pay_record_list, account, date) for account in user.accounts]
+#
+#     # 使用列表推导式构建结果字符串
+#     text_lines = [
+#         "{}\n{}".format(res['name'], '\n'.join(res['names']))
+#         for future in futures if (res := future.result())['names']
+#     ]
+#
+#     text = '\n'.join(text_lines)
+#
+#     if not text:
+#         logger.info('无存款失败用户')
+#         return '无存款失败用户'
+#
+#     logger.info(text)
+#     return text
+
+
 def get_pay_failed_by_telegram_id(telegram_id: int) -> Optional[str]:
     user = get_user_by_telegram_id(telegram_id)
-    futures = [global_thread_pool.submit(get_pay_record_list, account, get_curr_day()) for account in user.accounts]
+    futures = []
+
+    for account in user.accounts:
+        rate_limiter.acquire()  # 确保每个任务的速率限制
+        futures.append(global_thread_pool.submit(get_pay_record_list, account, get_curr_day()))
 
     # 使用列表推导式构建结果字符串
     text_lines = [
@@ -234,7 +269,7 @@ def get_pay_failed_by_telegram_id(telegram_id: int) -> Optional[str]:
         for future in futures if (res := future.result())['names']
     ]
 
-    text = '\n'.join(text_lines)
+    text = '\n\n'.join(text_lines)
 
     if not text:
         logger.info('无存款失败用户')
diff --git a/src/entity/visual_list.py b/src/entity/visual_list.py
index fff2fe6..e17d5d9 100644
--- a/src/entity/visual_list.py
+++ b/src/entity/visual_list.py
@@ -9,7 +9,7 @@ from src.core.constant import VISUAL_LIST_URL
 from src.core.util import get_curr_day, get_curr_month
 from src.entity.account import Account
 from src.entity.user import User, get_user_by_telegram_id
-from src.ui.thread_pool_manager import global_thread_pool
+from src.ui.thread_pool_manager import global_thread_pool, rate_limiter
 
 
 # 视图列表对象 对应界面上的图表
@@ -103,8 +103,10 @@ def get_statics(account, date=get_curr_day()) -> VisualInfo:
 
 def count_by_user(user: User, date: str):
     accounts = user.accounts
-
-    futures = [global_thread_pool.submit(get_statics, account, date) for account in accounts]
+    futures = []
+    for account in accounts:
+        rate_limiter.acquire()
+        futures.append(global_thread_pool.submit(get_statics, account, date))
     return [future.result() for future in futures]
 
 
diff --git a/src/ui/thread_pool_manager.py b/src/ui/thread_pool_manager.py
index fb3d3cd..0d596a9 100644
--- a/src/ui/thread_pool_manager.py
+++ b/src/ui/thread_pool_manager.py
@@ -1,23 +1,53 @@
-# threadpool_manager.py
 from PyQt6.QtCore import QThreadPool
+from concurrent.futures import ThreadPoolExecutor
+import threading
+import time
 
 # 创建一个全局的线程池实例
 pyqt_thread_pool = QThreadPool.globalInstance()
 
-from concurrent.futures import ThreadPoolExecutor
-
 
 class ThreadPoolManager:
     _instance = None
+    _lock = threading.Lock()
 
     def __new__(cls, *args, **kwargs):
         if not cls._instance:
-            cls._instance = super(ThreadPoolManager, cls).__new__(cls, *args, **kwargs)
-            cls._instance.thread_pool = ThreadPoolExecutor(max_workers=5)
+            with cls._lock:
+                if not cls._instance:
+                    cls._instance = super(ThreadPoolManager, cls).__new__(cls)
+                    max_workers = kwargs.get('max_workers', 5)
+                    cls._instance.thread_pool = ThreadPoolExecutor(max_workers=max_workers)
         return cls._instance
 
     def get_thread_pool(self):
         return self.thread_pool
 
 
-global_thread_pool = ThreadPoolManager().get_thread_pool()
+class RateLimiter:
+    def __init__(self, rate: int, per: float):
+        self._rate = rate
+        self._per = per
+        self._allowance = rate
+        self._last_check = time.monotonic()
+        self._lock = threading.Lock()
+
+    def acquire(self):
+        with self._lock:
+            current = time.monotonic()
+            time_passed = current - self._last_check
+            self._last_check = current
+            self._allowance += time_passed * (self._rate / self._per)
+            if self._allowance > self._rate:
+                self._allowance = self._rate
+            if self._allowance < 1.0:
+                sleep_time = (1.0 - self._allowance) * (self._per / self._rate)
+                time.sleep(sleep_time)
+                self._allowance = 0
+                return
+            self._allowance -= 1.0
+
+
+# 创建全局实例
+global_thread_pool = ThreadPoolManager(max_workers=20).get_thread_pool()
+rate_limiter = RateLimiter(rate=5, per=1.0)