Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import torch_geometric
2import torch_scatter
5def scatter_(name, src, index, out=None, dim=0, dim_size=None):
6 r"""Aggregates all values from the :attr:`src` tensor at the indices
7 specified in the :attr:`index` tensor along the first dimension.
8 If multiple indices reference the same location, their contributions
9 are aggregated according to :attr:`name` (either :obj:`"add"`,
10 :obj:`"mean"` or :obj:`"max"`).
12 Note:
13 This method was copied from torch-geometric v1.3.0 to maintain
14 backwards-compatibility.
16 Args:
17 name (string): The aggregation to use (:obj:`"add"`, :obj:`"mean"`,
18 :obj:`"max"`).
19 src (Tensor): The source tensor.
20 index (LongTensor): The indices of elements to scatter.
21 dim_size (int, optional): Automatically create output tensor with size
22 :attr:`dim_size` in the first dimension. If set to :attr:`None`, a
23 minimal sized output tensor is returned. (default: :obj:`None`)
25 :rtype: :class:`Tensor`
26 """
27 out = torch_scatter.scatter(src, index, out=None, dim=dim, dim_size=dim_size, reduce=name)
28 return out
31try:
32 from torch_geometric.utils import scatter_ # noqa
33except ImportError:
34 torch_geometric.utils.scatter_ = scatter_