项目作者: rishikksh20

项目描述 :
Compact Convolution Transformers
高级语言: Python
项目地址: git://github.com/rishikksh20/compact-convolution-transformer.git
创建时间: 2021-04-14T14:13:02Z
项目社区:https://github.com/rishikksh20/compact-convolution-transformer

开源协议:MIT License

下载


Compact Convolution Transformers

This repo contain pytorch implementation of Compact Convolution Transformers as explained in the Escaping the Big Data Paradigm with Compact Transformers paper, for official implementation of this paper visit here

Usage:

  1. import torch
  2. import numpy as np
  3. from cct import CompactTransformer
  4. img = torch.ones([1, 3, 224, 224])
  5. cvt = CompactTransformer(224, 16, 1000) # For CVT
  6. parameters = filter(lambda p: p.requires_grad, cvt.parameters())
  7. parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
  8. print('Trainable Parameters in CVT: %.3fM' % parameters)
  9. out = cvt(img)
  10. print("Shape of out :", out.shape) # [B, num_classes]
  11. cct = CompactTransformer(224, 16, 1000, conv_embed=True) # For CCT
  12. parameters = filter(lambda p: p.requires_grad, cct.parameters())
  13. parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
  14. print('Trainable Parameters in CCT: %.3fM' % parameters)
  15. out = cct(img)
  16. print("Shape of out :", out.shape) # [B, num_classes]

Citation

  1. @misc{hassani2021escaping,
  2. title={Escaping the Big Data Paradigm with Compact Transformers},
  3. author={Ali Hassani and Steven Walton and Nikhil Shah and Abulikemu Abuduweili and Jiachen Li and Humphrey Shi},
  4. year={2021},
  5. eprint={2104.05704},
  6. archivePrefix={arXiv},
  7. primaryClass={cs.CV}
  8. }