diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ee1623e822..981079813d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -139,13 +139,15 @@ jobs: - name: Install dependencies shell: bash -l {0} run: | - mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl "numpy<1.26" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy + pushd conda-envs/ci/unix/test 2>/dev/null + mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" --file test.yml # numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.9, but # not numpy, even though scipy 1.7 requires numpy<1.23. When installing # PyTensor next, pip installs a lower version of numpy via the PyPI. - if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION == "3.9" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numpy<1.23" "numba>=0.57" numba-scipy; fi - if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION != "3.9" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi - if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi + if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION == "3.9" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" --file numba-py39.yml; fi + if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION != "3.9" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" --file numba.yml; fi + if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" --file jax.yml; fi + popd 2>/dev/null pip install -e ./ mamba list && pip freeze python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' @@ -198,7 +200,9 @@ jobs: - name: Install dependencies shell: bash -l {0} run: | - mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" numba-scipy jax jaxlib pytest-benchmark + pushd conda-envs/ci/unix/benchmark 2>/dev/null + mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" --file benchmark.yml + popd 2>/dev/null pip install -e ./ mamba list && pip freeze python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' diff --git a/conda-envs/ci/unix/benchmark/benchmark.yml b/conda-envs/ci/unix/benchmark/benchmark.yml new file mode 100644 index 0000000000..95c9f119d9 --- /dev/null +++ b/conda-envs/ci/unix/benchmark/benchmark.yml @@ -0,0 +1,16 @@ +name: unix_benchmark_environment +channels: + - conda-forge +dependencies: + - cython + - jax + - jaxlib + - mkl + - mkl-service + - "numba>=0.57" + - numba-scipy + - numpy + - pip + - pytest + - pytest-benchmark + - scipy diff --git a/conda-envs/ci/unix/test/jax.yml b/conda-envs/ci/unix/test/jax.yml new file mode 100644 index 0000000000..89cf7b961e --- /dev/null +++ b/conda-envs/ci/unix/test/jax.yml @@ -0,0 +1,10 @@ +name: unix_test_environment +channels: + - conda-forge +dependencies: + - jax + - jaxlib + - numpyro + - pip + - pip: + - tensorflow-probability diff --git a/conda-envs/ci/unix/test/numba-py39.yml b/conda-envs/ci/unix/test/numba-py39.yml new file mode 100644 index 0000000000..de6be555cb --- /dev/null +++ b/conda-envs/ci/unix/test/numba-py39.yml @@ -0,0 +1,7 @@ +name: unix_test_environment +channels: + - conda-forge +dependencies: + - "numpy<1.23" + - "numba>=0.57" + - numba-scipy diff --git a/conda-envs/ci/unix/test/numba.yml b/conda-envs/ci/unix/test/numba.yml new file mode 100644 index 0000000000..2aa6b59888 --- /dev/null +++ b/conda-envs/ci/unix/test/numba.yml @@ -0,0 +1,6 @@ +name: unix_test_environment +channels: + - conda-forge +dependencies: + - "numba>=0.57" + - numba-scipy diff --git a/conda-envs/ci/unix/test/test.yml b/conda-envs/ci/unix/test/test.yml new file mode 100644 index 0000000000..9180edf22d --- /dev/null +++ b/conda-envs/ci/unix/test/test.yml @@ -0,0 +1,15 @@ +name: unix_test_environment +dependencies: + - coverage + - cython + - graphviz + - mkl + - mkl-service + - "numpy<1.26" + - pip + - pytest + - pytest-benchmark + - pytest-cov + - pytest-mock + - scipy + - sympy