项目作者: shivamsaboo17

项目描述 :
Gold Loss Correction for training neural networks with labels corrupted with severe noise
高级语言: Jupyter Notebook
项目地址: git://github.com/shivamsaboo17/GLC.git
创建时间: 2019-08-17T16:07:59Z
项目社区:https://github.com/shivamsaboo17/GLC

开源协议:

下载


GLC

Unofficial implementation of Using Trusted Data to Train Deep Networks on
Labels Corrupted by Severe Noise
(NIPS 18) in PyTorch.

Usage

(See example.ipynb for a walkthrough on MNIST)

  1. from datasets import GoldCorrectionDataset
  2. from glc import CorrectionGenerator, GoldCorrectionLossFunction
  3. c_gen = CorrectionGenerator(simulate=True, dataset=trn_ds, randomization_strength=1.0)
  4. # Fetch both corrupted and clean datasets if in simuate mode
  5. trusted_dataset, untrusted_dataset = c_gen.fetch_datasets()
  6. """
  7. Train the model on untrusted_dataset
  8. """
  9. # Generate correction matrix
  10. label_correction_matrix = c_gen.generate_correction_matrix(trainer.model, 32)
  11. # Wrap trusted and untrusted dataset together using GoldCorrectionDataset class
  12. gold_ds = GoldCorrectionDataset(trusted_dataset, untrusted_dataset)
  13. gold_dl = DataLoader(gold_ds, batch_size=32, shuffle=True)
  14. # Modified loss function
  15. gold_loss = GoldCorrectionLossFunction(label_correction_matrix)
  16. """
  17. Train using gold_ds and gold_loss the model, until convergence
  18. """

Results

MNIST

Regular training on trusted data (~5% of entire data) -> 61.12 accuracy

Gold Loss Correction with 5% trusted -> 95.45 accuracy (All samples in untrusted data (95% of total data) is corrupted by randomly assigning labels)