Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import torch_geometric 

2import torch_scatter 

3 

4 

5def scatter_(name, src, index, out=None, dim=0, dim_size=None): 

6 r"""Aggregates all values from the :attr:`src` tensor at the indices 

7 specified in the :attr:`index` tensor along the first dimension. 

8 If multiple indices reference the same location, their contributions 

9 are aggregated according to :attr:`name` (either :obj:`"add"`, 

10 :obj:`"mean"` or :obj:`"max"`). 

11 

12 Note: 

13 This method was copied from torch-geometric v1.3.0 to maintain 

14 backwards-compatibility. 

15 

16 Args: 

17 name (string): The aggregation to use (:obj:`"add"`, :obj:`"mean"`, 

18 :obj:`"max"`). 

19 src (Tensor): The source tensor. 

20 index (LongTensor): The indices of elements to scatter. 

21 dim_size (int, optional): Automatically create output tensor with size 

22 :attr:`dim_size` in the first dimension. If set to :attr:`None`, a 

23 minimal sized output tensor is returned. (default: :obj:`None`) 

24 

25 :rtype: :class:`Tensor` 

26 """ 

27 out = torch_scatter.scatter(src, index, out=None, dim=dim, dim_size=dim_size, reduce=name) 

28 return out 

29 

30 

31try: 

32 from torch_geometric.utils import scatter_ # noqa 

33except ImportError: 

34 torch_geometric.utils.scatter_ = scatter_