diff --git a/README.md b/README.md index 29b974dc..a2a0ec7e 100644 --- a/README.md +++ b/README.md @@ -176,7 +176,7 @@ the configuration file, which means that they should be called before `append_co ### Remote mode Testgres supports the creation of PostgreSQL nodes on a remote host. This is useful when you want to run distributed tests involving multiple nodes spread across different machines. -To use this feature, you need to use the RemoteOperations class. +To use this feature, you need to use the RemoteOperations class. This feature is only supported with Linux. Here is an example of how you might set this up: ```python diff --git a/setup.py b/setup.py index 8cb0f70a..074de8a1 100755 --- a/setup.py +++ b/setup.py @@ -12,7 +12,6 @@ "six>=1.9.0", "psutil", "packaging", - "paramiko", "fabric", "sshtunnel" ] @@ -30,7 +29,7 @@ readme = f.read() setup( - version='1.9.0', + version='1.9.1', name='testgres', packages=['testgres', 'testgres.operations'], description='Testing utility for PostgreSQL and its extensions', diff --git a/testgres/__init__.py b/testgres/__init__.py index b63c7df1..383daf2d 100644 --- a/testgres/__init__.py +++ b/testgres/__init__.py @@ -46,6 +46,8 @@ First, \ Any +from .config import testgres_config + from .operations.os_ops import OsOperations, ConnectionParams from .operations.local_ops import LocalOperations from .operations.remote_ops import RemoteOperations @@ -53,7 +55,7 @@ __all__ = [ "get_new_node", "get_remote_node", - "NodeBackup", + "NodeBackup", "testgres_config", "TestgresConfig", "configure_testgres", "scoped_config", "push_config", "pop_config", "NodeConnection", "DatabaseError", "InternalError", "ProgrammingError", "OperationalError", "TestgresException", "ExecUtilException", "QueryException", "TimeoutException", "CatchUpException", "StartNodeException", "InitNodeException", "BackupException", diff --git a/testgres/cache.py b/testgres/cache.py index bf8658c9..21198e83 100644 --- a/testgres/cache.py +++ b/testgres/cache.py @@ -57,7 +57,9 @@ def call_initdb(initdb_dir, log=logfile): # our initdb caching mechanism breaks this contract. pg_control = os.path.join(data_dir, XLOG_CONTROL_FILE) system_id = generate_system_id() - os_ops.write(pg_control, system_id, truncate=True, binary=True, read_and_write=True) + cur_pg_control = os_ops.read(pg_control, binary=True) + new_pg_control = system_id + cur_pg_control[len(system_id):] + os_ops.write(pg_control, new_pg_control, truncate=True, binary=True, read_and_write=True) # XXX: build new WAL segment with our system id _params = [get_bin_path("pg_resetwal"), "-D", data_dir, "-f"] diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index 89071282..318ae675 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -198,9 +198,15 @@ def touch(self, filename): with open(filename, "a"): os.utime(filename, None) - def read(self, filename, encoding=None): - with open(filename, "r", encoding=encoding) as file: - return file.read() + def read(self, filename, encoding=None, binary=False): + mode = "rb" if binary else "r" + with open(filename, mode) as file: + content = file.read() + if binary: + return content + if isinstance(content, bytes): + return content.decode(encoding or 'utf-8') + return content def readlines(self, filename, num_lines=0, binary=False, encoding=None): """ diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index 6815c7f1..5d9bfe7e 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -1,13 +1,12 @@ +import locale +import logging import os +import subprocess import tempfile import time -from typing import Optional import sshtunnel -import paramiko -from paramiko import SSHClient - from ..exceptions import ExecUtilException from .os_ops import OsOperations, ConnectionParams @@ -16,6 +15,9 @@ sshtunnel.SSH_TIMEOUT = 5.0 sshtunnel.TUNNEL_TIMEOUT = 5.0 +ConsoleEncoding = locale.getdefaultlocale()[1] +if not ConsoleEncoding: + ConsoleEncoding = 'UTF-8' error_markers = [b'error', b'Permission denied', b'fatal', b'No such file or directory'] @@ -31,33 +33,29 @@ def kill(self): def cmdline(self): command = "ps -p {} -o cmd --no-headers".format(self.pid) - stdin, stdout, stderr = self.ssh.exec_command(command) - cmdline = stdout.read().decode('utf-8').strip() + stdin, stdout, stderr = self.ssh.exec_command(command, verbose=True, encoding=ConsoleEncoding) + cmdline = stdout.strip() return cmdline.split() class RemoteOperations(OsOperations): def __init__(self, conn_params: ConnectionParams): + if os.name != "posix": + raise EnvironmentError("Remote operations are supported only on Linux!") + super().__init__(conn_params.username) self.conn_params = conn_params self.host = conn_params.host self.ssh_key = conn_params.ssh_key - self.ssh = self.ssh_connect() self.remote = True self.username = conn_params.username or self.get_user() - self.tunnel = None + self.add_known_host(self.host) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close_tunnel() - if getattr(self, 'ssh', None): - self.ssh.close() - - def __del__(self): - if getattr(self, 'ssh', None): - self.ssh.close() def close_tunnel(self): if getattr(self, 'tunnel', None): @@ -68,26 +66,17 @@ def close_tunnel(self): break time.sleep(0.5) - def ssh_connect(self) -> Optional[SSHClient]: - key = self._read_ssh_key() - ssh = paramiko.SSHClient() - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh.connect(self.host, username=self.username, pkey=key) - return ssh - - def _read_ssh_key(self): + def add_known_host(self, host): + cmd = 'ssh-keyscan -H %s >> /home/%s/.ssh/known_hosts' % (host, os.getlogin()) try: - with open(self.ssh_key, "r") as f: - key_data = f.read() - if "BEGIN OPENSSH PRIVATE KEY" in key_data: - key = paramiko.Ed25519Key.from_private_key_file(self.ssh_key) - else: - key = paramiko.RSAKey.from_private_key_file(self.ssh_key) - return key - except FileNotFoundError: - raise ExecUtilException(message="No such file or directory: '{}'".format(self.ssh_key)) - except Exception as e: - ExecUtilException(message="An error occurred while reading the ssh key: {}".format(e)) + subprocess.check_call( + cmd, + shell=True, + ) + logging.info("Successfully added %s to known_hosts." % host) + except subprocess.CalledProcessError as e: + raise ExecUtilException(message="Failed to add %s to known_hosts. Error: %s" % (host, str(e)), command=cmd, + exit_code=e.returncode, out=e.stderr) def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None, @@ -97,49 +86,34 @@ def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=Fa Args: - cmd (str): The command to be executed. """ - if self.ssh is None or not self.ssh.get_transport() or not self.ssh.get_transport().is_active(): - self.ssh = self.ssh_connect() - - if isinstance(cmd, list): - cmd = ' '.join(item.decode('utf-8') if isinstance(item, bytes) else item for item in cmd) - if input: - stdin, stdout, stderr = self.ssh.exec_command(cmd) - stdin.write(input) - stdin.flush() - else: - stdin, stdout, stderr = self.ssh.exec_command(cmd) - exit_status = 0 - if wait_exit: - exit_status = stdout.channel.recv_exit_status() + if isinstance(cmd, str): + ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key, cmd] + elif isinstance(cmd, list): + ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key] + cmd + process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + result, error = process.communicate(input) + exit_status = process.returncode if encoding: - result = stdout.read().decode(encoding) - error = stderr.read().decode(encoding) - else: - result = stdout.read() - error = stderr.read() + result = result.decode(encoding) + error = error.decode(encoding) if expect_error: raise Exception(result, error) - if encoding: - error_found = exit_status != 0 or any( - marker.decode(encoding) in error for marker in error_markers) + if not error: + error_found = 0 else: error_found = exit_status != 0 or any( - marker in error for marker in error_markers) + marker in error for marker in [b'error', b'Permission denied', b'fatal', b'No such file or directory']) if error_found: - if exit_status == 0: - exit_status = 1 - if encoding: - message = "Utility exited with non-zero code. Error: {}".format(error.decode(encoding)) - else: + if isinstance(error, bytes): message = b"Utility exited with non-zero code. Error: " + error - raise ExecUtilException(message=message, - command=cmd, - exit_code=exit_status, - out=result) + else: + message = f"Utility exited with non-zero code. Error: {error}" + raise ExecUtilException(message=message, command=cmd, exit_code=exit_status, out=result) if verbose: return exit_status, result, error @@ -154,7 +128,7 @@ def environ(self, var_name: str) -> str: - var_name (str): The name of the environment variable. """ cmd = "echo ${}".format(var_name) - return self.exec_command(cmd, encoding='utf-8').strip() + return self.exec_command(cmd, encoding=ConsoleEncoding).strip() def find_executable(self, executable): search_paths = self.environ("PATH") @@ -185,11 +159,11 @@ def set_env(self, var_name: str, var_val: str): # Get environment variables def get_user(self): - return self.exec_command("echo $USER", encoding='utf-8').strip() + return self.exec_command("echo $USER", encoding=ConsoleEncoding).strip() def get_name(self): cmd = 'python3 -c "import os; print(os.name)"' - return self.exec_command(cmd, encoding='utf-8').strip() + return self.exec_command(cmd, encoding=ConsoleEncoding).strip() # Work with dirs def makedirs(self, path, remove_existing=False): @@ -236,7 +210,7 @@ def listdir(self, path): return result.splitlines() def path_exists(self, path): - result = self.exec_command("test -e {}; echo $?".format(path), encoding='utf-8') + result = self.exec_command("test -e {}; echo $?".format(path), encoding=ConsoleEncoding) return int(result.strip()) == 0 @property @@ -257,22 +231,25 @@ def mkdtemp(self, prefix=None): - prefix (str): The prefix of the temporary directory name. """ if prefix: - temp_dir = self.exec_command("mktemp -d {}XXXXX".format(prefix), encoding='utf-8') + command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", f"mktemp -d {prefix}XXXXX"] else: - temp_dir = self.exec_command("mktemp -d", encoding='utf-8') + command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", "mktemp -d"] - if temp_dir: + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + if result.returncode == 0: + temp_dir = result.stdout.strip() if not os.path.isabs(temp_dir): - temp_dir = os.path.join('/home', self.username, temp_dir.strip()) + temp_dir = os.path.join('/home', self.username, temp_dir) return temp_dir else: - raise ExecUtilException("Could not create temporary directory.") + raise ExecUtilException(f"Could not create temporary directory. Error: {result.stderr}") def mkstemp(self, prefix=None): if prefix: - temp_dir = self.exec_command("mktemp {}XXXXX".format(prefix), encoding='utf-8') + temp_dir = self.exec_command("mktemp {}XXXXX".format(prefix), encoding=ConsoleEncoding) else: - temp_dir = self.exec_command("mktemp", encoding='utf-8') + temp_dir = self.exec_command("mktemp", encoding=ConsoleEncoding) if temp_dir: if not os.path.isabs(temp_dir): @@ -289,20 +266,7 @@ def copytree(self, src, dst): return self.exec_command("cp -r {} {}".format(src, dst)) # Work with files - def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding='utf-8'): - """ - Write data to a file on a remote host - - Args: - - filename (str): The file path where the data will be written. - - data (bytes or str): The data to be written to the file. - - truncate (bool): If True, the file will be truncated before writing ('w' or 'wb' option); - if False (default), data will be appended ('a' or 'ab' option). - - binary (bool): If True, the data will be written in binary mode ('wb' or 'ab' option); - if False (default), the data will be written in text mode ('w' or 'a' option). - - read_and_write (bool): If True, the file will be opened with read and write permissions ('r+' option); - if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option). - """ + def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding=ConsoleEncoding): mode = "wb" if binary else "w" if not truncate: mode = "ab" if binary else "a" @@ -311,35 +275,29 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file: if not truncate: - with self.ssh_connect() as ssh: - sftp = ssh.open_sftp() - try: - sftp.get(filename, tmp_file.name) - tmp_file.seek(0, os.SEEK_END) - except FileNotFoundError: - pass # File does not exist yet, we'll create it - sftp.close() + scp_cmd = ['scp', '-i', self.ssh_key, f"{self.username}@{self.host}:{filename}", tmp_file.name] + subprocess.run(scp_cmd, check=False) # The file might not exist yet + tmp_file.seek(0, os.SEEK_END) + if isinstance(data, bytes) and not binary: data = data.decode(encoding) elif isinstance(data, str) and binary: data = data.encode(encoding) + if isinstance(data, list): - # ensure each line ends with a newline - data = [(s if isinstance(s, str) else s.decode('utf-8')).rstrip('\n') + '\n' for s in data] + data = [(s if isinstance(s, str) else s.decode(ConsoleEncoding)).rstrip('\n') + '\n' for s in data] tmp_file.writelines(data) else: tmp_file.write(data) + tmp_file.flush() - with self.ssh_connect() as ssh: - sftp = ssh.open_sftp() - remote_directory = os.path.dirname(filename) - try: - sftp.stat(remote_directory) - except IOError: - sftp.mkdir(remote_directory) - sftp.put(tmp_file.name, filename) - sftp.close() + scp_cmd = ['scp', '-i', self.ssh_key, tmp_file.name, f"{self.username}@{self.host}:{filename}"] + subprocess.run(scp_cmd, check=True) + + remote_directory = os.path.dirname(filename) + mkdir_cmd = ['ssh', '-i', self.ssh_key, f"{self.username}@{self.host}", f"mkdir -p {remote_directory}"] + subprocess.run(mkdir_cmd, check=True) os.remove(tmp_file.name) @@ -359,7 +317,7 @@ def read(self, filename, binary=False, encoding=None): result = self.exec_command(cmd, encoding=encoding) if not binary and result: - result = result.decode(encoding or 'utf-8') + result = result.decode(encoding or ConsoleEncoding) return result @@ -372,7 +330,7 @@ def readlines(self, filename, num_lines=0, binary=False, encoding=None): result = self.exec_command(cmd, encoding=encoding) if not binary and result: - lines = result.decode(encoding or 'utf-8').splitlines() + lines = result.decode(encoding or ConsoleEncoding).splitlines() else: lines = result.splitlines() @@ -400,13 +358,18 @@ def kill(self, pid, signal): def get_pid(self): # Get current process id - return int(self.exec_command("echo $$", encoding='utf-8')) + return int(self.exec_command("echo $$", encoding=ConsoleEncoding)) def get_process_children(self, pid): - command = "pgrep -P {}".format(pid) - stdin, stdout, stderr = self.ssh.exec_command(command) - children = stdout.readlines() - return [PsUtilProcessProxy(self.ssh, int(child_pid.strip())) for child_pid in children] + command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", f"pgrep -P {pid}"] + + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + if result.returncode == 0: + children = result.stdout.strip().splitlines() + return [PsUtilProcessProxy(self, int(child_pid.strip())) for child_pid in children] + else: + raise ExecUtilException(f"Error in getting process children. Error: {result.stderr}") # Database control def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, ssh_key=None): @@ -424,18 +387,19 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, s """ self.close_tunnel() self.tunnel = sshtunnel.open_tunnel( - (host, 22), # Remote server IP and SSH port - ssh_username=user or self.username, - ssh_pkey=ssh_key or self.ssh_key, - remote_bind_address=(host, port), # PostgreSQL server IP and PostgreSQL port - local_bind_address=('localhost', port) # Local machine IP and available port + (self.host, 22), # Remote server IP and SSH port + ssh_username=self.username, + ssh_pkey=self.ssh_key, + remote_bind_address=(self.host, port), # PostgreSQL server IP and PostgreSQL port + local_bind_address=('localhost', 0) + # Local machine IP and available port (0 means it will pick any available port) ) - self.tunnel.start() try: + # Use localhost and self.tunnel.local_bind_port to connect conn = pglib.connect( - host=host, # change to 'localhost' because we're connecting through a local ssh tunnel + host='localhost', # Connect to localhost port=self.tunnel.local_bind_port, # use the local bind port set up by the tunnel database=dbname, user=user or self.username, diff --git a/testgres/utils.py b/testgres/utils.py index 5e12eba9..b7df70d1 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -118,11 +118,13 @@ def get_bin_path(filename): return filename -def get_pg_config(pg_config_path=None): +def get_pg_config(pg_config_path=None, os_ops=None): """ Return output of pg_config (provided that it is installed). NOTE: this function caches the result by default (see GlobalConfig). """ + if os_ops: + tconf.os_ops = os_ops def cache_pg_config_data(cmd): # execute pg_config and get the output @@ -146,7 +148,7 @@ def cache_pg_config_data(cmd): _pg_config_data = {} # return cached data - if _pg_config_data: + if not pg_config_path and _pg_config_data: return _pg_config_data # try specified pg_config path or PG_CONFIG diff --git a/tests/test_remote.py b/tests/test_remote.py index 3794349c..2e0f0676 100755 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -17,9 +17,6 @@ def setup(self): 'RDBMS_TESTPOOL_SSHKEY') or '../../container_files/postgres/ssh/id_ed25519') self.operations = RemoteOperations(conn_params) - yield - self.operations.__del__() - def test_exec_command_success(self): """ Test exec_command for successful command execution. diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index e8386383..44e77fbd 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -135,7 +135,6 @@ def test_init_after_cleanup(self): @unittest.skipUnless(util_exists('pg_resetwal'), 'might be missing') @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') def test_init_unique_system_id(self): - # FAIL # this function exists in PostgreSQL 9.6+ query = 'select system_identifier from pg_control_system()'