diff --git a/ai/huawei_gateway.py b/ai/huawei_gateway.py new file mode 100755 index 0000000..fb4ca3d --- /dev/null +++ b/ai/huawei_gateway.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python3 +""" +华为云 Token 动态网关 +- 6小时缓存机制 +- 支持内存扫描自动刷新 +- HUAWEI_TOKEN 环境变量最高优先级 +- Token 持久化到 /etc/huawei-gateway.env +- SSE 流式转发 +- 连接池 (20 连接) +- 401 自动重试 +- 兼容生产环境 (Waitress 32线程 / Gunicorn) +""" +import os +import re +import sys +import time +import logging +import threading +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed + +from flask import Flask, request, Response + +# 尝试导入 requests,失败则给出明确提示 +try: + import requests +except ImportError: + print("错误:缺少 requests 模块。请运行: pip install requests") + sys.exit(1) + +# ================= 配置 ================= +CACHE_TTL = 19800 # 5.5 小时(安全线) +MAX_WORKERS = 8 # 内存扫描线程数 +MAX_MEM_SEGMENT = 200 * 1024 * 1024 # 单段最大扫描 200MB +TOKEN_PATTERN = re.compile(b'Bearer ([A-Za-z0-9+/=_-]{100,})') +TARGET_HOST = 'tokenhub.developer.huaweicloud.com' + +# ================= 日志 ================= +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(message)s', + handlers=[ + logging.StreamHandler(sys.stdout) + ] +) +logger = logging.getLogger('huawei-gateway') + +# ================= 缓存 ================= +class TokenCache: + def __init__(self): + self._token = None + self._expires_at = 0 + self._lock = threading.RLock() + self._last_scan = 0 + self._scan_interval = 60 # 扫描间隔最小 60 秒 + self._blacklist = set() # 已失效的 token 指纹 + + def _fingerprint(self, token): + """取 token 前 16 + 后 16 字符做指纹""" + if len(token) <= 32: + return token + return token[:16] + token[-16:] + + def get(self): + with self._lock: + now = time.time() + if self._token and now < self._expires_at: + return self._token + return None + + def set(self, token, ttl=CACHE_TTL): + with self._lock: + self._token = token + self._expires_at = time.time() + ttl + self._last_scan = time.time() + self._blacklist.discard(self._fingerprint(token)) + + def blacklist_current(self): + """将当前 token 加入黑名单""" + with self._lock: + if self._token: + self._blacklist.add(self._fingerprint(self._token)) + self._token = None + self._expires_at = 0 + + def is_blacklisted(self, token): + with self._lock: + return self._fingerprint(token) in self._blacklist + + def is_scan_cooldown(self): + with self._lock: + return (time.time() - self._last_scan) < self._scan_interval + + def clear(self): + with self._lock: + self._token = None + self._expires_at = 0 + + +cache = TokenCache() + +# ================= 内存扫描 ================= +def scan_pid_mem(pid): + """扫描单个进程的内存寻找 Token""" + maps_path = f'/proc/{pid}/maps' + mem_path = f'/proc/{pid}/mem' + + if not os.path.exists(maps_path) or not os.path.exists(mem_path): + return None + + try: + with open(maps_path, 'r') as f: + for line in f: + parts = line.split() + if len(parts) < 2: + continue + perms = parts[1] + if 'r' not in perms or 'w' not in perms: + continue + + addrs = parts[0].split('-') + if len(addrs) != 2: + continue + + start = int(addrs[0], 16) + end = int(addrs[1], 16) + size = end - start + + if size > MAX_MEM_SEGMENT or size < 1024: + continue + + try: + with open(mem_path, 'rb') as mem: + mem.seek(start) + chunk_size = 64 * 1024 + remaining = size + while remaining > 0: + to_read = min(chunk_size, remaining) + data = mem.read(to_read) + if not data: + break + + for match in TOKEN_PATTERN.finditer(data): + token = match.group(1).decode('ascii', errors='replace') + if len(token) > 200: + return token + + remaining -= len(data) + except (PermissionError, OSError, ValueError): + continue + + except (PermissionError, OSError, ProcessLookupError): + pass + return None + + +def find_token_in_memory(): + """在所有进程中扫描 Token""" + # HUAWEI_TOKEN 环境变量优先级最高 + env_token = os.environ.get('HUAWEI_TOKEN', '').strip() + if env_token and len(env_token) > 200 and not cache.is_blacklisted(env_token): + cached = cache.get() + if cached != env_token: + cache.set(env_token) + logger.info("Token 从 HUAWEI_TOKEN 环境变量加载") + return env_token + + # 从持久化文件加载 + env_file = '/etc/huawei-gateway.env' + if os.path.isfile(env_file): + try: + with open(env_file, 'r') as f: + for line in f: + line = line.strip() + if line.startswith('HUAWEI_TOKEN='): + file_token = line.split('=', 1)[1].strip().strip('"').strip("'") + if file_token and len(file_token) > 200 and not cache.is_blacklisted(file_token): + cached = cache.get() + if cached != file_token: + cache.set(file_token) + logger.info("Token 从持久化文件加载") + return file_token + break + except (OSError, IOError): + pass + + if cache.is_scan_cooldown() and cache.get(): + return cache.get() + + try: + pids = [pid for pid in os.listdir('/proc') if pid.isdigit()] + except OSError: + logger.error("无法访问 /proc 目录") + return None + + # 优先扫描常见进程 + priority_pids = [] + other_pids = [] + + for pid in pids: + try: + exe_path = os.readlink(f'/proc/{pid}/exe') + if any(x in exe_path for x in ['python', 'node', 'java', 'chrome', 'electron']): + priority_pids.append(pid) + else: + other_pids.append(pid) + except (OSError, PermissionError): + other_pids.append(pid) + + all_pids = priority_pids + other_pids + + found_tokens = [] + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + futures = {executor.submit(scan_pid_mem, pid): pid for pid in all_pids} + for future in as_completed(futures): + try: + token = future.result(timeout=5) + if token and not cache.is_blacklisted(token): + found_tokens.append((token, futures[future])) + except Exception: + continue + + for token, pid in found_tokens: + cache.set(token) + logger.info(f"Token 已刷新 (来源 PID: {pid})") + return token + + return cache.get() + + +# ================= HTTP 会话池 ================= +http_session = requests.Session() +adapter = requests.adapters.HTTPAdapter( + pool_connections=20, + pool_maxsize=20, + max_retries=3 +) +http_session.mount('https://', adapter) +http_session.mount('http://', adapter) + +# ================= Flask 应用 ================= +app = Flask(__name__) + +@app.route('/health') +def health(): + """健康检查端点""" + token = cache.get() + return { + "status": "healthy", + "token_cached": token is not None, + "token_expires_in": max(0, cache._expires_at - time.time()) if hasattr(cache, '_expires_at') else 0, + "blacklisted": len(cache._blacklist) if hasattr(cache, '_blacklist') else 0 + }, 200 + + +@app.route('/set_token', methods=['POST']) +def set_token(): + """手动注入有效 Token""" + data = request.get_json(force=True, silent=True) if request.is_json else {} + token = data.get('token', '') + if not token or len(token) < 100: + return {"error": "请提供有效的 token"}, 400 + cache.set(token) + # 持久化到文件以便重启后恢复 + try: + with open('/etc/huawei-gateway.env', 'w') as f: + f.write(f'HUAWEI_TOKEN={token}\n') + except (OSError, IOError): + pass + logger.info("Token 已手动注入并持久化") + return {"status": "ok", "token_fingerprint": cache._fingerprint(token)}, 200 + + +@app.route('/v2/', methods=['POST', 'GET', 'OPTIONS', 'PUT', 'DELETE']) +def proxy(subpath): + if request.method == 'OPTIONS': + return Response(status=200, headers={ + 'Access-Control-Allow-Origin': '*', + 'Access-Control-Allow-Methods': 'GET, POST, PUT, DELETE, OPTIONS', + 'Access-Control-Allow-Headers': 'Content-Type, Authorization' + }) + + real_token = find_token_in_memory() + if not real_token: + logger.error("未在内存中找到华为云 Token") + return {"error": "未在内存中找到华为云Token,请确保华为云相关应用正在运行"}, 500 + + # 构建请求头 + headers = {} + for k, v in request.headers: + kl = k.lower() + if kl not in ('host', 'content-length', 'connection', 'accept-encoding', 'transfer-encoding'): + headers[k] = v + + headers['Authorization'] = f'Bearer {real_token}' + headers['Host'] = TARGET_HOST + + target_url = f'https://{TARGET_HOST}/v2/{subpath}' + + try: + # 使用 stream=True 支持 SSE 流式转发 + resp = http_session.request( + method=request.method, + url=target_url, + headers=headers, + data=request.get_data(), + cookies=request.cookies, + allow_redirects=False, + timeout=60, + stream=True + ) + + # 401 兜底:Token 可能提前过期,加入黑名单后强制刷新重试 + if resp.status_code == 401: + resp.close() + logger.warning("收到 401,将当前 Token 加入黑名单并强制刷新...") + cache.blacklist_current() + new_token = find_token_in_memory() + if new_token and new_token != real_token: + headers['Authorization'] = f'Bearer {new_token}' + resp = http_session.request( + method=request.method, + url=target_url, + headers=headers, + data=request.get_data(), + cookies=request.cookies, + allow_redirects=False, + timeout=60, + stream=True + ) + + # 过滤 hop-by-hop 头和压缩编码头 + skip_headers = {'transfer-encoding', 'content-encoding', 'content-length', + 'connection', 'keep-alive', 'upgrade'} + response_headers = [] + for k, v in resp.headers.items(): + if k.lower() not in skip_headers: + response_headers.append((k, v)) + + # SSE 流式转发 + content_type = resp.headers.get('Content-Type', '') + if 'text/event-stream' in content_type or resp.headers.get('Transfer-Encoding', '') == 'chunked': + def sse_stream(): + try: + for chunk in resp.iter_content(chunk_size=4096): + if chunk: + yield chunk + finally: + resp.close() + + return Response( + sse_stream(), + status=resp.status_code, + headers=response_headers, + direct_passthrough=True + ) + else: + # 非流式响应:读取完整内容 + content = resp.content + resp.close() + return Response( + content, + status=resp.status_code, + headers=response_headers + ) + + except requests.exceptions.Timeout: + logger.error("请求华为云 API 超时") + return {"error": "网关超时,请稍后重试"}, 504 + except requests.exceptions.ConnectionError: + logger.error("无法连接到华为云 API") + return {"error": "无法连接到华为云服务"}, 502 + except Exception as e: + logger.error(f"网关转发失败: {traceback.format_exc()}") + return {"error": f"网关转发失败: {str(e)}"}, 500 + + +def main(): + port = int(sys.argv[1]) if len(sys.argv) > 1 else 8080 + host = sys.argv[2] if len(sys.argv) > 2 else '127.0.0.1' + + # 从持久化文件加载 token + env_file = '/etc/huawei-gateway.env' + if os.path.isfile(env_file): + try: + with open(env_file, 'r') as f: + for line in f: + line = line.strip() + if line.startswith('HUAWEI_TOKEN=') and 'HUAWEI_TOKEN' not in os.environ: + val = line.split('=', 1)[1].strip().strip('"').strip("'") + if val and len(val) > 200: + os.environ['HUAWEI_TOKEN'] = val + logger.info("从持久化文件恢复 Token") + break + except (OSError, IOError): + pass + + # 尝试使用生产级 WSGI 服务器 + try: + import waitress + logger.info(f"使用 Waitress 启动网关 ({host}:{port})") + waitress.serve(app, host=host, port=port, threads=32) + except ImportError: + try: + import gunicorn.app.wsgiapp + logger.info(f"使用 Gunicorn 启动网关 ({host}:{port})") + os.execlp('gunicorn', 'gunicorn', '-w', '4', '-b', f'{host}:{port}', '--access-logfile', '-', 'huawei_gateway:app') + except (ImportError, OSError): + logger.warning("未安装 Waitress/Gunicorn,使用 Flask 开发服务器(建议生产环境安装 waitress)") + logger.info(f"启动网关 ({host}:{port})") + app.run(host=host, port=port, debug=False, threaded=True) + + +if __name__ == '__main__': + main() diff --git a/ai/hwaishell.sh b/ai/hwaishell.sh index 7c9f4ad..d1e8060 100644 --- a/ai/hwaishell.sh +++ b/ai/hwaishell.sh @@ -408,7 +408,7 @@ fi # 安装依赖(使用国内镜像) log_info "正在安装 Flask 和 requests..." $PIP_CMD install --upgrade pip --quiet -i https://pypi.tuna.tsinghua.edu.cn/simple 2>/dev/null -$PIP_CMD install flask requests --quiet -i https://pypi.tuna.tsinghua.edu.cn/simple +$PIP_CMD install flask requests waitress --quiet -i https://pypi.tuna.tsinghua.edu.cn/simple # ================= 创建网关核心代码 ================= mkdir -p /usr/local/bin /var/log /var/run @@ -419,8 +419,12 @@ cat << 'PYEOF' > /usr/local/bin/huawei_gateway.py 华为云 Token 动态网关 - 6小时缓存机制 - 支持内存扫描自动刷新 +- HUAWEI_TOKEN 环境变量最高优先级 +- Token 持久化到 /etc/huawei-gateway.env +- SSE 流式转发 +- 连接池 (20 连接) - 401 自动重试 -- 兼容生产环境 (Waitress/Gunicorn) +- 兼容生产环境 (Waitress 32线程 / Gunicorn) """ import os import re @@ -557,7 +561,7 @@ def scan_pid_mem(pid): for match in TOKEN_PATTERN.finditer(data): token = match.group(1).decode('ascii', errors='replace') # 验证 token 格式(华为云 token 通常是 JWT 格式) - if len(token) > 200 and token.count('.') >= 2: + if len(token) > 200: return token remaining -= len(data) @@ -571,6 +575,34 @@ def scan_pid_mem(pid): def find_token_in_memory(): """在所有进程中扫描 Token""" + # HUAWEI_TOKEN 环境变量优先级最高 + env_token = os.environ.get('HUAWEI_TOKEN', '').strip() + if env_token and len(env_token) > 200 and not cache.is_blacklisted(env_token): + cached = cache.get() + if cached != env_token: + cache.set(env_token) + logger.info("Token 从 HUAWEI_TOKEN 环境变量加载") + return env_token + + # 从持久化文件加载 + env_file = '/etc/huawei-gateway.env' + if os.path.isfile(env_file): + try: + with open(env_file, 'r') as f: + for line in f: + line = line.strip() + if line.startswith('HUAWEI_TOKEN='): + file_token = line.split('=', 1)[1].strip().strip('"').strip("'") + if file_token and len(file_token) > 200 and not cache.is_blacklisted(file_token): + cached = cache.get() + if cached != file_token: + cache.set(file_token) + logger.info("Token 从持久化文件加载") + return file_token + break + except (OSError, IOError): + pass + if cache.is_scan_cooldown() and cache.get(): return cache.get() @@ -616,6 +648,16 @@ def find_token_in_memory(): return cache.get() # 返回可能过期的缓存作为兜底 +# ================= HTTP 会话池 ================= +http_session = requests.Session() +adapter = requests.adapters.HTTPAdapter( + pool_connections=20, + pool_maxsize=20, + max_retries=3 +) +http_session.mount('https://', adapter) +http_session.mount('http://', adapter) + # ================= Flask 应用 ================= app = Flask(__name__) @@ -639,7 +681,13 @@ def set_token(): if not token or len(token) < 100: return {"error": "请提供有效的 token"}, 400 cache.set(token) - logger.info("Token 已手动注入") + # 持久化到文件以便重启后恢复 + try: + with open('/etc/huawei-gateway.env', 'w') as f: + f.write(f'HUAWEI_TOKEN={token}\n') + except (OSError, IOError): + pass + logger.info("Token 已手动注入并持久化") return {"status": "ok", "token_fingerprint": cache._fingerprint(token)}, 200 @@ -661,7 +709,7 @@ def proxy(subpath): headers = {} for k, v in request.headers: kl = k.lower() - if kl not in ('host', 'content-length', 'connection', 'accept-encoding'): + if kl not in ('host', 'content-length', 'connection', 'accept-encoding', 'transfer-encoding'): headers[k] = v headers['Authorization'] = f'Bearer {real_token}' @@ -670,7 +718,8 @@ def proxy(subpath): target_url = f'https://{TARGET_HOST}/v2/{subpath}' try: - resp = requests.request( + # 使用 stream=True 支持 SSE 流式转发 + resp = http_session.request( method=request.method, url=target_url, headers=headers, @@ -678,37 +727,62 @@ def proxy(subpath): cookies=request.cookies, allow_redirects=False, timeout=60, - stream=False + stream=True ) # 401 兜底:Token 可能提前过期,加入黑名单后强制刷新重试 if resp.status_code == 401: + resp.close() logger.warning("收到 401,将当前 Token 加入黑名单并强制刷新...") cache.blacklist_current() new_token = find_token_in_memory() if new_token and new_token != real_token: headers['Authorization'] = f'Bearer {new_token}' - resp = requests.request( + resp = http_session.request( method=request.method, url=target_url, headers=headers, data=request.get_data(), cookies=request.cookies, allow_redirects=False, - timeout=60 + timeout=60, + stream=True ) - # 构建响应 + # 过滤 hop-by-hop 头和压缩编码头 + skip_headers = {'transfer-encoding', 'content-encoding', 'content-length', + 'connection', 'keep-alive', 'upgrade'} response_headers = [] for k, v in resp.headers.items(): - if k.lower() not in ('transfer-encoding', 'content-encoding', 'content-length'): + if k.lower() not in skip_headers: response_headers.append((k, v)) - return Response( - resp.content, - status=resp.status_code, - headers=response_headers - ) + # SSE 流式转发 + content_type = resp.headers.get('Content-Type', '') + if 'text/event-stream' in content_type or resp.headers.get('Transfer-Encoding', '') == 'chunked': + def sse_stream(): + try: + for chunk in resp.iter_content(chunk_size=4096): + if chunk: + yield chunk + finally: + resp.close() + + return Response( + sse_stream(), + status=resp.status_code, + headers=response_headers, + direct_passthrough=True + ) + else: + # 非流式响应:读取完整内容 + content = resp.content + resp.close() + return Response( + content, + status=resp.status_code, + headers=response_headers + ) except requests.exceptions.Timeout: logger.error("请求华为云 API 超时") @@ -725,11 +799,27 @@ def main(): port = int(sys.argv[1]) if len(sys.argv) > 1 else 8080 host = sys.argv[2] if len(sys.argv) > 2 else '127.0.0.1' + # 从持久化文件加载 token + env_file = '/etc/huawei-gateway.env' + if os.path.isfile(env_file): + try: + with open(env_file, 'r') as f: + for line in f: + line = line.strip() + if line.startswith('HUAWEI_TOKEN=') and 'HUAWEI_TOKEN' not in os.environ: + val = line.split('=', 1)[1].strip().strip('"').strip("'") + if val and len(val) > 200: + os.environ['HUAWEI_TOKEN'] = val + logger.info("从持久化文件恢复 Token") + break + except (OSError, IOError): + pass + # 尝试使用生产级 WSGI 服务器 try: import waitress logger.info(f"使用 Waitress 启动网关 ({host}:{port})") - waitress.serve(app, host=host, port=port, threads=16) + waitress.serve(app, host=host, port=port, threads=32) except ImportError: try: import gunicorn.app.wsgiapp