diff --git a/Dockerfile b/Dockerfile
index 67d3e0b88e4b0befe33d81c396a48bf940ec885c..7e2f7c1d17a8a45be9ece8300864fc6e3b120933 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -19,6 +19,7 @@ RUN mamba install -c conda-forge ipywidgets
 RUN mamba install -c conda-forge jax
 
 RUN mamba install -c conda-forge keras
+ENV KERAS_BACKEND=jax
 RUN mamba install -c conda-forge control
 RUN mamba install -c conda-forge casadi
 RUN mamba install -c conda-forge networkx