-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
73 lines (59 loc) · 2.63 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from typing import Optional, Union, Tuple
from torch_sparse import coalesce
from torch import Tensor
import torch
def maybe_num_nodes(edge_index, num_nodes=None):
if num_nodes is not None:
return num_nodes
elif isinstance(edge_index, Tensor):
return int(edge_index.max()) + 1
else:
return max(edge_index.size(0), edge_index.size(1))
def to_undirected(edge_index: Tensor, edge_attr: Optional[Tensor] = None,
num_nodes: Optional[int] = None,
reduce: str = "add") -> Union[Tensor, Tuple[Tensor, Tensor]]:
r"""Converts the graph given by :attr:`edge_index` to an undirected graph
such that :math:`(j,i) \in \mathcal{E}` for every edge :math:`(i,j) \in
\mathcal{E}`.
Args:
edge_index (LongTensor): The edge indices.
edge_attr (Tensor, optional): Edge weights or multi-dimensional
edge features. (default: :obj:`None`)
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
reduce (string, optional): The reduce operation to use for merging edge
features. (default: :obj:`"add"`)
:rtype: :class:`LongTensor` if :attr:`edge_attr` is :obj:`None`, else
(:class:`LongTensor`, :class:`Tensor`)
"""
# Maintain backward compatibility to `to_undirected(edge_index, num_nodes)`
if isinstance(edge_attr, int):
edge_attr = None
num_nodes = edge_attr
num_nodes = maybe_num_nodes(edge_index, num_nodes)
row, col = edge_index
row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
edge_index = torch.stack([row, col], dim=0)
if edge_attr is not None:
edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes,
num_nodes, reduce)
if edge_attr is None:
return edge_index
else:
return edge_index, edge_attr
def remove_self_loops(edge_index, edge_attr: Optional[torch.Tensor] = None):
r"""Removes every self-loop in the graph given by :attr:`edge_index`, so
that :math:`(i,i) \not\in \mathcal{E}` for every :math:`i \in \mathcal{V}`.
Args:
edge_index (LongTensor): The edge indices.
edge_attr (Tensor, optional): Edge weights or multi-dimensional
edge features. (default: :obj:`None`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
mask = edge_index[0] != edge_index[1]
edge_index = edge_index[:, mask]
if edge_attr is None:
return edge_index, None
else:
return edge_index, edge_attr[mask]