项目作者: JieyuZ2

项目描述 :
WRENCH: Weak supeRvision bENCHmark
高级语言: Python
项目地址: git://github.com/JieyuZ2/wrench.git
创建时间: 2021-08-23T06:18:23Z
项目社区:https://github.com/JieyuZ2/wrench

开源协议:Apache License 2.0

下载




made-with-python
Maintenance
license
repo size
Total lines
visitors
GitHub stars
GitHub forks
Arxiv

🔧 New

1/25/23

  1. Add Hyper label model, please find more details in our paper.

4/20/22

  1. Add WS explainer, please find more details in our paper.

4/20/22

  1. We have updated the setup.py to make installation more flexible.

Please use pip install ws-benchmark==1.1.2rc0 to install the latest version. We strongly suggest create a new environment to install wrench. We will bring better compatibility in the next stable release.
If you have any problems with installation, please let us know.

Known incompatibilities:

tensorflow==2.8.0, albumentations==0.1.12

3/18/22

  1. Wrench is available on ws-benchmark now, using pip install ws-benchmark to qucik install.

2/13/22

  1. Add script to generate LFs for any tabular dataset as well as 5 new tabular datasets, namely, mushroom, spambase, PhishingWebsites, Bioresponse, and bank-marketing.

11/04/21

  1. (beta) Add parallel_fit for torch model to support pytorch DistributedDataParallel-example

10/15/21

  1. A branch of new methods: WeaSEL, ImplyLoss, ASTRA, MeanTeacher, Meta-Weight-Net, Learning-to-Reweight
  2. Support image classification (dataset class / torchvision backbone) as well as DomainNet/Animals-with-Attributes2 datasets (check out the datasets folder)

🔧 What is it?

Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development and evaluation of your own weak supervision models within the benchmark.

For more information, checkout our publications:

If you find this repository helpful, feel free to cite our publication:

  1. @inproceedings{
  2. zhang2021wrench,
  3. title={{WRENCH}: A Comprehensive Benchmark for Weak Supervision},
  4. author={Jieyu Zhang and Yue Yu and Yinghao Li and Yujing Wang and Yaming Yang and Mao Yang and Alexander Ratner},
  5. booktitle={Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
  6. year={2021},
  7. url={https://openreview.net/forum?id=Q9SKS5k8io}
  8. }

🔧 What is weak supervision?

Weak Supervision is a paradigm for automated training data creation without manual annotations.

For a brief overview, please check out this blog.

For more context, please check out this survey.

To track recent advances in weak supervision, please follow this repo.

🔧 Installation

[1] Install anaconda:
Instructions here: https://www.anaconda.com/download/

[2] Clone the repository:

  1. git clone https://github.com/JieyuZ2/wrench.git
  2. cd wrench

[3] Create virtual environment:

  1. conda env create -f environment.yml
  2. source activate wrench

If this not working or you want to use only a subset of modules of Wrench, check out this wiki page

[4] Download datasets:

  1. from huggingface_hub import snapshot_download
  2. path = "path to local dir"
  3. snapshot_download(repo_id="jieyuz2/WRENCH", repo_type="dataset", local_dir=path)

🔧 Available Datasets

Note that some datasets may have more training examples than what is reported in README/paper because we include the dev set, whose indices can be found in labeled_id.json if exists.

A documentation of dataset format and usage can be found in this wiki-page

classification:

Name Task # class # LF # train # validation # test data source LF source
Census income classification 2 83 10083 5561 16281 link link
Youtube spam classification 2 10 1586 120 250 link link
SMS spam classification 2 73 4571 500 500 link link
IMDB sentiment classification 2 8 20000 2500 2500 link link
Yelp sentiment classification 2 8 30400 3800 3800 link link
AGNews topic classification 4 9 96000 12000 12000 link link
TREC question classification 6 68 4965 500 500 link link
Spouse relation classification 2 9 22254 2801 2701 link link
SemEval relation classification 9 164 1749 178 600 link link
CDR bio relation classification 2 33 8430 920 4673 link link
Chemprot chemical relation classification 10 26 12861 1607 1607 link link
Commercial video frame classification 2 4 64130 9479 7496 link link
Tennis Rally video frame classification 2 6 6959 746 1098 link link
Basketball video frame classification 2 4 17970 1064 1222 link link
DomainNet image classification - - - - - link link

sequence tagging:

Name # class # LF # train # validation # test data source LF source
CoNLL-03 4 16 14041 3250 3453 link link
WikiGold 4 16 1355 169 170 link link
OntoNotes 5.0 18 17 115812 5000 22897 link link
BC5CDR 2 9 500 500 500 link link
NCBI-Disease 1 5 592 99 99 link link
Laptop-Review 1 3 2436 609 800 link link
MIT-Restaurant 8 16 7159 500 1521 link link
MIT-Movies 12 7 9241 500 2441 link link

The detailed documentation is coming soon.

🔧 Available Models

If you find any of the implementations is wrong/problematic, don’t hesitate to raise issue/pull request, we really appreciate it!

TODO-list: check this out!

classification:

Model Model Type Reference Link to Wrench
Majority Voting Label Model link
Weighted Majority Voting Label Model link
Dawid-Skene Label Model link link
Data Progamming Label Model link link
MeTaL Label Model link link
FlyingSquid Label Model link link
EBCC Label Model link link
IBCC Label Model link link
FABLE Label Model link link
Hyper Label Model Label Model link link
Logistic Regression End Model link
MLP End Model link
BERT End Model link link
COSINE End Model link link
ARS2 End Model link link
Denoise Joint Model link link
WeaSEL Joint Model link link
SepLL Joint Model link link

sequence tagging:

Model Model Type Reference Link to Wrench
Hidden Markov Model Label Model link link
Conditional Hidden Markov Model Label Model link link
LSTM-CNNs-CRF End Model link link
BERT-CRF End Model link link
LSTM-ConNet Joint Model link link
BERT-ConNet Joint Model link link

classification-to-sequence-tagging wrapper:

Wrench also provides a SeqLabelModelWrapper that adaptes label model for classification task to sequence tagging task.

Robust Learning methods as end model:

Model Model Type Reference Link to Wrench
Meta-Weight-Net End Model link link
Learning2ReWeight End Model link link

Semi-Supervised Learning methods as end model:

Model Model Type Reference Link to Wrench
MeanTeacher End Model link link

Weak Supervision with cleaned labels (Semi-Weak Supervision):

Model Model Type Reference Link to Wrench
ImplyLoss Joint Model link link
ASTRA Joint Model link link

🔧 Quick examples

🔧 Label model with parallel grid search for hyper-parameters

  1. import logging
  2. import numpy as np
  3. import pprint
  4. from wrench.dataset import load_dataset
  5. from wrench._logging import LoggingHandler
  6. from wrench.search import grid_search
  7. from wrench import labelmodel
  8. from wrench.evaluation import AverageMeter
  9. #### Just some code to print debug information to stdout
  10. logging.basicConfig(format='%(asctime)s - %(message)s',
  11. datefmt='%Y-%m-%d %H:%M:%S',
  12. level=logging.INFO,
  13. handlers=[LoggingHandler()])
  14. logger = logging.getLogger(__name__)
  15. #### Load dataset
  16. dataset_home = '../datasets'
  17. data = 'youtube'
  18. train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)
  19. #### Specify the hyper-parameter search space for grid search
  20. search_space = {
  21. 'Snorkel': {
  22. 'lr': np.logspace(-5, -1, num=5, base=10),
  23. 'l2': np.logspace(-5, -1, num=5, base=10),
  24. 'n_epochs': [5, 10, 50, 100, 200],
  25. }
  26. }
  27. #### Initialize label model
  28. label_model_name = 'Snorkel'
  29. label_model = getattr(labelmodel, label_model_name)
  30. #### Search best hyper-parameters using validation set in parallel
  31. n_trials = 100
  32. n_repeats = 5
  33. target = 'acc'
  34. searched_paras = grid_search(label_model(), dataset_train=train_data, dataset_valid=valid_data,
  35. metric=target, direction='auto', search_space=search_space[label_model_name],
  36. n_repeats=n_repeats, n_trials=n_trials, parallel=True)
  37. #### Evaluate the label model with searched hyper-parameters and average meter
  38. meter = AverageMeter(names=[target])
  39. for i in range(n_repeats):
  40. model = label_model(**searched_paras)
  41. history = model.fit(dataset_train=train_data, dataset_valid=valid_data)
  42. metric_value = model.test(test_data, target)
  43. meter.update(target=metric_value)
  44. metrics = meter.get_results()
  45. pprint.pprint(metrics)

For detailed guidance of grid_search, please check out this wiki page.

🔧 Run a standard supervised learning pipeline

  1. import logging
  2. import torch
  3. from wrench.dataset import load_dataset
  4. from wrench._logging import LoggingHandler
  5. from wrench.endmodel import MLPModel
  6. #### Just some code to print debug information to stdout
  7. logging.basicConfig(format='%(asctime)s - %(message)s',
  8. datefmt='%Y-%m-%d %H:%M:%S',
  9. level=logging.INFO,
  10. handlers=[LoggingHandler()])
  11. logger = logging.getLogger(__name__)
  12. #### Load dataset
  13. dataset_home = '../datasets'
  14. data = 'youtube'
  15. #### Extract data features using pre-trained BERT model and cache it
  16. extract_fn = 'bert'
  17. model_name = 'bert-base-cased'
  18. train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
  19. cache_name=extract_fn, model_name=model_name)
  20. #### Train a MLP classifier
  21. device = torch.device('cuda:0')
  22. n_steps = 100000
  23. batch_size = 128
  24. test_batch_size = 1000
  25. patience = 200
  26. evaluation_step = 50
  27. target='acc'
  28. model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
  29. history = model.fit(dataset_train=train_data, dataset_valid=valid_data, device=device, metric=target,
  30. patience=patience, evaluation_step=evaluation_step)
  31. #### Evaluate the trained model
  32. metric_value = model.test(test_data, target)

🔧 Build a two-stage weak supervision pipeline

  1. import logging
  2. import torch
  3. from wrench.dataset import load_dataset
  4. from wrench._logging import LoggingHandler
  5. from wrench.endmodel import MLPModel
  6. from wrench.labelmodel import MajorityVoting
  7. #### Just some code to print debug information to stdout
  8. logging.basicConfig(format='%(asctime)s - %(message)s',
  9. datefmt='%Y-%m-%d %H:%M:%S',
  10. level=logging.INFO,
  11. handlers=[LoggingHandler()])
  12. logger = logging.getLogger(__name__)
  13. #### Load dataset
  14. dataset_home = '../datasets'
  15. data = 'youtube'
  16. #### Extract data features using pre-trained BERT model and cache it
  17. extract_fn = 'bert'
  18. model_name = 'bert-base-cased'
  19. train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
  20. cache_name=extract_fn, model_name=model_name)
  21. #### Generate soft training label via a label model
  22. #### The weak labels provided by supervision sources are alreadly encoded in dataset object
  23. label_model = MajorityVoting()
  24. label_model.fit(train_data, valid_data)
  25. soft_label = label_model.predict_proba(train_data)
  26. #### Train a MLP classifier with soft label
  27. device = torch.device('cuda:0')
  28. n_steps = 100000
  29. batch_size = 128
  30. test_batch_size = 1000
  31. patience = 200
  32. evaluation_step = 50
  33. target='acc'
  34. model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
  35. history = model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=soft_label,
  36. device=device, metric=target, patience=patience, evaluation_step=evaluation_step)
  37. #### Evaluate the trained model
  38. metric_value = model.test(test_data, target)
  39. #### We can also train a MLP classifier with hard label
  40. from snorkel.utils import probs_to_preds
  41. hard_label = probs_to_preds(soft_label)
  42. model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
  43. model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=hard_label,
  44. device=device, metric=target, patience=patience, evaluation_step=evaluation_step)

🔧 Procedural labeling function generator

  1. import logging
  2. import torch
  3. from wrench.dataset import load_dataset
  4. from wrench._logging import LoggingHandler
  5. from wrench.synthetic import ConditionalIndependentGenerator, NGramLFGenerator
  6. from wrench.labelmodel import FlyingSquid
  7. #### Just some code to print debug information to stdout
  8. logging.basicConfig(format='%(asctime)s - %(message)s',
  9. datefmt='%Y-%m-%d %H:%M:%S',
  10. level=logging.INFO,
  11. handlers=[LoggingHandler()])
  12. logger = logging.getLogger(__name__)
  13. #### Generate synthetic dataset
  14. generator = ConditionalIndependentGenerator(
  15. n_class=2,
  16. n_lfs=10,
  17. alpha=0.75, # mean accuracy
  18. beta=0.1, # mean propensity
  19. alpha_radius=0.2, # radius of accuracy
  20. beta_radius=0.1 # radius of propensity
  21. )
  22. train_data = generator.generate_split('train', 10000)
  23. valid_data = generator.generate_split('valid', 1000)
  24. test_data = generator.generate_split('test', 1000)
  25. #### Evaluate label model on synthetic dataset
  26. label_model = FlyingSquid()
  27. label_model.fit(dataset_train=train_data, dataset_valid=valid_data)
  28. target_value = label_model.test(test_data, metric_fn='auc')
  29. #### Load dataset
  30. dataset_home = '../datasets'
  31. data = 'youtube'
  32. #### Load real-world dataset
  33. train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)
  34. #### Generate procedural labeling functions
  35. generator = NGramLFGenerator(dataset=train_data, min_acc_gain=0.1, min_support=0.01, ngram_range=(1, 2))
  36. applier = generator.generate(mode='correlated', n_lfs=10)
  37. L_test = applier.apply(test_data)
  38. L_train = applier.apply(train_data)
  39. #### Evaluate label model on real-world dataset with semi-synthetic labeling functions
  40. label_model = FlyingSquid()
  41. label_model.fit(dataset_train=L_train, dataset_valid=valid_data)
  42. target_value = label_model.test(L_test, metric_fn='auc')

🔧 Contact

Contact person: Jieyu Zhang, jieyuzhang97@gmail.com

Don’t hesitate to send us an e-mail if you have any question.

We’re also open to any collaboration!

🔧 Contributing Dataset and Model

We sincerely welcome any contribution to the datasets or models!

🔧 Citattion

  1. @inproceedings{
  2. zhang2021wrench,
  3. title={{WRENCH}: A Comprehensive Benchmark for Weak Supervision},
  4. author={Jieyu Zhang and Yue Yu and Yinghao Li and Yujing Wang and Yaming Yang and Mao Yang and Alexander Ratner},
  5. booktitle={Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
  6. year={2021},
  7. url={https://openreview.net/forum?id=Q9SKS5k8io}
  8. }