Skip to content

Fix initdb error on Windows #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix initdb error on Windows
  • Loading branch information
vshepard committed Dec 14, 2023
commit edb5708ac4c2aafac9717911ee0ac7f3ea30e5df
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@
readme = f.read()

setup(
version='1.9.2',
version='1.9.3',
name='testgres',
packages=['testgres', 'testgres.operations'],
description='Testing utility for PostgreSQL and its extensions',
url='https://github.com/postgrespro/testgres',
long_description=readme,
long_description_content_type='text/markdown',
license='PostgreSQL',
author='Ildar Musin',
author_email='[email protected]',
author='Postgres Professional',
author_email='[email protected]',
keywords=['test', 'testing', 'postgresql'],
install_requires=install_requires,
classifiers=[],
Expand Down
100 changes: 85 additions & 15 deletions testgres/operations/local_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import psutil

from ..exceptions import ExecUtilException
from .os_ops import ConnectionParams, OsOperations
from .os_ops import pglib
from .os_ops import ConnectionParams, OsOperations, pglib, get_default_encoding

try:
from shutil import which as find_executable
Expand All @@ -22,6 +21,12 @@
error_markers = [b'error', b'Permission denied', b'fatal']


def has_errors(output):
if isinstance(output, str):
output = output.encode(get_default_encoding())
return any(marker in output for marker in error_markers)


class LocalOperations(OsOperations):
def __init__(self, conn_params=None):
if conn_params is None:
Expand All @@ -33,7 +38,38 @@ def __init__(self, conn_params=None):
self.remote = False
self.username = conn_params.username or self.get_user()

# Command execution
@staticmethod
def _run_command(cmd, shell, input, timeout, encoding, temp_file=None):
"""Execute a command and return the process."""
if temp_file is not None:
stdout = temp_file
stderr = subprocess.STDOUT
else:
stdout = subprocess.PIPE
stderr = subprocess.PIPE

process = subprocess.Popen(
cmd,
shell=shell,
stdin=subprocess.PIPE if input is not None else None,
stdout=stdout,
stderr=stderr,
)

try:
return process.communicate(input=input.encode(encoding) if input else None, timeout=timeout), process
except subprocess.TimeoutExpired:
process.kill()
raise ExecUtilException("Command timed out after {} seconds.".format(timeout))

@staticmethod
def _raise_exec_exception(message, command, exit_code, output):
"""Raise an ExecUtilException."""
raise ExecUtilException(message=message.format(output),
command=command,
exit_code=exit_code,
out=output)

def exec_command(self, cmd, wait_exit=False, verbose=False,
expect_error=False, encoding=None, shell=False, text=False,
input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
Expand All @@ -56,16 +92,15 @@ def exec_command(self, cmd, wait_exit=False, verbose=False,
:return: The output of the subprocess.
"""
if os.name == 'nt':
with tempfile.NamedTemporaryFile() as buf:
process = subprocess.Popen(cmd, stdout=buf, stderr=subprocess.STDOUT)
process.communicate()
buf.seek(0)
result = buf.read().decode(encoding)
return result
return self._exec_command_windows(cmd, wait_exit=wait_exit, verbose=verbose,
expect_error=expect_error, encoding=encoding, shell=shell, text=text,
input=input, stdin=stdin, stdout=stdout, stderr=stderr,
get_process=get_process, timeout=timeout)
else:
process = subprocess.Popen(
cmd,
shell=shell,
stdin=stdin,
stdout=stdout,
stderr=stderr,
)
Expand All @@ -79,7 +114,7 @@ def exec_command(self, cmd, wait_exit=False, verbose=False,
raise ExecUtilException("Command timed out after {} seconds.".format(timeout))
exit_status = process.returncode

error_found = exit_status != 0 or any(marker in error for marker in error_markers)
error_found = exit_status != 0 or has_errors(error)

if encoding:
result = result.decode(encoding)
Expand All @@ -91,15 +126,50 @@ def exec_command(self, cmd, wait_exit=False, verbose=False,
if exit_status != 0 or error_found:
if exit_status == 0:
exit_status = 1
raise ExecUtilException(message='Utility exited with non-zero code. Error `{}`'.format(error),
command=cmd,
exit_code=exit_status,
out=result)
self._raise_exec_exception('Utility exited with non-zero code. Error `{}`', cmd, exit_status, result)
if verbose:
return exit_status, result, error
else:
return result

@staticmethod
def _process_output(process, encoding, temp_file=None):
"""Process the output of a command."""
if temp_file is not None:
temp_file.seek(0)
output = temp_file.read()
else:
output = process.stdout.read()

if encoding:
output = output.decode(encoding)

return output

def _exec_command_windows(self, cmd, wait_exit=False, verbose=False,
expect_error=False, encoding=None, shell=False, text=False,
input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
get_process=None, timeout=None):
with tempfile.NamedTemporaryFile(mode='w+b') as temp_file:
_, process = self._run_command(cmd, shell, input, timeout, encoding, temp_file)
if get_process:
return process
output = self._process_output(process, encoding, temp_file)

if process.returncode != 0 or has_errors(output):
if process.returncode == 0:
process.returncode = 1
if expect_error:
if verbose:
return process.returncode, output, output
else:
return output
else:
self._raise_exec_exception('Utility exited with non-zero code. Error `{}`', cmd, process.returncode,
output)

return (process.returncode, output, output) if verbose else output

# Environment setup
def environ(self, var_name):
return os.environ.get(var_name)
Expand Down Expand Up @@ -210,7 +280,7 @@ def read(self, filename, encoding=None, binary=False):
if binary:
return content
if isinstance(content, bytes):
return content.decode(encoding or 'utf-8')
return content.decode(encoding or get_default_encoding())
return content

def readlines(self, filename, num_lines=0, binary=False, encoding=None):
Expand Down
6 changes: 6 additions & 0 deletions testgres/operations/os_ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import locale

try:
import psycopg2 as pglib # noqa: F401
except ImportError:
Expand All @@ -14,6 +16,10 @@ def __init__(self, host='127.0.0.1', ssh_key=None, username=None):
self.username = username


def get_default_encoding():
return locale.getdefaultlocale()[1] or 'UTF-8'


class OsOperations:
def __init__(self, username=None):
self.ssh_key = None
Expand Down
34 changes: 15 additions & 19 deletions testgres/operations/remote_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import locale
import logging
import os
import subprocess
Expand All @@ -15,12 +14,7 @@
raise ImportError("You must have psycopg2 or pg8000 modules installed")

from ..exceptions import ExecUtilException

from .os_ops import OsOperations, ConnectionParams

ConsoleEncoding = locale.getdefaultlocale()[1]
if not ConsoleEncoding:
ConsoleEncoding = 'UTF-8'
from .os_ops import OsOperations, ConnectionParams, get_default_encoding

error_markers = [b'error', b'Permission denied', b'fatal', b'No such file or directory']

Expand All @@ -36,7 +30,7 @@ def kill(self):

def cmdline(self):
command = "ps -p {} -o cmd --no-headers".format(self.pid)
stdin, stdout, stderr = self.ssh.exec_command(command, verbose=True, encoding=ConsoleEncoding)
stdin, stdout, stderr = self.ssh.exec_command(command, verbose=True, encoding=get_default_encoding())
cmdline = stdout.strip()
return cmdline.split()

Expand Down Expand Up @@ -145,7 +139,7 @@ def environ(self, var_name: str) -> str:
- var_name (str): The name of the environment variable.
"""
cmd = "echo ${}".format(var_name)
return self.exec_command(cmd, encoding=ConsoleEncoding).strip()
return self.exec_command(cmd, encoding=get_default_encoding()).strip()

def find_executable(self, executable):
search_paths = self.environ("PATH")
Expand Down Expand Up @@ -176,11 +170,11 @@ def set_env(self, var_name: str, var_val: str):

# Get environment variables
def get_user(self):
return self.exec_command("echo $USER", encoding=ConsoleEncoding).strip()
return self.exec_command("echo $USER", encoding=get_default_encoding()).strip()

def get_name(self):
cmd = 'python3 -c "import os; print(os.name)"'
return self.exec_command(cmd, encoding=ConsoleEncoding).strip()
return self.exec_command(cmd, encoding=get_default_encoding()).strip()

# Work with dirs
def makedirs(self, path, remove_existing=False):
Expand Down Expand Up @@ -227,7 +221,7 @@ def listdir(self, path):
return result.splitlines()

def path_exists(self, path):
result = self.exec_command("test -e {}; echo $?".format(path), encoding=ConsoleEncoding)
result = self.exec_command("test -e {}; echo $?".format(path), encoding=get_default_encoding())
return int(result.strip()) == 0

@property
Expand Down Expand Up @@ -264,9 +258,9 @@ def mkdtemp(self, prefix=None):

def mkstemp(self, prefix=None):
if prefix:
temp_dir = self.exec_command("mktemp {}XXXXX".format(prefix), encoding=ConsoleEncoding)
temp_dir = self.exec_command("mktemp {}XXXXX".format(prefix), encoding=get_default_encoding())
else:
temp_dir = self.exec_command("mktemp", encoding=ConsoleEncoding)
temp_dir = self.exec_command("mktemp", encoding=get_default_encoding())

if temp_dir:
if not os.path.isabs(temp_dir):
Expand All @@ -283,7 +277,9 @@ def copytree(self, src, dst):
return self.exec_command("cp -r {} {}".format(src, dst))

# Work with files
def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding=ConsoleEncoding):
def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding=None):
if not encoding:
encoding = get_default_encoding()
mode = "wb" if binary else "w"
if not truncate:
mode = "ab" if binary else "a"
Expand All @@ -302,7 +298,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
data = data.encode(encoding)

if isinstance(data, list):
data = [(s if isinstance(s, str) else s.decode(ConsoleEncoding)).rstrip('\n') + '\n' for s in data]
data = [(s if isinstance(s, str) else s.decode(get_default_encoding())).rstrip('\n') + '\n' for s in data]
tmp_file.writelines(data)
else:
tmp_file.write(data)
Expand Down Expand Up @@ -334,7 +330,7 @@ def read(self, filename, binary=False, encoding=None):
result = self.exec_command(cmd, encoding=encoding)

if not binary and result:
result = result.decode(encoding or ConsoleEncoding)
result = result.decode(encoding or get_default_encoding())

return result

Expand All @@ -347,7 +343,7 @@ def readlines(self, filename, num_lines=0, binary=False, encoding=None):
result = self.exec_command(cmd, encoding=encoding)

if not binary and result:
lines = result.decode(encoding or ConsoleEncoding).splitlines()
lines = result.decode(encoding or get_default_encoding()).splitlines()
else:
lines = result.splitlines()

Expand Down Expand Up @@ -375,7 +371,7 @@ def kill(self, pid, signal):

def get_pid(self):
# Get current process id
return int(self.exec_command("echo $$", encoding=ConsoleEncoding))
return int(self.exec_command("echo $$", encoding=get_default_encoding()))

def get_process_children(self, pid):
command = ["ssh", "-i", self.ssh_key, f"{self.username}@{self.host}", f"pgrep -P {pid}"]
Expand Down
48 changes: 44 additions & 4 deletions testgres/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
from __future__ import print_function

import os
import port_for
import random
import socket

import sys

from contextlib import contextmanager
from packaging.version import Version, InvalidVersion
import re

from port_for import PortForException
from six import iteritems

from .exceptions import ExecUtilException
Expand All @@ -37,13 +40,49 @@ def reserve_port():
"""
Generate a new port and add it to 'bound_ports'.
"""

port = port_for.select_random(exclude_ports=bound_ports)
port = select_random(exclude_ports=bound_ports)
bound_ports.add(port)

return port


def select_random(
ports=None,
exclude_ports=None,
) -> int:
"""
Return random unused port number.
Standard function from port_for does not work on Windows because of error
'port_for.exceptions.PortForException: Can't select a port'
We should update it.
"""
if ports is None:
ports = set(range(1024, 65535))

if exclude_ports is None:
exclude_ports = set()

ports.difference_update(set(exclude_ports))

sampled_ports = random.sample(tuple(ports), min(len(ports), 100))

for port in sampled_ports:
if is_port_free(port):
return port

raise PortForException("Can't select a port")


def is_port_free(port: int) -> bool:
"""Check if a port is free to use."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(("", port))
return True
except OSError:
return False


def release_port(port):
"""
Free port provided by reserve_port().
Expand Down Expand Up @@ -80,7 +119,8 @@ def execute_utility(args, logfile=None, verbose=False):
lines = [u'\n'] + ['# ' + line for line in out.splitlines()] + [u'\n']
tconf.os_ops.write(filename=logfile, data=lines)
except IOError:
raise ExecUtilException("Problem with writing to logfile `{}` during run command `{}`".format(logfile, args))
raise ExecUtilException(
"Problem with writing to logfile `{}` during run command `{}`".format(logfile, args))
if verbose:
return exit_status, out, error
else:
Expand Down
Loading