项目作者: n2cholas

项目描述 :
Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
高级语言: Python
项目地址: git://github.com/n2cholas/jax-resnet.git
创建时间: 2020-12-29T06:10:33Z
项目社区:https://github.com/n2cholas/jax-resnet

开源协议:MIT License

下载


JAX ResNet - Implementations and Checkpoints for ResNet Variants

Build & Tests

A Flax (Linen) implementation of ResNet (He et al. 2015), Wide ResNet
(Zagoruyko & Komodakis 2016), ResNeXt (Xie et al. 2017), ResNet-D (He et al.
2020), and ResNeSt (Zhang et al. 2020). The code is modular so you can mix and
match the various stem, residual, and bottleneck implementations.

Installation

You can install this package from PyPI:

  1. pip install jax-resnet

Or directly from GitHub:

  1. pip install --upgrade git+https://github.com/n2cholas/jax-resnet.git

Usage

See the bottom of jax-resnet/resnet.py for the available aliases/options for
the ResNet variants (all models are in Flax)

Pretrained checkpoints from
torch.hub are available for the
following networks:

  • ResNet [18, 34, 50, 101, 152]
  • WideResNet [50, 101]
  • ResNeXt [50, 101]
  • ResNeSt [50-Fast, 50, 101, 200, 269]

The models are
tested
to have the same intermediate activations and outputs as the torch.hub
implementations, except ResNeSt-50 Fast, whose activations don’t match exactly
but the final accuracy does.

A pretrained checkpoint for ResNetD-50 is available from
fast.ai.
The activations do not match exactly, but the final accuracy matches.

  1. import jax.numpy as jnp
  2. from jax_resnet import pretrained_resnest
  3. ResNeSt50, variables = pretrained_resnest(50)
  4. model = ResNeSt50()
  5. out = model.apply(variables,
  6. jnp.ones((32, 224, 224, 3)), # ImageNet sized inputs.
  7. mutable=False) # Ensure `batch_stats` aren't updated.

You must install PyTorch yourself
(instructions) to use these
functions.

Transfer Learning

To extract a subset of the model, you can use
Sequential(model.layers[start:end]).

The slice_variables function (found in in
common.py)
allows you to extract the corresponding subset of the variables dict. Check out
that docstring for more information.

Checkpoint Accuracies

The top 1 and top 5 accuracies reported below are on the ImageNet2012
validation split. The data was preprocessed as in the official PyTorch
example
.

Model Size Top 1 Top 5
ResNet 18 69.75% 89.06%
34 73.29% 91.42%
50 76.13% 92.86%
101 77.37% 93.53%
152 78.30% 94.04%
Wide ResNet 50 78.48% 94.08%
101 78.88% 94.29%
ResNeXt 50 77.60% 93.70%
101 79.30% 94.51%
ResNet-D 50 77.57% 93.85%

The ResNeSt validation data was preprocessed as in
zhang1989/ResNeSt.

Model Size Crop Size Top 1 Top 5
ResNeSt-Fast 50 224 80.53% 95.34%
ResNeSt 50 224 81.05% 95.42%
101 256 82.82% 96.32%
200 320 83.84% 96.86%
269 416 84.53% 96.98%

References