Introduction to JAX and Deep Learning

Discover the power of JAX in deep learning. Gain insights into its ecosystem and learn about linear algebra, pseudo-random number generation, and optimization algorithms for cleaner, structured coding.

Intermediate

53 Lessons

2h 30min

Certificate of Completion

Discover the power of JAX in deep learning. Gain insights into its ecosystem and learn about linear algebra, pseudo-random number generation, and optimization algorithms for cleaner, structured coding.

AI-POWERED

Explanations

AI-POWERED

Explanations

This course includes

1 Project
95 Playgrounds
14 Challenges
7 Quizzes

This course includes

1 Project
95 Playgrounds
14 Challenges
7 Quizzes

Course Overview

JAX is a Python library designed for high-performance ML research. It is a powerful numerical computing library, just like Numpy, but with some key improvements. In this course, you will learn all about JAX and its ecosystem of libraries (Haiku, Jraph, Chex, Flax, Optax). Addressing a wide range of audiences, you will cover several topics including linear algebra, random variables theory, pseudo-random number generation, and optimization algorithms. By the end of this course, you will have a new set of sk...Show More

TAKEAWAY SKILLS

Random Variables

Neural Networks

Functional Programming

Deep Learning Basics

What You'll Learn

Learn the basics of JAX

Learn how to apply Autograd

Use auto vectorization for batching

Use Haiku and Flax for implementing neural networks

Cover Optax and overview of common optimization algorithms in deep learning

Use Chex for testing JAX programs

Learn the basics of applied linear algebra

Learn random variables theory and probability distributions

Learn pseudo-random number generation

Cover the basics of optimal transport

What You'll Learn

Learn the basics of JAX

Show more

Course Content

1.

Introduction

Get familiar with JAX, a powerful library for deep learning and numerical computing.
2.

JAX Programming Model

Walk through JAX's programming model, including pure functions, JIT, jaxpr, and autodiff.
3.

Linear Algebra

Explore the fundamental concepts of vectors, matrices, multivariate calculus, and convolutions in deep learning.
4.

Random Variables and Distributions

Grasp the fundamentals of random variables, distributions, PRNGs, and divergence measures in JAX.
5.

JAX Ecosystem

Take a closer look at the tools and libraries within the JAX ecosystem for deep learning.

Project: GAN Using the JAX ecosystem

Project

6.

Appendix

6 Lessons

Focus on installation steps, notable JAX libraries, models, vector calculus, common errors, and key terms.

Trusted by 1.4 million developers working at companies

Anthony Walker

@_webarchitect_

Emma Bostian 🐞

@EmmaBostian

Evan Dunbar

ML Engineer

Carlos Matias La Borde

Software Developer

Souvik Kundu

Front-end Developer

Vinay Krishnaiah

Software Developer

Eric Downs

Musician/Entrepeneur

Kenan Eyvazov

DevOps Engineer

Anthony Walker

@_webarchitect_

Emma Bostian 🐞

@EmmaBostian

Hands-on Learning Powered by AI

See how Educative uses AI to make your learning more immersive than ever before.

Instant Code Feedback

Evaluate and debug your code with the click of a button. Get real-time feedback on test cases, including time and space complexity of your solutions.

AI-Powered Mock Interviews

Adaptive Learning

Explain with AI

AI Code Mentor