项目作者: MaximLippeveld

项目描述 :
Repository for the riverreliability python package. This package provides the riverreliability diagram and PEACE metric, as well as the confidence-reliability diagram and ECE metric.
高级语言: Jupyter Notebook
项目地址: git://github.com/MaximLippeveld/riverreliability.git
创建时间: 2020-08-20T15:08:41Z
项目社区:https://github.com/MaximLippeveld/riverreliability

开源协议:Apache License 2.0

下载


River reliability

Install

Install the package with:

pip install riverreliability

How to use

Below, we show some basic funtionality of the package. Please look at the notebooks for more examples and documentation.

  1. np.random.seed(42)

We start of by generating a fake dataset for classification and splitting it in a train and test set.

  1. X, y = sklearn.datasets.make_classification(n_samples=5000, n_features=12, n_informative=3, n_classes=3)
  2. X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(X, y, test_size=0.2, shuffle=True)

For this example we use an SVM. We fit it on the training data and generate probabilities for the test set.

  1. logreg = sklearn.svm.SVC(probability=True)
  2. logreg.fit(X_train, y_train)
  3. y_probs = logreg.predict_proba(X_test)

As a sanity check we compute some performance metrics.

  1. print(f"Accuracy: {sklearn.metrics.accuracy_score(y_test, y_probs.argmax(axis=1))}")
  2. print(f"Balanced accuracy: {sklearn.metrics.balanced_accuracy_score(y_test, y_probs.argmax(axis=1))}")
  1. Accuracy: 0.808
  2. Balanced accuracy: 0.8084048918146675

To get an insight into calibration we can look at the posterior reliability diagrams and the PEACE metric.

We can plot the diagrams aggregated over all classes:

  1. ax = riverreliability.plots.river_reliability_diagram(y_probs.max(axis=1), y_probs.argmax(axis=1), y_test, bins="fd")
  2. peace_metric = riverreliability.metrics.peace(y_probs.max(axis=1), y_probs.argmax(axis=1), y_test)
  3. ax.set_title(f"PEACE: {peace_metric:.4f}")
  4. _ = ax.legend()

png

Or class-wise to spot miscalibrations for particular classes:

  1. import matplotlib.pyplot as plt
  1. axes = riverreliability.plots.class_wise_river_reliability_diagram(y_probs, y_probs.argmax(axis=1), y_test, bins=15)
  2. peace_metric = riverreliability.metrics.class_wise_error(y_probs, y_probs.argmax(axis=1), y_test, base_error=riverreliability.metrics.peace)
  3. _ = plt.suptitle(f"PEACE: {peace_metric:.4f}")

png

In this particular example we can see that the classifier is well calibrated.

See the notebooks directory for more examples.