“Jax”的版本间差异

来自Shiyin's note
跳到导航 跳到搜索
(创建页面,内容为“*简单的说就是GPU加速、支持自动微分(autodiff)的numpy。参考[https://jax.readthedocs.io/en/latest/notebooks/quickstart.html] ==安装== *jax和jaxli…”)
 
第6行: 第6行:
*安装最新版本 jax 0.2.26和jaxlib0.1.75后会在random函数报错“CustomCall failed: jaxlib/cuda_prng_kernels.cc:30: operation cudaGetLastError() failed: the provided PTX was compiled with an unsupported toolchain”
*安装最新版本 jax 0.2.26和jaxlib0.1.75后会在random函数报错“CustomCall failed: jaxlib/cuda_prng_kernels.cc:30: operation cudaGetLastError() failed: the provided PTX was compiled with an unsupported toolchain”
*最后安装的 jax0.2.1 (pip install -v jax=0.2.1), jaxlib是0.1.71 [https://storage.googleapis.com/jax-releases/cuda111/jaxlib-0.1.71+cuda111-cp38-none-manylinux2010_x86_64.whl]解决问题
*最后安装的 jax0.2.1 (pip install -v jax=0.2.1), jaxlib是0.1.71 [https://storage.googleapis.com/jax-releases/cuda111/jaxlib-0.1.71+cuda111-cp38-none-manylinux2010_x86_64.whl]解决问题
]

2021年12月21日 (二) 14:11的版本

  • 简单的说就是GPU加速、支持自动微分(autodiff)的numpy。参考[1]

安装

  • jax和jaxlib版本要匹配(注意cuda支持)
  • 我的CUDA版本是11.1 ,卡是A40,系统是Ubuntu,cudnn版本是805
  • 安装最新版本 jax 0.2.26和jaxlib0.1.75后会在random函数报错“CustomCall failed: jaxlib/cuda_prng_kernels.cc:30: operation cudaGetLastError() failed: the provided PTX was compiled with an unsupported toolchain”
  • 最后安装的 jax0.2.1 (pip install -v jax=0.2.1), jaxlib是0.1.71 [2]解决问题