项目作者: harsh306

项目描述 :
Continuation methods of Deep Neural Networks optimization, deep-learning, homotopy, bifurcation-analysis, continuation
高级语言: Python
项目地址: git://github.com/harsh306/continuation-jax.git
创建时间: 2021-02-04T00:30:57Z
项目社区:https://github.com/harsh306/continuation-jax

开源协议:MIT License

下载


continuation-jax : Continuaion Framework for lambda

Continuation methods of Deep Neural Networks
Tags: optimization, deep-learning, homotopy, bifurcation-analysis, continuation

Code style: black
PyPI version
License: MIT
build

Install using pip:

Package: https://pypi.org/project/continuation-jax/

  1. pip install continuation-jax

Import

  1. import cjax

Math operations on Pytrees

  1. >>> import cjax
  2. >>> from cjax.utils import math_trees
  3. >>> math_trees.pytree_element_mul([2,3,5], 2)
  4. [4, 6, 10]
  5. >>> math_trees.pytree_sub([2,3,5], [1,1,1])
  6. [DeviceArray(1, dtype=int32), DeviceArray(2, dtype=int32), DeviceArray(4, dtype=int32)]
  7. >>> math_trees.pytree_zeros_like({'a':12, 'b':45, 'c':[1,1]})
  8. {'a': 0, 'b': 0, 'c': [0, 0]}

Examples:

  1. """
  2. Main file to run contination on the user defined problem. Examples can be found in the examples/ directory.
  3. Continuation is topological procedure to train a neural network. This module tracks all
  4. the critical points or fixed points and dumps them to output file provided in hparams.json file.
  5. Typical usage example:
  6. continuation = ContinuationCreator(
  7. problem=problem, hparams=hparams
  8. ).get_continuation_method()
  9. continuation.run()
  10. """
  11. from cjax.continuation.creator.continuation_creator import ContinuationCreator
  12. from examples.model_simple_classifier.model_classifier import ModelContClassifier
  13. from cjax.utils.abstract_problem import ProblemWraper
  14. import json
  15. from jax.config import config
  16. from datetime import datetime
  17. import mlflow
  18. from cjax.utils.visualizer import pick_array, bif_plot
  19. config.update("jax_debug_nans", True)
  20. # TODO: use **kwargs to reduce params
  21. if __name__ == "__main__":
  22. problem = ModelContClassifier()
  23. problem = ProblemWraper(problem)
  24. with open(problem.HPARAMS_PATH, "r") as hfile:
  25. hparams = json.load(hfile)
  26. mlflow.set_tracking_uri(hparams['meta']["mlflow_uri"])
  27. mlflow.set_experiment(hparams['meta']["name"])
  28. with mlflow.start_run(run_name=hparams['meta']["method"]+"-"+hparams["meta"]["optimizer"]) as run:
  29. mlflow.log_dict(hparams, artifact_file="hparams/hparams.json")
  30. mlflow.log_text("", artifact_file="output/_touch.txt")
  31. artifact_uri = mlflow.get_artifact_uri("output/")
  32. hparams["meta"]["output_dir"] = artifact_uri
  33. print(f"URI: {artifact_uri}")
  34. start_time = datetime.now()
  35. if hparams["n_perturbs"] > 1:
  36. for perturb in range(hparams["n_perturbs"]):
  37. print(f"Running perturb {perturb}")
  38. continuation = ContinuationCreator(
  39. problem=problem, hparams=hparams, key=perturb
  40. ).get_continuation_method()
  41. continuation.run()
  42. else:
  43. continuation = ContinuationCreator(
  44. problem=problem, hparams=hparams
  45. ).get_continuation_method()
  46. continuation.run()
  47. end_time = datetime.now()
  48. print(f"Duration: {end_time-start_time}")
  49. figure = bif_plot(hparams["meta"]["output_dir"], pick_array)
  50. mlflow.log_figure(figure, artifact_file="plots/fig.png")

Note on Hyperparameters

Papers:

Contact:

harshnpathak@gmail.com