Keras 3.0 deep learning API backs TensorFlow, PyTorch, Jax

Keras 3.0, a “full rewrite” of the Keras deep learning API, has arrived, providing a new multi back-end implementation of the API.

Unveiled November 27, and accessible from GitHub, Keras 3.0 enables developers to run Keras workflows on top of the Jax, TensorFlow, or PyTorch machine learning frameworks, featuring large-scale model training and deployment capabilities. Keras is deployed as a low-level cross-framework language to develop custom components such as layers, models, or metrics that can be used in native workflows in all three frameworks, with one codebase.

Keras enables high-velocity development through a focus on UX, API design, and debugging, the Keras team said. They noted that Keras has been chosen by more than 2.5 million developers, and powers some of the most sophisticated, largest-scale machine learning systems in the world, such as the Waymo self-driving fleet and the YouTube recommendation engine.

Other benefits of Keras 3 the team cited include:

  • The ability to get the best performance out of models by dynamically selecting the most optimal back end, without requiring code changes.
  • Any Keras 3 model can be instantiated as a PyTorch module, exported as a TensorFlow SavedModel, or instantiated as a stateless Jax function.
  • The ability to leverage large-scale model parallelism and data parallelism with Jax.
  • Keras provides a full implementation of the NumPy API and a set of neural network-specific functions such as ops.softmax, ops.binary_crossentropy, and ops.conv.

Source link

Be the first to comment

Leave a Reply

Your email address will not be published.