Source code for pynetim.utils

# -*- coding: utf-8 -*-
"""工具函数模块。

提供通用的工具函数。
"""

from typing import Optional, List, Tuple, TYPE_CHECKING
from collections import deque
import heapq

if TYPE_CHECKING:
    from ..graph import IMGraph

from .graph_utils import (
    generate_er_graph,
    generate_ba_graph,
    generate_ws_graph,
    compute_sir_beta,
    load_edgelist,
    save_edgelist,
    to_networkx,
    to_igraph,
    to_scipy_sparse,
    to_pyg,
)

try:
    from .rr_utils import sample_rr_set_ic, sample_rr_set_lt, generate_rr_sets
    RR_UTILS_AVAILABLE = True
except ImportError:
    RR_UTILS_AVAILABLE = False


[docs] def renumber_edges(edges: List[Tuple[int, int]]) -> Tuple[List[Tuple[int, int]], dict, List[int]]: """重新编号边列表中的节点为连续整数。 Args: edges: 边列表,每个元素为 (u, v) 元组。 Returns: Tuple[List[Tuple[int, int]], dict, List[int]]: - 重编号后的边列表 - 原始节点ID到新节点ID的映射 - 新节点ID到原始节点ID的映射 Example: >>> from pynetim.utils import renumber_edges >>> edges = [(10, 20), (20, 30), (30, 40)] >>> new_edges, mapping, reverse_mapping = renumber_edges(edges) >>> print(new_edges) # [(0, 1), (1, 2), (2, 3)] """ all_nodes = set() for u, v in edges: all_nodes.add(u) all_nodes.add(v) sorted_nodes = sorted(all_nodes) node_to_id = {node: i for i, node in enumerate(sorted_nodes)} id_to_node = sorted_nodes new_edges = [(node_to_id[u], node_to_id[v]) for u, v in edges] return new_edges, node_to_id, id_to_node
[docs] def shortest_path_length( graph: 'IMGraph', source: int, target: int, use_weight: bool = False ) -> Optional[float]: """计算两个节点之间的最短路径长度。 Args: graph: 图对象。 source: 源节点。 target: 目标节点。 use_weight: 是否使用边权重计算最短路径。 - False (默认): 使用跳数(边数)作为距离 - True: 使用边权重之和作为距离 Returns: Optional[float]: 最短路径长度,如果不可达则返回 None。 - use_weight=False 时返回整数(跳数) - use_weight=True 时返回浮点数(权重和) Example: >>> from pynetim.utils import shortest_path_length >>> # 基于跳数 >>> dist = shortest_path_length(graph, 0, 5, use_weight=False) >>> # 基于权重 >>> dist = shortest_path_length(graph, 0, 5, use_weight=True) """ if source == target: return 0.0 if use_weight else 0 if use_weight: return _dijkstra_shortest_path(graph, source, target) else: return _bfs_shortest_path(graph, source, target)
def _bfs_shortest_path( graph: 'IMGraph', source: int, target: int ) -> Optional[int]: """BFS 计算无权最短路径(跳数)。 Args: graph: 图对象。 source: 源节点。 target: 目标节点。 Returns: Optional[int]: 最短路径跳数,如果不可达则返回 None。 """ visited = {source} queue = deque([(source, 0)]) while queue: node, dist = queue.popleft() for neighbor in graph.out_neighbors(node): if neighbor == target: return dist + 1 if neighbor not in visited: visited.add(neighbor) queue.append((neighbor, dist + 1)) return None def _dijkstra_shortest_path( graph: 'IMGraph', source: int, target: int ) -> Optional[float]: """Dijkstra 算法计算带权最短路径。 Args: graph: 图对象。 source: 源节点。 target: 目标节点。 Returns: Optional[float]: 最短路径权重和,如果不可达则返回 None。 """ distances = {source: 0.0} heap = [(0.0, source)] visited = set() while heap: current_dist, node = heapq.heappop(heap) if node in visited: continue visited.add(node) if node == target: return current_dist neighbors, weights = graph.out_neighbors_arrays(node) for neighbor, weight in zip(neighbors, weights): if neighbor in visited: continue new_dist = current_dist + weight if neighbor not in distances or new_dist < distances[neighbor]: distances[neighbor] = new_dist heapq.heappush(heap, (new_dist, neighbor)) return None
[docs] def all_pairs_shortest_path_length( graph: 'IMGraph', nodes: Optional[List[int]] = None, use_weight: bool = False ) -> dict: """计算节点对之间的最短路径长度。 Args: graph: 图对象。 nodes: 要计算的节点列表,如果为 None 则计算所有节点。 use_weight: 是否使用边权重。 Returns: dict: 嵌套字典,distances[u][v] 表示 u 到 v 的最短路径长度。 Example: >>> from pynetim.utils import all_pairs_shortest_path_length >>> distances = all_pairs_shortest_path_length(graph, nodes=[0, 1, 2]) >>> print(distances[0][1]) """ if nodes is None: nodes = list(range(graph.num_nodes)) distances = {} for source in nodes: distances[source] = {} for target in nodes: if source == target: distances[source][target] = 0.0 if use_weight else 0 else: dist = shortest_path_length(graph, source, target, use_weight) distances[source][target] = dist return distances
__all__ = [ 'generate_er_graph', 'generate_ba_graph', 'generate_ws_graph', 'compute_sir_beta', 'load_edgelist', 'save_edgelist', 'to_networkx', 'to_igraph', 'to_scipy_sparse', 'to_pyg', 'renumber_edges', 'shortest_path_length', 'all_pairs_shortest_path_length', ] if RR_UTILS_AVAILABLE: __all__.extend(['sample_rr_set_ic', 'sample_rr_set_lt', 'generate_rr_sets'])