FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04
# nvidia/cuda:11.6.0-cudnn8-devel-ubuntu20.04

RUN apt update && DEBIAN_FRONTEND=noninteractive apt install -y \
    wget \
    git \
    vim \
    pandoc

RUN wget -q -P /tmp \
  https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
    && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \
    && rm /tmp/Miniconda3-latest-Linux-x86_64.sh

ENV PATH="/opt/conda/bin:$PATH"

RUN conda install -y  \
    pip \
    python=3.10

COPY requirements/requirements-docker-v2.txt /requirements-docker-v2.txt

# RUN pip install -r /requirements-docker-v2.txt -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r /requirements-docker-v2.txt -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# RUN pip install nvidia-cudnn-cu11==8.6.0.163 ray
RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple nvidia-cudnn-cu11==8.6.0.163 ray

RUN pip install tensorcircuit-ng

# requirements conflict for ray
# jax must have cudnn>8.6 otherwise fail when init array on gpu, 
# while torch insists cudnn 8.5 in setup but 8.6 can also work for torch

RUN echo export TF_CPP_MIN_LOG_LEVEL=3 >> ~/.bashrc

CMD ["/bin/bash"]