看看以下代码使用pysftp的地方有几处啊 import pysftp import os impo...
创建于:2025年4月13日
创建于:2025年4月13日
看看以下代码使用pysftp的地方有几处啊
import pysftp
import os
import stat
from datetime import datetime
import paramiko
import tempfile
import time
import zipfile
import io
from pathlib import Path
import logging
import asyncio
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
cnopts = pysftp.CnOpts()
cnopts.hostkeys = None # 在生产环境中应该正确处理主机密钥
process_pool = ThreadPoolExecutor(max_workers=8) # 改为线程池,解决序列化问题
thread_pool = ThreadPoolExecutor(max_workers=16) # 添加一个专用的线程池
logger = logging.getLogger(name)
def _sftp_getfo(sftp_handler, remote_path, file_obj, callback=None):
"""可序列化的辅助函数,用于替代lambda"""
return sftp_handler.sftp.getfo(remote_path, file_obj, callback=callback)
def _sftp_stat(sftp_handler, path):
"""可序列化的辅助函数,用于替代lambda"""
return sftp_handler.sftp.stat(path)
def _sftp_walk(sftp_handler, remote_path):
"""可序列化的辅助函数,用于替代lambda"""
return list(sftp_handler.walk(remote_path))
def create_ssh_client(hostname, port, username, password=None, private_key=None, passphrase=None, timeout=30):
"""
创建SSH客户端并处理连接,支持密码和带有密码短语的私钥认证
textArgs: hostname: 主机名或IP地址 port: SSH端口 username: 用户名 password: 密码(可选) private_key: 私钥文件路径(可选) passphrase: 私钥的密码短语(可选) timeout: 连接超时时间(秒) Returns: paramiko.SSHClient: 已连接的SSH客户端 Raises: ValueError: 如果认证方式无效 ConnectionError: 如果连接失败 """ ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) connect_kwargs = { 'hostname': hostname, 'username': username, 'port': port, 'timeout': timeout, 'banner_timeout': timeout, 'auth_timeout': timeout } # 处理私钥和密码短语 if private_key: if isinstance(private_key, str) and os.path.exists(private_key): if passphrase: # 如果有密码短语,尝试不同的私钥类型 try: # 首先尝试RSA try: key = paramiko.RSAKey.from_private_key_file( private_key, password=passphrase ) except: # 如果RSA失败,尝试DSS try: key = paramiko.DSSKey.from_private_key_file( private_key, password=passphrase ) except: # 如果DSS失败,尝试ECDSA try: key = paramiko.ECDSAKey.from_private_key_file( private_key, password=passphrase ) except: # 如果ECDSA失败,尝试Ed25519 key = paramiko.Ed25519Key.from_private_key_file( private_key, password=passphrase ) connect_kwargs['pkey'] = key except Exception as e: raise ValueError(f"无法加载私钥: {str(e)}") else: # 没有密码短语,直接使用key_filename connect_kwargs['key_filename'] = private_key else: raise ValueError("无效的私钥路径") elif password: connect_kwargs['password'] = password else: raise ValueError("未提供密码或私钥") try: ssh.connect(**connect_kwargs) return ssh except Exception as e: raise ConnectionError(f"SSH连接失败: {str(e)}")
class SFTPHandler:
_connection_pool = {} # {connection_key: (handler, last_activity_time)}
_idle_timeout = 1200 # 空闲连接超时时间(秒)
_pool_lock = asyncio.Lock() # 添加异步锁来保护连接池
_connection_locks = {} # 添加连接锁字典来保护每个连接的访问 {connection_key: asyncio.Lock()}
text@classmethod async def get_connection(cls, host, username, port=22, password=None, private_key=None, passphrase=None): """获取或创建SFTP连接""" connection_key = f"{host}:{port}:{username}" current_time = time.time() lock_timeout = 8 # 获取锁的超时时间(秒) # 获取或创建此连接的锁 if connection_key not in cls._connection_locks: cls._connection_locks[connection_key] = asyncio.Lock() try: # 使用超时机制尝试获取连接锁,避免死锁 try: # 使用asyncio.wait_for添加超时控制获取锁 await asyncio.wait_for( cls._connection_locks[connection_key].acquire(), timeout=lock_timeout ) except asyncio.TimeoutError: logger.warning(f"Timeout acquiring connection lock for {connection_key}. Creating new connection.") # 如果获取锁超时,创建一个新的独立连接而不是从池中获取 handler = cls(host, username, port, password, private_key, passphrase) await handler.async_connect() return handler try: async with cls._pool_lock: # 使用异步锁保护连接池访问 # 检查是否存在有效连接 if connection_key in cls._connection_pool: handler, last_activity = cls._connection_pool[connection_key] # 检查连接是否超时 if current_time - last_activity > cls._idle_timeout: # 超时,关闭并删除旧连接 try: handler.close() except Exception as e: logger.warning(f"Error closing timed out connection: {str(e)}") del cls._connection_pool[connection_key] else: # 检查连接是否有效 try: if handler.is_connection_valid(): # 更新活动时间 cls._connection_pool[connection_key] = (handler, current_time) logger.debug(f"Reusing existing SFTP connection for {connection_key}") return handler except Exception as e: logger.warning(f"Connection for {connection_key} is invalid: {str(e)}") try: handler.close() except: pass del cls._connection_pool[connection_key] # 创建新连接 try: # 使用超时控制创建连接 handler = cls(host, username, port, password, private_key, passphrase) await asyncio.wait_for(handler.async_connect(), timeout=30) # 30秒超时 # 存储到连接池 async with cls._pool_lock: cls._connection_pool[connection_key] = (handler, current_time) logger.debug(f"Created new SFTP connection for {connection_key}") return handler except asyncio.TimeoutError: logger.error(f"Connection timeout for {connection_key}") raise ConnectionError(f"Connection timeout for {connection_key}") except TypeError as e: # 特别处理TypeError异常,这通常是由于同时请求导致的 error_msg = f"TypeErr creating connection: {str(e)}" logger.error(error_msg) # 重试一次 try: await asyncio.sleep(0.5) # 短暂等待 handler = cls(host, username, port, password, private_key, passphrase) await asyncio.wait_for(handler.async_connect(), timeout=30) async with cls._pool_lock: cls._connection_pool[connection_key] = (handler, current_time) return handler except Exception as retry_error: logger.error(f"Retry failed: {str(retry_error)}") # 如果重试也失败,创建一个不加入连接池的独立连接 logger.warning(f"Creating standalone connection for {connection_key} after retry failure") handler = cls(host, username, port, password, private_key, passphrase) await handler.async_connect() return handler except Exception as e: logger.error(f"Failed to create connection for {connection_key}: {str(e)}") raise ConnectionError(f"Failed to create connection for {connection_key}: {str(e)}") finally: # 释放连接锁 try: cls._connection_locks[connection_key].release() except Exception as e: logger.warning(f"Error releasing connection lock: {str(e)}") except Exception as e: logger.error(f"Unexpected error in get_connection: {str(e)}") # 发生任何未捕获的异常,创建一个新的独立连接作为后备 try: logger.warning(f"Creating fallback connection for {connection_key}") handler = cls(host, username, port, password, private_key, passphrase) await handler.async_connect() return handler except Exception as fallback_error: logger.error(f"Failed to create fallback connection: {str(fallback_error)}") raise ConnectionError(f"Could not establish any connection to {host}: {str(e)} -> {str(fallback_error)}") async def async_connect(self): """异步建立SFTP连接""" # 如果已经连接,直接返回 if self.sftp and self.connection: try: if self.is_connection_valid(): self.update_activity() return True except: pass last_error = None for attempt in range(self.max_retries): try: # 在事件循环的线程池中执行阻塞的SSH连接操作 loop = asyncio.get_event_loop() # 创建SSH客户端 ssh_args = ( self.host, self.port, self.username, self.password, self.private_key, self.passphrase, self.connect_timeout ) # 使用线程池,但不需要替换为辅助函数,因为create_ssh_client是顶级函数 ssh = await loop.run_in_executor( thread_pool, lambda: create_ssh_client(*ssh_args) ) # 创建SFTP客户端,使用线程池 self.sftp = await loop.run_in_executor(thread_pool, ssh.open_sftp) self.connection = ssh self.update_activity() # 设置初始目录 await self._setup_initial_directory() return True except Exception as e: last_error = str(e) if attempt < self.max_retries - 1: await asyncio.sleep(2 ** attempt) # 使用异步睡眠 continue else: raise ConnectionError(f"Failed to connect after {self.max_retries} attempts: {last_error}") async def _setup_initial_directory(self): """异步设置初始目录""" try: # 尝试获取用户的home目录 possible_home_dirs = [ f"/home/{self.username}", f"/usr/home/{self.username}", f"/Users/{self.username}", f"/var/home/{self.username}", ] # 首先尝试通过pwd命令获取 loop = asyncio.get_event_loop() # 使用线程池执行命令 exec_cmd = await loop.run_in_executor( thread_pool, self.connection.exec_command, 'pwd' ) stdin, stdout, stderr = exec_cmd # 使用线程池读取输出 pwd_bytes = await loop.run_in_executor(thread_pool, stdout.read) pwd_dir = pwd_bytes.decode().strip() if pwd_dir: possible_home_dirs.insert(0, pwd_dir) # 尝试每个可能的路径 for home_dir in possible_home_dirs: try: # 使用线程池切换目录 await loop.run_in_executor( thread_pool, self.sftp.chdir, home_dir ) # 验证是否真的可以访问该目录 await loop.run_in_executor( thread_pool, self.sftp.listdir, '.' ) self.current_directory = home_dir logger.info(f"Successfully changed to home directory: {home_dir}") return except (IOError, OSError): continue # 如果所有home目录都失败,尝试其他基本目录 for basic_dir in ['.', '/']: try: # 使用线程池切换目录 await loop.run_in_executor( thread_pool, self.sftp.chdir, basic_dir ) # 使用线程池获取当前目录 current_dir = await loop.run_in_executor( thread_pool, self.sftp.getcwd ) self.current_directory = current_dir or basic_dir logger.info(f"Changed to directory: {self.current_directory}") return except (IOError, OSError): continue # 如果所有尝试都失败 self.current_directory = '.' logger.warning("Could not determine a valid working directory") except Exception as e: logger.warning(f"Error while setting initial directory: {str(e)}") self.current_directory = '.' def is_connection_valid(self): """检查连接是否有效""" if not self.sftp or not self.connection: return False try: # 尝试执行一个简单的命令来检测连接是否有效 self.connection.exec_command('echo 1', timeout=5) # 或者尝试列出当前目录 self.sftp.listdir('.') return True except Exception as e: logger.warning(f"Connection validation failed: {str(e)}") return False def update_activity(self): """更新最后活动时间""" self.last_activity = time.time() # 同时更新连接池中的记录 connection_key = f"{self.host}:{self.port}:{self.username}" if connection_key in self._connection_pool: handler, _ = self._connection_pool[connection_key] if handler is self: # 确保是同一个实例 self._connection_pool[connection_key] = (handler, self.last_activity) def connect(self): """建立SFTP连接,带有重试机制""" # 如果已经连接,直接返回 if self.sftp and self.connection: try: # 验证连接是否有效 if self.is_connection_valid(): self.update_activity() return True except: # 连接可能已失效,继续尝试重新连接 pass # 连接不存在或已失效,重新连接 last_error = None for attempt in range(self.max_retries): try: # 使用辅助函数创建SSH客户端 ssh = create_ssh_client( hostname=self.host, port=self.port, username=self.username, password=self.password, private_key=self.private_key, passphrase=self.passphrase, timeout=self.connect_timeout ) # 创建 SFTP 客户端 self.sftp = ssh.open_sftp() self.connection = ssh self.update_activity() # 更新活动时间 # 尝试获取用户的 home 目录 try: # 尝试可能的 home 目录路径 possible_home_dirs = [ f"/home/{self.username}", # 标准路径 f"/usr/home/{self.username}", # BSD 风格路径 f"/Users/{self.username}", # macOS 路径 f"/var/home/{self.username}", # 某些系统的路径 ] # 首先尝试通过 pwd 命令获取 stdin, stdout, stderr = ssh.exec_command('pwd') pwd_dir = stdout.read().decode().strip() if pwd_dir: possible_home_dirs.insert(0, pwd_dir) # 将 pwd 返回的路径放在最前面 # 尝试每个可能的路径 for home_dir in possible_home_dirs: try: self.sftp.chdir(home_dir) # 验证是否真的可以访问该目录 self.sftp.listdir('.') self.current_directory = home_dir print(f"Successfully changed to home directory: {home_dir}") return True except (IOError, OSError): continue # 如果所有 home 目录都失败,尝试其他基本目录 for basic_dir in ['.', '/']: try: self.sftp.chdir(basic_dir) self.current_directory = self.sftp.getcwd() or basic_dir print(f"Changed to directory: {self.current_directory}") return True except (IOError, OSError): continue # 如果所有尝试都失败 self.current_directory = '.' print("Warning: Could not determine a valid working directory") return True except Exception as e: print(f"Warning: Error while setting initial directory: {str(e)}") self.current_directory = '.' return True except Exception as e: last_error = str(e) if attempt < self.max_retries - 1: # 如果不是最后一次尝试,等待后重试 time.sleep(2 ** attempt) # 指数退避 continue else: # 最后一次尝试也失败了 raise ConnectionError(f"Failed to connect after {self.max_retries} attempts: {last_error}") def list_directory(self, path='.'): """列出目录内容""" self.update_activity() # 更新活动时间 try: # 如果path为空或为'.',使用当前目录 if not path or path == '.': path = self.current_directory or '.' # 如果提供的是相对路径,并且我们知道当前目录,则构建完整路径 if not path.startswith('/') and self.current_directory and self.current_directory != '.': full_path = os.path.join(self.current_directory, path).replace('\\', '/') else: full_path = path try: items = [] for entry in self.sftp.listdir_attr(full_path): try: # 获取权限字符串 mode = entry.st_mode perms = '' # 检查文件类型 is_symlink = stat.S_ISLNK(mode) if is_symlink: perms += 'l' # 符号链接 try: # 尝试获取链接目标的信息 target_path = os.path.join(full_path, entry.filename).replace('\\', '/') target_stat = self.sftp.stat(target_path) is_target_dir = stat.S_ISDIR(target_stat.st_mode) is_target_readable = bool(target_stat.st_mode & stat.S_IRUSR) except Exception as e: # 如果无法访问目标,记录错误但继续处理 is_target_dir = False is_target_readable = False logger.warning(f"Cannot access symlink target for {entry.filename}: {str(e)}") elif stat.S_ISDIR(mode): perms += 'd' # 目录 else: perms += '-' # 普通文件 perms += 'r' if mode & stat.S_IRUSR else '-' perms += 'w' if mode & stat.S_IWUSR else '-' perms += 'x' if mode & stat.S_IXUSR else '-' perms += 'r' if mode & stat.S_IRGRP else '-' perms += 'w' if mode & stat.S_IWGRP else '-' perms += 'x' if mode & stat.S_IXGRP else '-' perms += 'r' if mode & stat.S_IROTH else '-' perms += 'w' if mode & stat.S_IWOTH else '-' perms += 'x' if mode & stat.S_IXOTH else '-' # 确定文件类型 file_type = 'directory' if stat.S_ISDIR(mode) else 'file' if is_symlink: file_type = 'symlink' item = { 'name': entry.filename, 'path': os.path.join(full_path, entry.filename).replace('\\', '/'), 'size': entry.st_size, 'type': file_type, 'permissions': perms, # 使用新的权限格式 'permissions_octal': oct(mode)[-3:], # 保留八进制格式 'modified': datetime.fromtimestamp(entry.st_mtime).strftime('%Y-%m-%d %H:%M:%S'), 'is_broken_link': is_symlink and not is_target_readable, # 添加标记表示是否是损坏的链接 'is_link_to_dir': is_symlink and is_target_dir # 添加标记表示链接是否指向目录 } items.append(item) except Exception as e: logger.warning(f"Error processing entry {entry.filename}: {str(e)}") continue return items except IOError as e: if "Permission denied" in str(e): # 如果访问被拒绝,尝试使用当前目录 if path != self.current_directory: logger.warning(f"Permission denied for {full_path}, trying current directory") return self.list_directory(self.current_directory) else: raise PermissionError(str(e)) else: raise except Exception as e: # 转换TypeError为更具体的错误 if isinstance(e, TypeError): raise TypeError(f"Load failed: {str(e)}") raise Exception(f"Failed to list directory: {str(e)}") def download_file(self, remote_path, local_path, callback=None): """下载文件,支持进度回调""" self.update_activity() # 更新活动时间 try: # 如果是相对路径,使用当前目录 if not remote_path.startswith('/') and self.current_directory and self.current_directory != '.': remote_path = os.path.join(self.current_directory, remote_path).replace('\\', '/') self.sftp.get(remote_path, local_path) if callback: callback(1, 1) # 简单的进度回调 return True except Exception as e: raise Exception(f"Failed to download file: {str(e)}") def upload_file(self, local_path, remote_path, callback=None): """上传文件,支持进度回调""" self.update_activity() # 更新活动时间 try: # 如果是相对路径,使用当前目录 if not remote_path.startswith('/') and self.current_directory and self.current_directory != '.': remote_path = os.path.join(self.current_directory, remote_path).replace('\\', '/') self.sftp.put(local_path, remote_path) if callback: callback(1, 1) # 简单的进度回调 return True except Exception as e: raise Exception(f"Failed to upload file: {str(e)}") def create_directory(self, path): """创建目录""" self.update_activity() # 更新活动时间 try: # 如果是相对路径,使用当前目录 if not path.startswith('/') and self.current_directory and self.current_directory != '.': path = os.path.join(self.current_directory, path).replace('\\', '/') self.sftp.mkdir(path) return True except Exception as e: raise Exception(f"Failed to create directory: {str(e)}") def remove_file(self, path): """删除文件""" self.update_activity() # 更新活动时间 try: # 如果是相对路径,使用当前目录 if not path.startswith('/') and self.current_directory and self.current_directory != '.': path = os.path.join(self.current_directory, path).replace('\\', '/') self.sftp.remove(path) return True except Exception as e: raise Exception(f"Failed to remove file: {str(e)}") def remove_directory(self, path): """递归删除目录及其内容 Args: path: 目录路径 """ self.update_activity() # 更新活动时间 try: # 如果是相对路径,使用当前目录 if not path.startswith('/') and self.current_directory and self.current_directory != '.': path = os.path.join(self.current_directory, path).replace('\\', '/') # 递归删除目录内容 for entry in self.sftp.listdir_attr(path): entry_path = os.path.join(path, entry.filename).replace('\\', '/') if stat.S_ISDIR(entry.st_mode): # 如果是目录,递归删除 self.remove_directory(entry_path) else: # 如果是文件,直接删除 self.sftp.remove(entry_path) # 删除空目录 self.sftp.rmdir(path) return True except Exception as e: raise Exception(f"Failed to remove directory: {str(e)}") def rename(self, old_path, new_path): """重命名文件或目录""" self.update_activity() # 更新活动时间 try: # 如果是相对路径,使用当前目录 if not old_path.startswith('/') and self.current_directory and self.current_directory != '.': old_path = os.path.join(self.current_directory, old_path).replace('\\', '/') if not new_path.startswith('/') and self.current_directory and self.current_directory != '.': new_path = os.path.join(self.current_directory, new_path).replace('\\', '/') self.sftp.rename(old_path, new_path) return True except Exception as e: raise Exception(f"Failed to rename: {str(e)}") def get_file_info(self, path): """获取文件信息""" self.update_activity() # 更新活动时间 try: # 如果是相对路径,使用当前目录 if not path.startswith('/') and self.current_directory and self.current_directory != '.': path = os.path.join(self.current_directory, path).replace('\\', '/') stat = self.sftp.stat(path) return { 'size': stat.st_size, 'type': 'directory' if stat.S_ISDIR(stat.st_mode) else 'file', 'permissions': oct(stat.st_mode)[-3:], 'modified': datetime.fromtimestamp(stat.st_mtime).strftime('%Y-%m-%d %H:%M:%S') } except Exception as e: raise Exception(f"Failed to get file info: {str(e)}") def chmod(self, path, mode): """修改文件或目录的权限 Args: path: 文件或目录的路径 mode: 权限模式,可以是八进制数字(如0o755)或字符串('755') """ self.update_activity() # 更新活动时间 try: # 如果是相对路径,使用当前目录 if not path.startswith('/') and self.current_directory and self.current_directory != '.': path = os.path.join(self.current_directory, path).replace('\\', '/') # 如果mode是字符串,转换为八进制数字 if isinstance(mode, str): mode = int(mode, 8) self.sftp.chmod(path, mode) return True except Exception as e: raise Exception(f"Failed to change permissions: {str(e)}") def close(self): """关闭连接""" if self.sftp: self.sftp.close() if self.connection: self.connection.close() self.sftp = None self.connection = None async def send_progress(self, info): """异步发送进度信息""" self.update_activity() # 更新活动时间 try: if self.progress_callback and not self.download_cancelled: # 添加会话ID到进度信息中 if self.current_session_id: info['session_id'] = self.current_session_id try: # 尝试获取当前事件循环,如果没有抛出异常 loop = asyncio.get_running_loop() # 使用asyncio.wait_for添加超时保护,防止WebSocket已断开但进度回调阻塞 try: await asyncio.wait_for(self.progress_callback(info), timeout=2.0) except asyncio.TimeoutError: logger.warning("Progress callback timed out, possibly due to closed WebSocket") self.download_cancelled = True # 如果回调超时,标记为已取消 raise Exception("Download cancelled due to connection loss") except RuntimeError as e: if "no running event loop" in str(e): logger.debug("No running event loop for progress update, storing for later") # 保存进度信息供以后处理 self._pending_progress_info = info else: raise except Exception as e: logger.error(f"Error sending progress: {str(e)}") # 如果是WebSocket错误,可能是连接已关闭,取消下载 if "WebSocket" in str(e) or "connection" in str(e).lower(): self.download_cancelled = True logger.warning("WebSocket connection issue detected, cancelling download") return # 发送错误消息 try: if self.progress_callback and not self.download_cancelled: # 尝试获取当前事件循环,如果不存在则忽略 try: asyncio.get_running_loop() await asyncio.wait_for( self.progress_callback({ 'type': 'error', 'message': '发送进度信息时出错', 'session_id': self.current_session_id }), timeout=2.0 ) except (RuntimeError, asyncio.TimeoutError): # 没有运行中的事件循环或超时,忽略错误 pass except Exception: pass # 如果连错误消息都发送不了,就忽略它 def cancel_download(self, session_id=None): """取消下载 Args: session_id: 可选的会话ID,如果提供,只有匹配的会话才会被取消 """ # 如果提供了会话ID,只有当前会话ID匹配时才取消 if session_id is None or session_id == self.current_session_id: self.download_cancelled = True logger.info(f"Download cancelled for session {self.current_session_id}") def walk(self, remote_path): """ 遍历远程目录的生成器函数,类似于 os.walk Args: remote_path: 远程目录路径 Yields: tuple: (当前目录路径, [子目录列表], [文件列表]) """ try: files = [] dirs = [] for entry in self.sftp.listdir_attr(remote_path): try: # 检查是否是符号链接 if stat.S_ISLNK(entry.st_mode): try: # 获取链接目标的信息 target_path = os.path.join(remote_path, entry.filename).replace('\\', '/') target_stat = self.sftp.stat(target_path) # 如果链接指向目录,将其作为目录处理 if stat.S_ISDIR(target_stat.st_mode): dirs.append(entry.filename) else: files.append(entry.filename) except Exception as e: # 如果无法访问链接目标,将其添加到文件列表 logger.warning(f"Cannot access symlink target for {entry.filename}: {str(e)}") files.append(entry.filename) elif stat.S_ISDIR(entry.st_mode): dirs.append(entry.filename) else: files.append(entry.filename) except Exception as e: logger.warning(f"Error processing entry {entry.filename}: {str(e)}") continue yield remote_path, dirs, files for dir_name in dirs: dir_path = os.path.join(remote_path, dir_name).replace('\\', '/') try: # 递归遍历子目录 for x in self.walk(dir_path): yield x except Exception as e: logger.warning(f"Error walking directory {dir_path}: {str(e)}") continue except Exception as e: logger.error(f"Error listing directory {remote_path}: {str(e)}") yield remote_path, [], [] async def walk_async(self, remote_path): """ 异步遍历目录,使用线程池执行阻塞操作 """ loop = asyncio.get_event_loop() try: # 使用线程池获取目录列表,使用辅助函数替代lambda result = await loop.run_in_executor( thread_pool, _sftp_walk, self, remote_path ) return result except Exception as e: logger.error(f"Error in async walk: {str(e)}") return [(remote_path, [], [])] async def process_pending_progress(self): """处理在回调函数中设置的待处理进度信息""" if hasattr(self, '_pending_progress_info') and self._pending_progress_info: progress_info = self._pending_progress_info self._pending_progress_info = None # 清除待处理的进度信息 if self.progress_callback: await self.send_progress(progress_info) async def download_file_or_directory(self, remote_path, callback=None, session_id=None): self.progress_callback = callback self.download_cancelled = False # 重置取消标志 self.current_session_id = session_id # 设置当前会话ID self._pending_progress_info = None # 初始化待处理进度信息 is_dir = False is_symlink = False skipped_items = [] total_size = 0 processed_size = 0 start_time = time.time() try: # 处理路径 if not remote_path.startswith('/') and self.current_directory and self.current_directory != '.': remote_path = os.path.join(self.current_directory, remote_path).replace('\\', '/') logger.info(f"Starting download of {remote_path} (Session ID: {session_id})") # 获取文件信息 try: stat_result = self.sftp.stat(remote_path) is_dir = stat.S_ISDIR(stat_result.st_mode) is_symlink = stat.S_ISLNK(stat_result.st_mode) # 如果是符号链接,检查目标是否可访问 if is_symlink: try: target_stat = self.sftp.stat(remote_path) is_target_readable = bool(target_stat.st_mode & stat.S_IRUSR) if not is_target_readable: raise Exception("链接目标不可读") except Exception as e: logger.error(f"Cannot access symlink target for {remote_path}: {str(e)}") if self.progress_callback: await self.send_progress({ 'type': 'error', 'message': f"无法访问链接目标: {str(e)}" }) return None, None # 如果是符号链接,我们将其视为普通文件下载 if is_symlink: is_dir = False if not is_dir: total_size = stat_result.st_size except IOError as e: logger.error(f"Failed to access path '{remote_path}': {str(e)}") if self.progress_callback: await self.send_progress({ 'type': 'error', 'message': f"无法访问文件: {str(e)}" }) return None, None if not is_dir: # 如果是文件或符号链接,直接下载 try: # 发送文件开始下载的消息 if self.progress_callback: await self.send_progress({ 'type': 'file_start', 'filename': os.path.basename(remote_path) }) # 使用 BytesIO 来存储文件内容 bio = io.BytesIO() def progress(transferred, total): if self.download_cancelled: raise Exception("Download cancelled by user") try: current_time = time.time() elapsed_time = current_time - start_time speed = transferred / elapsed_time if elapsed_time > 0 else 0 remaining_time = (total - transferred) / speed if speed > 0 else 0 progress_info = { 'type': 'progress', 'filename': os.path.basename(remote_path), 'current_file': os.path.basename(remote_path), 'total': total, 'processed': transferred, 'speed': speed, 'elapsed_time': elapsed_time, 'remaining_time': remaining_time, 'percentage': (transferred / total * 100) if total > 0 else 0 } # 添加节流机制:记录上一次更新时间 now = time.time() # 将节流状态存储在对象属性中,避免每次函数调用都重置 if not hasattr(self, '_last_progress_time') or now - self._last_progress_time > 0.5: # 每0.5秒更新一次 self._last_progress_time = now # 不能在非协程中直接创建异步任务 # 将进度信息保存到对象属性中,稍后处理 self._pending_progress_info = progress_info except Exception as e: logger.error(f"Error in progress callback: {str(e)}") # 不要在这里抛出异常,让下载继续进行 try: # 使用 getfo 方法下载文件并跟踪进度,添加超时控制 loop = asyncio.get_event_loop() await asyncio.wait_for( loop.run_in_executor( thread_pool, _sftp_getfo, self, remote_path, bio, progress ), timeout=1200 # 20分钟超时 ) # 处理可能在回调期间累积的进度信息 await self.process_pending_progress() except asyncio.TimeoutError: logger.error(f"Timeout downloading file {remote_path}") if self.progress_callback: await self.send_progress({ 'type': 'error', 'message': "下载文件超时,请重试或下载较小的文件" }) return None, None except Exception as e: if str(e) == "Download cancelled by user": raise logger.error(f"Error downloading file: {str(e)}") if self.progress_callback: await self.send_progress({ 'type': 'error', 'message': f"下载文件时出错: {str(e)}" }) return None, None # 发送文件完成的消息 if self.progress_callback: await self.send_progress({ 'type': 'file_complete', 'filename': os.path.basename(remote_path) }) return bio.getvalue(), os.path.basename(remote_path) except Exception as e: if str(e) == "Download cancelled by user": logger.info("Download cancelled by user, operation terminated") if self.progress_callback: try: await self.send_progress({ 'type': 'cancelled', 'message': '下载已取消' }) except Exception as msg_error: logger.warning(f"Error sending cancel message: {str(msg_error)}") return None, None logger.error(f"Failed to download file '{remote_path}': {str(e)}") if self.progress_callback: await self.send_progress({ 'type': 'error', 'message': f"下载文件失败: {str(e)}" }) return None, None else: # 如果是目录,创建zip文件 try: # 发送目录开始下载的消息 directory_name = os.path.basename(remote_path) if self.progress_callback: await self.send_progress({ 'type': 'directory_start', 'directory_name': directory_name }) # 首先计算总大小 logger.info("Calculating directory size...") if self.progress_callback: await self.send_progress({ 'type': 'progress', 'status': '正在计算目录大小...', 'current_file': directory_name, 'processed': 0, 'total': 0 }) total_size = 0 file_count = 0 scanned_count = 0 # 在此处获取事件循环,避免在循环内部重复获取或引用错误 loop = asyncio.get_event_loop() # 首先扫描所有文件以获取总大小和文件数量 # 使用异步walk替代同步walk walk_results = await self.walk_async(remote_path) for root, dirs, files in walk_results: # 检查是否已取消下载 if self.download_cancelled: logger.info("Download cancelled during directory size calculation") if self.progress_callback: await self.send_progress({ 'type': 'cancelled', 'message': '下载已取消' }) return None, None batch_count = 0 # 批处理计数器 for file in files: # 检查是否已取消下载 if self.download_cancelled: logger.info("Download cancelled during file size calculation") if self.progress_callback: await self.send_progress({ 'type': 'cancelled', 'message': '下载已取消' }) return None, None file_path = os.path.join(root, file).replace('\\', '/') try: # 使用线程池执行stat操作 stat_result = await loop.run_in_executor( thread_pool, _sftp_stat, self, file_path ) total_size += stat_result.st_size file_count += 1 scanned_count += 1 batch_count += 1 # 每扫描10个文件发送一次进度更新 if scanned_count % 10 == 0 and self.progress_callback: await self.send_progress({ 'type': 'progress', 'status': f'正在扫描文件 ({scanned_count} 个已扫描)...', 'current_file': file, 'processed': 0, 'total': 0 }) # 处理可能在线程池操作过程中累积的进度信息 await self.process_pending_progress() # 每批处理50个文件后,让出事件循环 if batch_count >= 50: batch_count = 0 await asyncio.sleep(0) # 让出事件循环 # 处理可能在线程池操作过程中累积的进度信息 await self.process_pending_progress() except Exception as e: logger.warning(f"Error scanning file {file_path}: {str(e)}") logger.info(f"Total size to download: {total_size} bytes") # 发送开始压缩的消息 if self.progress_callback: await self.send_progress({ 'type': 'progress', 'status': '开始压缩文件...', 'current_file': directory_name, 'total': total_size, 'processed': 0 }) zip_buffer = io.BytesIO() # 已在上面获取了loop,此处注释掉以避免重复声明 # loop = asyncio.get_event_loop() with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: processed_size = 0 processed_files = 0 # 使用异步目录遍历 walk_results = await self.walk_async(remote_path) for root, dirs, files in walk_results: # 如果下载已被取消,立即退出循环 if self.download_cancelled: logger.info("Download cancelled, stopping directory traversal") break # 首先添加目录 for dir_name in dirs: # 如果下载已被取消,立即退出循环 if self.download_cancelled: logger.info("Download cancelled, stopping directory processing") break dir_path = os.path.join(root, dir_name).replace('\\', '/') rel_path = os.path.relpath(dir_path, remote_path).replace('\\', '/') zip_info = zipfile.ZipInfo(rel_path + '/') zip_file.writestr(zip_info, '') # 每处理一个目录后让出事件循环 if dirs.index(dir_name) % 20 == 19: # 每20个目录 await asyncio.sleep(0) # 如果在添加目录时取消了下载,直接返回 if self.download_cancelled: if self.progress_callback: await self.send_progress({ 'type': 'cancelled', 'message': '下载已取消' }) return None, None # 然后添加文件 batch_count = 0 # 批处理计数器 for file_name in files: # 如果下载已被取消,立即退出循环 if self.download_cancelled: logger.info("Download cancelled, stopping file processing") break file_path = os.path.join(root, file_name).replace('\\', '/') rel_path = os.path.relpath(file_path, remote_path).replace('\\', '/') try: # 发送开始下载文件的消息 if self.progress_callback: await self.send_progress({ 'type': 'file_start', 'filename': file_name }) # 下载文件内容 bio = io.BytesIO() # 获取文件大小使用线程池 file_stat = await loop.run_in_executor( thread_pool, _sftp_stat, self, file_path ) file_size = file_stat.st_size def file_progress(transferred, total): if self.download_cancelled: raise Exception("Download cancelled by user") nonlocal processed_size current_total = processed_size + transferred current_time = time.time() elapsed_time = current_time - start_time speed = current_total / elapsed_time if elapsed_time > 0 else 0 # 创建进度信息 progress_info = { 'type': 'progress', 'filename': file_name, 'current_file': rel_path, 'total': total_size, 'processed': current_total, 'speed': speed, 'elapsed_time': elapsed_time, 'remaining_time': ((total_size - current_total) / speed) if speed > 0 else 0, 'percentage': (current_total / total_size * 100) if total_size > 0 else 0 } # 添加节流机制:记录上一次更新时间 now = time.time() # 将节流状态存储在对象属性中,避免每次函数调用都重置 if not hasattr(self, '_last_progress_time') or now - self._last_progress_time > 0.5: # 每0.5秒更新一次 self._last_progress_time = now # 不能在非协程中直接创建异步任务 # 将进度信息保存到对象属性中,稍后处理 self._pending_progress_info = progress_info await loop.run_in_executor( thread_pool, _sftp_getfo, self, file_path, bio, file_progress ) # 处理可能在回调期间累积的进度信息 await self.process_pending_progress() bio.seek(0) zip_file.writestr(rel_path, bio.getvalue()) processed_size += file_size processed_files += 1 batch_count += 1 # 发送文件完成的消息 if self.progress_callback: await self.send_progress({ 'type': 'file_complete', 'filename': file_name, 'processed_files': processed_files, 'total_files': file_count }) # 每批处理10个文件后,让出事件循环 if batch_count >= 10: batch_count = 0 await asyncio.sleep(0) # 让出事件循环 except Exception as e: logger.warning(f"Error processing file {file_path}: {str(e)}") # 如果是取消下载的异常,停止处理其他文件 if str(e) == "Download cancelled by user": logger.info("Download explicitly cancelled, stopping all file processing") # 跳出所有循环 skipped_items.append(f"{file_path} (用户取消下载)") if self.progress_callback: await self.send_progress({ 'type': 'cancelled', 'message': '下载已取消' }) return None, None else: skipped_items.append(f"{file_path} (Error: {str(e)})") # 如果有跳过的项目,添加说明文件 if skipped_items: info_content = "以下文件或目录无法访问:\n\n" + "\n".join(skipped_items) zip_file.writestr("_skipped_items.txt", info_content.encode('utf-8')) # 发送目录完成的消息 if self.progress_callback: await self.send_progress({ 'type': 'directory_complete', 'directory_name': directory_name, 'total_files': file_count, 'processed_files': processed_files }) zip_data = zip_buffer.getvalue() zip_buffer.close() if not zip_data: raise Exception("No data was written to the ZIP file") suggested_filename = os.path.basename(remote_path) + '.zip' logger.info(f"Successfully created ZIP archive for {remote_path}") return zip_data, suggested_filename except Exception as e: if str(e) == "Download cancelled by user": logger.info("Download cancelled by user, operation terminated") if self.progress_callback: try: await self.send_progress({ 'type': 'cancelled', 'message': '下载已取消' }) except Exception as msg_error: logger.warning(f"Error sending cancel message: {str(msg_error)}") return None, None logger.error(f"Failed to create ZIP archive for directory '{remote_path}': {str(e)}") raise Exception(f"Failed to create ZIP archive for directory '{remote_path}': {str(e)}") except Exception as e: if str(e) == "Download cancelled by user": logger.info("Download cancelled by user, operation terminated") if self.progress_callback: try: await self.send_progress({ 'type': 'cancelled', 'message': '下载已取消' }) except Exception as msg_error: logger.warning(f"Error sending cancel message: {str(msg_error)}") return None, None error_type = "directory" if is_dir else "file" logger.error(f"Download error for {error_type} '{remote_path}': {str(e)}") if self.progress_callback: await self.send_progress({ 'type': 'error', 'message': str(e) }) raise finally: # 清理资源 self.progress_callback = None self.download_cancelled = False self.current_session_id = None self._pending_progress_info = None def get_directory_size(self, remote_path): """获取目录大小 Args: remote_path: 远程目录路径 Returns: int: 目录总大小(字节) """ total_size = 0 try: for entry in self.sftp.listdir_attr(remote_path): try: # 检查是否是符号链接 if stat.S_ISLNK(entry.st_mode): try: # 获取链接目标的信息 target_path = os.path.join(remote_path, entry.filename).replace('\\', '/') target_stat = self.sftp.stat(target_path) # 如果链接指向目录,递归计算目录大小 if stat.S_ISDIR(target_stat.st_mode): child_path = os.path.join(remote_path, entry.filename).replace('\\', '/') total_size += self.get_directory_size(child_path) else: total_size += target_stat.st_size except Exception as e: # 如果无法访问链接目标,只计算链接本身的大小 logger.warning(f"Cannot access symlink target for {entry.filename}: {str(e)}") total_size += entry.st_size elif stat.S_ISDIR(entry.st_mode): # 递归计算子目录大小 child_path = os.path.join(remote_path, entry.filename).replace('\\', '/') total_size += self.get_directory_size(child_path) else: total_size += entry.st_size except Exception as e: logger.warning(f"Error processing entry {entry.filename}: {str(e)}") continue return total_size except Exception as e: raise Exception(f"Failed to get directory size: {str(e)}") def read_file(self, path): """读取文件内容 Args: path: 文件路径 Returns: bytes: 文件内容的字节数据 """ try: # 如果是相对路径,使用当前目录 if not path.startswith('/') and self.current_directory and self.current_directory != '.': path = os.path.join(self.current_directory, path).replace('\\', '/') with tempfile.NamedTemporaryFile() as temp_file: self.sftp.get(path, temp_file.name) with open(temp_file.name, 'rb') as f: return f.read() except Exception as e: raise Exception(f"Failed to read file: {str(e)}") def write_file(self, path, content): """写入文件内容 Args: path: 文件路径 content: 要写入的内容 """ try: # 如果是相对路径,使用当前目录 if not path.startswith('/') and self.current_directory and self.current_directory != '.': path = os.path.join(self.current_directory, path).replace('\\', '/') with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', delete=False) as temp_file: temp_file.write(content) temp_file.flush() try: self.sftp.put(temp_file.name, path) finally: os.unlink(temp_file.name) return True except Exception as e: raise Exception(f"Failed to write file: {str(e)}") @classmethod async def close_connection(cls, host, username, port=22): """关闭特定的连接""" connection_key = f"{host}:{port}:{username}" async with cls._pool_lock: if connection_key in cls._connection_pool: handler, _ = cls._connection_pool[connection_key] try: # 在事件循环的线程池中执行关闭操作 loop = asyncio.get_event_loop() await loop.run_in_executor(thread_pool, handler.close) except Exception as e: logger.warning(f"Error closing connection: {str(e)}") del cls._connection_pool[connection_key] logger.debug(f"Closed SFTP connection for {connection_key}") @classmethod async def cleanup_idle_connections(cls): """清理空闲连接""" current_time = time.time() keys_to_remove = [] async with cls._pool_lock: for key, (handler, last_activity) in cls._connection_pool.items(): if current_time - last_activity > cls._idle_timeout: keys_to_remove.append(key) for key in keys_to_remove: try: handler, _ = cls._connection_pool[key] loop = asyncio.get_event_loop() await loop.run_in_executor(thread_pool, handler.close) except Exception as e: logger.warning(f"Error closing idle connection for {key}: {str(e)}") del cls._connection_pool[key] logger.debug(f"Cleaned up {len(keys_to_remove)} idle connections") def __init__(self, host, username, port=22, password=None, private_key=None, passphrase=None): self.host = host self.port = port self.username = username self.password = password self.private_key = private_key self.passphrase = passphrase self.sftp = None self.connection = None self.max_retries = 2 self.connect_timeout = 8 self.operation_timeout = 60 self.current_directory = None self.progress_callback = None self.download_cancelled = False # 添加取消标志 self.current_session_id = None # 添加会话ID self.last_activity = time.time() # 添加最后活动时间
从代码上看,实际调用 pysftp
的地方只有两处(不算 import 语句的话):
cnopts = pysftp.CnOpts()
cnopts.hostkeys = None
除此之外,SFTP 的大部分操作(如连接、上传、下载等)都改用 paramiko
提供的 SSHClient
和 SFTPClient
来完成了。也就是说,代码里只是在最开始用 pysftp
生成了一个 CnOpts
对象并设置了 hostkeys = None
,其他逻辑并没有真正用到 pysftp
的会话或传输方法。