项目作者: francois-rozet

项目描述 :
NumPy-style histograms in PyTorch
高级语言: Python
项目地址: git://github.com/francois-rozet/torchist.git
创建时间: 2021-02-18T15:52:15Z
项目社区:https://github.com/francois-rozet/torchist

开源协议:MIT License

下载


NumPy-style histograms in PyTorch

The torchist package implements NumPy’s histogram and histogramdd functions in PyTorch with CUDA support. The package also features implementations of ravel_multi_index, unravel_index and some useful functionals like entropy or kl_divergence.

Installation

The torchist package is available on PyPI, which means it is installable with pip.

  1. pip install torchist

Alternatively, if you need the latest features, you can install it from the repository.

  1. pip install git+https://github.com/francois-rozet/torchist

Getting Started

  1. import torch
  2. import torchist
  3. x = torch.rand(100, 3).cuda()
  4. hist = torchist.histogramdd(x, bins=10, low=0.0, upp=1.0)
  5. print(hist.shape) # (10, 10, 10)

Benchmark

The implementations of torchist are on par or faster than those of numpy on CPU and benefit greately from CUDA capabilities.

  1. $ python torchist/__init__.py
  2. CPU
  3. ---
  4. np.histogram : 1.2559 s
  5. np.histogramdd : 20.7816 s
  6. np.histogram (non-uniform) : 5.4878 s
  7. np.histogramdd (non-uniform) : 17.3757 s
  8. torchist.histogram : 1.3975 s
  9. torchist.histogramdd : 9.6160 s
  10. torchist.histogram (non-uniform) : 5.0883 s
  11. torchist.histogramdd (non-uniform) : 17.2743 s
  12. CUDA
  13. ----
  14. torchist.histogram : 0.1363 s
  15. torchist.histogramdd : 0.3754 s
  16. torchist.histogram (non-uniform) : 0.1355 s
  17. torchist.histogramdd (non-uniform) : 0.5137 s