2023年2月10日現在、JaxとJaxlibのバージョンの整合性が取れていない状況になっています。
おそらく2023/2/3のアップデートでNvidiaのドライバーを更新したことが原因です。
libcudnnのアップデート & jaxのアップデートで解決します。
!apt install --allow-change-held-packages libcudnn8="8.4.0.27-1+cuda11.6"
!pip install --upgrade pip
!pip install -U "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html