diff --git a/Dockerfile b/Dockerfile
index 36882deb1a06abaeba8f0d6178e396cce6a06a1e..9ea4d08df8a3e58ecdb9027c78c0a97a5c57c632 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -16,7 +16,7 @@ RUN mamba install -c conda-forge myst-nb
 RUN mamba install -c conda-forge ipympl
 RUN mamba install -c conda-forge ipywidgets
 
-RUN mamba install -c conda-forge jax
+RUN mamba install -c conda-forge "jaxlib=*=*cuda*" jax
 
 RUN mamba install -c conda-forge keras
 RUN mamba install -c conda-forge control