diff --git a/.gitignore b/.gitignore index 038d1952..238181b5 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ dist/ build/ docs/build/ +logs/ env/ venv/ diff --git a/.travis.yml b/.travis.yml index c06cab3d..55b7afa9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,11 +8,10 @@ services: - docker install: - - ./mk_dockerfile.sh - - docker-compose build + - docker build --build-arg PG_VERSION="${PG_VERSION}" --build-arg PYTHON_VERSION="${PYTHON_VERSION}" -t tests -f Dockerfile--${TEST_PLATFORM}.tmpl . script: - - docker-compose run $(bash <(curl -s https://codecov.io/env)) tests + - docker run $(bash <(curl -s https://codecov.io/env)) -t tests notifications: email: @@ -20,15 +19,19 @@ notifications: on_failure: always env: - - PYTHON_VERSION=3 PG_VERSION=14 - - PYTHON_VERSION=3 PG_VERSION=13 - - PYTHON_VERSION=3 PG_VERSION=12 - - PYTHON_VERSION=3 PG_VERSION=11 - - PYTHON_VERSION=3 PG_VERSION=10 -# - PYTHON_VERSION=3 PG_VERSION=9.6 -# - PYTHON_VERSION=3 PG_VERSION=9.5 -# - PYTHON_VERSION=3 PG_VERSION=9.4 -# - PYTHON_VERSION=2 PG_VERSION=10 -# - PYTHON_VERSION=2 PG_VERSION=9.6 -# - PYTHON_VERSION=2 PG_VERSION=9.5 -# - PYTHON_VERSION=2 PG_VERSION=9.4 + - TEST_PLATFORM=std2-all PYTHON_VERSION=3.8.0 PG_VERSION=17 + - TEST_PLATFORM=std2-all PYTHON_VERSION=3.8 PG_VERSION=17 + - TEST_PLATFORM=std2-all PYTHON_VERSION=3.9 PG_VERSION=17 + - TEST_PLATFORM=std2-all PYTHON_VERSION=3.10 PG_VERSION=17 + - TEST_PLATFORM=std2-all PYTHON_VERSION=3.11 PG_VERSION=17 + - TEST_PLATFORM=std PYTHON_VERSION=3 PG_VERSION=16 + - TEST_PLATFORM=std PYTHON_VERSION=3 PG_VERSION=15 + - TEST_PLATFORM=std PYTHON_VERSION=3 PG_VERSION=14 + - TEST_PLATFORM=std PYTHON_VERSION=3 PG_VERSION=13 + - TEST_PLATFORM=std PYTHON_VERSION=3 PG_VERSION=12 + - TEST_PLATFORM=std PYTHON_VERSION=3 PG_VERSION=11 + - TEST_PLATFORM=std PYTHON_VERSION=3 PG_VERSION=10 + - TEST_PLATFORM=std-all PYTHON_VERSION=3 PG_VERSION=17 + - TEST_PLATFORM=ubuntu_24_04 PYTHON_VERSION=3 PG_VERSION=17 + - TEST_PLATFORM=altlinux_10 PYTHON_VERSION=3 PG_VERSION=17 + - TEST_PLATFORM=altlinux_11 PYTHON_VERSION=3 PG_VERSION=17 diff --git a/Dockerfile--altlinux_10.tmpl b/Dockerfile--altlinux_10.tmpl new file mode 100644 index 00000000..d78b05f5 --- /dev/null +++ b/Dockerfile--altlinux_10.tmpl @@ -0,0 +1,118 @@ +ARG PG_VERSION +ARG PYTHON_VERSION + +# --------------------------------------------- base1 +FROM alt:p10 as base1 +ARG PG_VERSION + +RUN apt-get update +RUN apt-get install -y sudo curl ca-certificates +RUN apt-get update +RUN apt-get install -y openssh-server openssh-clients +RUN apt-get install -y time + +# RUN apt-get install -y mc + +RUN apt-get install -y libsqlite3-devel + +EXPOSE 22 + +RUN ssh-keygen -A + +# --------------------------------------------- postgres +FROM base1 as base1_with_dev_tools + +RUN apt-get update + +RUN apt-get install -y git +RUN apt-get install -y gcc +RUN apt-get install -y make + +RUN apt-get install -y meson +RUN apt-get install -y flex +RUN apt-get install -y bison + +RUN apt-get install -y pkg-config +RUN apt-get install -y libssl-devel +RUN apt-get install -y libicu-devel +RUN apt-get install -y libzstd-devel +RUN apt-get install -y zlib-devel +RUN apt-get install -y liblz4-devel +RUN apt-get install -y libzstd-devel +RUN apt-get install -y libxml2-devel + +# --------------------------------------------- postgres +FROM base1_with_dev_tools as base1_with_pg-17 + +RUN git clone https://github.com/postgres/postgres.git -b REL_17_STABLE /pg/postgres/source + +WORKDIR /pg/postgres/source + +RUN ./configure --prefix=/pg/postgres/install --with-zlib --with-openssl --without-readline --with-lz4 --with-zstd --with-libxml +RUN make -j 4 install +RUN make -j 4 -C contrib install + +# SETUP PG_CONFIG +# When pg_config symlink in /usr/local/bin it returns a real (right) result of --bindir +RUN ln -s /pg/postgres/install/bin/pg_config -t /usr/local/bin + +# SETUP PG CLIENT LIBRARY +# libpq.so.5 is enough +RUN ln -s /pg/postgres/install/lib/libpq.so.5.17 /usr/lib64/libpq.so.5 + +# --------------------------------------------- base2_with_python-3 +FROM base1_with_pg-${PG_VERSION} as base2_with_python-3 +RUN apt-get install -y python3 +RUN apt-get install -y python3-dev +RUN apt-get install -y python3-module-virtualenv +RUN apt-get install -y python3-modules-sqlite3 + +# AltLinux does not have "generic" virtualenv utility. Let's create it. +RUN if [[ -f "/usr/bin/virtualenv" ]] ; then \ + echo AAA; \ + elif [[ -f "/usr/bin/virtualenv3" ]] ; then \ + ln -s /usr/bin/virtualenv3 /usr/bin/virtualenv; \ + else \ + echo "/usr/bin/virtualenv is not created!"; \ + exit 1; \ + fi + +ENV PYTHON_VERSION=3 + +# --------------------------------------------- final +FROM base2_with_python-${PYTHON_VERSION} as final + +RUN adduser test -G wheel + +# It enables execution of "sudo service ssh start" without password +RUN sh -c "echo \"WHEEL_USERS ALL=(ALL:ALL) NOPASSWD: ALL\"" >> /etc/sudoers + +ADD . /pg/testgres +WORKDIR /pg/testgres +RUN chown -R test /pg/testgres + +ENV LANG=C.UTF-8 + +USER test + +RUN chmod 700 ~/ +RUN mkdir -p ~/.ssh + +# +# Altlinux 10 and 11 too slowly create a new SSH connection (x6). +# +# So, we exclude the "remote" tests until this problem has been resolved. +# + +ENTRYPOINT sh -c " \ +set -eux; \ +echo HELLO FROM ENTRYPOINT; \ +echo HOME DIR IS [`realpath ~/`]; \ +sudo /usr/sbin/sshd; \ +ssh-keyscan -H localhost >> ~/.ssh/known_hosts; \ +ssh-keyscan -H 127.0.0.1 >> ~/.ssh/known_hosts; \ +ssh-keygen -t rsa -f ~/.ssh/id_rsa -q -N ''; \ +cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys; \ +chmod 600 ~/.ssh/authorized_keys; \ +ls -la ~/.ssh/; \ +TEST_FILTER=\"TestTestgresLocal or TestOsOpsLocal or local\" bash ./run_tests.sh;" diff --git a/Dockerfile--altlinux_11.tmpl b/Dockerfile--altlinux_11.tmpl new file mode 100644 index 00000000..5c88585d --- /dev/null +++ b/Dockerfile--altlinux_11.tmpl @@ -0,0 +1,118 @@ +ARG PG_VERSION +ARG PYTHON_VERSION + +# --------------------------------------------- base1 +FROM alt:p11 as base1 +ARG PG_VERSION + +RUN apt-get update +RUN apt-get install -y sudo curl ca-certificates +RUN apt-get update +RUN apt-get install -y openssh-server openssh-clients +RUN apt-get install -y time + +# RUN apt-get install -y mc + +RUN apt-get install -y libsqlite3-devel + +EXPOSE 22 + +RUN ssh-keygen -A + +# --------------------------------------------- postgres +FROM base1 as base1_with_dev_tools + +RUN apt-get update + +RUN apt-get install -y git +RUN apt-get install -y gcc +RUN apt-get install -y make + +RUN apt-get install -y meson +RUN apt-get install -y flex +RUN apt-get install -y bison + +RUN apt-get install -y pkg-config +RUN apt-get install -y libssl-devel +RUN apt-get install -y libicu-devel +RUN apt-get install -y libzstd-devel +RUN apt-get install -y zlib-devel +RUN apt-get install -y liblz4-devel +RUN apt-get install -y libzstd-devel +RUN apt-get install -y libxml2-devel + +# --------------------------------------------- postgres +FROM base1_with_dev_tools as base1_with_pg-17 + +RUN git clone https://github.com/postgres/postgres.git -b REL_17_STABLE /pg/postgres/source + +WORKDIR /pg/postgres/source + +RUN ./configure --prefix=/pg/postgres/install --with-zlib --with-openssl --without-readline --with-lz4 --with-zstd --with-libxml +RUN make -j 4 install +RUN make -j 4 -C contrib install + +# SETUP PG_CONFIG +# When pg_config symlink in /usr/local/bin it returns a real (right) result of --bindir +RUN ln -s /pg/postgres/install/bin/pg_config -t /usr/local/bin + +# SETUP PG CLIENT LIBRARY +# libpq.so.5 is enough +RUN ln -s /pg/postgres/install/lib/libpq.so.5.17 /usr/lib64/libpq.so.5 + +# --------------------------------------------- base2_with_python-3 +FROM base1_with_pg-${PG_VERSION} as base2_with_python-3 +RUN apt-get install -y python3 +RUN apt-get install -y python3-dev +RUN apt-get install -y python3-module-virtualenv +RUN apt-get install -y python3-modules-sqlite3 + +# AltLinux does not have "generic" virtualenv utility. Let's create it. +RUN if [[ -f "/usr/bin/virtualenv" ]] ; then \ + echo AAA; \ + elif [[ -f "/usr/bin/virtualenv3" ]] ; then \ + ln -s /usr/bin/virtualenv3 /usr/bin/virtualenv; \ + else \ + echo "/usr/bin/virtualenv is not created!"; \ + exit 1; \ + fi + +ENV PYTHON_VERSION=3 + +# --------------------------------------------- final +FROM base2_with_python-${PYTHON_VERSION} as final + +RUN adduser test -G wheel + +# It enables execution of "sudo service ssh start" without password +RUN sh -c "echo \"WHEEL_USERS ALL=(ALL:ALL) NOPASSWD: ALL\"" >> /etc/sudoers + +ADD . /pg/testgres +WORKDIR /pg/testgres +RUN chown -R test /pg/testgres + +ENV LANG=C.UTF-8 + +USER test + +RUN chmod 700 ~/ +RUN mkdir -p ~/.ssh + +# +# Altlinux 10 and 11 too slowly create a new SSH connection (x6). +# +# So, we exclude the "remote" tests until this problem has been resolved. +# + +ENTRYPOINT sh -c " \ +set -eux; \ +echo HELLO FROM ENTRYPOINT; \ +echo HOME DIR IS [`realpath ~/`]; \ +sudo /usr/sbin/sshd; \ +ssh-keyscan -H localhost >> ~/.ssh/known_hosts; \ +ssh-keyscan -H 127.0.0.1 >> ~/.ssh/known_hosts; \ +ssh-keygen -t rsa -f ~/.ssh/id_rsa -q -N ''; \ +cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys; \ +chmod 600 ~/.ssh/authorized_keys; \ +ls -la ~/.ssh/; \ +TEST_FILTER=\"TestTestgresLocal or TestOsOpsLocal or local\" bash ./run_tests.sh;" diff --git a/Dockerfile--std-all.tmpl b/Dockerfile--std-all.tmpl new file mode 100644 index 00000000..d19f52a6 --- /dev/null +++ b/Dockerfile--std-all.tmpl @@ -0,0 +1,62 @@ +ARG PG_VERSION +ARG PYTHON_VERSION + +# --------------------------------------------- base1 +FROM postgres:${PG_VERSION}-alpine as base1 + +# --------------------------------------------- base2_with_python-3 +FROM base1 as base2_with_python-3 +RUN apk add --no-cache curl python3 python3-dev build-base musl-dev linux-headers py-virtualenv +ENV PYTHON_VERSION=3 + +# --------------------------------------------- final +FROM base2_with_python-${PYTHON_VERSION} as final + +#RUN apk add --no-cache mc + +# Full version of "ps" command +RUN apk add --no-cache procps + +RUN apk add --no-cache openssh +RUN apk add --no-cache sudo + +ENV LANG=C.UTF-8 + +RUN addgroup -S sudo +RUN adduser postgres sudo + +EXPOSE 22 +RUN ssh-keygen -A + +ADD . /pg/testgres +WORKDIR /pg/testgres +RUN chown -R postgres:postgres /pg + +# It allows to use sudo without password +RUN sh -c "echo \"postgres ALL=(ALL:ALL) NOPASSWD:ALL\"">>/etc/sudoers + +# THIS CMD IS NEEDED TO CONNECT THROUGH SSH WITHOUT PASSWORD +RUN sh -c "echo "postgres:*" | chpasswd -e" + +USER postgres + +# THIS CMD IS NEEDED TO CONNECT THROUGH SSH WITHOUT PASSWORD +RUN chmod 700 ~/ + +RUN mkdir -p ~/.ssh +#RUN chmod 700 ~/.ssh + +#ENTRYPOINT PYTHON_VERSION=${PYTHON_VERSION} bash run_tests.sh + +ENTRYPOINT sh -c " \ +set -eux; \ +echo HELLO FROM ENTRYPOINT; \ +echo HOME DIR IS [`realpath ~/`]; \ +ssh-keygen -t rsa -f ~/.ssh/id_rsa -q -N ''; \ +cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys; \ +chmod 600 ~/.ssh/authorized_keys; \ +ls -la ~/.ssh/; \ +sudo /usr/sbin/sshd; \ +ssh-keyscan -H localhost >> ~/.ssh/known_hosts; \ +ssh-keyscan -H 127.0.0.1 >> ~/.ssh/known_hosts; \ +TEST_FILTER=\"\" bash run_tests.sh;" diff --git a/Dockerfile--std.tmpl b/Dockerfile--std.tmpl new file mode 100644 index 00000000..67aa30b4 --- /dev/null +++ b/Dockerfile--std.tmpl @@ -0,0 +1,22 @@ +ARG PG_VERSION +ARG PYTHON_VERSION + +# --------------------------------------------- base1 +FROM postgres:${PG_VERSION}-alpine as base1 + +# --------------------------------------------- base2_with_python-3 +FROM base1 as base2_with_python-3 +RUN apk add --no-cache curl python3 python3-dev build-base musl-dev linux-headers py-virtualenv +ENV PYTHON_VERSION=3 + +# --------------------------------------------- final +FROM base2_with_python-${PYTHON_VERSION} as final + +ENV LANG=C.UTF-8 + +ADD . /pg/testgres +WORKDIR /pg/testgres +RUN chown -R postgres:postgres /pg + +USER postgres +ENTRYPOINT bash run_tests.sh diff --git a/Dockerfile--std2-all.tmpl b/Dockerfile--std2-all.tmpl new file mode 100644 index 00000000..10d8280c --- /dev/null +++ b/Dockerfile--std2-all.tmpl @@ -0,0 +1,96 @@ +ARG PG_VERSION +ARG PYTHON_VERSION + +# --------------------------------------------- base1 +FROM postgres:${PG_VERSION}-alpine as base1 + +# --------------------------------------------- base2_with_python-3 +FROM base1 as base2_with_python-3 +RUN apk add --no-cache curl python3 python3-dev build-base musl-dev linux-headers + +# For pyenv +RUN apk add patch +RUN apk add git +RUN apk add xz-dev +RUN apk add zip +RUN apk add zlib-dev +RUN apk add libffi-dev +RUN apk add readline-dev +RUN apk add openssl openssl-dev +RUN apk add sqlite-dev +RUN apk add bzip2-dev + +# --------------------------------------------- base3_with_python-3.8.0 +FROM base2_with_python-3 as base3_with_python-3.8.0 +ENV PYTHON_VERSION=3.8.0 + +# --------------------------------------------- base3_with_python-3.8 +FROM base2_with_python-3 as base3_with_python-3.8 +ENV PYTHON_VERSION=3.8 + +# --------------------------------------------- base3_with_python-3.9 +FROM base2_with_python-3 as base3_with_python-3.9 +ENV PYTHON_VERSION=3.9 + +# --------------------------------------------- base3_with_python-3.10 +FROM base2_with_python-3 as base3_with_python-3.10 +ENV PYTHON_VERSION=3.10 + +# --------------------------------------------- base3_with_python-3.11 +FROM base2_with_python-3 as base3_with_python-3.11 +ENV PYTHON_VERSION=3.11 + +# --------------------------------------------- final +FROM base3_with_python-${PYTHON_VERSION} as final + +#RUN apk add --no-cache mc + +# Full version of "ps" command +RUN apk add --no-cache procps + +RUN apk add --no-cache openssh +RUN apk add --no-cache sudo + +ENV LANG=C.UTF-8 + +RUN addgroup -S sudo +RUN adduser postgres sudo + +EXPOSE 22 +RUN ssh-keygen -A + +ADD . /pg/testgres +WORKDIR /pg/testgres +RUN chown -R postgres:postgres /pg + +# It allows to use sudo without password +RUN sh -c "echo \"postgres ALL=(ALL:ALL) NOPASSWD:ALL\"">>/etc/sudoers + +# THIS CMD IS NEEDED TO CONNECT THROUGH SSH WITHOUT PASSWORD +RUN sh -c "echo "postgres:*" | chpasswd -e" + +USER postgres + +RUN curl https://raw.githubusercontent.com/pyenv/pyenv-installer/master/bin/pyenv-installer | bash + +RUN ~/.pyenv/bin/pyenv install ${PYTHON_VERSION} + +# THIS CMD IS NEEDED TO CONNECT THROUGH SSH WITHOUT PASSWORD +RUN chmod 700 ~/ + +RUN mkdir -p ~/.ssh +#RUN chmod 700 ~/.ssh + +ENTRYPOINT sh -c " \ +set -eux; \ +echo HELLO FROM ENTRYPOINT; \ +echo HOME DIR IS [`realpath ~/`]; \ +ssh-keygen -t rsa -f ~/.ssh/id_rsa -q -N ''; \ +cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys; \ +chmod 600 ~/.ssh/authorized_keys; \ +ls -la ~/.ssh/; \ +sudo /usr/sbin/sshd; \ +ssh-keyscan -H localhost >> ~/.ssh/known_hosts; \ +ssh-keyscan -H 127.0.0.1 >> ~/.ssh/known_hosts; \ +export PATH=\"~/.pyenv/bin:$PATH\"; \ +TEST_FILTER=\"\" bash run_tests2.sh;" diff --git a/Dockerfile--ubuntu_24_04.tmpl b/Dockerfile--ubuntu_24_04.tmpl new file mode 100644 index 00000000..7a559776 --- /dev/null +++ b/Dockerfile--ubuntu_24_04.tmpl @@ -0,0 +1,73 @@ +ARG PG_VERSION +ARG PYTHON_VERSION + +# --------------------------------------------- base1 +FROM ubuntu:24.04 as base1 +ARG PG_VERSION + +RUN apt update +RUN apt install -y sudo curl ca-certificates +RUN apt update +RUN apt install -y openssh-server +RUN apt install -y time +RUN apt install -y netcat-traditional + +RUN apt update +RUN apt install -y postgresql-common + +RUN bash /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y + +RUN install -d /usr/share/postgresql-common/pgdg +RUN curl -o /usr/share/postgresql-common/pgdg/apt.postgresql.org.asc --fail https://www.postgresql.org/media/keys/ACCC4CF8.asc + +# It does not work +# RUN sh -c 'echo "deb [signed-by=/usr/share/postgresql-common/pgdg/apt.postgresql.org.asc] https://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' + +RUN apt update +RUN apt install -y postgresql-${PG_VERSION} + +# RUN apt install -y mc + +# [2025-02-26] It adds the user 'postgres' in the group 'sudo' +# [2025-02-27] It is not required. +# RUN adduser postgres sudo + +EXPOSE 22 + +RUN ssh-keygen -A + +# It enables execution of "sudo service ssh start" without password +RUN sh -c "echo postgres ALL=NOPASSWD:/usr/sbin/service ssh start" >> /etc/sudoers + +# --------------------------------------------- base2_with_python-3 +FROM base1 as base2_with_python-3 +RUN apt install -y python3 python3-dev python3-virtualenv libpq-dev +ENV PYTHON_VERSION=3 + +# --------------------------------------------- final +FROM base2_with_python-${PYTHON_VERSION} as final + +ADD . /pg/testgres +WORKDIR /pg/testgres +RUN chown -R postgres /pg + +ENV LANG=C.UTF-8 + +USER postgres + +RUN chmod 700 ~/ +RUN mkdir -p ~/.ssh + +ENTRYPOINT sh -c " \ +#set -eux; \ +echo HELLO FROM ENTRYPOINT; \ +echo HOME DIR IS [`realpath ~/`]; \ +service ssh enable; \ +sudo service ssh start; \ +ssh-keyscan -H localhost >> ~/.ssh/known_hosts; \ +ssh-keyscan -H 127.0.0.1 >> ~/.ssh/known_hosts; \ +ssh-keygen -t rsa -f ~/.ssh/id_rsa -q -N ''; \ +cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys; \ +chmod 600 ~/.ssh/authorized_keys; \ +ls -la ~/.ssh/; \ +TEST_FILTER=\"\" bash ./run_tests.sh;" diff --git a/Dockerfile.tmpl b/Dockerfile.tmpl deleted file mode 100644 index dc5878b6..00000000 --- a/Dockerfile.tmpl +++ /dev/null @@ -1,23 +0,0 @@ -FROM postgres:${PG_VERSION}-alpine - -ENV PYTHON=python${PYTHON_VERSION} -RUN if [ "${PYTHON_VERSION}" = "2" ] ; then \ - apk add --no-cache curl python2 python2-dev build-base musl-dev \ - linux-headers py-virtualenv py-pip; \ - fi -RUN if [ "${PYTHON_VERSION}" = "3" ] ; then \ - apk add --no-cache curl python3 python3-dev build-base musl-dev \ - linux-headers py-virtualenv; \ - fi -ENV LANG=C.UTF-8 - -RUN mkdir -p /pg -COPY run_tests.sh /run.sh -RUN chmod 755 /run.sh - -ADD . /pg/testgres -WORKDIR /pg/testgres -RUN chown -R postgres:postgres /pg - -USER postgres -ENTRYPOINT PYTHON_VERSION=${PYTHON_VERSION} /run.sh diff --git a/README.md b/README.md index a2a0ec7e..defbc8b3 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -[![Build Status](https://travis-ci.com/postgrespro/testgres.svg?branch=master)](https://app.travis-ci.com/github/postgrespro/testgres/branches) +[![Build Status](https://api.travis-ci.com/postgrespro/testgres.svg?branch=master)](https://travis-ci.com/github/postgrespro/testgres) [![codecov](https://codecov.io/gh/postgrespro/testgres/branch/master/graph/badge.svg)](https://codecov.io/gh/postgrespro/testgres) [![PyPI version](https://badge.fury.io/py/testgres.svg)](https://badge.fury.io/py/testgres) @@ -6,7 +6,7 @@ # testgres -PostgreSQL testing utility. Both Python 2.7 and 3.3+ are supported. +PostgreSQL testing utility. Python 3.8+ is supported. ## Installation @@ -59,12 +59,12 @@ with testgres.get_new_node() as node: # ... node stops and its files are about to be removed ``` -There are four API methods for runnig queries: +There are four API methods for running queries: | Command | Description | |----------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------| | `node.psql(query, ...)` | Runs query via `psql` command and returns tuple `(error code, stdout, stderr)`. | -| `node.safe_psql(query, ...)` | Same as `psql()` except that it returns only `stdout`. If an error occures during the execution, an exception will be thrown. | +| `node.safe_psql(query, ...)` | Same as `psql()` except that it returns only `stdout`. If an error occurs during the execution, an exception will be thrown. | | `node.execute(query, ...)` | Connects to PostgreSQL using `psycopg2` or `pg8000` (depends on which one is installed in your system) and returns two-dimensional array with data. | | `node.connect(dbname, ...)` | Returns connection wrapper (`NodeConnection`) capable of running several queries within a single transaction. | diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index 86edf9a4..00000000 --- a/docker-compose.yml +++ /dev/null @@ -1,4 +0,0 @@ -version: '3.8' -services: - tests: - build: . diff --git a/mk_dockerfile.sh b/mk_dockerfile.sh deleted file mode 100755 index d2aa3a8a..00000000 --- a/mk_dockerfile.sh +++ /dev/null @@ -1,2 +0,0 @@ -set -eu -sed -e 's/${PYTHON_VERSION}/'${PYTHON_VERSION}/g -e 's/${PG_VERSION}/'${PG_VERSION}/g Dockerfile.tmpl > Dockerfile diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..9f5fa375 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,9 @@ +[pytest] +testpaths = tests testgres/plugins/pg_probackup2/pg_probackup2/tests +addopts = --strict-markers +markers = +#log_file = logs/pytest.log +log_file_level = NOTSET +log_file_format = %(levelname)8s [%(asctime)s] %(message)s +log_file_date_format=%Y-%m-%d %H:%M:%S + diff --git a/run_tests.sh b/run_tests.sh index 73c459be..65c17dbf 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -4,29 +4,25 @@ set -eux - -# choose python version -echo python version is $PYTHON_VERSION -VIRTUALENV="virtualenv --python=/usr/bin/python$PYTHON_VERSION" -PIP="pip$PYTHON_VERSION" +if [ -z ${TEST_FILTER+x} ]; \ +then export TEST_FILTER="TestTestgresLocal or (TestTestgresCommon and (not remote))"; \ +fi # fail early echo check that pg_config is in PATH command -v pg_config -# prepare environment -VENV_PATH=/tmp/testgres_venv +# prepare python environment +VENV_PATH="/tmp/testgres_venv" rm -rf $VENV_PATH -$VIRTUALENV $VENV_PATH +virtualenv --python="/usr/bin/python${PYTHON_VERSION}" "${VENV_PATH}" export VIRTUAL_ENV_DISABLE_PROMPT=1 -source $VENV_PATH/bin/activate - -# install utilities -$PIP install coverage flake8 psutil Sphinx +source "${VENV_PATH}/bin/activate" +pip install coverage flake8 psutil Sphinx pytest pytest-xdist psycopg2 six psutil # install testgres' dependencies export PYTHONPATH=$(pwd) -$PIP install . +# $PIP install . # test code quality flake8 . @@ -38,21 +34,17 @@ rm -f $COVERAGE_FILE # run tests (PATH) -time coverage run -a tests/test_simple.py +time coverage run -a -m pytest -l -v -n 4 -k "${TEST_FILTER}" # run tests (PG_BIN) -time \ - PG_BIN=$(dirname $(which pg_config)) \ - ALT_CONFIG=1 \ - coverage run -a tests/test_simple.py +PG_BIN=$(pg_config --bindir) \ +time coverage run -a -m pytest -l -v -n 4 -k "${TEST_FILTER}" # run tests (PG_CONFIG) -time \ - PG_CONFIG=$(which pg_config) \ - ALT_CONFIG=1 \ - coverage run -a tests/test_simple.py +PG_CONFIG=$(pg_config --bindir)/pg_config \ +time coverage run -a -m pytest -l -v -n 4 -k "${TEST_FILTER}" # show coverage diff --git a/run_tests2.sh b/run_tests2.sh new file mode 100755 index 00000000..173b19dc --- /dev/null +++ b/run_tests2.sh @@ -0,0 +1,68 @@ +#!/usr/bin/env bash + +# Copyright (c) 2017-2025 Postgres Professional + +set -eux + +eval "$(pyenv init -)" +eval "$(pyenv virtualenv-init -)" + +pyenv virtualenv --force ${PYTHON_VERSION} cur +pyenv activate cur + +if [ -z ${TEST_FILTER+x} ]; \ +then export TEST_FILTER="TestTestgresLocal or (TestTestgresCommon and (not remote))"; \ +fi + +# fail early +echo check that pg_config is in PATH +command -v pg_config + +# prepare python environment +VENV_PATH="/tmp/testgres_venv" +rm -rf $VENV_PATH +python -m venv "${VENV_PATH}" +export VIRTUAL_ENV_DISABLE_PROMPT=1 +source "${VENV_PATH}/bin/activate" +pip install coverage flake8 psutil Sphinx pytest pytest-xdist psycopg2 six psutil + +# install testgres' dependencies +export PYTHONPATH=$(pwd) +# $PIP install . + +# test code quality +flake8 . + + +# remove existing coverage file +export COVERAGE_FILE=.coverage +rm -f $COVERAGE_FILE + + +# run tests (PATH) +time coverage run -a -m pytest -l -v -n 4 -k "${TEST_FILTER}" + + +# run tests (PG_BIN) +PG_BIN=$(pg_config --bindir) \ +time coverage run -a -m pytest -l -v -n 4 -k "${TEST_FILTER}" + + +# run tests (PG_CONFIG) +PG_CONFIG=$(pg_config --bindir)/pg_config \ +time coverage run -a -m pytest -l -v -n 4 -k "${TEST_FILTER}" + + +# show coverage +coverage report + +# build documentation +cd docs +make html +cd .. + +# attempt to fix codecov +set +eux + +# send coverage stats to Codecov +bash <(curl -s https://codecov.io/bash) diff --git a/setup.py b/setup.py index 412e8823..0b209181 100755 --- a/setup.py +++ b/setup.py @@ -27,9 +27,9 @@ readme = f.read() setup( - version='1.10.1', + version='1.11.0', name='testgres', - packages=['testgres', 'testgres.operations', 'testgres.helpers'], + packages=['testgres', 'testgres.operations', 'testgres.impl'], description='Testing utility for PostgreSQL and its extensions', url='https://github.com/postgrespro/testgres', long_description=readme, diff --git a/testgres/__init__.py b/testgres/__init__.py index 8d0e38c6..339ae62e 100644 --- a/testgres/__init__.py +++ b/testgres/__init__.py @@ -23,7 +23,8 @@ CatchUpException, \ StartNodeException, \ InitNodeException, \ - BackupException + BackupException, \ + InvalidOperationException from .enums import \ XLogMethod, \ @@ -33,6 +34,7 @@ DumpFormat from .node import PostgresNode, NodeApp +from .node import PortManager from .utils import \ reserve_port, \ @@ -52,18 +54,17 @@ from .operations.local_ops import LocalOperations from .operations.remote_ops import RemoteOperations -from .helpers.port_manager import PortManager - __all__ = [ "get_new_node", "get_remote_node", "NodeBackup", "testgres_config", "TestgresConfig", "configure_testgres", "scoped_config", "push_config", "pop_config", "NodeConnection", "DatabaseError", "InternalError", "ProgrammingError", "OperationalError", - "TestgresException", "ExecUtilException", "QueryException", "TimeoutException", "CatchUpException", "StartNodeException", "InitNodeException", "BackupException", + "TestgresException", "ExecUtilException", "QueryException", "TimeoutException", "CatchUpException", "StartNodeException", "InitNodeException", "BackupException", "InvalidOperationException", "XLogMethod", "IsolationLevel", "NodeStatus", "ProcessType", "DumpFormat", "PostgresNode", "NodeApp", + "PortManager", "reserve_port", "release_port", "bound_ports", "get_bin_path", "get_pg_config", "get_pg_version", - "First", "Any", "PortManager", + "First", "Any", "OsOperations", "LocalOperations", "RemoteOperations", "ConnectionParams" ] diff --git a/testgres/api.py b/testgres/api.py index e4b1cdd5..6a96ee84 100644 --- a/testgres/api.py +++ b/testgres/api.py @@ -47,8 +47,6 @@ def get_remote_node(name=None, conn_params=None): Simply a wrapper around :class:`.PostgresNode` constructor for remote node. See :meth:`.PostgresNode.__init__` for details. For remote connection you can add the next parameter: - conn_params = ConnectionParams(host='127.0.0.1', - ssh_key=None, - username=default_username()) + conn_params = ConnectionParams(host='127.0.0.1', ssh_key=None, username=default_username()) """ return get_new_node(name=name, conn_params=conn_params) diff --git a/testgres/backup.py b/testgres/backup.py index cecb0f7b..857c46d4 100644 --- a/testgres/backup.py +++ b/testgres/backup.py @@ -15,9 +15,11 @@ from .exceptions import BackupException +from .operations.os_ops import OsOperations + from .utils import \ - get_bin_path, \ - execute_utility, \ + get_bin_path2, \ + execute_utility2, \ clean_on_error @@ -44,6 +46,9 @@ def __init__(self, username: database user name. xlog_method: none | fetch | stream (see docs) """ + assert node.os_ops is not None + assert isinstance(node.os_ops, OsOperations) + if not options: options = [] self.os_ops = node.os_ops @@ -73,7 +78,7 @@ def __init__(self, data_dir = os.path.join(self.base_dir, DATA_DIR) _params = [ - get_bin_path("pg_basebackup"), + get_bin_path2(self.os_ops, "pg_basebackup"), "-p", str(node.port), "-h", node.host, "-U", username, @@ -81,7 +86,7 @@ def __init__(self, "-X", xlog_method.value ] # yapf: disable _params += options - execute_utility(_params, self.log_file) + execute_utility2(self.os_ops, _params, self.log_file) def __enter__(self): return self @@ -142,8 +147,19 @@ def spawn_primary(self, name=None, destroy=True): base_dir = self._prepare_dir(destroy) # Build a new PostgresNode - NodeClass = self.original_node.__class__ - with clean_on_error(NodeClass(name=name, base_dir=base_dir, conn_params=self.original_node.os_ops.conn_params)) as node: + assert self.original_node is not None + + if (hasattr(self.original_node, "clone_with_new_name_and_base_dir")): + node = self.original_node.clone_with_new_name_and_base_dir(name=name, base_dir=base_dir) + else: + # For backward compatibility + NodeClass = self.original_node.__class__ + node = NodeClass(name=name, base_dir=base_dir, conn_params=self.original_node.os_ops.conn_params) + + assert node is not None + assert type(node) == self.original_node.__class__ # noqa: E721 + + with clean_on_error(node) as node: # New nodes should always remove dir tree node._should_rm_dirs = True @@ -168,14 +184,19 @@ def spawn_replica(self, name=None, destroy=True, slot=None): """ # Build a new PostgresNode - with clean_on_error(self.spawn_primary(name=name, - destroy=destroy)) as node: + node = self.spawn_primary(name=name, destroy=destroy) + assert node is not None + try: # Assign it a master and a recovery file (private magic) node._assign_master(self.original_node) node._create_recovery_conf(username=self.username, slot=slot) + except: # noqa: E722 + # TODO: Pass 'final=True' ? + node.cleanup(release_resources=True) + raise - return node + return node def cleanup(self): """ diff --git a/testgres/cache.py b/testgres/cache.py index f17b54b5..499cce91 100644 --- a/testgres/cache.py +++ b/testgres/cache.py @@ -15,23 +15,39 @@ ExecUtilException from .utils import \ - get_bin_path, \ - execute_utility + get_bin_path2, \ + execute_utility2 from .operations.local_ops import LocalOperations from .operations.os_ops import OsOperations -def cached_initdb(data_dir, logfile=None, params=None, os_ops: OsOperations = LocalOperations(), bin_path=None, cached=True): +def cached_initdb(data_dir, logfile=None, params=None, os_ops: OsOperations = None, bin_path=None, cached=True): """ Perform initdb or use cached node files. """ + assert os_ops is None or isinstance(os_ops, OsOperations) + + if os_ops is None: + os_ops = LocalOperations.get_single_instance() + + assert isinstance(os_ops, OsOperations) + + def make_utility_path(name): + assert name is not None + assert type(name) == str # noqa: E721 + + if bin_path: + return os.path.join(bin_path, name) + + return get_bin_path2(os_ops, name) + def call_initdb(initdb_dir, log=logfile): try: - initdb_path = os.path.join(bin_path, 'initdb') if bin_path else get_bin_path("initdb") + initdb_path = make_utility_path("initdb") _params = [initdb_path, "-D", initdb_dir, "-N"] - execute_utility(_params + (params or []), log) + execute_utility2(os_ops, _params + (params or []), log) except ExecUtilException as e: raise_from(InitNodeException("Failed to run initdb"), e) @@ -63,8 +79,8 @@ def call_initdb(initdb_dir, log=logfile): os_ops.write(pg_control, new_pg_control, truncate=True, binary=True, read_and_write=True) # XXX: build new WAL segment with our system id - _params = [get_bin_path("pg_resetwal"), "-D", data_dir, "-f"] - execute_utility(_params, logfile) + _params = [make_utility_path("pg_resetwal"), "-D", data_dir, "-f"] + execute_utility2(os_ops, _params, logfile) except ExecUtilException as e: msg = "Failed to reset WAL for system id" diff --git a/testgres/config.py b/testgres/config.py index b6c43926..55d52426 100644 --- a/testgres/config.py +++ b/testgres/config.py @@ -2,6 +2,8 @@ import atexit import copy +import logging +import os import tempfile from contextlib import contextmanager @@ -10,6 +12,10 @@ from .operations.os_ops import OsOperations from .operations.local_ops import LocalOperations +log_level = os.getenv('LOGGING_LEVEL', 'WARNING').upper() +log_format = os.getenv('LOGGING_FORMAT', '%(asctime)s - %(levelname)s - %(message)s') +logging.basicConfig(level=log_level, format=log_format) + class GlobalConfig(object): """ @@ -44,8 +50,9 @@ class GlobalConfig(object): _cached_initdb_dir = None """ underlying class attribute for cached_initdb_dir property """ - os_ops = LocalOperations() + os_ops = LocalOperations.get_single_instance() """ OsOperation object that allows work on remote host """ + @property def cached_initdb_dir(self): """ path to a temp directory for cached initdb. """ diff --git a/testgres/connection.py b/testgres/connection.py index 49b74844..b8dc49a9 100644 --- a/testgres/connection.py +++ b/testgres/connection.py @@ -1,4 +1,5 @@ # coding: utf-8 +import logging # we support both pg8000 and psycopg2 try: @@ -41,11 +42,13 @@ def __init__(self, self._node = node - self._connection = node.os_ops.db_connect(dbname=dbname, - user=username, - password=password, - host=node.host, - port=node.port) + self._connection = pglib.connect( + database=dbname, + user=username, + password=password, + host=node.host, + port=node.port + ) self._connection.autocommit = autocommit self._cursor = self.connection.cursor() @@ -110,7 +113,7 @@ def execute(self, query, *args): except ProgrammingError: return None except Exception as e: - print("Error executing query: {}\n {}".format(repr(e), query)) + logging.error("Error executing query: {}\n {}".format(repr(e), query)) return None def close(self): diff --git a/testgres/consts.py b/testgres/consts.py index 98c84af6..89c49ab7 100644 --- a/testgres/consts.py +++ b/testgres/consts.py @@ -35,3 +35,7 @@ # logical replication settings LOGICAL_REPL_MAX_CATCHUP_ATTEMPTS = 60 + +PG_CTL__STATUS__OK = 0 +PG_CTL__STATUS__NODE_IS_STOPPED = 3 +PG_CTL__STATUS__BAD_DATADIR = 4 diff --git a/testgres/exceptions.py b/testgres/exceptions.py index ee329031..20c1a8cf 100644 --- a/testgres/exceptions.py +++ b/testgres/exceptions.py @@ -7,15 +7,20 @@ class TestgresException(Exception): pass +class PortForException(TestgresException): + pass + + @six.python_2_unicode_compatible class ExecUtilException(TestgresException): - def __init__(self, message=None, command=None, exit_code=0, out=None): + def __init__(self, message=None, command=None, exit_code=0, out=None, error=None): super(ExecUtilException, self).__init__(message) self.message = message self.command = command self.exit_code = exit_code self.out = out + self.error = error def __str__(self): msg = [] @@ -24,13 +29,17 @@ def __str__(self): msg.append(self.message) if self.command: - msg.append(u'Command: {}'.format(self.command)) + command_s = ' '.join(self.command) if isinstance(self.command, list) else self.command, + msg.append(u'Command: {}'.format(command_s)) if self.exit_code: msg.append(u'Exit code: {}'.format(self.exit_code)) + if self.error: + msg.append(u'---- Error:\n{}'.format(self.error)) + if self.out: - msg.append(u'----\n{}'.format(self.out)) + msg.append(u'---- Out:\n{}'.format(self.out)) return self.convert_and_join(msg) @@ -98,3 +107,7 @@ class InitNodeException(TestgresException): class BackupException(TestgresException): pass + + +class InvalidOperationException(TestgresException): + pass diff --git a/testgres/helpers/port_manager.py b/testgres/helpers/port_manager.py deleted file mode 100644 index 6afdf8a9..00000000 --- a/testgres/helpers/port_manager.py +++ /dev/null @@ -1,40 +0,0 @@ -import socket -import random -from typing import Set, Iterable, Optional - - -class PortForException(Exception): - pass - - -class PortManager: - def __init__(self, ports_range=(1024, 65535)): - self.ports_range = ports_range - - @staticmethod - def is_port_free(port: int) -> bool: - """Check if a port is free to use.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - try: - s.bind(("", port)) - return True - except OSError: - return False - - def find_free_port(self, ports: Optional[Set[int]] = None, exclude_ports: Optional[Iterable[int]] = None) -> int: - """Return a random unused port number.""" - if ports is None: - ports = set(range(1024, 65535)) - - if exclude_ports is None: - exclude_ports = set() - - ports.difference_update(set(exclude_ports)) - - sampled_ports = random.sample(tuple(ports), min(len(ports), 100)) - - for port in sampled_ports: - if self.is_port_free(port): - return port - - raise PortForException("Can't select a port") diff --git a/testgres/impl/port_manager__generic.py b/testgres/impl/port_manager__generic.py new file mode 100755 index 00000000..a51af2bd --- /dev/null +++ b/testgres/impl/port_manager__generic.py @@ -0,0 +1,64 @@ +from ..operations.os_ops import OsOperations + +from ..port_manager import PortManager +from ..exceptions import PortForException + +import threading +import random +import typing + + +class PortManager__Generic(PortManager): + _os_ops: OsOperations + _guard: object + # TODO: is there better to use bitmap fot _available_ports? + _available_ports: typing.Set[int] + _reserved_ports: typing.Set[int] + + def __init__(self, os_ops: OsOperations): + assert os_ops is not None + assert isinstance(os_ops, OsOperations) + self._os_ops = os_ops + self._guard = threading.Lock() + self._available_ports: typing.Set[int] = set(range(1024, 65535)) + self._reserved_ports: typing.Set[int] = set() + + def reserve_port(self) -> int: + assert self._guard is not None + assert type(self._available_ports) == set # noqa: E721t + assert type(self._reserved_ports) == set # noqa: E721 + + with self._guard: + t = tuple(self._available_ports) + assert len(t) == len(self._available_ports) + sampled_ports = random.sample(t, min(len(t), 100)) + t = None + + for port in sampled_ports: + assert not (port in self._reserved_ports) + assert port in self._available_ports + + if not self._os_ops.is_port_free(port): + continue + + self._reserved_ports.add(port) + self._available_ports.discard(port) + assert port in self._reserved_ports + assert not (port in self._available_ports) + return port + + raise PortForException("Can't select a port.") + + def release_port(self, number: int) -> None: + assert type(number) == int # noqa: E721 + + assert self._guard is not None + assert type(self._reserved_ports) == set # noqa: E721 + + with self._guard: + assert number in self._reserved_ports + assert not (number in self._available_ports) + self._available_ports.add(number) + self._reserved_ports.discard(number) + assert not (number in self._reserved_ports) + assert number in self._available_ports diff --git a/testgres/impl/port_manager__this_host.py b/testgres/impl/port_manager__this_host.py new file mode 100755 index 00000000..0d56f356 --- /dev/null +++ b/testgres/impl/port_manager__this_host.py @@ -0,0 +1,33 @@ +from ..port_manager import PortManager + +from .. import utils + +import threading + + +class PortManager__ThisHost(PortManager): + sm_single_instance: PortManager = None + sm_single_instance_guard = threading.Lock() + + @staticmethod + def get_single_instance() -> PortManager: + assert __class__ == PortManager__ThisHost + assert __class__.sm_single_instance_guard is not None + + if __class__.sm_single_instance is not None: + assert type(__class__.sm_single_instance) == __class__ # noqa: E721 + return __class__.sm_single_instance + + with __class__.sm_single_instance_guard: + if __class__.sm_single_instance is None: + __class__.sm_single_instance = __class__() + assert __class__.sm_single_instance is not None + assert type(__class__.sm_single_instance) == __class__ # noqa: E721 + return __class__.sm_single_instance + + def reserve_port(self) -> int: + return utils.reserve_port() + + def release_port(self, number: int) -> None: + assert type(number) == int # noqa: E721 + return utils.release_port(number) diff --git a/testgres/node.py b/testgres/node.py index e5e8fd5f..9a2f4e77 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -1,13 +1,18 @@ # coding: utf-8 +from __future__ import annotations +import logging import os import random import signal import subprocess import threading +import tempfile +import platform from queue import Queue import time +import typing try: from collections.abc import Iterable @@ -47,7 +52,9 @@ RECOVERY_CONF_FILE, \ PG_LOG_FILE, \ UTILS_LOG_FILE, \ - PG_PID_FILE + PG_CTL__STATUS__OK, \ + PG_CTL__STATUS__NODE_IS_STOPPED, \ + PG_CTL__STATUS__BAD_DATADIR \ from .consts import \ MAX_LOGICAL_REPLICATION_WORKERS, \ @@ -63,7 +70,6 @@ from .defaults import \ default_dbname, \ - default_username, \ generate_app_name from .exceptions import \ @@ -74,7 +80,12 @@ TimeoutException, \ InitNodeException, \ TestgresException, \ - BackupException + BackupException, \ + InvalidOperationException + +from .port_manager import PortManager +from .impl.port_manager__this_host import PortManager__ThisHost +from .impl.port_manager__generic import PortManager__Generic from .logger import TestgresLogger @@ -82,22 +93,22 @@ from .standby import First +from . import utils + from .utils import \ PgVer, \ eprint, \ - get_bin_path, \ - get_pg_version, \ - reserve_port, \ - release_port, \ - execute_utility, \ + get_bin_path2, \ + get_pg_version2, \ + execute_utility2, \ options_string, \ clean_on_error from .backup import NodeBackup from .operations.os_ops import ConnectionParams +from .operations.os_ops import OsOperations from .operations.local_ops import LocalOperations -from .operations.remote_ops import RemoteOperations InternalError = pglib.InternalError ProgrammingError = pglib.ProgrammingError @@ -121,13 +132,31 @@ def __getattr__(self, name): return getattr(self.process, name) def __repr__(self): - return '{}(ptype={}, process={})'.format(self.__class__.__name__, - str(self.ptype), - repr(self.process)) + return '{}(ptype={}, process={})'.format( + self.__class__.__name__, + str(self.ptype), + repr(self.process)) class PostgresNode(object): - def __init__(self, name=None, port=None, base_dir=None, conn_params: ConnectionParams = ConnectionParams(), bin_dir=None, prefix=None): + # a max number of node start attempts + _C_MAX_START_ATEMPTS = 5 + + _name: typing.Optional[str] + _port: typing.Optional[int] + _should_free_port: bool + _os_ops: OsOperations + _port_manager: PortManager + + def __init__(self, + name=None, + base_dir=None, + port: typing.Optional[int] = None, + conn_params: ConnectionParams = None, + bin_dir=None, + prefix=None, + os_ops: typing.Optional[OsOperations] = None, + port_manager: typing.Optional[PortManager] = None): """ PostgresNode constructor. @@ -136,11 +165,30 @@ def __init__(self, name=None, port=None, base_dir=None, conn_params: ConnectionP port: port to accept connections. base_dir: path to node's data directory. bin_dir: path to node's binary directory. + os_ops: None or correct OS operation object. + port_manager: None or correct port manager object. """ + assert port is None or type(port) == int # noqa: E721 + assert os_ops is None or isinstance(os_ops, OsOperations) + assert port_manager is None or isinstance(port_manager, PortManager) + + if conn_params is not None: + assert type(conn_params) == ConnectionParams # noqa: E721 + + raise InvalidOperationException("conn_params is deprecated, please use os_ops parameter instead.") # private - self._pg_version = PgVer(get_pg_version(bin_dir)) - self._should_free_port = port is None + if os_ops is None: + self._os_ops = __class__._get_os_ops() + else: + assert isinstance(os_ops, OsOperations) + self._os_ops = os_ops + pass + + assert self._os_ops is not None + assert isinstance(self._os_ops, OsOperations) + + self._pg_version = PgVer(get_pg_version2(self._os_ops, bin_dir)) self._base_dir = base_dir self._bin_dir = bin_dir self._prefix = prefix @@ -148,18 +196,32 @@ def __init__(self, name=None, port=None, base_dir=None, conn_params: ConnectionP self._master = None # basic - self.name = name or generate_app_name() - if testgres_config.os_ops: - self.os_ops = testgres_config.os_ops - elif conn_params.ssh_key: - self.os_ops = RemoteOperations(conn_params) + self._name = name or generate_app_name() + + if port is not None: + assert type(port) == int # noqa: E721 + assert port_manager is None + self._port = port + self._should_free_port = False + self._port_manager = None else: - self.os_ops = LocalOperations(conn_params) + if port_manager is None: + self._port_manager = __class__._get_port_manager(self._os_ops) + elif os_ops is None: + raise InvalidOperationException("When port_manager is not None you have to define os_ops, too.") + else: + assert isinstance(port_manager, PortManager) + assert self._os_ops is os_ops + self._port_manager = port_manager - self.port = port or reserve_port() + assert self._port_manager is not None + assert isinstance(self._port_manager, PortManager) - self.host = self.os_ops.host - self.ssh_key = self.os_ops.ssh_key + self._port = self._port_manager.reserve_port() # raises + assert type(self._port) == int # noqa: E721 + self._should_free_port = True + + assert type(self._port) == int # noqa: E721 # defaults for __exit__() self.cleanup_on_good_exit = testgres_config.node_cleanup_on_good_exit @@ -177,8 +239,6 @@ def __enter__(self): return self def __exit__(self, type, value, traceback): - self.free_port() - # NOTE: Ctrl+C does not count! got_exception = type is not None and type != KeyboardInterrupt @@ -192,9 +252,93 @@ def __exit__(self, type, value, traceback): else: self._try_shutdown(attempts) + self._release_resources() + def __repr__(self): return "{}(name='{}', port={}, base_dir='{}')".format( - self.__class__.__name__, self.name, self.port, self.base_dir) + self.__class__.__name__, + self.name, + str(self._port) if self._port is not None else "None", + self.base_dir + ) + + @staticmethod + def _get_os_ops() -> OsOperations: + if testgres_config.os_ops: + return testgres_config.os_ops + + return LocalOperations.get_single_instance() + + @staticmethod + def _get_port_manager(os_ops: OsOperations) -> PortManager: + assert os_ops is not None + assert isinstance(os_ops, OsOperations) + + if os_ops is LocalOperations.get_single_instance(): + assert utils._old_port_manager is not None + assert type(utils._old_port_manager) == PortManager__Generic # noqa: E721 + assert utils._old_port_manager._os_ops is os_ops + return PortManager__ThisHost.get_single_instance() + + # TODO: Throw the exception "Please define a port manager." ? + return PortManager__Generic(os_ops) + + def clone_with_new_name_and_base_dir(self, name: str, base_dir: str): + assert name is None or type(name) == str # noqa: E721 + assert base_dir is None or type(base_dir) == str # noqa: E721 + + assert __class__ == PostgresNode + + if self._port_manager is None: + raise InvalidOperationException("PostgresNode without PortManager can't be cloned.") + + assert self._port_manager is not None + assert isinstance(self._port_manager, PortManager) + assert self._os_ops is not None + assert isinstance(self._os_ops, OsOperations) + + node = PostgresNode( + name=name, + base_dir=base_dir, + bin_dir=self._bin_dir, + prefix=self._prefix, + os_ops=self._os_ops, + port_manager=self._port_manager) + + return node + + @property + def os_ops(self) -> OsOperations: + assert self._os_ops is not None + assert isinstance(self._os_ops, OsOperations) + return self._os_ops + + @property + def name(self) -> str: + if self._name is None: + raise InvalidOperationException("PostgresNode name is not defined.") + assert type(self._name) == str # noqa: E721 + return self._name + + @property + def host(self) -> str: + assert self._os_ops is not None + assert isinstance(self._os_ops, OsOperations) + return self._os_ops.host + + @property + def port(self) -> int: + if self._port is None: + raise InvalidOperationException("PostgresNode port is not defined.") + + assert type(self._port) == int # noqa: E721 + return self._port + + @property + def ssh_key(self) -> typing.Optional[str]: + assert self._os_ops is not None + assert isinstance(self._os_ops, OsOperations) + return self._os_ops.ssh_key @property def pid(self): @@ -202,14 +346,136 @@ def pid(self): Return postmaster's PID if node is running, else 0. """ - if self.status(): - pid_file = os.path.join(self.data_dir, PG_PID_FILE) - lines = self.os_ops.readlines(pid_file) - pid = int(lines[0]) if lines else None - return pid + self__data_dir = self.data_dir + + _params = [ + self._get_bin_path('pg_ctl'), + "-D", self__data_dir, + "status" + ] # yapf: disable + + status_code, out, error = execute_utility2( + self.os_ops, + _params, + self.utils_log_file, + verbose=True, + ignore_errors=True) + + assert type(status_code) == int # noqa: E721 + assert type(out) == str # noqa: E721 + assert type(error) == str # noqa: E721 + + # ----------------- + if status_code == PG_CTL__STATUS__NODE_IS_STOPPED: + return 0 + + # ----------------- + if status_code == PG_CTL__STATUS__BAD_DATADIR: + return 0 + + # ----------------- + if status_code != PG_CTL__STATUS__OK: + errMsg = "Getting of a node status [data_dir is {0}] failed.".format(self__data_dir) + + raise ExecUtilException( + message=errMsg, + command=_params, + exit_code=status_code, + out=out, + error=error, + ) + + # ----------------- + assert status_code == PG_CTL__STATUS__OK + + if out == "": + __class__._throw_error__pg_ctl_returns_an_empty_string( + _params + ) + + C_PID_PREFIX = "(PID: " + + i = out.find(C_PID_PREFIX) - # for clarity - return 0 + if i == -1: + __class__._throw_error__pg_ctl_returns_an_unexpected_string( + out, + _params + ) + + assert i > 0 + assert i < len(out) + assert len(C_PID_PREFIX) <= len(out) + assert i <= len(out) - len(C_PID_PREFIX) + + i += len(C_PID_PREFIX) + start_pid_s = i + + while True: + if i == len(out): + __class__._throw_error__pg_ctl_returns_an_unexpected_string( + out, + _params + ) + + ch = out[i] + + if ch == ")": + break + + if ch.isdigit(): + i += 1 + continue + + __class__._throw_error__pg_ctl_returns_an_unexpected_string( + out, + _params + ) + assert False + + if i == start_pid_s: + __class__._throw_error__pg_ctl_returns_an_unexpected_string( + out, + _params + ) + + # TODO: Let's verify a length of pid string. + + pid = int(out[start_pid_s:i]) + + if pid == 0: + __class__._throw_error__pg_ctl_returns_a_zero_pid( + out, + _params + ) + + assert pid != 0 + return pid + + @staticmethod + def _throw_error__pg_ctl_returns_an_empty_string(_params): + errLines = [] + errLines.append("Utility pg_ctl returns empty string.") + errLines.append("Command line is {0}".format(_params)) + raise RuntimeError("\n".join(errLines)) + + @staticmethod + def _throw_error__pg_ctl_returns_an_unexpected_string(out, _params): + errLines = [] + errLines.append("Utility pg_ctl returns an unexpected string:") + errLines.append(out) + errLines.append("------------") + errLines.append("Command line is {0}".format(_params)) + raise RuntimeError("\n".join(errLines)) + + @staticmethod + def _throw_error__pg_ctl_returns_a_zero_pid(out, _params): + errLines = [] + errLines.append("Utility pg_ctl returns a zero pid. Output string is:") + errLines.append(out) + errLines.append("------------") + errLines.append("Command line is {0}".format(_params)) + raise RuntimeError("\n".join(errLines)) @property def auxiliary_pids(self): @@ -262,9 +528,11 @@ def source_walsender(self): where application_name = %s """ - if not self.master: + if self.master is None: raise TestgresException("Node doesn't have a master") + assert type(self.master) == PostgresNode # noqa: E721 + # master should be on the same host assert self.master.host == self.host @@ -295,7 +563,7 @@ def base_dir(self): @property def bin_dir(self): if not self._bin_dir: - self._bin_dir = os.path.dirname(get_bin_path("pg_config")) + self._bin_dir = os.path.dirname(get_bin_path2(self.os_ops, "pg_config")) return self._bin_dir @property @@ -331,22 +599,88 @@ def version(self): """ return self._pg_version - def _try_shutdown(self, max_attempts): + def _try_shutdown(self, max_attempts, with_force=False): + assert type(max_attempts) == int # noqa: E721 + assert type(with_force) == bool # noqa: E721 + assert max_attempts > 0 + attempts = 0 # try stopping server N times while attempts < max_attempts: + attempts += 1 try: self.stop() - break # OK except ExecUtilException: - pass # one more time + continue # one more time except Exception: - # TODO: probably should kill stray instance eprint('cannot stop node {}'.format(self.name)) break - attempts += 1 + return # OK + + # If force stopping is enabled and PID is valid + if not with_force: + return + + node_pid = self.pid + assert node_pid is not None + assert type(node_pid) == int # noqa: E721 + + if node_pid == 0: + return + + # TODO: [2025-02-28] It is really the old ugly code. We have to rewrite it! + + ps_command = ['ps', '-o', 'pid=', '-p', str(node_pid)] + + ps_output = self.os_ops.exec_command(cmd=ps_command, shell=True, ignore_errors=True).decode('utf-8') + assert type(ps_output) == str # noqa: E721 + + if ps_output == "": + return + + if ps_output != str(node_pid): + __class__._throw_bugcheck__unexpected_result_of_ps( + ps_output, + ps_command) + + try: + eprint('Force stopping node {0} with PID {1}'.format(self.name, node_pid)) + self.os_ops.kill(node_pid, signal.SIGKILL, expect_error=False) + except Exception: + # The node has already stopped + pass + + # Check that node stopped - print only column pid without headers + ps_output = self.os_ops.exec_command(cmd=ps_command, shell=True, ignore_errors=True).decode('utf-8') + assert type(ps_output) == str # noqa: E721 + + if ps_output == "": + eprint('Node {0} has been stopped successfully.'.format(self.name)) + return + + if ps_output == str(node_pid): + eprint('Failed to stop node {0}.'.format(self.name)) + return + + __class__._throw_bugcheck__unexpected_result_of_ps( + ps_output, + ps_command) + + def _release_resources(self): + self.free_port() + + @staticmethod + def _throw_bugcheck__unexpected_result_of_ps(result, cmd): + assert type(result) == str # noqa: E721 + assert type(cmd) == list # noqa: E721 + errLines = [] + errLines.append("[BUG CHECK] Unexpected result of command ps:") + errLines.append(result) + errLines.append("-----") + errLines.append("Command line is {0}".format(cmd)) + raise RuntimeError("\n".join(errLines)) def _assign_master(self, master): """NOTE: this is a private method!""" @@ -465,10 +799,13 @@ def init(self, initdb_params=None, cached=True, **kwargs): """ # initialize this PostgreSQL node + assert self._os_ops is not None + assert isinstance(self._os_ops, OsOperations) + cached_initdb( data_dir=self.data_dir, logfile=self.utils_log_file, - os_ops=self.os_ops, + os_ops=self._os_ops, params=initdb_params, bin_path=self.bin_dir, cached=False) @@ -529,7 +866,9 @@ def get_auth_method(t): u"host\treplication\tall\t127.0.0.1/32\t{}\n".format(auth_host), u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host), u"host\treplication\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host), - u"host\tall\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host) + u"host\tall\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host), + u"host\tall\tall\tall\t{}\n".format(auth_host), + u"host\treplication\tall\tall\t{}\n".format(auth_host) ] # yapf: disable # write missing lines @@ -634,7 +973,7 @@ def status(self): "-D", self.data_dir, "status" ] # yapf: disable - status_code, out, error = execute_utility(_params, self.utils_log_file, verbose=True) + status_code, out, error = execute_utility2(self.os_ops, _params, self.utils_log_file, verbose=True) if error and 'does not exist' in error: return NodeStatus.Uninitialized elif 'no server running' in out: @@ -660,7 +999,7 @@ def get_control_data(self): _params += ["-D"] if self._pg_version >= PgVer('9.5') else [] _params += [self.data_dir] - data = execute_utility(_params, self.utils_log_file) + data = execute_utility2(self.os_ops, _params, self.utils_log_file) out_dict = {} @@ -670,7 +1009,7 @@ def get_control_data(self): return out_dict - def slow_start(self, replica=False, dbname='template1', username=None, max_attempts=0): + def slow_start(self, replica=False, dbname='template1', username=None, max_attempts=0, exec_env=None): """ Starts the PostgreSQL instance and then polls the instance until it reaches the expected state (primary or replica). The state is checked @@ -683,9 +1022,9 @@ def slow_start(self, replica=False, dbname='template1', username=None, max_attem If False, waits for the instance to be in primary mode. Default is False. max_attempts: """ - if not username: - username = default_username() - self.start() + assert exec_env is None or type(exec_env) == dict # noqa: E721 + + self.start(exec_env=exec_env) if replica: query = 'SELECT pg_is_in_recovery()' @@ -694,14 +1033,14 @@ def slow_start(self, replica=False, dbname='template1', username=None, max_attem # Call poll_query_until until the expected value is returned self.poll_query_until(query=query, dbname=dbname, - username=username, + username=username or self.os_ops.username, suppress={InternalError, QueryException, ProgrammingError, OperationalError}, max_attempts=max_attempts) - def start(self, params=[], wait=True): + def start(self, params=[], wait=True, exec_env=None): """ Starts the PostgreSQL node using pg_ctl if node has not been started. By default, it waits for the operation to complete before returning. @@ -715,41 +1054,89 @@ def start(self, params=[], wait=True): Returns: This instance of :class:`.PostgresNode`. """ + assert exec_env is None or type(exec_env) == dict # noqa: E721 + assert __class__._C_MAX_START_ATEMPTS > 1 + if self.is_started: return self - _params = [ - self._get_bin_path("pg_ctl"), - "-D", self.data_dir, - "-l", self.pg_log_file, - "-w" if wait else '-W', # --wait or --no-wait - "start" - ] + params # yapf: disable + if self._port is None: + raise InvalidOperationException("Can't start PostgresNode. Port is not defined.") - startup_retries = 5 - while True: + assert type(self._port) == int # noqa: E721 + + _params = [self._get_bin_path("pg_ctl"), + "-D", self.data_dir, + "-l", self.pg_log_file, + "-w" if wait else '-W', # --wait or --no-wait + "start"] + params # yapf: disable + + def LOCAL__start_node(): + # 'error' will be None on Windows + _, _, error = execute_utility2(self.os_ops, _params, self.utils_log_file, verbose=True, exec_env=exec_env) + assert error is None or type(error) == str # noqa: E721 + if error and 'does not exist' in error: + raise Exception(error) + + def LOCAL__raise_cannot_start_node(from_exception, msg): + assert isinstance(from_exception, Exception) + assert type(msg) == str # noqa: E721 + files = self._collect_special_files() + raise_from(StartNodeException(msg, files), from_exception) + + def LOCAL__raise_cannot_start_node__std(from_exception): + assert isinstance(from_exception, Exception) + LOCAL__raise_cannot_start_node(from_exception, 'Cannot start node') + + if not self._should_free_port: try: - exit_status, out, error = execute_utility(_params, self.utils_log_file, verbose=True) - if error and 'does not exist' in error: - raise Exception + LOCAL__start_node() except Exception as e: - files = self._collect_special_files() - if any(len(file) > 1 and 'Is another postmaster already ' - 'running on port' in file[1].decode() for - file in files): - print("Detected an issue with connecting to port {0}. " - "Trying another port after a 5-second sleep...".format(self.port)) - self.port = reserve_port() - options = {} - options['port'] = str(self.port) - self.set_auto_conf(options) - startup_retries -= 1 - time.sleep(5) - continue + LOCAL__raise_cannot_start_node__std(e) + else: + assert self._should_free_port + assert self._port_manager is not None + assert isinstance(self._port_manager, PortManager) + assert __class__._C_MAX_START_ATEMPTS > 1 + + log_reader = PostgresNodeLogReader(self, from_beginnig=False) + + nAttempt = 0 + timeout = 1 + while True: + assert nAttempt >= 0 + assert nAttempt < __class__._C_MAX_START_ATEMPTS + nAttempt += 1 + try: + LOCAL__start_node() + except Exception as e: + assert nAttempt > 0 + assert nAttempt <= __class__._C_MAX_START_ATEMPTS + if nAttempt == __class__._C_MAX_START_ATEMPTS: + LOCAL__raise_cannot_start_node(e, "Cannot start node after multiple attempts.") + + is_it_port_conflict = PostgresNodeUtils.delect_port_conflict(log_reader) + + if not is_it_port_conflict: + LOCAL__raise_cannot_start_node__std(e) - msg = 'Cannot start node' - raise_from(StartNodeException(msg, files), e) - break + logging.warning( + "Detected a conflict with using the port {0}. Trying another port after a {1}-second sleep...".format(self._port, timeout) + ) + time.sleep(timeout) + timeout = min(2 * timeout, 5) + cur_port = self._port + new_port = self._port_manager.reserve_port() # can raise + try: + options = {'port': new_port} + self.set_auto_conf(options) + except: # noqa: E722 + self._port_manager.release_port(new_port) + raise + self._port = new_port + self._port_manager.release_port(cur_port) + continue + break self._maybe_start_logger() self.is_started = True return self @@ -775,7 +1162,7 @@ def stop(self, params=[], wait=True): "stop" ] + params # yapf: disable - execute_utility(_params, self.utils_log_file) + execute_utility2(self.os_ops, _params, self.utils_log_file) self._maybe_stop_logger() self.is_started = False @@ -817,7 +1204,7 @@ def restart(self, params=[]): ] + params # yapf: disable try: - error_code, out, error = execute_utility(_params, self.utils_log_file, verbose=True) + error_code, out, error = execute_utility2(self.os_ops, _params, self.utils_log_file, verbose=True) if error and 'could not start server' in error: raise ExecUtilException except ExecUtilException as e: @@ -846,7 +1233,7 @@ def reload(self, params=[]): "reload" ] + params # yapf: disable - execute_utility(_params, self.utils_log_file) + execute_utility2(self.os_ops, _params, self.utils_log_file) return self @@ -868,7 +1255,7 @@ def promote(self, dbname=None, username=None): "promote" ] # yapf: disable - execute_utility(_params, self.utils_log_file) + execute_utility2(self.os_ops, _params, self.utils_log_file) # for versions below 10 `promote` is asynchronous so we need to wait # until it actually becomes writable @@ -903,25 +1290,36 @@ def pg_ctl(self, params): "-w" # wait ] + params # yapf: disable - return execute_utility(_params, self.utils_log_file) + return execute_utility2(self.os_ops, _params, self.utils_log_file) def free_port(self): """ Reclaim port owned by this node. - NOTE: does not free auto selected ports. + NOTE: this method does not release manually defined port but reset it. """ + assert type(self._should_free_port) == bool # noqa: E721 - if self._should_free_port: + if not self._should_free_port: + self._port = None + else: + assert type(self._port) == int # noqa: E721 + + assert self._port_manager is not None + assert isinstance(self._port_manager, PortManager) + + port = self._port self._should_free_port = False - release_port(self.port) + self._port = None + self._port_manager.release_port(port) - def cleanup(self, max_attempts=3): + def cleanup(self, max_attempts=3, full=False, release_resources=False): """ Stop node if needed and remove its data/logs directory. NOTE: take a look at TestgresConfig.node_cleanup_full. Args: max_attempts: how many times should we try to stop()? + full: clean full base dir Returns: This instance of :class:`.PostgresNode`. @@ -930,12 +1328,15 @@ def cleanup(self, max_attempts=3): self._try_shutdown(max_attempts) # choose directory to be removed - if testgres_config.node_cleanup_full: + if testgres_config.node_cleanup_full or full: rm_dir = self.base_dir # everything else: rm_dir = self.data_dir # just data, save logs - self.os_ops.rmdirs(rm_dir, ignore_errors=True) + self.os_ops.rmdirs(rm_dir, ignore_errors=False) + + if release_resources: + self._release_resources() return self @@ -946,6 +1347,8 @@ def psql(self, dbname=None, username=None, input=None, + host: typing.Optional[str] = None, + port: typing.Optional[int] = None, **variables): """ Execute a query using psql. @@ -956,6 +1359,8 @@ def psql(self, dbname: database name to connect to. username: database user name. input: raw input to be passed. + host: an explicit host of server. + port: an explicit port of server. **variables: vars to be set before execution. Returns: @@ -967,15 +1372,64 @@ def psql(self, >>> psql(query='select 3', ON_ERROR_STOP=1) """ - # Set default arguments - dbname = dbname or default_dbname() - username = username or default_username() + assert host is None or type(host) == str # noqa: E721 + assert port is None or type(port) == int # noqa: E721 + assert type(variables) == dict # noqa: E721 + + return self._psql( + ignore_errors=True, + query=query, + filename=filename, + dbname=dbname, + username=username, + input=input, + host=host, + port=port, + **variables + ) + + def _psql( + self, + ignore_errors, + query=None, + filename=None, + dbname=None, + username=None, + input=None, + host: typing.Optional[str] = None, + port: typing.Optional[int] = None, + **variables): + assert host is None or type(host) == str # noqa: E721 + assert port is None or type(port) == int # noqa: E721 + assert type(variables) == dict # noqa: E721 + + # + # We do not support encoding. It may be added later. Ok? + # + if input is None: + pass + elif type(input) == bytes: # noqa: E721 + pass + else: + raise Exception("Input data must be None or bytes.") + + if host is None: + host = self.host + + if port is None: + port = self.port + + assert host is not None + assert port is not None + assert type(host) == str # noqa: E721 + assert type(port) == int # noqa: E721 psql_params = [ self._get_bin_path("psql"), - "-p", str(self.port), - "-h", self.host, - "-U", username, + "-p", str(port), + "-h", host, + "-U", username or self.os_ops.username, + "-d", dbname or default_dbname(), "-X", # no .psqlrc "-A", # unaligned output "-t", # print rows only @@ -988,31 +1442,19 @@ def psql(self, # select query source if query: - if self.os_ops.remote: - psql_params.extend(("-c", '"{}"'.format(query))) - else: - psql_params.extend(("-c", query)) + psql_params.extend(("-c", query)) elif filename: psql_params.extend(("-f", filename)) else: raise QueryException('Query or filename must be provided') - # should be the last one - psql_params.append(dbname) - if not self.os_ops.remote: - # start psql process - process = subprocess.Popen(psql_params, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - - # wait until it finishes and get stdout and stderr - out, err = process.communicate(input=input) - return process.returncode, out, err - else: - status_code, out, err = self.os_ops.exec_command(psql_params, verbose=True, input=input) - - return status_code, out, err + return self.os_ops.exec_command( + psql_params, + verbose=True, + input=input, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + ignore_errors=ignore_errors) @method_decorator(positional_args_hack(['dbname', 'query'])) def safe_psql(self, query=None, expect_error=False, **kwargs): @@ -1033,22 +1475,27 @@ def safe_psql(self, query=None, expect_error=False, **kwargs): Returns: psql's output as str. """ + assert type(kwargs) == dict # noqa: E721 + assert not ("ignore_errors" in kwargs.keys()) + assert not ("expect_error" in kwargs.keys()) # force this setting kwargs['ON_ERROR_STOP'] = 1 try: - ret, out, err = self.psql(query=query, **kwargs) + ret, out, err = self._psql(ignore_errors=False, query=query, **kwargs) except ExecUtilException as e: - ret = e.exit_code - out = e.out - err = e.message - if ret: - if expect_error: - out = (err or b'').decode('utf-8') - else: - raise QueryException((err or b'').decode('utf-8'), query) - elif expect_error: - assert False, "Exception was expected, but query finished successfully: `{}` ".format(query) + if not expect_error: + raise QueryException(e.message, query) + + if type(e.error) == bytes: # noqa: E721 + return e.error.decode("utf-8") # throw + + # [2024-12-09] This situation is not expected + assert False + return e.error + + if expect_error: + raise InvalidOperationException("Exception was expected, but query finished successfully: `{}`.".format(query)) return out @@ -1087,9 +1534,6 @@ def tmpfile(): fname = self.os_ops.mkstemp(prefix=TMP_DUMP) return fname - # Set default arguments - dbname = dbname or default_dbname() - username = username or default_username() filename = filename or tmpfile() _params = [ @@ -1097,12 +1541,12 @@ def tmpfile(): "-p", str(self.port), "-h", self.host, "-f", filename, - "-U", username, - "-d", dbname, + "-U", username or self.os_ops.username, + "-d", dbname or default_dbname(), "-F", format.value ] # yapf: disable - execute_utility(_params, self.utils_log_file) + execute_utility2(self.os_ops, _params, self.utils_log_file) return filename @@ -1118,7 +1562,7 @@ def restore(self, filename, dbname=None, username=None): # Set default arguments dbname = dbname or default_dbname() - username = username or default_username() + username = username or self.os_ops.username _params = [ self._get_bin_path("pg_restore"), @@ -1129,9 +1573,9 @@ def restore(self, filename, dbname=None, username=None): filename ] # yapf: disable - # try pg_restore if dump is binary formate, and psql if not + # try pg_restore if dump is binary format, and psql if not try: - execute_utility(_params, self.utils_log_name) + execute_utility2(self.os_ops, _params, self.utils_log_name) except ExecUtilException: self.psql(filename=filename, dbname=dbname, username=username) @@ -1171,7 +1615,6 @@ def poll_query_until(self, assert sleep_time > 0 attempts = 0 while max_attempts == 0 or attempts < max_attempts: - print(f"Pooling {attempts}") try: res = self.execute(dbname=dbname, query=query, @@ -1195,6 +1638,7 @@ def poll_query_until(self, return # done except tuple(suppress or []): + logging.info(f"Trying execute, attempt {attempts + 1}.\nQuery: {query}") pass # we're suppressing them time.sleep(sleep_time) @@ -1272,7 +1716,7 @@ def set_synchronous_standbys(self, standbys): Args: standbys: either :class:`.First` or :class:`.Any` object specifying - sychronization parameters or just a plain list of + synchronization parameters or just a plain list of :class:`.PostgresNode`s replicas which would be equivalent to passing ``First(1, )``. For PostgreSQL 9.5 and below it is only possible to specify a plain list of standbys as @@ -1388,15 +1832,13 @@ def pgbench(self, if options is None: options = [] - # Set default arguments dbname = dbname or default_dbname() - username = username or default_username() _params = [ self._get_bin_path("pgbench"), "-p", str(self.port), "-h", self.host, - "-U", username, + "-U", username or self.os_ops.username ] + options # yapf: disable # should be the last one @@ -1463,15 +1905,13 @@ def pgbench_run(self, dbname=None, username=None, options=[], **kwargs): >>> pgbench_run(time=10) """ - # Set default arguments dbname = dbname or default_dbname() - username = username or default_username() _params = [ self._get_bin_path("pgbench"), "-p", str(self.port), "-h", self.host, - "-U", username, + "-U", username or self.os_ops.username ] + options # yapf: disable for key, value in iteritems(kwargs): @@ -1488,7 +1928,7 @@ def pgbench_run(self, dbname=None, username=None, options=[], **kwargs): # should be the last one _params.append(dbname) - return execute_utility(_params, self.utils_log_file) + return execute_utility2(self.os_ops, _params, self.utils_log_file) def connect(self, dbname=None, @@ -1616,30 +2056,38 @@ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): name, var = line.partition('=')[::2] name = name.strip() - var = var.strip() - var = var.strip('"') - var = var.strip("'") - # remove options specified in rm_options list + # Remove options specified in rm_options list if name in rm_options: continue current_options[name] = var for option in options: - current_options[option] = options[option] + assert type(option) == str # noqa: E721 + assert option != "" + assert option.strip() == option + + value = options[option] + valueType = type(value) + + if valueType == str: + value = __class__._escape_config_value(value) + elif valueType == bool: + value = "on" if value else "off" + + current_options[option] = value auto_conf = '' for option in current_options: - auto_conf += "{0} = '{1}'\n".format( - option, current_options[option]) + auto_conf += option + " = " + str(current_options[option]) + "\n" for directive in current_directives: auto_conf += directive + "\n" self.os_ops.write(path, auto_conf, truncate=True) - def upgrade_from(self, old_node, options=None): + def upgrade_from(self, old_node, options=None, expect_error=False): """ Upgrade this node from an old node using pg_upgrade. @@ -1667,35 +2115,235 @@ def upgrade_from(self, old_node, options=None): "--old-datadir", old_node.data_dir, "--new-datadir", self.data_dir, "--old-port", str(old_node.port), - "--new-port", str(self.port), + "--new-port", str(self.port) ] upgrade_command += options - return self.os_ops.exec_command(upgrade_command) + return self.os_ops.exec_command(upgrade_command, expect_error=expect_error) def _get_bin_path(self, filename): if self.bin_dir: bin_path = os.path.join(self.bin_dir, filename) else: - bin_path = get_bin_path(filename) + bin_path = get_bin_path2(self.os_ops, filename) return bin_path + def _escape_config_value(value): + assert type(value) == str # noqa: E721 + + result = "'" + + for ch in value: + if ch == "'": + result += "\\'" + elif ch == "\n": + result += "\\n" + elif ch == "\r": + result += "\\r" + elif ch == "\t": + result += "\\t" + elif ch == "\b": + result += "\\b" + elif ch == "\\": + result += "\\\\" + else: + result += ch + + result += "'" + return result + + +class PostgresNodeLogReader: + class LogInfo: + position: int + + def __init__(self, position: int): + self.position = position + + # -------------------------------------------------------------------- + class LogDataBlock: + _file_name: str + _position: int + _data: str + + def __init__( + self, + file_name: str, + position: int, + data: str + ): + assert type(file_name) == str # noqa: E721 + assert type(position) == int # noqa: E721 + assert type(data) == str # noqa: E721 + assert file_name != "" + assert position >= 0 + self._file_name = file_name + self._position = position + self._data = data + + @property + def file_name(self) -> str: + assert type(self._file_name) == str # noqa: E721 + assert self._file_name != "" + return self._file_name + + @property + def position(self) -> int: + assert type(self._position) == int # noqa: E721 + assert self._position >= 0 + return self._position + + @property + def data(self) -> str: + assert type(self._data) == str # noqa: E721 + return self._data + + # -------------------------------------------------------------------- + _node: PostgresNode + _logs: typing.Dict[str, LogInfo] + + # -------------------------------------------------------------------- + def __init__(self, node: PostgresNode, from_beginnig: bool): + assert node is not None + assert isinstance(node, PostgresNode) + assert type(from_beginnig) == bool # noqa: E721 + + self._node = node + + if from_beginnig: + self._logs = dict() + else: + self._logs = self._collect_logs() + + assert type(self._logs) == dict # noqa: E721 + return + + def read(self) -> typing.List[LogDataBlock]: + assert self._node is not None + assert isinstance(self._node, PostgresNode) + + cur_logs: typing.Dict[__class__.LogInfo] = self._collect_logs() + assert cur_logs is not None + assert type(cur_logs) == dict # noqa: E721 + + assert type(self._logs) == dict # noqa: E721 + + result = list() + + for file_name, cur_log_info in cur_logs.items(): + assert type(file_name) == str # noqa: E721 + assert type(cur_log_info) == __class__.LogInfo # noqa: E721 + + read_pos = 0 + + if file_name in self._logs.keys(): + prev_log_info = self._logs[file_name] + assert type(prev_log_info) == __class__.LogInfo # noqa: E721 + read_pos = prev_log_info.position # the previous size + + file_content_b = self._node.os_ops.read_binary(file_name, read_pos) + assert type(file_content_b) == bytes # noqa: E721 + + # + # A POTENTIAL PROBLEM: file_content_b may contain an incompleted UTF-8 symbol. + # + file_content_s = file_content_b.decode() + assert type(file_content_s) == str # noqa: E721 + + next_read_pos = read_pos + len(file_content_b) + + # It is a research/paranoja check. + # When we will process partial UTF-8 symbol, it must be adjusted. + assert cur_log_info.position <= next_read_pos + + cur_log_info.position = next_read_pos + + block = __class__.LogDataBlock( + file_name, + read_pos, + file_content_s + ) + + result.append(block) + + # A new check point + self._logs = cur_logs + + return result + + def _collect_logs(self) -> typing.Dict[LogInfo]: + assert self._node is not None + assert isinstance(self._node, PostgresNode) + + files = [ + self._node.pg_log_file + ] # yapf: disable + + result = dict() + + for f in files: + assert type(f) == str # noqa: E721 + + # skip missing files + if not self._node.os_ops.path_exists(f): + continue + + file_size = self._node.os_ops.get_file_size(f) + assert type(file_size) == int # noqa: E721 + assert file_size >= 0 + + result[f] = __class__.LogInfo(file_size) + + return result + + +class PostgresNodeUtils: + @staticmethod + def delect_port_conflict(log_reader: PostgresNodeLogReader) -> bool: + assert type(log_reader) == PostgresNodeLogReader # noqa: E721 + + blocks = log_reader.read() + assert type(blocks) == list # noqa: E721 + + for block in blocks: + assert type(block) == PostgresNodeLogReader.LogDataBlock # noqa: E721 + + if 'Is another postmaster already running on port' in block.data: + return True + + return False + class NodeApp: - def __init__(self, test_path, nodes_to_cleanup, os_ops=LocalOperations()): - self.test_path = test_path - self.nodes_to_cleanup = nodes_to_cleanup + def __init__(self, test_path=None, nodes_to_cleanup=None, os_ops=None): + assert os_ops is None or isinstance(os_ops, OsOperations) + + if os_ops is None: + os_ops = LocalOperations.get_single_instance() + + assert isinstance(os_ops, OsOperations) + + if test_path: + if os.path.isabs(test_path): + self.test_path = test_path + else: + self.test_path = os.path.join(os_ops.cwd(), test_path) + else: + self.test_path = os_ops.cwd() + self.nodes_to_cleanup = nodes_to_cleanup if nodes_to_cleanup else [] self.os_ops = os_ops def make_empty( self, - base_dir=None): + base_dir=None, + port=None, + bin_dir=None): real_base_dir = os.path.join(self.test_path, base_dir) self.os_ops.rmdirs(real_base_dir, ignore_errors=True) self.os_ops.makedirs(real_base_dir) - node = PostgresNode(base_dir=real_base_dir) + node = PostgresNode(base_dir=real_base_dir, port=port, bin_dir=bin_dir) node.should_rm_dirs = True self.nodes_to_cleanup.append(node) @@ -1704,14 +2352,18 @@ def make_empty( def make_simple( self, base_dir=None, + port=None, set_replication=False, ptrack_enable=False, initdb_params=[], pg_options={}, - checksum=True): + checksum=True, + bin_dir=None): + assert type(pg_options) == dict # noqa: E721 + if checksum and '--data-checksums' not in initdb_params: initdb_params.append('--data-checksums') - node = self.make_empty(base_dir) + node = self.make_empty(base_dir, port, bin_dir=bin_dir) node.init( initdb_params=initdb_params, allow_streaming=set_replication) @@ -1721,19 +2373,22 @@ def make_simple( node.major_version = float(node.major_version_str) # Set default parameters - options = {'max_connections': 100, - 'shared_buffers': '10MB', - 'fsync': 'off', - 'wal_level': 'logical', - 'hot_standby': 'off', - 'log_line_prefix': '%t [%p]: [%l-1] ', - 'log_statement': 'none', - 'log_duration': 'on', - 'log_min_duration_statement': 0, - 'log_connections': 'on', - 'log_disconnections': 'on', - 'restart_after_crash': 'off', - 'autovacuum': 'off'} + options = { + 'max_connections': 100, + 'shared_buffers': '10MB', + 'fsync': 'off', + 'wal_level': 'logical', + 'hot_standby': 'off', + 'log_line_prefix': '%t [%p]: [%l-1] ', + 'log_statement': 'none', + 'log_duration': 'on', + 'log_min_duration_statement': 0, + 'log_connections': 'on', + 'log_disconnections': 'on', + 'restart_after_crash': 'off', + 'autovacuum': 'off', + # unix_socket_directories will be defined later + } # Allow replication in pg_hba.conf if set_replication: @@ -1748,11 +2403,16 @@ def make_simple( else: options['wal_keep_segments'] = '12' - # set default values - node.set_auto_conf(options) - # Apply given parameters - node.set_auto_conf(pg_options) + for option_name, option_value in iteritems(pg_options): + options[option_name] = option_value + + # Define delayed propertyes + if not ("unix_socket_directories" in options.keys()): + options["unix_socket_directories"] = __class__._gettempdir_for_socket() + + # Set config values + node.set_auto_conf(options) # kludge for testgres # https://github.com/postgrespro/testgres/issues/54 @@ -1761,3 +2421,56 @@ def make_simple( node.set_auto_conf({}, 'postgresql.conf', ['wal_keep_segments']) return node + + @staticmethod + def _gettempdir_for_socket(): + platform_system_name = platform.system().lower() + + if platform_system_name == "windows": + return __class__._gettempdir() + + # + # [2025-02-17] Hot fix. + # + # Let's use hard coded path as Postgres likes. + # + # pg_config_manual.h: + # + # #ifndef WIN32 + # #define DEFAULT_PGSOCKET_DIR "/tmp" + # #else + # #define DEFAULT_PGSOCKET_DIR "" + # #endif + # + # On the altlinux-10 tempfile.gettempdir() may return + # the path to "private" temp directiry - "/temp/.private//" + # + # But Postgres want to find a socket file in "/tmp" (see above). + # + + return "/tmp" + + @staticmethod + def _gettempdir(): + v = tempfile.gettempdir() + + # + # Paranoid checks + # + if type(v) != str: # noqa: E721 + __class__._raise_bugcheck("tempfile.gettempdir returned a value with type {0}.".format(type(v).__name__)) + + if v == "": + __class__._raise_bugcheck("tempfile.gettempdir returned an empty string.") + + if not os.path.exists(v): + __class__._raise_bugcheck("tempfile.gettempdir returned a not exist path [{0}].".format(v)) + + # OK + return v + + @staticmethod + def _raise_bugcheck(msg): + assert type(msg) == str # noqa: E721 + assert msg != "" + raise Exception("[BUG CHECK] " + msg) diff --git a/testgres/operations/helpers.py b/testgres/operations/helpers.py new file mode 100644 index 00000000..ebbf0f73 --- /dev/null +++ b/testgres/operations/helpers.py @@ -0,0 +1,55 @@ +import locale + + +class Helpers: + @staticmethod + def _make_get_default_encoding_func(): + # locale.getencoding is added in Python 3.11 + if hasattr(locale, 'getencoding'): + return locale.getencoding + + # It must exist + return locale.getpreferredencoding + + # Prepared pointer on function to get a name of system codepage + _get_default_encoding_func = _make_get_default_encoding_func.__func__() + + @staticmethod + def GetDefaultEncoding(): + # + # Original idea/source was: + # + # def os_ops.get_default_encoding(): + # if not hasattr(locale, 'getencoding'): + # locale.getencoding = locale.getpreferredencoding + # return locale.getencoding() or 'UTF-8' + # + + assert __class__._get_default_encoding_func is not None + + r = __class__._get_default_encoding_func() + + if r: + assert r is not None + assert type(r) == str # noqa: E721 + assert r != "" + return r + + # Is it an unexpected situation? + return 'UTF-8' + + @staticmethod + def PrepareProcessInput(input, encoding): + if not input: + return None + + if type(input) == str: # noqa: E721 + if encoding is None: + return input.encode(__class__.GetDefaultEncoding()) + + assert type(encoding) == str # noqa: E721 + return input.encode(encoding) + + # It is expected! + assert type(input) == bytes # noqa: E721 + return input diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index ef360d3b..ccf1ab82 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -1,14 +1,22 @@ import getpass +import logging import os import shutil import stat import subprocess import tempfile +import time +import socket import psutil +import typing +import threading from ..exceptions import ExecUtilException -from .os_ops import ConnectionParams, OsOperations, pglib, get_default_encoding +from ..exceptions import InvalidOperationException +from .os_ops import ConnectionParams, OsOperations, get_default_encoding +from .raise_error import RaiseError +from .helpers import Helpers try: from shutil import which as find_executable @@ -18,18 +26,12 @@ from distutils import rmtree CMD_TIMEOUT_SEC = 60 -error_markers = [b'error', b'Permission denied', b'fatal'] - - -def has_errors(output): - if output: - if isinstance(output, str): - output = output.encode(get_default_encoding()) - return any(marker in output for marker in error_markers) - return False class LocalOperations(OsOperations): + sm_single_instance: OsOperations = None + sm_single_instance_guard = threading.Lock() + def __init__(self, conn_params=None): if conn_params is None: conn_params = ConnectionParams() @@ -38,15 +40,23 @@ def __init__(self, conn_params=None): self.host = conn_params.host self.ssh_key = None self.remote = False - self.username = conn_params.username or self.get_user() + self.username = conn_params.username or getpass.getuser() @staticmethod - def _raise_exec_exception(message, command, exit_code, output): - """Raise an ExecUtilException.""" - raise ExecUtilException(message=message.format(output), - command=' '.join(command) if isinstance(command, list) else command, - exit_code=exit_code, - out=output) + def get_single_instance() -> OsOperations: + assert __class__ == LocalOperations + assert __class__.sm_single_instance_guard is not None + + if __class__.sm_single_instance is not None: + assert type(__class__.sm_single_instance) == __class__ # noqa: E721 + return __class__.sm_single_instance + + with __class__.sm_single_instance_guard: + if __class__.sm_single_instance is None: + __class__.sm_single_instance = __class__() + assert __class__.sm_single_instance is not None + assert type(__class__.sm_single_instance) == __class__ # noqa: E721 + return __class__.sm_single_instance @staticmethod def _process_output(encoding, temp_file_path): @@ -57,83 +67,177 @@ def _process_output(encoding, temp_file_path): output = output.decode(encoding) return output, None # In Windows stderr writing in stdout - def _run_command(self, cmd, shell, input, stdin, stdout, stderr, get_process, timeout, encoding): - """Execute a command and return the process and its output.""" - if os.name == 'nt' and stdout is None: # Windows - with tempfile.NamedTemporaryFile(mode='w+b', delete=False) as temp_file: - stdout = temp_file - stderr = subprocess.STDOUT - process = subprocess.Popen( - cmd, - shell=shell, - stdin=stdin or subprocess.PIPE if input is not None else None, - stdout=stdout, - stderr=stderr, - ) - if get_process: - return process, None, None - temp_file_path = temp_file.name - - # Wait process finished - process.wait() - - output, error = self._process_output(encoding, temp_file_path) - return process, output, error - else: # Other OS + def _run_command__nt(self, cmd, shell, input, stdin, stdout, stderr, get_process, timeout, encoding, exec_env=None): + assert exec_env is None or type(exec_env) == dict # noqa: E721 + + # TODO: why don't we use the data from input? + + extParams: typing.Dict[str, str] = dict() + + if exec_env is None: + pass + elif len(exec_env) == 0: + pass + else: + env = os.environ.copy() + assert type(env) == dict # noqa: E721 + for v in exec_env.items(): + assert type(v) == tuple # noqa: E721 + assert len(v) == 2 + assert type(v[0]) == str # noqa: E721 + assert v[0] != "" + + if v[1] is None: + env.pop(v[0], None) + else: + assert type(v[1]) == str # noqa: E721 + env[v[0]] = v[1] + + extParams["env"] = env + + with tempfile.NamedTemporaryFile(mode='w+b', delete=False) as temp_file: + stdout = temp_file + stderr = subprocess.STDOUT process = subprocess.Popen( cmd, shell=shell, stdin=stdin or subprocess.PIPE if input is not None else None, - stdout=stdout or subprocess.PIPE, - stderr=stderr or subprocess.PIPE, + stdout=stdout, + stderr=stderr, + **extParams, ) if get_process: return process, None, None - try: - output, error = process.communicate(input=input.encode(encoding) if input else None, timeout=timeout) - if encoding: - output = output.decode(encoding) - error = error.decode(encoding) - return process, output, error - except subprocess.TimeoutExpired: - process.kill() - raise ExecUtilException("Command timed out after {} seconds.".format(timeout)) + temp_file_path = temp_file.name + + # Wait process finished + process.wait() + + output, error = self._process_output(encoding, temp_file_path) + return process, output, error + + def _run_command__generic(self, cmd, shell, input, stdin, stdout, stderr, get_process, timeout, encoding, exec_env=None): + assert exec_env is None or type(exec_env) == dict # noqa: E721 + + input_prepared = None + if not get_process: + input_prepared = Helpers.PrepareProcessInput(input, encoding) # throw + + assert input_prepared is None or (type(input_prepared) == bytes) # noqa: E721 + + extParams: typing.Dict[str, str] = dict() + + if exec_env is None: + pass + elif len(exec_env) == 0: + pass + else: + env = os.environ.copy() + assert type(env) == dict # noqa: E721 + for v in exec_env.items(): + assert type(v) == tuple # noqa: E721 + assert len(v) == 2 + assert type(v[0]) == str # noqa: E721 + assert v[0] != "" + + if v[1] is None: + env.pop(v[0], None) + else: + assert type(v[1]) == str # noqa: E721 + env[v[0]] = v[1] + + extParams["env"] = env + + process = subprocess.Popen( + cmd, + shell=shell, + stdin=stdin or subprocess.PIPE if input is not None else None, + stdout=stdout or subprocess.PIPE, + stderr=stderr or subprocess.PIPE, + **extParams + ) + assert not (process is None) + if get_process: + return process, None, None + try: + output, error = process.communicate(input=input_prepared, timeout=timeout) + except subprocess.TimeoutExpired: + process.kill() + raise ExecUtilException("Command timed out after {} seconds.".format(timeout)) + + assert type(output) == bytes # noqa: E721 + assert type(error) == bytes # noqa: E721 + + if encoding: + output = output.decode(encoding) + error = error.decode(encoding) + return process, output, error + + def _run_command(self, cmd, shell, input, stdin, stdout, stderr, get_process, timeout, encoding, exec_env=None): + """Execute a command and return the process and its output.""" + if os.name == 'nt' and stdout is None: # Windows + method = __class__._run_command__nt + else: # Other OS + method = __class__._run_command__generic + + return method(self, cmd, shell, input, stdin, stdout, stderr, get_process, timeout, encoding, exec_env=exec_env) def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=False, - text=False, input=None, stdin=None, stdout=None, stderr=None, get_process=False, timeout=None): + text=False, input=None, stdin=None, stdout=None, stderr=None, get_process=False, timeout=None, + ignore_errors=False, exec_env=None): """ Execute a command in a subprocess and handle the output based on the provided parameters. """ - process, output, error = self._run_command(cmd, shell, input, stdin, stdout, stderr, get_process, timeout, encoding) + assert type(expect_error) == bool # noqa: E721 + assert type(ignore_errors) == bool # noqa: E721 + assert exec_env is None or type(exec_env) == dict # noqa: E721 + + process, output, error = self._run_command(cmd, shell, input, stdin, stdout, stderr, get_process, timeout, encoding, exec_env=exec_env) if get_process: return process - if process.returncode != 0 or (has_errors(error) and not expect_error): - self._raise_exec_exception('Utility exited with non-zero code. Error `{}`', cmd, process.returncode, error) + + if expect_error: + if process.returncode == 0: + raise InvalidOperationException("We expected an execution error.") + elif ignore_errors: + pass + elif process.returncode == 0: + pass + else: + assert not expect_error + assert not ignore_errors + assert process.returncode != 0 + RaiseError.UtilityExitedWithNonZeroCode( + cmd=cmd, + exit_code=process.returncode, + msg_arg=error or output, + error=error, + out=output) if verbose: return process.returncode, output, error - else: - return output + + return output # Environment setup def environ(self, var_name): return os.environ.get(var_name) + def cwd(self): + return os.getcwd() + def find_executable(self, executable): return find_executable(executable) def is_executable(self, file): # Check if the file is executable - return os.stat(file).st_mode & stat.S_IXUSR + assert stat.S_IXUSR != 0 + return (os.stat(file).st_mode & stat.S_IXUSR) == stat.S_IXUSR def set_env(self, var_name, var_val): # Check if the directory is already in PATH os.environ[var_name] = var_val - # Get environment variables - def get_user(self): - return self.username or getpass.getuser() - def get_name(self): return os.name @@ -146,8 +250,56 @@ def makedirs(self, path, remove_existing=False): except FileExistsError: pass - def rmdirs(self, path, ignore_errors=True): - return rmtree(path, ignore_errors=ignore_errors) + def makedir(self, path: str): + assert type(path) == str # noqa: E721 + os.mkdir(path) + + # [2025-02-03] Old name of parameter attempts is "retries". + def rmdirs(self, path, ignore_errors=True, attempts=3, delay=1): + """ + Removes a directory and its contents, retrying on failure. + + :param path: Path to the directory. + :param ignore_errors: If True, ignore errors. + :param retries: Number of attempts to remove the directory. + :param delay: Delay between attempts in seconds. + """ + assert type(path) == str # noqa: E721 + assert type(ignore_errors) == bool # noqa: E721 + assert type(attempts) == int # noqa: E721 + assert type(delay) == int or type(delay) == float # noqa: E721 + assert attempts > 0 + assert delay >= 0 + + attempt = 0 + while True: + assert attempt < attempts + attempt += 1 + try: + rmtree(path) + except FileNotFoundError: + pass + except Exception as e: + if attempt < attempt: + errMsg = "Failed to remove directory {0} on attempt {1} ({2}): {3}".format( + path, attempt, type(e).__name__, e + ) + logging.warning(errMsg) + time.sleep(delay) + continue + + assert attempt == attempts + if not ignore_errors: + raise + + return False + + # OK! + return True + + def rmdir(self, path: str): + assert type(path) == str # noqa: E721 + os.rmdir(path) def listdir(self, path): return os.listdir(path) @@ -184,27 +336,56 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal Args: filename: The file path where the data will be written. data: The data to be written to the file. - truncate: If True, the file will be truncated before writing ('w' or 'wb' option); - if False (default), data will be appended ('a' or 'ab' option). - binary: If True, the data will be written in binary mode ('wb' or 'ab' option); - if False (default), the data will be written in text mode ('w' or 'a' option). - read_and_write: If True, the file will be opened with read and write permissions ('r+' option); - if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option) + truncate: If True, the file will be truncated before writing ('w' option); + if False (default), data will be appended ('a' option). + binary: If True, the data will be written in binary mode ('b' option); + if False (default), the data will be written in text mode. + read_and_write: If True, the file will be opened with read and write permissions ('+' option); + if False (default), only write permission will be used. """ - # If it is a bytes str or list if isinstance(data, bytes) or isinstance(data, list) and all(isinstance(item, bytes) for item in data): binary = True - mode = "wb" if binary else "w" - if not truncate: - mode = "ab" if binary else "a" + + mode = "w" if truncate else "a" + if read_and_write: - mode = "r+b" if binary else "r+" + mode += "+" + + # If it is a bytes str or list + if binary: + mode += "b" + + assert type(mode) == str # noqa: E721 + assert mode != "" with open(filename, mode) as file: if isinstance(data, list): - file.writelines(data) + data2 = [__class__._prepare_line_to_write(s, binary) for s in data] + file.writelines(data2) else: - file.write(data) + data2 = __class__._prepare_data_to_write(data, binary) + file.write(data2) + + @staticmethod + def _prepare_line_to_write(data, binary): + data = __class__._prepare_data_to_write(data, binary) + + if binary: + assert type(data) == bytes # noqa: E721 + return data.rstrip(b'\n') + b'\n' + + assert type(data) == str # noqa: E721 + return data.rstrip('\n') + '\n' + + @staticmethod + def _prepare_data_to_write(data, binary): + if isinstance(data, bytes): + return data if binary else data.decode() + + if isinstance(data, str): + return data if not binary else data.encode() + + raise InvalidOperationException("Unknown type of data type [{0}].".format(type(data).__name__)) def touch(self, filename): """ @@ -219,13 +400,35 @@ def touch(self, filename): os.utime(filename, None) def read(self, filename, encoding=None, binary=False): - mode = "rb" if binary else "r" - with open(filename, mode) as file: + assert type(filename) == str # noqa: E721 + assert encoding is None or type(encoding) == str # noqa: E721 + assert type(binary) == bool # noqa: E721 + + if binary: + if encoding is not None: + raise InvalidOperationException("Enconding is not allowed for read binary operation") + + return self._read__binary(filename) + + # python behavior + assert (None or "abc") == "abc" + assert ("" or "abc") == "abc" + + return self._read__text_with_encoding(filename, encoding or get_default_encoding()) + + def _read__text_with_encoding(self, filename, encoding): + assert type(filename) == str # noqa: E721 + assert type(encoding) == str # noqa: E721 + with open(filename, mode='r', encoding=encoding) as file: # open in a text mode content = file.read() - if binary: - return content - if isinstance(content, bytes): - return content.decode(encoding or get_default_encoding()) + assert type(content) == str # noqa: E721 + return content + + def _read__binary(self, filename): + assert type(filename) == str # noqa: E721 + with open(filename, 'rb') as file: # open in a binary mode + content = file.read() + assert type(content) == bytes # noqa: E721 return content def readlines(self, filename, num_lines=0, binary=False, encoding=None): @@ -233,12 +436,26 @@ def readlines(self, filename, num_lines=0, binary=False, encoding=None): Read lines from a local file. If num_lines is greater than 0, only the last num_lines lines will be read. """ + assert type(num_lines) == int # noqa: E721 + assert type(filename) == str # noqa: E721 + assert type(binary) == bool # noqa: E721 + assert encoding is None or type(encoding) == str # noqa: E721 assert num_lines >= 0 + + if binary: + assert encoding is None + pass + elif encoding is None: + encoding = get_default_encoding() + assert type(encoding) == str # noqa: E721 + else: + assert type(encoding) == str # noqa: E721 + pass + mode = 'rb' if binary else 'r' if num_lines == 0: with open(filename, mode, encoding=encoding) as file: # open in binary mode return file.readlines() - else: bufsize = 8192 buffers = 1 @@ -261,35 +478,60 @@ def readlines(self, filename, num_lines=0, binary=False, encoding=None): buffers * max(2, int(num_lines / max(cur_lines, 1))) ) # Adjust buffer size + def read_binary(self, filename, offset): + assert type(filename) == str # noqa: E721 + assert type(offset) == int # noqa: E721 + + if offset < 0: + raise ValueError("Negative 'offset' is not supported.") + + with open(filename, 'rb') as file: # open in a binary mode + file.seek(offset, os.SEEK_SET) + r = file.read() + assert type(r) == bytes # noqa: E721 + return r + def isfile(self, remote_file): return os.path.isfile(remote_file) def isdir(self, dirname): return os.path.isdir(dirname) + def get_file_size(self, filename): + assert filename is not None + assert type(filename) == str # noqa: E721 + return os.path.getsize(filename) + def remove_file(self, filename): return os.remove(filename) # Processes control - def kill(self, pid, signal): + def kill(self, pid, signal, expect_error=False): # Kill the process cmd = "kill -{} {}".format(signal, pid) - return self.exec_command(cmd) + return self.exec_command(cmd, expect_error=expect_error) def get_pid(self): # Get current process id return os.getpid() def get_process_children(self, pid): + assert type(pid) == int # noqa: E721 return psutil.Process(pid).children() - # Database control - def db_connect(self, dbname, user, password=None, host="localhost", port=5432): - conn = pglib.connect( - host=host, - port=port, - database=dbname, - user=user, - password=password, - ) - return conn + def is_port_free(self, number: int) -> bool: + assert type(number) == int # noqa: E721 + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("", number)) + return True + except OSError: + return False + + def get_tempdir(self) -> str: + r = tempfile.gettempdir() + assert r is not None + assert type(r) == str # noqa: E721 + assert os.path.exists(r) + return r diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py index dd6613cf..45e4f71c 100644 --- a/testgres/operations/os_ops.py +++ b/testgres/operations/os_ops.py @@ -1,29 +1,25 @@ +import getpass import locale -try: - import psycopg2 as pglib # noqa: F401 -except ImportError: - try: - import pg8000 as pglib # noqa: F401 - except ImportError: - raise ImportError("You must have psycopg2 or pg8000 modules installed") - class ConnectionParams: - def __init__(self, host='127.0.0.1', ssh_key=None, username=None): + def __init__(self, host='127.0.0.1', port=None, ssh_key=None, username=None): self.host = host + self.port = port self.ssh_key = ssh_key self.username = username def get_default_encoding(): - return locale.getdefaultlocale()[1] or 'UTF-8' + if not hasattr(locale, 'getencoding'): + locale.getencoding = locale.getpreferredencoding + return locale.getencoding() or 'UTF-8' class OsOperations: def __init__(self, username=None): self.ssh_key = None - self.username = username + self.username = username or getpass.getuser() # Command execution def exec_command(self, cmd, **kwargs): @@ -33,6 +29,9 @@ def exec_command(self, cmd, **kwargs): def environ(self, var_name): raise NotImplementedError() + def cwd(self): + raise NotImplementedError() + def find_executable(self, executable): raise NotImplementedError() @@ -44,9 +43,8 @@ def set_env(self, var_name, var_val): # Check if the directory is already in PATH raise NotImplementedError() - # Get environment variables def get_user(self): - raise NotImplementedError() + return self.username def get_name(self): raise NotImplementedError() @@ -55,9 +53,17 @@ def get_name(self): def makedirs(self, path, remove_existing=False): raise NotImplementedError() + def makedir(self, path: str): + assert type(path) == str # noqa: E721 + raise NotImplementedError() + def rmdirs(self, path, ignore_errors=True): raise NotImplementedError() + def rmdir(self, path: str): + assert type(path) == str # noqa: E721 + raise NotImplementedError() + def listdir(self, path): raise NotImplementedError() @@ -71,6 +77,9 @@ def pathsep(self): def mkdtemp(self, prefix=None): raise NotImplementedError() + def mkstemp(self, prefix=None): + raise NotImplementedError() + def copytree(self, src, dst): raise NotImplementedError() @@ -87,9 +96,25 @@ def read(self, filename, encoding, binary): def readlines(self, filename): raise NotImplementedError() + def read_binary(self, filename, offset): + assert type(filename) == str # noqa: E721 + assert type(offset) == int # noqa: E721 + assert offset >= 0 + raise NotImplementedError() + def isfile(self, remote_file): raise NotImplementedError() + def isdir(self, dirname): + raise NotImplementedError() + + def get_file_size(self, filename): + raise NotImplementedError() + + def remove_file(self, filename): + assert type(filename) == str # noqa: E721 + raise NotImplementedError() + # Processes control def kill(self, pid, signal): # Kill the process @@ -102,6 +127,9 @@ def get_pid(self): def get_process_children(self, pid): raise NotImplementedError() - # Database control - def db_connect(self, dbname, user, password=None, host="localhost", port=5432): + def is_port_free(self, number: int): + assert type(number) == int # noqa: E721 + raise NotImplementedError() + + def get_tempdir(self) -> str: raise NotImplementedError() diff --git a/testgres/operations/raise_error.py b/testgres/operations/raise_error.py new file mode 100644 index 00000000..0d14be5a --- /dev/null +++ b/testgres/operations/raise_error.py @@ -0,0 +1,57 @@ +from ..exceptions import ExecUtilException +from .helpers import Helpers + + +class RaiseError: + @staticmethod + def UtilityExitedWithNonZeroCode(cmd, exit_code, msg_arg, error, out): + assert type(exit_code) == int # noqa: E721 + + msg_arg_s = __class__._TranslateDataIntoString(msg_arg) + assert type(msg_arg_s) == str # noqa: E721 + + msg_arg_s = msg_arg_s.strip() + if msg_arg_s == "": + msg_arg_s = "#no_error_message" + + message = "Utility exited with non-zero code (" + str(exit_code) + "). Error: `" + msg_arg_s + "`" + raise ExecUtilException( + message=message, + command=cmd, + exit_code=exit_code, + out=out, + error=error) + + @staticmethod + def CommandExecutionError(cmd, exit_code, message, error, out): + assert type(exit_code) == int # noqa: E721 + assert type(message) == str # noqa: E721 + assert message != "" + + raise ExecUtilException( + message=message, + command=cmd, + exit_code=exit_code, + out=out, + error=error) + + @staticmethod + def _TranslateDataIntoString(data): + if data is None: + return "" + + if type(data) == bytes: # noqa: E721 + return __class__._TranslateDataIntoString__FromBinary(data) + + return str(data) + + @staticmethod + def _TranslateDataIntoString__FromBinary(data): + assert type(data) == bytes # noqa: E721 + + try: + return data.decode(Helpers.GetDefaultEncoding()) + except UnicodeDecodeError: + pass + + return "#cannot_decode_text" diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index 01251e1c..a478b453 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -1,139 +1,142 @@ -import logging +import getpass import os +import platform import subprocess import tempfile -import platform - -# we support both pg8000 and psycopg2 -try: - import psycopg2 as pglib -except ImportError: - try: - import pg8000 as pglib - except ImportError: - raise ImportError("You must have psycopg2 or pg8000 modules installed") +import io +import logging +import typing from ..exceptions import ExecUtilException +from ..exceptions import InvalidOperationException from .os_ops import OsOperations, ConnectionParams, get_default_encoding +from .raise_error import RaiseError +from .helpers import Helpers error_markers = [b'error', b'Permission denied', b'fatal', b'No such file or directory'] class PsUtilProcessProxy: def __init__(self, ssh, pid): + assert isinstance(ssh, RemoteOperations) + assert type(pid) == int # noqa: E721 self.ssh = ssh self.pid = pid def kill(self): - command = "kill {}".format(self.pid) - self.ssh.exec_command(command) + assert isinstance(self.ssh, RemoteOperations) + assert type(self.pid) == int # noqa: E721 + command = ["kill", str(self.pid)] + self.ssh.exec_command(command, encoding=get_default_encoding()) def cmdline(self): - command = "ps -p {} -o cmd --no-headers".format(self.pid) - stdin, stdout, stderr = self.ssh.exec_command(command, verbose=True, encoding=get_default_encoding()) - cmdline = stdout.strip() + assert isinstance(self.ssh, RemoteOperations) + assert type(self.pid) == int # noqa: E721 + command = ["ps", "-p", str(self.pid), "-o", "cmd", "--no-headers"] + output = self.ssh.exec_command(command, encoding=get_default_encoding()) + assert type(output) == str # noqa: E721 + cmdline = output.strip() + # TODO: This code work wrong if command line contains quoted values. Yes? return cmdline.split() class RemoteOperations(OsOperations): def __init__(self, conn_params: ConnectionParams): - if not platform.system().lower() == "linux": raise EnvironmentError("Remote operations are supported only on Linux!") super().__init__(conn_params.username) self.conn_params = conn_params self.host = conn_params.host + self.port = conn_params.port self.ssh_key = conn_params.ssh_key + self.ssh_args = [] if self.ssh_key: - self.ssh_cmd = ["-i", self.ssh_key] - else: - self.ssh_cmd = [] + self.ssh_args += ["-i", self.ssh_key] + if self.port: + self.ssh_args += ["-p", self.port] self.remote = True - self.username = conn_params.username or self.get_user() - self.add_known_host(self.host) - self.tunnel_process = None + self.username = conn_params.username or getpass.getuser() + self.ssh_dest = f"{self.username}@{self.host}" if conn_params.username else self.host def __enter__(self): return self - def __exit__(self, exc_type, exc_val, exc_tb): - self.close_ssh_tunnel() - - def establish_ssh_tunnel(self, local_port, remote_port): - """ - Establish an SSH tunnel from a local port to a remote PostgreSQL port. - """ - ssh_cmd = ['-N', '-L', f"{local_port}:localhost:{remote_port}"] - self.tunnel_process = self.exec_command(ssh_cmd, get_process=True, timeout=300) - - def close_ssh_tunnel(self): - if hasattr(self, 'tunnel_process'): - self.tunnel_process.terminate() - self.tunnel_process.wait() - del self.tunnel_process - else: - print("No active tunnel to close.") - - def add_known_host(self, host): - known_hosts_path = os.path.expanduser("~/.ssh/known_hosts") - cmd = 'ssh-keyscan -H %s >> %s' % (host, known_hosts_path) - - try: - subprocess.check_call(cmd, shell=True) - logging.info("Successfully added %s to known_hosts." % host) - except subprocess.CalledProcessError as e: - raise Exception("Failed to add %s to known_hosts. Error: %s" % (host, str(e))) - def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None, - stderr=None, get_process=None, timeout=None): + stderr=None, get_process=None, timeout=None, ignore_errors=False, + exec_env=None): """ Execute a command in the SSH session. Args: - cmd (str): The command to be executed. """ - ssh_cmd = [] - if isinstance(cmd, str): - ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + [cmd] - elif isinstance(cmd, list): - ssh_cmd = ['ssh', f"{self.username}@{self.host}"] + self.ssh_cmd + cmd + assert type(expect_error) == bool # noqa: E721 + assert type(ignore_errors) == bool # noqa: E721 + assert exec_env is None or type(exec_env) == dict # noqa: E721 + + input_prepared = None + if not get_process: + input_prepared = Helpers.PrepareProcessInput(input, encoding) # throw + + assert input_prepared is None or (type(input_prepared) == bytes) # noqa: E721 + + if type(cmd) == str: # noqa: E721 + cmd_s = cmd + elif type(cmd) == list: # noqa: E721 + cmd_s = subprocess.list2cmdline(cmd) + else: + raise ValueError("Invalid 'cmd' argument type - {0}".format(type(cmd).__name__)) + + assert type(cmd_s) == str # noqa: E721 + + cmd_items = __class__._make_exec_env_list(exec_env=exec_env) + cmd_items.append(cmd_s) + + env_cmd_s = ';'.join(cmd_items) + + ssh_cmd = ['ssh', self.ssh_dest] + self.ssh_args + [env_cmd_s] + process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + assert not (process is None) if get_process: return process try: - result, error = process.communicate(input, timeout=timeout) + output, error = process.communicate(input=input_prepared, timeout=timeout) except subprocess.TimeoutExpired: process.kill() raise ExecUtilException("Command timed out after {} seconds.".format(timeout)) - exit_status = process.returncode + assert type(output) == bytes # noqa: E721 + assert type(error) == bytes # noqa: E721 if encoding: - result = result.decode(encoding) + output = output.decode(encoding) error = error.decode(encoding) if expect_error: - raise Exception(result, error) - - if not error: - error_found = 0 + if process.returncode == 0: + raise InvalidOperationException("We expected an execution error.") + elif ignore_errors: + pass + elif process.returncode == 0: + pass else: - error_found = exit_status != 0 or any( - marker in error for marker in [b'error', b'Permission denied', b'fatal', b'No such file or directory']) - - if error_found: - if isinstance(error, bytes): - message = b"Utility exited with non-zero code. Error: " + error - else: - message = f"Utility exited with non-zero code. Error: {error}" - raise ExecUtilException(message=message, command=cmd, exit_code=exit_status, out=result) + assert not expect_error + assert not ignore_errors + assert process.returncode != 0 + RaiseError.UtilityExitedWithNonZeroCode( + cmd=cmd, + exit_code=process.returncode, + msg_arg=error, + error=error, + out=output) if verbose: - return exit_status, result, error - else: - return result + return process.returncode, output, error + + return output # Environment setup def environ(self, var_name: str) -> str: @@ -145,6 +148,10 @@ def environ(self, var_name: str) -> str: cmd = "echo ${}".format(var_name) return self.exec_command(cmd, encoding=get_default_encoding()).strip() + def cwd(self): + cmd = 'pwd' + return self.exec_command(cmd, encoding=get_default_encoding()).rstrip() + def find_executable(self, executable): search_paths = self.environ("PATH") if not search_paths: @@ -160,8 +167,30 @@ def find_executable(self, executable): def is_executable(self, file): # Check if the file is executable - is_exec = self.exec_command("test -x {} && echo OK".format(file)) - return is_exec == b"OK\n" + command = ["test", "-x", file] + + exit_status, output, error = self.exec_command(cmd=command, encoding=get_default_encoding(), ignore_errors=True, verbose=True) + + assert type(output) == str # noqa: E721 + assert type(error) == str # noqa: E721 + + if exit_status == 0: + return True + + if exit_status == 1: + return False + + errMsg = "Test operation returns an unknown result code: {0}. File name is [{1}].".format( + exit_status, + file) + + RaiseError.CommandExecutionError( + cmd=command, + exit_code=exit_status, + message=errMsg, + error=error, + out=output + ) def set_env(self, var_name: str, var_val: str): """ @@ -172,10 +201,6 @@ def set_env(self, var_name: str, var_val: str): """ return self.exec_command("export {}={}".format(var_name, var_val)) - # Get environment variables - def get_user(self): - return self.exec_command("echo $USER", encoding=get_default_encoding()).strip() - def get_name(self): cmd = 'python3 -c "import os; print(os.name)"' return self.exec_command(cmd, encoding=get_default_encoding()).strip() @@ -200,20 +225,55 @@ def makedirs(self, path, remove_existing=False): raise Exception("Couldn't create dir {} because of error {}".format(path, error)) return result - def rmdirs(self, path, verbose=False, ignore_errors=True): + def makedir(self, path: str): + assert type(path) == str # noqa: E721 + cmd = ["mkdir", path] + self.exec_command(cmd) + + def rmdirs(self, path, ignore_errors=True): """ Remove a directory in the remote server. Args: - path (str): The path to the directory to be removed. - - verbose (bool): If True, return exit status, result, and error. - ignore_errors (bool): If True, do not raise error if directory does not exist. """ - cmd = "rm -rf {}".format(path) - exit_status, result, error = self.exec_command(cmd, verbose=True) - if verbose: - return exit_status, result, error - else: - return result + assert type(path) == str # noqa: E721 + assert type(ignore_errors) == bool # noqa: E721 + + # ENOENT = 2 - No such file or directory + # ENOTDIR = 20 - Not a directory + + cmd1 = [ + "if", "[", "-d", path, "]", ";", + "then", "rm", "-rf", path, ";", + "elif", "[", "-e", path, "]", ";", + "then", "{", "echo", "cannot remove '" + path + "': it is not a directory", ">&2", ";", "exit", "20", ";", "}", ";", + "else", "{", "echo", "directory '" + path + "' does not exist", ">&2", ";", "exit", "2", ";", "}", ";", + "fi" + ] + + cmd2 = ["sh", "-c", subprocess.list2cmdline(cmd1)] + + try: + self.exec_command(cmd2, encoding=Helpers.GetDefaultEncoding()) + except ExecUtilException as e: + if e.exit_code == 2: # No such file or directory + return True + + if not ignore_errors: + raise + + errMsg = "Failed to remove directory {0} ({1}): {2}".format( + path, type(e).__name__, e + ) + logging.warning(errMsg) + return False + return True + + def rmdir(self, path: str): + assert type(path) == str # noqa: E721 + cmd = ["rmdir", path] + self.exec_command(cmd) def listdir(self, path): """ @@ -221,12 +281,38 @@ def listdir(self, path): Args: path (str): The path to the directory. """ - result = self.exec_command("ls {}".format(path)) - return result.splitlines() + command = ["ls", path] + output = self.exec_command(cmd=command, encoding=get_default_encoding()) + assert type(output) == str # noqa: E721 + result = output.splitlines() + assert type(result) == list # noqa: E721 + return result def path_exists(self, path): - result = self.exec_command("test -e {}; echo $?".format(path), encoding=get_default_encoding()) - return int(result.strip()) == 0 + command = ["test", "-e", path] + + exit_status, output, error = self.exec_command(cmd=command, encoding=get_default_encoding(), ignore_errors=True, verbose=True) + + assert type(output) == str # noqa: E721 + assert type(error) == str # noqa: E721 + + if exit_status == 0: + return True + + if exit_status == 1: + return False + + errMsg = "Test operation returns an unknown result code: {0}. Path is [{1}].".format( + exit_status, + path) + + RaiseError.CommandExecutionError( + cmd=command, + exit_code=exit_status, + message=errMsg, + error=error, + out=output + ) @property def pathsep(self): @@ -246,32 +332,54 @@ def mkdtemp(self, prefix=None): - prefix (str): The prefix of the temporary directory name. """ if prefix: - command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"] + command = ["mktemp", "-d", "-t", prefix + "XXXXXX"] else: - command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", "mktemp -d"] + command = ["mktemp", "-d"] - result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + exec_exitcode, exec_output, exec_error = self.exec_command(command, verbose=True, encoding=get_default_encoding(), ignore_errors=True) - if result.returncode == 0: - temp_dir = result.stdout.strip() - if not os.path.isabs(temp_dir): - temp_dir = os.path.join('/home', self.username, temp_dir) - return temp_dir - else: - raise ExecUtilException(f"Could not create temporary directory. Error: {result.stderr}") + assert type(exec_exitcode) == int # noqa: E721 + assert type(exec_output) == str # noqa: E721 + assert type(exec_error) == str # noqa: E721 + + if exec_exitcode != 0: + RaiseError.CommandExecutionError( + cmd=command, + exit_code=exec_exitcode, + message="Could not create temporary directory.", + error=exec_error, + out=exec_output) + + temp_dir = exec_output.strip() + return temp_dir def mkstemp(self, prefix=None): + """ + Creates a temporary file in the remote server. + Args: + - prefix (str): The prefix of the temporary directory name. + """ if prefix: - temp_dir = self.exec_command("mktemp {}XXXXX".format(prefix), encoding=get_default_encoding()) + command = ["mktemp", "-t", prefix + "XXXXXX"] else: - temp_dir = self.exec_command("mktemp", encoding=get_default_encoding()) + command = ["mktemp"] - if temp_dir: - if not os.path.isabs(temp_dir): - temp_dir = os.path.join('/home', self.username, temp_dir.strip()) - return temp_dir - else: - raise ExecUtilException("Could not create temporary directory.") + exec_exitcode, exec_output, exec_error = self.exec_command(command, verbose=True, encoding=get_default_encoding(), ignore_errors=True) + + assert type(exec_exitcode) == int # noqa: E721 + assert type(exec_output) == str # noqa: E721 + assert type(exec_error) == str # noqa: E721 + + if exec_exitcode != 0: + RaiseError.CommandExecutionError( + cmd=command, + exit_code=exec_exitcode, + message="Could not create temporary file.", + error=exec_error, + out=exec_output) + + temp_file = exec_output.strip() + return temp_file def copytree(self, src, dst): if not os.path.isabs(dst): @@ -285,38 +393,54 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal if not encoding: encoding = get_default_encoding() mode = "wb" if binary else "w" - if not truncate: - mode = "ab" if binary else "a" - if read_and_write: - mode = "r+b" if binary else "r+" with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file: + # For scp the port is specified by a "-P" option + scp_args = ['-P' if x == '-p' else x for x in self.ssh_args] + if not truncate: - scp_cmd = ['scp'] + self.ssh_cmd + [f"{self.username}@{self.host}:{filename}", tmp_file.name] + scp_cmd = ['scp'] + scp_args + [f"{self.ssh_dest}:{filename}", tmp_file.name] subprocess.run(scp_cmd, check=False) # The file might not exist yet tmp_file.seek(0, os.SEEK_END) - if isinstance(data, bytes) and not binary: - data = data.decode(encoding) - elif isinstance(data, str) and binary: - data = data.encode(encoding) - if isinstance(data, list): - data = [(s if isinstance(s, str) else s.decode(get_default_encoding())).rstrip('\n') + '\n' for s in data] - tmp_file.writelines(data) + data2 = [__class__._prepare_line_to_write(s, binary, encoding) for s in data] + tmp_file.writelines(data2) else: - tmp_file.write(data) + data2 = __class__._prepare_data_to_write(data, binary, encoding) + tmp_file.write(data2) tmp_file.flush() - scp_cmd = ['scp'] + self.ssh_cmd + [tmp_file.name, f"{self.username}@{self.host}:{filename}"] + scp_cmd = ['scp'] + scp_args + [tmp_file.name, f"{self.ssh_dest}:{filename}"] subprocess.run(scp_cmd, check=True) remote_directory = os.path.dirname(filename) - mkdir_cmd = ['ssh'] + self.ssh_cmd + [f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"] + mkdir_cmd = ['ssh'] + self.ssh_args + [self.ssh_dest, f"mkdir -p {remote_directory}"] subprocess.run(mkdir_cmd, check=True) os.remove(tmp_file.name) + @staticmethod + def _prepare_line_to_write(data, binary, encoding): + data = __class__._prepare_data_to_write(data, binary, encoding) + + if binary: + assert type(data) == bytes # noqa: E721 + return data.rstrip(b'\n') + b'\n' + + assert type(data) == str # noqa: E721 + return data.rstrip('\n') + '\n' + + @staticmethod + def _prepare_data_to_write(data, binary, encoding): + if isinstance(data, bytes): + return data if binary else data.decode(encoding) + + if isinstance(data, str): + return data if not binary else data.encode(encoding) + + raise InvalidOperationException("Unknown type of data type [{0}].".format(type(data).__name__)) + def touch(self, filename): """ Create a new file or update the access and modification times of an existing file on the remote server. @@ -329,29 +453,86 @@ def touch(self, filename): self.exec_command("touch {}".format(filename)) def read(self, filename, binary=False, encoding=None): - cmd = "cat {}".format(filename) - result = self.exec_command(cmd, encoding=encoding) - - if not binary and result: - result = result.decode(encoding or get_default_encoding()) - - return result + assert type(filename) == str # noqa: E721 + assert encoding is None or type(encoding) == str # noqa: E721 + assert type(binary) == bool # noqa: E721 + + if binary: + if encoding is not None: + raise InvalidOperationException("Enconding is not allowed for read binary operation") + + return self._read__binary(filename) + + # python behavior + assert (None or "abc") == "abc" + assert ("" or "abc") == "abc" + + return self._read__text_with_encoding(filename, encoding or get_default_encoding()) + + def _read__text_with_encoding(self, filename, encoding): + assert type(filename) == str # noqa: E721 + assert type(encoding) == str # noqa: E721 + content = self._read__binary(filename) + assert type(content) == bytes # noqa: E721 + buf0 = io.BytesIO(content) + buf1 = io.TextIOWrapper(buf0, encoding=encoding) + content_s = buf1.read() + assert type(content_s) == str # noqa: E721 + return content_s + + def _read__binary(self, filename): + assert type(filename) == str # noqa: E721 + cmd = ["cat", filename] + content = self.exec_command(cmd) + assert type(content) == bytes # noqa: E721 + return content def readlines(self, filename, num_lines=0, binary=False, encoding=None): + assert type(num_lines) == int # noqa: E721 + assert type(filename) == str # noqa: E721 + assert type(binary) == bool # noqa: E721 + assert encoding is None or type(encoding) == str # noqa: E721 + if num_lines > 0: - cmd = "tail -n {} {}".format(num_lines, filename) + cmd = ["tail", "-n", str(num_lines), filename] + else: + cmd = ["cat", filename] + + if binary: + assert encoding is None + pass + elif encoding is None: + encoding = get_default_encoding() + assert type(encoding) == str # noqa: E721 else: - cmd = "cat {}".format(filename) + assert type(encoding) == str # noqa: E721 + pass result = self.exec_command(cmd, encoding=encoding) + assert result is not None - if not binary and result: - lines = result.decode(encoding or get_default_encoding()).splitlines() + if binary: + assert type(result) == bytes # noqa: E721 + lines = result.splitlines() else: + assert type(result) == str # noqa: E721 lines = result.splitlines() + assert type(lines) == list # noqa: E721 return lines + def read_binary(self, filename, offset): + assert type(filename) == str # noqa: E721 + assert type(offset) == int # noqa: E721 + + if offset < 0: + raise ValueError("Negative 'offset' is not supported.") + + cmd = ["tail", "-c", "+{}".format(offset + 1), filename] + r = self.exec_command(cmd) + assert type(r) == bytes # noqa: E721 + return r + def isfile(self, remote_file): stdout = self.exec_command("test -f {}; echo $?".format(remote_file)) result = int(stdout.strip()) @@ -362,6 +543,70 @@ def isdir(self, dirname): response = self.exec_command(cmd) return response.strip() == b"True" + def get_file_size(self, filename): + C_ERR_SRC = "RemoteOpertions::get_file_size" + + assert filename is not None + assert type(filename) == str # noqa: E721 + cmd = ["du", "-b", filename] + + s = self.exec_command(cmd, encoding=get_default_encoding()) + assert type(s) == str # noqa: E721 + + if len(s) == 0: + raise Exception( + "[BUG CHECK] Can't get size of file [{2}]. Remote operation returned an empty string. Check point [{0}][{1}].".format( + C_ERR_SRC, + "#001", + filename + ) + ) + + i = 0 + + while i < len(s) and s[i].isdigit(): + assert s[i] >= '0' + assert s[i] <= '9' + i += 1 + + if i == 0: + raise Exception( + "[BUG CHECK] Can't get size of file [{2}]. Remote operation returned a bad formatted string. Check point [{0}][{1}].".format( + C_ERR_SRC, + "#002", + filename + ) + ) + + if i == len(s): + raise Exception( + "[BUG CHECK] Can't get size of file [{2}]. Remote operation returned a bad formatted string. Check point [{0}][{1}].".format( + C_ERR_SRC, + "#003", + filename + ) + ) + + if not s[i].isspace(): + raise Exception( + "[BUG CHECK] Can't get size of file [{2}]. Remote operation returned a bad formatted string. Check point [{0}][{1}].".format( + C_ERR_SRC, + "#004", + filename + ) + ) + + r = 0 + + for i2 in range(0, i): + ch = s[i2] + assert ch >= '0' + assert ch <= '9' + # Here is needed to check overflow or that it is a human-valid result? + r = (r * 10) + ord(ch) - ord('0') + + return r + def remove_file(self, filename): cmd = "rm {}".format(filename) return self.exec_command(cmd) @@ -377,30 +622,158 @@ def get_pid(self): return int(self.exec_command("echo $$", encoding=get_default_encoding())) def get_process_children(self, pid): - command = ["ssh"] + self.ssh_cmd + [f"{self.username}@{self.host}", f"pgrep -P {pid}"] + assert type(pid) == int # noqa: E721 + command = ["ssh"] + self.ssh_args + [self.ssh_dest, "pgrep", "-P", str(pid)] result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) if result.returncode == 0: children = result.stdout.strip().splitlines() return [PsUtilProcessProxy(self, int(child_pid.strip())) for child_pid in children] + + raise ExecUtilException(f"Error in getting process children. Error: {result.stderr}") + + def is_port_free(self, number: int) -> bool: + assert type(number) == int # noqa: E721 + + cmd = ["nc", "-w", "5", "-z", "-v", "localhost", str(number)] + + exit_status, output, error = self.exec_command(cmd=cmd, encoding=get_default_encoding(), ignore_errors=True, verbose=True) + + assert type(output) == str # noqa: E721 + assert type(error) == str # noqa: E721 + + if exit_status == 0: + return __class__._is_port_free__process_0(error) + + if exit_status == 1: + return __class__._is_port_free__process_1(error) + + errMsg = "nc returns an unknown result code: {0}".format(exit_status) + + RaiseError.CommandExecutionError( + cmd=cmd, + exit_code=exit_status, + message=errMsg, + error=error, + out=output + ) + + def get_tempdir(self) -> str: + command = ["mktemp", "-u", "-d"] + + exec_exitcode, exec_output, exec_error = self.exec_command( + command, + verbose=True, + encoding=get_default_encoding(), + ignore_errors=True + ) + + assert type(exec_exitcode) == int # noqa: E721 + assert type(exec_output) == str # noqa: E721 + assert type(exec_error) == str # noqa: E721 + + if exec_exitcode != 0: + RaiseError.CommandExecutionError( + cmd=command, + exit_code=exec_exitcode, + message="Could not detect a temporary directory.", + error=exec_error, + out=exec_output) + + temp_subdir = exec_output.strip() + assert type(temp_subdir) == str # noqa: E721 + temp_dir = os.path.dirname(temp_subdir) + assert type(temp_dir) == str # noqa: E721 + return temp_dir + + @staticmethod + def _is_port_free__process_0(error: str) -> bool: + assert type(error) == str # noqa: E721 + # + # Example of error text: + # "Connection to localhost (127.0.0.1) 1024 port [tcp/*] succeeded!\n" + # + # May be here is needed to check error message? + # + return False + + @staticmethod + def _is_port_free__process_1(error: str) -> bool: + assert type(error) == str # noqa: E721 + # + # Example of error text: + # "nc: connect to localhost (127.0.0.1) port 1024 (tcp) failed: Connection refused\n" + # + # May be here is needed to check error message? + # + return True + + @staticmethod + def _make_exec_env_list(exec_env: typing.Dict) -> typing.List[str]: + env: typing.Dict[str, str] = dict() + + # ---------------------------------- SYSTEM ENV + for envvar in os.environ.items(): + if __class__._does_put_envvar_into_exec_cmd(envvar[0]): + env[envvar[0]] = envvar[1] + + # ---------------------------------- EXEC (LOCAL) ENV + if exec_env is None: + pass else: - raise ExecUtilException(f"Error in getting process children. Error: {result.stderr}") + for envvar in exec_env.items(): + assert type(envvar) == tuple # noqa: E721 + assert len(envvar) == 2 + assert type(envvar[0]) == str # noqa: E721 + env[envvar[0]] = envvar[1] + + # ---------------------------------- FINAL BUILD + result: typing.List[str] = list() + for envvar in env.items(): + assert type(envvar) == tuple # noqa: E721 + assert len(envvar) == 2 + assert type(envvar[0]) == str # noqa: E721 + + if envvar[1] is None: + result.append("unset " + envvar[0]) + else: + assert type(envvar[1]) == str # noqa: E721 + qvalue = __class__._quote_envvar(envvar[1]) + assert type(qvalue) == str # noqa: E721 + result.append(envvar[0] + "=" + qvalue) + continue - # Database control - def db_connect(self, dbname, user, password=None, host="localhost", port=5432): - """ - Established SSH tunnel and Connects to a PostgreSQL - """ - self.establish_ssh_tunnel(local_port=port, remote_port=5432) - try: - conn = pglib.connect( - host=host, - port=port, - database=dbname, - user=user, - password=password, - ) - return conn - except Exception as e: - raise Exception(f"Could not connect to the database. Error: {e}") + return result + + sm_envs_for_exec_cmd = ["LANG", "LANGUAGE"] + + @staticmethod + def _does_put_envvar_into_exec_cmd(name: str) -> bool: + assert type(name) == str # noqa: E721 + name = name.upper() + if name.startswith("LC_"): + return True + if name in __class__.sm_envs_for_exec_cmd: + return True + return False + + @staticmethod + def _quote_envvar(value: str) -> str: + assert type(value) == str # noqa: E721 + result = "\"" + for ch in value: + if ch == "\"": + result += "\\\"" + elif ch == "\\": + result += "\\\\" + else: + result += ch + result += "\"" + return result + + +def normalize_error(error): + if isinstance(error, bytes): + return error.decode() + return error diff --git a/testgres/plugins/__init__.py b/testgres/plugins/__init__.py index 8c19a23b..824eadc6 100644 --- a/testgres/plugins/__init__.py +++ b/testgres/plugins/__init__.py @@ -1,7 +1,7 @@ -from pg_probackup2.gdb import GDBobj -from pg_probackup2.app import ProbackupApp, ProbackupException -from pg_probackup2.init_helpers import init_params -from pg_probackup2.storage.fs_backup import FSTestBackupDir +from .pg_probackup2.pg_probackup2.gdb import GDBobj +from .pg_probackup2.pg_probackup2.app import ProbackupApp, ProbackupException +from .pg_probackup2.pg_probackup2.init_helpers import init_params +from .pg_probackup2.pg_probackup2.storage.fs_backup import FSTestBackupDir __all__ = [ "ProbackupApp", "ProbackupException", "init_params", "FSTestBackupDir", "GDBobj" diff --git a/testgres/plugins/pg_probackup2/pg_probackup2/app.py b/testgres/plugins/pg_probackup2/pg_probackup2/app.py index 1a4ca9e7..5166e9b8 100644 --- a/testgres/plugins/pg_probackup2/pg_probackup2/app.py +++ b/testgres/plugins/pg_probackup2/pg_probackup2/app.py @@ -1,6 +1,7 @@ import contextlib import importlib import json +import logging import os import re import subprocess @@ -43,23 +44,52 @@ def __str__(self): class ProbackupApp: def __init__(self, test_class: unittest.TestCase, - pg_node, pb_log_path, test_env, auto_compress_alg, backup_dir): + pg_node, pb_log_path, test_env, auto_compress_alg, backup_dir, probackup_path=None): + self.process = None self.test_class = test_class self.pg_node = pg_node self.pb_log_path = pb_log_path self.test_env = test_env self.auto_compress_alg = auto_compress_alg self.backup_dir = backup_dir - self.probackup_path = init_params.probackup_path + self.probackup_path = probackup_path or init_params.probackup_path self.probackup_old_path = init_params.probackup_old_path self.remote = init_params.remote + self.wal_tree_enabled = init_params.wal_tree_enabled self.verbose = init_params.verbose self.archive_compress = init_params.archive_compress self.test_class.output = None self.execution_time = None + def form_daemon_process(self, cmdline, env): + def stream_output(stream: subprocess.PIPE) -> None: + try: + for line in iter(stream.readline, ''): + print(line) + self.test_class.output += line + finally: + stream.close() + + self.process = subprocess.Popen( + cmdline, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + env=env + ) + logging.info(f"Process started in background with PID: {self.process.pid}") + + if self.process.stdout and self.process.stderr: + stdout_thread = threading.Thread(target=stream_output, args=(self.process.stdout,), daemon=True) + stderr_thread = threading.Thread(target=stream_output, args=(self.process.stderr,), daemon=True) + + stdout_thread.start() + stderr_thread.start() + + return self.process.pid + def run(self, command, gdb=False, old_binary=False, return_id=True, env=None, - skip_log_directory=False, expect_error=False, use_backup_dir=True): + skip_log_directory=False, expect_error=False, use_backup_dir=True, daemonize=False): """ Run pg_probackup backup_dir: target directory for making backup @@ -72,9 +102,11 @@ def run(self, command, gdb=False, old_binary=False, return_id=True, env=None, command = [command[0], *use_backup_dir.pb_args, *command[1:]] elif use_backup_dir: command = [command[0], *self.backup_dir.pb_args, *command[1:]] + else: + command = [command[0], *self.backup_dir.pb_args[2:], *command[1:]] if not self.probackup_old_path and old_binary: - print('PGPROBACKUPBIN_OLD is not set') + logging.error('PGPROBACKUPBIN_OLD is not set') exit(1) if old_binary: @@ -107,19 +139,21 @@ def run(self, command, gdb=False, old_binary=False, return_id=True, env=None, return GDBobj(cmdline, self.test_class) try: - result = None if type(gdb) is tuple and gdb[0] == 'suspend': # special test flow for manually debug probackup gdb_port = gdb[1] cmdline = ['gdbserver'] + ['localhost:' + str(gdb_port)] + cmdline - print("pg_probackup gdb suspended, waiting gdb connection on localhost:{0}".format(gdb_port)) + logging.warning("pg_probackup gdb suspended, waiting gdb connection on localhost:{0}".format(gdb_port)) start_time = time.time() - self.test_class.output = subprocess.check_output( - cmdline, - stderr=subprocess.STDOUT, - env=env - ).decode('utf-8', errors='replace') + if daemonize: + return self.form_daemon_process(cmdline, env) + else: + self.test_class.output = subprocess.check_output( + cmdline, + stderr=subprocess.STDOUT, + env=env + ).decode('utf-8', errors='replace') end_time = time.time() self.execution_time = end_time - start_time @@ -185,6 +219,9 @@ def add_instance(self, instance, node, old_binary=False, options=None, expect_er '--remote-proto=ssh', '--remote-host=localhost'] + if self.wal_tree_enabled: + options = options + ['--wal-tree'] + return self.run(cmd + options, old_binary=old_binary, expect_error=expect_error) def set_config(self, instance, old_binary=False, options=None, expect_error=False): @@ -233,7 +270,7 @@ def backup_node( if options is None: options = [] if not node and not data_dir: - print('You must provide ether node or data_dir for backup') + logging.error('You must provide ether node or data_dir for backup') exit(1) if not datname: @@ -388,6 +425,7 @@ def catchup_node( backup_mode, source_pgdata, destination_node, options=None, remote_host='localhost', + remote_port=None, expect_error=False, gdb=False ): @@ -401,7 +439,9 @@ def catchup_node( '--destination-pgdata={0}'.format(destination_node.data_dir) ] if self.remote: - cmd_list += ['--remote-proto=ssh', '--remote-host=%s' % remote_host] + cmd_list += ['--remote-proto=ssh', f'--remote-host={remote_host}'] + if remote_port: + cmd_list.append(f'--remote-port={remote_port}') if self.verbose: cmd_list += [ '--log-level-file=VERBOSE', @@ -499,7 +539,7 @@ def show( if i == '': backup_record_split.remove(i) if len(header_split) != len(backup_record_split): - print(warning.format( + logging.error(warning.format( header=header, body=body, header_split=header_split, body_split=backup_record_split) @@ -578,7 +618,7 @@ def show_archive( else: show_splitted = self.run(cmd_list + options, old_binary=old_binary, expect_error=expect_error).splitlines() - print(show_splitted) + logging.error(show_splitted) exit(1) def validate( @@ -702,9 +742,13 @@ def set_archiving( if overwrite: archive_command += ' --overwrite' - archive_command += ' --log-level-console=VERBOSE' - archive_command += ' -j 5' - archive_command += ' --batch-size 10' + if init_params.major_version > 2: + archive_command += ' --log-level-console=trace' + else: + archive_command += ' --log-level-console=VERBOSE' + archive_command += ' -j 5' + archive_command += ' --batch-size 10' + archive_command += ' --no-sync' if archive_timeout: @@ -766,7 +810,7 @@ def load_backup_class(fs_type): if fs_type: implementation = fs_type - print("Using ", implementation) + logging.info("Using ", implementation) module_name, class_name = implementation.rsplit(sep='.', maxsplit=1) module = importlib.import_module(module_name) @@ -798,5 +842,22 @@ def archive_get(self, instance, wal_file_name, wal_file_path, options=None, expe ] return self.run(cmd + options, expect_error=expect_error) + def maintain( + self, instance=None, backup_id=None, + options=None, old_binary=False, gdb=False, expect_error=False + ): + if options is None: + options = [] + cmd_list = [ + 'maintain', + ] + if instance: + cmd_list += ['--instance={0}'.format(instance)] + if backup_id: + cmd_list += ['-i', backup_id] + + return self.run(cmd_list + options, old_binary=old_binary, gdb=gdb, + expect_error=expect_error) + def build_backup_dir(self, backup='backup'): return fs_backup_class(rel_path=self.rel_path, backup=backup) diff --git a/testgres/plugins/pg_probackup2/pg_probackup2/gdb.py b/testgres/plugins/pg_probackup2/pg_probackup2/gdb.py index 0b61da65..2424c04d 100644 --- a/testgres/plugins/pg_probackup2/pg_probackup2/gdb.py +++ b/testgres/plugins/pg_probackup2/pg_probackup2/gdb.py @@ -37,7 +37,7 @@ def __init__(self, cmd, env, attach=False): " to run GDB tests") raise GdbException("No gdb usage possible.") - # Check gdb presense + # Check gdb presence try: gdb_version, _ = subprocess.Popen( ['gdb', '--version'], @@ -108,6 +108,9 @@ def kill(self): self.proc.stdin.close() self.proc.stdout.close() + def terminate_subprocess(self): + self._execute('kill') + def set_breakpoint(self, location): result = self._execute('break ' + location) diff --git a/testgres/plugins/pg_probackup2/pg_probackup2/init_helpers.py b/testgres/plugins/pg_probackup2/pg_probackup2/init_helpers.py index 73731a6e..c4570a39 100644 --- a/testgres/plugins/pg_probackup2/pg_probackup2/init_helpers.py +++ b/testgres/plugins/pg_probackup2/pg_probackup2/init_helpers.py @@ -1,3 +1,4 @@ +import logging from functools import reduce import getpass import os @@ -31,7 +32,7 @@ cached_initdb_dir=False, node_cleanup_full=delete_logs) except Exception as e: - print("Can't configure testgres: {0}".format(e)) + logging.warning("Can't configure testgres: {0}".format(e)) class Init(object): @@ -104,7 +105,7 @@ def __init__(self): if os.path.isfile(probackup_path_tmp): if not os.access(probackup_path_tmp, os.X_OK): - print('{0} is not an executable file'.format( + logging.warning('{0} is not an executable file'.format( probackup_path_tmp)) else: self.probackup_path = probackup_path_tmp @@ -114,14 +115,13 @@ def __init__(self): if os.path.isfile(probackup_path_tmp): if not os.access(probackup_path_tmp, os.X_OK): - print('{0} is not an executable file'.format( + logging.warning('{0} is not an executable file'.format( probackup_path_tmp)) else: self.probackup_path = probackup_path_tmp if not self.probackup_path: - print('pg_probackup binary is not found') - exit(1) + raise Exception('pg_probackup binary is not found') if os.name == 'posix': self.EXTERNAL_DIRECTORY_DELIMITER = ':' @@ -169,6 +169,11 @@ def __init__(self): self.remote = test_env.get('PGPROBACKUP_SSH_REMOTE', None) == 'ON' self.ptrack = test_env.get('PG_PROBACKUP_PTRACK', None) == 'ON' and self.pg_config_version >= 110000 + self.wal_tree_enabled = test_env.get('PG_PROBACKUP_WAL_TREE_ENABLED', None) == 'ON' + + self.bckp_source = test_env.get('PG_PROBACKUP_SOURCE', 'pro').lower() + if self.bckp_source not in ('base', 'direct', 'pro'): + raise Exception("Wrong PG_PROBACKUP_SOURCE value. Available options: base|direct|pro") self.paranoia = test_env.get('PG_PROBACKUP_PARANOIA', None) == 'ON' env_compress = test_env.get('ARCHIVE_COMPRESSION', None) @@ -207,11 +212,15 @@ def __init__(self): if self.probackup_version.split('.')[0].isdigit(): self.major_version = int(self.probackup_version.split('.')[0]) else: - print('Can\'t process pg_probackup version \"{}\": the major version is expected to be a number'.format(self.probackup_version)) - sys.exit(1) + raise Exception('Can\'t process pg_probackup version \"{}\": the major version is expected to be a number'.format(self.probackup_version)) def test_env(self): return self._test_env.copy() -init_params = Init() +try: + init_params = Init() +except Exception as e: + logging.error(str(e)) + logging.warning("testgres.plugins.probackup2.init_params is set to None.") + init_params = None diff --git a/testgres/helpers/__init__.py b/testgres/plugins/pg_probackup2/pg_probackup2/tests/__init__.py similarity index 100% rename from testgres/helpers/__init__.py rename to testgres/plugins/pg_probackup2/pg_probackup2/tests/__init__.py diff --git a/testgres/plugins/pg_probackup2/pg_probackup2/tests/basic_test.py b/testgres/plugins/pg_probackup2/pg_probackup2/tests/basic_test.py deleted file mode 100644 index f5a82d38..00000000 --- a/testgres/plugins/pg_probackup2/pg_probackup2/tests/basic_test.py +++ /dev/null @@ -1,79 +0,0 @@ -import os -import shutil -import unittest -import testgres -from pg_probackup2.app import ProbackupApp -from pg_probackup2.init_helpers import Init, init_params -from pg_probackup2.app import build_backup_dir - - -class TestUtils: - @staticmethod - def get_module_and_function_name(test_id): - try: - module_name = test_id.split('.')[-2] - fname = test_id.split('.')[-1] - except IndexError: - print(f"Couldn't get module name and function name from test_id: `{test_id}`") - module_name, fname = test_id.split('(')[1].split('.')[1], test_id.split('(')[0] - return module_name, fname - - -class ProbackupTest(unittest.TestCase): - def setUp(self): - self.setup_test_environment() - self.setup_test_paths() - self.setup_backup_dir() - self.setup_probackup() - - def setup_test_environment(self): - self.output = None - self.cmd = None - self.nodes_to_cleanup = [] - self.module_name, self.fname = TestUtils.get_module_and_function_name(self.id()) - self.test_env = Init().test_env() - - def setup_test_paths(self): - self.rel_path = os.path.join(self.module_name, self.fname) - self.test_path = os.path.join(init_params.tmp_path, self.rel_path) - os.makedirs(self.test_path) - self.pb_log_path = os.path.join(self.test_path, "pb_log") - - def setup_backup_dir(self): - self.backup_dir = build_backup_dir(self, 'backup') - self.backup_dir.cleanup() - - def setup_probackup(self): - self.pg_node = testgres.NodeApp(self.test_path, self.nodes_to_cleanup) - self.pb = ProbackupApp(self, self.pg_node, self.pb_log_path, self.test_env, - auto_compress_alg='zlib', backup_dir=self.backup_dir) - - def tearDown(self): - if os.path.exists(self.test_path): - shutil.rmtree(self.test_path) - - -class BasicTest(ProbackupTest): - def test_full_backup(self): - # Setting up a simple test node - node = self.pg_node.make_simple('node', pg_options={"fsync": "off", "synchronous_commit": "off"}) - - # Initialize and configure Probackup - self.pb.init() - self.pb.add_instance('node', node) - self.pb.set_archiving('node', node) - - # Start the node and initialize pgbench - node.slow_start() - node.pgbench_init(scale=100, no_vacuum=True) - - # Perform backup and validation - backup_id = self.pb.backup_node('node', node) - out = self.pb.validate('node', backup_id) - - # Check if the backup is valid - self.assertIn(f"INFO: Backup {backup_id} is valid", out) - - -if __name__ == "__main__": - unittest.main() diff --git a/testgres/plugins/pg_probackup2/pg_probackup2/tests/test_basic.py b/testgres/plugins/pg_probackup2/pg_probackup2/tests/test_basic.py new file mode 100644 index 00000000..2540ddb0 --- /dev/null +++ b/testgres/plugins/pg_probackup2/pg_probackup2/tests/test_basic.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import os +import shutil +import pytest + +import testgres +from ...pg_probackup2.app import ProbackupApp +from ...pg_probackup2.init_helpers import Init, init_params +from ..storage.fs_backup import FSTestBackupDir + + +class ProbackupTest: + pg_node: testgres.NodeApp + + @staticmethod + def probackup_is_available() -> bool: + p = os.environ.get("PGPROBACKUPBIN") + + if p is None: + return False + + if not os.path.exists(p): + return False + + return True + + @pytest.fixture(autouse=True, scope="function") + def implicit_fixture(self, request: pytest.FixtureRequest): + assert isinstance(request, pytest.FixtureRequest) + self.helper__setUp(request) + yield + self.helper__tearDown() + + def helper__setUp(self, request: pytest.FixtureRequest): + assert isinstance(request, pytest.FixtureRequest) + + self.helper__setup_test_environment(request) + self.helper__setup_test_paths() + self.helper__setup_backup_dir() + self.helper__setup_probackup() + + def helper__setup_test_environment(self, request: pytest.FixtureRequest): + assert isinstance(request, pytest.FixtureRequest) + + self.output = None + self.cmd = None + self.nodes_to_cleanup = [] + self.module_name, self.fname = request.node.cls.__name__, request.node.name + self.test_env = Init().test_env() + + def helper__setup_test_paths(self): + self.rel_path = os.path.join(self.module_name, self.fname) + self.test_path = os.path.join(init_params.tmp_path, self.rel_path) + os.makedirs(self.test_path, exist_ok=True) + self.pb_log_path = os.path.join(self.test_path, "pb_log") + + def helper__setup_backup_dir(self): + self.backup_dir = self.helper__build_backup_dir('backup') + self.backup_dir.cleanup() + + def helper__setup_probackup(self): + self.pg_node = testgres.NodeApp(self.test_path, self.nodes_to_cleanup) + self.pb = ProbackupApp(self, self.pg_node, self.pb_log_path, self.test_env, + auto_compress_alg='zlib', backup_dir=self.backup_dir) + + def helper__tearDown(self): + if os.path.exists(self.test_path): + shutil.rmtree(self.test_path) + + def helper__build_backup_dir(self, backup='backup'): + return FSTestBackupDir(rel_path=self.rel_path, backup=backup) + + +@pytest.mark.skipif(not ProbackupTest.probackup_is_available(), reason="Check that PGPROBACKUPBIN is defined and is valid.") +class TestBasic(ProbackupTest): + def test_full_backup(self): + assert self.pg_node is not None + assert type(self.pg_node) == testgres.NodeApp # noqa: E721 + assert self.pb is not None + assert type(self.pb) == ProbackupApp # noqa: E721 + + # Setting up a simple test node + node = self.pg_node.make_simple('node', pg_options={"fsync": "off", "synchronous_commit": "off"}) + + assert node is not None + assert type(node) == testgres.PostgresNode # noqa: E721 + + with node: + # Initialize and configure Probackup + self.pb.init() + self.pb.add_instance('node', node) + self.pb.set_archiving('node', node) + + # Start the node and initialize pgbench + node.slow_start() + node.pgbench_init(scale=100, no_vacuum=True) + + # Perform backup and validation + backup_id = self.pb.backup_node('node', node) + out = self.pb.validate('node', backup_id) + + # Check if the backup is valid + assert f"INFO: Backup {backup_id} is valid" in out diff --git a/testgres/plugins/pg_probackup2/setup.py b/testgres/plugins/pg_probackup2/setup.py index 2ff5a503..7a3212e4 100644 --- a/testgres/plugins/pg_probackup2/setup.py +++ b/testgres/plugins/pg_probackup2/setup.py @@ -4,7 +4,7 @@ from distutils.core import setup setup( - version='0.0.3', + version='0.1.0', name='testgres_pg_probackup2', packages=['pg_probackup2', 'pg_probackup2.storage'], description='Plugin for testgres that manages pg_probackup2', diff --git a/testgres/port_manager.py b/testgres/port_manager.py new file mode 100644 index 00000000..1ae696c8 --- /dev/null +++ b/testgres/port_manager.py @@ -0,0 +1,10 @@ +class PortManager: + def __init__(self): + super().__init__() + + def reserve_port(self) -> int: + raise NotImplementedError("PortManager::reserve_port is not implemented.") + + def release_port(self, number: int) -> None: + assert type(number) == int # noqa: E721 + raise NotImplementedError("PortManager::release_port is not implemented.") diff --git a/testgres/utils.py b/testgres/utils.py index 745a2555..d231eec3 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -13,15 +13,25 @@ from six import iteritems -from .helpers.port_manager import PortManager from .exceptions import ExecUtilException from .config import testgres_config as tconf +from .operations.os_ops import OsOperations +from .operations.remote_ops import RemoteOperations +from .operations.local_ops import LocalOperations +from .operations.helpers import Helpers as OsHelpers + +from .impl.port_manager__generic import PortManager__Generic # rows returned by PG_CONFIG _pg_config_data = {} +# +# The old, global "port manager" always worked with LOCAL system +# +_old_port_manager = PortManager__Generic(LocalOperations.get_single_instance()) + # ports used by nodes -bound_ports = set() +bound_ports = _old_port_manager._reserved_ports # re-export version type @@ -34,23 +44,24 @@ def __init__(self, version: str) -> None: super().__init__(version) -def reserve_port(): +def internal__reserve_port(): """ Generate a new port and add it to 'bound_ports'. """ - port_mng = PortManager() - port = port_mng.find_free_port(exclude_ports=bound_ports) - bound_ports.add(port) + return _old_port_manager.reserve_port() - return port - -def release_port(port): +def internal__release_port(port): """ Free port provided by reserve_port(). """ - bound_ports.discard(port) + assert type(port) == int # noqa: E721 + return _old_port_manager.release_port(port) + + +reserve_port = internal__reserve_port +release_port = internal__release_port def execute_utility(args, logfile=None, verbose=False): @@ -64,22 +75,40 @@ def execute_utility(args, logfile=None, verbose=False): Returns: stdout of executed utility. """ - exit_status, out, error = tconf.os_ops.exec_command(args, verbose=True) - # decode result + return execute_utility2(tconf.os_ops, args, logfile, verbose) + + +def execute_utility2( + os_ops: OsOperations, + args, + logfile=None, + verbose=False, + ignore_errors=False, + exec_env=None, +): + assert os_ops is not None + assert isinstance(os_ops, OsOperations) + assert type(verbose) == bool # noqa: E721 + assert type(ignore_errors) == bool # noqa: E721 + assert exec_env is None or type(exec_env) == dict # noqa: E721 + + exit_status, out, error = os_ops.exec_command( + args, + verbose=True, + ignore_errors=ignore_errors, + encoding=OsHelpers.GetDefaultEncoding(), + exec_env=exec_env) + out = '' if not out else out - if isinstance(out, bytes): - out = out.decode('utf-8') - if isinstance(error, bytes): - error = error.decode('utf-8') # write new log entry if possible if logfile: try: - tconf.os_ops.write(filename=logfile, data=args, truncate=True) + os_ops.write(filename=logfile, data=args, truncate=True) if out: # comment-out lines lines = [u'\n'] + ['# ' + line for line in out.splitlines()] + [u'\n'] - tconf.os_ops.write(filename=logfile, data=lines) + os_ops.write(filename=logfile, data=lines) except IOError: raise ExecUtilException( "Problem with writing to logfile `{}` during run command `{}`".format(logfile, args)) @@ -94,25 +123,32 @@ def get_bin_path(filename): Return absolute path to an executable using PG_BIN or PG_CONFIG. This function does nothing if 'filename' is already absolute. """ + return get_bin_path2(tconf.os_ops, filename) + + +def get_bin_path2(os_ops: OsOperations, filename): + assert os_ops is not None + assert isinstance(os_ops, OsOperations) + # check if it's already absolute if os.path.isabs(filename): return filename - if tconf.os_ops.remote: + if isinstance(os_ops, RemoteOperations): pg_config = os.environ.get("PG_CONFIG_REMOTE") or os.environ.get("PG_CONFIG") else: # try PG_CONFIG - get from local machine pg_config = os.environ.get("PG_CONFIG") if pg_config: - bindir = get_pg_config()["BINDIR"] + bindir = get_pg_config(pg_config, os_ops)["BINDIR"] return os.path.join(bindir, filename) # try PG_BIN - pg_bin = tconf.os_ops.environ("PG_BIN") + pg_bin = os_ops.environ("PG_BIN") if pg_bin: return os.path.join(pg_bin, filename) - pg_config_path = tconf.os_ops.find_executable('pg_config') + pg_config_path = os_ops.find_executable('pg_config') if pg_config_path: bindir = get_pg_config(pg_config_path)["BINDIR"] return os.path.join(bindir, filename) @@ -125,12 +161,20 @@ def get_pg_config(pg_config_path=None, os_ops=None): Return output of pg_config (provided that it is installed). NOTE: this function caches the result by default (see GlobalConfig). """ - if os_ops: - tconf.os_ops = os_ops + + if os_ops is None: + os_ops = tconf.os_ops + + return get_pg_config2(os_ops, pg_config_path) + + +def get_pg_config2(os_ops: OsOperations, pg_config_path): + assert os_ops is not None + assert isinstance(os_ops, OsOperations) def cache_pg_config_data(cmd): # execute pg_config and get the output - out = tconf.os_ops.exec_command(cmd, encoding='utf-8') + out = os_ops.exec_command(cmd, encoding='utf-8') data = {} for line in out.splitlines(): @@ -154,11 +198,15 @@ def cache_pg_config_data(cmd): return _pg_config_data # try specified pg_config path or PG_CONFIG - if tconf.os_ops.remote: - pg_config = pg_config_path or os.environ.get("PG_CONFIG_REMOTE") or os.environ.get("PG_CONFIG") + if pg_config_path: + return cache_pg_config_data(pg_config_path) + + if isinstance(os_ops, RemoteOperations): + pg_config = os.environ.get("PG_CONFIG_REMOTE") or os.environ.get("PG_CONFIG") else: # try PG_CONFIG - get from local machine - pg_config = pg_config_path or os.environ.get("PG_CONFIG") + pg_config = os.environ.get("PG_CONFIG") + if pg_config: return cache_pg_config_data(pg_config) @@ -172,19 +220,29 @@ def cache_pg_config_data(cmd): return cache_pg_config_data("pg_config") -def get_pg_version(bin_dir=None): +def get_pg_version2(os_ops: OsOperations, bin_dir=None): """ Return PostgreSQL version provided by postmaster. """ + assert os_ops is not None + assert isinstance(os_ops, OsOperations) # Get raw version (e.g., postgres (PostgreSQL) 9.5.7) - postgres_path = os.path.join(bin_dir, 'postgres') if bin_dir else get_bin_path('postgres') - _params = [postgres_path, '--version'] - raw_ver = tconf.os_ops.exec_command(_params, encoding='utf-8') + postgres_path = os.path.join(bin_dir, 'postgres') if bin_dir else get_bin_path2(os_ops, 'postgres') + cmd = [postgres_path, '--version'] + raw_ver = os_ops.exec_command(cmd, encoding='utf-8') return parse_pg_version(raw_ver) +def get_pg_version(bin_dir=None): + """ + Return PostgreSQL version provided by postmaster. + """ + + return get_pg_version2(tconf.os_ops, bin_dir) + + def parse_pg_version(version_out): # Generalize removal of system-specific suffixes (anything in parentheses) raw_ver = re.sub(r'\([^)]*\)', '', version_out).strip() @@ -228,7 +286,6 @@ def eprint(*args, **kwargs): """ Print stuff to stderr. """ - print(*args, file=sys.stderr, **kwargs) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..111edc87 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,989 @@ +# ///////////////////////////////////////////////////////////////////////////// +# PyTest Configuration + +import pluggy +import pytest +import os +import logging +import pathlib +import math +import datetime +import typing + +import _pytest.outcomes +import _pytest.unittest +import _pytest.logging + +# ///////////////////////////////////////////////////////////////////////////// + +C_ROOT_DIR__RELATIVE = ".." + +# ///////////////////////////////////////////////////////////////////////////// +# TestConfigPropNames + + +class TestConfigPropNames: + TEST_CFG__LOG_DIR = "TEST_CFG__LOG_DIR" + + +# ///////////////////////////////////////////////////////////////////////////// + +T_TUPLE__str_int = typing.Tuple[str, int] + +# ///////////////////////////////////////////////////////////////////////////// +# TestStartupData__Helper + + +class TestStartupData__Helper: + sm_StartTS = datetime.datetime.now() + + # -------------------------------------------------------------------- + def GetStartTS() -> datetime.datetime: + assert type(__class__.sm_StartTS) == datetime.datetime # noqa: E721 + return __class__.sm_StartTS + + # -------------------------------------------------------------------- + def CalcRootDir() -> str: + r = os.path.abspath(__file__) + r = os.path.dirname(r) + r = os.path.join(r, C_ROOT_DIR__RELATIVE) + r = os.path.abspath(r) + return r + + # -------------------------------------------------------------------- + def CalcRootLogDir() -> str: + if TestConfigPropNames.TEST_CFG__LOG_DIR in os.environ: + resultPath = os.environ[TestConfigPropNames.TEST_CFG__LOG_DIR] + else: + rootDir = __class__.CalcRootDir() + resultPath = os.path.join(rootDir, "logs") + + assert type(resultPath) == str # noqa: E721 + return resultPath + + # -------------------------------------------------------------------- + def CalcCurrentTestWorkerSignature() -> str: + currentPID = os.getpid() + assert type(currentPID) + + startTS = __class__.sm_StartTS + assert type(startTS) + + result = "pytest-{0:04d}{1:02d}{2:02d}_{3:02d}{4:02d}{5:02d}".format( + startTS.year, + startTS.month, + startTS.day, + startTS.hour, + startTS.minute, + startTS.second, + ) + + gwid = os.environ.get("PYTEST_XDIST_WORKER") + + if gwid is not None: + result += "--xdist_" + str(gwid) + + result += "--" + "pid" + str(currentPID) + return result + + +# ///////////////////////////////////////////////////////////////////////////// +# TestStartupData + + +class TestStartupData: + sm_RootDir: str = TestStartupData__Helper.CalcRootDir() + sm_CurrentTestWorkerSignature: str = ( + TestStartupData__Helper.CalcCurrentTestWorkerSignature() + ) + + sm_RootLogDir: str = TestStartupData__Helper.CalcRootLogDir() + + # -------------------------------------------------------------------- + def GetRootDir() -> str: + assert type(__class__.sm_RootDir) == str # noqa: E721 + return __class__.sm_RootDir + + # -------------------------------------------------------------------- + def GetRootLogDir() -> str: + assert type(__class__.sm_RootLogDir) == str # noqa: E721 + return __class__.sm_RootLogDir + + # -------------------------------------------------------------------- + def GetCurrentTestWorkerSignature() -> str: + assert type(__class__.sm_CurrentTestWorkerSignature) == str # noqa: E721 + return __class__.sm_CurrentTestWorkerSignature + + +# ///////////////////////////////////////////////////////////////////////////// +# TEST_PROCESS_STATS + + +class TEST_PROCESS_STATS: + cTotalTests: int = 0 + cNotExecutedTests: int = 0 + cExecutedTests: int = 0 + cPassedTests: int = 0 + cFailedTests: int = 0 + cXFailedTests: int = 0 + cSkippedTests: int = 0 + cNotXFailedTests: int = 0 + cWarningTests: int = 0 + cUnexpectedTests: int = 0 + cAchtungTests: int = 0 + + FailedTests: typing.List[T_TUPLE__str_int] = list() + XFailedTests: typing.List[T_TUPLE__str_int] = list() + NotXFailedTests: typing.List[str] = list() + WarningTests: typing.List[T_TUPLE__str_int] = list() + AchtungTests: typing.List[str] = list() + + cTotalDuration: datetime.timedelta = datetime.timedelta() + + cTotalErrors: int = 0 + cTotalWarnings: int = 0 + + # -------------------------------------------------------------------- + def incrementTotalTestCount() -> None: + assert type(__class__.cTotalTests) == int # noqa: E721 + assert __class__.cTotalTests >= 0 + + __class__.cTotalTests += 1 + + assert __class__.cTotalTests > 0 + + # -------------------------------------------------------------------- + def incrementNotExecutedTestCount() -> None: + assert type(__class__.cNotExecutedTests) == int # noqa: E721 + assert __class__.cNotExecutedTests >= 0 + + __class__.cNotExecutedTests += 1 + + assert __class__.cNotExecutedTests > 0 + + # -------------------------------------------------------------------- + def incrementExecutedTestCount() -> int: + assert type(__class__.cExecutedTests) == int # noqa: E721 + assert __class__.cExecutedTests >= 0 + + __class__.cExecutedTests += 1 + + assert __class__.cExecutedTests > 0 + return __class__.cExecutedTests + + # -------------------------------------------------------------------- + def incrementPassedTestCount() -> None: + assert type(__class__.cPassedTests) == int # noqa: E721 + assert __class__.cPassedTests >= 0 + + __class__.cPassedTests += 1 + + assert __class__.cPassedTests > 0 + + # -------------------------------------------------------------------- + def incrementFailedTestCount(testID: str, errCount: int) -> None: + assert type(testID) == str # noqa: E721 + assert type(errCount) == int # noqa: E721 + assert errCount > 0 + assert type(__class__.FailedTests) == list # noqa: E721 + assert type(__class__.cFailedTests) == int # noqa: E721 + assert __class__.cFailedTests >= 0 + + __class__.FailedTests.append((testID, errCount)) # raise? + __class__.cFailedTests += 1 + + assert len(__class__.FailedTests) > 0 + assert __class__.cFailedTests > 0 + assert len(__class__.FailedTests) == __class__.cFailedTests + + # -------- + assert type(__class__.cTotalErrors) == int # noqa: E721 + assert __class__.cTotalErrors >= 0 + + __class__.cTotalErrors += errCount + + assert __class__.cTotalErrors > 0 + + # -------------------------------------------------------------------- + def incrementXFailedTestCount(testID: str, errCount: int) -> None: + assert type(testID) == str # noqa: E721 + assert type(errCount) == int # noqa: E721 + assert errCount >= 0 + assert type(__class__.XFailedTests) == list # noqa: E721 + assert type(__class__.cXFailedTests) == int # noqa: E721 + assert __class__.cXFailedTests >= 0 + + __class__.XFailedTests.append((testID, errCount)) # raise? + __class__.cXFailedTests += 1 + + assert len(__class__.XFailedTests) > 0 + assert __class__.cXFailedTests > 0 + assert len(__class__.XFailedTests) == __class__.cXFailedTests + + # -------------------------------------------------------------------- + def incrementSkippedTestCount() -> None: + assert type(__class__.cSkippedTests) == int # noqa: E721 + assert __class__.cSkippedTests >= 0 + + __class__.cSkippedTests += 1 + + assert __class__.cSkippedTests > 0 + + # -------------------------------------------------------------------- + def incrementNotXFailedTests(testID: str) -> None: + assert type(testID) == str # noqa: E721 + assert type(__class__.NotXFailedTests) == list # noqa: E721 + assert type(__class__.cNotXFailedTests) == int # noqa: E721 + assert __class__.cNotXFailedTests >= 0 + + __class__.NotXFailedTests.append(testID) # raise? + __class__.cNotXFailedTests += 1 + + assert len(__class__.NotXFailedTests) > 0 + assert __class__.cNotXFailedTests > 0 + assert len(__class__.NotXFailedTests) == __class__.cNotXFailedTests + + # -------------------------------------------------------------------- + def incrementWarningTestCount(testID: str, warningCount: int) -> None: + assert type(testID) == str # noqa: E721 + assert type(warningCount) == int # noqa: E721 + assert testID != "" + assert warningCount > 0 + assert type(__class__.WarningTests) == list # noqa: E721 + assert type(__class__.cWarningTests) == int # noqa: E721 + assert __class__.cWarningTests >= 0 + + __class__.WarningTests.append((testID, warningCount)) # raise? + __class__.cWarningTests += 1 + + assert len(__class__.WarningTests) > 0 + assert __class__.cWarningTests > 0 + assert len(__class__.WarningTests) == __class__.cWarningTests + + # -------- + assert type(__class__.cTotalWarnings) == int # noqa: E721 + assert __class__.cTotalWarnings >= 0 + + __class__.cTotalWarnings += warningCount + + assert __class__.cTotalWarnings > 0 + + # -------------------------------------------------------------------- + def incrementUnexpectedTests() -> None: + assert type(__class__.cUnexpectedTests) == int # noqa: E721 + assert __class__.cUnexpectedTests >= 0 + + __class__.cUnexpectedTests += 1 + + assert __class__.cUnexpectedTests > 0 + + # -------------------------------------------------------------------- + def incrementAchtungTestCount(testID: str) -> None: + assert type(testID) == str # noqa: E721 + assert type(__class__.AchtungTests) == list # noqa: E721 + assert type(__class__.cAchtungTests) == int # noqa: E721 + assert __class__.cAchtungTests >= 0 + + __class__.AchtungTests.append(testID) # raise? + __class__.cAchtungTests += 1 + + assert len(__class__.AchtungTests) > 0 + assert __class__.cAchtungTests > 0 + assert len(__class__.AchtungTests) == __class__.cAchtungTests + + +# ///////////////////////////////////////////////////////////////////////////// + + +def timedelta_to_human_text(delta: datetime.timedelta) -> str: + assert isinstance(delta, datetime.timedelta) + + C_SECONDS_IN_MINUTE = 60 + C_SECONDS_IN_HOUR = 60 * C_SECONDS_IN_MINUTE + + v = delta.seconds + + cHours = int(v / C_SECONDS_IN_HOUR) + v = v - cHours * C_SECONDS_IN_HOUR + cMinutes = int(v / C_SECONDS_IN_MINUTE) + cSeconds = v - cMinutes * C_SECONDS_IN_MINUTE + + result = "" if delta.days == 0 else "{0} day(s) ".format(delta.days) + + result = result + "{:02d}:{:02d}:{:02d}.{:06d}".format( + cHours, cMinutes, cSeconds, delta.microseconds + ) + + return result + + +# ///////////////////////////////////////////////////////////////////////////// + + +def helper__build_test_id(item: pytest.Function) -> str: + assert item is not None + assert isinstance(item, pytest.Function) + + testID = "" + + if item.cls is not None: + testID = item.cls.__module__ + "." + item.cls.__name__ + "::" + + testID = testID + item.name + + return testID + + +# ///////////////////////////////////////////////////////////////////////////// + +g_error_msg_count_key = pytest.StashKey[int]() +g_warning_msg_count_key = pytest.StashKey[int]() +g_critical_msg_count_key = pytest.StashKey[int]() + +# ///////////////////////////////////////////////////////////////////////////// + + +def helper__makereport__setup( + item: pytest.Function, call: pytest.CallInfo, outcome: pluggy.Result +): + assert item is not None + assert call is not None + assert outcome is not None + # it may be pytest.Function or _pytest.unittest.TestCaseFunction + assert isinstance(item, pytest.Function) + assert type(call) == pytest.CallInfo # noqa: E721 + assert type(outcome) == pluggy.Result # noqa: E721 + + C_LINE1 = "******************************************************" + + # logging.info("pytest_runtest_makereport - setup") + + TEST_PROCESS_STATS.incrementTotalTestCount() + + rep: pytest.TestReport = outcome.get_result() + assert rep is not None + assert type(rep) == pytest.TestReport # noqa: E721 + + if rep.outcome == "skipped": + TEST_PROCESS_STATS.incrementNotExecutedTestCount() + return + + testID = helper__build_test_id(item) + + if rep.outcome == "passed": + testNumber = TEST_PROCESS_STATS.incrementExecutedTestCount() + + logging.info(C_LINE1) + logging.info("* START TEST {0}".format(testID)) + logging.info("*") + logging.info("* Path : {0}".format(item.path)) + logging.info("* Number: {0}".format(testNumber)) + logging.info("*") + return + + assert rep.outcome != "passed" + + TEST_PROCESS_STATS.incrementAchtungTestCount(testID) + + logging.info(C_LINE1) + logging.info("* ACHTUNG TEST {0}".format(testID)) + logging.info("*") + logging.info("* Path : {0}".format(item.path)) + logging.info("* Outcome is [{0}]".format(rep.outcome)) + + if rep.outcome == "failed": + assert call.excinfo is not None + assert call.excinfo.value is not None + logging.info("*") + logging.error(call.excinfo.value) + + logging.info("*") + return + + +# ------------------------------------------------------------------------ +def helper__makereport__call( + item: pytest.Function, call: pytest.CallInfo, outcome: pluggy.Result +): + assert item is not None + assert call is not None + assert outcome is not None + # it may be pytest.Function or _pytest.unittest.TestCaseFunction + assert isinstance(item, pytest.Function) + assert type(call) == pytest.CallInfo # noqa: E721 + assert type(outcome) == pluggy.Result # noqa: E721 + + # -------- + item_error_msg_count1 = item.stash.get(g_error_msg_count_key, 0) + assert type(item_error_msg_count1) == int # noqa: E721 + assert item_error_msg_count1 >= 0 + + item_error_msg_count2 = item.stash.get(g_critical_msg_count_key, 0) + assert type(item_error_msg_count2) == int # noqa: E721 + assert item_error_msg_count2 >= 0 + + item_error_msg_count = item_error_msg_count1 + item_error_msg_count2 + + # -------- + item_warning_msg_count = item.stash.get(g_warning_msg_count_key, 0) + assert type(item_warning_msg_count) == int # noqa: E721 + assert item_warning_msg_count >= 0 + + # -------- + rep = outcome.get_result() + assert rep is not None + assert type(rep) == pytest.TestReport # noqa: E721 + + # -------- + testID = helper__build_test_id(item) + + # -------- + assert call.start <= call.stop + + startDT = datetime.datetime.fromtimestamp(call.start) + assert type(startDT) == datetime.datetime # noqa: E721 + stopDT = datetime.datetime.fromtimestamp(call.stop) + assert type(stopDT) == datetime.datetime # noqa: E721 + + testDurration = stopDT - startDT + assert type(testDurration) == datetime.timedelta # noqa: E721 + + # -------- + exitStatus = None + if rep.outcome == "skipped": + assert call.excinfo is not None # research + assert call.excinfo.value is not None # research + + if type(call.excinfo.value) == _pytest.outcomes.Skipped: # noqa: E721 + assert not hasattr(rep, "wasxfail") + + exitStatus = "SKIPPED" + reasonText = str(call.excinfo.value) + reasonMsgTempl = "SKIP REASON: {0}" + + TEST_PROCESS_STATS.incrementSkippedTestCount() + + elif type(call.excinfo.value) == _pytest.outcomes.XFailed: # noqa: E721 + exitStatus = "XFAILED" + reasonText = str(call.excinfo.value) + reasonMsgTempl = "XFAIL REASON: {0}" + + TEST_PROCESS_STATS.incrementXFailedTestCount(testID, item_error_msg_count) + + else: + exitStatus = "XFAILED" + assert hasattr(rep, "wasxfail") + assert rep.wasxfail is not None + assert type(rep.wasxfail) == str # noqa: E721 + + reasonText = rep.wasxfail + reasonMsgTempl = "XFAIL REASON: {0}" + + if type(call.excinfo.value) == SIGNAL_EXCEPTION: # noqa: E721 + pass + else: + logging.error(call.excinfo.value) + item_error_msg_count += 1 + + TEST_PROCESS_STATS.incrementXFailedTestCount(testID, item_error_msg_count) + + assert type(reasonText) == str # noqa: E721 + + if reasonText != "": + assert type(reasonMsgTempl) == str # noqa: E721 + logging.info("*") + logging.info("* " + reasonMsgTempl.format(reasonText)) + + elif rep.outcome == "failed": + assert call.excinfo is not None + assert call.excinfo.value is not None + + if type(call.excinfo.value) == SIGNAL_EXCEPTION: # noqa: E721 + assert item_error_msg_count > 0 + pass + else: + logging.error(call.excinfo.value) + item_error_msg_count += 1 + + assert item_error_msg_count > 0 + TEST_PROCESS_STATS.incrementFailedTestCount(testID, item_error_msg_count) + + exitStatus = "FAILED" + elif rep.outcome == "passed": + assert call.excinfo is None + + if hasattr(rep, "wasxfail"): + assert type(rep.wasxfail) == str # noqa: E721 + + TEST_PROCESS_STATS.incrementNotXFailedTests(testID) + + warnMsg = "NOTE: Test is marked as xfail" + + if rep.wasxfail != "": + warnMsg += " [" + rep.wasxfail + "]" + + logging.info(warnMsg) + exitStatus = "NOT XFAILED" + else: + assert not hasattr(rep, "wasxfail") + + TEST_PROCESS_STATS.incrementPassedTestCount() + exitStatus = "PASSED" + else: + TEST_PROCESS_STATS.incrementUnexpectedTests() + exitStatus = "UNEXPECTED [{0}]".format(rep.outcome) + # [2025-03-28] It may create a useless problem in new environment. + # assert False + + # -------- + if item_warning_msg_count > 0: + TEST_PROCESS_STATS.incrementWarningTestCount(testID, item_warning_msg_count) + + # -------- + assert type(TEST_PROCESS_STATS.cTotalDuration) == datetime.timedelta # noqa: E721 + assert type(testDurration) == datetime.timedelta # noqa: E721 + + TEST_PROCESS_STATS.cTotalDuration += testDurration + + assert testDurration <= TEST_PROCESS_STATS.cTotalDuration + + # -------- + logging.info("*") + logging.info("* DURATION : {0}".format(timedelta_to_human_text(testDurration))) + logging.info("*") + logging.info("* EXIT STATUS : {0}".format(exitStatus)) + logging.info("* ERROR COUNT : {0}".format(item_error_msg_count)) + logging.info("* WARNING COUNT: {0}".format(item_warning_msg_count)) + logging.info("*") + logging.info("* STOP TEST {0}".format(testID)) + logging.info("*") + + +# ///////////////////////////////////////////////////////////////////////////// + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_makereport(item: pytest.Function, call: pytest.CallInfo): + # + # https://docs.pytest.org/en/7.1.x/how-to/writing_hook_functions.html#hookwrapper-executing-around-other-hooks + # + # Note that hook wrappers donโ€™t return results themselves, + # they merely perform tracing or other side effects around the actual hook implementations. + # + # https://docs.pytest.org/en/7.1.x/reference/reference.html#test-running-runtest-hooks + # + assert item is not None + assert call is not None + # it may be pytest.Function or _pytest.unittest.TestCaseFunction + assert isinstance(item, pytest.Function) + assert type(call) == pytest.CallInfo # noqa: E721 + + outcome: pluggy.Result = yield + assert outcome is not None + assert type(outcome) == pluggy.Result # noqa: E721 + + assert type(call.when) == str # noqa: E721 + + if call.when == "collect": + return + + if call.when == "setup": + helper__makereport__setup(item, call, outcome) + return + + if call.when == "call": + helper__makereport__call(item, call, outcome) + return + + if call.when == "teardown": + return + + errMsg = "[pytest_runtest_makereport] unknown 'call.when' value: [{0}].".format( + call.when + ) + + raise RuntimeError(errMsg) + + +# ///////////////////////////////////////////////////////////////////////////// + + +class LogWrapper2: + _old_method: any + _err_counter: typing.Optional[int] + _warn_counter: typing.Optional[int] + + _critical_counter: typing.Optional[int] + + # -------------------------------------------------------------------- + def __init__(self): + self._old_method = None + self._err_counter = None + self._warn_counter = None + + self._critical_counter = None + + # -------------------------------------------------------------------- + def __enter__(self): + assert self._old_method is None + assert self._err_counter is None + assert self._warn_counter is None + + assert self._critical_counter is None + + assert logging.root is not None + assert isinstance(logging.root, logging.RootLogger) + + self._old_method = logging.root.handle + self._err_counter = 0 + self._warn_counter = 0 + + self._critical_counter = 0 + + logging.root.handle = self + return self + + # -------------------------------------------------------------------- + def __exit__(self, exc_type, exc_val, exc_tb): + assert self._old_method is not None + assert self._err_counter is not None + assert self._warn_counter is not None + + assert logging.root is not None + assert isinstance(logging.root, logging.RootLogger) + + assert logging.root.handle is self + + logging.root.handle = self._old_method + + self._old_method = None + self._err_counter = None + self._warn_counter = None + self._critical_counter = None + return False + + # -------------------------------------------------------------------- + def __call__(self, record: logging.LogRecord): + assert record is not None + assert isinstance(record, logging.LogRecord) + assert self._old_method is not None + assert self._err_counter is not None + assert self._warn_counter is not None + assert self._critical_counter is not None + + assert type(self._err_counter) == int # noqa: E721 + assert self._err_counter >= 0 + assert type(self._warn_counter) == int # noqa: E721 + assert self._warn_counter >= 0 + assert type(self._critical_counter) == int # noqa: E721 + assert self._critical_counter >= 0 + + r = self._old_method(record) + + if record.levelno == logging.ERROR: + self._err_counter += 1 + assert self._err_counter > 0 + elif record.levelno == logging.WARNING: + self._warn_counter += 1 + assert self._warn_counter > 0 + elif record.levelno == logging.CRITICAL: + self._critical_counter += 1 + assert self._critical_counter > 0 + + return r + + +# ///////////////////////////////////////////////////////////////////////////// + + +class SIGNAL_EXCEPTION(Exception): + def __init__(self): + pass + + +# ///////////////////////////////////////////////////////////////////////////// + + +@pytest.hookimpl(hookwrapper=True) +def pytest_pyfunc_call(pyfuncitem: pytest.Function): + assert pyfuncitem is not None + assert isinstance(pyfuncitem, pytest.Function) + + assert logging.root is not None + assert isinstance(logging.root, logging.RootLogger) + assert logging.root.handle is not None + + debug__log_handle_method = logging.root.handle + assert debug__log_handle_method is not None + + debug__log_error_method = logging.error + assert debug__log_error_method is not None + + debug__log_warning_method = logging.warning + assert debug__log_warning_method is not None + + pyfuncitem.stash[g_error_msg_count_key] = 0 + pyfuncitem.stash[g_warning_msg_count_key] = 0 + pyfuncitem.stash[g_critical_msg_count_key] = 0 + + try: + with LogWrapper2() as logWrapper: + assert type(logWrapper) == LogWrapper2 # noqa: E721 + assert logWrapper._old_method is not None + assert type(logWrapper._err_counter) == int # noqa: E721 + assert logWrapper._err_counter == 0 + assert type(logWrapper._warn_counter) == int # noqa: E721 + assert logWrapper._warn_counter == 0 + assert type(logWrapper._critical_counter) == int # noqa: E721 + assert logWrapper._critical_counter == 0 + assert logging.root.handle is logWrapper + + r: pluggy.Result = yield + + assert r is not None + assert type(r) == pluggy.Result # noqa: E721 + + assert logWrapper._old_method is not None + assert type(logWrapper._err_counter) == int # noqa: E721 + assert logWrapper._err_counter >= 0 + assert type(logWrapper._warn_counter) == int # noqa: E721 + assert logWrapper._warn_counter >= 0 + assert type(logWrapper._critical_counter) == int # noqa: E721 + assert logWrapper._critical_counter >= 0 + assert logging.root.handle is logWrapper + + assert g_error_msg_count_key in pyfuncitem.stash + assert g_warning_msg_count_key in pyfuncitem.stash + assert g_critical_msg_count_key in pyfuncitem.stash + + assert pyfuncitem.stash[g_error_msg_count_key] == 0 + assert pyfuncitem.stash[g_warning_msg_count_key] == 0 + assert pyfuncitem.stash[g_critical_msg_count_key] == 0 + + pyfuncitem.stash[g_error_msg_count_key] = logWrapper._err_counter + pyfuncitem.stash[g_warning_msg_count_key] = logWrapper._warn_counter + pyfuncitem.stash[g_critical_msg_count_key] = logWrapper._critical_counter + + if r.exception is not None: + pass + elif logWrapper._err_counter > 0: + r.force_exception(SIGNAL_EXCEPTION()) + elif logWrapper._critical_counter > 0: + r.force_exception(SIGNAL_EXCEPTION()) + finally: + assert logging.error is debug__log_error_method + assert logging.warning is debug__log_warning_method + assert logging.root.handle == debug__log_handle_method + pass + + +# ///////////////////////////////////////////////////////////////////////////// + + +def helper__calc_W(n: int) -> int: + assert n > 0 + + x = int(math.log10(n)) + assert type(x) == int # noqa: E721 + assert x >= 0 + x += 1 + return x + + +# ------------------------------------------------------------------------ +def helper__print_test_list(tests: typing.List[str]) -> None: + assert type(tests) == list # noqa: E721 + + assert helper__calc_W(9) == 1 + assert helper__calc_W(10) == 2 + assert helper__calc_W(11) == 2 + assert helper__calc_W(99) == 2 + assert helper__calc_W(100) == 3 + assert helper__calc_W(101) == 3 + assert helper__calc_W(999) == 3 + assert helper__calc_W(1000) == 4 + assert helper__calc_W(1001) == 4 + + W = helper__calc_W(len(tests)) + + templateLine = "{0:0" + str(W) + "d}. {1}" + + nTest = 0 + + for t in tests: + assert type(t) == str # noqa: E721 + assert t != "" + nTest += 1 + logging.info(templateLine.format(nTest, t)) + + +# ------------------------------------------------------------------------ +def helper__print_test_list2(tests: typing.List[T_TUPLE__str_int]) -> None: + assert type(tests) == list # noqa: E721 + + assert helper__calc_W(9) == 1 + assert helper__calc_W(10) == 2 + assert helper__calc_W(11) == 2 + assert helper__calc_W(99) == 2 + assert helper__calc_W(100) == 3 + assert helper__calc_W(101) == 3 + assert helper__calc_W(999) == 3 + assert helper__calc_W(1000) == 4 + assert helper__calc_W(1001) == 4 + + W = helper__calc_W(len(tests)) + + templateLine = "{0:0" + str(W) + "d}. {1} ({2})" + + nTest = 0 + + for t in tests: + assert type(t) == tuple # noqa: E721 + assert len(t) == 2 + assert type(t[0]) == str # noqa: E721 + assert type(t[1]) == int # noqa: E721 + assert t[0] != "" + assert t[1] >= 0 + nTest += 1 + logging.info(templateLine.format(nTest, t[0], t[1])) + + +# ///////////////////////////////////////////////////////////////////////////// + + +@pytest.fixture(autouse=True, scope="session") +def run_after_tests(request: pytest.FixtureRequest): + assert isinstance(request, pytest.FixtureRequest) + + yield + + C_LINE1 = "---------------------------" + + def LOCAL__print_line1_with_header(header: str): + assert type(C_LINE1) == str # noqa: E721 + assert type(header) == str # noqa: E721 + assert header != "" + logging.info(C_LINE1 + " [" + header + "]") + + def LOCAL__print_test_list(header: str, test_count: int, test_list: typing.List[str]): + assert type(header) == str # noqa: E721 + assert type(test_count) == int # noqa: E721 + assert type(test_list) == list # noqa: E721 + assert header != "" + assert test_count >= 0 + assert len(test_list) == test_count + + LOCAL__print_line1_with_header(header) + logging.info("") + if len(test_list) > 0: + helper__print_test_list(test_list) + logging.info("") + + def LOCAL__print_test_list2( + header: str, test_count: int, test_list: typing.List[T_TUPLE__str_int] + ): + assert type(header) == str # noqa: E721 + assert type(test_count) == int # noqa: E721 + assert type(test_list) == list # noqa: E721 + assert header != "" + assert test_count >= 0 + assert len(test_list) == test_count + + LOCAL__print_line1_with_header(header) + logging.info("") + if len(test_list) > 0: + helper__print_test_list2(test_list) + logging.info("") + + # fmt: off + LOCAL__print_test_list( + "ACHTUNG TESTS", + TEST_PROCESS_STATS.cAchtungTests, + TEST_PROCESS_STATS.AchtungTests, + ) + + LOCAL__print_test_list2( + "FAILED TESTS", + TEST_PROCESS_STATS.cFailedTests, + TEST_PROCESS_STATS.FailedTests + ) + + LOCAL__print_test_list2( + "XFAILED TESTS", + TEST_PROCESS_STATS.cXFailedTests, + TEST_PROCESS_STATS.XFailedTests, + ) + + LOCAL__print_test_list( + "NOT XFAILED TESTS", + TEST_PROCESS_STATS.cNotXFailedTests, + TEST_PROCESS_STATS.NotXFailedTests, + ) + + LOCAL__print_test_list2( + "WARNING TESTS", + TEST_PROCESS_STATS.cWarningTests, + TEST_PROCESS_STATS.WarningTests, + ) + # fmt: on + + LOCAL__print_line1_with_header("SUMMARY STATISTICS") + logging.info("") + logging.info("[TESTS]") + logging.info(" TOTAL : {0}".format(TEST_PROCESS_STATS.cTotalTests)) + logging.info(" EXECUTED : {0}".format(TEST_PROCESS_STATS.cExecutedTests)) + logging.info(" NOT EXECUTED : {0}".format(TEST_PROCESS_STATS.cNotExecutedTests)) + logging.info(" ACHTUNG : {0}".format(TEST_PROCESS_STATS.cAchtungTests)) + logging.info("") + logging.info(" PASSED : {0}".format(TEST_PROCESS_STATS.cPassedTests)) + logging.info(" FAILED : {0}".format(TEST_PROCESS_STATS.cFailedTests)) + logging.info(" XFAILED : {0}".format(TEST_PROCESS_STATS.cXFailedTests)) + logging.info(" NOT XFAILED : {0}".format(TEST_PROCESS_STATS.cNotXFailedTests)) + logging.info(" SKIPPED : {0}".format(TEST_PROCESS_STATS.cSkippedTests)) + logging.info(" WITH WARNINGS: {0}".format(TEST_PROCESS_STATS.cWarningTests)) + logging.info(" UNEXPECTED : {0}".format(TEST_PROCESS_STATS.cUnexpectedTests)) + logging.info("") + + assert type(TEST_PROCESS_STATS.cTotalDuration) == datetime.timedelta # noqa: E721 + + LOCAL__print_line1_with_header("TIME") + logging.info("") + logging.info( + " TOTAL DURATION: {0}".format( + timedelta_to_human_text(TEST_PROCESS_STATS.cTotalDuration) + ) + ) + logging.info("") + + LOCAL__print_line1_with_header("TOTAL INFORMATION") + logging.info("") + logging.info(" TOTAL ERROR COUNT : {0}".format(TEST_PROCESS_STATS.cTotalErrors)) + logging.info(" TOTAL WARNING COUNT: {0}".format(TEST_PROCESS_STATS.cTotalWarnings)) + logging.info("") + + +# ///////////////////////////////////////////////////////////////////////////// + + +@pytest.hookimpl(trylast=True) +def pytest_configure(config: pytest.Config) -> None: + assert isinstance(config, pytest.Config) + + log_name = TestStartupData.GetCurrentTestWorkerSignature() + log_name += ".log" + + log_dir = TestStartupData.GetRootLogDir() + + pathlib.Path(log_dir).mkdir(exist_ok=True) + + logging_plugin: _pytest.logging.LoggingPlugin = config.pluginmanager.get_plugin( + "logging-plugin" + ) + + assert logging_plugin is not None + assert isinstance(logging_plugin, _pytest.logging.LoggingPlugin) + + logging_plugin.set_log_path(os.path.join(log_dir, log_name)) + + +# ///////////////////////////////////////////////////////////////////////////// diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/global_data.py b/tests/helpers/global_data.py new file mode 100644 index 00000000..f3df41a3 --- /dev/null +++ b/tests/helpers/global_data.py @@ -0,0 +1,78 @@ +from testgres.operations.os_ops import OsOperations +from testgres.operations.os_ops import ConnectionParams +from testgres.operations.local_ops import LocalOperations +from testgres.operations.remote_ops import RemoteOperations + +from testgres.node import PortManager +from testgres.node import PortManager__ThisHost +from testgres.node import PortManager__Generic + +import os + + +class OsOpsDescr: + sign: str + os_ops: OsOperations + + def __init__(self, sign: str, os_ops: OsOperations): + assert type(sign) == str # noqa: E721 + assert isinstance(os_ops, OsOperations) + self.sign = sign + self.os_ops = os_ops + + +class OsOpsDescrs: + sm_remote_conn_params = ConnectionParams( + host=os.getenv('RDBMS_TESTPOOL1_HOST') or '127.0.0.1', + username=os.getenv('USER'), + ssh_key=os.getenv('RDBMS_TESTPOOL_SSHKEY')) + + sm_remote_os_ops = RemoteOperations(sm_remote_conn_params) + + sm_remote_os_ops_descr = OsOpsDescr("remote_ops", sm_remote_os_ops) + + sm_local_os_ops = LocalOperations.get_single_instance() + + sm_local_os_ops_descr = OsOpsDescr("local_ops", sm_local_os_ops) + + +class PortManagers: + sm_remote_port_manager = PortManager__Generic(OsOpsDescrs.sm_remote_os_ops) + + sm_local_port_manager = PortManager__ThisHost.get_single_instance() + + sm_local2_port_manager = PortManager__Generic(OsOpsDescrs.sm_local_os_ops) + + +class PostgresNodeService: + sign: str + os_ops: OsOperations + port_manager: PortManager + + def __init__(self, sign: str, os_ops: OsOperations, port_manager: PortManager): + assert type(sign) == str # noqa: E721 + assert isinstance(os_ops, OsOperations) + assert isinstance(port_manager, PortManager) + self.sign = sign + self.os_ops = os_ops + self.port_manager = port_manager + + +class PostgresNodeServices: + sm_remote = PostgresNodeService( + "remote", + OsOpsDescrs.sm_remote_os_ops, + PortManagers.sm_remote_port_manager + ) + + sm_local = PostgresNodeService( + "local", + OsOpsDescrs.sm_local_os_ops, + PortManagers.sm_local_port_manager + ) + + sm_local2 = PostgresNodeService( + "local2", + OsOpsDescrs.sm_local_os_ops, + PortManagers.sm_local2_port_manager + ) diff --git a/tests/helpers/run_conditions.py b/tests/helpers/run_conditions.py new file mode 100644 index 00000000..11357c30 --- /dev/null +++ b/tests/helpers/run_conditions.py @@ -0,0 +1,12 @@ +# coding: utf-8 +import pytest +import platform + + +class RunConditions: + # It is not a test kit! + __test__ = False + + def skip_if_windows(): + if platform.system().lower() == "windows": + pytest.skip("This test does not support Windows.") diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 00000000..a80a11f1 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,41 @@ +from testgres import TestgresConfig +from testgres import configure_testgres +from testgres import scoped_config +from testgres import pop_config + +import testgres + +import pytest + + +class TestConfig: + def test_config_stack(self): + # no such option + with pytest.raises(expected_exception=TypeError): + configure_testgres(dummy=True) + + # we have only 1 config in stack + with pytest.raises(expected_exception=IndexError): + pop_config() + + d0 = TestgresConfig.cached_initdb_dir + d1 = 'dummy_abc' + d2 = 'dummy_def' + + with scoped_config(cached_initdb_dir=d1) as c1: + assert (c1.cached_initdb_dir == d1) + + with scoped_config(cached_initdb_dir=d2) as c2: + stack_size = len(testgres.config.config_stack) + + # try to break a stack + with pytest.raises(expected_exception=TypeError): + with scoped_config(dummy=True): + pass + + assert (c2.cached_initdb_dir == d2) + assert (len(testgres.config.config_stack) == stack_size) + + assert (c1.cached_initdb_dir == d1) + + assert (TestgresConfig.cached_initdb_dir == d0) diff --git a/tests/test_conftest.py--devel b/tests/test_conftest.py--devel new file mode 100644 index 00000000..67c1dafe --- /dev/null +++ b/tests/test_conftest.py--devel @@ -0,0 +1,80 @@ +import pytest +import logging + + +class TestConfest: + def test_failed(self): + raise Exception("TEST EXCEPTION!") + + def test_ok(self): + pass + + @pytest.mark.skip() + def test_mark_skip__no_reason(self): + pass + + @pytest.mark.xfail() + def test_mark_xfail__no_reason(self): + raise Exception("XFAIL EXCEPTION") + + @pytest.mark.xfail() + def test_mark_xfail__no_reason___no_error(self): + pass + + @pytest.mark.skip(reason="reason") + def test_mark_skip__with_reason(self): + pass + + @pytest.mark.xfail(reason="reason") + def test_mark_xfail__with_reason(self): + raise Exception("XFAIL EXCEPTION") + + @pytest.mark.xfail(reason="reason") + def test_mark_xfail__with_reason___no_error(self): + pass + + def test_exc_skip__no_reason(self): + pytest.skip() + + def test_exc_xfail__no_reason(self): + pytest.xfail() + + def test_exc_skip__with_reason(self): + pytest.skip(reason="SKIP REASON") + + def test_exc_xfail__with_reason(self): + pytest.xfail(reason="XFAIL EXCEPTION") + + def test_log_error(self): + logging.error("IT IS A LOG ERROR!") + + def test_log_error_and_exc(self): + logging.error("IT IS A LOG ERROR!") + + raise Exception("TEST EXCEPTION!") + + def test_log_error_and_warning(self): + logging.error("IT IS A LOG ERROR!") + logging.warning("IT IS A LOG WARNING!") + logging.error("IT IS THE SECOND LOG ERROR!") + logging.warning("IT IS THE SECOND LOG WARNING!") + + @pytest.mark.xfail() + def test_log_error_and_xfail_mark_without_reason(self): + logging.error("IT IS A LOG ERROR!") + + @pytest.mark.xfail(reason="It is a reason message") + def test_log_error_and_xfail_mark_with_reason(self): + logging.error("IT IS A LOG ERROR!") + + @pytest.mark.xfail() + def test_two_log_error_and_xfail_mark_without_reason(self): + logging.error("IT IS THE FIRST LOG ERROR!") + logging.info("----------") + logging.error("IT IS THE SECOND LOG ERROR!") + + @pytest.mark.xfail(reason="It is a reason message") + def test_two_log_error_and_xfail_mark_with_reason(self): + logging.error("IT IS THE FIRST LOG ERROR!") + logging.info("----------") + logging.error("IT IS THE SECOND LOG ERROR!") diff --git a/tests/test_os_ops_common.py b/tests/test_os_ops_common.py new file mode 100644 index 00000000..5ae3a61f --- /dev/null +++ b/tests/test_os_ops_common.py @@ -0,0 +1,1115 @@ +# coding: utf-8 +from .helpers.global_data import OsOpsDescr +from .helpers.global_data import OsOpsDescrs +from .helpers.global_data import OsOperations +from .helpers.run_conditions import RunConditions + +import os + +import pytest +import re +import tempfile +import logging +import socket +import threading +import typing +import uuid + +from testgres import InvalidOperationException +from testgres import ExecUtilException + +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future as ThreadFuture + + +class TestOsOpsCommon: + sm_os_ops_descrs: typing.List[OsOpsDescr] = [ + OsOpsDescrs.sm_local_os_ops_descr, + OsOpsDescrs.sm_remote_os_ops_descr + ] + + @pytest.fixture( + params=[descr.os_ops for descr in sm_os_ops_descrs], + ids=[descr.sign for descr in sm_os_ops_descrs] + ) + def os_ops(self, request: pytest.FixtureRequest) -> OsOperations: + assert isinstance(request, pytest.FixtureRequest) + assert isinstance(request.param, OsOperations) + return request.param + + def test_exec_command_success(self, os_ops: OsOperations): + """ + Test exec_command for successful command execution. + """ + assert isinstance(os_ops, OsOperations) + + RunConditions.skip_if_windows() + + cmd = ["sh", "-c", "python3 --version"] + + response = os_ops.exec_command(cmd) + + assert b'Python 3.' in response + + def test_exec_command_failure(self, os_ops: OsOperations): + """ + Test exec_command for command execution failure. + """ + assert isinstance(os_ops, OsOperations) + + RunConditions.skip_if_windows() + + cmd = ["sh", "-c", "nonexistent_command"] + + while True: + try: + os_ops.exec_command(cmd) + except ExecUtilException as e: + assert type(e.exit_code) == int # noqa: E721 + assert e.exit_code == 127 + + assert type(e.message) == str # noqa: E721 + assert type(e.error) == bytes # noqa: E721 + + assert e.message.startswith("Utility exited with non-zero code (127). Error:") + assert "nonexistent_command" in e.message + assert "not found" in e.message + assert b"nonexistent_command" in e.error + assert b"not found" in e.error + break + raise Exception("We wait an exception!") + + def test_exec_command_failure__expect_error(self, os_ops: OsOperations): + """ + Test exec_command for command execution failure. + """ + assert isinstance(os_ops, OsOperations) + + RunConditions.skip_if_windows() + + cmd = ["sh", "-c", "nonexistent_command"] + + exit_status, result, error = os_ops.exec_command(cmd, verbose=True, expect_error=True) + + assert exit_status == 127 + assert result == b'' + assert type(error) == bytes # noqa: E721 + assert b"nonexistent_command" in error + assert b"not found" in error + + def test_exec_command_with_exec_env(self, os_ops: OsOperations): + assert isinstance(os_ops, OsOperations) + + RunConditions.skip_if_windows() + + C_ENV_NAME = "TESTGRES_TEST__EXEC_ENV_20250414" + + cmd = ["sh", "-c", "echo ${}".format(C_ENV_NAME)] + + exec_env = {C_ENV_NAME: "Hello!"} + + response = os_ops.exec_command(cmd, exec_env=exec_env) + assert response is not None + assert type(response) == bytes # noqa: E721 + assert response == b'Hello!\n' + + response = os_ops.exec_command(cmd) + assert response is not None + assert type(response) == bytes # noqa: E721 + assert response == b'\n' + + def test_exec_command__test_unset(self, os_ops: OsOperations): + assert isinstance(os_ops, OsOperations) + + RunConditions.skip_if_windows() + + C_ENV_NAME = "LANG" + + cmd = ["sh", "-c", "echo ${}".format(C_ENV_NAME)] + + response1 = os_ops.exec_command(cmd) + assert response1 is not None + assert type(response1) == bytes # noqa: E721 + + if response1 == b'\n': + logging.warning("Environment variable {} is not defined.".format(C_ENV_NAME)) + return + + exec_env = {C_ENV_NAME: None} + response2 = os_ops.exec_command(cmd, exec_env=exec_env) + assert response2 is not None + assert type(response2) == bytes # noqa: E721 + assert response2 == b'\n' + + response3 = os_ops.exec_command(cmd) + assert response3 is not None + assert type(response3) == bytes # noqa: E721 + assert response3 == response1 + + def test_exec_command__test_unset_dummy_var(self, os_ops: OsOperations): + assert isinstance(os_ops, OsOperations) + + RunConditions.skip_if_windows() + + C_ENV_NAME = "TESTGRES_TEST__DUMMY_VAR_20250414" + + cmd = ["sh", "-c", "echo ${}".format(C_ENV_NAME)] + + exec_env = {C_ENV_NAME: None} + response2 = os_ops.exec_command(cmd, exec_env=exec_env) + assert response2 is not None + assert type(response2) == bytes # noqa: E721 + assert response2 == b'\n' + + def test_is_executable_true(self, os_ops: OsOperations): + """ + Test is_executable for an existing executable. + """ + assert isinstance(os_ops, OsOperations) + + RunConditions.skip_if_windows() + + response = os_ops.is_executable("/bin/sh") + + assert response is True + + def test_is_executable_false(self, os_ops: OsOperations): + """ + Test is_executable for a non-executable. + """ + assert isinstance(os_ops, OsOperations) + + response = os_ops.is_executable(__file__) + + assert response is False + + def test_makedirs_and_rmdirs_success(self, os_ops: OsOperations): + """ + Test makedirs and rmdirs for successful directory creation and removal. + """ + assert isinstance(os_ops, OsOperations) + + RunConditions.skip_if_windows() + + cmd = "pwd" + pwd = os_ops.exec_command(cmd, wait_exit=True, encoding='utf-8').strip() + + path = "{}/test_dir".format(pwd) + + # Test makedirs + os_ops.makedirs(path) + assert os.path.exists(path) + assert os_ops.path_exists(path) + + # Test rmdirs + os_ops.rmdirs(path) + assert not os.path.exists(path) + assert not os_ops.path_exists(path) + + def test_makedirs_failure(self, os_ops: OsOperations): + """ + Test makedirs for failure. + """ + # Try to create a directory in a read-only location + assert isinstance(os_ops, OsOperations) + + RunConditions.skip_if_windows() + + path = "/root/test_dir" + + # Test makedirs + with pytest.raises(Exception): + os_ops.makedirs(path) + + def test_listdir(self, os_ops: OsOperations): + """ + Test listdir for listing directory contents. + """ + assert isinstance(os_ops, OsOperations) + + RunConditions.skip_if_windows() + + path = "/etc" + files = os_ops.listdir(path) + assert isinstance(files, list) + for f in files: + assert f is not None + assert type(f) == str # noqa: E721 + + def test_path_exists_true__directory(self, os_ops: OsOperations): + """ + Test path_exists for an existing directory. + """ + assert isinstance(os_ops, OsOperations) + + RunConditions.skip_if_windows() + + assert os_ops.path_exists("/etc") is True + + def test_path_exists_true__file(self, os_ops: OsOperations): + """ + Test path_exists for an existing file. + """ + assert isinstance(os_ops, OsOperations) + + RunConditions.skip_if_windows() + + assert os_ops.path_exists(__file__) is True + + def test_path_exists_false__directory(self, os_ops: OsOperations): + """ + Test path_exists for a non-existing directory. + """ + assert isinstance(os_ops, OsOperations) + + RunConditions.skip_if_windows() + + assert os_ops.path_exists("/nonexistent_path") is False + + def test_path_exists_false__file(self, os_ops: OsOperations): + """ + Test path_exists for a non-existing file. + """ + assert isinstance(os_ops, OsOperations) + + RunConditions.skip_if_windows() + + assert os_ops.path_exists("/etc/nonexistent_path.txt") is False + + def test_mkdtemp__default(self, os_ops: OsOperations): + assert isinstance(os_ops, OsOperations) + + path = os_ops.mkdtemp() + logging.info("Path is [{0}].".format(path)) + assert os.path.exists(path) + os.rmdir(path) + assert not os.path.exists(path) + + def test_mkdtemp__custom(self, os_ops: OsOperations): + assert isinstance(os_ops, OsOperations) + + C_TEMPLATE = "abcdef" + path = os_ops.mkdtemp(C_TEMPLATE) + logging.info("Path is [{0}].".format(path)) + assert os.path.exists(path) + assert C_TEMPLATE in os.path.basename(path) + os.rmdir(path) + assert not os.path.exists(path) + + def test_rmdirs(self, os_ops: OsOperations): + assert isinstance(os_ops, OsOperations) + + path = os_ops.mkdtemp() + assert os.path.exists(path) + + assert os_ops.rmdirs(path, ignore_errors=False) is True + assert not os.path.exists(path) + + def test_rmdirs__01_with_subfolder(self, os_ops: OsOperations): + assert isinstance(os_ops, OsOperations) + + # folder with subfolder + path = os_ops.mkdtemp() + assert os.path.exists(path) + + dir1 = os.path.join(path, "dir1") + assert not os.path.exists(dir1) + + os_ops.makedirs(dir1) + assert os.path.exists(dir1) + + assert os_ops.rmdirs(path, ignore_errors=False) is True + assert not os.path.exists(path) + assert not os.path.exists(dir1) + + def test_rmdirs__02_with_file(self, os_ops: OsOperations): + assert isinstance(os_ops, OsOperations) + + # folder with file + path = os_ops.mkdtemp() + assert os.path.exists(path) + + file1 = os.path.join(path, "file1.txt") + assert not os.path.exists(file1) + + os_ops.touch(file1) + assert os.path.exists(file1) + + assert os_ops.rmdirs(path, ignore_errors=False) is True + assert not os.path.exists(path) + assert not os.path.exists(file1) + + def test_rmdirs__03_with_subfolder_and_file(self, os_ops: OsOperations): + assert isinstance(os_ops, OsOperations) + + # folder with subfolder and file + path = os_ops.mkdtemp() + assert os.path.exists(path) + + dir1 = os.path.join(path, "dir1") + assert not os.path.exists(dir1) + + os_ops.makedirs(dir1) + assert os.path.exists(dir1) + + file1 = os.path.join(dir1, "file1.txt") + assert not os.path.exists(file1) + + os_ops.touch(file1) + assert os.path.exists(file1) + + assert os_ops.rmdirs(path, ignore_errors=False) is True + assert not os.path.exists(path) + assert not os.path.exists(dir1) + assert not os.path.exists(file1) + + def test_write_text_file(self, os_ops: OsOperations): + """ + Test write for writing data to a text file. + """ + assert isinstance(os_ops, OsOperations) + + RunConditions.skip_if_windows() + + filename = os_ops.mkstemp() + data = "Hello, world!" + + os_ops.write(filename, data, truncate=True) + os_ops.write(filename, data) + + response = os_ops.read(filename) + + assert response == data + data + + os_ops.remove_file(filename) + + def test_write_binary_file(self, os_ops: OsOperations): + """ + Test write for writing data to a binary file. + """ + assert isinstance(os_ops, OsOperations) + + RunConditions.skip_if_windows() + + filename = "/tmp/test_file.bin" + data = b"\x00\x01\x02\x03" + + os_ops.write(filename, data, binary=True, truncate=True) + + response = os_ops.read(filename, binary=True) + + assert response == data + + def test_read_text_file(self, os_ops: OsOperations): + """ + Test read for reading data from a text file. + """ + assert isinstance(os_ops, OsOperations) + + RunConditions.skip_if_windows() + + filename = "/etc/hosts" + + response = os_ops.read(filename) + + assert isinstance(response, str) + + def test_read_binary_file(self, os_ops: OsOperations): + """ + Test read for reading data from a binary file. + """ + assert isinstance(os_ops, OsOperations) + + RunConditions.skip_if_windows() + + filename = "/usr/bin/python3" + + response = os_ops.read(filename, binary=True) + + assert isinstance(response, bytes) + + def test_read__text(self, os_ops: OsOperations): + """ + Test OsOperations::read for text data. + """ + assert isinstance(os_ops, OsOperations) + + filename = __file__ # current file + + with open(filename, 'r') as file: # open in a text mode + response0 = file.read() + + assert type(response0) == str # noqa: E721 + + response1 = os_ops.read(filename) + assert type(response1) == str # noqa: E721 + assert response1 == response0 + + response2 = os_ops.read(filename, encoding=None, binary=False) + assert type(response2) == str # noqa: E721 + assert response2 == response0 + + response3 = os_ops.read(filename, encoding="") + assert type(response3) == str # noqa: E721 + assert response3 == response0 + + response4 = os_ops.read(filename, encoding="UTF-8") + assert type(response4) == str # noqa: E721 + assert response4 == response0 + + def test_read__binary(self, os_ops: OsOperations): + """ + Test OsOperations::read for binary data. + """ + filename = __file__ # current file + + with open(filename, 'rb') as file: # open in a binary mode + response0 = file.read() + + assert type(response0) == bytes # noqa: E721 + + response1 = os_ops.read(filename, binary=True) + assert type(response1) == bytes # noqa: E721 + assert response1 == response0 + + def test_read__binary_and_encoding(self, os_ops: OsOperations): + """ + Test OsOperations::read for binary data and encoding. + """ + assert isinstance(os_ops, OsOperations) + + filename = __file__ # current file + + with pytest.raises( + InvalidOperationException, + match=re.escape("Enconding is not allowed for read binary operation")): + os_ops.read(filename, encoding="", binary=True) + + def test_read_binary__spec(self, os_ops: OsOperations): + """ + Test OsOperations::read_binary. + """ + assert isinstance(os_ops, OsOperations) + + filename = __file__ # currnt file + + with open(filename, 'rb') as file: # open in a binary mode + response0 = file.read() + + assert type(response0) == bytes # noqa: E721 + + response1 = os_ops.read_binary(filename, 0) + assert type(response1) == bytes # noqa: E721 + assert response1 == response0 + + response2 = os_ops.read_binary(filename, 1) + assert type(response2) == bytes # noqa: E721 + assert len(response2) < len(response1) + assert len(response2) + 1 == len(response1) + assert response2 == response1[1:] + + response3 = os_ops.read_binary(filename, len(response1)) + assert type(response3) == bytes # noqa: E721 + assert len(response3) == 0 + + response4 = os_ops.read_binary(filename, len(response2)) + assert type(response4) == bytes # noqa: E721 + assert len(response4) == 1 + assert response4[0] == response1[len(response1) - 1] + + response5 = os_ops.read_binary(filename, len(response1) + 1) + assert type(response5) == bytes # noqa: E721 + assert len(response5) == 0 + + def test_read_binary__spec__negative_offset(self, os_ops: OsOperations): + """ + Test OsOperations::read_binary with negative offset. + """ + assert isinstance(os_ops, OsOperations) + + with pytest.raises( + ValueError, + match=re.escape("Negative 'offset' is not supported.")): + os_ops.read_binary(__file__, -1) + + def test_get_file_size(self, os_ops: OsOperations): + """ + Test OsOperations::get_file_size. + """ + assert isinstance(os_ops, OsOperations) + + filename = __file__ # current file + + sz0 = os.path.getsize(filename) + assert type(sz0) == int # noqa: E721 + + sz1 = os_ops.get_file_size(filename) + assert type(sz1) == int # noqa: E721 + assert sz1 == sz0 + + def test_isfile_true(self, os_ops: OsOperations): + """ + Test isfile for an existing file. + """ + assert isinstance(os_ops, OsOperations) + + filename = __file__ + + response = os_ops.isfile(filename) + + assert response is True + + def test_isfile_false__not_exist(self, os_ops: OsOperations): + """ + Test isfile for a non-existing file. + """ + assert isinstance(os_ops, OsOperations) + + filename = os.path.join(os.path.dirname(__file__), "nonexistent_file.txt") + + response = os_ops.isfile(filename) + + assert response is False + + def test_isfile_false__directory(self, os_ops: OsOperations): + """ + Test isfile for a firectory. + """ + assert isinstance(os_ops, OsOperations) + + name = os.path.dirname(__file__) + + assert os_ops.isdir(name) + + response = os_ops.isfile(name) + + assert response is False + + def test_isdir_true(self, os_ops: OsOperations): + """ + Test isdir for an existing directory. + """ + assert isinstance(os_ops, OsOperations) + + name = os.path.dirname(__file__) + + response = os_ops.isdir(name) + + assert response is True + + def test_isdir_false__not_exist(self, os_ops: OsOperations): + """ + Test isdir for a non-existing directory. + """ + assert isinstance(os_ops, OsOperations) + + name = os.path.join(os.path.dirname(__file__), "it_is_nonexistent_directory") + + response = os_ops.isdir(name) + + assert response is False + + def test_isdir_false__file(self, os_ops: OsOperations): + """ + Test isdir for a file. + """ + assert isinstance(os_ops, OsOperations) + + name = __file__ + + assert os_ops.isfile(name) + + response = os_ops.isdir(name) + + assert response is False + + def test_cwd(self, os_ops: OsOperations): + """ + Test cwd. + """ + assert isinstance(os_ops, OsOperations) + + v = os_ops.cwd() + + assert v is not None + assert type(v) == str # noqa: E721 + assert v != "" + + class tagWriteData001: + def __init__(self, sign, source, cp_rw, cp_truncate, cp_binary, cp_data, result): + self.sign = sign + self.source = source + self.call_param__rw = cp_rw + self.call_param__truncate = cp_truncate + self.call_param__binary = cp_binary + self.call_param__data = cp_data + self.result = result + + sm_write_data001 = [ + tagWriteData001("A001", "1234567890", False, False, False, "ABC", "1234567890ABC"), + tagWriteData001("A002", b"1234567890", False, False, True, b"ABC", b"1234567890ABC"), + + tagWriteData001("B001", "1234567890", False, True, False, "ABC", "ABC"), + tagWriteData001("B002", "1234567890", False, True, False, "ABC1234567890", "ABC1234567890"), + tagWriteData001("B003", b"1234567890", False, True, True, b"ABC", b"ABC"), + tagWriteData001("B004", b"1234567890", False, True, True, b"ABC1234567890", b"ABC1234567890"), + + tagWriteData001("C001", "1234567890", True, False, False, "ABC", "1234567890ABC"), + tagWriteData001("C002", b"1234567890", True, False, True, b"ABC", b"1234567890ABC"), + + tagWriteData001("D001", "1234567890", True, True, False, "ABC", "ABC"), + tagWriteData001("D002", "1234567890", True, True, False, "ABC1234567890", "ABC1234567890"), + tagWriteData001("D003", b"1234567890", True, True, True, b"ABC", b"ABC"), + tagWriteData001("D004", b"1234567890", True, True, True, b"ABC1234567890", b"ABC1234567890"), + + tagWriteData001("E001", "\0001234567890\000", False, False, False, "\000ABC\000", "\0001234567890\000\000ABC\000"), + tagWriteData001("E002", b"\0001234567890\000", False, False, True, b"\000ABC\000", b"\0001234567890\000\000ABC\000"), + + tagWriteData001("F001", "a\nb\n", False, False, False, ["c", "d"], "a\nb\nc\nd\n"), + tagWriteData001("F002", b"a\nb\n", False, False, True, [b"c", b"d"], b"a\nb\nc\nd\n"), + + tagWriteData001("G001", "a\nb\n", False, False, False, ["c\n\n", "d\n"], "a\nb\nc\nd\n"), + tagWriteData001("G002", b"a\nb\n", False, False, True, [b"c\n\n", b"d\n"], b"a\nb\nc\nd\n"), + ] + + @pytest.fixture( + params=sm_write_data001, + ids=[x.sign for x in sm_write_data001], + ) + def write_data001(self, request): + assert isinstance(request, pytest.FixtureRequest) + assert type(request.param) == __class__.tagWriteData001 # noqa: E721 + return request.param + + def test_write(self, write_data001: tagWriteData001, os_ops: OsOperations): + assert type(write_data001) == __class__.tagWriteData001 # noqa: E721 + assert isinstance(os_ops, OsOperations) + + mode = "w+b" if write_data001.call_param__binary else "w+" + + with tempfile.NamedTemporaryFile(mode=mode, delete=True) as tmp_file: + tmp_file.write(write_data001.source) + tmp_file.flush() + + os_ops.write( + tmp_file.name, + write_data001.call_param__data, + read_and_write=write_data001.call_param__rw, + truncate=write_data001.call_param__truncate, + binary=write_data001.call_param__binary) + + tmp_file.seek(0) + + s = tmp_file.read() + + assert s == write_data001.result + + def test_touch(self, os_ops: OsOperations): + """ + Test touch for creating a new file or updating access and modification times of an existing file. + """ + assert isinstance(os_ops, OsOperations) + + filename = os_ops.mkstemp() + + # TODO: this test does not check the result of 'touch' command! + + os_ops.touch(filename) + + assert os_ops.isfile(filename) + + os_ops.remove_file(filename) + + def test_is_port_free__true(self, os_ops: OsOperations): + assert isinstance(os_ops, OsOperations) + + C_LIMIT = 128 + + ports = set(range(1024, 65535)) + assert type(ports) == set # noqa: E721 + + ok_count = 0 + no_count = 0 + + for port in ports: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("", port)) + except OSError: + continue + + r = os_ops.is_port_free(port) + + if r: + ok_count += 1 + logging.info("OK. Port {} is free.".format(port)) + else: + no_count += 1 + logging.warning("NO. Port {} is not free.".format(port)) + + if ok_count == C_LIMIT: + return + + if no_count == C_LIMIT: + raise RuntimeError("To many false positive test attempts.") + + if ok_count == 0: + raise RuntimeError("No one free port was found.") + + def test_is_port_free__false(self, os_ops: OsOperations): + assert isinstance(os_ops, OsOperations) + + C_LIMIT = 10 + + ports = set(range(1024, 65535)) + assert type(ports) == set # noqa: E721 + + def LOCAL_server(s: socket.socket): + assert s is not None + assert type(s) == socket.socket # noqa: E721 + + try: + while True: + r = s.accept() + + if r is None: + break + except Exception as e: + assert e is not None + pass + + ok_count = 0 + no_count = 0 + + for port in ports: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("", port)) + except OSError: + continue + + th = threading.Thread(target=LOCAL_server, args=[s]) + + s.listen(10) + + assert type(th) == threading.Thread # noqa: E721 + th.start() + + try: + r = os_ops.is_port_free(port) + finally: + s.shutdown(2) + th.join() + + if not r: + ok_count += 1 + logging.info("OK. Port {} is not free.".format(port)) + else: + no_count += 1 + logging.warning("NO. Port {} does not accept connection.".format(port)) + + if ok_count == C_LIMIT: + return + + if no_count == C_LIMIT: + raise RuntimeError("To many false positive test attempts.") + + if ok_count == 0: + raise RuntimeError("No one free port was found.") + + def test_get_tmpdir(self, os_ops: OsOperations): + assert isinstance(os_ops, OsOperations) + + dir = os_ops.get_tempdir() + assert type(dir) == str # noqa: E721 + assert os_ops.path_exists(dir) + assert os.path.exists(dir) + + file_path = os.path.join(dir, "testgres--" + uuid.uuid4().hex + ".tmp") + + os_ops.write(file_path, "1234", binary=False) + + assert os_ops.path_exists(file_path) + assert os.path.exists(file_path) + + d = os_ops.read(file_path, binary=False) + + assert d == "1234" + + os_ops.remove_file(file_path) + + assert not os_ops.path_exists(file_path) + assert not os.path.exists(file_path) + + def test_get_tmpdir__compare_with_py_info(self, os_ops: OsOperations): + assert isinstance(os_ops, OsOperations) + + actual_dir = os_ops.get_tempdir() + assert actual_dir is not None + assert type(actual_dir) == str # noqa: E721 + expected_dir = str(tempfile.tempdir) + assert actual_dir == expected_dir + + class tagData_OS_OPS__NUMS: + os_ops_descr: OsOpsDescr + nums: int + + def __init__(self, os_ops_descr: OsOpsDescr, nums: int): + assert isinstance(os_ops_descr, OsOpsDescr) + assert type(nums) == int # noqa: E721 + + self.os_ops_descr = os_ops_descr + self.nums = nums + + sm_test_exclusive_creation__mt__data = [ + tagData_OS_OPS__NUMS(OsOpsDescrs.sm_local_os_ops_descr, 100000), + tagData_OS_OPS__NUMS(OsOpsDescrs.sm_remote_os_ops_descr, 120), + ] + + @pytest.fixture( + params=sm_test_exclusive_creation__mt__data, + ids=[x.os_ops_descr.sign for x in sm_test_exclusive_creation__mt__data] + ) + def data001(self, request: pytest.FixtureRequest) -> tagData_OS_OPS__NUMS: + assert isinstance(request, pytest.FixtureRequest) + return request.param + + def test_mkdir__mt(self, data001: tagData_OS_OPS__NUMS): + assert type(data001) == __class__.tagData_OS_OPS__NUMS # noqa: E721 + + N_WORKERS = 4 + N_NUMBERS = data001.nums + assert type(N_NUMBERS) == int # noqa: E721 + + os_ops = data001.os_ops_descr.os_ops + assert isinstance(os_ops, OsOperations) + + lock_dir_prefix = "test_mkdir_mt--" + uuid.uuid4().hex + + lock_dir = os_ops.mkdtemp(prefix=lock_dir_prefix) + + logging.info("A lock file [{}] is creating ...".format(lock_dir)) + + assert os.path.exists(lock_dir) + + def MAKE_PATH(lock_dir: str, num: int) -> str: + assert type(lock_dir) == str # noqa: E721 + assert type(num) == int # noqa: E721 + return os.path.join(lock_dir, str(num) + ".lock") + + def LOCAL_WORKER(os_ops: OsOperations, + workerID: int, + lock_dir: str, + cNumbers: int, + reservedNumbers: typing.Set[int]) -> None: + assert isinstance(os_ops, OsOperations) + assert type(workerID) == int # noqa: E721 + assert type(lock_dir) == str # noqa: E721 + assert type(cNumbers) == int # noqa: E721 + assert type(reservedNumbers) == set # noqa: E721 + assert cNumbers > 0 + assert len(reservedNumbers) == 0 + + assert os.path.exists(lock_dir) + + def LOG_INFO(template: str, *args: list) -> None: + assert type(template) == str # noqa: E721 + assert type(args) == tuple # noqa: E721 + + msg = template.format(*args) + assert type(msg) == str # noqa: E721 + + logging.info("[Worker #{}] {}".format(workerID, msg)) + return + + LOG_INFO("HELLO! I am here!") + + for num in range(cNumbers): + assert not (num in reservedNumbers) + + file_path = MAKE_PATH(lock_dir, num) + + try: + os_ops.makedir(file_path) + except Exception as e: + LOG_INFO( + "Can't reserve {}. Error ({}): {}", + num, + type(e).__name__, + str(e) + ) + continue + + LOG_INFO("Number {} is reserved!", num) + assert os_ops.path_exists(file_path) + reservedNumbers.add(num) + continue + + n_total = cNumbers + n_ok = len(reservedNumbers) + assert n_ok <= n_total + + LOG_INFO("Finish! OK: {}. FAILED: {}.", n_ok, n_total - n_ok) + return + + # ----------------------- + logging.info("Worker are creating ...") + + threadPool = ThreadPoolExecutor( + max_workers=N_WORKERS, + thread_name_prefix="ex_creator" + ) + + class tadWorkerData: + future: ThreadFuture + reservedNumbers: typing.Set[int] + + workerDatas: typing.List[tadWorkerData] = list() + + nErrors = 0 + + try: + for n in range(N_WORKERS): + logging.info("worker #{} is creating ...".format(n)) + + workerDatas.append(tadWorkerData()) + + workerDatas[n].reservedNumbers = set() + + workerDatas[n].future = threadPool.submit( + LOCAL_WORKER, + os_ops, + n, + lock_dir, + N_NUMBERS, + workerDatas[n].reservedNumbers + ) + + assert workerDatas[n].future is not None + + logging.info("OK. All the workers were created!") + except Exception as e: + nErrors += 1 + logging.error("A problem is detected ({}): {}".format(type(e).__name__, str(e))) + + logging.info("Will wait for stop of all the workers...") + + nWorkers = 0 + + assert type(workerDatas) == list # noqa: E721 + + for i in range(len(workerDatas)): + worker = workerDatas[i].future + + if worker is None: + continue + + nWorkers += 1 + + assert isinstance(worker, ThreadFuture) + + try: + logging.info("Wait for worker #{}".format(i)) + worker.result() + except Exception as e: + nErrors += 1 + logging.error("Worker #{} finished with error ({}): {}".format( + i, + type(e).__name__, + str(e), + )) + continue + + assert nWorkers == N_WORKERS + + if nErrors != 0: + raise RuntimeError("Some problems were detected. Please examine the log messages.") + + logging.info("OK. Let's check worker results!") + + reservedNumbers: typing.Dict[int, int] = dict() + + for i in range(N_WORKERS): + logging.info("Worker #{} is checked ...".format(i)) + + workerNumbers = workerDatas[i].reservedNumbers + assert type(workerNumbers) == set # noqa: E721 + + for n in workerNumbers: + if n < 0 or n >= N_NUMBERS: + nErrors += 1 + logging.error("Unexpected number {}".format(n)) + continue + + if n in reservedNumbers.keys(): + nErrors += 1 + logging.error("Number {} was already reserved by worker #{}".format( + n, + reservedNumbers[n] + )) + else: + reservedNumbers[n] = i + + file_path = MAKE_PATH(lock_dir, n) + if not os_ops.path_exists(file_path): + nErrors += 1 + logging.error("File {} is not found!".format(file_path)) + continue + + continue + + logging.info("OK. Let's check reservedNumbers!") + + for n in range(N_NUMBERS): + if not (n in reservedNumbers.keys()): + nErrors += 1 + logging.error("Number {} is not reserved!".format(n)) + continue + + file_path = MAKE_PATH(lock_dir, n) + if not os_ops.path_exists(file_path): + nErrors += 1 + logging.error("File {} is not found!".format(file_path)) + continue + + # OK! + continue + + logging.info("Verification is finished! Total error count is {}.".format(nErrors)) + + if nErrors == 0: + logging.info("Root lock-directory [{}] will be deleted.".format( + lock_dir + )) + + for n in range(N_NUMBERS): + file_path = MAKE_PATH(lock_dir, n) + try: + os_ops.rmdir(file_path) + except Exception as e: + nErrors += 1 + logging.error("Cannot delete directory [{}]. Error ({}): {}".format( + file_path, + type(e).__name__, + str(e) + )) + continue + + if os_ops.path_exists(file_path): + nErrors += 1 + logging.error("Directory {} is not deleted!".format(file_path)) + continue + + if nErrors == 0: + try: + os_ops.rmdir(lock_dir) + except Exception as e: + nErrors += 1 + logging.error("Cannot delete directory [{}]. Error ({}): {}".format( + lock_dir, + type(e).__name__, + str(e) + )) + + logging.info("Test is finished! Total error count is {}.".format(nErrors)) + return diff --git a/tests/test_os_ops_local.py b/tests/test_os_ops_local.py new file mode 100644 index 00000000..f60c3fc9 --- /dev/null +++ b/tests/test_os_ops_local.py @@ -0,0 +1,60 @@ +# coding: utf-8 +from .helpers.global_data import OsOpsDescrs +from .helpers.global_data import OsOperations + +import os + +import pytest +import re + + +class TestOsOpsLocal: + @pytest.fixture + def os_ops(self): + return OsOpsDescrs.sm_local_os_ops + + def test_read__unknown_file(self, os_ops: OsOperations): + """ + Test LocalOperations::read with unknown file. + """ + + with pytest.raises(FileNotFoundError, match=re.escape("[Errno 2] No such file or directory: '/dummy'")): + os_ops.read("/dummy") + + def test_read_binary__spec__unk_file(self, os_ops: OsOperations): + """ + Test LocalOperations::read_binary with unknown file. + """ + + with pytest.raises( + FileNotFoundError, + match=re.escape("[Errno 2] No such file or directory: '/dummy'")): + os_ops.read_binary("/dummy", 0) + + def test_get_file_size__unk_file(self, os_ops: OsOperations): + """ + Test LocalOperations::get_file_size. + """ + assert isinstance(os_ops, OsOperations) + + with pytest.raises(FileNotFoundError, match=re.escape("[Errno 2] No such file or directory: '/dummy'")): + os_ops.get_file_size("/dummy") + + def test_cwd(self, os_ops: OsOperations): + """ + Test cwd. + """ + assert isinstance(os_ops, OsOperations) + + v = os_ops.cwd() + + assert v is not None + assert type(v) == str # noqa: E721 + + expectedValue = os.getcwd() + assert expectedValue is not None + assert type(expectedValue) == str # noqa: E721 + assert expectedValue != "" # research + + # Comp result + assert v == expectedValue diff --git a/tests/test_os_ops_remote.py b/tests/test_os_ops_remote.py new file mode 100755 index 00000000..65830218 --- /dev/null +++ b/tests/test_os_ops_remote.py @@ -0,0 +1,79 @@ +# coding: utf-8 + +from .helpers.global_data import OsOpsDescrs +from .helpers.global_data import OsOperations + +from testgres import ExecUtilException + +import os +import pytest + + +class TestOsOpsRemote: + @pytest.fixture + def os_ops(self): + return OsOpsDescrs.sm_remote_os_ops + + def test_rmdirs__try_to_delete_nonexist_path(self, os_ops: OsOperations): + assert isinstance(os_ops, OsOperations) + + path = "/root/test_dir" + + assert os_ops.rmdirs(path, ignore_errors=False) is True + + def test_rmdirs__try_to_delete_file(self, os_ops: OsOperations): + assert isinstance(os_ops, OsOperations) + + path = os_ops.mkstemp() + assert type(path) == str # noqa: E721 + assert os.path.exists(path) + + with pytest.raises(ExecUtilException) as x: + os_ops.rmdirs(path, ignore_errors=False) + + assert os.path.exists(path) + assert type(x.value) == ExecUtilException # noqa: E721 + assert x.value.message == "Utility exited with non-zero code (20). Error: `cannot remove '" + path + "': it is not a directory`" + assert type(x.value.error) == str # noqa: E721 + assert x.value.error.strip() == "cannot remove '" + path + "': it is not a directory" + assert type(x.value.exit_code) == int # noqa: E721 + assert x.value.exit_code == 20 + + def test_read__unknown_file(self, os_ops: OsOperations): + """ + Test RemoteOperations::read with unknown file. + """ + assert isinstance(os_ops, OsOperations) + + with pytest.raises(ExecUtilException) as x: + os_ops.read("/dummy") + + assert "Utility exited with non-zero code (1)." in str(x.value) + assert "No such file or directory" in str(x.value) + assert "/dummy" in str(x.value) + + def test_read_binary__spec__unk_file(self, os_ops: OsOperations): + """ + Test RemoteOperations::read_binary with unknown file. + """ + assert isinstance(os_ops, OsOperations) + + with pytest.raises(ExecUtilException) as x: + os_ops.read_binary("/dummy", 0) + + assert "Utility exited with non-zero code (1)." in str(x.value) + assert "No such file or directory" in str(x.value) + assert "/dummy" in str(x.value) + + def test_get_file_size__unk_file(self, os_ops: OsOperations): + """ + Test RemoteOperations::get_file_size. + """ + assert isinstance(os_ops, OsOperations) + + with pytest.raises(ExecUtilException) as x: + os_ops.get_file_size("/dummy") + + assert "Utility exited with non-zero code (1)." in str(x.value) + assert "No such file or directory" in str(x.value) + assert "/dummy" in str(x.value) diff --git a/tests/test_remote.py b/tests/test_remote.py deleted file mode 100755 index e0e4a555..00000000 --- a/tests/test_remote.py +++ /dev/null @@ -1,194 +0,0 @@ -import os - -import pytest - -from testgres import ExecUtilException -from testgres import RemoteOperations -from testgres import ConnectionParams - - -class TestRemoteOperations: - - @pytest.fixture(scope="function", autouse=True) - def setup(self): - conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '127.0.0.1', - username=os.getenv('USER'), - ssh_key=os.getenv('RDBMS_TESTPOOL_SSHKEY')) - self.operations = RemoteOperations(conn_params) - - def test_exec_command_success(self): - """ - Test exec_command for successful command execution. - """ - cmd = "python3 --version" - response = self.operations.exec_command(cmd, wait_exit=True) - - assert b'Python 3.' in response - - def test_exec_command_failure(self): - """ - Test exec_command for command execution failure. - """ - cmd = "nonexistent_command" - try: - exit_status, result, error = self.operations.exec_command(cmd, verbose=True, wait_exit=True) - except ExecUtilException as e: - error = e.message - assert error == b'Utility exited with non-zero code. Error: bash: line 1: nonexistent_command: command not found\n' - - def test_is_executable_true(self): - """ - Test is_executable for an existing executable. - """ - cmd = os.getenv('PG_CONFIG') - response = self.operations.is_executable(cmd) - - assert response is True - - def test_is_executable_false(self): - """ - Test is_executable for a non-executable. - """ - cmd = "python" - response = self.operations.is_executable(cmd) - - assert response is False - - def test_makedirs_and_rmdirs_success(self): - """ - Test makedirs and rmdirs for successful directory creation and removal. - """ - cmd = "pwd" - pwd = self.operations.exec_command(cmd, wait_exit=True, encoding='utf-8').strip() - - path = "{}/test_dir".format(pwd) - - # Test makedirs - self.operations.makedirs(path) - assert self.operations.path_exists(path) - - # Test rmdirs - self.operations.rmdirs(path) - assert not self.operations.path_exists(path) - - def test_makedirs_and_rmdirs_failure(self): - """ - Test makedirs and rmdirs for directory creation and removal failure. - """ - # Try to create a directory in a read-only location - path = "/root/test_dir" - - # Test makedirs - with pytest.raises(Exception): - self.operations.makedirs(path) - - # Test rmdirs - try: - exit_status, result, error = self.operations.rmdirs(path, verbose=True) - except ExecUtilException as e: - error = e.message - assert error == b"Utility exited with non-zero code. Error: rm: cannot remove '/root/test_dir': Permission denied\n" - - def test_listdir(self): - """ - Test listdir for listing directory contents. - """ - path = "/etc" - files = self.operations.listdir(path) - - assert isinstance(files, list) - - def test_path_exists_true(self): - """ - Test path_exists for an existing path. - """ - path = "/etc" - response = self.operations.path_exists(path) - - assert response is True - - def test_path_exists_false(self): - """ - Test path_exists for a non-existing path. - """ - path = "/nonexistent_path" - response = self.operations.path_exists(path) - - assert response is False - - def test_write_text_file(self): - """ - Test write for writing data to a text file. - """ - filename = "/tmp/test_file.txt" - data = "Hello, world!" - - self.operations.write(filename, data, truncate=True) - self.operations.write(filename, data) - - response = self.operations.read(filename) - - assert response == data + data - - def test_write_binary_file(self): - """ - Test write for writing data to a binary file. - """ - filename = "/tmp/test_file.bin" - data = b"\x00\x01\x02\x03" - - self.operations.write(filename, data, binary=True, truncate=True) - - response = self.operations.read(filename, binary=True) - - assert response == data - - def test_read_text_file(self): - """ - Test read for reading data from a text file. - """ - filename = "/etc/hosts" - - response = self.operations.read(filename) - - assert isinstance(response, str) - - def test_read_binary_file(self): - """ - Test read for reading data from a binary file. - """ - filename = "/usr/bin/python3" - - response = self.operations.read(filename, binary=True) - - assert isinstance(response, bytes) - - def test_touch(self): - """ - Test touch for creating a new file or updating access and modification times of an existing file. - """ - filename = "/tmp/test_file.txt" - - self.operations.touch(filename) - - assert self.operations.isfile(filename) - - def test_isfile_true(self): - """ - Test isfile for an existing file. - """ - filename = "/etc/hosts" - - response = self.operations.isfile(filename) - - assert response is True - - def test_isfile_false(self): - """ - Test isfile for a non-existing file. - """ - filename = "/nonexistent_file.txt" - - response = self.operations.isfile(filename) - - assert response is False diff --git a/tests/test_simple.py b/tests/test_simple.py deleted file mode 100644 index ba23c3ed..00000000 --- a/tests/test_simple.py +++ /dev/null @@ -1,1063 +0,0 @@ -#!/usr/bin/env python -# coding: utf-8 - -import os -import re -import subprocess -import tempfile -import testgres -import time -import six -import unittest -import psutil - -import logging.config - -from contextlib import contextmanager -from shutil import rmtree - -from testgres import \ - InitNodeException, \ - StartNodeException, \ - ExecUtilException, \ - BackupException, \ - QueryException, \ - TimeoutException, \ - TestgresException - -from testgres import \ - TestgresConfig, \ - configure_testgres, \ - scoped_config, \ - pop_config - -from testgres import \ - NodeStatus, \ - ProcessType, \ - IsolationLevel, \ - get_new_node - -from testgres import \ - get_bin_path, \ - get_pg_config, \ - get_pg_version - -from testgres import \ - First, \ - Any - -# NOTE: those are ugly imports -from testgres import bound_ports -from testgres.utils import PgVer, parse_pg_version -from testgres.node import ProcessProxy - - -def pg_version_ge(version): - cur_ver = PgVer(get_pg_version()) - min_ver = PgVer(version) - return cur_ver >= min_ver - - -def util_exists(util): - def good_properties(f): - return (os.path.exists(f) and # noqa: W504 - os.path.isfile(f) and # noqa: W504 - os.access(f, os.X_OK)) # yapf: disable - - # try to resolve it - if good_properties(get_bin_path(util)): - return True - - # check if util is in PATH - for path in os.environ["PATH"].split(os.pathsep): - if good_properties(os.path.join(path, util)): - return True - - -def rm_carriage_returns(out): - """ - In Windows we have additional '\r' symbols in output. - Let's get rid of them. - """ - if os.name == 'nt': - if isinstance(out, (int, float, complex)): - return out - elif isinstance(out, tuple): - return tuple(rm_carriage_returns(item) for item in out) - elif isinstance(out, bytes): - return out.replace(b'\r', b'') - else: - return out.replace('\r', '') - else: - return out - - -@contextmanager -def removing(f): - try: - yield f - finally: - if os.path.isfile(f): - os.remove(f) - elif os.path.isdir(f): - rmtree(f, ignore_errors=True) - - -class TestgresTests(unittest.TestCase): - def test_node_repr(self): - with get_new_node() as node: - pattern = r"PostgresNode\(name='.+', port=.+, base_dir='.+'\)" - self.assertIsNotNone(re.match(pattern, str(node))) - - def test_custom_init(self): - with get_new_node() as node: - # enable page checksums - node.init(initdb_params=['-k']).start() - - with get_new_node() as node: - node.init( - allow_streaming=True, - initdb_params=['--auth-local=reject', '--auth-host=reject']) - - hba_file = os.path.join(node.data_dir, 'pg_hba.conf') - with open(hba_file, 'r') as conf: - lines = conf.readlines() - - # check number of lines - self.assertGreaterEqual(len(lines), 6) - - # there should be no trust entries at all - self.assertFalse(any('trust' in s for s in lines)) - - def test_double_init(self): - with get_new_node().init() as node: - # can't initialize node more than once - with self.assertRaises(InitNodeException): - node.init() - - def test_init_after_cleanup(self): - with get_new_node() as node: - node.init().start().execute('select 1') - node.cleanup() - node.init().start().execute('select 1') - - @unittest.skipUnless(util_exists('pg_resetwal.exe' if os.name == 'nt' else 'pg_resetwal'), 'pgbench might be missing') - @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') - def test_init_unique_system_id(self): - # this function exists in PostgreSQL 9.6+ - query = 'select system_identifier from pg_control_system()' - - with scoped_config(cache_initdb=False): - with get_new_node().init().start() as node0: - id0 = node0.execute(query)[0] - - with scoped_config(cache_initdb=True, - cached_initdb_unique=True) as config: - - self.assertTrue(config.cache_initdb) - self.assertTrue(config.cached_initdb_unique) - - # spawn two nodes; ids must be different - with get_new_node().init().start() as node1, \ - get_new_node().init().start() as node2: - - id1 = node1.execute(query)[0] - id2 = node2.execute(query)[0] - - # ids must increase - self.assertGreater(id1, id0) - self.assertGreater(id2, id1) - - def test_node_exit(self): - base_dir = None - - with self.assertRaises(QueryException): - with get_new_node().init() as node: - base_dir = node.base_dir - node.safe_psql('select 1') - - # we should save the DB for "debugging" - self.assertTrue(os.path.exists(base_dir)) - rmtree(base_dir, ignore_errors=True) - - with get_new_node().init() as node: - base_dir = node.base_dir - - # should have been removed by default - self.assertFalse(os.path.exists(base_dir)) - - def test_double_start(self): - with get_new_node().init().start() as node: - # can't start node more than once - node.start() - self.assertTrue(node.is_started) - - def test_uninitialized_start(self): - with get_new_node() as node: - # node is not initialized yet - with self.assertRaises(StartNodeException): - node.start() - - def test_restart(self): - with get_new_node() as node: - node.init().start() - - # restart, ok - res = node.execute('select 1') - self.assertEqual(res, [(1, )]) - node.restart() - res = node.execute('select 2') - self.assertEqual(res, [(2, )]) - - # restart, fail - with self.assertRaises(StartNodeException): - node.append_conf('pg_hba.conf', 'DUMMY') - node.restart() - - def test_reload(self): - with get_new_node() as node: - node.init().start() - - # change client_min_messages and save old value - cmm_old = node.execute('show client_min_messages') - node.append_conf(client_min_messages='DEBUG1') - - # reload config - node.reload() - - # check new value - cmm_new = node.execute('show client_min_messages') - self.assertEqual('debug1', cmm_new[0][0].lower()) - self.assertNotEqual(cmm_old, cmm_new) - - def test_pg_ctl(self): - with get_new_node() as node: - node.init().start() - - status = node.pg_ctl(['status']) - self.assertTrue('PID' in status) - - def test_status(self): - self.assertTrue(NodeStatus.Running) - self.assertFalse(NodeStatus.Stopped) - self.assertFalse(NodeStatus.Uninitialized) - - # check statuses after each operation - with get_new_node() as node: - self.assertEqual(node.pid, 0) - self.assertEqual(node.status(), NodeStatus.Uninitialized) - - node.init() - - self.assertEqual(node.pid, 0) - self.assertEqual(node.status(), NodeStatus.Stopped) - - node.start() - - self.assertNotEqual(node.pid, 0) - self.assertEqual(node.status(), NodeStatus.Running) - - node.stop() - - self.assertEqual(node.pid, 0) - self.assertEqual(node.status(), NodeStatus.Stopped) - - node.cleanup() - - self.assertEqual(node.pid, 0) - self.assertEqual(node.status(), NodeStatus.Uninitialized) - - def test_psql(self): - with get_new_node().init().start() as node: - - # check returned values (1 arg) - res = node.psql('select 1') - self.assertEqual(rm_carriage_returns(res), (0, b'1\n', b'')) - - # check returned values (2 args) - res = node.psql('postgres', 'select 2') - self.assertEqual(rm_carriage_returns(res), (0, b'2\n', b'')) - - # check returned values (named) - res = node.psql(query='select 3', dbname='postgres') - self.assertEqual(rm_carriage_returns(res), (0, b'3\n', b'')) - - # check returned values (1 arg) - res = node.safe_psql('select 4') - self.assertEqual(rm_carriage_returns(res), b'4\n') - - # check returned values (2 args) - res = node.safe_psql('postgres', 'select 5') - self.assertEqual(rm_carriage_returns(res), b'5\n') - - # check returned values (named) - res = node.safe_psql(query='select 6', dbname='postgres') - self.assertEqual(rm_carriage_returns(res), b'6\n') - - # check feeding input - node.safe_psql('create table horns (w int)') - node.safe_psql('copy horns from stdin (format csv)', - input=b"1\n2\n3\n\\.\n") - _sum = node.safe_psql('select sum(w) from horns') - self.assertEqual(rm_carriage_returns(_sum), b'6\n') - - # check psql's default args, fails - with self.assertRaises(QueryException): - node.psql() - - node.stop() - - # check psql on stopped node, fails - with self.assertRaises(QueryException): - node.safe_psql('select 1') - - def test_transactions(self): - with get_new_node().init().start() as node: - - with node.connect() as con: - con.begin() - con.execute('create table test(val int)') - con.execute('insert into test values (1)') - con.commit() - - con.begin() - con.execute('insert into test values (2)') - res = con.execute('select * from test order by val asc') - self.assertListEqual(res, [(1, ), (2, )]) - con.rollback() - - con.begin() - res = con.execute('select * from test') - self.assertListEqual(res, [(1, )]) - con.rollback() - - con.begin() - con.execute('drop table test') - con.commit() - - def test_control_data(self): - with get_new_node() as node: - - # node is not initialized yet - with self.assertRaises(ExecUtilException): - node.get_control_data() - - node.init() - data = node.get_control_data() - - # check returned dict - self.assertIsNotNone(data) - self.assertTrue(any('pg_control' in s for s in data.keys())) - - def test_backup_simple(self): - with get_new_node() as master: - - # enable streaming for backups - master.init(allow_streaming=True) - - # node must be running - with self.assertRaises(BackupException): - master.backup() - - # it's time to start node - master.start() - - # fill node with some data - master.psql('create table test as select generate_series(1, 4) i') - - with master.backup(xlog_method='stream') as backup: - with backup.spawn_primary().start() as slave: - res = slave.execute('select * from test order by i asc') - self.assertListEqual(res, [(1, ), (2, ), (3, ), (4, )]) - - def test_backup_multiple(self): - with get_new_node() as node: - node.init(allow_streaming=True).start() - - with node.backup(xlog_method='fetch') as backup1, \ - node.backup(xlog_method='fetch') as backup2: - - self.assertNotEqual(backup1.base_dir, backup2.base_dir) - - with node.backup(xlog_method='fetch') as backup: - with backup.spawn_primary('node1', destroy=False) as node1, \ - backup.spawn_primary('node2', destroy=False) as node2: - - self.assertNotEqual(node1.base_dir, node2.base_dir) - - def test_backup_exhaust(self): - with get_new_node() as node: - node.init(allow_streaming=True).start() - - with node.backup(xlog_method='fetch') as backup: - - # exhaust backup by creating new node - with backup.spawn_primary(): - pass - - # now let's try to create one more node - with self.assertRaises(BackupException): - backup.spawn_primary() - - def test_backup_wrong_xlog_method(self): - with get_new_node() as node: - node.init(allow_streaming=True).start() - - with self.assertRaises(BackupException, - msg='Invalid xlog_method "wrong"'): - node.backup(xlog_method='wrong') - - def test_pg_ctl_wait_option(self): - with get_new_node() as node: - node.init().start(wait=False) - while True: - try: - node.stop(wait=False) - break - except ExecUtilException: - # it's ok to get this exception here since node - # could be not started yet - pass - - def test_replicate(self): - with get_new_node() as node: - node.init(allow_streaming=True).start() - - with node.replicate().start() as replica: - res = replica.execute('select 1') - self.assertListEqual(res, [(1, )]) - - node.execute('create table test (val int)', commit=True) - - replica.catchup() - - res = node.execute('select * from test') - self.assertListEqual(res, []) - - @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') - def test_synchronous_replication(self): - with get_new_node() as master: - old_version = not pg_version_ge('9.6') - - master.init(allow_streaming=True).start() - - if not old_version: - master.append_conf('synchronous_commit = remote_apply') - - # create standby - with master.replicate() as standby1, master.replicate() as standby2: - standby1.start() - standby2.start() - - # check formatting - self.assertEqual( - '1 ("{}", "{}")'.format(standby1.name, standby2.name), - str(First(1, (standby1, standby2)))) # yapf: disable - self.assertEqual( - 'ANY 1 ("{}", "{}")'.format(standby1.name, standby2.name), - str(Any(1, (standby1, standby2)))) # yapf: disable - - # set synchronous_standby_names - master.set_synchronous_standbys(First(2, [standby1, standby2])) - master.restart() - - # the following part of the test is only applicable to newer - # versions of PostgresQL - if not old_version: - master.safe_psql('create table abc(a int)') - - # Create a large transaction that will take some time to apply - # on standby to check that it applies synchronously - # (If set synchronous_commit to 'on' or other lower level then - # standby most likely won't catchup so fast and test will fail) - master.safe_psql( - 'insert into abc select generate_series(1, 1000000)') - res = standby1.safe_psql('select count(*) from abc') - self.assertEqual(rm_carriage_returns(res), b'1000000\n') - - @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') - def test_logical_replication(self): - with get_new_node() as node1, get_new_node() as node2: - node1.init(allow_logical=True) - node1.start() - node2.init().start() - - create_table = 'create table test (a int, b int)' - node1.safe_psql(create_table) - node2.safe_psql(create_table) - - # create publication / create subscription - pub = node1.publish('mypub') - sub = node2.subscribe(pub, 'mysub') - - node1.safe_psql('insert into test values (1, 1), (2, 2)') - - # wait until changes apply on subscriber and check them - sub.catchup() - res = node2.execute('select * from test') - self.assertListEqual(res, [(1, 1), (2, 2)]) - - # disable and put some new data - sub.disable() - node1.safe_psql('insert into test values (3, 3)') - - # enable and ensure that data successfully transfered - sub.enable() - sub.catchup() - res = node2.execute('select * from test') - self.assertListEqual(res, [(1, 1), (2, 2), (3, 3)]) - - # Add new tables. Since we added "all tables" to publication - # (default behaviour of publish() method) we don't need - # to explicitely perform pub.add_tables() - create_table = 'create table test2 (c char)' - node1.safe_psql(create_table) - node2.safe_psql(create_table) - sub.refresh() - - # put new data - node1.safe_psql('insert into test2 values (\'a\'), (\'b\')') - sub.catchup() - res = node2.execute('select * from test2') - self.assertListEqual(res, [('a', ), ('b', )]) - - # drop subscription - sub.drop() - pub.drop() - - # create new publication and subscription for specific table - # (ommitting copying data as it's already done) - pub = node1.publish('newpub', tables=['test']) - sub = node2.subscribe(pub, 'newsub', copy_data=False) - - node1.safe_psql('insert into test values (4, 4)') - sub.catchup() - res = node2.execute('select * from test') - self.assertListEqual(res, [(1, 1), (2, 2), (3, 3), (4, 4)]) - - # explicitely add table - with self.assertRaises(ValueError): - pub.add_tables([]) # fail - pub.add_tables(['test2']) - node1.safe_psql('insert into test2 values (\'c\')') - sub.catchup() - res = node2.execute('select * from test2') - self.assertListEqual(res, [('a', ), ('b', )]) - - @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') - def test_logical_catchup(self): - """ Runs catchup for 100 times to be sure that it is consistent """ - with get_new_node() as node1, get_new_node() as node2: - node1.init(allow_logical=True) - node1.start() - node2.init().start() - - create_table = 'create table test (key int primary key, val int); ' - node1.safe_psql(create_table) - node1.safe_psql('alter table test replica identity default') - node2.safe_psql(create_table) - - # create publication / create subscription - sub = node2.subscribe(node1.publish('mypub'), 'mysub') - - for i in range(0, 100): - node1.execute('insert into test values ({0}, {0})'.format(i)) - sub.catchup() - res = node2.execute('select * from test') - self.assertListEqual(res, [( - i, - i, - )]) - node1.execute('delete from test') - - @unittest.skipIf(pg_version_ge('10'), 'requires <10') - def test_logical_replication_fail(self): - with get_new_node() as node: - with self.assertRaises(InitNodeException): - node.init(allow_logical=True) - - def test_replication_slots(self): - with get_new_node() as node: - node.init(allow_streaming=True).start() - - with node.replicate(slot='slot1').start() as replica: - replica.execute('select 1') - - # cannot create new slot with the same name - with self.assertRaises(TestgresException): - node.replicate(slot='slot1') - - def test_incorrect_catchup(self): - with get_new_node() as node: - node.init(allow_streaming=True).start() - - # node has no master, can't catch up - with self.assertRaises(TestgresException): - node.catchup() - - def test_promotion(self): - with get_new_node() as master: - master.init().start() - master.safe_psql('create table abc(id serial)') - - with master.replicate().start() as replica: - master.stop() - replica.promote() - - # make standby becomes writable master - replica.safe_psql('insert into abc values (1)') - res = replica.safe_psql('select * from abc') - self.assertEqual(rm_carriage_returns(res), b'1\n') - - def test_dump(self): - query_create = 'create table test as select generate_series(1, 2) as val' - query_select = 'select * from test order by val asc' - - with get_new_node().init().start() as node1: - - node1.execute(query_create) - for format in ['plain', 'custom', 'directory', 'tar']: - with removing(node1.dump(format=format)) as dump: - with get_new_node().init().start() as node3: - if format == 'directory': - self.assertTrue(os.path.isdir(dump)) - else: - self.assertTrue(os.path.isfile(dump)) - # restore dump - node3.restore(filename=dump) - res = node3.execute(query_select) - self.assertListEqual(res, [(1, ), (2, )]) - - def test_users(self): - with get_new_node().init().start() as node: - node.psql('create role test_user login') - value = node.safe_psql('select 1', username='test_user') - value = rm_carriage_returns(value) - self.assertEqual(value, b'1\n') - - def test_poll_query_until(self): - with get_new_node() as node: - node.init().start() - - get_time = 'select extract(epoch from now())' - check_time = 'select extract(epoch from now()) - {} >= 5' - - start_time = node.execute(get_time)[0][0] - node.poll_query_until(query=check_time.format(start_time)) - end_time = node.execute(get_time)[0][0] - - self.assertTrue(end_time - start_time >= 5) - - # check 0 columns - with self.assertRaises(QueryException): - node.poll_query_until( - query='select from pg_catalog.pg_class limit 1') - - # check None, fail - with self.assertRaises(QueryException): - node.poll_query_until(query='create table abc (val int)') - - # check None, ok - node.poll_query_until(query='create table def()', - expected=None) # returns nothing - - # check 0 rows equivalent to expected=None - node.poll_query_until( - query='select * from pg_catalog.pg_class where true = false', - expected=None) - - # check arbitrary expected value, fail - with self.assertRaises(TimeoutException): - node.poll_query_until(query='select 3', - expected=1, - max_attempts=3, - sleep_time=0.01) - - # check arbitrary expected value, ok - node.poll_query_until(query='select 2', expected=2) - - # check timeout - with self.assertRaises(TimeoutException): - node.poll_query_until(query='select 1 > 2', - max_attempts=3, - sleep_time=0.01) - - # check ProgrammingError, fail - with self.assertRaises(testgres.ProgrammingError): - node.poll_query_until(query='dummy1') - - # check ProgrammingError, ok - with self.assertRaises(TimeoutException): - node.poll_query_until(query='dummy2', - max_attempts=3, - sleep_time=0.01, - suppress={testgres.ProgrammingError}) - - # check 1 arg, ok - node.poll_query_until('select true') - - def test_logging(self): - logfile = tempfile.NamedTemporaryFile('w', delete=True) - - log_conf = { - 'version': 1, - 'handlers': { - 'file': { - 'class': 'logging.FileHandler', - 'filename': logfile.name, - 'formatter': 'base_format', - 'level': logging.DEBUG, - }, - }, - 'formatters': { - 'base_format': { - 'format': '%(node)-5s: %(message)s', - }, - }, - 'root': { - 'handlers': ('file', ), - 'level': 'DEBUG', - }, - } - - logging.config.dictConfig(log_conf) - - with scoped_config(use_python_logging=True): - node_name = 'master' - - with get_new_node(name=node_name) as master: - master.init().start() - - # execute a dummy query a few times - for i in range(20): - master.execute('select 1') - time.sleep(0.01) - - # let logging worker do the job - time.sleep(0.1) - - # check that master's port is found - with open(logfile.name, 'r') as log: - lines = log.readlines() - self.assertTrue(any(node_name in s for s in lines)) - - # test logger after stop/start/restart - master.stop() - master.start() - master.restart() - self.assertTrue(master._logger.is_alive()) - - @unittest.skipUnless(util_exists('pgbench.exe' if os.name == 'nt' else 'pgbench'), 'pgbench might be missing') - def test_pgbench(self): - with get_new_node().init().start() as node: - - # initialize pgbench DB and run benchmarks - node.pgbench_init(scale=2, foreign_keys=True, - options=['-q']).pgbench_run(time=2) - - # run TPC-B benchmark - proc = node.pgbench(stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - options=['-T3']) - - out, _ = proc.communicate() - out = out.decode('utf-8') - - proc.stdout.close() - - self.assertTrue('tps' in out) - - def test_pg_config(self): - # check same instances - a = get_pg_config() - b = get_pg_config() - self.assertEqual(id(a), id(b)) - - # save right before config change - c1 = get_pg_config() - - # modify setting for this scope - with scoped_config(cache_pg_config=False) as config: - - # sanity check for value - self.assertFalse(config.cache_pg_config) - - # save right after config change - c2 = get_pg_config() - - # check different instances after config change - self.assertNotEqual(id(c1), id(c2)) - - # check different instances - a = get_pg_config() - b = get_pg_config() - self.assertNotEqual(id(a), id(b)) - - def test_config_stack(self): - # no such option - with self.assertRaises(TypeError): - configure_testgres(dummy=True) - - # we have only 1 config in stack - with self.assertRaises(IndexError): - pop_config() - - d0 = TestgresConfig.cached_initdb_dir - d1 = 'dummy_abc' - d2 = 'dummy_def' - - with scoped_config(cached_initdb_dir=d1) as c1: - self.assertEqual(c1.cached_initdb_dir, d1) - - with scoped_config(cached_initdb_dir=d2) as c2: - - stack_size = len(testgres.config.config_stack) - - # try to break a stack - with self.assertRaises(TypeError): - with scoped_config(dummy=True): - pass - - self.assertEqual(c2.cached_initdb_dir, d2) - self.assertEqual(len(testgres.config.config_stack), stack_size) - - self.assertEqual(c1.cached_initdb_dir, d1) - - self.assertEqual(TestgresConfig.cached_initdb_dir, d0) - - def test_unix_sockets(self): - with get_new_node() as node: - node.init(unix_sockets=False, allow_streaming=True) - node.start() - - node.execute('select 1') - node.safe_psql('select 1') - - with node.replicate().start() as r: - r.execute('select 1') - r.safe_psql('select 1') - - def test_auto_name(self): - with get_new_node().init(allow_streaming=True).start() as m: - with m.replicate().start() as r: - - # check that nodes are running - self.assertTrue(m.status()) - self.assertTrue(r.status()) - - # check their names - self.assertNotEqual(m.name, r.name) - self.assertTrue('testgres' in m.name) - self.assertTrue('testgres' in r.name) - - def test_file_tail(self): - from testgres.utils import file_tail - - s1 = "the quick brown fox jumped over that lazy dog\n" - s2 = "abc\n" - s3 = "def\n" - - with tempfile.NamedTemporaryFile(mode='r+', delete=True) as f: - sz = 0 - while sz < 3 * 8192: - sz += len(s1) - f.write(s1) - f.write(s2) - f.write(s3) - - f.seek(0) - lines = file_tail(f, 3) - self.assertEqual(lines[0], s1) - self.assertEqual(lines[1], s2) - self.assertEqual(lines[2], s3) - - f.seek(0) - lines = file_tail(f, 1) - self.assertEqual(lines[0], s3) - - def test_isolation_levels(self): - with get_new_node().init().start() as node: - with node.connect() as con: - # string levels - con.begin('Read Uncommitted').commit() - con.begin('Read Committed').commit() - con.begin('Repeatable Read').commit() - con.begin('Serializable').commit() - - # enum levels - con.begin(IsolationLevel.ReadUncommitted).commit() - con.begin(IsolationLevel.ReadCommitted).commit() - con.begin(IsolationLevel.RepeatableRead).commit() - con.begin(IsolationLevel.Serializable).commit() - - # check wrong level - with self.assertRaises(QueryException): - con.begin('Garbage').commit() - - def test_ports_management(self): - # check that no ports have been bound yet - self.assertEqual(len(bound_ports), 0) - - with get_new_node() as node: - # check that we've just bound a port - self.assertEqual(len(bound_ports), 1) - - # check that bound_ports contains our port - port_1 = list(bound_ports)[0] - port_2 = node.port - self.assertEqual(port_1, port_2) - - # check that port has been freed successfully - self.assertEqual(len(bound_ports), 0) - - def test_exceptions(self): - str(StartNodeException('msg', [('file', 'lines')])) - str(ExecUtilException('msg', 'cmd', 1, 'out')) - str(QueryException('msg', 'query')) - - def test_version_management(self): - a = PgVer('10.0') - b = PgVer('10') - c = PgVer('9.6.5') - d = PgVer('15.0') - e = PgVer('15rc1') - f = PgVer('15beta4') - h = PgVer('15.3biha') - i = PgVer('15.3') - g = PgVer('15.3.1bihabeta1') - k = PgVer('15.3.1') - - self.assertTrue(a == b) - self.assertTrue(b > c) - self.assertTrue(a > c) - self.assertTrue(d > e) - self.assertTrue(e > f) - self.assertTrue(d > f) - self.assertTrue(h > f) - self.assertTrue(h == i) - self.assertTrue(g == k) - self.assertTrue(g > h) - - version = get_pg_version() - with get_new_node() as node: - self.assertTrue(isinstance(version, six.string_types)) - self.assertTrue(isinstance(node.version, PgVer)) - self.assertEqual(node.version, PgVer(version)) - - def test_child_pids(self): - master_processes = [ - ProcessType.AutovacuumLauncher, - ProcessType.BackgroundWriter, - ProcessType.Checkpointer, - ProcessType.StatsCollector, - ProcessType.WalSender, - ProcessType.WalWriter, - ] - - if pg_version_ge('10'): - master_processes.append(ProcessType.LogicalReplicationLauncher) - - repl_processes = [ - ProcessType.Startup, - ProcessType.WalReceiver, - ] - - with get_new_node().init().start() as master: - - # master node doesn't have a source walsender! - with self.assertRaises(TestgresException): - master.source_walsender - - with master.connect() as con: - self.assertGreater(con.pid, 0) - - with master.replicate().start() as replica: - - # test __str__ method - str(master.child_processes[0]) - - master_pids = master.auxiliary_pids - for ptype in master_processes: - self.assertIn(ptype, master_pids) - - replica_pids = replica.auxiliary_pids - for ptype in repl_processes: - self.assertIn(ptype, replica_pids) - - # there should be exactly 1 source walsender for replica - self.assertEqual(len(master_pids[ProcessType.WalSender]), 1) - pid1 = master_pids[ProcessType.WalSender][0] - pid2 = replica.source_walsender.pid - self.assertEqual(pid1, pid2) - - replica.stop() - - # there should be no walsender after we've stopped replica - with self.assertRaises(TestgresException): - replica.source_walsender - - def test_child_process_dies(self): - # test for FileNotFound exception during child_processes() function - cmd = ["timeout", "60"] if os.name == 'nt' else ["sleep", "60"] - - with subprocess.Popen(cmd, shell=True) as process: # shell=True might be needed on Windows - self.assertEqual(process.poll(), None) - # collect list of processes currently running - children = psutil.Process(os.getpid()).children() - # kill a process, so received children dictionary becomes invalid - process.kill() - process.wait() - # try to handle children list -- missing processes will have ptype "ProcessType.Unknown" - [ProcessProxy(p) for p in children] - - def test_upgrade_node(self): - old_bin_dir = os.path.dirname(get_bin_path("pg_config")) - new_bin_dir = os.path.dirname(get_bin_path("pg_config")) - node_old = get_new_node(prefix='node_old', bin_dir=old_bin_dir) - node_old.init() - node_old.start() - node_old.stop() - node_new = get_new_node(prefix='node_new', bin_dir=new_bin_dir) - node_new.init(cached=False) - res = node_new.upgrade_from(old_node=node_old) - node_new.start() - self.assertTrue(b'Upgrade Complete' in res) - - def test_parse_pg_version(self): - # Linux Mint - assert parse_pg_version("postgres (PostgreSQL) 15.5 (Ubuntu 15.5-1.pgdg22.04+1)") == "15.5" - # Linux Ubuntu - assert parse_pg_version("postgres (PostgreSQL) 12.17") == "12.17" - # Windows - assert parse_pg_version("postgres (PostgreSQL) 11.4") == "11.4" - # Macos - assert parse_pg_version("postgres (PostgreSQL) 14.9 (Homebrew)") == "14.9" - - def test_the_same_port(self): - with get_new_node() as node: - node.init().start() - - with get_new_node() as node2: - node2.port = node.port - node2.init().start() - - -if __name__ == '__main__': - if os.environ.get('ALT_CONFIG'): - suite = unittest.TestSuite() - - # Small subset of tests for alternative configs (PG_BIN or PG_CONFIG) - suite.addTest(TestgresTests('test_pg_config')) - suite.addTest(TestgresTests('test_pg_ctl')) - suite.addTest(TestgresTests('test_psql')) - suite.addTest(TestgresTests('test_replicate')) - - print('Running tests for alternative config:') - for t in suite: - print(t) - print() - - runner = unittest.TextTestRunner() - runner.run(suite) - else: - unittest.main() diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py deleted file mode 100755 index d51820ba..00000000 --- a/tests/test_simple_remote.py +++ /dev/null @@ -1,995 +0,0 @@ -#!/usr/bin/env python -# coding: utf-8 - -import os -import re -import subprocess -import tempfile - -import testgres -import time -import six -import unittest -import psutil - -import logging.config - -from contextlib import contextmanager - -from testgres.exceptions import \ - InitNodeException, \ - StartNodeException, \ - ExecUtilException, \ - BackupException, \ - QueryException, \ - TimeoutException, \ - TestgresException - -from testgres.config import \ - TestgresConfig, \ - configure_testgres, \ - scoped_config, \ - pop_config, testgres_config - -from testgres import \ - NodeStatus, \ - ProcessType, \ - IsolationLevel, \ - get_remote_node, \ - RemoteOperations - -from testgres import \ - get_bin_path, \ - get_pg_config, \ - get_pg_version - -from testgres import \ - First, \ - Any - -# NOTE: those are ugly imports -from testgres import bound_ports -from testgres.utils import PgVer -from testgres.node import ProcessProxy, ConnectionParams - -conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '127.0.0.1', - username=os.getenv('USER'), - ssh_key=os.getenv('RDBMS_TESTPOOL_SSHKEY')) -os_ops = RemoteOperations(conn_params) -testgres_config.set_os_ops(os_ops=os_ops) - - -def pg_version_ge(version): - cur_ver = PgVer(get_pg_version()) - min_ver = PgVer(version) - return cur_ver >= min_ver - - -def util_exists(util): - def good_properties(f): - return (os_ops.path_exists(f) and # noqa: W504 - os_ops.isfile(f) and # noqa: W504 - os_ops.is_executable(f)) # yapf: disable - - # try to resolve it - if good_properties(get_bin_path(util)): - return True - - # check if util is in PATH - for path in os_ops.environ("PATH").split(os_ops.pathsep): - if good_properties(os.path.join(path, util)): - return True - - -@contextmanager -def removing(f): - try: - yield f - finally: - if os_ops.isfile(f): - os_ops.remove_file(f) - - elif os_ops.isdir(f): - os_ops.rmdirs(f, ignore_errors=True) - - -class TestgresRemoteTests(unittest.TestCase): - - def test_node_repr(self): - with get_remote_node(conn_params=conn_params) as node: - pattern = r"PostgresNode\(name='.+', port=.+, base_dir='.+'\)" - self.assertIsNotNone(re.match(pattern, str(node))) - - def test_custom_init(self): - with get_remote_node(conn_params=conn_params) as node: - # enable page checksums - node.init(initdb_params=['-k']).start() - - with get_remote_node(conn_params=conn_params) as node: - node.init( - allow_streaming=True, - initdb_params=['--auth-local=reject', '--auth-host=reject']) - - hba_file = os.path.join(node.data_dir, 'pg_hba.conf') - lines = os_ops.readlines(hba_file) - - # check number of lines - self.assertGreaterEqual(len(lines), 6) - - # there should be no trust entries at all - self.assertFalse(any('trust' in s for s in lines)) - - def test_double_init(self): - with get_remote_node(conn_params=conn_params).init() as node: - # can't initialize node more than once - with self.assertRaises(InitNodeException): - node.init() - - def test_init_after_cleanup(self): - with get_remote_node(conn_params=conn_params) as node: - node.init().start().execute('select 1') - node.cleanup() - node.init().start().execute('select 1') - - @unittest.skipUnless(util_exists('pg_resetwal'), 'might be missing') - @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') - def test_init_unique_system_id(self): - # this function exists in PostgreSQL 9.6+ - query = 'select system_identifier from pg_control_system()' - - with scoped_config(cache_initdb=False): - with get_remote_node(conn_params=conn_params).init().start() as node0: - id0 = node0.execute(query)[0] - - with scoped_config(cache_initdb=True, - cached_initdb_unique=True) as config: - self.assertTrue(config.cache_initdb) - self.assertTrue(config.cached_initdb_unique) - - # spawn two nodes; ids must be different - with get_remote_node(conn_params=conn_params).init().start() as node1, \ - get_remote_node(conn_params=conn_params).init().start() as node2: - id1 = node1.execute(query)[0] - id2 = node2.execute(query)[0] - - # ids must increase - self.assertGreater(id1, id0) - self.assertGreater(id2, id1) - - def test_node_exit(self): - with self.assertRaises(QueryException): - with get_remote_node(conn_params=conn_params).init() as node: - base_dir = node.base_dir - node.safe_psql('select 1') - - # we should save the DB for "debugging" - self.assertTrue(os_ops.path_exists(base_dir)) - os_ops.rmdirs(base_dir, ignore_errors=True) - - with get_remote_node(conn_params=conn_params).init() as node: - base_dir = node.base_dir - - # should have been removed by default - self.assertFalse(os_ops.path_exists(base_dir)) - - def test_double_start(self): - with get_remote_node(conn_params=conn_params).init().start() as node: - # can't start node more than once - node.start() - self.assertTrue(node.is_started) - - def test_uninitialized_start(self): - with get_remote_node(conn_params=conn_params) as node: - # node is not initialized yet - with self.assertRaises(StartNodeException): - node.start() - - def test_restart(self): - with get_remote_node(conn_params=conn_params) as node: - node.init().start() - - # restart, ok - res = node.execute('select 1') - self.assertEqual(res, [(1,)]) - node.restart() - res = node.execute('select 2') - self.assertEqual(res, [(2,)]) - - # restart, fail - with self.assertRaises(StartNodeException): - node.append_conf('pg_hba.conf', 'DUMMY') - node.restart() - - def test_reload(self): - with get_remote_node(conn_params=conn_params) as node: - node.init().start() - - # change client_min_messages and save old value - cmm_old = node.execute('show client_min_messages') - node.append_conf(client_min_messages='DEBUG1') - - # reload config - node.reload() - - # check new value - cmm_new = node.execute('show client_min_messages') - self.assertEqual('debug1', cmm_new[0][0].lower()) - self.assertNotEqual(cmm_old, cmm_new) - - def test_pg_ctl(self): - with get_remote_node(conn_params=conn_params) as node: - node.init().start() - - status = node.pg_ctl(['status']) - self.assertTrue('PID' in status) - - def test_status(self): - self.assertTrue(NodeStatus.Running) - self.assertFalse(NodeStatus.Stopped) - self.assertFalse(NodeStatus.Uninitialized) - - # check statuses after each operation - with get_remote_node(conn_params=conn_params) as node: - self.assertEqual(node.pid, 0) - self.assertEqual(node.status(), NodeStatus.Uninitialized) - - node.init() - - self.assertEqual(node.pid, 0) - self.assertEqual(node.status(), NodeStatus.Stopped) - - node.start() - - self.assertNotEqual(node.pid, 0) - self.assertEqual(node.status(), NodeStatus.Running) - - node.stop() - - self.assertEqual(node.pid, 0) - self.assertEqual(node.status(), NodeStatus.Stopped) - - node.cleanup() - - self.assertEqual(node.pid, 0) - self.assertEqual(node.status(), NodeStatus.Uninitialized) - - def test_psql(self): - with get_remote_node(conn_params=conn_params).init().start() as node: - # check returned values (1 arg) - res = node.psql('select 1') - self.assertEqual(res, (0, b'1\n', b'')) - - # check returned values (2 args) - res = node.psql('postgres', 'select 2') - self.assertEqual(res, (0, b'2\n', b'')) - - # check returned values (named) - res = node.psql(query='select 3', dbname='postgres') - self.assertEqual(res, (0, b'3\n', b'')) - - # check returned values (1 arg) - res = node.safe_psql('select 4') - self.assertEqual(res, b'4\n') - - # check returned values (2 args) - res = node.safe_psql('postgres', 'select 5') - self.assertEqual(res, b'5\n') - - # check returned values (named) - res = node.safe_psql(query='select 6', dbname='postgres') - self.assertEqual(res, b'6\n') - - # check feeding input - node.safe_psql('create table horns (w int)') - node.safe_psql('copy horns from stdin (format csv)', - input=b"1\n2\n3\n\\.\n") - _sum = node.safe_psql('select sum(w) from horns') - self.assertEqual(_sum, b'6\n') - - # check psql's default args, fails - with self.assertRaises(QueryException): - node.psql() - - node.stop() - - # check psql on stopped node, fails - with self.assertRaises(QueryException): - node.safe_psql('select 1') - - def test_transactions(self): - with get_remote_node(conn_params=conn_params).init().start() as node: - with node.connect() as con: - con.begin() - con.execute('create table test(val int)') - con.execute('insert into test values (1)') - con.commit() - - con.begin() - con.execute('insert into test values (2)') - res = con.execute('select * from test order by val asc') - self.assertListEqual(res, [(1,), (2,)]) - con.rollback() - - con.begin() - res = con.execute('select * from test') - self.assertListEqual(res, [(1,)]) - con.rollback() - - con.begin() - con.execute('drop table test') - con.commit() - - def test_control_data(self): - with get_remote_node(conn_params=conn_params) as node: - # node is not initialized yet - with self.assertRaises(ExecUtilException): - node.get_control_data() - - node.init() - data = node.get_control_data() - - # check returned dict - self.assertIsNotNone(data) - self.assertTrue(any('pg_control' in s for s in data.keys())) - - def test_backup_simple(self): - with get_remote_node(conn_params=conn_params) as master: - # enable streaming for backups - master.init(allow_streaming=True) - - # node must be running - with self.assertRaises(BackupException): - master.backup() - - # it's time to start node - master.start() - - # fill node with some data - master.psql('create table test as select generate_series(1, 4) i') - - with master.backup(xlog_method='stream') as backup: - with backup.spawn_primary().start() as slave: - res = slave.execute('select * from test order by i asc') - self.assertListEqual(res, [(1,), (2,), (3,), (4,)]) - - def test_backup_multiple(self): - with get_remote_node(conn_params=conn_params) as node: - node.init(allow_streaming=True).start() - - with node.backup(xlog_method='fetch') as backup1, \ - node.backup(xlog_method='fetch') as backup2: - self.assertNotEqual(backup1.base_dir, backup2.base_dir) - - with node.backup(xlog_method='fetch') as backup: - with backup.spawn_primary('node1', destroy=False) as node1, \ - backup.spawn_primary('node2', destroy=False) as node2: - self.assertNotEqual(node1.base_dir, node2.base_dir) - - def test_backup_exhaust(self): - with get_remote_node(conn_params=conn_params) as node: - node.init(allow_streaming=True).start() - - with node.backup(xlog_method='fetch') as backup: - # exhaust backup by creating new node - with backup.spawn_primary(): - pass - - # now let's try to create one more node - with self.assertRaises(BackupException): - backup.spawn_primary() - - def test_backup_wrong_xlog_method(self): - with get_remote_node(conn_params=conn_params) as node: - node.init(allow_streaming=True).start() - - with self.assertRaises(BackupException, - msg='Invalid xlog_method "wrong"'): - node.backup(xlog_method='wrong') - - def test_pg_ctl_wait_option(self): - with get_remote_node(conn_params=conn_params) as node: - node.init().start(wait=False) - while True: - try: - node.stop(wait=False) - break - except ExecUtilException: - # it's ok to get this exception here since node - # could be not started yet - pass - - def test_replicate(self): - with get_remote_node(conn_params=conn_params) as node: - node.init(allow_streaming=True).start() - - with node.replicate().start() as replica: - res = replica.execute('select 1') - self.assertListEqual(res, [(1,)]) - - node.execute('create table test (val int)', commit=True) - - replica.catchup() - - res = node.execute('select * from test') - self.assertListEqual(res, []) - - @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') - def test_synchronous_replication(self): - with get_remote_node(conn_params=conn_params) as master: - old_version = not pg_version_ge('9.6') - - master.init(allow_streaming=True).start() - - if not old_version: - master.append_conf('synchronous_commit = remote_apply') - - # create standby - with master.replicate() as standby1, master.replicate() as standby2: - standby1.start() - standby2.start() - - # check formatting - self.assertEqual( - '1 ("{}", "{}")'.format(standby1.name, standby2.name), - str(First(1, (standby1, standby2)))) # yapf: disable - self.assertEqual( - 'ANY 1 ("{}", "{}")'.format(standby1.name, standby2.name), - str(Any(1, (standby1, standby2)))) # yapf: disable - - # set synchronous_standby_names - master.set_synchronous_standbys(First(2, [standby1, standby2])) - master.restart() - - # the following part of the test is only applicable to newer - # versions of PostgresQL - if not old_version: - master.safe_psql('create table abc(a int)') - - # Create a large transaction that will take some time to apply - # on standby to check that it applies synchronously - # (If set synchronous_commit to 'on' or other lower level then - # standby most likely won't catchup so fast and test will fail) - master.safe_psql( - 'insert into abc select generate_series(1, 1000000)') - res = standby1.safe_psql('select count(*) from abc') - self.assertEqual(res, b'1000000\n') - - @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') - def test_logical_replication(self): - with get_remote_node(conn_params=conn_params) as node1, get_remote_node(conn_params=conn_params) as node2: - node1.init(allow_logical=True) - node1.start() - node2.init().start() - - create_table = 'create table test (a int, b int)' - node1.safe_psql(create_table) - node2.safe_psql(create_table) - - # create publication / create subscription - pub = node1.publish('mypub') - sub = node2.subscribe(pub, 'mysub') - - node1.safe_psql('insert into test values (1, 1), (2, 2)') - - # wait until changes apply on subscriber and check them - sub.catchup() - res = node2.execute('select * from test') - self.assertListEqual(res, [(1, 1), (2, 2)]) - - # disable and put some new data - sub.disable() - node1.safe_psql('insert into test values (3, 3)') - - # enable and ensure that data successfully transfered - sub.enable() - sub.catchup() - res = node2.execute('select * from test') - self.assertListEqual(res, [(1, 1), (2, 2), (3, 3)]) - - # Add new tables. Since we added "all tables" to publication - # (default behaviour of publish() method) we don't need - # to explicitely perform pub.add_tables() - create_table = 'create table test2 (c char)' - node1.safe_psql(create_table) - node2.safe_psql(create_table) - sub.refresh() - - # put new data - node1.safe_psql('insert into test2 values (\'a\'), (\'b\')') - sub.catchup() - res = node2.execute('select * from test2') - self.assertListEqual(res, [('a',), ('b',)]) - - # drop subscription - sub.drop() - pub.drop() - - # create new publication and subscription for specific table - # (ommitting copying data as it's already done) - pub = node1.publish('newpub', tables=['test']) - sub = node2.subscribe(pub, 'newsub', copy_data=False) - - node1.safe_psql('insert into test values (4, 4)') - sub.catchup() - res = node2.execute('select * from test') - self.assertListEqual(res, [(1, 1), (2, 2), (3, 3), (4, 4)]) - - # explicitely add table - with self.assertRaises(ValueError): - pub.add_tables([]) # fail - pub.add_tables(['test2']) - node1.safe_psql('insert into test2 values (\'c\')') - sub.catchup() - res = node2.execute('select * from test2') - self.assertListEqual(res, [('a',), ('b',)]) - - @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') - def test_logical_catchup(self): - """ Runs catchup for 100 times to be sure that it is consistent """ - with get_remote_node(conn_params=conn_params) as node1, get_remote_node(conn_params=conn_params) as node2: - node1.init(allow_logical=True) - node1.start() - node2.init().start() - - create_table = 'create table test (key int primary key, val int); ' - node1.safe_psql(create_table) - node1.safe_psql('alter table test replica identity default') - node2.safe_psql(create_table) - - # create publication / create subscription - sub = node2.subscribe(node1.publish('mypub'), 'mysub') - - for i in range(0, 100): - node1.execute('insert into test values ({0}, {0})'.format(i)) - sub.catchup() - res = node2.execute('select * from test') - self.assertListEqual(res, [( - i, - i, - )]) - node1.execute('delete from test') - - @unittest.skipIf(pg_version_ge('10'), 'requires <10') - def test_logical_replication_fail(self): - with get_remote_node(conn_params=conn_params) as node: - with self.assertRaises(InitNodeException): - node.init(allow_logical=True) - - def test_replication_slots(self): - with get_remote_node(conn_params=conn_params) as node: - node.init(allow_streaming=True).start() - - with node.replicate(slot='slot1').start() as replica: - replica.execute('select 1') - - # cannot create new slot with the same name - with self.assertRaises(TestgresException): - node.replicate(slot='slot1') - - def test_incorrect_catchup(self): - with get_remote_node(conn_params=conn_params) as node: - node.init(allow_streaming=True).start() - - # node has no master, can't catch up - with self.assertRaises(TestgresException): - node.catchup() - - def test_promotion(self): - with get_remote_node(conn_params=conn_params) as master: - master.init().start() - master.safe_psql('create table abc(id serial)') - - with master.replicate().start() as replica: - master.stop() - replica.promote() - - # make standby becomes writable master - replica.safe_psql('insert into abc values (1)') - res = replica.safe_psql('select * from abc') - self.assertEqual(res, b'1\n') - - def test_dump(self): - query_create = 'create table test as select generate_series(1, 2) as val' - query_select = 'select * from test order by val asc' - - with get_remote_node(conn_params=conn_params).init().start() as node1: - - node1.execute(query_create) - for format in ['plain', 'custom', 'directory', 'tar']: - with removing(node1.dump(format=format)) as dump: - with get_remote_node(conn_params=conn_params).init().start() as node3: - if format == 'directory': - self.assertTrue(os_ops.isdir(dump)) - else: - self.assertTrue(os_ops.isfile(dump)) - # restore dump - node3.restore(filename=dump) - res = node3.execute(query_select) - self.assertListEqual(res, [(1,), (2,)]) - - def test_users(self): - with get_remote_node(conn_params=conn_params).init().start() as node: - node.psql('create role test_user login') - value = node.safe_psql('select 1', username='test_user') - self.assertEqual(b'1\n', value) - - def test_poll_query_until(self): - with get_remote_node(conn_params=conn_params) as node: - node.init().start() - - get_time = 'select extract(epoch from now())' - check_time = 'select extract(epoch from now()) - {} >= 5' - - start_time = node.execute(get_time)[0][0] - node.poll_query_until(query=check_time.format(start_time)) - end_time = node.execute(get_time)[0][0] - - self.assertTrue(end_time - start_time >= 5) - - # check 0 columns - with self.assertRaises(QueryException): - node.poll_query_until( - query='select from pg_catalog.pg_class limit 1') - - # check None, fail - with self.assertRaises(QueryException): - node.poll_query_until(query='create table abc (val int)') - - # check None, ok - node.poll_query_until(query='create table def()', - expected=None) # returns nothing - - # check 0 rows equivalent to expected=None - node.poll_query_until( - query='select * from pg_catalog.pg_class where true = false', - expected=None) - - # check arbitrary expected value, fail - with self.assertRaises(TimeoutException): - node.poll_query_until(query='select 3', - expected=1, - max_attempts=3, - sleep_time=0.01) - - # check arbitrary expected value, ok - node.poll_query_until(query='select 2', expected=2) - - # check timeout - with self.assertRaises(TimeoutException): - node.poll_query_until(query='select 1 > 2', - max_attempts=3, - sleep_time=0.01) - - # check ProgrammingError, fail - with self.assertRaises(testgres.ProgrammingError): - node.poll_query_until(query='dummy1') - - # check ProgrammingError, ok - with self.assertRaises(TimeoutException): - node.poll_query_until(query='dummy2', - max_attempts=3, - sleep_time=0.01, - suppress={testgres.ProgrammingError}) - - # check 1 arg, ok - node.poll_query_until('select true') - - def test_logging(self): - # FAIL - logfile = tempfile.NamedTemporaryFile('w', delete=True) - - log_conf = { - 'version': 1, - 'handlers': { - 'file': { - 'class': 'logging.FileHandler', - 'filename': logfile.name, - 'formatter': 'base_format', - 'level': logging.DEBUG, - }, - }, - 'formatters': { - 'base_format': { - 'format': '%(node)-5s: %(message)s', - }, - }, - 'root': { - 'handlers': ('file',), - 'level': 'DEBUG', - }, - } - - logging.config.dictConfig(log_conf) - - with scoped_config(use_python_logging=True): - node_name = 'master' - - with get_remote_node(name=node_name) as master: - master.init().start() - - # execute a dummy query a few times - for i in range(20): - master.execute('select 1') - time.sleep(0.01) - - # let logging worker do the job - time.sleep(0.1) - - # check that master's port is found - with open(logfile.name, 'r') as log: - lines = log.readlines() - self.assertTrue(any(node_name in s for s in lines)) - - # test logger after stop/start/restart - master.stop() - master.start() - master.restart() - self.assertTrue(master._logger.is_alive()) - - @unittest.skipUnless(util_exists('pgbench'), 'might be missing') - def test_pgbench(self): - with get_remote_node(conn_params=conn_params).init().start() as node: - # initialize pgbench DB and run benchmarks - node.pgbench_init(scale=2, foreign_keys=True, - options=['-q']).pgbench_run(time=2) - - # run TPC-B benchmark - proc = node.pgbench(stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - options=['-T3']) - out = proc.communicate()[0] - self.assertTrue(b'tps = ' in out) - - def test_pg_config(self): - # check same instances - a = get_pg_config() - b = get_pg_config() - self.assertEqual(id(a), id(b)) - - # save right before config change - c1 = get_pg_config() - # modify setting for this scope - with scoped_config(cache_pg_config=False) as config: - # sanity check for value - self.assertFalse(config.cache_pg_config) - - # save right after config change - c2 = get_pg_config() - - # check different instances after config change - self.assertNotEqual(id(c1), id(c2)) - - # check different instances - a = get_pg_config() - b = get_pg_config() - self.assertNotEqual(id(a), id(b)) - - def test_config_stack(self): - # no such option - with self.assertRaises(TypeError): - configure_testgres(dummy=True) - - # we have only 1 config in stack - with self.assertRaises(IndexError): - pop_config() - - d0 = TestgresConfig.cached_initdb_dir - d1 = 'dummy_abc' - d2 = 'dummy_def' - - with scoped_config(cached_initdb_dir=d1) as c1: - self.assertEqual(c1.cached_initdb_dir, d1) - - with scoped_config(cached_initdb_dir=d2) as c2: - stack_size = len(testgres.config.config_stack) - - # try to break a stack - with self.assertRaises(TypeError): - with scoped_config(dummy=True): - pass - - self.assertEqual(c2.cached_initdb_dir, d2) - self.assertEqual(len(testgres.config.config_stack), stack_size) - - self.assertEqual(c1.cached_initdb_dir, d1) - - self.assertEqual(TestgresConfig.cached_initdb_dir, d0) - - def test_unix_sockets(self): - with get_remote_node(conn_params=conn_params) as node: - node.init(unix_sockets=False, allow_streaming=True) - node.start() - - res_exec = node.execute('select 1') - res_psql = node.safe_psql('select 1') - self.assertEqual(res_exec, [(1,)]) - self.assertEqual(res_psql, b'1\n') - - with node.replicate().start() as r: - res_exec = r.execute('select 1') - res_psql = r.safe_psql('select 1') - self.assertEqual(res_exec, [(1,)]) - self.assertEqual(res_psql, b'1\n') - - def test_auto_name(self): - with get_remote_node(conn_params=conn_params).init(allow_streaming=True).start() as m: - with m.replicate().start() as r: - # check that nodes are running - self.assertTrue(m.status()) - self.assertTrue(r.status()) - - # check their names - self.assertNotEqual(m.name, r.name) - self.assertTrue('testgres' in m.name) - self.assertTrue('testgres' in r.name) - - def test_file_tail(self): - from testgres.utils import file_tail - - s1 = "the quick brown fox jumped over that lazy dog\n" - s2 = "abc\n" - s3 = "def\n" - - with tempfile.NamedTemporaryFile(mode='r+', delete=True) as f: - sz = 0 - while sz < 3 * 8192: - sz += len(s1) - f.write(s1) - f.write(s2) - f.write(s3) - - f.seek(0) - lines = file_tail(f, 3) - self.assertEqual(lines[0], s1) - self.assertEqual(lines[1], s2) - self.assertEqual(lines[2], s3) - - f.seek(0) - lines = file_tail(f, 1) - self.assertEqual(lines[0], s3) - - def test_isolation_levels(self): - with get_remote_node(conn_params=conn_params).init().start() as node: - with node.connect() as con: - # string levels - con.begin('Read Uncommitted').commit() - con.begin('Read Committed').commit() - con.begin('Repeatable Read').commit() - con.begin('Serializable').commit() - - # enum levels - con.begin(IsolationLevel.ReadUncommitted).commit() - con.begin(IsolationLevel.ReadCommitted).commit() - con.begin(IsolationLevel.RepeatableRead).commit() - con.begin(IsolationLevel.Serializable).commit() - - # check wrong level - with self.assertRaises(QueryException): - con.begin('Garbage').commit() - - def test_ports_management(self): - # check that no ports have been bound yet - self.assertEqual(len(bound_ports), 0) - - with get_remote_node(conn_params=conn_params) as node: - # check that we've just bound a port - self.assertEqual(len(bound_ports), 1) - - # check that bound_ports contains our port - port_1 = list(bound_ports)[0] - port_2 = node.port - self.assertEqual(port_1, port_2) - - # check that port has been freed successfully - self.assertEqual(len(bound_ports), 0) - - def test_exceptions(self): - str(StartNodeException('msg', [('file', 'lines')])) - str(ExecUtilException('msg', 'cmd', 1, 'out')) - str(QueryException('msg', 'query')) - - def test_version_management(self): - a = PgVer('10.0') - b = PgVer('10') - c = PgVer('9.6.5') - d = PgVer('15.0') - e = PgVer('15rc1') - f = PgVer('15beta4') - - self.assertTrue(a == b) - self.assertTrue(b > c) - self.assertTrue(a > c) - self.assertTrue(d > e) - self.assertTrue(e > f) - self.assertTrue(d > f) - - version = get_pg_version() - with get_remote_node(conn_params=conn_params) as node: - self.assertTrue(isinstance(version, six.string_types)) - self.assertTrue(isinstance(node.version, PgVer)) - self.assertEqual(node.version, PgVer(version)) - - def test_child_pids(self): - master_processes = [ - ProcessType.AutovacuumLauncher, - ProcessType.BackgroundWriter, - ProcessType.Checkpointer, - ProcessType.StatsCollector, - ProcessType.WalSender, - ProcessType.WalWriter, - ] - - if pg_version_ge('10'): - master_processes.append(ProcessType.LogicalReplicationLauncher) - - repl_processes = [ - ProcessType.Startup, - ProcessType.WalReceiver, - ] - - with get_remote_node(conn_params=conn_params).init().start() as master: - - # master node doesn't have a source walsender! - with self.assertRaises(TestgresException): - master.source_walsender - - with master.connect() as con: - self.assertGreater(con.pid, 0) - - with master.replicate().start() as replica: - - # test __str__ method - str(master.child_processes[0]) - - master_pids = master.auxiliary_pids - for ptype in master_processes: - self.assertIn(ptype, master_pids) - - replica_pids = replica.auxiliary_pids - for ptype in repl_processes: - self.assertIn(ptype, replica_pids) - - # there should be exactly 1 source walsender for replica - self.assertEqual(len(master_pids[ProcessType.WalSender]), 1) - pid1 = master_pids[ProcessType.WalSender][0] - pid2 = replica.source_walsender.pid - self.assertEqual(pid1, pid2) - - replica.stop() - - # there should be no walsender after we've stopped replica - with self.assertRaises(TestgresException): - replica.source_walsender - - def test_child_process_dies(self): - # test for FileNotFound exception during child_processes() function - with subprocess.Popen(["sleep", "60"]) as process: - self.assertEqual(process.poll(), None) - # collect list of processes currently running - children = psutil.Process(os.getpid()).children() - # kill a process, so received children dictionary becomes invalid - process.kill() - process.wait() - # try to handle children list -- missing processes will have ptype "ProcessType.Unknown" - [ProcessProxy(p) for p in children] - - -if __name__ == '__main__': - if os_ops.environ('ALT_CONFIG'): - suite = unittest.TestSuite() - - # Small subset of tests for alternative configs (PG_BIN or PG_CONFIG) - suite.addTest(TestgresRemoteTests('test_pg_config')) - suite.addTest(TestgresRemoteTests('test_pg_ctl')) - suite.addTest(TestgresRemoteTests('test_psql')) - suite.addTest(TestgresRemoteTests('test_replicate')) - - print('Running tests for alternative config:') - for t in suite: - print(t) - print() - - runner = unittest.TextTestRunner() - runner.run(suite) - else: - unittest.main() diff --git a/tests/test_testgres_common.py b/tests/test_testgres_common.py new file mode 100644 index 00000000..cf203a67 --- /dev/null +++ b/tests/test_testgres_common.py @@ -0,0 +1,1590 @@ +from .helpers.global_data import PostgresNodeService +from .helpers.global_data import PostgresNodeServices +from .helpers.global_data import OsOperations +from .helpers.global_data import PortManager + +from testgres.node import PgVer +from testgres.node import PostgresNode +from testgres.node import PostgresNodeLogReader +from testgres.node import PostgresNodeUtils +from testgres.utils import get_pg_version2 +from testgres.utils import file_tail +from testgres.utils import get_bin_path2 +from testgres import ProcessType +from testgres import NodeStatus +from testgres import IsolationLevel + +# New name prevents to collect test-functions in TestgresException and fixes +# the problem with pytest warning. +from testgres import TestgresException as testgres_TestgresException + +from testgres import InitNodeException +from testgres import StartNodeException +from testgres import QueryException +from testgres import ExecUtilException +from testgres import TimeoutException +from testgres import InvalidOperationException +from testgres import BackupException +from testgres import ProgrammingError +from testgres import scoped_config +from testgres import First, Any + +from contextlib import contextmanager + +import pytest +import six +import logging +import time +import tempfile +import uuid +import os +import re +import subprocess +import typing + + +@contextmanager +def removing(os_ops: OsOperations, f): + assert isinstance(os_ops, OsOperations) + + try: + yield f + finally: + if os_ops.isfile(f): + os_ops.remove_file(f) + + elif os_ops.isdir(f): + os_ops.rmdirs(f, ignore_errors=True) + + +class TestTestgresCommon: + sm_node_svcs: typing.List[PostgresNodeService] = [ + PostgresNodeServices.sm_local, + PostgresNodeServices.sm_local2, + PostgresNodeServices.sm_remote, + ] + + @pytest.fixture( + params=sm_node_svcs, + ids=[descr.sign for descr in sm_node_svcs] + ) + def node_svc(self, request: pytest.FixtureRequest) -> PostgresNodeService: + assert isinstance(request, pytest.FixtureRequest) + assert isinstance(request.param, PostgresNodeService) + assert isinstance(request.param.os_ops, OsOperations) + assert isinstance(request.param.port_manager, PortManager) + return request.param + + def test_version_management(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + a = PgVer('10.0') + b = PgVer('10') + c = PgVer('9.6.5') + d = PgVer('15.0') + e = PgVer('15rc1') + f = PgVer('15beta4') + h = PgVer('15.3biha') + i = PgVer('15.3') + g = PgVer('15.3.1bihabeta1') + k = PgVer('15.3.1') + + assert (a == b) + assert (b > c) + assert (a > c) + assert (d > e) + assert (e > f) + assert (d > f) + assert (h > f) + assert (h == i) + assert (g == k) + assert (g > h) + + version = get_pg_version2(node_svc.os_ops) + + with __class__.helper__get_node(node_svc) as node: + assert (isinstance(version, six.string_types)) + assert (isinstance(node.version, PgVer)) + assert (node.version == PgVer(version)) + + def test_node_repr(self, node_svc: PostgresNodeService): + with __class__.helper__get_node(node_svc).init() as node: + pattern = r"PostgresNode\(name='.+', port=.+, base_dir='.+'\)" + assert re.match(pattern, str(node)) is not None + + def test_custom_init(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + with __class__.helper__get_node(node_svc) as node: + # enable page checksums + node.init(initdb_params=['-k']).start() + + with __class__.helper__get_node(node_svc) as node: + node.init( + allow_streaming=True, + initdb_params=['--auth-local=reject', '--auth-host=reject']) + + hba_file = os.path.join(node.data_dir, 'pg_hba.conf') + lines = node.os_ops.readlines(hba_file) + + # check number of lines + assert (len(lines) >= 6) + + # there should be no trust entries at all + assert not (any('trust' in s for s in lines)) + + def test_double_init(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + with __class__.helper__get_node(node_svc).init() as node: + # can't initialize node more than once + with pytest.raises(expected_exception=InitNodeException): + node.init() + + def test_init_after_cleanup(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + with __class__.helper__get_node(node_svc) as node: + node.init().start().execute('select 1') + node.cleanup() + node.init().start().execute('select 1') + + def test_init_unique_system_id(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + # this function exists in PostgreSQL 9.6+ + current_version = get_pg_version2(node_svc.os_ops) + + __class__.helper__skip_test_if_util_not_exist(node_svc.os_ops, "pg_resetwal") + __class__.helper__skip_test_if_pg_version_is_not_ge(current_version, '9.6') + + query = 'select system_identifier from pg_control_system()' + + with scoped_config(cache_initdb=False): + with __class__.helper__get_node(node_svc).init().start() as node0: + id0 = node0.execute(query)[0] + + with scoped_config(cache_initdb=True, + cached_initdb_unique=True) as config: + assert (config.cache_initdb) + assert (config.cached_initdb_unique) + + # spawn two nodes; ids must be different + with __class__.helper__get_node(node_svc).init().start() as node1, \ + __class__.helper__get_node(node_svc).init().start() as node2: + id1 = node1.execute(query)[0] + id2 = node2.execute(query)[0] + + # ids must increase + assert (id1 > id0) + assert (id2 > id1) + + def test_node_exit(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + with pytest.raises(expected_exception=QueryException): + with __class__.helper__get_node(node_svc).init() as node: + base_dir = node.base_dir + node.safe_psql('select 1') + + # we should save the DB for "debugging" + assert (node_svc.os_ops.path_exists(base_dir)) + node_svc.os_ops.rmdirs(base_dir, ignore_errors=True) + + with __class__.helper__get_node(node_svc).init() as node: + base_dir = node.base_dir + + # should have been removed by default + assert not (node_svc.os_ops.path_exists(base_dir)) + + def test_double_start(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + with __class__.helper__get_node(node_svc).init().start() as node: + # can't start node more than once + node.start() + assert (node.is_started) + + def test_uninitialized_start(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + with __class__.helper__get_node(node_svc) as node: + # node is not initialized yet + with pytest.raises(expected_exception=StartNodeException): + node.start() + + def test_restart(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + with __class__.helper__get_node(node_svc) as node: + node.init().start() + + # restart, ok + res = node.execute('select 1') + assert (res == [(1,)]) + node.restart() + res = node.execute('select 2') + assert (res == [(2,)]) + + # restart, fail + with pytest.raises(expected_exception=StartNodeException): + node.append_conf('pg_hba.conf', 'DUMMY') + node.restart() + + def test_reload(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + with __class__.helper__get_node(node_svc) as node: + node.init().start() + + # change client_min_messages and save old value + cmm_old = node.execute('show client_min_messages') + node.append_conf(client_min_messages='DEBUG1') + + # reload config + node.reload() + + # check new value + cmm_new = node.execute('show client_min_messages') + assert ('debug1' == cmm_new[0][0].lower()) + assert (cmm_old != cmm_new) + + def test_pg_ctl(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + with __class__.helper__get_node(node_svc) as node: + node.init().start() + + status = node.pg_ctl(['status']) + assert ('PID' in status) + + def test_status(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + assert (NodeStatus.Running) + assert not (NodeStatus.Stopped) + assert not (NodeStatus.Uninitialized) + + # check statuses after each operation + with __class__.helper__get_node(node_svc) as node: + assert (node.pid == 0) + assert (node.status() == NodeStatus.Uninitialized) + + node.init() + + assert (node.pid == 0) + assert (node.status() == NodeStatus.Stopped) + + node.start() + + assert (node.pid != 0) + assert (node.status() == NodeStatus.Running) + + node.stop() + + assert (node.pid == 0) + assert (node.status() == NodeStatus.Stopped) + + node.cleanup() + + assert (node.pid == 0) + assert (node.status() == NodeStatus.Uninitialized) + + def test_child_pids(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + master_processes = [ + ProcessType.AutovacuumLauncher, + ProcessType.BackgroundWriter, + ProcessType.Checkpointer, + ProcessType.StatsCollector, + ProcessType.WalSender, + ProcessType.WalWriter, + ] + + postgresVersion = get_pg_version2(node_svc.os_ops) + + if __class__.helper__pg_version_ge(postgresVersion, '10'): + master_processes.append(ProcessType.LogicalReplicationLauncher) + + if __class__.helper__pg_version_ge(postgresVersion, '14'): + master_processes.remove(ProcessType.StatsCollector) + + repl_processes = [ + ProcessType.Startup, + ProcessType.WalReceiver, + ] + + def LOCAL__test_auxiliary_pids( + node: PostgresNode, + expectedTypes: typing.List[ProcessType] + ) -> typing.List[ProcessType]: + # returns list of the absence processes + assert node is not None + assert type(node) == PostgresNode # noqa: E721 + assert expectedTypes is not None + assert type(expectedTypes) == list # noqa: E721 + + pids = node.auxiliary_pids + assert pids is not None # noqa: E721 + assert type(pids) == dict # noqa: E721 + + result: typing.List[ProcessType] = list() + for ptype in expectedTypes: + if not (ptype in pids): + result.append(ptype) + return result + + def LOCAL__check_auxiliary_pids__multiple_attempts( + node: PostgresNode, + expectedTypes: typing.List[ProcessType]): + assert node is not None + assert type(node) == PostgresNode # noqa: E721 + assert expectedTypes is not None + assert type(expectedTypes) == list # noqa: E721 + + nAttempt = 0 + + while nAttempt < 5: + nAttempt += 1 + + logging.info("Test pids of [{0}] node. Attempt #{1}.".format( + node.name, + nAttempt + )) + + if nAttempt > 1: + time.sleep(1) + + absenceList = LOCAL__test_auxiliary_pids(node, expectedTypes) + assert absenceList is not None + assert type(absenceList) == list # noqa: E721 + if len(absenceList) == 0: + logging.info("Bingo!") + return + + logging.info("These processes are not found: {0}.".format(absenceList)) + continue + + raise Exception("Node {0} does not have the following processes: {1}.".format( + node.name, + absenceList + )) + + with __class__.helper__get_node(node_svc).init().start() as master: + + # master node doesn't have a source walsender! + with pytest.raises(expected_exception=testgres_TestgresException): + master.source_walsender + + with master.connect() as con: + assert (con.pid > 0) + + with master.replicate().start() as replica: + assert type(replica) == PostgresNode # noqa: E721 + + # test __str__ method + str(master.child_processes[0]) + + LOCAL__check_auxiliary_pids__multiple_attempts( + master, + master_processes) + + LOCAL__check_auxiliary_pids__multiple_attempts( + replica, + repl_processes) + + master_pids = master.auxiliary_pids + + # there should be exactly 1 source walsender for replica + assert (len(master_pids[ProcessType.WalSender]) == 1) + pid1 = master_pids[ProcessType.WalSender][0] + pid2 = replica.source_walsender.pid + assert (pid1 == pid2) + + replica.stop() + + # there should be no walsender after we've stopped replica + with pytest.raises(expected_exception=testgres_TestgresException): + replica.source_walsender + + def test_exceptions(self): + str(StartNodeException('msg', [('file', 'lines')])) + str(ExecUtilException('msg', 'cmd', 1, 'out')) + str(QueryException('msg', 'query')) + + def test_auto_name(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + with __class__.helper__get_node(node_svc).init(allow_streaming=True).start() as m: + with m.replicate().start() as r: + # check that nodes are running + assert (m.status()) + assert (r.status()) + + # check their names + assert (m.name != r.name) + assert ('testgres' in m.name) + assert ('testgres' in r.name) + + def test_file_tail(self): + s1 = "the quick brown fox jumped over that lazy dog\n" + s2 = "abc\n" + s3 = "def\n" + + with tempfile.NamedTemporaryFile(mode='r+', delete=True) as f: + sz = 0 + while sz < 3 * 8192: + sz += len(s1) + f.write(s1) + f.write(s2) + f.write(s3) + + f.seek(0) + lines = file_tail(f, 3) + assert (lines[0] == s1) + assert (lines[1] == s2) + assert (lines[2] == s3) + + f.seek(0) + lines = file_tail(f, 1) + assert (lines[0] == s3) + + def test_isolation_levels(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc).init().start() as node: + with node.connect() as con: + # string levels + con.begin('Read Uncommitted').commit() + con.begin('Read Committed').commit() + con.begin('Repeatable Read').commit() + con.begin('Serializable').commit() + + # enum levels + con.begin(IsolationLevel.ReadUncommitted).commit() + con.begin(IsolationLevel.ReadCommitted).commit() + con.begin(IsolationLevel.RepeatableRead).commit() + con.begin(IsolationLevel.Serializable).commit() + + # check wrong level + with pytest.raises(expected_exception=QueryException): + con.begin('Garbage').commit() + + def test_users(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc).init().start() as node: + node.psql('create role test_user login') + value = node.safe_psql('select 1', username='test_user') + value = __class__.helper__rm_carriage_returns(value) + assert (value == b'1\n') + + def test_poll_query_until(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc) as node: + node.init().start() + + get_time = 'select extract(epoch from now())' + check_time = 'select extract(epoch from now()) - {} >= 5' + + start_time = node.execute(get_time)[0][0] + node.poll_query_until(query=check_time.format(start_time)) + end_time = node.execute(get_time)[0][0] + + assert (end_time - start_time >= 5) + + # check 0 columns + with pytest.raises(expected_exception=QueryException): + node.poll_query_until( + query='select from pg_catalog.pg_class limit 1') + + # check None, fail + with pytest.raises(expected_exception=QueryException): + node.poll_query_until(query='create table abc (val int)') + + # check None, ok + node.poll_query_until(query='create table def()', + expected=None) # returns nothing + + # check 0 rows equivalent to expected=None + node.poll_query_until( + query='select * from pg_catalog.pg_class where true = false', + expected=None) + + # check arbitrary expected value, fail + with pytest.raises(expected_exception=TimeoutException): + node.poll_query_until(query='select 3', + expected=1, + max_attempts=3, + sleep_time=0.01) + + # check arbitrary expected value, ok + node.poll_query_until(query='select 2', expected=2) + + # check timeout + with pytest.raises(expected_exception=TimeoutException): + node.poll_query_until(query='select 1 > 2', + max_attempts=3, + sleep_time=0.01) + + # check ProgrammingError, fail + with pytest.raises(expected_exception=ProgrammingError): + node.poll_query_until(query='dummy1') + + # check ProgrammingError, ok + with pytest.raises(expected_exception=(TimeoutException)): + node.poll_query_until(query='dummy2', + max_attempts=3, + sleep_time=0.01, + suppress={ProgrammingError}) + + # check 1 arg, ok + node.poll_query_until('select true') + + def test_logging(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + C_MAX_ATTEMPTS = 50 + # This name is used for testgres logging, too. + C_NODE_NAME = "testgres_tests." + __class__.__name__ + "test_logging-master-" + uuid.uuid4().hex + + logging.info("Node name is [{0}]".format(C_NODE_NAME)) + + with tempfile.NamedTemporaryFile('w', delete=True) as logfile: + formatter = logging.Formatter(fmt="%(node)-5s: %(message)s") + handler = logging.FileHandler(filename=logfile.name) + handler.formatter = formatter + logger = logging.getLogger(C_NODE_NAME) + assert logger is not None + assert len(logger.handlers) == 0 + + try: + # It disables to log on the root level + logger.propagate = False + logger.addHandler(handler) + + with scoped_config(use_python_logging=True): + with __class__.helper__get_node(node_svc, name=C_NODE_NAME) as master: + logging.info("Master node is initilizing") + master.init() + + logging.info("Master node is starting") + master.start() + + logging.info("Dummy query is executed a few times") + for _ in range(20): + master.execute('select 1') + time.sleep(0.01) + + # let logging worker do the job + time.sleep(0.1) + + logging.info("Master node log file is checking") + nAttempt = 0 + + while True: + assert nAttempt <= C_MAX_ATTEMPTS + if nAttempt == C_MAX_ATTEMPTS: + raise Exception("Test failed!") + + # let logging worker do the job + time.sleep(0.1) + + nAttempt += 1 + + logging.info("Attempt {0}".format(nAttempt)) + + # check that master's port is found + with open(logfile.name, 'r') as log: + lines = log.readlines() + + assert lines is not None + assert type(lines) == list # noqa: E721 + + def LOCAL__test_lines(): + for s in lines: + if any(C_NODE_NAME in s for s in lines): + logging.info("OK. We found the node_name in a line \"{0}\"".format(s)) + return True + return False + + if LOCAL__test_lines(): + break + + logging.info("Master node log file does not have an expected information.") + continue + + # test logger after stop/start/restart + logging.info("Master node is stopping...") + master.stop() + logging.info("Master node is staring again...") + master.start() + logging.info("Master node is restaring...") + master.restart() + assert (master._logger.is_alive()) + finally: + # It is a hack code to logging cleanup + logging._acquireLock() + assert logging.Logger.manager is not None + assert C_NODE_NAME in logging.Logger.manager.loggerDict.keys() + logging.Logger.manager.loggerDict.pop(C_NODE_NAME, None) + assert not (C_NODE_NAME in logging.Logger.manager.loggerDict.keys()) + assert not (handler in logging._handlers.values()) + logging._releaseLock() + # GO HOME! + return + + def test_psql(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc).init().start() as node: + + # check returned values (1 arg) + res = node.psql('select 1') + assert (__class__.helper__rm_carriage_returns(res) == (0, b'1\n', b'')) + + # check returned values (2 args) + res = node.psql('postgres', 'select 2') + assert (__class__.helper__rm_carriage_returns(res) == (0, b'2\n', b'')) + + # check returned values (named) + res = node.psql(query='select 3', dbname='postgres') + assert (__class__.helper__rm_carriage_returns(res) == (0, b'3\n', b'')) + + # check returned values (1 arg) + res = node.safe_psql('select 4') + assert (__class__.helper__rm_carriage_returns(res) == b'4\n') + + # check returned values (2 args) + res = node.safe_psql('postgres', 'select 5') + assert (__class__.helper__rm_carriage_returns(res) == b'5\n') + + # check returned values (named) + res = node.safe_psql(query='select 6', dbname='postgres') + assert (__class__.helper__rm_carriage_returns(res) == b'6\n') + + # check feeding input + node.safe_psql('create table horns (w int)') + node.safe_psql('copy horns from stdin (format csv)', + input=b"1\n2\n3\n\\.\n") + _sum = node.safe_psql('select sum(w) from horns') + assert (__class__.helper__rm_carriage_returns(_sum) == b'6\n') + + # check psql's default args, fails + with pytest.raises(expected_exception=QueryException): + r = node.psql() # raises! + logging.error("node.psql returns [{}]".format(r)) + + node.stop() + + # check psql on stopped node, fails + with pytest.raises(expected_exception=QueryException): + # [2025-04-03] This call does not raise exception! I do not know why. + r = node.safe_psql('select 1') # raises! + logging.error("node.safe_psql returns [{}]".format(r)) + + def test_psql__another_port(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc).init() as node1: + with __class__.helper__get_node(node_svc).init() as node2: + node1.start() + node2.start() + assert node1.port != node2.port + assert node1.host == node2.host + + node1.stop() + + logging.info("test table in node2 is creating ...") + node2.safe_psql( + dbname="postgres", + query="create table test (id integer);" + ) + + logging.info("try to find test table through node1.psql ...") + res = node1.psql( + dbname="postgres", + query="select count(*) from pg_class where relname='test'", + host=node2.host, + port=node2.port, + ) + assert (__class__.helper__rm_carriage_returns(res) == (0, b'1\n', b'')) + + def test_psql__another_bad_host(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc).init() as node: + logging.info("try to execute node1.psql ...") + res = node.psql( + dbname="postgres", + query="select count(*) from pg_class where relname='test'", + host="DUMMY_HOST_NAME", + port=node.port, + ) + + res2 = __class__.helper__rm_carriage_returns(res) + + assert res2[0] != 0 + assert b"DUMMY_HOST_NAME" in res[2] + + def test_safe_psql__another_port(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc).init() as node1: + with __class__.helper__get_node(node_svc).init() as node2: + node1.start() + node2.start() + assert node1.port != node2.port + assert node1.host == node2.host + + node1.stop() + + logging.info("test table in node2 is creating ...") + node2.safe_psql( + dbname="postgres", + query="create table test (id integer);" + ) + + logging.info("try to find test table through node1.psql ...") + res = node1.safe_psql( + dbname="postgres", + query="select count(*) from pg_class where relname='test'", + host=node2.host, + port=node2.port, + ) + assert (__class__.helper__rm_carriage_returns(res) == b'1\n') + + def test_safe_psql__another_bad_host(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc).init() as node: + logging.info("try to execute node1.psql ...") + + with pytest.raises(expected_exception=Exception) as x: + node.safe_psql( + dbname="postgres", + query="select count(*) from pg_class where relname='test'", + host="DUMMY_HOST_NAME", + port=node.port, + ) + + assert "DUMMY_HOST_NAME" in str(x.value) + + def test_safe_psql__expect_error(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc).init().start() as node: + err = node.safe_psql('select_or_not_select 1', expect_error=True) + assert (type(err) == str) # noqa: E721 + assert ('select_or_not_select' in err) + assert ('ERROR: syntax error at or near "select_or_not_select"' in err) + + # --------- + with pytest.raises( + expected_exception=InvalidOperationException, + match="^" + re.escape("Exception was expected, but query finished successfully: `select 1;`.") + "$" + ): + node.safe_psql("select 1;", expect_error=True) + + # --------- + res = node.safe_psql("select 1;", expect_error=False) + assert (__class__.helper__rm_carriage_returns(res) == b'1\n') + + def test_transactions(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc).init().start() as node: + + with node.connect() as con: + con.begin() + con.execute('create table test(val int)') + con.execute('insert into test values (1)') + con.commit() + + con.begin() + con.execute('insert into test values (2)') + res = con.execute('select * from test order by val asc') + assert (res == [(1, ), (2, )]) + con.rollback() + + con.begin() + res = con.execute('select * from test') + assert (res == [(1, )]) + con.rollback() + + con.begin() + con.execute('drop table test') + con.commit() + + def test_control_data(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc) as node: + + # node is not initialized yet + with pytest.raises(expected_exception=ExecUtilException): + node.get_control_data() + + node.init() + data = node.get_control_data() + + # check returned dict + assert data is not None + assert (any('pg_control' in s for s in data.keys())) + + def test_backup_simple(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc) as master: + + # enable streaming for backups + master.init(allow_streaming=True) + + # node must be running + with pytest.raises(expected_exception=BackupException): + master.backup() + + # it's time to start node + master.start() + + # fill node with some data + master.psql('create table test as select generate_series(1, 4) i') + + with master.backup(xlog_method='stream') as backup: + with backup.spawn_primary().start() as slave: + res = slave.execute('select * from test order by i asc') + assert (res == [(1, ), (2, ), (3, ), (4, )]) + + def test_backup_multiple(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc) as node: + node.init(allow_streaming=True).start() + + with node.backup(xlog_method='fetch') as backup1, \ + node.backup(xlog_method='fetch') as backup2: + assert (backup1.base_dir != backup2.base_dir) + + with node.backup(xlog_method='fetch') as backup: + with backup.spawn_primary('node1', destroy=False) as node1, \ + backup.spawn_primary('node2', destroy=False) as node2: + assert (node1.base_dir != node2.base_dir) + + def test_backup_exhaust(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc) as node: + node.init(allow_streaming=True).start() + + with node.backup(xlog_method='fetch') as backup: + # exhaust backup by creating new node + with backup.spawn_primary(): + pass + + # now let's try to create one more node + with pytest.raises(expected_exception=BackupException): + backup.spawn_primary() + + def test_backup_wrong_xlog_method(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc) as node: + node.init(allow_streaming=True).start() + + with pytest.raises( + expected_exception=BackupException, + match="^" + re.escape('Invalid xlog_method "wrong"') + "$" + ): + node.backup(xlog_method='wrong') + + def test_pg_ctl_wait_option(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + C_MAX_ATTEMPT = 5 + + nAttempt = 0 + + while True: + if nAttempt == C_MAX_ATTEMPT: + raise Exception("PostgresSQL did not start.") + + nAttempt += 1 + logging.info("------------------------ NODE #{}".format( + nAttempt + )) + + with __class__.helper__get_node(node_svc, port=12345) as node: + if self.impl__test_pg_ctl_wait_option(node_svc, node): + break + continue + + logging.info("OK. Test is passed. Number of attempts is {}".format( + nAttempt + )) + return + + def impl__test_pg_ctl_wait_option( + self, + node_svc: PostgresNodeService, + node: PostgresNode + ) -> None: + assert isinstance(node_svc, PostgresNodeService) + assert isinstance(node, PostgresNode) + assert node.status() == NodeStatus.Uninitialized + + C_MAX_ATTEMPTS = 50 + + node.init() + assert node.status() == NodeStatus.Stopped + + node_log_reader = PostgresNodeLogReader(node, from_beginnig=True) + + node.start(wait=False) + nAttempt = 0 + while True: + if PostgresNodeUtils.delect_port_conflict(node_log_reader): + logging.info("Node port {} conflicted with another PostgreSQL instance.".format( + node.port + )) + return False + + if nAttempt == C_MAX_ATTEMPTS: + # + # [2025-03-11] + # We have an unexpected problem with this test in CI + # Let's get an additional information about this test failure. + # + logging.error("Node was not stopped.") + if not node.os_ops.path_exists(node.pg_log_file): + logging.warning("Node log does not exist.") + else: + logging.info("Let's read node log file [{0}]".format(node.pg_log_file)) + logFileData = node.os_ops.read(node.pg_log_file, binary=False) + logging.info("Node log file content:\n{0}".format(logFileData)) + + raise Exception("Could not stop node.") + + nAttempt += 1 + + if nAttempt > 1: + logging.info("Wait 1 second.") + time.sleep(1) + logging.info("") + + logging.info("Try to stop node. Attempt #{0}.".format(nAttempt)) + + try: + node.stop(wait=False) + break + except ExecUtilException as e: + # it's ok to get this exception here since node + # could be not started yet + logging.info("Node is not stopped. Exception ({0}): {1}".format(type(e).__name__, e)) + continue + + logging.info("OK. Stop command was executed. Let's wait while our node will stop really.") + nAttempt = 0 + while True: + if nAttempt == C_MAX_ATTEMPTS: + raise Exception("Could not stop node.") + + nAttempt += 1 + if nAttempt > 1: + logging.info("Wait 1 second.") + time.sleep(1) + logging.info("") + + logging.info("Attempt #{0}.".format(nAttempt)) + s1 = node.status() + + if s1 == NodeStatus.Running: + continue + + if s1 == NodeStatus.Stopped: + break + + raise Exception("Unexpected node status: {0}.".format(s1)) + + logging.info("OK. Node is stopped.") + return True + + def test_replicate(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc) as node: + node.init(allow_streaming=True).start() + + with node.replicate().start() as replica: + res = replica.execute('select 1') + assert (res == [(1, )]) + + node.execute('create table test (val int)', commit=True) + + replica.catchup() + + res = node.execute('select * from test') + assert (res == []) + + def test_synchronous_replication(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + current_version = get_pg_version2(node_svc.os_ops) + + __class__.helper__skip_test_if_pg_version_is_not_ge(current_version, "9.6") + + with __class__.helper__get_node(node_svc) as master: + old_version = not __class__.helper__pg_version_ge(current_version, '9.6') + + master.init(allow_streaming=True).start() + + if not old_version: + master.append_conf('synchronous_commit = remote_apply') + + # create standby + with master.replicate() as standby1, master.replicate() as standby2: + standby1.start() + standby2.start() + + # check formatting + assert ( + '1 ("{}", "{}")'.format(standby1.name, standby2.name) == str(First(1, (standby1, standby2))) + ) # yapf: disable + assert ( + 'ANY 1 ("{}", "{}")'.format(standby1.name, standby2.name) == str(Any(1, (standby1, standby2))) + ) # yapf: disable + + # set synchronous_standby_names + master.set_synchronous_standbys(First(2, [standby1, standby2])) + master.restart() + + # the following part of the test is only applicable to newer + # versions of PostgresQL + if not old_version: + master.safe_psql('create table abc(a int)') + + # Create a large transaction that will take some time to apply + # on standby to check that it applies synchronously + # (If set synchronous_commit to 'on' or other lower level then + # standby most likely won't catchup so fast and test will fail) + master.safe_psql( + 'insert into abc select generate_series(1, 1000000)') + res = standby1.safe_psql('select count(*) from abc') + assert (__class__.helper__rm_carriage_returns(res) == b'1000000\n') + + def test_logical_replication(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + current_version = get_pg_version2(node_svc.os_ops) + + __class__.helper__skip_test_if_pg_version_is_not_ge(current_version, "10") + + with __class__.helper__get_node(node_svc) as node1, __class__.helper__get_node(node_svc) as node2: + node1.init(allow_logical=True) + node1.start() + node2.init().start() + + create_table = 'create table test (a int, b int)' + node1.safe_psql(create_table) + node2.safe_psql(create_table) + + # create publication / create subscription + pub = node1.publish('mypub') + sub = node2.subscribe(pub, 'mysub') + + node1.safe_psql('insert into test values (1, 1), (2, 2)') + + # wait until changes apply on subscriber and check them + sub.catchup() + res = node2.execute('select * from test') + assert (res == [(1, 1), (2, 2)]) + + # disable and put some new data + sub.disable() + node1.safe_psql('insert into test values (3, 3)') + + # enable and ensure that data successfully transferred + sub.enable() + sub.catchup() + res = node2.execute('select * from test') + assert (res == [(1, 1), (2, 2), (3, 3)]) + + # Add new tables. Since we added "all tables" to publication + # (default behaviour of publish() method) we don't need + # to explicitly perform pub.add_tables() + create_table = 'create table test2 (c char)' + node1.safe_psql(create_table) + node2.safe_psql(create_table) + sub.refresh() + + # put new data + node1.safe_psql('insert into test2 values (\'a\'), (\'b\')') + sub.catchup() + res = node2.execute('select * from test2') + assert (res == [('a', ), ('b', )]) + + # drop subscription + sub.drop() + pub.drop() + + # create new publication and subscription for specific table + # (omitting copying data as it's already done) + pub = node1.publish('newpub', tables=['test']) + sub = node2.subscribe(pub, 'newsub', copy_data=False) + + node1.safe_psql('insert into test values (4, 4)') + sub.catchup() + res = node2.execute('select * from test') + assert (res == [(1, 1), (2, 2), (3, 3), (4, 4)]) + + # explicitly add table + with pytest.raises(expected_exception=ValueError): + pub.add_tables([]) # fail + pub.add_tables(['test2']) + node1.safe_psql('insert into test2 values (\'c\')') + sub.catchup() + res = node2.execute('select * from test2') + assert (res == [('a', ), ('b', )]) + + def test_logical_catchup(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + """ Runs catchup for 100 times to be sure that it is consistent """ + + current_version = get_pg_version2(node_svc.os_ops) + + __class__.helper__skip_test_if_pg_version_is_not_ge(current_version, "10") + + with __class__.helper__get_node(node_svc) as node1, __class__.helper__get_node(node_svc) as node2: + node1.init(allow_logical=True) + node1.start() + node2.init().start() + + create_table = 'create table test (key int primary key, val int); ' + node1.safe_psql(create_table) + node1.safe_psql('alter table test replica identity default') + node2.safe_psql(create_table) + + # create publication / create subscription + sub = node2.subscribe(node1.publish('mypub'), 'mysub') + + for i in range(0, 100): + node1.execute('insert into test values ({0}, {0})'.format(i)) + sub.catchup() + res = node2.execute('select * from test') + assert (res == [(i, i, )]) + node1.execute('delete from test') + + def test_logical_replication_fail(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + current_version = get_pg_version2(node_svc.os_ops) + + __class__.helper__skip_test_if_pg_version_is_ge(current_version, "10") + + with __class__.helper__get_node(node_svc) as node: + with pytest.raises(expected_exception=InitNodeException): + node.init(allow_logical=True) + + def test_replication_slots(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc) as node: + node.init(allow_streaming=True).start() + + with node.replicate(slot='slot1').start() as replica: + replica.execute('select 1') + + # cannot create new slot with the same name + with pytest.raises(expected_exception=testgres_TestgresException): + node.replicate(slot='slot1') + + def test_incorrect_catchup(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc) as node: + node.init(allow_streaming=True).start() + + # node has no master, can't catch up + with pytest.raises(expected_exception=testgres_TestgresException): + node.catchup() + + def test_promotion(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + with __class__.helper__get_node(node_svc) as master: + master.init().start() + master.safe_psql('create table abc(id serial)') + + with master.replicate().start() as replica: + master.stop() + replica.promote() + + # make standby becomes writable master + replica.safe_psql('insert into abc values (1)') + res = replica.safe_psql('select * from abc') + assert (__class__.helper__rm_carriage_returns(res) == b'1\n') + + def test_dump(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + query_create = 'create table test as select generate_series(1, 2) as val' + query_select = 'select * from test order by val asc' + + with __class__.helper__get_node(node_svc).init().start() as node1: + + node1.execute(query_create) + for format in ['plain', 'custom', 'directory', 'tar']: + with removing(node_svc.os_ops, node1.dump(format=format)) as dump: + with __class__.helper__get_node(node_svc).init().start() as node3: + if format == 'directory': + assert (os.path.isdir(dump)) + else: + assert (os.path.isfile(dump)) + # restore dump + node3.restore(filename=dump) + res = node3.execute(query_select) + assert (res == [(1, ), (2, )]) + + def test_pgbench(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + __class__.helper__skip_test_if_util_not_exist(node_svc.os_ops, "pgbench") + + with __class__.helper__get_node(node_svc).init().start() as node: + # initialize pgbench DB and run benchmarks + node.pgbench_init( + scale=2, + foreign_keys=True, + options=['-q'] + ).pgbench_run(time=2) + + # run TPC-B benchmark + proc = node.pgbench(stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + options=['-T3']) + out = proc.communicate()[0] + assert (b'tps = ' in out) + + def test_unix_sockets(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + with __class__.helper__get_node(node_svc) as node: + node.init(unix_sockets=False, allow_streaming=True) + node.start() + + res_exec = node.execute('select 1') + assert (res_exec == [(1,)]) + res_psql = node.safe_psql('select 1') + assert (res_psql == b'1\n') + + with node.replicate() as r: + assert type(r) == PostgresNode # noqa: E721 + r.start() + res_exec = r.execute('select 1') + assert (res_exec == [(1,)]) + res_psql = r.safe_psql('select 1') + assert (res_psql == b'1\n') + + def test_the_same_port(self, node_svc: PostgresNodeService): + assert isinstance(node_svc, PostgresNodeService) + + with __class__.helper__get_node(node_svc) as node: + node.init().start() + assert (node._should_free_port) + assert (type(node.port) == int) # noqa: E721 + node_port_copy = node.port + r = node.safe_psql("SELECT 1;") + assert (__class__.helper__rm_carriage_returns(r) == b'1\n') + + with __class__.helper__get_node(node_svc, port=node.port) as node2: + assert (type(node2.port) == int) # noqa: E721 + assert (node2.port == node.port) + assert not (node2._should_free_port) + + with pytest.raises( + expected_exception=StartNodeException, + match=re.escape("Cannot start node") + ): + node2.init().start() + + # node is still working + assert (node.port == node_port_copy) + assert (node._should_free_port) + r = node.safe_psql("SELECT 3;") + assert (__class__.helper__rm_carriage_returns(r) == b'3\n') + + class tagPortManagerProxy(PortManager): + m_PrevPortManager: PortManager + + m_DummyPortNumber: int + m_DummyPortMaxUsage: int + + m_DummyPortCurrentUsage: int + m_DummyPortTotalUsage: int + + def __init__(self, prevPortManager: PortManager, dummyPortNumber: int, dummyPortMaxUsage: int): + assert isinstance(prevPortManager, PortManager) + assert type(dummyPortNumber) == int # noqa: E721 + assert type(dummyPortMaxUsage) == int # noqa: E721 + assert dummyPortNumber >= 0 + assert dummyPortMaxUsage >= 0 + + super().__init__() + + self.m_PrevPortManager = prevPortManager + + self.m_DummyPortNumber = dummyPortNumber + self.m_DummyPortMaxUsage = dummyPortMaxUsage + + self.m_DummyPortCurrentUsage = 0 + self.m_DummyPortTotalUsage = 0 + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + assert self.m_DummyPortCurrentUsage == 0 + + assert self.m_PrevPortManager is not None + + def reserve_port(self) -> int: + assert type(self.m_DummyPortMaxUsage) == int # noqa: E721 + assert type(self.m_DummyPortTotalUsage) == int # noqa: E721 + assert type(self.m_DummyPortCurrentUsage) == int # noqa: E721 + assert self.m_DummyPortTotalUsage >= 0 + assert self.m_DummyPortCurrentUsage >= 0 + + assert self.m_DummyPortTotalUsage <= self.m_DummyPortMaxUsage + assert self.m_DummyPortCurrentUsage <= self.m_DummyPortTotalUsage + + assert self.m_PrevPortManager is not None + assert isinstance(self.m_PrevPortManager, PortManager) + + if self.m_DummyPortTotalUsage == self.m_DummyPortMaxUsage: + return self.m_PrevPortManager.reserve_port() + + self.m_DummyPortTotalUsage += 1 + self.m_DummyPortCurrentUsage += 1 + return self.m_DummyPortNumber + + def release_port(self, dummyPortNumber: int): + assert type(dummyPortNumber) == int # noqa: E721 + + assert type(self.m_DummyPortMaxUsage) == int # noqa: E721 + assert type(self.m_DummyPortTotalUsage) == int # noqa: E721 + assert type(self.m_DummyPortCurrentUsage) == int # noqa: E721 + assert self.m_DummyPortTotalUsage >= 0 + assert self.m_DummyPortCurrentUsage >= 0 + + assert self.m_DummyPortTotalUsage <= self.m_DummyPortMaxUsage + assert self.m_DummyPortCurrentUsage <= self.m_DummyPortTotalUsage + + assert self.m_PrevPortManager is not None + assert isinstance(self.m_PrevPortManager, PortManager) + + if self.m_DummyPortCurrentUsage > 0 and dummyPortNumber == self.m_DummyPortNumber: + assert self.m_DummyPortTotalUsage > 0 + self.m_DummyPortCurrentUsage -= 1 + return + + return self.m_PrevPortManager.release_port(dummyPortNumber) + + def test_port_rereserve_during_node_start(self, node_svc: PostgresNodeService): + assert type(node_svc) == PostgresNodeService # noqa: E721 + assert PostgresNode._C_MAX_START_ATEMPTS == 5 + + C_COUNT_OF_BAD_PORT_USAGE = 3 + + with __class__.helper__get_node(node_svc) as node1: + node1.init().start() + assert node1._should_free_port + assert type(node1.port) == int # noqa: E721 + node1_port_copy = node1.port + assert __class__.helper__rm_carriage_returns(node1.safe_psql("SELECT 1;")) == b'1\n' + + with __class__.tagPortManagerProxy(node_svc.port_manager, node1.port, C_COUNT_OF_BAD_PORT_USAGE) as proxy: + assert proxy.m_DummyPortNumber == node1.port + with __class__.helper__get_node(node_svc, port_manager=proxy) as node2: + assert node2._should_free_port + assert node2.port == node1.port + + node2.init().start() + + assert node2.port != node1.port + assert node2._should_free_port + assert proxy.m_DummyPortCurrentUsage == 0 + assert proxy.m_DummyPortTotalUsage == C_COUNT_OF_BAD_PORT_USAGE + assert node2.is_started + r = node2.safe_psql("SELECT 2;") + assert __class__.helper__rm_carriage_returns(r) == b'2\n' + + # node1 is still working + assert node1.port == node1_port_copy + assert node1._should_free_port + r = node1.safe_psql("SELECT 3;") + assert __class__.helper__rm_carriage_returns(r) == b'3\n' + + def test_port_conflict(self, node_svc: PostgresNodeService): + assert type(node_svc) == PostgresNodeService # noqa: E721 + assert PostgresNode._C_MAX_START_ATEMPTS > 1 + + C_COUNT_OF_BAD_PORT_USAGE = PostgresNode._C_MAX_START_ATEMPTS + + with __class__.helper__get_node(node_svc) as node1: + node1.init().start() + assert node1._should_free_port + assert type(node1.port) == int # noqa: E721 + node1_port_copy = node1.port + assert __class__.helper__rm_carriage_returns(node1.safe_psql("SELECT 1;")) == b'1\n' + + with __class__.tagPortManagerProxy(node_svc.port_manager, node1.port, C_COUNT_OF_BAD_PORT_USAGE) as proxy: + assert proxy.m_DummyPortNumber == node1.port + with __class__.helper__get_node(node_svc, port_manager=proxy) as node2: + assert node2._should_free_port + assert node2.port == node1.port + + with pytest.raises( + expected_exception=StartNodeException, + match=re.escape("Cannot start node after multiple attempts.") + ): + node2.init().start() + + assert node2.port == node1.port + assert node2._should_free_port + assert proxy.m_DummyPortCurrentUsage == 1 + assert proxy.m_DummyPortTotalUsage == C_COUNT_OF_BAD_PORT_USAGE + assert not node2.is_started + + # node2 must release our dummyPort (node1.port) + assert (proxy.m_DummyPortCurrentUsage == 0) + + # node1 is still working + assert node1.port == node1_port_copy + assert node1._should_free_port + r = node1.safe_psql("SELECT 3;") + assert __class__.helper__rm_carriage_returns(r) == b'3\n' + + def test_try_to_get_port_after_free_manual_port(self, node_svc: PostgresNodeService): + assert type(node_svc) == PostgresNodeService # noqa: E721 + + assert node_svc.port_manager is not None + assert isinstance(node_svc.port_manager, PortManager) + + with __class__.helper__get_node(node_svc) as node1: + assert node1 is not None + assert type(node1) == PostgresNode # noqa: E721 + assert node1.port is not None + assert type(node1.port) == int # noqa: E721 + with __class__.helper__get_node(node_svc, port=node1.port, port_manager=None) as node2: + assert node2 is not None + assert type(node1) == PostgresNode # noqa: E721 + assert node2 is not node1 + assert node2.port is not None + assert type(node2.port) == int # noqa: E721 + assert node2.port == node1.port + + logging.info("Release node2 port") + node2.free_port() + + logging.info("try to get node2.port...") + with pytest.raises( + InvalidOperationException, + match="^" + re.escape("PostgresNode port is not defined.") + "$" + ): + p = node2.port + assert p is None + + def test_try_to_start_node_after_free_manual_port(self, node_svc: PostgresNodeService): + assert type(node_svc) == PostgresNodeService # noqa: E721 + + assert node_svc.port_manager is not None + assert isinstance(node_svc.port_manager, PortManager) + + with __class__.helper__get_node(node_svc) as node1: + assert node1 is not None + assert type(node1) == PostgresNode # noqa: E721 + assert node1.port is not None + assert type(node1.port) == int # noqa: E721 + with __class__.helper__get_node(node_svc, port=node1.port, port_manager=None) as node2: + assert node2 is not None + assert type(node1) == PostgresNode # noqa: E721 + assert node2 is not node1 + assert node2.port is not None + assert type(node2.port) == int # noqa: E721 + assert node2.port == node1.port + + logging.info("Release node2 port") + node2.free_port() + + logging.info("node2 is trying to start...") + with pytest.raises( + InvalidOperationException, + match="^" + re.escape("Can't start PostgresNode. Port is not defined.") + "$" + ): + node2.start() + + @staticmethod + def helper__get_node( + node_svc: PostgresNodeService, + name: typing.Optional[str] = None, + port: typing.Optional[int] = None, + port_manager: typing.Optional[PortManager] = None + ) -> PostgresNode: + assert isinstance(node_svc, PostgresNodeService) + assert isinstance(node_svc.os_ops, OsOperations) + assert isinstance(node_svc.port_manager, PortManager) + + if port_manager is None: + port_manager = node_svc.port_manager + + return PostgresNode( + name, + port=port, + os_ops=node_svc.os_ops, + port_manager=port_manager if port is None else None + ) + + @staticmethod + def helper__skip_test_if_pg_version_is_not_ge(ver1: str, ver2: str): + assert type(ver1) == str # noqa: E721 + assert type(ver2) == str # noqa: E721 + if not __class__.helper__pg_version_ge(ver1, ver2): + pytest.skip('requires {0}+'.format(ver2)) + + @staticmethod + def helper__skip_test_if_pg_version_is_ge(ver1: str, ver2: str): + assert type(ver1) == str # noqa: E721 + assert type(ver2) == str # noqa: E721 + if __class__.helper__pg_version_ge(ver1, ver2): + pytest.skip('requires <{0}'.format(ver2)) + + @staticmethod + def helper__pg_version_ge(ver1: str, ver2: str) -> bool: + assert type(ver1) == str # noqa: E721 + assert type(ver2) == str # noqa: E721 + v1 = PgVer(ver1) + v2 = PgVer(ver2) + return v1 >= v2 + + @staticmethod + def helper__rm_carriage_returns(out): + """ + In Windows we have additional '\r' symbols in output. + Let's get rid of them. + """ + if isinstance(out, (int, float, complex)): + return out + + if isinstance(out, tuple): + return tuple(__class__.helper__rm_carriage_returns(item) for item in out) + + if isinstance(out, bytes): + return out.replace(b'\r', b'') + + assert type(out) == str # noqa: E721 + return out.replace('\r', '') + + @staticmethod + def helper__skip_test_if_util_not_exist(os_ops: OsOperations, name: str): + assert isinstance(os_ops, OsOperations) + assert type(name) == str # noqa: E721 + if not __class__.helper__util_exists(os_ops, name): + pytest.skip('might be missing') + + @staticmethod + def helper__util_exists(os_ops: OsOperations, util): + assert isinstance(os_ops, OsOperations) + + def good_properties(f): + return (os_ops.path_exists(f) and # noqa: W504 + os_ops.isfile(f) and # noqa: W504 + os_ops.is_executable(f)) # yapf: disable + + # try to resolve it + if good_properties(get_bin_path2(os_ops, util)): + return True + + # check if util is in PATH + for path in os_ops.environ("PATH").split(os.pathsep): + if good_properties(os.path.join(path, util)): + return True diff --git a/tests/test_testgres_local.py b/tests/test_testgres_local.py new file mode 100644 index 00000000..63e5f37e --- /dev/null +++ b/tests/test_testgres_local.py @@ -0,0 +1,414 @@ +# coding: utf-8 +import os +import re +import subprocess +import pytest +import psutil +import platform +import logging + +import testgres + +from testgres import StartNodeException +from testgres import ExecUtilException +from testgres import NodeApp +from testgres import scoped_config +from testgres import get_new_node +from testgres import get_bin_path +from testgres import get_pg_config +from testgres import get_pg_version + +# NOTE: those are ugly imports +from testgres.utils import bound_ports +from testgres.utils import PgVer +from testgres.node import ProcessProxy + + +def pg_version_ge(version): + cur_ver = PgVer(get_pg_version()) + min_ver = PgVer(version) + return cur_ver >= min_ver + + +def util_exists(util): + def good_properties(f): + return (os.path.exists(f) and # noqa: W504 + os.path.isfile(f) and # noqa: W504 + os.access(f, os.X_OK)) # yapf: disable + + # try to resolve it + if good_properties(get_bin_path(util)): + return True + + # check if util is in PATH + for path in os.environ["PATH"].split(os.pathsep): + if good_properties(os.path.join(path, util)): + return True + + +def rm_carriage_returns(out): + """ + In Windows we have additional '\r' symbols in output. + Let's get rid of them. + """ + if os.name == 'nt': + if isinstance(out, (int, float, complex)): + return out + elif isinstance(out, tuple): + return tuple(rm_carriage_returns(item) for item in out) + elif isinstance(out, bytes): + return out.replace(b'\r', b'') + else: + return out.replace('\r', '') + else: + return out + + +class TestTestgresLocal: + def test_pg_config(self): + # check same instances + a = get_pg_config() + b = get_pg_config() + assert (id(a) == id(b)) + + # save right before config change + c1 = get_pg_config() + + # modify setting for this scope + with scoped_config(cache_pg_config=False) as config: + # sanity check for value + assert not (config.cache_pg_config) + + # save right after config change + c2 = get_pg_config() + + # check different instances after config change + assert (id(c1) != id(c2)) + + # check different instances + a = get_pg_config() + b = get_pg_config() + assert (id(a) != id(b)) + + def test_ports_management(self): + assert bound_ports is not None + assert type(bound_ports) == set # noqa: E721 + + if len(bound_ports) != 0: + logging.warning("bound_ports is not empty: {0}".format(bound_ports)) + + stage0__bound_ports = bound_ports.copy() + + with get_new_node() as node: + assert bound_ports is not None + assert type(bound_ports) == set # noqa: E721 + + assert node.port is not None + assert type(node.port) == int # noqa: E721 + + logging.info("node port is {0}".format(node.port)) + + assert node.port in bound_ports + assert node.port not in stage0__bound_ports + + assert stage0__bound_ports <= bound_ports + assert len(stage0__bound_ports) + 1 == len(bound_ports) + + stage1__bound_ports = stage0__bound_ports.copy() + stage1__bound_ports.add(node.port) + + assert stage1__bound_ports == bound_ports + + # check that port has been freed successfully + assert bound_ports is not None + assert type(bound_ports) == set # noqa: E721 + assert bound_ports == stage0__bound_ports + + def test_child_process_dies(self): + # test for FileNotFound exception during child_processes() function + cmd = ["timeout", "60"] if os.name == 'nt' else ["sleep", "60"] + + nAttempt = 0 + + while True: + if nAttempt == 5: + raise Exception("Max attempt number is exceed.") + + nAttempt += 1 + + logging.info("Attempt #{0}".format(nAttempt)) + + with subprocess.Popen(cmd, shell=True) as process: # shell=True might be needed on Windows + r = process.poll() + + if r is not None: + logging.warning("process.pool() returns an unexpected result: {0}.".format(r)) + continue + + assert r is None + # collect list of processes currently running + children = psutil.Process(os.getpid()).children() + # kill a process, so received children dictionary becomes invalid + process.kill() + process.wait() + # try to handle children list -- missing processes will have ptype "ProcessType.Unknown" + [ProcessProxy(p) for p in children] + break + + def test_upgrade_node(self): + old_bin_dir = os.path.dirname(get_bin_path("pg_config")) + new_bin_dir = os.path.dirname(get_bin_path("pg_config")) + with get_new_node(prefix='node_old', bin_dir=old_bin_dir) as node_old: + node_old.init() + node_old.start() + node_old.stop() + with get_new_node(prefix='node_new', bin_dir=new_bin_dir) as node_new: + node_new.init(cached=False) + res = node_new.upgrade_from(old_node=node_old) + node_new.start() + assert (b'Upgrade Complete' in res) + + class tagPortManagerProxy: + sm_prev_testgres_reserve_port = None + sm_prev_testgres_release_port = None + + sm_DummyPortNumber = None + sm_DummyPortMaxUsage = None + + sm_DummyPortCurrentUsage = None + sm_DummyPortTotalUsage = None + + def __init__(self, dummyPortNumber, dummyPortMaxUsage): + assert type(dummyPortNumber) == int # noqa: E721 + assert type(dummyPortMaxUsage) == int # noqa: E721 + assert dummyPortNumber >= 0 + assert dummyPortMaxUsage >= 0 + + assert __class__.sm_prev_testgres_reserve_port is None + assert __class__.sm_prev_testgres_release_port is None + assert testgres.utils.reserve_port == testgres.utils.internal__reserve_port + assert testgres.utils.release_port == testgres.utils.internal__release_port + + __class__.sm_prev_testgres_reserve_port = testgres.utils.reserve_port + __class__.sm_prev_testgres_release_port = testgres.utils.release_port + + testgres.utils.reserve_port = __class__._proxy__reserve_port + testgres.utils.release_port = __class__._proxy__release_port + + assert testgres.utils.reserve_port == __class__._proxy__reserve_port + assert testgres.utils.release_port == __class__._proxy__release_port + + __class__.sm_DummyPortNumber = dummyPortNumber + __class__.sm_DummyPortMaxUsage = dummyPortMaxUsage + + __class__.sm_DummyPortCurrentUsage = 0 + __class__.sm_DummyPortTotalUsage = 0 + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + assert __class__.sm_DummyPortCurrentUsage == 0 + + assert __class__.sm_prev_testgres_reserve_port is not None + assert __class__.sm_prev_testgres_release_port is not None + + assert testgres.utils.reserve_port == __class__._proxy__reserve_port + assert testgres.utils.release_port == __class__._proxy__release_port + + testgres.utils.reserve_port = __class__.sm_prev_testgres_reserve_port + testgres.utils.release_port = __class__.sm_prev_testgres_release_port + + __class__.sm_prev_testgres_reserve_port = None + __class__.sm_prev_testgres_release_port = None + + @staticmethod + def _proxy__reserve_port(): + assert type(__class__.sm_DummyPortMaxUsage) == int # noqa: E721 + assert type(__class__.sm_DummyPortTotalUsage) == int # noqa: E721 + assert type(__class__.sm_DummyPortCurrentUsage) == int # noqa: E721 + assert __class__.sm_DummyPortTotalUsage >= 0 + assert __class__.sm_DummyPortCurrentUsage >= 0 + + assert __class__.sm_DummyPortTotalUsage <= __class__.sm_DummyPortMaxUsage + assert __class__.sm_DummyPortCurrentUsage <= __class__.sm_DummyPortTotalUsage + + assert __class__.sm_prev_testgres_reserve_port is not None + + if __class__.sm_DummyPortTotalUsage == __class__.sm_DummyPortMaxUsage: + return __class__.sm_prev_testgres_reserve_port() + + __class__.sm_DummyPortTotalUsage += 1 + __class__.sm_DummyPortCurrentUsage += 1 + return __class__.sm_DummyPortNumber + + @staticmethod + def _proxy__release_port(dummyPortNumber): + assert type(dummyPortNumber) == int # noqa: E721 + + assert type(__class__.sm_DummyPortMaxUsage) == int # noqa: E721 + assert type(__class__.sm_DummyPortTotalUsage) == int # noqa: E721 + assert type(__class__.sm_DummyPortCurrentUsage) == int # noqa: E721 + assert __class__.sm_DummyPortTotalUsage >= 0 + assert __class__.sm_DummyPortCurrentUsage >= 0 + + assert __class__.sm_DummyPortTotalUsage <= __class__.sm_DummyPortMaxUsage + assert __class__.sm_DummyPortCurrentUsage <= __class__.sm_DummyPortTotalUsage + + assert __class__.sm_prev_testgres_release_port is not None + + if __class__.sm_DummyPortCurrentUsage > 0 and dummyPortNumber == __class__.sm_DummyPortNumber: + assert __class__.sm_DummyPortTotalUsage > 0 + __class__.sm_DummyPortCurrentUsage -= 1 + return + + return __class__.sm_prev_testgres_release_port(dummyPortNumber) + + def test_port_rereserve_during_node_start(self): + assert testgres.PostgresNode._C_MAX_START_ATEMPTS == 5 + + C_COUNT_OF_BAD_PORT_USAGE = 3 + + with get_new_node() as node1: + node1.init().start() + assert (node1._should_free_port) + assert (type(node1.port) == int) # noqa: E721 + node1_port_copy = node1.port + assert (rm_carriage_returns(node1.safe_psql("SELECT 1;")) == b'1\n') + + with __class__.tagPortManagerProxy(node1.port, C_COUNT_OF_BAD_PORT_USAGE): + assert __class__.tagPortManagerProxy.sm_DummyPortNumber == node1.port + with get_new_node() as node2: + assert (node2._should_free_port) + assert (node2.port == node1.port) + + node2.init().start() + + assert (node2.port != node1.port) + assert (node2._should_free_port) + assert (__class__.tagPortManagerProxy.sm_DummyPortCurrentUsage == 0) + assert (__class__.tagPortManagerProxy.sm_DummyPortTotalUsage == C_COUNT_OF_BAD_PORT_USAGE) + assert (node2.is_started) + + assert (rm_carriage_returns(node2.safe_psql("SELECT 2;")) == b'2\n') + + # node1 is still working + assert (node1.port == node1_port_copy) + assert (node1._should_free_port) + assert (rm_carriage_returns(node1.safe_psql("SELECT 3;")) == b'3\n') + + def test_port_conflict(self): + assert testgres.PostgresNode._C_MAX_START_ATEMPTS > 1 + + C_COUNT_OF_BAD_PORT_USAGE = testgres.PostgresNode._C_MAX_START_ATEMPTS + + with get_new_node() as node1: + node1.init().start() + assert (node1._should_free_port) + assert (type(node1.port) == int) # noqa: E721 + node1_port_copy = node1.port + assert (rm_carriage_returns(node1.safe_psql("SELECT 1;")) == b'1\n') + + with __class__.tagPortManagerProxy(node1.port, C_COUNT_OF_BAD_PORT_USAGE): + assert __class__.tagPortManagerProxy.sm_DummyPortNumber == node1.port + with get_new_node() as node2: + assert (node2._should_free_port) + assert (node2.port == node1.port) + + with pytest.raises( + expected_exception=StartNodeException, + match=re.escape("Cannot start node after multiple attempts.") + ): + node2.init().start() + + assert (node2.port == node1.port) + assert (node2._should_free_port) + assert (__class__.tagPortManagerProxy.sm_DummyPortCurrentUsage == 1) + assert (__class__.tagPortManagerProxy.sm_DummyPortTotalUsage == C_COUNT_OF_BAD_PORT_USAGE) + assert not (node2.is_started) + + # node2 must release our dummyPort (node1.port) + assert (__class__.tagPortManagerProxy.sm_DummyPortCurrentUsage == 0) + + # node1 is still working + assert (node1.port == node1_port_copy) + assert (node1._should_free_port) + assert (rm_carriage_returns(node1.safe_psql("SELECT 3;")) == b'3\n') + + def test_simple_with_bin_dir(self): + with get_new_node() as node: + node.init().start() + bin_dir = node.bin_dir + + app = NodeApp() + with app.make_simple(base_dir=node.base_dir, bin_dir=bin_dir) as correct_bin_dir: + correct_bin_dir.slow_start() + correct_bin_dir.safe_psql("SELECT 1;") + correct_bin_dir.stop() + + while True: + try: + app.make_simple(base_dir=node.base_dir, bin_dir="wrong/path") + except FileNotFoundError: + break # Expected error + except ExecUtilException: + break # Expected error + + raise RuntimeError("Error was expected.") # We should not reach this + + return + + def test_set_auto_conf(self): + # elements contain [property id, value, storage value] + testData = [ + ["archive_command", + "cp '%p' \"/mnt/server/archivedir/%f\"", + "'cp \\'%p\\' \"/mnt/server/archivedir/%f\""], + ["log_line_prefix", + "'\n\r\t\b\\\"", + "'\\\'\\n\\r\\t\\b\\\\\""], + ["log_connections", + True, + "on"], + ["log_disconnections", + False, + "off"], + ["autovacuum_max_workers", + 3, + "3"] + ] + if pg_version_ge('12'): + testData.append(["restore_command", + 'cp "/mnt/server/archivedir/%f" \'%p\'', + "'cp \"/mnt/server/archivedir/%f\" \\'%p\\''"]) + + with get_new_node() as node: + node.init().start() + + options = {} + + for x in testData: + options[x[0]] = x[1] + + node.set_auto_conf(options) + node.stop() + node.slow_start() + + auto_conf_path = f"{node.data_dir}/postgresql.auto.conf" + with open(auto_conf_path, "r") as f: + content = f.read() + + for x in testData: + assert x[0] + " = " + x[2] in content + + @staticmethod + def helper__skip_test_if_util_not_exist(name: str): + assert type(name) == str # noqa: E721 + + if platform.system().lower() == "windows": + name2 = name + ".exe" + else: + name2 = name + + if not util_exists(name2): + pytest.skip('might be missing') diff --git a/tests/test_testgres_remote.py b/tests/test_testgres_remote.py new file mode 100755 index 00000000..6a8d068b --- /dev/null +++ b/tests/test_testgres_remote.py @@ -0,0 +1,190 @@ +# coding: utf-8 +import os + +import pytest +import logging + +from .helpers.global_data import PostgresNodeService +from .helpers.global_data import PostgresNodeServices + +import testgres + +from testgres.exceptions import InitNodeException +from testgres.exceptions import ExecUtilException + +from testgres.config import scoped_config +from testgres.config import testgres_config + +from testgres import get_bin_path +from testgres import get_pg_config + +# NOTE: those are ugly imports + + +def util_exists(util): + def good_properties(f): + return (testgres_config.os_ops.path_exists(f) and # noqa: W504 + testgres_config.os_ops.isfile(f) and # noqa: W504 + testgres_config.os_ops.is_executable(f)) # yapf: disable + + # try to resolve it + if good_properties(get_bin_path(util)): + return True + + # check if util is in PATH + for path in testgres_config.os_ops.environ("PATH").split(testgres_config.os_ops.pathsep): + if good_properties(os.path.join(path, util)): + return True + + +class TestTestgresRemote: + @pytest.fixture(autouse=True, scope="class") + def implicit_fixture(self): + cur_os_ops = PostgresNodeServices.sm_remote.os_ops + assert cur_os_ops is not None + + prev_ops = testgres_config.os_ops + assert prev_ops is not None + testgres_config.set_os_ops(os_ops=cur_os_ops) + assert testgres_config.os_ops is cur_os_ops + yield + assert testgres_config.os_ops is cur_os_ops + testgres_config.set_os_ops(os_ops=prev_ops) + assert testgres_config.os_ops is prev_ops + + def test_init__LANG_ะก(self): + # PBCKP-1744 + prev_LANG = os.environ.get("LANG") + + try: + os.environ["LANG"] = "C" + + with __class__.helper__get_node() as node: + node.init().start() + finally: + __class__.helper__restore_envvar("LANG", prev_LANG) + + def test_init__unk_LANG_and_LC_CTYPE(self): + # PBCKP-1744 + prev_LANG = os.environ.get("LANG") + prev_LANGUAGE = os.environ.get("LANGUAGE") + prev_LC_CTYPE = os.environ.get("LC_CTYPE") + prev_LC_COLLATE = os.environ.get("LC_COLLATE") + + try: + # TODO: Pass unkData through test parameter. + unkDatas = [ + ("UNKNOWN_LANG", "UNKNOWN_CTYPE"), + ("\"UNKNOWN_LANG\"", "\"UNKNOWN_CTYPE\""), + ("\\UNKNOWN_LANG\\", "\\UNKNOWN_CTYPE\\"), + ("\"UNKNOWN_LANG", "UNKNOWN_CTYPE\""), + ("\\UNKNOWN_LANG", "UNKNOWN_CTYPE\\"), + ("\\", "\\"), + ("\"", "\""), + ] + + errorIsDetected = False + + for unkData in unkDatas: + logging.info("----------------------") + logging.info("Unk LANG is [{0}]".format(unkData[0])) + logging.info("Unk LC_CTYPE is [{0}]".format(unkData[1])) + + os.environ["LANG"] = unkData[0] + os.environ.pop("LANGUAGE", None) + os.environ["LC_CTYPE"] = unkData[1] + os.environ.pop("LC_COLLATE", None) + + assert os.environ.get("LANG") == unkData[0] + assert not ("LANGUAGE" in os.environ.keys()) + assert os.environ.get("LC_CTYPE") == unkData[1] + assert not ("LC_COLLATE" in os.environ.keys()) + + assert os.getenv('LANG') == unkData[0] + assert os.getenv('LANGUAGE') is None + assert os.getenv('LC_CTYPE') == unkData[1] + assert os.getenv('LC_COLLATE') is None + + exc: ExecUtilException = None + with __class__.helper__get_node() as node: + try: + node.init() # IT RAISES! + except InitNodeException as e: + exc = e.__cause__ + assert exc is not None + assert isinstance(exc, ExecUtilException) + + if exc is None: + logging.warning("We expected an error!") + continue + + errorIsDetected = True + + assert isinstance(exc, ExecUtilException) + + errMsg = str(exc) + logging.info("Error message is {0}: {1}".format(type(exc).__name__, errMsg)) + + assert "warning: setlocale: LC_CTYPE: cannot change locale (" + unkData[1] + ")" in errMsg + assert "initdb: error: invalid locale settings; check LANG and LC_* environment variables" in errMsg + continue + + if not errorIsDetected: + pytest.xfail("All the bad data are processed without errors!") + + finally: + __class__.helper__restore_envvar("LANG", prev_LANG) + __class__.helper__restore_envvar("LANGUAGE", prev_LANGUAGE) + __class__.helper__restore_envvar("LC_CTYPE", prev_LC_CTYPE) + __class__.helper__restore_envvar("LC_COLLATE", prev_LC_COLLATE) + + def test_pg_config(self): + # check same instances + a = get_pg_config() + b = get_pg_config() + assert (id(a) == id(b)) + + # save right before config change + c1 = get_pg_config() + + # modify setting for this scope + with scoped_config(cache_pg_config=False) as config: + # sanity check for value + assert not (config.cache_pg_config) + + # save right after config change + c2 = get_pg_config() + + # check different instances after config change + assert (id(c1) != id(c2)) + + # check different instances + a = get_pg_config() + b = get_pg_config() + assert (id(a) != id(b)) + + @staticmethod + def helper__get_node(name=None): + svc = PostgresNodeServices.sm_remote + + assert isinstance(svc, PostgresNodeService) + assert isinstance(svc.os_ops, testgres.OsOperations) + assert isinstance(svc.port_manager, testgres.PortManager) + + return testgres.PostgresNode( + name, + os_ops=svc.os_ops, + port_manager=svc.port_manager) + + @staticmethod + def helper__restore_envvar(name, prev_value): + if prev_value is None: + os.environ.pop(name, None) + else: + os.environ[name] = prev_value + + @staticmethod + def helper__skip_test_if_util_not_exist(name: str): + assert type(name) == str # noqa: E721 + if not util_exists(name): + pytest.skip('might be missing') diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..39e9dda0 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,63 @@ +from .helpers.global_data import OsOpsDescr +from .helpers.global_data import OsOpsDescrs +from .helpers.global_data import OsOperations + +from testgres.utils import parse_pg_version +from testgres.utils import get_pg_config2 +from testgres import scoped_config + +import pytest +import typing + + +class TestUtils: + sm_os_ops_descrs: typing.List[OsOpsDescr] = [ + OsOpsDescrs.sm_local_os_ops_descr, + OsOpsDescrs.sm_remote_os_ops_descr + ] + + @pytest.fixture( + params=[descr.os_ops for descr in sm_os_ops_descrs], + ids=[descr.sign for descr in sm_os_ops_descrs] + ) + def os_ops(self, request: pytest.FixtureRequest) -> OsOperations: + assert isinstance(request, pytest.FixtureRequest) + assert isinstance(request.param, OsOperations) + return request.param + + def test_parse_pg_version(self): + # Linux Mint + assert parse_pg_version("postgres (PostgreSQL) 15.5 (Ubuntu 15.5-1.pgdg22.04+1)") == "15.5" + # Linux Ubuntu + assert parse_pg_version("postgres (PostgreSQL) 12.17") == "12.17" + # Windows + assert parse_pg_version("postgres (PostgreSQL) 11.4") == "11.4" + # Macos + assert parse_pg_version("postgres (PostgreSQL) 14.9 (Homebrew)") == "14.9" + + def test_get_pg_config2(self, os_ops: OsOperations): + assert isinstance(os_ops, OsOperations) + + # check same instances + a = get_pg_config2(os_ops, None) + b = get_pg_config2(os_ops, None) + assert (id(a) == id(b)) + + # save right before config change + c1 = get_pg_config2(os_ops, None) + + # modify setting for this scope + with scoped_config(cache_pg_config=False) as config: + # sanity check for value + assert not (config.cache_pg_config) + + # save right after config change + c2 = get_pg_config2(os_ops, None) + + # check different instances after config change + assert (id(c1) != id(c2)) + + # check different instances + a = get_pg_config2(os_ops, None) + b = get_pg_config2(os_ops, None) + assert (id(a) != id(b))