From 03ff2ad03776bd1f96e8e84b6a330aaf6cf370ba Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 26 Aug 2023 14:38:49 +0200 Subject: [PATCH 001/118] Support Python 3.10 and Postgres 15 --- .gitignore | 2 ++ LICENSE.txt | 2 +- docs/about.txt | 6 +++--- docs/announce.rst | 4 ++-- docs/conf.py | 2 +- docs/contents/changelog.rst | 4 ++++ docs/contents/install.rst | 2 +- docs/copyright.rst | 2 +- pg.py | 6 +++--- pgconn.c | 2 +- pgdb.py | 2 +- pginternal.c | 2 +- pglarge.c | 2 +- pgmodule.c | 2 +- pgnotice.c | 2 +- pgquery.c | 2 +- pgsource.c | 2 +- setup.py | 7 ++++--- tests/test_classic_connection.py | 6 +++--- tests/test_classic_dbwrapper.py | 2 +- tests/test_classic_functions.py | 2 +- tox.ini | 10 +++++----- 22 files changed, 40 insertions(+), 33 deletions(-) diff --git a/.gitignore b/.gitignore index 67e3ae35..71300f9e 100644 --- a/.gitignore +++ b/.gitignore @@ -23,11 +23,13 @@ _build_doctrees/ docker-compose.yml Dockerfile Vagrantfile +Vagrantfile-* .coverage .tox/ .venv/ .vagrant/ +.vagrant-*/ Thumbs.db .DS_Store diff --git a/LICENSE.txt b/LICENSE.txt index c10a5870..eea706fe 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -6,7 +6,7 @@ Copyright (c) 1995, Pascal Andre Further modifications copyright (c) 1997-2008 by D'Arcy J.M. Cain -Further modifications copyright (c) 2009-2022 by the PyGreSQL Development Team +Further modifications copyright (c) 2009-2023 by the PyGreSQL Development Team PyGreSQL is released under the PostgreSQL License, a liberal Open Source license, similar to the BSD or MIT licenses: diff --git a/docs/about.txt b/docs/about.txt index 3463b3b7..d1492061 100644 --- a/docs/about.txt +++ b/docs/about.txt @@ -5,7 +5,7 @@ powerful PostgreSQL features from Python. | This software is copyright © 1995, Pascal Andre. | Further modifications are copyright © 1997-2008 by D'Arcy J.M. Cain. - | Further modifications are copyright © 2009-2022 by the PyGreSQL team. + | Further modifications are copyright © 2009-2023 by the PyGreSQL team. | For licensing details, see the full :doc:`copyright`. **PostgreSQL** is a highly scalable, SQL compliant, open source @@ -36,7 +36,7 @@ on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -The current version PyGreSQL 5.2.4 needs PostgreSQL 9.0 to 9.6 or 10 to 14, and -Python 2.7 or 3.5 to 3.10. If you need to support older PostgreSQL versions or +The current version PyGreSQL 5.2.4 needs PostgreSQL 9.0 to 9.6 or 10 to 15, and +Python 2.7 or 3.5 to 3.11. If you need to support older PostgreSQL versions or older Python 2.x versions, you can resort to the PyGreSQL 4.x versions that still support them. diff --git a/docs/announce.rst b/docs/announce.rst index 0c90d212..cadf376b 100644 --- a/docs/announce.rst +++ b/docs/announce.rst @@ -22,8 +22,8 @@ This version has been built and unit tested on: - openSUSE - Ubuntu - Windows 7 and 10 with both MinGW and Visual Studio - - PostgreSQL 9.0 to 9.6 and 10 to 14 (32 and 64bit) - - Python 2.7 and 3.5 to 3.10 (32 and 64bit) + - PostgreSQL 9.0 to 9.6 and 10 to 15 (32 and 64bit) + - Python 2.7 and 3.5 to 3.11 (32 and 64bit) | D'Arcy J.M. Cain | darcy@PyGreSQL.org diff --git a/docs/conf.py b/docs/conf.py index 6ea28189..6a9f87e0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -61,7 +61,7 @@ # General information about the project. project = 'PyGreSQL' author = 'The PyGreSQL team' -copyright = '2022, ' + author +copyright = '2023, ' + author # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 975ad682..0944b349 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,10 @@ ChangeLog ========= +Version 5.2.5 (to be released) +------------------------------ +- This version officially supports the new Python 3.11 and PostgreSQL 15. + Version 5.2.4 (2022-03-26) -------------------------- - Three more fixes in the `inserttable()` method of the `pg` module: diff --git a/docs/contents/install.rst b/docs/contents/install.rst index 1b7ef55e..4ef323af 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -14,7 +14,7 @@ On Windows, you also need to make sure that the directory that contains ``libpq.dll`` is part of your ``PATH`` environment variable. The current version of PyGreSQL has been tested with Python versions -2.7 and 3.5 to 3.10, and PostgreSQL versions 9.0 to 9.6 and 10 to 14. +2.7 and 3.5 to 3.11, and PostgreSQL versions 9.0 to 9.6 and 10 to 15. PyGreSQL will be installed as three modules, a shared library called ``_pg.so`` (on Linux) or a DLL called ``_pg.pyd`` (on Windows), and two pure diff --git a/docs/copyright.rst b/docs/copyright.rst index 4c9aacc6..9a8113ec 100644 --- a/docs/copyright.rst +++ b/docs/copyright.rst @@ -10,7 +10,7 @@ Copyright (c) 1995, Pascal Andre Further modifications copyright (c) 1997-2008 by D'Arcy J.M. Cain (darcy@PyGreSQL.org) -Further modifications copyright (c) 2009-2022 by the PyGreSQL team. +Further modifications copyright (c) 2009-2023 by the PyGreSQL team. Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement diff --git a/pg.py b/pg.py index b03c3d71..fbd97725 100644 --- a/pg.py +++ b/pg.py @@ -4,7 +4,7 @@ # # This file contains the classic pg module. # -# Copyright (c) 2022 by the PyGreSQL Development Team +# Copyright (c) 2023 by the PyGreSQL Development Team # # The notification handler is based on pgnotify which is # Copyright (c) 2001 Ng Pheng Siong. All rights reserved. @@ -99,7 +99,7 @@ try: # noinspection PyUnresolvedReferences - from typing import Dict, List, Union + from typing import Dict, List, Union # noqa: F401 has_typing = True except ImportError: # Python < 3.5 has_typing = False @@ -1934,7 +1934,7 @@ def set_parameter(self, parameter, value=None, local=False): value = set(value) if len(value) == 1: value = value.pop() - if not(value is None or isinstance(value, basestring)): + if not (value is None or isinstance(value, basestring)): raise ValueError( 'A single value must be specified' ' when parameter is a set') diff --git a/pgconn.c b/pgconn.c index 6d12ede4..e8548d32 100644 --- a/pgconn.c +++ b/pgconn.c @@ -3,7 +3,7 @@ * * The connection object - this file is part a of the C extension module. * - * Copyright (c) 2022 by the PyGreSQL Development Team + * Copyright (c) 2023 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pgdb.py b/pgdb.py index 7b0eaefc..7eaf9cb0 100644 --- a/pgdb.py +++ b/pgdb.py @@ -4,7 +4,7 @@ # # This file contains the DB-API 2 compatible pgdb module. # -# Copyright (c) 2022 by the PyGreSQL Development Team +# Copyright (c) 2023 by the PyGreSQL Development Team # # Please see the LICENSE.TXT file for specific restrictions. diff --git a/pginternal.c b/pginternal.c index 91c565be..6dcad8bc 100644 --- a/pginternal.c +++ b/pginternal.c @@ -3,7 +3,7 @@ * * Internal functions - this file is part a of the C extension module. * - * Copyright (c) 2022 by the PyGreSQL Development Team + * Copyright (c) 2023 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pglarge.c b/pglarge.c index ed8f1824..c080d658 100644 --- a/pglarge.c +++ b/pglarge.c @@ -3,7 +3,7 @@ * * Large object support - this file is part a of the C extension module. * - * Copyright (c) 2022 by the PyGreSQL Development Team + * Copyright (c) 2023 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pgmodule.c b/pgmodule.c index 07982aad..bbb4b0db 100644 --- a/pgmodule.c +++ b/pgmodule.c @@ -3,7 +3,7 @@ * * This is the main file for the C extension module. * - * Copyright (c) 2022 by the PyGreSQL Development Team + * Copyright (c) 2023 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pgnotice.c b/pgnotice.c index 7e5b93c7..ae6b2b68 100644 --- a/pgnotice.c +++ b/pgnotice.c @@ -3,7 +3,7 @@ * * The notice object - this file is part a of the C extension module. * - * Copyright (c) 2022 by the PyGreSQL Development Team + * Copyright (c) 2023 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pgquery.c b/pgquery.c index a04eb68b..852c848b 100644 --- a/pgquery.c +++ b/pgquery.c @@ -3,7 +3,7 @@ * * The query object - this file is part a of the C extension module. * - * Copyright (c) 2022 by the PyGreSQL Development Team + * Copyright (c) 2023 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pgsource.c b/pgsource.c index 4fa04365..053ad02f 100644 --- a/pgsource.c +++ b/pgsource.c @@ -3,7 +3,7 @@ * * The source object - this file is part a of the C extension module. * - * Copyright (c) 2022 by the PyGreSQL Development Team + * Copyright (c) 2023 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/setup.py b/setup.py index 3cfbe278..cdb20c4f 100755 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ # # PyGreSQL - a Python interface for the PostgreSQL database. # -# Copyright (c) 2022 by the PyGreSQL Development Team +# Copyright (c) 2023 by the PyGreSQL Development Team # # Please see the LICENSE.TXT file for specific restrictions. @@ -26,8 +26,8 @@ * PostgreSQL pg_config tool (usually included in the devel package) (the Windows installer has it as part of the database server feature) -PyGreSQL currently supports Python versions 2.7 and 3.5 to 3.10, -and PostgreSQL versions 9.0 to 9.6 and 10 to 14. +PyGreSQL currently supports Python versions 2.7 and 3.5 to 3.11, +and PostgreSQL versions 9.0 to 9.6 and 10 to 15. Use as follows: python setup.py build_ext # to build the module @@ -252,6 +252,7 @@ def finalize_options(self): 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', "Programming Language :: SQL", "Topic :: Database", "Topic :: Database :: Front-Ends", diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 32c21870..bd423d91 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -195,7 +195,7 @@ def testAttributeProtocolVersion(self): def testAttributeServerVersion(self): server_version = self.connection.server_version self.assertIsInstance(server_version, int) - self.assertTrue(90000 <= server_version < 150000) + self.assertTrue(90000 <= server_version < 160000) def testAttributeSocket(self): socket = self.connection.socket @@ -871,7 +871,7 @@ def testGetresultUtf8(self): # pass the query as unicode try: v = self.c.query(q).getresult()[0][0] - except(pg.DataError, pg.NotSupportedError): + except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") v = None self.assertIsInstance(v, str) @@ -2623,7 +2623,7 @@ def testSetBool(self): finally: pg.set_bool(use_bool) self.assertIsInstance(r, str) - self.assertIs(r, 't') + self.assertEqual(r, 't') pg.set_bool(True) try: r = query("select true::bool").getresult()[0][0] diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index dfb2eecd..ca87a607 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -252,7 +252,7 @@ def testAttributeProtocolVersion(self): def testAttributeServerVersion(self): server_version = self.db.server_version self.assertIsInstance(server_version, int) - self.assertTrue(90000 <= server_version < 150000) + self.assertTrue(90000 <= server_version < 160000) self.assertEqual(server_version, self.db.db.server_version) def testAttributeSocket(self): diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index d7c7a720..653fbb87 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -126,7 +126,7 @@ def testPqlibVersion(self): v = pg.get_pqlib_version() self.assertIsInstance(v, long) self.assertGreater(v, 90000) - self.assertLess(v, 150000) + self.assertLess(v, 160000) class TestParseArray(unittest.TestCase): diff --git a/tox.ini b/tox.ini index dd98fa29..67dee9fd 100644 --- a/tox.ini +++ b/tox.ini @@ -1,18 +1,18 @@ # config file for tox [tox] -envlist = py27,py3{5,6,7,8,9,10},flake8,docs +envlist = py27,py3{5,6,7,8,9,10,11},flake8,docs [testenv:flake8] -basepython = python3.9 -deps = flake8>=4,<5 +basepython = python3.11 +deps = flake8>=6,<7 commands = flake8 setup.py pg.py pgdb.py tests [testenv:docs] -basepython = python3.9 +basepython = python3.11 deps = - sphinx>=4.4,<5 + sphinx>=4.5,<5 cloud_sptheme>=1.10,<2 commands = sphinx-build -b html -nEW docs docs/_build/html From e142a51996cd098d1038f7945849e1201f5b04e0 Mon Sep 17 00:00:00 2001 From: justinpryzby Date: Sat, 26 Aug 2023 07:47:02 -0500 Subject: [PATCH 002/118] inserttable: test for errors and return number of tuples as str (#73) Contributed by: Justin Pryzby --- docs/contents/pg/connection.rst | 3 ++- pgconn.c | 17 +++++++++++++---- tests/test_classic_connection.py | 5 +++++ tests/test_classic_dbwrapper.py | 3 ++- tests/test_tutorial.py | 3 ++- 5 files changed, 24 insertions(+), 7 deletions(-) diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index a9fccfdf..c95adf59 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -494,7 +494,7 @@ inserttable -- insert an iterable into a table :param str table: the table name :param list values: iterable of row values, which must be lists or tuples :param list columns: list or tuple of column names - :rtype: None + :rtype: int :raises TypeError: invalid connection, bad argument type, or too many arguments :raises MemoryError: insert buffer could not be allocated :raises ValueError: unsupported values @@ -506,6 +506,7 @@ of the same size, containing the values for each inserted row. These may contain string, integer, long or double (real) values. ``columns`` is an optional tuple or list of column names to be passed on to the COPY command. +The number of rows affected is returned. .. warning:: diff --git a/pgconn.c b/pgconn.c index 6d12ede4..eb86abed 100644 --- a/pgconn.c +++ b/pgconn.c @@ -1012,7 +1012,7 @@ conn_inserttable(connObject *self, PyObject *args) Py_DECREF(iter_row); if (PyErr_Occurred()) { - PQerrorMessage(self->cnx); PyMem_Free(buffer); + PyMem_Free(buffer); return NULL; /* pass the iteration error */ } @@ -1026,9 +1026,18 @@ conn_inserttable(connObject *self, PyObject *args) PyMem_Free(buffer); - /* no error : returns nothing */ - Py_INCREF(Py_None); - return Py_None; + Py_BEGIN_ALLOW_THREADS + result = PQgetResult(self->cnx); + Py_END_ALLOW_THREADS + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + PyErr_SetString(PyExc_ValueError, PQerrorMessage(self->cnx)); + PQclear(result); + return NULL; + } else { + long ntuples = atol(PQcmdTuples(result)); + PQclear(result); + return PyInt_FromLong(ntuples); + } } /* Get transaction state. */ diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 32c21870..1fea33be 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -2023,6 +2023,11 @@ def testInserttableWithHugeListOfColumnNames(self): cols *= 2 self.assertRaises(MemoryError, self.c.inserttable, 'test', data, cols) + def testInserttableWithOutOfRangeData(self): + # try inserting data out of range for the column type + # Should raise a value error because of smallint out of range + self.assertRaises(ValueError, self.c.inserttable, 'test', [[33000]], ['i2']) + def testInserttableMaxValues(self): data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1), True, '2999-12-31', '11:59:59', 1e99, diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index dfb2eecd..bd481e09 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -4227,10 +4227,11 @@ def testInserttableFromQuery(self): self.createTable('test_table_to', 'n integer, t timestamp') for i in range(1, 4): query("insert into test_table_from values ($1, now())", i) - self.db.inserttable( + n = self.db.inserttable( 'test_table_to', query("select n, t::text from test_table_from")) data_from = query("select * from test_table_from").getresult() data_to = query("select * from test_table_to").getresult() + self.assertEqual(n, 3) self.assertEqual([row[0] for row in data_from], [1, 2, 3]) self.assertEqual(data_from, data_to) diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index dbc93024..6f968560 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -48,7 +48,8 @@ def test_all_steps(self): self.assertEqual(r, {'name': 'banana', 'id': 2}) more_fruits = 'cherimaya durian eggfruit fig grapefruit'.split() data = list(enumerate(more_fruits, start=3)) - db.inserttable('fruits', data) + n = db.inserttable('fruits', data) + self.assertEqual(n, 5) q = db.query('select * from fruits') r = str(q).splitlines() self.assertEqual(r[0], 'id| name ') From 2871d27cad0a404b74b40012ccbed48ad500452f Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 26 Aug 2023 15:04:06 +0200 Subject: [PATCH 003/118] Update GitHub workflows --- .github/workflows/docs.yml | 12 ++++++------ .github/workflows/lint.yml | 8 ++++---- .github/workflows/tests.yml | 22 ++++++++++++---------- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 32ea4e43..ec18c7ba 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -8,20 +8,20 @@ on: jobs: build: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v1 - - name: Set up Python 3.9 - uses: actions/setup-python@v1 + - uses: actions/checkout@v3 + - name: Set up Python 3.11 + uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11 - name: Install dependencies run: | sudo apt install libpq-dev python -m pip install --upgrade pip pip install . - pip install "sphinx>=4.4,<5" + pip install "sphinx>=4.5,<5" pip install "cloud_sptheme>=1.10,<2" - name: Create docs with Sphinx run: | diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 437449a1..205d8b54 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -7,18 +7,18 @@ on: jobs: checks: name: Quality checks run - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 strategy: fail-fast: false steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Install tox run: pip install tox - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11 - name: Run quality checks run: tox -e flake8,docs timeout-minutes: 5 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b6eeecd1..e9269da7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,7 +9,7 @@ on: jobs: tests: name: Unit tests run - runs-on: ubuntu-18.04 + runs-on: ubuntu-20.04 strategy: fail-fast: false @@ -22,15 +22,17 @@ jobs: - {python: "3.8", postgres: "12"} - {python: "3.9", postgres: "13"} - {python: "3.10", postgres: "14"} + - {python: "3.11", postgres: "15"} # Opposite extremes of the supported Py/PG range, other architecture - - {python: "2.7", postgres: "14", architecture: "x86"} - - {python: "3.5", postgres: "13", architecture: "x86"} - - {python: "3.6", postgres: "12", architecture: "x86"} - - {python: "3.7", postgres: "11", architecture: "x86"} - - {python: "3.8", postgres: "10", architecture: "x86"} - - {python: "3.9", postgres: "9.6", architecture: "x86"} - - {python: "3.10", postgres: "9.3", architecture: "x86"} + - {python: "2.7", postgres: "15", architecture: "x86"} + - {python: "3.5", postgres: "14", architecture: "x86"} + - {python: "3.6", postgres: "13", architecture: "x86"} + - {python: "3.7", postgres: "12", architecture: "x86"} + - {python: "3.8", postgres: "11", architecture: "x86"} + - {python: "3.9", postgres: "10", architecture: "x86"} + - {python: "3.10", postgres: "9.6", architecture: "x86"} + - {python: "3.11", postgres: "9.3", architecture: "x86"} env: PYGRESQL_DB: test @@ -54,10 +56,10 @@ jobs: --health-retries 5 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Install tox run: pip install tox - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - name: Run tests From e29ca9ac798b440f8cb65abc82b17ae4759740b1 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 26 Aug 2023 15:10:24 +0200 Subject: [PATCH 004/118] Fix tox issue with passenv --- tox.ini | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 67dee9fd..917e22c0 100644 --- a/tox.ini +++ b/tox.ini @@ -18,7 +18,9 @@ commands = sphinx-build -b html -nEW docs docs/_build/html [testenv] -passenv = PG* PYGRESQL_* +passenv = + PG* + PYGRESQL_* commands = python setup.py clean --all build_ext --force --inplace --strict --ssl-info --memory-size python -m unittest discover {posargs} From a2f51859d36a25cf36854d259689b1217f94c639 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 26 Aug 2023 15:17:33 +0200 Subject: [PATCH 005/118] Remove desupported versions from workflow --- .github/workflows/tests.yml | 19 ++++++++----------- tests/test_classic_connection.py | 3 ++- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e9269da7..46eac7c0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,9 +15,9 @@ jobs: fail-fast: false matrix: include: - - {python: "2.7", postgres: "9.3"} - - {python: "3.5", postgres: "9.6"} - - {python: "3.6", postgres: "10"} + # - {python: "2.7", postgres: "9.3"} + # - {python: "3.5", postgres: "9.6"} + # - {python: "3.6", postgres: "10"} - {python: "3.7", postgres: "11"} - {python: "3.8", postgres: "12"} - {python: "3.9", postgres: "13"} @@ -25,14 +25,11 @@ jobs: - {python: "3.11", postgres: "15"} # Opposite extremes of the supported Py/PG range, other architecture - - {python: "2.7", postgres: "15", architecture: "x86"} - - {python: "3.5", postgres: "14", architecture: "x86"} - - {python: "3.6", postgres: "13", architecture: "x86"} - - {python: "3.7", postgres: "12", architecture: "x86"} - - {python: "3.8", postgres: "11", architecture: "x86"} - - {python: "3.9", postgres: "10", architecture: "x86"} - - {python: "3.10", postgres: "9.6", architecture: "x86"} - - {python: "3.11", postgres: "9.3", architecture: "x86"} + - {python: "3.7", postgres: "15", architecture: "x86"} + - {python: "3.8", postgres: "14", architecture: "x86"} + - {python: "3.9", postgres: "13", architecture: "x86"} + - {python: "3.10", postgres: "12", architecture: "x86"} + - {python: "3.11", postgres: "11", architecture: "x86"} env: PYGRESQL_DB: test diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 01932f39..a66af902 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -2026,7 +2026,8 @@ def testInserttableWithHugeListOfColumnNames(self): def testInserttableWithOutOfRangeData(self): # try inserting data out of range for the column type # Should raise a value error because of smallint out of range - self.assertRaises(ValueError, self.c.inserttable, 'test', [[33000]], ['i2']) + self.assertRaises( + ValueError, self.c.inserttable, 'test', [[33000]], ['i2']) def testInserttableMaxValues(self): data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1), From 538bb9f40d2e78c845036b4dc9cfaa56e621b707 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 26 Aug 2023 21:14:18 +0200 Subject: [PATCH 006/118] Support generated columns in classic module (#83) --- docs/contents/changelog.rst | 7 ++ docs/contents/pg/db_wrapper.rst | 14 ++++ pg.py | 55 +++++++++++-- tests/test_classic_dbwrapper.py | 140 +++++++++++++++++++++++++++++++- 4 files changed, 207 insertions(+), 9 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 0944b349..a9b1b4fe 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -4,6 +4,13 @@ ChangeLog Version 5.2.5 (to be released) ------------------------------ - This version officially supports the new Python 3.11 and PostgreSQL 15. +- Two more improvements in the `inserttable()` method of the `pg` module + (thanks to Justin Pryzby for this contribution): + - error handling has been improved (#72) + - the method now returns the number of inserted rows (#73) +- Another improvement in the `pg` module (#83): + - generated columns can be requested with the `get_generated()` method + - generated columns are ignored by the insert, update and upsert method Version 5.2.4 (2022-03-26) -------------------------- diff --git a/docs/contents/pg/db_wrapper.rst b/docs/contents/pg/db_wrapper.rst index d2ef4e05..5d587f97 100644 --- a/docs/contents/pg/db_wrapper.rst +++ b/docs/contents/pg/db_wrapper.rst @@ -136,6 +136,20 @@ By default, only a limited number of simple types will be returned. You can get the registered types instead, if enabled by calling the :meth:`DB.use_regtypes` method. +get_generated -- get the generated columns of a table +----------------------------------------------------- + +.. method:: DB.get_generated(table) + + Get the generated columns of a table + + :param str table: name of table + :returns: an frozenset of column names + +Given the name of a table, digs out the set of generated columns. + +.. versionadded:: 5.2.5 + has_table_privilege -- check table privilege -------------------------------------------- diff --git a/pg.py b/pg.py index fbd97725..371c616b 100644 --- a/pg.py +++ b/pg.py @@ -1629,6 +1629,7 @@ def __init__(self, *args, **kw): self.dbname = db.db self._regtypes = False self._attnames = {} + self._generated = {} self._pkeys = {} self._privileges = {} self.adapter = Adapter(self) @@ -1657,6 +1658,17 @@ def __init__(self, *args, **kw): " WHERE a.attrelid OPERATOR(pg_catalog.=)" " %s::pg_catalog.regclass" " AND %s AND NOT a.attisdropped ORDER BY a.attnum") + if db.server_version < 100000: + self._query_generated = None + elif db.server_version < 120000: + self._query_generated = ( + "a.attidentity OPERATOR(pg_catalog.=) 'a'" + ) + else: + self._query_generated = ( + "(a.attidentity OPERATOR(pg_catalog.=) 'a' OR" + " a.attgenerated OPERATOR(pg_catalog.!=) '')" + ) db.set_cast_hook(self.dbtypes.typecast) # For debugging scripts, self.debug can be set # * to a string format specification (e.g. in CGI set to "%s
"), @@ -2130,7 +2142,7 @@ def get_relations(self, kinds=None, system=False): """Get list of relations in connected database of specified kinds. If kinds is None or empty, all kinds of relations are returned. - Otherwise kinds can be a string or sequence of type letters + Otherwise, kinds can be a string or sequence of type letters specifying which kind of relations you want to list. Set the system flag if you want to get the system relations as well. @@ -2190,6 +2202,32 @@ def get_attnames(self, table, with_oid=True, flush=False): attnames[table] = names # cache it return names + def get_generated(self, table, flush=False): + """Given the name of a table, dig out the set of generated columns. + + Returns a set of column names that are generated and unalterable. + + If flush is set, then the internal cache for generated columns will + be flushed. This may be necessary after the database schema or + the search path has been changed. + """ + query_generated = self._query_generated + if not query_generated: + return frozenset() + generated = self._generated + if flush: + generated.clear() + self._do_debug('The generated cache has been flushed') + try: # cache lookup + names = generated[table] + except KeyError: # cache miss, check the database + q = "a.attnum OPERATOR(pg_catalog.>) 0 AND " + query_generated + q = self._query_attnames % (_quote_if_unqualified('$1', table), q) + names = self.db.query(q, (table,)).getresult() + names = frozenset(name[0] for name in names) + generated[table] = names # cache it + return names + def use_regtypes(self, regtypes=None): """Use registered type names instead of simplified type names.""" if regtypes is None: @@ -2307,8 +2345,8 @@ def insert(self, table, row=None, **kw): be passed as the first parameter. The other parameters are used for providing the data of the row that shall be inserted into the table. If a dictionary is supplied as the second parameter, it starts with - that. Otherwise it uses a blank dictionary. Either way the dictionary - is updated from the keywords. + that. Otherwise, it uses a blank dictionary. + Either way the dictionary is updated from the keywords. The dictionary is then reloaded with the values actually inserted in order to pick up values modified by rules, triggers, etc. @@ -2321,13 +2359,14 @@ def insert(self, table, row=None, **kw): if 'oid' in row: del row['oid'] # do not insert oid attnames = self.get_attnames(table) + generated = self.get_generated(table) qoid = _oid_key(table) if 'oid' in attnames else None params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier names, values = [], [] for n in attnames: - if n in row: + if n in row and n not in generated: names.append(col(n)) values.append(adapt(row[n], attnames[n])) if not names: @@ -2360,6 +2399,7 @@ def update(self, table, row=None, **kw): if table.endswith('*'): table = table[:-1].rstrip() # need parent table name attnames = self.get_attnames(table) + generated = self.get_generated(table) qoid = _oid_key(table) if 'oid' in attnames else None if row is None: row = {} @@ -2390,7 +2430,7 @@ def update(self, table, row=None, **kw): values = [] keyname = set(keyname) for n in attnames: - if n in row and n not in keyname: + if n in row and n not in keyname and n not in generated: values.append('%s = %s' % (col(n), adapt(row[n], attnames[n]))) if not values: return row @@ -2461,13 +2501,14 @@ def upsert(self, table, row=None, **kw): if 'oid' in kw: del kw['oid'] # do not update oid attnames = self.get_attnames(table) + generated = self.get_generated(table) qoid = _oid_key(table) if 'oid' in attnames else None params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier names, values = [], [] for n in attnames: - if n in row: + if n in row and n not in generated: names.append(col(n)) values.append(adapt(row[n], attnames[n])) names, values = ', '.join(names), ', '.join(values) @@ -2480,7 +2521,7 @@ def upsert(self, table, row=None, **kw): keyname = set(keyname) keyname.add('oid') for n in attnames: - if n not in keyname: + if n not in keyname and n not in generated: value = kw.get(n, n in row) if value: if not isinstance(value, basestring): diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 3fa2db69..e97a23e2 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -184,8 +184,8 @@ def testAllDBAttributes(self): 'escape_literal', 'escape_string', 'fileno', 'get', 'get_as_dict', 'get_as_list', - 'get_attnames', 'get_cast_hook', - 'get_databases', 'get_notice_receiver', + 'get_attnames', 'get_cast_hook', 'get_databases', + 'get_generated', 'get_notice_receiver', 'get_parameter', 'get_relations', 'get_tables', 'getline', 'getlo', 'getnotify', 'has_table_privilege', 'host', @@ -1473,6 +1473,53 @@ def testGetAttnamesIsAttrDict(self): r = ' '.join(list(r.keys())) self.assertEqual(r, 'n alpha v gamma tau beta') + def testGetGenerated(self): + get_generated = self.db.get_generated + server_version = self.db.server_version + if server_version >= 100000: + self.assertRaises(pg.ProgrammingError, + self.db.get_generated, 'does_not_exist') + self.assertRaises(pg.ProgrammingError, + self.db.get_generated, 'has.too.many.dots') + r = get_generated('test') + self.assertIsInstance(r, frozenset) + self.assertFalse(r) + if server_version >= 100000: + table = 'test_get_generated_1' + self.createTable( + table, + 'i int generated always as identity primary key,' + ' j int generated always as identity,' + ' k int generated by default as identity,' + ' n serial, m int') + r = get_generated(table) + self.assertIsInstance(r, frozenset) + self.assertEqual(r, {'i', 'j'}) + if server_version >= 120000: + table = 'test_get_generated_2' + self.createTable( + table, + 'n int, m int generated always as (n + 3) stored,' + ' i int generated always as identity,' + ' j int generated by default as identity') + r = get_generated(table) + self.assertIsInstance(r, frozenset) + self.assertEqual(r, {'m', 'i'}) + + def testGetGeneratedIsCached(self): + server_version = self.db.server_version + if server_version < 100000: + return + get_generated = self.db.get_generated + query = self.db.query + table = 'test_get_generated_2' + self.createTable(table, 'i int primary key') + self.assertFalse(get_generated(table)) + query('alter table %s alter column i' + ' add generated always as identity' % table) + self.assertFalse(get_generated(table)) + self.assertEqual(get_generated(table, flush=True), {'i'}) + def testHasTablePrivilege(self): can = self.db.has_table_privilege self.assertEqual(can('test'), True) @@ -1918,6 +1965,32 @@ def testInsertIntoView(self): r = query(q).getresult() self.assertEqual(r, [(1234, 'abcd'), (5678, 'efgh')]) + def testInsertWithGeneratedColumns(self): + insert = self.db.insert + get = self.db.get + server_version = self.db.server_version + table = 'insert_test_table_2' + table_def = 'i int not null' + if server_version >= 100000: + table_def += ( + ', a int generated always as identity' + ', d int generated by default as identity primary key') + else: + table_def += ', a int not null default 1, d int primary key' + if server_version >= 120000: + table_def += ', j int generated always as (i + 7) stored' + else: + table_def += ', j int not null default 42' + self.createTable(table, table_def) + i, d = 35, 1001 + j = i + 7 + r = insert(table, {'i': i, 'd': d, 'a': 1, 'j': j}) + self.assertIsInstance(r, dict) + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + r = get(table, d) + self.assertIsInstance(r, dict) + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + def testUpdate(self): update = self.db.update query = self.db.query @@ -2089,6 +2162,38 @@ def testUpdateWithQuotedNames(self): self.assertEqual(r['much space'], 7007) self.assertEqual(r['Questions?'], 'When?') + def testUpdateWithGeneratedColumns(self): + update = self.db.update + get = self.db.get + query = self.db.query + server_version = self.db.server_version + table = 'update_test_table_2' + table_def = 'i int not null' + if server_version >= 100000: + table_def += ( + ', a int generated always as identity' + ', d int generated by default as identity primary key') + else: + table_def += ', a int not null default 1, d int primary key' + if server_version >= 120000: + table_def += ', j int generated always as (i + 7) stored' + else: + table_def += ', j int not null default 42' + self.createTable(table, table_def) + i, d = 35, 1001 + j = i + 7 + r = query('insert into %s (i, d) values (%d, %d)' % (table, i, d)) + self.assertEqual(r, '1') + r = get(table, d) + self.assertIsInstance(r, dict) + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + r['i'] += 1 + r = update(table, r) + i += 1 + if server_version >= 120000: + j += 1 + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + def testUpsert(self): upsert = self.db.upsert query = self.db.query @@ -2349,6 +2454,37 @@ def testUpsertWithQuotedNames(self): r = query(q).getresult() self.assertEqual(r, [(31, 9009, 'No.')]) + def testUpsertWithGeneratedColumns(self): + upsert = self.db.upsert + get = self.db.get + server_version = self.db.server_version + table = 'upsert_test_table_2' + table_def = 'i int not null' + if server_version >= 100000: + table_def += ( + ', a int generated always as identity' + ', d int generated by default as identity primary key') + else: + table_def += ', a int not null default 1, d int primary key' + if server_version >= 120000: + table_def += ', j int generated always as (i + 7) stored' + else: + table_def += ', j int not null default 42' + self.createTable(table, table_def) + i, d = 35, 1001 + j = i + 7 + r = upsert(table, {'i': i, 'd': d, 'a': 1, 'j': j}) + self.assertIsInstance(r, dict) + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + r['i'] += 1 + r = upsert(table, r) + i += 1 + if server_version >= 120000: + j += 1 + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + r = get(table, d) + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + def testClear(self): clear = self.db.clear f = False if pg.get_bool() else 'f' From 8a358e118717fded765d4f1633e39895b6ca26e3 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 26 Aug 2023 22:54:46 +0200 Subject: [PATCH 007/118] Add default typecast for sql_identifier --- docs/contents/changelog.rst | 2 ++ pg.py | 2 +- pgdb.py | 2 +- tests/test_classic_dbwrapper.py | 2 +- 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index a9b1b4fe..c55caaad 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -6,11 +6,13 @@ Version 5.2.5 (to be released) - This version officially supports the new Python 3.11 and PostgreSQL 15. - Two more improvements in the `inserttable()` method of the `pg` module (thanks to Justin Pryzby for this contribution): + - error handling has been improved (#72) - the method now returns the number of inserted rows (#73) - Another improvement in the `pg` module (#83): - generated columns can be requested with the `get_generated()` method - generated columns are ignored by the insert, update and upsert method +- Avoid internal query and error when casting the `sql_identifier` type (#82) Version 5.2.4 (2022-03-26) -------------------------- diff --git a/pg.py b/pg.py index 371c616b..181892ae 100644 --- a/pg.py +++ b/pg.py @@ -1066,7 +1066,7 @@ class Typecasts(dict): # (str functions are ignored but have been added for faster access) defaults = { 'char': str, 'bpchar': str, 'name': str, - 'text': str, 'varchar': str, + 'text': str, 'varchar': str, 'sql_identifier': str, 'bool': cast_bool, 'bytea': unescape_bytea, 'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int, 'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json, diff --git a/pgdb.py b/pgdb.py index 7eaf9cb0..d2d06f4d 100644 --- a/pgdb.py +++ b/pgdb.py @@ -555,7 +555,7 @@ class Typecasts(dict): # (str functions are ignored but have been added for faster access) defaults = { 'char': str, 'bpchar': str, 'name': str, - 'text': str, 'varchar': str, + 'text': str, 'varchar': str, 'sql_identifier': str, 'bool': cast_bool, 'bytea': unescape_bytea, 'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int, 'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode, diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index e97a23e2..4246a7c3 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -4899,7 +4899,7 @@ def testMunging(self): else: self.assertNotIn('oid(t)', r) - def testQeryInformationSchema(self): + def testQueryInformationSchema(self): q = "column_name" if self.db.server_version < 110000: q += "::text" # old version does not have sql_identifier array From 53e8f10d2eb6d2ab4179ec4c8fee4495cbd9852d Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 27 Aug 2023 00:19:54 +0200 Subject: [PATCH 008/118] Fix multiple calls of getresult() after send_query() --- docs/contents/changelog.rst | 1 + pgquery.c | 13 +++++++++---- tests/test_classic_connection.py | 11 +++++++++++ tests/test_classic_dbwrapper.py | 17 +++++++++++++++++ 4 files changed, 38 insertions(+), 4 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index c55caaad..57739271 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -13,6 +13,7 @@ Version 5.2.5 (to be released) - generated columns can be requested with the `get_generated()` method - generated columns are ignored by the insert, update and upsert method - Avoid internal query and error when casting the `sql_identifier` type (#82) +- Fix issue with multiple calls of `getresult()` after `send_query()` (#80) Version 5.2.4 (2022-03-26) -------------------------- diff --git a/pgquery.c b/pgquery.c index 852c848b..0d7ebc7d 100644 --- a/pgquery.c +++ b/pgquery.c @@ -139,8 +139,9 @@ _get_async_result(queryObject *self, int keep) { Py_END_ALLOW_THREADS if (!self->result) { /* end of result set, return None */ - Py_DECREF(self->pgcnx); - self->pgcnx = NULL; + self->max_row = 0; + self->num_fields = 0; + self->col_types = NULL; Py_INCREF(Py_None); return Py_None; } @@ -161,7 +162,7 @@ _get_async_result(queryObject *self, int keep) { } } else if (result == Py_None) { - /* It's would be confusing to return None here because the + /* It would be confusing to return None here because the caller has to call again until we return None. We can't just consume that final None because we don't know if there are additional statements following this one, so we return @@ -180,7 +181,12 @@ _get_async_result(queryObject *self, int keep) { Py_DECREF(self); return NULL; } + } else if (self->async == 2 && + !self->max_row && !self->num_fields && !self->col_types) { + Py_INCREF(Py_None); + return Py_None; } + /* return the query object itself as sentinel for a normal query result */ return (PyObject *)self; } @@ -722,7 +728,6 @@ query_namedresult(queryObject *self, PyObject *noargs) } if ((res_list = _get_async_result(self, 1)) == (PyObject *)self) { - res = PyObject_CallFunction(namediter, "(O)", self); if (!res) return NULL; if (PyList_Check(res)) return res; diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index a66af902..0152feb8 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -580,6 +580,17 @@ def testNamedresultAsync(self): self.assertEqual(v._fields, ('alias0',)) self.assertEqual(v.alias0, 0) self.assertIsNone(query.namedresult()) + self.assertIsNone(query.namedresult()) + + def testListFieldsAfterSecondGetResultAsync(self): + q = "select 1 as one" + query = self.c.send_query(q) + self.assertEqual(query.getresult(), [(1,)]) + self.assertEqual(query.listfields(), ('one',)) + self.assertIsNone(query.getresult()) + self.assertEqual(query.listfields(), ()) + self.assertIsNone(query.getresult()) + self.assertEqual(query.listfields(), ()) def testGet3Cols(self): q = "select 1,2,3" diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 4246a7c3..fac7e067 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -158,6 +158,23 @@ def testDeleteDb(self): self.assertRaises(pg.InternalError, db.close) del db + def testAsyncQueryBeforeDeletion(self): + db = DB() + query = db.send_query('select 1') + self.assertEqual(query.getresult(), [(1,)]) + self.assertIsNone(query.getresult()) + self.assertIsNone(query.getresult()) + del db + gc.collect() + + def testAsyncQueryAfterDeletion(self): + db = DB() + query = db.send_query('select 1') + del db + gc.collect() + self.assertIsNone(query.getresult()) + self.assertIsNone(query.getresult()) + class TestDBClassBasic(unittest.TestCase): """Test existence of the DB class wrapped pg connection methods.""" From c2a42905ecc84eefb58a0b9f4a08d4ea8839174b Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 27 Aug 2023 16:02:09 +0200 Subject: [PATCH 009/118] Test both param styles with DB API 2 --- tests/test_dbapi20.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 94b7ab73..a03dca93 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -95,6 +95,19 @@ def test_percent_sign(self): cur.execute("select 'a %% sign'") self.assertEqual(cur.fetchone(), ('a % sign',)) + def test_paramstyles(self): + self.assertEqual(pgdb.paramstyle, 'pyformat') + con = self._connect() + cur = con.cursor() + # parameters can be passed as tuple + cur.execute("select %s, %s, %s", (123, 'abc', True)) + self.assertEqual(cur.fetchone(), (123, 'abc', True)) + # parameters can be passed as dict + cur.execute("select %(one)s, %(two)s, %(one)s, %(three)s", { + "one": 123, "two": "abc", "three": True + }) + self.assertEqual(cur.fetchone(), (123, 'abc', 123, True)) + def test_callproc_no_params(self): con = self._connect() cur = con.cursor() From 8699ace9aee8ff9ccba38695a757f34b2f9149ff Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 27 Aug 2023 18:07:24 +0200 Subject: [PATCH 010/118] Add test that inserttable does not miss failures --- tests/test_classic_connection.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 0152feb8..5c043701 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -2185,6 +2185,20 @@ def testInsertTableBigRowSize(self): data = [(t,)] self.assertRaises(MemoryError, self.c.inserttable, 'test', data, ['t']) + def testInsertTableSmallIntOverflow(self): + rest_row = self.data[2][1:] + data = [(32000,) + rest_row] + self.c.inserttable('test', data) + self.assertEqual(self.get_back(), data) + data = [(33000,) + rest_row] + try: + self.c.inserttable('test', data) + except ValueError as e: + self.assertIn( + 'value "33000" is out of range for type smallint', str(e)) + else: + self.assertFalse('expected an error') + class TestDirectSocketAccess(unittest.TestCase): """Test copy command with direct socket access.""" From e167bbcc32a5bf70865f31bdcd20cb1f6c827ece Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 28 Aug 2023 12:43:31 +0200 Subject: [PATCH 011/118] Prepare patch release --- .bumpversion.cfg | 2 +- docs/about.txt | 2 +- docs/announce.rst | 4 ++-- docs/conf.py | 2 +- docs/contents/changelog.rst | 4 ++-- setup.py | 4 ++-- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 31f5835a..3a654eda 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 5.2.4 +current_version = 5.2.5 commit = False tag = False diff --git a/docs/about.txt b/docs/about.txt index d1492061..c472a304 100644 --- a/docs/about.txt +++ b/docs/about.txt @@ -36,7 +36,7 @@ on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -The current version PyGreSQL 5.2.4 needs PostgreSQL 9.0 to 9.6 or 10 to 15, and +The current version PyGreSQL 5.2.5 needs PostgreSQL 9.0 to 9.6 or 10 to 15, and Python 2.7 or 3.5 to 3.11. If you need to support older PostgreSQL versions or older Python 2.x versions, you can resort to the PyGreSQL 4.x versions that still support them. diff --git a/docs/announce.rst b/docs/announce.rst index cadf376b..a95cb949 100644 --- a/docs/announce.rst +++ b/docs/announce.rst @@ -3,10 +3,10 @@ PyGreSQL Announcements ====================== --------------------------------- -Release of PyGreSQL version 5.2.4 +Release of PyGreSQL version 5.2.5 --------------------------------- -Release 5.2.4 of PyGreSQL. +Release 5.2.5 of PyGreSQL. It is available at: https://pypi.org/project/PyGreSQL/. diff --git a/docs/conf.py b/docs/conf.py index 6a9f87e0..1e9e4113 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -68,7 +68,7 @@ # built documents. # # The full version, including alpha/beta/rc tags. -version = release = '5.2.4' +version = release = '5.2.5' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 57739271..bc8322f4 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,8 +1,8 @@ ChangeLog ========= -Version 5.2.5 (to be released) ------------------------------- +Version 5.2.5 (2023-08-28) +-------------------------- - This version officially supports the new Python 3.11 and PostgreSQL 15. - Two more improvements in the `inserttable()` method of the `pg` module (thanks to Justin Pryzby for this contribution): diff --git a/setup.py b/setup.py index cdb20c4f..89e07f94 100755 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ # # Please see the LICENSE.TXT file for specific restrictions. -"""Setup script for PyGreSQL version 5.2.4 +"""Setup script for PyGreSQL version 5.2.5 PyGreSQL is an open-source Python module that interfaces to a PostgreSQL database. It wraps the lower level C API library libpq @@ -52,7 +52,7 @@ from distutils.ccompiler import get_default_compiler from distutils.sysconfig import get_python_inc, get_python_lib -version = '5.2.4' +version = '5.2.5' if not (sys.version_info[:2] == (2, 7) or (3, 5) <= sys.version_info[:2] < (4, 0)): From f7bc4a36aa83b3e9ac2a6ae7d75ca2261bdbc8e3 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 28 Aug 2023 13:00:04 +0200 Subject: [PATCH 012/118] Fix tests for PostgreSQL < 9.5 --- tests/test_classic_dbwrapper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index fac7e067..fa3cc655 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -1526,7 +1526,7 @@ def testGetGenerated(self): def testGetGeneratedIsCached(self): server_version = self.db.server_version if server_version < 100000: - return + self.skipTest("database does not support generated columns") get_generated = self.db.get_generated query = self.db.query table = 'test_get_generated_2' @@ -2472,6 +2472,8 @@ def testUpsertWithQuotedNames(self): self.assertEqual(r, [(31, 9009, 'No.')]) def testUpsertWithGeneratedColumns(self): + if self.db.server_version < 90500: + self.skipTest('database does not support upsert') upsert = self.db.upsert get = self.db.get server_version = self.db.server_version From fd8748d5c35c606759e3caac1003947a46f5bd7e Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 28 Aug 2023 14:52:55 +0200 Subject: [PATCH 013/118] Add .readthedocs.yaml file --- .readthedocs.yaml | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 .readthedocs.yaml diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000..9712e405 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,22 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.11" + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/conf.py + +# We recommend specifying your dependencies to enable reproducible builds: +# https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html +python: + install: + - requirements: docs/requirements.txt From c5f1e58bc1cc2eeb7c6db6c0e61760b768815845 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 28 Aug 2023 15:25:08 +0200 Subject: [PATCH 014/118] Do not use custom domain --- .github/workflows/docs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index ec18c7ba..358659a1 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -33,6 +33,6 @@ jobs: github_token: ${{ secrets.GITHUB_TOKEN }} publish_branch: gh-pages publish_dir: docs/_build/html - cname: pygresql.org + # cname: pygresql.org enable_jekyll: false force_orphan: true From f3683fc3b39e534557317912863c0294a6c0b09a Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 28 Aug 2023 16:27:35 +0200 Subject: [PATCH 015/118] Update Sphinx, get rid of outdated cloud theme --- .github/workflows/docs.yml | 5 +- MANIFEST.in | 3 +- README.rst | 7 +- docs/.gitignore | 1 - docs/Makefile | 198 ++--------------------- docs/_static/pygresql.css_t | 86 ---------- docs/_templates/layout.html | 58 ------- docs/community/source.rst | 6 +- docs/conf.py | 307 +++--------------------------------- docs/{toc.txt => index.rst} | 4 +- docs/make.bat | 250 ++--------------------------- docs/requirements.txt | 3 +- docs/start.txt | 15 -- tox.ini | 3 +- 14 files changed, 57 insertions(+), 889 deletions(-) delete mode 100644 docs/.gitignore delete mode 100644 docs/_static/pygresql.css_t delete mode 100644 docs/_templates/layout.html rename docs/{toc.txt => index.rst} (64%) delete mode 100644 docs/start.txt diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 358659a1..5a9ef894 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -3,7 +3,7 @@ name: Release PyGreSQL documentation on: push: branches: - - master + - main jobs: build: @@ -21,8 +21,7 @@ jobs: sudo apt install libpq-dev python -m pip install --upgrade pip pip install . - pip install "sphinx>=4.5,<5" - pip install "cloud_sptheme>=1.10,<2" + pip install "sphinx>=7,<8" - name: Create docs with Sphinx run: | cd docs diff --git a/MANIFEST.in b/MANIFEST.in index 239841c7..9b263981 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -20,5 +20,4 @@ exclude docs/index.rst recursive-include docs/community *.rst recursive-include docs/contents *.rst recursive-include docs/download *.rst -recursive-include docs/_static *.css_t *.ico *.png -recursive-include docs/_templates *.html +recursive-include docs/_static *.ico *.png diff --git a/README.rst b/README.rst index 98bb30bb..a6054363 100644 --- a/README.rst +++ b/README.rst @@ -24,6 +24,7 @@ see the documentation. Documentation ------------- -The documentation is available at `pygresql.org `_. - -At mirror of the documentation can be found at `pygresql.readthedocs.io `_. +The documentation is available at +`pygresql.github.io/PyGreSQL/ `_ +and at `pygresql.readthedocs.io `_, +where you can also find the documentation for older versions. diff --git a/docs/.gitignore b/docs/.gitignore deleted file mode 100644 index 4a579446..00000000 --- a/docs/.gitignore +++ /dev/null @@ -1 +0,0 @@ -index.rst \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile index 0a1113c9..d4bb2cbb 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -1,192 +1,20 @@ -# Makefile for Sphinx documentation +# Minimal makefile for Sphinx documentation # -# You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -PAPER = +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . BUILDDIR = _build -# User-friendly check for sphinx-build -ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) -$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) -endif - -# Internal variables. -PAPEROPT_a4 = -D latex_paper_size=a4 -PAPEROPT_letter = -D latex_paper_size=letter -ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . -# the i18n builder cannot share the environment and doctrees with the others -I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . - -.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest coverage gettext - +# Put it first so that "make" without argument is like "make help". help: - @echo "Please use \`make ' where is one of" - @echo " html to make standalone HTML files" - @echo " dirhtml to make HTML files named index.html in directories" - @echo " singlehtml to make a single large HTML file" - @echo " pickle to make pickle files" - @echo " json to make JSON files" - @echo " htmlhelp to make HTML files and a HTML help project" - @echo " qthelp to make HTML files and a qthelp project" - @echo " applehelp to make an Apple Help Book" - @echo " devhelp to make HTML files and a Devhelp project" - @echo " epub to make an epub" - @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" - @echo " latexpdf to make LaTeX files and run them through pdflatex" - @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" - @echo " text to make text files" - @echo " man to make manual pages" - @echo " texinfo to make Texinfo files" - @echo " info to make Texinfo files and run them through makeinfo" - @echo " gettext to make PO message catalogs" - @echo " changes to make an overview of all changed/added/deprecated items" - @echo " xml to make Docutils-native XML files" - @echo " pseudoxml to make pseudoxml-XML files for display purposes" - @echo " linkcheck to check all external links for integrity" - @echo " doctest to run all doctests embedded in the documentation (if enabled)" - @echo " coverage to run coverage check of the documentation (if enabled)" - -clean: - rm -rf $(BUILDDIR)/* - -html: - $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." - -dirhtml: - $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." - -singlehtml: - $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml - @echo - @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." - -pickle: - $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle - @echo - @echo "Build finished; now you can process the pickle files." - -json: - $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json - @echo - @echo "Build finished; now you can process the JSON files." - -htmlhelp: - $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp - @echo - @echo "Build finished; now you can run HTML Help Workshop with the" \ - ".hhp project file in $(BUILDDIR)/htmlhelp." - -qthelp: - $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp - @echo - @echo "Build finished; now you can run "qcollectiongenerator" with the" \ - ".qhcp project file in $(BUILDDIR)/qthelp, like this:" - @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/PyGreSQL.qhcp" - @echo "To view the help file:" - @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/PyGreSQL.qhc" - -applehelp: - $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp - @echo - @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." - @echo "N.B. You won't be able to view it unless you put it in" \ - "~/Library/Documentation/Help or install it in your application" \ - "bundle." - -devhelp: - $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp - @echo - @echo "Build finished." - @echo "To view the help file:" - @echo "# mkdir -p $$HOME/.local/share/devhelp/PyGreSQL" - @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/PyGreSQL" - @echo "# devhelp" - -epub: - $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub - @echo - @echo "Build finished. The epub file is in $(BUILDDIR)/epub." - -latex: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo - @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." - @echo "Run \`make' in that directory to run these through (pdf)latex" \ - "(use \`make latexpdf' here to do that automatically)." - -latexpdf: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo "Running LaTeX files through pdflatex..." - $(MAKE) -C $(BUILDDIR)/latex all-pdf - @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." - -latexpdfja: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo "Running LaTeX files through platex and dvipdfmx..." - $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja - @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." - -text: - $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text - @echo - @echo "Build finished. The text files are in $(BUILDDIR)/text." - -man: - $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man - @echo - @echo "Build finished. The manual pages are in $(BUILDDIR)/man." - -texinfo: - $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo - @echo - @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." - @echo "Run \`make' in that directory to run these through makeinfo" \ - "(use \`make info' here to do that automatically)." - -info: - $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo - @echo "Running Texinfo files through makeinfo..." - make -C $(BUILDDIR)/texinfo info - @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." - -gettext: - $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale - @echo - @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." - -changes: - $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes - @echo - @echo "The overview file is in $(BUILDDIR)/changes." - -linkcheck: - $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck - @echo - @echo "Link check complete; look for any errors in the above output " \ - "or in $(BUILDDIR)/linkcheck/output.txt." - -doctest: - $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest - @echo "Testing of doctests in the sources finished, look at the " \ - "results in $(BUILDDIR)/doctest/output.txt." - -coverage: - $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage - @echo "Testing of coverage in the sources finished, look at the " \ - "results in $(BUILDDIR)/coverage/python.txt." + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -xml: - $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml - @echo - @echo "Build finished. The XML files are in $(BUILDDIR)/xml." +.PHONY: help Makefile -pseudoxml: - $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml - @echo - @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_static/pygresql.css_t b/docs/_static/pygresql.css_t deleted file mode 100644 index a3bc4de2..00000000 --- a/docs/_static/pygresql.css_t +++ /dev/null @@ -1,86 +0,0 @@ -{% macro experimental(keyword, value) %} - {% if value %} - -moz-{{keyword}}: {{value}}; - -webkit-{{keyword}}: {{value}}; - -o-{{keyword}}: {{value}}; - -ms-{{keyword}}: {{value}}; - {{keyword}}: {{value}}; - {% endif %} -{% endmacro %} - -{% macro border_radius(value) -%} - {{experimental("border-radius", value)}} -{% endmacro %} - -{% macro box_shadow(value) -%} - {{experimental("box-shadow", value)}} -{% endmacro %} - -.pageheader.related { - text-align: left; - padding: 10px 15px; - border: 1px solid #eeeeee; - margin-bottom: 10px; - {{border_radius("1em 1em 1em 1em")}} - {% if theme_borderless_decor | tobool %} - border-top: 0; - border-bottom: 0; - {% endif %} -} - -.pageheader.related .logo { - font-size: 36px; - font-style: italic; - letter-spacing: 5px; - margin-right: 2em; -} - -.pageheader.related .logo { - font-size: 36px; - font-style: italic; - letter-spacing: 5px; - margin-right: 2em; -} - -.pageheader.related .logo a, .pageheader.related .logo a:hover { - background: transparent; - color: {{ theme_relbarlinkcolor }}; - border: none; - text-decoration: none; - text-shadow: none; - {{box_shadow("none")}} -} - -.pageheader.related ul { - float: right; - margin: 2px 1em; -} - -.pageheader.related li { - float: left; - margin: 0 0 0 10px; -} - -.pageheader.related li a { - padding: 8px 12px; -} - -.norelbar .subtitle { - font-size: 14px; - line-height: 18px; - font-weight: bold; - letter-spacing: 4px; - text-align: right; - padding: 0 1em; - margin-top: -9px; -} - -.relbar-top .related.norelbar { - height: 22px; - border-bottom: 14px solid #eeeeee; -} - -.relbar-bottom .related.norelbar { - height: 22px; - border-top: 14px solid #eeeeee; -} diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html deleted file mode 100644 index 1cb2ddee..00000000 --- a/docs/_templates/layout.html +++ /dev/null @@ -1,58 +0,0 @@ -{%- extends "cloud/layout.html" %} - -{% set css_files = css_files + ["_static/pygresql.css"] %} - -{# - This layout adds a page header above the standard layout. - It also removes the relbars from all pages that are not part - of the core documentation in the contents/ directory, - adapting the navigation bar (breadcrumb) appropriately. -#} - -{% set is_content = pagename.startswith(('contents/', 'genindex', 'modindex', 'py-', 'search')) %} -{% if is_content %} -{% set master_doc = 'contents/index' %} -{% set parents = parents[1:] %} -{% endif %} - -{% block header %} - - - -{% endblock %} - -{% block relbar1 -%} -{%- if is_content -%} - {{ super() }} -{% else %} -
-{%- endif -%} -{%- endblock %} - -{% block relbar2 -%} -{%- if is_content -%} - {{ super() }} -{%- else -%} -
-{%- endif -%} -{%- endblock %} - -{% block content -%} -{%- if is_content -%} -{{ super() }} -{%- else -%} -
{{ super() }}
-{%- endif -%} -{%- endblock %} diff --git a/docs/community/source.rst b/docs/community/source.rst index 224985fd..497f6280 100644 --- a/docs/community/source.rst +++ b/docs/community/source.rst @@ -4,12 +4,12 @@ Access to the source repository The source code of PyGreSQL is available as a `Git `_ repository on `GitHub `_. -The current master branch of the repository can be cloned with the command:: +The current main branch of the repository can be cloned with the command:: git clone https://github.com/PyGreSQL/PyGreSQL.git -You can also download the master branch as a -`zip archive `_. +You can also download the main branch as a +`zip archive `_. Contributions can be proposed as `pull requests `_ on GitHub. diff --git a/docs/conf.py b/docs/conf.py index 1e9e4113..933c4e38 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,90 +1,26 @@ -# -*- coding: utf-8 -*- +# Configuration file for the Sphinx documentation builder. # -# PyGreSQL documentation build configuration file. -# -# This file is execfile()d with the current directory set to its -# containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. - -import sys -import os -import shlex -import shutil - -# Import Cloud theme (this will also automatically add the theme directory). -# Note: We add a navigation bar to the cloud them using a custom layout. -if os.environ.get('READTHEDOCS', None) == 'True': - # We cannot use our custom layout here, since RTD overrides layout.html. - use_cloud_theme = False -else: - try: - import cloud_sptheme - use_cloud_theme = True - except ImportError: - use_cloud_theme = False +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html -shutil.copyfile('start.txt' if use_cloud_theme else 'toc.txt', 'index.rst') - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -#sys.path.insert(0, os.path.abspath('.')) - -# -- General configuration ------------------------------------------------ - -# If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc'] +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] if use_cloud_theme else [] - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -# source_suffix = ['.rst', '.md'] -source_suffix = '.rst' - -# The encoding of source files. -#source_encoding = 'utf-8-sig' - -# The master toctree document. -master_doc = 'index' - -# General information about the project. project = 'PyGreSQL' author = 'The PyGreSQL team' copyright = '2023, ' + author -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The full version, including alpha/beta/rc tags. version = release = '5.2.5' -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# This is also used if you do content translation via gettext catalogs. -# Usually you set "language" from the command line for these cases. -language = None +language = 'en' -# There are two options for replacing |today|: either, you set today to some -# non-false value, then it is used: -#today = '' -# Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -exclude_patterns = ['_build'] +extensions = ['sphinx.ext.autodoc'] + +templates_path = ['_templates'] +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # List of pages which are included in other pages and therefore should # not appear in the toctree. @@ -93,8 +29,6 @@ 'community/mailinglist.rst', 'community/source.rst', 'community/issues.rst', 'community/support.rst', 'community/homes.rst'] -if use_cloud_theme: - exclude_patterns += ['about.rst'] # ignore certain warnings # (references to some of the Python names do not resolve correctly) @@ -102,13 +36,14 @@ nitpick_ignore = [ ('py:' + t, n) for t, names in { 'attr': ('arraysize', 'error', 'sqlstate', 'DatabaseError.sqlstate'), - 'class': ('bool', 'bytes', 'callable', 'class', + 'class': ('bool', 'bytes', 'callable', 'callables', 'class', 'dict', 'float', 'function', 'int', 'iterable', 'list', 'object', 'set', 'str', 'tuple', 'False', 'True', 'None', - 'namedtuple', 'OrderedDict', 'decimal.Decimal', + 'namedtuple', 'namedtuples', + 'OrderedDict', 'decimal.Decimal', 'bytes/str', 'list of namedtuples', 'tuple of callables', - 'type of first field', + 'first field', 'type of first field', 'Notice', 'DATETIME'), 'data': ('defbase', 'defhost', 'defopt', 'defpasswd', 'defport', 'defuser'), @@ -125,217 +60,15 @@ 'obj': ('False', 'True', 'None') }.items() for n in names] -# The reST default role (used for this markup: `text`) for all documents. -#default_role = None - -# If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True - -# If true, the current module name will be prepended to all description -# unit titles (such as .. function::). -#add_module_names = True -# If true, sectionauthor and moduleauthor directives will be shown in the -# output. They are ignored by default. -#show_authors = False -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -# A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] - - -# If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False - -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = False - - -# -- Options for HTML output ---------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -html_theme = 'cloud' if use_cloud_theme else 'default' - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -if use_cloud_theme: - html_theme_options = { - 'roottarget': 'contents/index', - 'defaultcollapsed': True, - 'shaded_decor': True} -else: - html_theme_options = {} - -# Add any paths that contain custom themes here, relative to this directory. -html_theme_path = ['_themes'] +html_theme = 'alabaster' +html_static_path = ['_static'] -# The name for this set of Sphinx documents. If None, it defaults to -# " v documentation". html_title = 'PyGreSQL %s' % version -if use_cloud_theme: - html_title += ' documentation' -# A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None - -# The name of an image file (relative to this directory) to place at the top -# of the sidebar. html_logo = '_static/pygresql.png' - -# The name of an image file (within the static path) to use as favicon of the -# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 -# pixels large. html_favicon = '_static/favicon.ico' - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] - -# Add any extra paths that contain custom files (such as robots.txt or -# .htaccess) here, relative to this directory. These files are copied -# directly to the root of the documentation. -#html_extra_path = [] - -# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, -# using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' - -# If true, SmartyPants will be used to convert quotes and dashes to -# typographically correct entities. -#html_use_smartypants = True - -# Custom sidebar templates, maps document names to template names. -#html_sidebars = {} - -# Additional templates that should be rendered to pages, maps page names to -# template names. -#html_additional_pages = {} - -# If false, no module index is generated. -#html_domain_indices = True - -# If false, no index is generated. -#html_use_index = True - -# If true, the index is split into individual pages for each letter. -#html_split_index = False - -# If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True - -# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True - -# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True - -# If true, an OpenSearch description file will be output, and all pages will -# contain a tag referring to it. The value of this option must be the -# base URL from which the finished HTML is served. -#html_use_opensearch = '' - -# This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None - -# Language to be used for generating the HTML full-text search index. -# Sphinx supports the following languages: -# 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' -# 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' -#html_search_language = 'en' - -# A dictionary with options for the search language support, empty by default. -# Now only 'ja' uses this config value -#html_search_options = {'type': 'default'} - -# The name of a javascript file (relative to the configuration directory) that -# implements a search results scorer. If empty, the default will be used. -#html_search_scorer = 'scorer.js' - -# Output file base name for HTML help builder. -htmlhelp_basename = 'PyGreSQLdoc' - - -# -- Options for LaTeX output --------------------------------------------- - -latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', - -# Latex figure (float) alignment -#'figure_align': 'htbp', -} - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, documentclass [howto, manual, or own class]). -latex_documents = [ - (master_doc, 'PyGreSQL.tex', 'PyGreSQL Documentation', - author, 'manual'), -] - -# The name of an image file (relative to this directory) to place at the top of -# the title page. -#latex_logo = None - -# For "manual" documents, if this is true, then toplevel headings are parts, -# not chapters. -#latex_use_parts = False - -# If true, show page references after internal links. -#latex_show_pagerefs = False - -# If true, show URL addresses after external links. -#latex_show_urls = False - -# Documents to append as an appendix to all manuals. -#latex_appendices = [] - -# If false, no module index is generated. -#latex_domain_indices = True - - -# -- Options for manual page output --------------------------------------- - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'pygresql', 'PyGreSQL Documentation', [author], 1) -] - -# If true, show URL addresses after external links. -#man_show_urls = False - - -# -- Options for Texinfo output ------------------------------------------- - -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -texinfo_documents = [ - (master_doc, 'PyGreSQL', u'PyGreSQL Documentation', - author, 'PyGreSQL', 'One line description of project.', - 'Miscellaneous'), -] - -# Documents to append as an appendix to all manuals. -#texinfo_appendices = [] - -# If false, no module index is generated. -#texinfo_domain_indices = True - -# How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' - -# If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False diff --git a/docs/toc.txt b/docs/index.rst similarity index 64% rename from docs/toc.txt rename to docs/index.rst index 441021b4..c40103a8 100644 --- a/docs/toc.txt +++ b/docs/index.rst @@ -1,5 +1,3 @@ -.. PyGreSQL index page with toc (for use without cloud theme) - Welcome to PyGreSQL =================== @@ -11,4 +9,4 @@ Welcome to PyGreSQL announce download/index contents/index - community/index \ No newline at end of file + community/index diff --git a/docs/make.bat b/docs/make.bat index b8571b60..954237b9 100644 --- a/docs/make.bat +++ b/docs/make.bat @@ -1,62 +1,16 @@ @ECHO OFF +pushd %~dp0 + REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) +set SOURCEDIR=. set BUILDDIR=_build -set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . -set I18NSPHINXOPTS=%SPHINXOPTS% . -if NOT "%PAPER%" == "" ( - set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% - set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% -) - -if "%1" == "" goto help - -if "%1" == "help" ( - :help - echo.Please use `make ^` where ^ is one of - echo. html to make standalone HTML files - echo. dirhtml to make HTML files named index.html in directories - echo. singlehtml to make a single large HTML file - echo. pickle to make pickle files - echo. json to make JSON files - echo. htmlhelp to make HTML files and a HTML help project - echo. qthelp to make HTML files and a qthelp project - echo. devhelp to make HTML files and a Devhelp project - echo. epub to make an epub - echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter - echo. text to make text files - echo. man to make manual pages - echo. texinfo to make Texinfo files - echo. gettext to make PO message catalogs - echo. changes to make an overview over all changed/added/deprecated items - echo. xml to make Docutils-native XML files - echo. pseudoxml to make pseudoxml-XML files for display purposes - echo. linkcheck to check all external links for integrity - echo. doctest to run all doctests embedded in the documentation if enabled - echo. coverage to run coverage check of the documentation if enabled - goto end -) - -if "%1" == "clean" ( - for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i - del /q /s %BUILDDIR%\* - goto end -) - - -REM Check if sphinx-build is available and fallback to Python version if any -%SPHINXBUILD% 1>NUL 2>NUL -if errorlevel 9009 goto sphinx_python -goto sphinx_ok -:sphinx_python - -set SPHINXBUILD=python -m sphinx.__init__ -%SPHINXBUILD% 2> nul +%SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx @@ -65,199 +19,17 @@ if errorlevel 9009 ( echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ + echo.https://www.sphinx-doc.org/ exit /b 1 ) -:sphinx_ok - - -if "%1" == "html" ( - %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/html. - goto end -) - -if "%1" == "dirhtml" ( - %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. - goto end -) - -if "%1" == "singlehtml" ( - %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. - goto end -) - -if "%1" == "pickle" ( - %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the pickle files. - goto end -) - -if "%1" == "json" ( - %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the JSON files. - goto end -) - -if "%1" == "htmlhelp" ( - %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run HTML Help Workshop with the ^ -.hhp project file in %BUILDDIR%/htmlhelp. - goto end -) - -if "%1" == "qthelp" ( - %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run "qcollectiongenerator" with the ^ -.qhcp project file in %BUILDDIR%/qthelp, like this: - echo.^> qcollectiongenerator %BUILDDIR%\qthelp\PyGreSQL.qhcp - echo.To view the help file: - echo.^> assistant -collectionFile %BUILDDIR%\qthelp\PyGreSQL.ghc - goto end -) - -if "%1" == "devhelp" ( - %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. - goto end -) - -if "%1" == "epub" ( - %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The epub file is in %BUILDDIR%/epub. - goto end -) - -if "%1" == "latex" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdf" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf - cd %~dp0 - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdfja" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf-ja - cd %~dp0 - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "text" ( - %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The text files are in %BUILDDIR%/text. - goto end -) - -if "%1" == "man" ( - %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The manual pages are in %BUILDDIR%/man. - goto end -) - -if "%1" == "texinfo" ( - %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. - goto end -) - -if "%1" == "gettext" ( - %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The message catalogs are in %BUILDDIR%/locale. - goto end -) - -if "%1" == "changes" ( - %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes - if errorlevel 1 exit /b 1 - echo. - echo.The overview file is in %BUILDDIR%/changes. - goto end -) - -if "%1" == "linkcheck" ( - %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck - if errorlevel 1 exit /b 1 - echo. - echo.Link check complete; look for any errors in the above output ^ -or in %BUILDDIR%/linkcheck/output.txt. - goto end -) - -if "%1" == "doctest" ( - %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest - if errorlevel 1 exit /b 1 - echo. - echo.Testing of doctests in the sources finished, look at the ^ -results in %BUILDDIR%/doctest/output.txt. - goto end -) +if "%1" == "" goto help -if "%1" == "coverage" ( - %SPHINXBUILD% -b coverage %ALLSPHINXOPTS% %BUILDDIR%/coverage - if errorlevel 1 exit /b 1 - echo. - echo.Testing of coverage in the sources finished, look at the ^ -results in %BUILDDIR%/coverage/python.txt. - goto end -) +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end -if "%1" == "xml" ( - %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The XML files are in %BUILDDIR%/xml. - goto end -) - -if "%1" == "pseudoxml" ( - %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. - goto end -) +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end +popd diff --git a/docs/requirements.txt b/docs/requirements.txt index a59b8f44..9cd8b2f5 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,2 +1 @@ -sphinx>=4.4,<5 -cloud_sptheme>=1.10,<2 +sphinx>=7,<8 diff --git a/docs/start.txt b/docs/start.txt deleted file mode 100644 index 5166896a..00000000 --- a/docs/start.txt +++ /dev/null @@ -1,15 +0,0 @@ -.. PyGreSQL index page without toc (for use with cloud theme) - -Welcome to PyGreSQL -=================== - -.. toctree:: - :hidden: - - copyright - announce - download/index - contents/index - community/index - -.. include:: about.txt \ No newline at end of file diff --git a/tox.ini b/tox.ini index 917e22c0..1199be00 100644 --- a/tox.ini +++ b/tox.ini @@ -12,8 +12,7 @@ commands = [testenv:docs] basepython = python3.11 deps = - sphinx>=4.5,<5 - cloud_sptheme>=1.10,<2 + sphinx>=7,<8 commands = sphinx-build -b html -nEW docs docs/_build/html From 89eaebaeecd6a94ed532d1eb04f32227cb6f018d Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 14:02:07 +0200 Subject: [PATCH 016/118] Add dev container for VS Code Also start de-supporting older Python versions. --- .devcontainer/devcontainer.json | 63 +++++++++++++++++++++++ .devcontainer/provision.sh | 82 ++++++++++++++++++++++++++++++ .github/workflows/docs.yml | 52 +++++++++---------- .github/workflows/lint.yml | 6 ++- .github/workflows/tests.yml | 31 ++++++----- .vscode/settings.json | 6 +++ MANIFEST.in | 1 - tests/config.py | 21 +++++--- tests/test_classic_notification.py | 8 --- tox.ini | 4 +- 10 files changed, 212 insertions(+), 62 deletions(-) create mode 100644 .devcontainer/devcontainer.json create mode 100644 .devcontainer/provision.sh create mode 100644 .vscode/settings.json diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..c1374910 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,63 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/ubuntu +{ + "name": "PyGreSQL", + // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile + "dockerComposeFile": "docker-compose.yml", + "service": "dev", + "workspaceFolder": "/workspace", + "customizations": { + "vscode": { + // Set *default* container specific settings.json values on container create. + "settings": { + "terminal.integrated.profiles.linux": { + "bash": { + "path": "/bin/bash" + } + }, + "sqltools.connections": [ + { + "name": "Container database", + "driver": "PostgreSQL", + "previewLimit": 50, + "server": "pg15", + "port": 5432, + "database": "test", + "username": "test", + "password": "test" + } + ], + "python.pythonPath": "/usr/local/bin/python", + "python.analysis.typeCheckingMode": "basic", + "python.testing.unittestEnabled": true, + "editor.formatOnSave": true, + "editor.renderWhitespace": "all", + "editor.rulers": [ + 79 + ] + }, + // Add the IDs of extensions you want installed when the container is created. + "extensions": [ + "ms-azuretools.vscode-docker", + "ms-python.python", + "ms-vscode.cpptools", + "mtxr.sqltools", + "njpwerner.autodocstring", + "redhat.vscode-yaml", + "eamodio.gitlens", + "streetsidesoftware.code-spell-checker", + "lextudio.restructuredtext" + ] + } + }, + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + // Use 'postCreateCommand' to run commands after the container is created. + "postCreateCommand": "bash /workspace/.devcontainer/provision.sh" + // Configure tool-specific properties. + // "customizations": {}, + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "root" +} \ No newline at end of file diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh new file mode 100644 index 00000000..5cea536f --- /dev/null +++ b/.devcontainer/provision.sh @@ -0,0 +1,82 @@ +#!/usr/bin/bash + +# install development environment for PyGreSQL + +export DEBIAN_FRONTEND=noninteractive + +sudo apt-get update +sudo apt-get -y upgrade + +# install base utilities and configure time zone + +sudo ln -fs /usr/share/zoneinfo/UTC /etc/localtime +sudo apt-get install -y apt-utils software-properties-common +sudo apt-get install -y tzdata +sudo dpkg-reconfigure --frontend noninteractive tzdata + +sudo apt-get install -y rpm wget zip + +# install all supported Python versions + +sudo add-apt-repository -y ppa:deadsnakes/ppa +sudo apt-get update + +sudo apt-get install -y python3.7 python3.7-dev python3.7-distutils +sudo apt-get install -y python3.8 python3.8-dev python3.8-distutils +sudo apt-get install -y python3.9 python3.9-dev python3.9-distutils +sudo apt-get install -y python3.10 python3.10-dev python3.10-distutils +sudo apt-get install -y python3.11 python3.11-dev python3.11-distutils + +# install testing tool + +sudo apt-get install -y tox + +# install PostgreSQL client tools + +sudo apt-get install -y postgresql libpq-dev + +for pghost in pg10 pg12 pg14 pg15 +do + export PGHOST=$pghost + export PGDATABASE=postgres + export PGUSER=postgres + export PGPASSWORD=postgres + + createdb -E UTF8 -T template0 test + createdb -E SQL_ASCII -T template0 test_ascii + createdb -E LATIN1 -l C -T template0 test_latin1 + createdb -E LATIN9 -l C -T template0 test_latin9 + createdb -E ISO_8859_5 -l C -T template0 test_cyrillic + + psql -c "create user test with password 'test'" + + psql -c "grant create on database test to test" + psql -c "grant create on database test_ascii to test" + psql -c "grant create on database test_latin1 to test" + psql -c "grant create on database test_latin9 to test" + psql -c "grant create on database test_cyrillic to test" + + psql -c "grant create on schema public to test" test + psql -c "grant create on schema public to test" test_ascii + psql -c "grant create on schema public to test" test_latin1 + psql -c "grant create on schema public to test" test_latin9 + psql -c "grant create on schema public to test" test_cyrillic + + psql -c "create extension hstore" test + psql -c "create extension hstore" test_ascii + psql -c "create extension hstore" test_latin1 + psql -c "create extension hstore" test_latin9 + psql -c "create extension hstore" test_cyrillic +done + +export PGHOST=pg15 +export PGPORT=5432 +export PGDATABASE=test +export PGUSER=test +export PGPASSWORD=test + +export PYGRESQL_DB=$PGDATABASE +export PYGRESQL_HOST=$PGHOST +export PYGRESQL_PORT=$PGPORT +export PYGRESQL_USER=$PGUSER +export PYGRESQL_PASSWD=$PGPASSWORD diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 5a9ef894..aae221a0 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -1,4 +1,4 @@ -name: Release PyGreSQL documentation +name: Publish PyGreSQL documentation on: push: @@ -7,31 +7,31 @@ on: jobs: build: - runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.11 - uses: actions/setup-python@v4 - with: - python-version: 3.11 - - name: Install dependencies - run: | - sudo apt install libpq-dev - python -m pip install --upgrade pip - pip install . - pip install "sphinx>=7,<8" - - name: Create docs with Sphinx - run: | - cd docs - make html - - name: Deploy docs to GitHub pages - uses: peaceiris/actions-gh-pages@v3 - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - publish_branch: gh-pages - publish_dir: docs/_build/html - # cname: pygresql.org - enable_jekyll: false - force_orphan: true + - name: CHeck out repository + uses: actions/checkout@v3 + - name: Set up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: 3.11 + - name: Install dependencies + run: | + sudo apt install libpq-dev + python -m pip install --upgrade pip + pip install . + pip install "sphinx>=7,<8" + - name: Create docs with Sphinx + run: | + cd docs + make html + - name: Deploy docs to GitHub pages + uses: peaceiris/actions-gh-pages@v3 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_branch: gh-pages + publish_dir: docs/_build/html + cname: pygresql.org + enable_jekyll: false + force_orphan: true diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 205d8b54..54ae2fd3 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,10 +13,12 @@ jobs: fail-fast: false steps: - - uses: actions/checkout@v3 + - name: Check out repository + uses: actions/checkout@v3 - name: Install tox run: pip install tox - - uses: actions/setup-python@v4 + - name: Setup Python + uses: actions/setup-python@v4 with: python-version: 3.11 - name: Run quality checks diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 46eac7c0..ca8e4a36 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,27 +9,24 @@ on: jobs: tests: name: Unit tests run - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 strategy: fail-fast: false matrix: include: - # - {python: "2.7", postgres: "9.3"} - # - {python: "3.5", postgres: "9.6"} - # - {python: "3.6", postgres: "10"} - - {python: "3.7", postgres: "11"} - - {python: "3.8", postgres: "12"} - - {python: "3.9", postgres: "13"} - - {python: "3.10", postgres: "14"} - - {python: "3.11", postgres: "15"} + - { python: "3.7", postgres: "11" } + - { python: "3.8", postgres: "12" } + - { python: "3.9", postgres: "13" } + - { python: "3.10", postgres: "14" } + - { python: "3.11", postgres: "15" } # Opposite extremes of the supported Py/PG range, other architecture - - {python: "3.7", postgres: "15", architecture: "x86"} - - {python: "3.8", postgres: "14", architecture: "x86"} - - {python: "3.9", postgres: "13", architecture: "x86"} - - {python: "3.10", postgres: "12", architecture: "x86"} - - {python: "3.11", postgres: "11", architecture: "x86"} + - { python: "3.7", postgres: "15", architecture: "x86" } + - { python: "3.8", postgres: "14", architecture: "x86" } + - { python: "3.9", postgres: "13", architecture: "x86" } + - { python: "3.10", postgres: "12", architecture: "x86" } + - { python: "3.11", postgres: "11", architecture: "x86" } env: PYGRESQL_DB: test @@ -53,10 +50,12 @@ jobs: --health-retries 5 steps: - - uses: actions/checkout@v3 + - name: Check out repository + uses: actions/checkout@v3 - name: Install tox run: pip install tox - - uses: actions/setup-python@v4 + - name: Setup Python + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - name: Run tests diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..9ee86e71 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "[python]": { + "editor.defaultFormatter": "ms-python.autopep8" + }, + "python.formatting.provider": "none" +} \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in index 9b263981..e6e9e5a9 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -9,7 +9,6 @@ include LICENSE.txt include tox.ini recursive-include tests *.py -exclude tests/LOCAL_PyGreSQL.py include docs/Makefile include docs/make.bat diff --git a/tests/config.py b/tests/config.py index a6082593..a6dcd3a3 100644 --- a/tests/config.py +++ b/tests/config.py @@ -4,8 +4,10 @@ from os import environ # We need a database to test against. -# If LOCAL_PyGreSQL.py exists, we will get our information from that. -# Otherwise, we use the defaults. + +# The connection parameters are taken from the usual PG* environment +# variables and can be overridden with PYGRESQL_* environment variables +# or values specified in the file .LOCAL_PyGreSQL or LOCAL_PyGreSQL.py. # The tests should be run with various PostgreSQL versions and databases # created with different encodings and locales. Particularly, make sure the @@ -13,11 +15,16 @@ # The current user must have create schema privilege on the database. -dbname = environ.get('PYGRESQL_DB', 'unittest') -dbhost = environ.get('PYGRESQL_HOST', None) -dbport = environ.get('PYGRESQL_PORT', 5432) -dbuser = environ.get('PYGRESQL_USER', None) -dbpasswd = environ.get('PYGRESQL_PASSWD', None) +get = environ.get + +dbname = get('PYGRESQL_DB', get('PGDATABASE')) +dbhost = get('PYGRESQL_HOST', get('PGHOST')) +dbport = get('PYGRESQL_PORT', get('PGPORT')) +dbuser = get('PYGRESQL_USER', get('PGUSER')) +dbpasswd = get('PYGRESQL_PASSWD', get('PGPASSWORD')) + +if dbport: + dbport = int(dbport) try: from .LOCAL_PyGreSQL import * # noqa: F401 diff --git a/tests/test_classic_notification.py b/tests/test_classic_notification.py index 29e6921d..39f607df 100755 --- a/tests/test_classic_notification.py +++ b/tests/test_classic_notification.py @@ -21,14 +21,6 @@ debug = False # let DB wrapper print debugging output -try: - from .LOCAL_PyGreSQL import * # noqa: F401 -except (ImportError, ValueError): - try: - from LOCAL_PyGreSQL import * # noqa: F401 - except ImportError: - pass - def DB(): """Create a DB wrapper object connecting to the test database.""" diff --git a/tox.ini b/tox.ini index 1199be00..d48b44c7 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ # config file for tox [tox] -envlist = py27,py3{5,6,7,8,9,10,11},flake8,docs +envlist = py3{7,8,9,10,11},flake8,docs [testenv:flake8] basepython = python3.11 @@ -22,4 +22,4 @@ passenv = PYGRESQL_* commands = python setup.py clean --all build_ext --force --inplace --strict --ssl-info --memory-size - python -m unittest discover {posargs} + python -m unittest {posargs:discover} From 5771ad75d98863fe97d69af248ba14855d165a7c Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 17:38:24 +0200 Subject: [PATCH 017/118] Properly set dev environment variables --- .devcontainer/dev.env | 11 +++++++++++ .devcontainer/provision.sh | 12 ------------ 2 files changed, 11 insertions(+), 12 deletions(-) create mode 100644 .devcontainer/dev.env diff --git a/.devcontainer/dev.env b/.devcontainer/dev.env new file mode 100644 index 00000000..996ee8d2 --- /dev/null +++ b/.devcontainer/dev.env @@ -0,0 +1,11 @@ +PGHOST=pg15 +PGPORT=5432 +PGDATABASE=test +PGUSER=test +PGPASSWORD=test + +PYGRESQL_DB=test +PYGRESQL_HOST=pg15 +PYGRESQL_PORT=5432 +PYGRESQL_USER=test +PYGRESQL_PASSWD=test diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index 5cea536f..b47abb8c 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -68,15 +68,3 @@ do psql -c "create extension hstore" test_latin9 psql -c "create extension hstore" test_cyrillic done - -export PGHOST=pg15 -export PGPORT=5432 -export PGDATABASE=test -export PGUSER=test -export PGPASSWORD=test - -export PYGRESQL_DB=$PGDATABASE -export PYGRESQL_HOST=$PGHOST -export PYGRESQL_PORT=$PGPORT -export PYGRESQL_USER=$PGUSER -export PYGRESQL_PASSWD=$PGPASSWORD From 6c03eba7ebd3f998ab9f08e7a4f21171221796e1 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 17:40:26 +0200 Subject: [PATCH 018/118] Ignore VS Code settings --- .gitignore | 1 + .vscode/settings.json | 6 ------ 2 files changed, 1 insertion(+), 6 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.gitignore b/.gitignore index 71300f9e..83732331 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,4 @@ Thumbs.db .idea/ .vs/ +.vscode/ diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 9ee86e71..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "[python]": { - "editor.defaultFormatter": "ms-python.autopep8" - }, - "python.formatting.provider": "none" -} \ No newline at end of file From 97d2a258879c1fb9e22754e6ea189eac917b109c Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 17:46:31 +0200 Subject: [PATCH 019/118] Start desupporting old Python and Postgres versions --- docs/about.txt | 7 +++-- docs/announce.rst | 17 +++++------- docs/contents/install.rst | 2 +- docs/contents/pg/adaptation.rst | 4 +-- docs/contents/pg/connection.rst | 2 +- docs/contents/pgdb/adaptation.rst | 4 +-- pg.py | 13 ++------- pgdb.py | 14 +++------- setup.py | 45 ++++++++----------------------- tests/test_classic_connection.py | 3 +-- tests/test_classic_dbwrapper.py | 3 --- 11 files changed, 32 insertions(+), 82 deletions(-) diff --git a/docs/about.txt b/docs/about.txt index c472a304..04f615e1 100644 --- a/docs/about.txt +++ b/docs/about.txt @@ -36,7 +36,6 @@ on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -The current version PyGreSQL 5.2.5 needs PostgreSQL 9.0 to 9.6 or 10 to 15, and -Python 2.7 or 3.5 to 3.11. If you need to support older PostgreSQL versions or -older Python 2.x versions, you can resort to the PyGreSQL 4.x versions that -still support them. +The current version PyGreSQL 6.0 needs PostgreSQL 10 to 15, and Python +3.7 to 3.11. If you need to support older PostgreSQL or Python versions, +you can resort to the PyGreSQL 5.x versions that still support them. diff --git a/docs/announce.rst b/docs/announce.rst index a95cb949..d0a5f19c 100644 --- a/docs/announce.rst +++ b/docs/announce.rst @@ -2,11 +2,11 @@ PyGreSQL Announcements ====================== ---------------------------------- -Release of PyGreSQL version 5.2.5 ---------------------------------- +------------------------------- +Release of PyGreSQL version 6.0 +------------------------------- -Release 5.2.5 of PyGreSQL. +Release 6.0 of PyGreSQL. It is available at: https://pypi.org/project/PyGreSQL/. @@ -17,13 +17,10 @@ Please refer to `changelog.txt `_ for things that have changed in this version. This version has been built and unit tested on: - - NetBSD - - FreeBSD - - openSUSE - Ubuntu - - Windows 7 and 10 with both MinGW and Visual Studio - - PostgreSQL 9.0 to 9.6 and 10 to 15 (32 and 64bit) - - Python 2.7 and 3.5 to 3.11 (32 and 64bit) + - Windows 7 and 10 with Visual Studio + - PostgreSQL 10 to 15 (32 and 64bit) + - Python 3.7 to 3.11 (32 and 64bit) | D'Arcy J.M. Cain | darcy@PyGreSQL.org diff --git a/docs/contents/install.rst b/docs/contents/install.rst index 4ef323af..d1926881 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -14,7 +14,7 @@ On Windows, you also need to make sure that the directory that contains ``libpq.dll`` is part of your ``PATH`` environment variable. The current version of PyGreSQL has been tested with Python versions -2.7 and 3.5 to 3.11, and PostgreSQL versions 9.0 to 9.6 and 10 to 15. +3.7 to 3.11, and PostgreSQL versions 10 to 15. PyGreSQL will be installed as three modules, a shared library called ``_pg.so`` (on Linux) or a DLL called ``_pg.pyd`` (on Windows), and two pure diff --git a/docs/contents/pg/adaptation.rst b/docs/contents/pg/adaptation.rst index 1cf44418..c5d0a795 100644 --- a/docs/contents/pg/adaptation.rst +++ b/docs/contents/pg/adaptation.rst @@ -26,7 +26,7 @@ PostgreSQL Python char, bpchar, name, text, varchar str bool bool bytea bytes -int2, int4, int8, oid, serial int [#int8]_ +int2, int4, int8, oid, serial int int2vector list of int float4, float8 float numeric, money Decimal @@ -45,8 +45,6 @@ record tuple Elements of arrays and records will also be converted accordingly. - .. [#int8] int8 is converted to long in Python 2 - .. [#array] The first element of the array will always be the first element of the Python list, no matter what the lower bound of the PostgreSQL array is. The information about the start index of the array (which is diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index c95adf59..d1c95213 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -730,7 +730,7 @@ the connection and its status. These attributes are: .. attribute:: Connection.server_version - the backend version (int, e.g. 90305 for 9.3.5) + the backend version (int, e.g. 150400 for 15.4) .. versionadded:: 4.0 diff --git a/docs/contents/pgdb/adaptation.rst b/docs/contents/pgdb/adaptation.rst index 0f9ad5a6..ebb36e5b 100644 --- a/docs/contents/pgdb/adaptation.rst +++ b/docs/contents/pgdb/adaptation.rst @@ -26,7 +26,7 @@ PostgreSQL Python char, bpchar, name, text, varchar str bool bool bytea bytes -int2, int4, int8, oid, serial int [#int8]_ +int2, int4, int8, oid, serial int int2vector list of int float4, float8 float numeric, money Decimal @@ -45,8 +45,6 @@ record tuple Elements of arrays and records will also be converted accordingly. - .. [#int8] int8 is converted to long in Python 2 - .. [#array] The first element of the array will always be the first element of the Python list, no matter what the lower bound of the PostgreSQL array is. The information about the start index of the array (which is diff --git a/pg.py b/pg.py index 181892ae..9d5a7e13 100644 --- a/pg.py +++ b/pg.py @@ -96,13 +96,7 @@ from re import compile as regex from json import loads as jsondecode, dumps as jsonencode from uuid import UUID - -try: - # noinspection PyUnresolvedReferences - from typing import Dict, List, Union # noqa: F401 - has_typing = True -except ImportError: # Python < 3.5 - has_typing = False +from typing import Dict, List, Union # noqa: F401 try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable long @@ -342,9 +336,6 @@ class _SimpleTypes(dict): bytes, unicode, basestring] } # type: Dict[str, List[Union[str, type]]] - if long is not int: # Python 2 has a separate long type - _type_aliases['num'].append(long) - # noinspection PyMissingConstructor def __init__(self): """Initialize type mapping.""" @@ -354,7 +345,7 @@ def __init__(self): self[key] = typ if isinstance(key, str): self['_%s' % key] = '%s[]' % typ - elif has_typing and not isinstance(key, tuple): + elif not isinstance(key, tuple): self[List[key]] = '%s[]' % typ @staticmethod diff --git a/pgdb.py b/pgdb.py index d2d06f4d..ccf848e9 100644 --- a/pgdb.py +++ b/pgdb.py @@ -1008,10 +1008,7 @@ def _quote(self, value): if not value: # exception for empty array return "'{}'" q = self._quote - try: - return 'ARRAY[%s]' % (','.join(str(q(v)) for v in value),) - except UnicodeEncodeError: # Python 2 with non-ascii values - return u'ARRAY[%s]' % (','.join(unicode(q(v)) for v in value),) + return 'ARRAY[%s]' % (','.join(str(q(v)) for v in value),) if isinstance(value, tuple): # Quote as a ROW constructor. This is better than using a record # literal because it carries the information that this is a record @@ -1019,10 +1016,7 @@ def _quote(self, value): # this usable with the IN syntax as well. It is only necessary # when the records has a single column which is not really useful. q = self._quote - try: - return '(%s)' % (','.join(str(q(v)) for v in value),) - except UnicodeEncodeError: # Python 2 with non-ascii values - return u'(%s)' % (','.join(unicode(q(v)) for v in value),) + return '(%s)' % (','.join(str(q(v)) for v in value),) try: # noinspection PyUnresolvedReferences value = value.__pg_repr__() except AttributeError: @@ -1472,8 +1466,8 @@ def __next__(self): raise StopIteration return res - # Note that since Python 2.6 the iterator protocol uses __next()__ - # instead of next(), we keep it only for backward compatibility of pgdb. + # Note that the iterator protocol now uses __next()__ instead of next(), + # but we keep it for backward compatibility of pgdb. next = __next__ @staticmethod diff --git a/setup.py b/setup.py index 89e07f94..fb5330e8 100755 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ # # Please see the LICENSE.TXT file for specific restrictions. -"""Setup script for PyGreSQL version 5.2.5 +"""Setup script for PyGreSQL version 6.0 PyGreSQL is an open-source Python module that interfaces to a PostgreSQL database. It wraps the lower level C API library libpq @@ -26,8 +26,8 @@ * PostgreSQL pg_config tool (usually included in the devel package) (the Windows installer has it as part of the database server feature) -PyGreSQL currently supports Python versions 2.7 and 3.5 to 3.11, -and PostgreSQL versions 9.0 to 9.6 and 10 to 15. +PyGreSQL currently supports Python versions 3.7 to 3.11, +and PostgreSQL versions 10 to 15. Use as follows: python setup.py build_ext # to build the module @@ -52,10 +52,9 @@ from distutils.ccompiler import get_default_compiler from distutils.sysconfig import get_python_inc, get_python_lib -version = '5.2.5' +version = '6.0' -if not (sys.version_info[:2] == (2, 7) - or (3, 5) <= sys.version_info[:2] < (4, 0)): +if not (3, 7) <= sys.version_info[:2] < (4, 0): raise Exception( "Sorry, PyGreSQL %s does not support this Python version" % version) @@ -84,7 +83,7 @@ def pg_version(): match = re.search(r'(\d+)\.(\d+)', pg_config('version')) if match: return tuple(map(int, match.groups())) - return 9, 0 + return 10, 0 pg_version = pg_version() @@ -146,7 +145,7 @@ def initialize_options(self): self.pqlib_info = None self.ssl_info = None self.memory_size = None - supported = pg_version >= (9, 0) + supported = pg_version >= (10, 0) if not supported: warnings.warn( "PyGreSQL does not support the installed PostgreSQL version.") @@ -162,33 +161,15 @@ def finalize_options(self): define_macros.append(('LARGE_OBJECTS', None)) if self.default_vars is None or self.default_vars: define_macros.append(('DEFAULT_VARS', None)) - wanted = self.escaping_funcs - supported = pg_version >= (9, 0) - if wanted or (wanted is None and supported): + if self.escaping_funcs is None or self.escaping_funcs: define_macros.append(('ESCAPING_FUNCS', None)) - if not supported: - warnings.warn( - "The installed PostgreSQL version" - " does not support the newer string escaping functions.") - wanted = self.pqlib_info - supported = pg_version >= (9, 1) - if wanted or (wanted is None and supported): + if self.pqlib_info is None or self.pqlib_info: define_macros.append(('PQLIB_INFO', None)) - if not supported: - warnings.warn( - "The installed PostgreSQL version" - " does not support PQLib info functions.") - wanted = self.ssl_info - supported = pg_version >= (9, 5) - if wanted or (wanted is None and supported): + if self.ssl_info is None or self.ssl_info: define_macros.append(('SSL_INFO', None)) - if not supported: - warnings.warn( - "The installed PostgreSQL version" - " does not support SSL info functions.") wanted = self.memory_size supported = pg_version >= (12, 0) - if wanted or (wanted is None and supported): + if (wanted is None and supported) or wanted: define_macros.append(('MEMORY_SIZE', None)) if not supported: warnings.warn( @@ -243,11 +224,7 @@ def finalize_options(self): "Operating System :: OS Independent", "Programming Language :: C", 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 5c043701..c43c2101 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -297,8 +297,7 @@ def testAllQueryMembers(self): members.remove('memsize') query_members = [ a for a in dir(query) - if not a.startswith('__') - and a != 'next'] # this is only needed in Python 2 + if not a.startswith('__')] self.assertEqual(members, query_members) def testMethodEndcopy(self): diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index fa3cc655..fd09c9d5 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -5015,9 +5015,6 @@ def getLeaks(self, fut): gc.collect() objs[:] = gc.get_objects() objs[:] = [obj for obj in objs if id(obj) not in ids] - if objs and sys.version_info[:3] in ((3, 5, 0), (3, 5, 1)): - # workaround for Python 3.5 issue 26811 - objs[:] = [obj for obj in objs if repr(obj) != '(,)'] self.assertEqual(len(objs), 0) def testLeaksWithClose(self): From db5374c893b3e9368e11caf2e1058d834ce1eca3 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 17:01:01 +0000 Subject: [PATCH 020/118] Remove now unnecessary py3c shim --- docs/download/files.rst | 1 - pgconn.c | 58 ++++++++-------- pginternal.c | 56 +++++++--------- pglarge.c | 14 ++-- pgmodule.c | 75 ++++++++++----------- pgnotice.c | 6 +- pgquery.c | 27 ++++---- pgsource.c | 55 +++++++--------- py3c.h | 143 ---------------------------------------- 9 files changed, 134 insertions(+), 301 deletions(-) delete mode 100644 py3c.h diff --git a/docs/download/files.rst b/docs/download/files.rst index 4f4741fd..ec581bf0 100644 --- a/docs/download/files.rst +++ b/docs/download/files.rst @@ -12,7 +12,6 @@ pgquery.c the query object pgsource.c the source object pgtypes.h PostgreSQL type definitions -py3c.h Python 2/3 compatibility layer for the C extension pg.py the "classic" PyGreSQL module pgdb.py a DB-SIG DB-API 2.0 compliant API wrapper for PyGreSQL diff --git a/pgconn.c b/pgconn.c index d39c9301..910f2212 100644 --- a/pgconn.c +++ b/pgconn.c @@ -26,7 +26,7 @@ conn_dealloc(connObject *self) static PyObject * conn_getattr(connObject *self, PyObject *nameobj) { - const char *name = PyStr_AsString(nameobj); + const char *name = PyUnicode_AsUTF8(nameobj); /* * Although we could check individually, there are only a few @@ -47,49 +47,49 @@ conn_getattr(connObject *self, PyObject *nameobj) char *r = PQhost(self->cnx); if (!r || r[0] == '/') /* Pg >= 9.6 can return a Unix socket path */ r = "localhost"; - return PyStr_FromString(r); + return PyUnicode_FromString(r); } /* postmaster port */ if (!strcmp(name, "port")) - return PyInt_FromLong(atol(PQport(self->cnx))); + return PyLong_FromLong(atol(PQport(self->cnx))); /* selected database */ if (!strcmp(name, "db")) - return PyStr_FromString(PQdb(self->cnx)); + return PyUnicode_FromString(PQdb(self->cnx)); /* selected options */ if (!strcmp(name, "options")) - return PyStr_FromString(PQoptions(self->cnx)); + return PyUnicode_FromString(PQoptions(self->cnx)); /* error (status) message */ if (!strcmp(name, "error")) - return PyStr_FromString(PQerrorMessage(self->cnx)); + return PyUnicode_FromString(PQerrorMessage(self->cnx)); /* connection status : 1 - OK, 0 - BAD */ if (!strcmp(name, "status")) - return PyInt_FromLong(PQstatus(self->cnx) == CONNECTION_OK ? 1 : 0); + return PyLong_FromLong(PQstatus(self->cnx) == CONNECTION_OK ? 1 : 0); /* provided user name */ if (!strcmp(name, "user")) - return PyStr_FromString(PQuser(self->cnx)); + return PyUnicode_FromString(PQuser(self->cnx)); /* protocol version */ if (!strcmp(name, "protocol_version")) - return PyInt_FromLong(PQprotocolVersion(self->cnx)); + return PyLong_FromLong(PQprotocolVersion(self->cnx)); /* backend version */ if (!strcmp(name, "server_version")) - return PyInt_FromLong(PQserverVersion(self->cnx)); + return PyLong_FromLong(PQserverVersion(self->cnx)); /* descriptor number of connection socket */ if (!strcmp(name, "socket")) { - return PyInt_FromLong(PQsocket(self->cnx)); + return PyLong_FromLong(PQsocket(self->cnx)); } /* PID of backend process */ if (!strcmp(name, "backend_pid")) { - return PyInt_FromLong(PQbackendPID(self->cnx)); + return PyLong_FromLong(PQbackendPID(self->cnx)); } /* whether the connection uses SSL */ @@ -183,7 +183,7 @@ _conn_non_query_result(int status, PGresult* result, PGconn *cnx) char *ret = PQcmdTuples(result); if (ret[0]) { /* return number of rows affected */ - PyObject *obj = PyStr_FromString(ret); + PyObject *obj = PyUnicode_FromString(ret); PQclear(result); return obj; } @@ -193,7 +193,7 @@ _conn_non_query_result(int status, PGresult* result, PGconn *cnx) } /* for a single insert, return the oid */ PQclear(result); - return PyInt_FromLong((long) oid); + return PyLong_FromLong((long) oid); } case PGRES_COPY_OUT: /* no data will be received */ case PGRES_COPY_IN: @@ -325,7 +325,7 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) return NULL; } *s++ = str_obj; - *p = PyStr_AsString(str_obj); + *p = PyUnicode_AsUTF8(str_obj); } } @@ -614,7 +614,7 @@ conn_getline(connObject *self, PyObject *noargs) } /* for backward compatibility, convert terminating newline to zero byte */ if (*line) line[strlen(line) - 1] = '\0'; - str = PyStr_FromString(line); + str = PyUnicode_FromString(line); PQfreemem(line); return str; } @@ -947,9 +947,9 @@ conn_inserttable(connObject *self, PyObject *args) Py_DECREF(s); } } - else if (PyInt_Check(item) || PyLong_Check(item)) { + else if (PyLong_Check(item)) { PyObject* s = PyObject_Str(item); - const char* t = PyStr_AsString(s); + const char* t = PyUnicode_AsUTF8(s); while (*t && bufsiz) { *bufpt++ = *t++; --bufsiz; @@ -958,7 +958,7 @@ conn_inserttable(connObject *self, PyObject *args) } else { PyObject* s = PyObject_Repr(item); - const char* t = PyStr_AsString(s); + const char* t = PyUnicode_AsUTF8(s); while (*t && bufsiz) { switch (*t) { @@ -1036,7 +1036,7 @@ conn_inserttable(connObject *self, PyObject *args) } else { long ntuples = atol(PQcmdTuples(result)); PQclear(result); - return PyInt_FromLong(ntuples); + return PyLong_FromLong(ntuples); } } @@ -1052,7 +1052,7 @@ conn_transaction(connObject *self, PyObject *noargs) return NULL; } - return PyInt_FromLong(PQtransactionStatus(self->cnx)); + return PyLong_FromLong(PQtransactionStatus(self->cnx)); } /* Get parameter setting. */ @@ -1079,7 +1079,7 @@ conn_parameter(connObject *self, PyObject *args) name = PQparameterStatus(self->cnx, name); if (name) - return PyStr_FromString(name); + return PyUnicode_FromString(name); /* unknown parameter, return None */ Py_INCREF(Py_None); @@ -1107,7 +1107,7 @@ conn_date_format(connObject *self, PyObject *noargs) self->date_format = fmt; /* cache the result */ } - return PyStr_FromString(fmt); + return PyUnicode_FromString(fmt); } #ifdef ESCAPING_FUNCS @@ -1450,7 +1450,7 @@ conn_cancel(connObject *self, PyObject *noargs) } /* request that the server abandon processing of the current command */ - return PyInt_FromLong((long) PQrequestCancel(self->cnx)); + return PyLong_FromLong((long) PQrequestCancel(self->cnx)); } /* Get connection socket. */ @@ -1465,7 +1465,7 @@ conn_fileno(connObject *self, PyObject *noargs) return NULL; } - return PyInt_FromLong((long) PQsocket(self->cnx)); + return PyLong_FromLong((long) PQsocket(self->cnx)); } /* Set external typecast callback function. */ @@ -1536,7 +1536,7 @@ conn_poll(connObject *self, PyObject *noargs) return NULL; } - return PyInt_FromLong(rc); + return PyLong_FromLong(rc); } /* Set notice receiver callback function. */ @@ -1632,7 +1632,7 @@ conn_get_notify(connObject *self, PyObject *noargs) else { PyObject *notify_result, *tmp; - if (!(tmp = PyStr_FromString(notify->relname))) { + if (!(tmp = PyUnicode_FromString(notify->relname))) { return NULL; } @@ -1642,7 +1642,7 @@ conn_get_notify(connObject *self, PyObject *noargs) PyTuple_SET_ITEM(notify_result, 0, tmp); - if (!(tmp = PyInt_FromLong(notify->be_pid))) { + if (!(tmp = PyLong_FromLong(notify->be_pid))) { Py_DECREF(notify_result); return NULL; } @@ -1650,7 +1650,7 @@ conn_get_notify(connObject *self, PyObject *noargs) PyTuple_SET_ITEM(notify_result, 1, tmp); /* extra exists even in old versions that did not support it */ - if (!(tmp = PyStr_FromString(notify->extra))) { + if (!(tmp = PyUnicode_FromString(notify->extra))) { Py_DECREF(notify_result); return NULL; } diff --git a/pginternal.c b/pginternal.c index 6dcad8bc..50181b0d 100644 --- a/pginternal.c +++ b/pginternal.c @@ -247,11 +247,10 @@ cast_sized_text(char *s, Py_ssize_t size, int encoding, int type) break; default: /* PYGRES_TEXT */ -#if IS_PY3 obj = get_decoded_string(s, size, encoding); - if (!obj) /* cannot decode */ -#endif - obj = PyBytes_FromStringAndSize(s, size); + if (!obj) { /* cannot decode */ + obj = PyBytes_FromStringAndSize(s, size); + } } return obj; @@ -296,7 +295,7 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) *t++ = *s++; } *t = '\0'; - obj = PyInt_FromString(buf, NULL, 10); + obj = PyLong_FromString(buf, NULL, 10); break; case PYGRES_LONG: @@ -312,7 +311,7 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) break; case PYGRES_FLOAT: - tmp_obj = PyStr_FromStringAndSize(s, size); + tmp_obj = PyUnicode_FromStringAndSize(s, size); obj = PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); break; @@ -336,7 +335,7 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) obj = PyObject_CallFunction(decimal, "(s)", buf); } else { - tmp_obj = PyStr_FromString(buf); + tmp_obj = PyUnicode_FromString(buf); obj = PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); @@ -344,7 +343,7 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) break; case PYGRES_DECIMAL: - tmp_obj = PyStr_FromStringAndSize(s, size); + tmp_obj = PyUnicode_FromStringAndSize(s, size); obj = decimal ? PyObject_CallFunctionObjArgs( decimal, tmp_obj, NULL) : PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); @@ -353,7 +352,7 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) case PYGRES_BOOL: /* convert to bool only if bool_as_text is not set */ if (bool_as_text) { - obj = PyStr_FromString(*s == 't' ? "t" : "f"); + obj = PyUnicode_FromString(*s == 't' ? "t" : "f"); } else { obj = *s == 't' ? Py_True : Py_False; @@ -363,7 +362,7 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) default: /* other types should never be passed, use cast_sized_text */ - obj = PyStr_FromStringAndSize(s, size); + obj = PyUnicode_FromStringAndSize(s, size); } return obj; @@ -381,15 +380,12 @@ cast_unsized_simple(char *s, int type) switch (type) { /* this must be the PyGreSQL internal type */ case PYGRES_INT: - obj = PyInt_FromString(s, NULL, 10); - break; - case PYGRES_LONG: obj = PyLong_FromString(s, NULL, 10); break; case PYGRES_FLOAT: - tmp_obj = PyStr_FromString(s); + tmp_obj = PyUnicode_FromString(s); obj = PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); break; @@ -416,7 +412,7 @@ cast_unsized_simple(char *s, int type) obj = PyObject_CallFunction(decimal, "(s)", s); } else { - tmp_obj = PyStr_FromString(s); + tmp_obj = PyUnicode_FromString(s); obj = PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); } @@ -425,7 +421,7 @@ cast_unsized_simple(char *s, int type) case PYGRES_BOOL: /* convert to bool only if bool_as_text is not set */ if (bool_as_text) { - obj = PyStr_FromString(*s == 't' ? "t" : "f"); + obj = PyUnicode_FromString(*s == 't' ? "t" : "f"); } else { obj = *s == 't' ? Py_True : Py_False; @@ -435,7 +431,7 @@ cast_unsized_simple(char *s, int type) default: /* other types should never be passed, use cast_sized_text */ - obj = PyStr_FromString(s); + obj = PyUnicode_FromString(s); } return obj; @@ -613,12 +609,11 @@ cast_array(char *s, Py_ssize_t size, int encoding, element = cast_sized_simple(estr, esize, type); } else { /* external casting of base type */ -#if IS_PY3 element = encoding == pg_encoding_ascii ? NULL : get_decoded_string(estr, esize, encoding); - if (!element) /* no decoding necessary or possible */ -#endif - element = PyBytes_FromStringAndSize(estr, esize); + if (!element) { /* no decoding necessary or possible */ + element = PyBytes_FromStringAndSize(estr, esize); + } if (element && cast) { PyObject *tmp = element; element = PyObject_CallFunctionObjArgs( @@ -768,12 +763,11 @@ cast_record(char *s, Py_ssize_t size, int encoding, element = cast_sized_simple(estr, esize, etype); } else { /* external casting of base type */ -#if IS_PY3 element = encoding == pg_encoding_ascii ? NULL : get_decoded_string(estr, esize, encoding); - if (!element) /* no decoding necessary or possible */ -#endif - element = PyBytes_FromStringAndSize(estr, esize); + if (!element) { /* no decoding necessary or possible */ + element = PyBytes_FromStringAndSize(estr, esize); + } if (element && cast) { if (len) { PyObject *ecast = PySequence_GetItem(cast, i); @@ -1065,17 +1059,15 @@ set_error_msg_and_state(PyObject *type, { PyObject *err_obj, *msg_obj, *sql_obj = NULL; -#if IS_PY3 if (encoding == -1) /* unknown */ msg_obj = PyUnicode_DecodeLocale(msg, NULL); else msg_obj = get_decoded_string(msg, (Py_ssize_t) strlen(msg), encoding); if (!msg_obj) /* cannot decode */ -#endif - msg_obj = PyBytes_FromString(msg); + msg_obj = PyBytes_FromString(msg); if (sqlstate) { - sql_obj = PyStr_FromStringAndSize(sqlstate, 5); + sql_obj = PyUnicode_FromStringAndSize(sqlstate, 5); } else { Py_INCREF(Py_None); sql_obj = Py_None; @@ -1139,7 +1131,7 @@ get_ssl_attributes(PGconn *cnx) { const char *val = PQsslAttribute(cnx, *s); if (val) { - PyObject * val_obj = PyStr_FromString(val); + PyObject * val_obj = PyUnicode_FromString(val); PyDict_SetItemString(attr_dict, *s, val_obj); Py_DECREF(val_obj); @@ -1280,7 +1272,7 @@ format_result(const PGresult *res) /* create the footer */ sprintf(p, "(%d row%s)", m, m == 1 ? "" : "s"); /* return the result */ - result = PyStr_FromString(buffer); + result = PyUnicode_FromString(buffer); PyMem_Free(buffer); return result; } @@ -1293,7 +1285,7 @@ format_result(const PGresult *res) } } else - return PyStr_FromString("(nothing selected)"); + return PyUnicode_FromString("(nothing selected)"); } /* Internal function converting a Postgres datestyles to date formats. */ diff --git a/pglarge.c b/pglarge.c index c080d658..863e2ec9 100644 --- a/pglarge.c +++ b/pglarge.c @@ -31,7 +31,7 @@ large_str(largeObject *self) sprintf(str, self->lo_fd >= 0 ? "Opened large object, oid %ld" : "Closed large object, oid %ld", (long) self->lo_oid); - return PyStr_FromString(str); + return PyUnicode_FromString(str); } /* Check validity of large object. */ @@ -67,7 +67,7 @@ _check_lo_obj(largeObject *self, int level) static PyObject * large_getattr(largeObject *self, PyObject *nameobj) { - const char *name = PyStr_AsString(nameobj); + const char *name = PyUnicode_AsUTF8(nameobj); /* list postgreSQL large object fields */ @@ -85,7 +85,7 @@ large_getattr(largeObject *self, PyObject *nameobj) /* large object oid */ if (!strcmp(name, "oid")) { if (_check_lo_obj(self, 0)) - return PyInt_FromLong((long) self->lo_oid); + return PyLong_FromLong((long) self->lo_oid); PyErr_Clear(); Py_INCREF(Py_None); return Py_None; @@ -93,7 +93,7 @@ large_getattr(largeObject *self, PyObject *nameobj) /* error (status) message */ if (!strcmp(name, "error")) - return PyStr_FromString(PQerrorMessage(self->pgcnx->cnx)); + return PyUnicode_FromString(PQerrorMessage(self->pgcnx->cnx)); /* seeks name in methods (fallback) */ return PyObject_GenericGetAttr((PyObject *) self, nameobj); @@ -285,7 +285,7 @@ large_seek(largeObject *self, PyObject *args) } /* returns position */ - return PyInt_FromLong(ret); + return PyLong_FromLong(ret); } /* Get large object size. */ @@ -325,7 +325,7 @@ large_size(largeObject *self, PyObject *noargs) } /* returns size */ - return PyInt_FromLong(end); + return PyLong_FromLong(end); } /* Get large object cursor position. */ @@ -350,7 +350,7 @@ large_tell(largeObject *self, PyObject *noargs) } /* returns size */ - return PyInt_FromLong(start); + return PyLong_FromLong(start); } /* Export large object as unix file. */ diff --git a/pgmodule.c b/pgmodule.c index bbb4b0db..6adc79c0 100644 --- a/pgmodule.c +++ b/pgmodule.c @@ -19,9 +19,6 @@ /* The type definitions from */ #include "pgtypes.h" -/* Macros for single-source Python 2/3 compatibility */ -#include "py3c.h" - static PyObject *Error, *Warning, *InterfaceError, *DatabaseError, *InternalError, *OperationalError, *ProgrammingError, *IntegrityError, *DataError, *NotSupportedError, @@ -237,7 +234,7 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) pghost = PyBytes_AsString(pg_default_host); if ((pgport == -1) && (pg_default_port != Py_None)) - pgport = (int) PyInt_AsLong(pg_default_port); + pgport = (int) PyLong_AsLong(pg_default_port); if ((!pgopt) && (pg_default_opt != Py_None)) pgopt = PyBytes_AsString(pg_default_opt); @@ -488,7 +485,7 @@ static PyObject * pg_get_datestyle(PyObject *self, PyObject *noargs) { if (date_format) { - return PyStr_FromString(date_format_to_style(date_format)); + return PyUnicode_FromString(date_format_to_style(date_format)); } else { Py_INCREF(Py_None); return Py_None; @@ -507,7 +504,7 @@ pg_get_decimal_point(PyObject *self, PyObject *noargs) if (decimal_point) { s[0] = decimal_point; s[1] = '\0'; - ret = PyStr_FromString(s); + ret = PyUnicode_FromString(s); } else { Py_INCREF(Py_None); ret = Py_None; @@ -804,7 +801,7 @@ pg_set_defhost(PyObject *self, PyObject *args) old = pg_default_host; if (tmp) { - pg_default_host = PyStr_FromString(tmp); + pg_default_host = PyUnicode_FromString(tmp); } else { Py_INCREF(Py_None); @@ -847,7 +844,7 @@ pg_set_defbase(PyObject *self, PyObject *args) old = pg_default_base; if (tmp) { - pg_default_base = PyStr_FromString(tmp); + pg_default_base = PyUnicode_FromString(tmp); } else { Py_INCREF(Py_None); @@ -890,7 +887,7 @@ pg_setdefopt(PyObject *self, PyObject *args) old = pg_default_opt; if (tmp) { - pg_default_opt = PyStr_FromString(tmp); + pg_default_opt = PyUnicode_FromString(tmp); } else { Py_INCREF(Py_None); @@ -934,7 +931,7 @@ pg_set_defuser(PyObject *self, PyObject *args) old = pg_default_user; if (tmp) { - pg_default_user = PyStr_FromString(tmp); + pg_default_user = PyUnicode_FromString(tmp); } else { Py_INCREF(Py_None); @@ -962,7 +959,7 @@ pg_set_defpasswd(PyObject *self, PyObject *args) } if (tmp) { - pg_default_passwd = PyStr_FromString(tmp); + pg_default_passwd = PyUnicode_FromString(tmp); } else { Py_INCREF(Py_None); @@ -1006,7 +1003,7 @@ pg_set_defport(PyObject *self, PyObject *args) old = pg_default_port; if (port != -1) { - pg_default_port = PyInt_FromLong(port); + pg_default_port = PyLong_FromLong(port); } else { Py_INCREF(Py_None); @@ -1250,7 +1247,9 @@ static struct PyModuleDef moduleDef = { }; /* Initialization function for the module */ -MODULE_INIT_FUNC(_pg) +PyMODINIT_FUNC PyInit__pg(void); + +PyMODINIT_FUNC PyInit__pg(void) { PyObject *mod, *dict, *s; @@ -1259,18 +1258,10 @@ MODULE_INIT_FUNC(_pg) mod = PyModule_Create(&moduleDef); /* Initialize here because some Windows platforms get confused otherwise */ -#if IS_PY3 connType.tp_base = noticeType.tp_base = queryType.tp_base = sourceType.tp_base = &PyBaseObject_Type; #ifdef LARGE_OBJECTS largeType.tp_base = &PyBaseObject_Type; -#endif -#else - connType.ob_type = noticeType.ob_type = - queryType.ob_type = sourceType.ob_type = &PyType_Type; -#ifdef LARGE_OBJECTS - largeType.ob_type = &PyType_Type; -#endif #endif if (PyType_Ready(&connType) @@ -1288,10 +1279,10 @@ MODULE_INIT_FUNC(_pg) dict = PyModule_GetDict(mod); /* Exceptions as defined by DB-API 2.0 */ - Error = PyErr_NewException("pg.Error", PyExc_StandardError, NULL); + Error = PyErr_NewException("pg.Error", PyExc_Exception, NULL); PyDict_SetItemString(dict, "Error", Error); - Warning = PyErr_NewException("pg.Warning", PyExc_StandardError, NULL); + Warning = PyErr_NewException("pg.Warning", PyExc_Exception, NULL); PyDict_SetItemString(dict, "Warning", Warning); InterfaceError = PyErr_NewException( @@ -1339,39 +1330,39 @@ MODULE_INIT_FUNC(_pg) PyDict_SetItemString(dict, "MultipleResultsError", MultipleResultsError); /* Make the version available */ - s = PyStr_FromString(PyPgVersion); + s = PyUnicode_FromString(PyPgVersion); PyDict_SetItemString(dict, "version", s); PyDict_SetItemString(dict, "__version__", s); Py_DECREF(s); /* Result types for queries */ - PyDict_SetItemString(dict, "RESULT_EMPTY", PyInt_FromLong(RESULT_EMPTY)); - PyDict_SetItemString(dict, "RESULT_DML", PyInt_FromLong(RESULT_DML)); - PyDict_SetItemString(dict, "RESULT_DDL", PyInt_FromLong(RESULT_DDL)); - PyDict_SetItemString(dict, "RESULT_DQL", PyInt_FromLong(RESULT_DQL)); + PyDict_SetItemString(dict, "RESULT_EMPTY", PyLong_FromLong(RESULT_EMPTY)); + PyDict_SetItemString(dict, "RESULT_DML", PyLong_FromLong(RESULT_DML)); + PyDict_SetItemString(dict, "RESULT_DDL", PyLong_FromLong(RESULT_DDL)); + PyDict_SetItemString(dict, "RESULT_DQL", PyLong_FromLong(RESULT_DQL)); /* Transaction states */ - PyDict_SetItemString(dict, "TRANS_IDLE", PyInt_FromLong(PQTRANS_IDLE)); - PyDict_SetItemString(dict, "TRANS_ACTIVE", PyInt_FromLong(PQTRANS_ACTIVE)); - PyDict_SetItemString(dict, "TRANS_INTRANS", PyInt_FromLong(PQTRANS_INTRANS)); - PyDict_SetItemString(dict, "TRANS_INERROR", PyInt_FromLong(PQTRANS_INERROR)); - PyDict_SetItemString(dict, "TRANS_UNKNOWN", PyInt_FromLong(PQTRANS_UNKNOWN)); + PyDict_SetItemString(dict, "TRANS_IDLE", PyLong_FromLong(PQTRANS_IDLE)); + PyDict_SetItemString(dict, "TRANS_ACTIVE", PyLong_FromLong(PQTRANS_ACTIVE)); + PyDict_SetItemString(dict, "TRANS_INTRANS", PyLong_FromLong(PQTRANS_INTRANS)); + PyDict_SetItemString(dict, "TRANS_INERROR", PyLong_FromLong(PQTRANS_INERROR)); + PyDict_SetItemString(dict, "TRANS_UNKNOWN", PyLong_FromLong(PQTRANS_UNKNOWN)); /* Polling results */ - PyDict_SetItemString(dict, "POLLING_OK", PyInt_FromLong(PGRES_POLLING_OK)); - PyDict_SetItemString(dict, "POLLING_FAILED", PyInt_FromLong(PGRES_POLLING_FAILED)); - PyDict_SetItemString(dict, "POLLING_READING", PyInt_FromLong(PGRES_POLLING_READING)); - PyDict_SetItemString(dict, "POLLING_WRITING", PyInt_FromLong(PGRES_POLLING_WRITING)); + PyDict_SetItemString(dict, "POLLING_OK", PyLong_FromLong(PGRES_POLLING_OK)); + PyDict_SetItemString(dict, "POLLING_FAILED", PyLong_FromLong(PGRES_POLLING_FAILED)); + PyDict_SetItemString(dict, "POLLING_READING", PyLong_FromLong(PGRES_POLLING_READING)); + PyDict_SetItemString(dict, "POLLING_WRITING", PyLong_FromLong(PGRES_POLLING_WRITING)); #ifdef LARGE_OBJECTS /* Create mode for large objects */ - PyDict_SetItemString(dict, "INV_READ", PyInt_FromLong(INV_READ)); - PyDict_SetItemString(dict, "INV_WRITE", PyInt_FromLong(INV_WRITE)); + PyDict_SetItemString(dict, "INV_READ", PyLong_FromLong(INV_READ)); + PyDict_SetItemString(dict, "INV_WRITE", PyLong_FromLong(INV_WRITE)); /* Position flags for lo_lseek */ - PyDict_SetItemString(dict, "SEEK_SET", PyInt_FromLong(SEEK_SET)); - PyDict_SetItemString(dict, "SEEK_CUR", PyInt_FromLong(SEEK_CUR)); - PyDict_SetItemString(dict, "SEEK_END", PyInt_FromLong(SEEK_END)); + PyDict_SetItemString(dict, "SEEK_SET", PyLong_FromLong(SEEK_SET)); + PyDict_SetItemString(dict, "SEEK_CUR", PyLong_FromLong(SEEK_CUR)); + PyDict_SetItemString(dict, "SEEK_END", PyLong_FromLong(SEEK_END)); #endif /* LARGE_OBJECTS */ #ifdef DEFAULT_VARS diff --git a/pgnotice.c b/pgnotice.c index ae6b2b68..e079283c 100644 --- a/pgnotice.c +++ b/pgnotice.c @@ -13,7 +13,7 @@ static PyObject * notice_getattr(noticeObject *self, PyObject *nameobj) { PGresult const *res = self->res; - const char *name = PyStr_AsString(nameobj); + const char *name = PyUnicode_AsUTF8(nameobj); int fieldcode; if (!res) { @@ -35,7 +35,7 @@ notice_getattr(noticeObject *self, PyObject *nameobj) /* full message */ if (!strcmp(name, "message")) { - return PyStr_FromString(PQresultErrorMessage(res)); + return PyUnicode_FromString(PQresultErrorMessage(res)); } /* other possible fields */ @@ -51,7 +51,7 @@ notice_getattr(noticeObject *self, PyObject *nameobj) if (fieldcode) { char *s = PQresultErrorField(res, fieldcode); if (s) { - return PyStr_FromString(s); + return PyUnicode_FromString(s); } else { Py_INCREF(Py_None); return Py_None; diff --git a/pgquery.c b/pgquery.c index 0d7ebc7d..0923eb66 100644 --- a/pgquery.c +++ b/pgquery.c @@ -168,7 +168,7 @@ _get_async_result(queryObject *self, int keep) { are additional statements following this one, so we return an empty string where query() would return None. */ Py_DECREF(result); - result = PyStr_FromString(""); + result = PyUnicode_FromString(""); } return result; } @@ -266,7 +266,7 @@ static char query_ntuples__doc__[] = static PyObject * query_ntuples(queryObject *self, PyObject *noargs) { - return PyInt_FromLong(self->max_row); + return PyLong_FromLong(self->max_row); } /* List field names from query result. */ @@ -285,7 +285,7 @@ query_listfields(queryObject *self, PyObject *noargs) if (fieldstuple) { for (i = 0; i < self->num_fields; ++i) { name = PQfname(self->result, i); - str = PyStr_FromString(name); + str = PyUnicode_FromString(name); PyTuple_SET_ITEM(fieldstuple, i, str); } } @@ -317,7 +317,7 @@ query_fieldname(queryObject *self, PyObject *args) /* gets fields name and builds object */ name = PQfname(self->result, i); - return PyStr_FromString(name); + return PyUnicode_FromString(name); } /* Get field number from name in last result. */ @@ -343,7 +343,7 @@ query_fieldnum(queryObject *self, PyObject *args) return NULL; } - return PyInt_FromLong(num); + return PyLong_FromLong(num); } /* Build a tuple with info for query field with given number. */ @@ -353,10 +353,10 @@ _query_build_field_info(PGresult *res, int col_num) { info = PyTuple_New(4); if (info) { - PyTuple_SET_ITEM(info, 0, PyStr_FromString(PQfname(res, col_num))); - PyTuple_SET_ITEM(info, 1, PyInt_FromLong((long) PQftype(res, col_num))); - PyTuple_SET_ITEM(info, 2, PyInt_FromLong(PQfsize(res, col_num))); - PyTuple_SET_ITEM(info, 3, PyInt_FromLong(PQfmod(res, col_num))); + PyTuple_SET_ITEM(info, 0, PyUnicode_FromString(PQfname(res, col_num))); + PyTuple_SET_ITEM(info, 1, PyLong_FromLong((long) PQftype(res, col_num))); + PyTuple_SET_ITEM(info, 2, PyLong_FromLong(PQfsize(res, col_num))); + PyTuple_SET_ITEM(info, 3, PyLong_FromLong(PQfmod(res, col_num))); } return info; } @@ -383,13 +383,13 @@ query_fieldinfo(queryObject *self, PyObject *args) /* gets field number */ if (PyBytes_Check(field)) { num = PQfnumber(self->result, PyBytes_AsString(field)); - } else if (PyStr_Check(field)) { + } else if (PyUnicode_Check(field)) { PyObject *tmp = get_encoded_string(field, self->encoding); if (!tmp) return NULL; num = PQfnumber(self->result, PyBytes_AsString(tmp)); Py_DECREF(tmp); - } else if (PyInt_Check(field)) { - num = (int) PyInt_AsLong(field); + } else if (PyLong_Check(field)) { + num = (int) PyLong_AsLong(field); } else { PyErr_SetString(PyExc_TypeError, "Field should be given as column number or name"); @@ -980,8 +980,7 @@ static PyTypeObject queryType = { PyObject_GenericGetAttr, /* tp_getattro */ 0, /* tp_setattro */ 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT - |Py_TPFLAGS_HAVE_ITER, /* tp_flags */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ query__doc__, /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ diff --git a/pgsource.c b/pgsource.c index 053ad02f..7b081273 100644 --- a/pgsource.c +++ b/pgsource.c @@ -28,10 +28,10 @@ source_str(sourceObject *self) return format_result(self->result); case RESULT_DDL: case RESULT_DML: - return PyStr_FromString(PQcmdStatus(self->result)); + return PyUnicode_FromString(PQcmdStatus(self->result)); case RESULT_EMPTY: default: - return PyStr_FromString("(empty PostgreSQL source object)"); + return PyUnicode_FromString("(empty PostgreSQL source object)"); } } @@ -65,7 +65,7 @@ _check_source_obj(sourceObject *self, int level) static PyObject * source_getattr(sourceObject *self, PyObject *nameobj) { - const char *name = PyStr_AsString(nameobj); + const char *name = PyUnicode_AsUTF8(nameobj); /* pg connection object */ if (!strcmp(name, "pgcnx")) { @@ -79,19 +79,19 @@ source_getattr(sourceObject *self, PyObject *nameobj) /* arraysize */ if (!strcmp(name, "arraysize")) - return PyInt_FromLong(self->arraysize); + return PyLong_FromLong(self->arraysize); /* resulttype */ if (!strcmp(name, "resulttype")) - return PyInt_FromLong(self->result_type); + return PyLong_FromLong(self->result_type); /* ntuples */ if (!strcmp(name, "ntuples")) - return PyInt_FromLong(self->max_row); + return PyLong_FromLong(self->max_row); /* nfields */ if (!strcmp(name, "nfields")) - return PyInt_FromLong(self->num_fields); + return PyLong_FromLong(self->num_fields); /* seeks name in methods (fallback) */ return PyObject_GenericGetAttr((PyObject *) self, nameobj); @@ -103,12 +103,12 @@ source_setattr(sourceObject *self, char *name, PyObject *v) { /* arraysize */ if (!strcmp(name, "arraysize")) { - if (!PyInt_Check(v)) { + if (!PyLong_Check(v)) { PyErr_SetString(PyExc_TypeError, "arraysize must be integer"); return -1; } - self->arraysize = PyInt_AsLong(v); + self->arraysize = PyLong_AsLong(v); return 0; } @@ -227,7 +227,7 @@ source_execute(sourceObject *self, PyObject *sql) self->result_type = RESULT_DDL; num_rows = -1; } - return PyInt_FromLong(num_rows); + return PyLong_FromLong(num_rows); } /* query failed */ @@ -272,7 +272,7 @@ source_oidstatus(sourceObject *self, PyObject *noargs) return Py_None; } - return PyInt_FromLong((long) oid); + return PyLong_FromLong((long) oid); } /* Fetch rows from last result. */ @@ -287,9 +287,7 @@ source_fetch(sourceObject *self, PyObject *args) PyObject *res_list; int i, k; long size; -#if IS_PY3 int encoding; -#endif /* checks validity */ if (!_check_source_obj(self, CHECK_RESULT | CHECK_DQL | CHECK_CNX)) { @@ -313,9 +311,7 @@ source_fetch(sourceObject *self, PyObject *args) /* allocate list for result */ if (!(res_list = PyList_New(0))) return NULL; -#if IS_PY3 encoding = self->encoding; -#endif /* builds result */ for (i = 0, k = self->current_row; i < size; ++i, ++k) { @@ -336,15 +332,14 @@ source_fetch(sourceObject *self, PyObject *args) else { char *s = PQgetvalue(self->result, k, j); Py_ssize_t size = PQgetlength(self->result, k, j); -#if IS_PY3 if (PQfformat(self->result, j) == 0) { /* textual format */ str = get_decoded_string(s, size, encoding); if (!str) /* cannot decode */ str = PyBytes_FromStringAndSize(s, size); } - else -#endif - str = PyBytes_FromStringAndSize(s, size); + else { + str = PyBytes_FromStringAndSize(s, size); + } } PyTuple_SET_ITEM(rowtuple, j, str); } @@ -531,7 +526,7 @@ source_putdata(sourceObject *self, PyObject *buffer) tmp = PQcmdTuples(result); num_rows = tmp[0] ? atol(tmp) : -1; - ret = PyInt_FromLong(num_rows); + ret = PyLong_FromLong(num_rows); } else { if (!errormsg) errormsg = PQerrorMessage(self->pgcnx->cnx); @@ -602,7 +597,7 @@ source_getdata(sourceObject *self, PyObject *args) tmp = PQcmdTuples(result); num_rows = tmp[0] ? atol(tmp) : -1; - ret = PyInt_FromLong(num_rows); + ret = PyLong_FromLong(num_rows); } else { PyErr_SetString(PyExc_IOError, PQerrorMessage(self->pgcnx->cnx)); @@ -634,11 +629,11 @@ _source_fieldindex(sourceObject *self, PyObject *param, const char *usage) return -1; /* gets field number */ - if (PyStr_Check(param)) { + if (PyUnicode_Check(param)) { num = PQfnumber(self->result, PyBytes_AsString(param)); } - else if (PyInt_Check(param)) { - num = (int) PyInt_AsLong(param); + else if (PyLong_Check(param)) { + num = (int) PyLong_AsLong(param); } else { PyErr_SetString(PyExc_TypeError, usage); @@ -667,15 +662,15 @@ _source_buildinfo(sourceObject *self, int num) } /* affects field information */ - PyTuple_SET_ITEM(result, 0, PyInt_FromLong(num)); + PyTuple_SET_ITEM(result, 0, PyLong_FromLong(num)); PyTuple_SET_ITEM(result, 1, - PyStr_FromString(PQfname(self->result, num))); + PyUnicode_FromString(PQfname(self->result, num))); PyTuple_SET_ITEM(result, 2, - PyInt_FromLong((long) PQftype(self->result, num))); + PyLong_FromLong((long) PQftype(self->result, num))); PyTuple_SET_ITEM(result, 3, - PyInt_FromLong(PQfsize(self->result, num))); + PyLong_FromLong(PQfsize(self->result, num))); PyTuple_SET_ITEM(result, 4, - PyInt_FromLong(PQfmod(self->result, num))); + PyLong_FromLong(PQfmod(self->result, num))); return result; } @@ -751,7 +746,7 @@ source_field(sourceObject *self, PyObject *desc) return NULL; } - return PyStr_FromString( + return PyUnicode_FromString( PQgetvalue(self->result, self->current_row, num)); } diff --git a/py3c.h b/py3c.h deleted file mode 100644 index c137b191..00000000 --- a/py3c.h +++ /dev/null @@ -1,143 +0,0 @@ -/* Copyright (c) 2015, Red Hat, Inc. and/or its affiliates - * Licensed under the MIT license; see py3c.h - */ - -#ifndef _PY3C_COMPAT_H_ -#define _PY3C_COMPAT_H_ -#define PY_SSIZE_T_CLEAN -#include - -#if PY_MAJOR_VERSION >= 3 - -/***** Python 3 *****/ - -#define IS_PY3 1 - -/* Strings */ - -#define PyStr_Type PyUnicode_Type -#define PyStr_Check PyUnicode_Check -#define PyStr_CheckExact PyUnicode_CheckExact -#define PyStr_FromString PyUnicode_FromString -#define PyStr_FromStringAndSize PyUnicode_FromStringAndSize -#define PyStr_FromFormat PyUnicode_FromFormat -#define PyStr_FromFormatV PyUnicode_FromFormatV -#define PyStr_AsString PyUnicode_AsUTF8 -#define PyStr_Concat PyUnicode_Concat -#define PyStr_Format PyUnicode_Format -#define PyStr_InternInPlace PyUnicode_InternInPlace -#define PyStr_InternFromString PyUnicode_InternFromString -#define PyStr_Decode PyUnicode_Decode - -#define PyStr_AsUTF8String PyUnicode_AsUTF8String // returns PyBytes -#define PyStr_AsUTF8 PyUnicode_AsUTF8 -#define PyStr_AsUTF8AndSize PyUnicode_AsUTF8AndSize - -/* Ints */ - -#define PyInt_Type PyLong_Type -#define PyInt_Check PyLong_Check -#define PyInt_CheckExact PyLong_CheckExact -#define PyInt_FromString PyLong_FromString -#define PyInt_FromLong PyLong_FromLong -#define PyInt_FromSsize_t PyLong_FromSsize_t -#define PyInt_FromSize_t PyLong_FromSize_t -#define PyInt_AsLong PyLong_AsLong -#define PyInt_AS_LONG PyLong_AS_LONG -#define PyInt_AsUnsignedLongLongMask PyLong_AsUnsignedLongLongMask -#define PyInt_AsSsize_t PyLong_AsSsize_t - -/* Module init */ - -#define MODULE_INIT_FUNC(name) \ - PyMODINIT_FUNC PyInit_ ## name(void); \ - PyMODINIT_FUNC PyInit_ ## name(void) - -/* Other */ - -#define Py_TPFLAGS_HAVE_ITER 0 // not needed in Python 3 - -#define PyExc_StandardError PyExc_Exception // exists only in Python 2 - -#else - -/***** Python 2 *****/ - -#define IS_PY3 0 - -/* Strings */ - -#define PyStr_Type PyString_Type -#define PyStr_Check PyString_Check -#define PyStr_CheckExact PyString_CheckExact -#define PyStr_FromString PyString_FromString -#define PyStr_FromStringAndSize PyString_FromStringAndSize -#define PyStr_FromFormat PyString_FromFormat -#define PyStr_FromFormatV PyString_FromFormatV -#define PyStr_AsString PyString_AsString -#define PyStr_Format PyString_Format -#define PyStr_InternInPlace PyString_InternInPlace -#define PyStr_InternFromString PyString_InternFromString -#define PyStr_Decode PyString_Decode - -static inline PyObject *PyStr_Concat(PyObject *left, PyObject *right) { - PyObject *str = left; - Py_INCREF(left); // reference to old left will be stolen - PyString_Concat(&str, right); - if (str) { - return str; - } else { - return NULL; - } -} - -#define PyStr_AsUTF8String(str) (Py_INCREF(str), (str)) -#define PyStr_AsUTF8 PyString_AsString -#define PyStr_AsUTF8AndSize(pystr, sizeptr) \ - ((*sizeptr=PyString_Size(pystr)), PyString_AsString(pystr)) - -#define PyBytes_Type PyString_Type -#define PyBytes_Check PyString_Check -#define PyBytes_CheckExact PyString_CheckExact -#define PyBytes_FromString PyString_FromString -#define PyBytes_FromStringAndSize PyString_FromStringAndSize -#define PyBytes_FromFormat PyString_FromFormat -#define PyBytes_FromFormatV PyString_FromFormatV -#define PyBytes_Size PyString_Size -#define PyBytes_GET_SIZE PyString_GET_SIZE -#define PyBytes_AsString PyString_AsString -#define PyBytes_AS_STRING PyString_AS_STRING -#define PyBytes_AsStringAndSize PyString_AsStringAndSize -#define PyBytes_Concat PyString_Concat -#define PyBytes_ConcatAndDel PyString_ConcatAndDel -#define _PyBytes_Resize _PyString_Resize - -/* Floats */ - -#define PyFloat_FromString(str) PyFloat_FromString(str, NULL) - -/* Module init */ - -#define PyModuleDef_HEAD_INIT 0 - -typedef struct PyModuleDef { - int m_base; - const char* m_name; - const char* m_doc; - Py_ssize_t m_size; - PyMethodDef *m_methods; -} PyModuleDef; - -#define PyModule_Create(def) \ - Py_InitModule3((def)->m_name, (def)->m_methods, (def)->m_doc) - -#define MODULE_INIT_FUNC(name) \ - static PyObject *PyInit_ ## name(void); \ - void init ## name(void); \ - void init ## name(void) { PyInit_ ## name(); } \ - static PyObject *PyInit_ ## name(void) - - -#endif - -#endif From de8936f7ba56a2468aee8f37d1cb8c7dfa00fc55 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 17:04:50 +0000 Subject: [PATCH 021/118] Remove now unnecessary encoding comments --- tests/config.py | 1 - tests/test_classic.py | 1 - tests/test_classic_connection.py | 1 - tests/test_classic_dbwrapper.py | 1 - tests/test_classic_functions.py | 1 - tests/test_classic_largeobj.py | 1 - tests/test_classic_notification.py | 1 - tests/test_dbapi20.py | 1 - tests/test_dbapi20_copy.py | 1 - tests/test_tutorial.py | 1 - 10 files changed, 10 deletions(-) diff --git a/tests/config.py b/tests/config.py index a6dcd3a3..e6bf326c 100644 --- a/tests/config.py +++ b/tests/config.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- from os import environ diff --git a/tests/test_classic.py b/tests/test_classic.py index 727e4a86..3284d9ee 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- from __future__ import print_function diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index c43c2101..8c7adc39 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- """Test the classic PyGreSQL interface. diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index fd09c9d5..0843710d 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- """Test the classic PyGreSQL interface. diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 653fbb87..db450ec8 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- """Test the classic PyGreSQL interface. diff --git a/tests/test_classic_largeobj.py b/tests/test_classic_largeobj.py index fc0464d5..3271686c 100755 --- a/tests/test_classic_largeobj.py +++ b/tests/test_classic_largeobj.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- """Test the classic PyGreSQL interface. diff --git a/tests/test_classic_notification.py b/tests/test_classic_notification.py index 39f607df..6f94cebd 100755 --- a/tests/test_classic_notification.py +++ b/tests/test_classic_notification.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- """Test the classic PyGreSQL interface. diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index a03dca93..2d853f73 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- import gc import sys diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index 47fc012a..d6fd1cfc 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- """Test the modern PyGreSQL interface. diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index 6f968560..d9d1398b 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- from __future__ import print_function From 0fd08bb5d9330d90a8e018f781b57c44606041ec Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 17:07:18 +0000 Subject: [PATCH 022/118] Remove now unnecessary future statements --- pg.py | 2 -- pgdb.py | 2 -- tests/test_classic.py | 2 -- tests/test_tutorial.py | 2 -- 4 files changed, 8 deletions(-) diff --git a/pg.py b/pg.py index 9d5a7e13..b8c4fa08 100644 --- a/pg.py +++ b/pg.py @@ -20,8 +20,6 @@ For a DB-API 2 compliant interface use the newer pgdb module. """ -from __future__ import print_function, division - try: from _pg import * except ImportError as e: diff --git a/pgdb.py b/pgdb.py index ccf848e9..2919caf3 100644 --- a/pgdb.py +++ b/pgdb.py @@ -64,8 +64,6 @@ connection.close() # close the connection """ -from __future__ import print_function, division - try: from _pg import * except ImportError as e: diff --git a/tests/test_classic.py b/tests/test_classic.py index 3284d9ee..375bad3f 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -1,7 +1,5 @@ #!/usr/bin/python -from __future__ import print_function - import unittest from functools import partial diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index d9d1398b..0193165a 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -1,7 +1,5 @@ #!/usr/bin/python -from __future__ import print_function - import unittest from pg import DB From 60f9eb4a73c4f489e47b797bdfe5a8984a1cc985 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 20:08:04 +0200 Subject: [PATCH 023/118] Simplify pg modules assuming modern Python --- pg.py | 245 ++++++++++--------------------------------------------- pgdb.py | 246 ++++++++++---------------------------------------------- 2 files changed, 84 insertions(+), 407 deletions(-) diff --git a/pg.py b/pg.py index b8c4fa08..2edeb6e2 100644 --- a/pg.py +++ b/pg.py @@ -46,10 +46,9 @@ else: libpq += 'so' if e: - # note: we could use "raise from e" here in Python 3 raise ImportError( "Cannot import shared library for PyGreSQL,\n" - "probably because no %s is installed.\n%s" % (libpq, e)) + "probably because no %s is installed.\n%s" % (libpq, e)) from e __version__ = version @@ -85,165 +84,24 @@ import warnings import weakref -from datetime import date, time, datetime, timedelta, tzinfo +from datetime import date, time, datetime, timedelta from decimal import Decimal from math import isnan, isinf from collections import namedtuple, OrderedDict +from inspect import signature from operator import itemgetter -from functools import partial +from functools import lru_cache, partial from re import compile as regex from json import loads as jsondecode, dumps as jsonencode from uuid import UUID from typing import Dict, List, Union # noqa: F401 -try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable - long -except NameError: # Python >= 3.0 - long = int - -try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable - unicode -except NameError: # Python >= 3.0 - unicode = str - -try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable - basestring -except NameError: # Python >= 3.0 - basestring = (str, bytes) - -try: - from functools import lru_cache -except ImportError: # Python < 3.2 - from functools import update_wrapper - try: # noinspection PyCompatibility - from _thread import RLock - except ImportError: - class RLock: # for builds without threads - def __enter__(self): - pass - - def __exit__(self, exctype, excinst, exctb): - pass - - def lru_cache(maxsize=128): - """Simplified functools.lru_cache decorator for one argument.""" - - def decorator(function): - sentinel = object() - cache = {} - get = cache.get - lock = RLock() - root = [] - root_full = [root, False] - root[:] = [root, root, None, None] - - if maxsize == 0: - - def wrapper(arg): - res = function(arg) - return res - - elif maxsize is None: - - def wrapper(arg): - res = get(arg, sentinel) - if res is not sentinel: - return res - res = function(arg) - cache[arg] = res - return res - - else: - - def wrapper(arg): - with lock: - link = get(arg) - if link is not None: - root = root_full[0] - prv, nxt, _arg, res = link - prv[1] = nxt - nxt[0] = prv - last = root[0] - last[1] = root[0] = link - link[0] = last - link[1] = root - return res - res = function(arg) - with lock: - root, full = root_full - if arg in cache: - pass - elif full: - oldroot = root - oldroot[2] = arg - oldroot[3] = res - root = root_full[0] = oldroot[1] - oldarg = root[2] - oldres = root[3] # noqa F481 (keep reference) - root[2] = root[3] = None - del cache[oldarg] - cache[arg] = oldroot - else: - last = root[0] - link = [last, root, arg, res] - last[1] = root[0] = cache[arg] = link - if len(cache) >= maxsize: - root_full[1] = True - return res - - wrapper.__wrapped__ = function - return update_wrapper(wrapper, function) - - return decorator - # Auxiliary classes and functions that are independent of a DB connection: -try: # noinspection PyUnresolvedReferences - from inspect import signature -except ImportError: # Python < 3.3 - from inspect import getargspec - - def get_args(func): - return getargspec(func).args -else: - - def get_args(func): - return list(signature(func).parameters) - -try: - from datetime import timezone -except ImportError: # Python < 3.2 - - class timezone(tzinfo): - """Simple timezone implementation.""" - - def __init__(self, offset, name=None): - self.offset = offset - if not name: - minutes = self.offset.days * 1440 + self.offset.seconds // 60 - if minutes < 0: - hours, minutes = divmod(-minutes, 60) - hours = -hours - else: - hours, minutes = divmod(minutes, 60) - name = 'UTC%+03d:%02d' % (hours, minutes) - self.name = name - - def utcoffset(self, dt): - return self.offset +def get_args(func): + return list(signature(func).parameters) - def tzname(self, dt): - return self.name - - def dst(self, dt): - return None - - timezone.utc = timezone(timedelta(0), 'UTC') - - _has_timezone = False -else: - _has_timezone = True # time zones used in Postgres timestamptz output _timezones = dict(CET='+0100', EET='+0200', EST='-0500', @@ -259,14 +117,6 @@ def _timezone_as_offset(tz): return _timezones.get(tz, '+0000') -def _get_timezone(tz): - tz = _timezone_as_offset(tz) - minutes = 60 * int(tz[1:3]) + int(tz[3:5]) - if tz[0] == '-': - minutes = -minutes - return timezone(timedelta(minutes=minutes), tz) - - def _oid_key(table): """Build oid key from a table name.""" return 'oid(%s)' % table @@ -285,7 +135,7 @@ class Hstore(dict): def _quote(cls, s): if s is None: return 'NULL' - if not isinstance(s, basestring): + if not isinstance(s, str): s = str(s) if not s: return '""' @@ -308,7 +158,7 @@ def __init__(self, obj, encode=None): def __str__(self): obj = self.obj - if isinstance(obj, basestring): + if isinstance(obj, str): return obj return self.encode(obj) @@ -330,8 +180,7 @@ class _SimpleTypes(dict): 'int': ['cid', 'int2', 'int4', 'int8', 'oid', 'xid', int], 'hstore': [Hstore], 'json': ['jsonb', Json], 'uuid': [UUID], 'num': ['numeric', Decimal], 'money': [], - 'text': ['bpchar', 'char', 'name', 'varchar', - bytes, unicode, basestring] + 'text': ['bpchar', 'char', 'name', 'varchar', bytes, str] } # type: Dict[str, List[Union[str, type]]] # noinspection PyMissingConstructor @@ -369,7 +218,7 @@ def _quote_if_unqualified(param, name): (could be a qualified name or just a name with a dot in it) and must be quoted manually by the caller. """ - if isinstance(name, basestring) and '.' not in name: + if isinstance(name, str) and '.' not in name: return 'quote_ident(%s)' % (param,) return param @@ -440,7 +289,7 @@ def __init__(self, db): @classmethod def _adapt_bool(cls, v): """Adapt a boolean parameter.""" - if isinstance(v, basestring): + if isinstance(v, str): if not v: return None v = v.lower() in cls._bool_true_values @@ -451,7 +300,7 @@ def _adapt_date(cls, v): """Adapt a date parameter.""" if not v: return None - if isinstance(v, basestring) and v.lower() in cls._date_literals: + if isinstance(v, str) and v.lower() in cls._date_literals: return Literal(v) return v @@ -472,7 +321,7 @@ def _adapt_json(self, v): """Adapt a json parameter.""" if not v: return None - if isinstance(v, basestring): + if isinstance(v, str): return v if isinstance(v, Json): return str(v) @@ -482,7 +331,7 @@ def _adapt_hstore(self, v): """Adapt a hstore parameter.""" if not v: return None - if isinstance(v, basestring): + if isinstance(v, str): return v if isinstance(v, Hstore): return str(v) @@ -494,7 +343,7 @@ def _adapt_uuid(self, v): """Adapt a UUID parameter.""" if not v: return None - if isinstance(v, basestring): + if isinstance(v, str): return v return str(v) @@ -523,7 +372,7 @@ def _adapt_bool_array(cls, v): return '{%s}' % ','.join(adapt(v) for v in v) if v is None: return 'null' - if isinstance(v, basestring): + if isinstance(v, str): if not v: return 'null' v = v.lower() in cls._bool_true_values @@ -558,7 +407,7 @@ def _adapt_json_array(self, v): return '{%s}' % ','.join(adapt(v) for v in v) if not v: return 'null' - if not isinstance(v, basestring): + if not isinstance(v, str): v = self.db.encode_json(v) if self._re_array_quote.search(v): v = '"%s"' % self._re_array_escape.sub(r'\\\1', v) @@ -642,11 +491,11 @@ def guess_simple_type(cls, value): return _simple_type_dict[type(value)] except KeyError: pass - if isinstance(value, basestring): + if isinstance(value, (bytes, str)): return 'text' if isinstance(value, bool): return 'bool' - if isinstance(value, (int, long)): + if isinstance(value, int): return 'int' if isinstance(value, float): return 'float' @@ -695,12 +544,10 @@ def adapt_inline(self, value, nested=False): if isinstance(value, Literal): return value if isinstance(value, Bytea): - value = self.db.escape_bytea(value) - if bytes is not str: # Python >= 3.0 - value = value.decode('ascii') + value = self.db.escape_bytea(value).decode('ascii') elif isinstance(value, (datetime, date, time, timedelta)): value = str(value) - if isinstance(value, basestring): + if isinstance(value, (bytes, str)): value = self.db.escape_string(value) return "'%s'" % value if isinstance(value, bool): @@ -711,7 +558,7 @@ def adapt_inline(self, value, nested=False): if isnan(value): return "'NaN'" return value - if isinstance(value, (int, long, Decimal)): + if isinstance(value, (int, Decimal)): return value if isinstance(value, list): q = self.adapt_inline @@ -767,7 +614,7 @@ def format_query(self, command, values=None, types=None, inline=False): else: add = params.add if types: - if isinstance(types, basestring): + if isinstance(types, str): types = types.split() if (not isinstance(types, (list, tuple)) or len(types) != len(values)): @@ -884,12 +731,9 @@ def cast_timetz(value): else: tz = '+0000' fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' - if _has_timezone: - value += _timezone_as_offset(tz) - fmt += '%z' - return datetime.strptime(value, fmt).timetz() - return datetime.strptime(value, fmt).timetz().replace( - tzinfo=_get_timezone(tz)) + value += _timezone_as_offset(tz) + fmt += '%z' + return datetime.strptime(value, fmt).timetz() def cast_timestamp(value, connection): @@ -944,12 +788,9 @@ def cast_timestamptz(value, connection): if len(value[0]) > 10: return datetime.max fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S'] - if _has_timezone: - value.append(_timezone_as_offset(tz)) - fmt.append('%z') - return datetime.strptime(' '.join(value), ' '.join(fmt)) - return datetime.strptime(' '.join(value), ' '.join(fmt)).replace( - tzinfo=_get_timezone(tz)) + value.append(_timezone_as_offset(tz)) + fmt.append('%z') + return datetime.strptime(' '.join(value), ' '.join(fmt)) _re_interval_sql_standard = regex( @@ -1057,7 +898,7 @@ class Typecasts(dict): 'char': str, 'bpchar': str, 'name': str, 'text': str, 'varchar': str, 'sql_identifier': str, 'bool': cast_bool, 'bytea': unescape_bytea, - 'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int, + 'int2': int, 'int4': int, 'serial': int, 'int8': int, 'oid': int, 'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json, 'float4': float, 'float8': float, 'numeric': cast_num, 'money': cast_money, @@ -1117,7 +958,7 @@ def get(self, typ, default=None): def set(self, typ, cast): """Set a typecast function for the specified database type(s).""" - if isinstance(typ, basestring): + if isinstance(typ, str): typ = [typ] if cast is None: for t in typ: @@ -1138,7 +979,7 @@ def reset(self, typ=None): if typ is None: self.clear() else: - if isinstance(typ, basestring): + if isinstance(typ, str): typ = [typ] for t in typ: self.pop(t, None) @@ -1151,7 +992,7 @@ def get_default(cls, typ): @classmethod def set_default(cls, typ, cast): """Set a default typecast function for the given database type(s).""" - if isinstance(typ, basestring): + if isinstance(typ, str): typ = [typ] defaults = cls.defaults if cast is None: @@ -1716,7 +1557,7 @@ def _do_debug(self, *args): """Print a debug message""" if self.debug: s = '\n'.join(str(arg) for arg in args) - if isinstance(self.debug, basestring): + if isinstance(self.debug, str): print(self.debug % s) elif hasattr(self.debug, 'write'): # noinspection PyCallingNonCallable @@ -1858,7 +1699,7 @@ def get_parameter(self, parameter): By passing the special name 'all' as the parameter, you can get a dict of all existing configuration parameters. """ - if isinstance(parameter, basestring): + if isinstance(parameter, str): parameter = [parameter] values = None elif isinstance(parameter, (list, tuple)): @@ -1875,7 +1716,7 @@ def get_parameter(self, parameter): params = {} if isinstance(values, dict) else [] for key in parameter: param = key.strip().lower() if isinstance( - key, basestring) else None + key, (bytes, str)) else None if not param: raise TypeError('Invalid parameter') if param == 'all': @@ -1923,7 +1764,7 @@ def set_parameter(self, parameter, value=None, local=False): have no effect if it is executed outside a transaction, since the transaction will end immediately. """ - if isinstance(parameter, basestring): + if isinstance(parameter, str): parameter = {parameter: value} elif isinstance(parameter, (list, tuple)): if isinstance(value, (list, tuple)): @@ -1935,7 +1776,7 @@ def set_parameter(self, parameter, value=None, local=False): value = set(value) if len(value) == 1: value = value.pop() - if not (value is None or isinstance(value, basestring)): + if not (value is None or isinstance(value, str)): raise ValueError( 'A single value must be specified' ' when parameter is a set') @@ -1953,7 +1794,7 @@ def set_parameter(self, parameter, value=None, local=False): params = {} for key, value in parameter.items(): param = key.strip().lower() if isinstance( - key, basestring) else None + key, str) else None if not param: raise TypeError('Invalid parameter') if param == 'all': @@ -2272,7 +2113,7 @@ def get(self, table, row, keyname=None): table = table[:-1].rstrip() attnames = self.get_attnames(table) qoid = _oid_key(table) if 'oid' in attnames else None - if keyname and isinstance(keyname, basestring): + if keyname and isinstance(keyname, str): keyname = (keyname,) if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row: row['oid'] = row[qoid] @@ -2513,7 +2354,7 @@ def upsert(self, table, row=None, **kw): if n not in keyname and n not in generated: value = kw.get(n, n in row) if value: - if not isinstance(value, basestring): + if not isinstance(value, str): value = 'excluded.%s' % col(n) update.append('%s = %s' % (col(n), value)) if not values: @@ -2631,7 +2472,7 @@ def truncate(self, table, restart=False, cascade=False, only=False): can be specified after the table name to explicitly indicate that descendant tables are included. """ - if isinstance(table, basestring): + if isinstance(table, str): only = {table: only} table = [table] elif isinstance(table, (list, tuple)): @@ -2764,7 +2605,7 @@ def get_as_dict(self, table, keyname=None, what=None, where=None, keyname = self.pkey(table, True) except (KeyError, ProgrammingError): raise _prg_error('Table %s has no primary key' % table) - if isinstance(keyname, basestring): + if isinstance(keyname, str): keyname = [keyname] elif not isinstance(keyname, (list, tuple)): raise KeyError('The keyname must be a string, list or tuple') diff --git a/pgdb.py b/pgdb.py index 2919caf3..85767e3a 100644 --- a/pgdb.py +++ b/pgdb.py @@ -90,10 +90,9 @@ else: libpq += 'so' if e: - # note: we could use "raise from e" here in Python 3 raise ImportError( "Cannot import shared library for PyGreSQL,\n" - "probably because no %s is installed.\n%s" % (libpq, e)) + "probably because no %s is installed.\n%s" % (libpq, e)) from e __version__ = version @@ -114,122 +113,20 @@ 'get_typecast', 'set_typecast', 'reset_typecast', 'version', '__version__'] -from datetime import date, time, datetime, timedelta, tzinfo +from datetime import date, time, datetime, timedelta from time import localtime from decimal import Decimal as StdDecimal from uuid import UUID as Uuid from math import isnan, isinf -try: # noinspection PyCompatibility - from collections.abc import Iterable -except ImportError: # Python < 3.3 - from collections import Iterable from collections import namedtuple -from functools import partial +from collections.abc import Iterable +from inspect import signature +from functools import lru_cache, partial from re import compile as regex from json import loads as jsondecode, dumps as jsonencode Decimal = StdDecimal -try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable - long -except NameError: # Python >= 3.0 - long = int - -try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable - unicode -except NameError: # Python >= 3.0 - unicode = str - -try: # noinspection PyUnresolvedReferences,PyUnboundLocalVariable - basestring -except NameError: # Python >= 3.0 - basestring = (str, bytes) - -try: - from functools import lru_cache -except ImportError: # Python < 3.2 - from functools import update_wrapper - try: # noinspection PyCompatibility - from _thread import RLock - except ImportError: - class RLock: # for builds without threads - def __enter__(self): - pass - - def __exit__(self, exctype, excinst, exctb): - pass - - def lru_cache(maxsize=128): - """Simplified functools.lru_cache decorator for one argument.""" - - def decorator(function): - sentinel = object() - cache = {} - get = cache.get - lock = RLock() - root = [] - root_full = [root, False] - root[:] = [root, root, None, None] - - if maxsize == 0: - - def wrapper(arg): - res = function(arg) - return res - - elif maxsize is None: - - def wrapper(arg): - res = get(arg, sentinel) - if res is not sentinel: - return res - res = function(arg) - cache[arg] = res - return res - - else: - - def wrapper(arg): - with lock: - link = get(arg) - if link is not None: - root = root_full[0] - prv, nxt, _arg, res = link - prv[1] = nxt - nxt[0] = prv - last = root[0] - last[1] = root[0] = link - link[0] = last - link[1] = root - return res - res = function(arg) - with lock: - root, full = root_full - if arg in cache: - pass - elif full: - oldroot = root - oldroot[2] = arg - oldroot[3] = res - root = root_full[0] = oldroot[1] - oldarg = root[2] - oldres = root[3] # noqa F481 (keep reference) - root[2] = root[3] = None - del cache[oldarg] - cache[arg] = oldroot - else: - last = root[0] - link = [last, root, arg, res] - last[1] = root[0] = cache[arg] = link - if len(cache) >= maxsize: - root_full[1] = True - return res - - wrapper.__wrapped__ = function - return update_wrapper(wrapper, function) - - return decorator - # *** Module Constants *** @@ -249,51 +146,9 @@ def wrapper(arg): # *** Internal Type Handling *** -try: # noinspection PyUnresolvedReferences - from inspect import signature -except ImportError: # Python < 3.3 - from inspect import getargspec +def get_args(func): + return list(signature(func).parameters) - def get_args(func): - return getargspec(func).args -else: - - def get_args(func): - return list(signature(func).parameters) - -try: - from datetime import timezone -except ImportError: # Python < 3.2 - - class timezone(tzinfo): - """Simple timezone implementation.""" - - def __init__(self, offset, name=None): - self.offset = offset - if not name: - minutes = self.offset.days * 1440 + self.offset.seconds // 60 - if minutes < 0: - hours, minutes = divmod(-minutes, 60) - hours = -hours - else: - hours, minutes = divmod(minutes, 60) - name = 'UTC%+03d:%02d' % (hours, minutes) - self.name = name - - def utcoffset(self, dt): - return self.offset - - def tzname(self, dt): - return self.name - - def dst(self, dt): - return None - - timezone.utc = timezone(timedelta(0), 'UTC') - - _has_timezone = False -else: - _has_timezone = True # time zones used in Postgres timestamptz output _timezones = dict(CET='+0100', EET='+0200', EST='-0500', @@ -309,14 +164,6 @@ def _timezone_as_offset(tz): return _timezones.get(tz, '+0000') -def _get_timezone(tz): - tz = _timezone_as_offset(tz) - minutes = 60 * int(tz[1:3]) + int(tz[3:5]) - if tz[0] == '-': - minutes = -minutes - return timezone(timedelta(minutes=minutes), tz) - - def decimal_type(decimal_type=None): """Get or set global type to be used for decimal values. @@ -385,12 +232,9 @@ def cast_timetz(value): else: tz = '+0000' fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' - if _has_timezone: - value += _timezone_as_offset(tz) - fmt += '%z' - return datetime.strptime(value, fmt).timetz() - return datetime.strptime(value, fmt).timetz().replace( - tzinfo=_get_timezone(tz)) + value += _timezone_as_offset(tz) + fmt += '%z' + return datetime.strptime(value, fmt).timetz() def cast_timestamp(value, connection): @@ -445,12 +289,9 @@ def cast_timestamptz(value, connection): if len(value[0]) > 10: return datetime.max fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S'] - if _has_timezone: - value.append(_timezone_as_offset(tz)) - fmt.append('%z') - return datetime.strptime(' '.join(value), ' '.join(fmt)) - return datetime.strptime(' '.join(value), ' '.join(fmt)).replace( - tzinfo=_get_timezone(tz)) + value.append(_timezone_as_offset(tz)) + fmt.append('%z') + return datetime.strptime(' '.join(value), ' '.join(fmt)) _re_interval_sql_standard = regex( @@ -555,7 +396,7 @@ class Typecasts(dict): 'char': str, 'bpchar': str, 'name': str, 'text': str, 'varchar': str, 'sql_identifier': str, 'bool': cast_bool, 'bytea': unescape_bytea, - 'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int, + 'int2': int, 'int4': int, 'serial': int, 'int8': int, 'oid': int, 'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode, 'float4': float, 'float8': float, 'numeric': Decimal, 'money': cast_money, @@ -611,7 +452,7 @@ def get(self, typ, default=None): def set(self, typ, cast): """Set a typecast function for the specified database type(s).""" - if isinstance(typ, basestring): + if isinstance(typ, str): typ = [typ] if cast is None: for t in typ: @@ -634,7 +475,7 @@ def reset(self, typ=None): self.clear() self.update(defaults) else: - if isinstance(typ, basestring): + if isinstance(typ, str): typ = [typ] for t in typ: cast = defaults.get(t) @@ -967,11 +808,9 @@ def _quote(self, value): return 'NULL' if isinstance(value, (Hstore, Json)): value = str(value) - if isinstance(value, basestring): + if isinstance(value, (bytes, str)): if isinstance(value, Binary): - value = self._cnx.escape_bytea(value) - if bytes is not str: # Python >= 3.0 - value = value.decode('ascii') + value = self._cnx.escape_bytea(value).decode('ascii') else: value = self._cnx.escape_string(value) return "'%s'" % (value,) @@ -981,7 +820,7 @@ def _quote(self, value): if isnan(value): return "'NaN'" return value - if isinstance(value, (int, long, Decimal, Literal)): + if isinstance(value, (int, Decimal, Literal)): return value if isinstance(value, datetime): if value.tzinfo: @@ -1237,10 +1076,10 @@ def copy_from(self, stream, table, input_type = bytes type_name = 'byte strings' else: - input_type = basestring + input_type = (bytes, str) type_name = 'strings' - if isinstance(stream, basestring): + if isinstance(stream, (bytes, str)): if not isinstance(stream, input_type): raise ValueError("The input must be %s" % (type_name,)) if not binary_format: @@ -1291,7 +1130,7 @@ def chunks(): def chunks(): yield read() - if not table or not isinstance(table, basestring): + if not table or not isinstance(table, str): raise TypeError("Need a table to copy to") if table.lower().startswith('select '): raise ValueError("Must specify a table, not a query") @@ -1302,13 +1141,13 @@ def chunks(): options = [] params = [] if format is not None: - if not isinstance(format, basestring): + if not isinstance(format, str): raise TypeError("The format option must be be a string") if format not in ('text', 'csv', 'binary'): raise ValueError("Invalid format") options.append('format %s' % (format,)) if sep is not None: - if not isinstance(sep, basestring): + if not isinstance(sep, str): raise TypeError("The sep option must be a string") if format == 'binary': raise ValueError( @@ -1319,12 +1158,12 @@ def chunks(): options.append('delimiter %s') params.append(sep) if null is not None: - if not isinstance(null, basestring): + if not isinstance(null, str): raise TypeError("The null option must be a string") options.append('null %s') params.append(null) if columns: - if not isinstance(columns, basestring): + if not isinstance(columns, str): columns = ','.join(map( self.connection._cnx.escape_identifier, columns)) operation.append('(%s)' % (columns,)) @@ -1375,7 +1214,7 @@ def copy_to(self, stream, table, write = stream.write except AttributeError: raise TypeError("Need an output stream to copy to") - if not table or not isinstance(table, basestring): + if not table or not isinstance(table, str): raise TypeError("Need a table to copy to") if table.lower().startswith('select '): if columns: @@ -1388,13 +1227,13 @@ def copy_to(self, stream, table, options = [] params = [] if format is not None: - if not isinstance(format, basestring): + if not isinstance(format, str): raise TypeError("The format option must be a string") if format not in ('text', 'csv', 'binary'): raise ValueError("Invalid format") options.append('format %s' % (format,)) if sep is not None: - if not isinstance(sep, basestring): + if not isinstance(sep, str): raise TypeError("The sep option must be a string") if binary_format: raise ValueError( @@ -1405,15 +1244,12 @@ def copy_to(self, stream, table, options.append('delimiter %s') params.append(sep) if null is not None: - if not isinstance(null, basestring): + if not isinstance(null, str): raise TypeError("The null option must be a string") options.append('null %s') params.append(null) if decode is None: - if format == 'binary': - decode = False - else: - decode = str is unicode + decode = format != 'binary' else: if not isinstance(decode, (int, bool)): raise TypeError("The decode option must be a boolean") @@ -1421,7 +1257,7 @@ def copy_to(self, stream, table, raise ValueError( "The decode option is not allowed with binary format") if columns: - if not isinstance(columns, basestring): + if not isinstance(columns, str): columns = ','.join(map( self.connection._cnx.escape_identifier, columns)) operation.append('(%s)' % (columns,)) @@ -1730,12 +1566,12 @@ class Type(frozenset): """ def __new__(cls, values): - if isinstance(values, basestring): + if isinstance(values, str): values = values.split() return super(Type, cls).__new__(cls, values) def __eq__(self, other): - if isinstance(other, basestring): + if isinstance(other, str): if other.startswith('_'): other = other[1:] return other in self @@ -1743,7 +1579,7 @@ def __eq__(self, other): return super(Type, self).__eq__(other) def __ne__(self, other): - if isinstance(other, basestring): + if isinstance(other, str): if other.startswith('_'): other = other[1:] return other not in self @@ -1755,13 +1591,13 @@ class ArrayType: """Type class for PostgreSQL array types.""" def __eq__(self, other): - if isinstance(other, basestring): + if isinstance(other, str): return other.startswith('_') else: return isinstance(other, ArrayType) def __ne__(self, other): - if isinstance(other, basestring): + if isinstance(other, str): return not other.startswith('_') else: return not isinstance(other, ArrayType) @@ -1774,7 +1610,7 @@ def __eq__(self, other): if isinstance(other, TypeCode): # noinspection PyUnresolvedReferences return other.type == 'c' - elif isinstance(other, basestring): + elif isinstance(other, str): return other == 'record' else: return isinstance(other, RecordType) @@ -1783,7 +1619,7 @@ def __ne__(self, other): if isinstance(other, TypeCode): # noinspection PyUnresolvedReferences return other.type != 'c' - elif isinstance(other, basestring): + elif isinstance(other, str): return other != 'record' else: return not isinstance(other, RecordType) @@ -1884,7 +1720,7 @@ class Hstore(dict): def _quote(cls, s): if s is None: return 'NULL' - if not isinstance(s, basestring): + if not isinstance(s, str): s = str(s) if not s: return '""' @@ -1908,7 +1744,7 @@ def __init__(self, obj, encode=None): def __str__(self): obj = self.obj - if isinstance(obj, basestring): + if isinstance(obj, str): return obj return self.encode(obj) From b1fcd3b61d053988aab4038d2723fb9aa8a747da Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 20:47:15 +0200 Subject: [PATCH 024/118] Simplify test modules assuming modern Python --- docs/contents/pgdb/cursor.rst | 2 +- tests/dbapi20.py | 28 ++---- tests/test_classic_connection.py | 143 +++++++++---------------------- tests/test_classic_dbwrapper.py | 72 +++++----------- tests/test_classic_functions.py | 18 +--- tests/test_dbapi20.py | 38 +++----- tests/test_dbapi20_copy.py | 17 ++-- 7 files changed, 87 insertions(+), 231 deletions(-) diff --git a/docs/contents/pgdb/cursor.rst b/docs/contents/pgdb/cursor.rst index 52d600e8..e1ed8b0f 100644 --- a/docs/contents/pgdb/cursor.rst +++ b/docs/contents/pgdb/cursor.rst @@ -295,7 +295,7 @@ specified, all of them will be copied. :param str null: the textual representation of the ``NULL`` value, can also be an empty string (the default is ``'\\N'``) :param bool decode: whether decoded strings shall be returned - for non-binary formats (the default is True in Python 3) + for non-binary formats (the default is ``True``) :param list column: an optional list of column names :returns: a generator if stream is set to ``None``, otherwise the cursor diff --git a/tests/dbapi20.py b/tests/dbapi20.py index 2bb7e2b0..b793fbf2 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -1,4 +1,5 @@ #!/usr/bin/python + """Python DB API 2.0 driver compliance unit test suite. This software is Public Domain and may be used without restrictions. @@ -9,23 +10,6 @@ import unittest import time -try: # noinspection PyUnresolvedReferences - _BaseException = StandardError # noqa: F821 -except NameError: # Python >= 3.0 - _BaseException = Exception - -try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences - unicode -except NameError: # Python >= 3.0 - unicode = str - - -def str2bytes(sval): - if str is not unicode and isinstance(sval, str): - # noinspection PyUnresolvedReferences - sval = sval.decode("latin1") - return sval.encode("latin1") # python 3 make unicode into bytes - class DatabaseAPI20Test(unittest.TestCase): """Test a database self.driver for DB API 2.0 compatibility. @@ -103,7 +87,7 @@ def tearDown(self): pass finally: con.close() - except _BaseException: + except Exception: pass def _connect(self): @@ -151,8 +135,8 @@ def test_Exceptions(self): # Make sure required exceptions exist, and are in the # defined hierarchy. sub = issubclass - self.assertTrue(sub(self.driver.Warning, _BaseException)) - self.assertTrue(sub(self.driver.Error, _BaseException)) + self.assertTrue(sub(self.driver.Warning, Exception)) + self.assertTrue(sub(self.driver.Error, Exception)) self.assertTrue(sub(self.driver.InterfaceError, self.driver.Error)) self.assertTrue(sub(self.driver.DatabaseError, self.driver.Error)) @@ -805,8 +789,8 @@ def test_Timestamp(self): self.assertEqual(str(t1), str(t2)) def test_Binary(self): - self.driver.Binary(str2bytes('Something')) - self.driver.Binary(str2bytes('')) + self.driver.Binary(b'Something') + self.driver.Binary(b'') def test_STRING(self): self.assertTrue(hasattr(self.driver, 'STRING'), diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 8c7adc39..0dedc5c4 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -15,31 +15,13 @@ import os from collections import namedtuple - -try: - # noinspection PyCompatibility - from collections.abc import Iterable -except ImportError: # Python < 3.3 - from collections import Iterable - +from collections.abc import Iterable from decimal import Decimal import pg # the module under test from .config import dbname, dbhost, dbport, dbuser, dbpasswd -try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences - long -except NameError: # Python >= 3.0 - long = int - -try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences - unicode -except NameError: # Python >= 3.0 - unicode = str - -unicode_strings = str is not bytes - windows = os.name == 'nt' # There is a known a bug in libpq under Windows which can cause @@ -462,10 +444,10 @@ def testGetresult(self): def testGetresultLong(self): q = "select 9876543210" - result = long(9876543210) - self.assertIsInstance(result, long) + result = 9876543210 + self.assertIsInstance(result, int) v = self.c.query(q).getresult()[0][0] - self.assertIsInstance(v, long) + self.assertIsInstance(v, int) self.assertEqual(v, result) def testGetresultDecimal(self): @@ -506,10 +488,10 @@ def testDictresult(self): def testDictresultLong(self): q = "select 9876543210 as longjohnsilver" - result = long(9876543210) - self.assertIsInstance(result, long) + result = 9876543210 + self.assertIsInstance(result, int) v = self.c.query(q).dictresult()[0]['longjohnsilver'] - self.assertIsInstance(v, long) + self.assertIsInstance(v, int) self.assertEqual(v, result) def testDictresultDecimal(self): @@ -839,7 +821,7 @@ def testMemSize(self): query = self.c.query q = query("select repeat('foo!', 8)") size = q.memsize() - self.assertIsInstance(size, long) + self.assertIsInstance(size, int) self.assertGreaterEqual(size, 32) self.assertLess(size, 8000) q = query("select repeat('foo!', 2000)") @@ -875,8 +857,6 @@ def testDictresulAscii(self): def testGetresultUtf8(self): result = u'Hello, wörld & мир!' q = u"select '%s'" % result - if not unicode_strings: - result = result.encode('utf8') # pass the query as unicode try: v = self.c.query(q).getresult()[0][0] @@ -894,8 +874,6 @@ def testGetresultUtf8(self): def testDictresultUtf8(self): result = u'Hello, wörld & мир!' q = u"select '%s' as greeting" % result - if not unicode_strings: - result = result.encode('utf8') try: v = self.c.query(q).dictresult()[0]['greeting'] except (pg.DataError, pg.NotSupportedError): @@ -915,8 +893,6 @@ def testGetresultLatin1(self): self.skipTest("database does not support latin1") result = u'Hello, wörld!' q = u"select '%s'" % result - if not unicode_strings: - result = result.encode('latin1') v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -932,8 +908,6 @@ def testDictresultLatin1(self): self.skipTest("database does not support latin1") result = u'Hello, wörld!' q = u"select '%s' as greeting" % result - if not unicode_strings: - result = result.encode('latin1') v = self.c.query(q).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -949,8 +923,6 @@ def testGetresultCyrillic(self): self.skipTest("database does not support cyrillic") result = u'Hello, мир!' q = u"select '%s'" % result - if not unicode_strings: - result = result.encode('cyrillic') v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -966,8 +938,6 @@ def testDictresultCyrillic(self): self.skipTest("database does not support cyrillic") result = u'Hello, мир!' q = u"select '%s' as greeting" % result - if not unicode_strings: - result = result.encode('cyrillic') v = self.c.query(q).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -983,8 +953,6 @@ def testGetresultLatin9(self): self.skipTest("database does not support latin9") result = u'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' q = u"select '%s'" % result - if not unicode_strings: - result = result.encode('latin9') v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -1000,8 +968,6 @@ def testDictresultLatin9(self): self.skipTest("database does not support latin9") result = u'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' q = u"select '%s' as menu" % result - if not unicode_strings: - result = result.encode('latin9') v = self.c.query(q).dictresult()[0]['menu'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -1138,20 +1104,14 @@ def testQueryWithUnicodeParamsLatin1(self): except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult() - if unicode_strings: - self.assertEqual(r, [('Hello, wörld!',)]) - else: - self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)]) + self.assertEqual(r, [('Hello, wörld!',)]) self.assertRaises( UnicodeError, query, "select $1||', '||$2||'!'", ('Hello', u'мир')) query('set client_encoding=iso_8859_1') r = query( "select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult() - if unicode_strings: - self.assertEqual(r, [('Hello, wörld!',)]) - else: - self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)]) + self.assertEqual(r, [('Hello, wörld!',)]) self.assertRaises( UnicodeError, query, "select $1||', '||$2||'!'", ('Hello', u'мир')) @@ -1173,10 +1133,7 @@ def testQueryWithUnicodeParamsCyrillic(self): ('Hello', u'wörld')) r = query( "select $1||', '||$2||'!'", ('Hello', u'мир')).getresult() - if unicode_strings: - self.assertEqual(r, [('Hello, мир!',)]) - else: - self.assertEqual(r, [(u'Hello, мир!'.encode('cyrillic'),)]) + self.assertEqual(r, [('Hello, мир!',)]) query('set client_encoding=sql_ascii') self.assertRaises( UnicodeError, query, "select $1||', '||$2||'!'", @@ -1337,7 +1294,7 @@ def testInt(self): self.assert_proper_cast(0, 'xid', int) def testLong(self): - self.assert_proper_cast(0, 'bigint', long) + self.assert_proper_cast(0, 'bigint', int) def testFloat(self): self.assert_proper_cast(0, 'float', float) @@ -1806,22 +1763,22 @@ def tearDown(self): self.c.close() data = [ - (-1, -1, long(-1), True, '1492-10-12', '08:30:00', + (-1, -1, -1, True, '1492-10-12', '08:30:00', -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'), - (0, 0, long(0), False, '1607-04-14', '09:00:00', + (0, 0, 0, False, '1607-04-14', '09:00:00', 0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'), - (1, 1, long(1), True, '1801-03-04', '03:45:00', + (1, 1, 1, True, '1801-03-04', '03:45:00', 1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'), - (2, 2, long(2), False, '1903-12-17', '11:22:00', + (2, 2, 2, False, '1903-12-17', '11:22:00', 2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')] @classmethod def db_len(cls, s, encoding): # noinspection PyUnresolvedReferences if cls.has_encoding: - s = s if isinstance(s, unicode) else s.decode(encoding) + s = s if isinstance(s, str) else s.decode(encoding) else: - s = s.encode(encoding) if isinstance(s, unicode) else s + s = s.encode(encoding) if isinstance(s, str) else s return len(s) def get_back(self, encoding='utf-8'): @@ -1835,7 +1792,7 @@ def get_back(self, encoding='utf-8'): if row[1] is not None: # integer self.assertIsInstance(row[1], int) if row[2] is not None: # bigint - self.assertIsInstance(row[2], long) + self.assertIsInstance(row[2], int) if row[3] is not None: # boolean self.assertIsInstance(row[3], bool) if row[4] is not None: # date @@ -2039,7 +1996,7 @@ def testInserttableWithOutOfRangeData(self): ValueError, self.c.inserttable, 'test', [[33000]], ['i2']) def testInserttableMaxValues(self): - data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1), + data = [(2 ** 15 - 1, 2 ** 31 - 1, 2 ** 31 - 1, True, '2999-12-31', '11:59:59', 1e99, 1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None, "1", "1234", "1234", "1234" * 100)] @@ -2054,16 +2011,15 @@ def testInserttableByteValues(self): # non-ascii chars do not fit in char(1) when there is no encoding c = u'€' if self.has_encoding else u'$' row_unicode = ( - 0, 0, long(0), False, u'1970-01-01', u'00:00:00', + 0, 0, 0, False, u'1970-01-01', u'00:00:00', 0.0, 0.0, 0.0, u'0.0', c, u'bäd', u'bäd', u"käse сыр pont-l'évêque") row_bytes = tuple( - s.encode('utf-8') if isinstance(s, unicode) else s + s.encode('utf-8') if isinstance(s, str) else s for s in row_unicode) data = [row_bytes] * 2 self.c.inserttable('test', data) - if unicode_strings: - data = [row_unicode] * 2 + data = [row_unicode] * 2 self.assertEqual(self.get_back(), data) def testInserttableUnicodeUtf8(self): @@ -2074,16 +2030,11 @@ def testInserttableUnicodeUtf8(self): # non-ascii chars do not fit in char(1) when there is no encoding c = u'€' if self.has_encoding else u'$' row_unicode = ( - 0, 0, long(0), False, u'1970-01-01', u'00:00:00', + 0, 0, 0, False, u'1970-01-01', u'00:00:00', 0.0, 0.0, 0.0, u'0.0', c, u'bäd', u'bäd', u"käse сыр pont-l'évêque") data = [row_unicode] * 2 self.c.inserttable('test', data) - if not unicode_strings: - row_bytes = tuple( - s.encode('utf-8') if isinstance(s, unicode) else s - for s in row_unicode) - data = [row_bytes] * 2 self.assertEqual(self.get_back(), data) def testInserttableUnicodeLatin1(self): @@ -2095,22 +2046,17 @@ def testInserttableUnicodeLatin1(self): # non-ascii chars do not fit in char(1) when there is no encoding c = u'€' if self.has_encoding else u'$' row_unicode = ( - 0, 0, long(0), False, u'1970-01-01', u'00:00:00', + 0, 0, 0, False, u'1970-01-01', u'00:00:00', 0.0, 0.0, 0.0, u'0.0', c, u'bäd', u'bäd', u"for käse and pont-l'évêque pay in €") data = [row_unicode] # cannot encode € sign with latin1 encoding self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data) row_unicode = tuple( - s.replace(u'€', u'¥') if isinstance(s, unicode) else s + s.replace(u'€', u'¥') if isinstance(s, str) else s for s in row_unicode) data = [row_unicode] * 2 self.c.inserttable('test', data) - if not unicode_strings: - row_bytes = tuple( - s.encode('latin1') if isinstance(s, unicode) else s - for s in row_unicode) - data = [row_bytes] * 2 self.assertEqual(self.get_back('latin1'), data) def testInserttableUnicodeLatin9(self): @@ -2123,16 +2069,11 @@ def testInserttableUnicodeLatin9(self): # non-ascii chars do not fit in char(1) when there is no encoding c = u'€' if self.has_encoding else u'$' row_unicode = ( - 0, 0, long(0), False, u'1970-01-01', u'00:00:00', + 0, 0, 0, False, u'1970-01-01', u'00:00:00', 0.0, 0.0, 0.0, u'0.0', c, u'bäd', u'bäd', u"for käse and pont-l'évêque pay in €") data = [row_unicode] * 2 self.c.inserttable('test', data) - if not unicode_strings: - row_bytes = tuple( - s.encode('latin9') if isinstance(s, unicode) else s - for s in row_unicode) - data = [row_bytes] * 2 self.assertEqual(self.get_back('latin9'), data) def testInserttableNoEncoding(self): @@ -2140,7 +2081,7 @@ def testInserttableNoEncoding(self): # non-ascii chars do not fit in char(1) when there is no encoding c = u'€' if self.has_encoding else u'$' row_unicode = ( - 0, 0, long(0), False, u'1970-01-01', u'00:00:00', + 0, 0, 0, False, u'1970-01-01', u'00:00:00', 0.0, 0.0, 0.0, u'0.0', c, u'bäd', u'bäd', u"for käse and pont-l'évêque pay in €") data = [row_unicode] @@ -2164,7 +2105,7 @@ def __repr__(self): return s s = '1\'2"3\b4\f5\n6\r7\t8\b9\\0' - s1 = s.encode('ascii') if unicode_strings else s.decode('ascii') + s1 = s.encode('ascii') s2 = S() data = [(t,) for t in (s, s1, s2)] self.c.inserttable('test', data, ['t']) @@ -2596,7 +2537,7 @@ def testSetDecimal(self): pg.set_decimal(decimal_class) self.assertNotIsInstance(r, decimal_class) self.assertIsInstance(r, int) - self.assertEqual(r, int(3425)) + self.assertEqual(r, 3425) def testGetBool(self): use_bool = pg.get_bool() @@ -2725,10 +2666,7 @@ def testSetByteaEscaped(self): self.assertEqual(r, b'data') def testSetRowFactorySize(self): - try: - from functools import lru_cache - except ImportError: # Python < 3.2 - lru_cache = None + from functools import lru_cache queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc'] query = self.c.query for maxsize in (None, 0, 1, 2, 3, 10, 1024): @@ -2742,12 +2680,11 @@ def testSetRowFactorySize(self): else: self.assertEqual(r, (1, 2, 3)) self.assertEqual(r._fields, ('a', 'b', 'c')) - if lru_cache: - info = pg._row_factory.cache_info() - self.assertEqual(info.maxsize, maxsize) - self.assertEqual(info.hits + info.misses, 6) - self.assertEqual( - info.hits, 0 if maxsize is not None and maxsize < 2 else 4) + info = pg._row_factory.cache_info() + self.assertEqual(info.maxsize, maxsize) + self.assertEqual(info.hits + info.misses, 6) + self.assertEqual( + info.hits, 0 if maxsize is not None and maxsize < 2 else 4) class TestStandaloneEscapeFunctions(unittest.TestCase): @@ -2783,13 +2720,13 @@ def testEscapeString(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') r = f(u'plain') - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u'plain') r = f(u"das is' käse".encode('utf-8')) self.assertIsInstance(r, bytes) self.assertEqual(r, u"das is'' käse".encode('utf-8')) r = f(u"that's cheesy") - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u"that''s cheesy") r = f(r"It's bad to have a \ inside.") self.assertEqual(r, r"It''s bad to have a \\ inside.") @@ -2801,13 +2738,13 @@ def testEscapeBytea(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') r = f(u'plain') - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u'plain') r = f(u"das is' käse".encode('utf-8')) self.assertIsInstance(r, bytes) self.assertEqual(r, b"das is'' k\\\\303\\\\244se") r = f(u"that's cheesy") - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u"that''s cheesy") r = f(b'O\x00ps\xff!') self.assertEqual(r, b'O\\\\000ps\\\\377!') diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 0843710d..25c3c11d 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -21,6 +21,7 @@ from collections import OrderedDict from decimal import Decimal from datetime import date, time, datetime, timedelta +from io import StringIO from uuid import UUID from time import strftime from operator import itemgetter @@ -29,21 +30,6 @@ debug = False # let DB wrapper print debugging output -try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences - long -except NameError: # Python >= 3.0 - long = int - -try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences - unicode -except NameError: # Python >= 3.0 - unicode = str - -if str is bytes: # noinspection PyCompatibility,PyUnresolvedReferences - from StringIO import StringIO -else: # Python >= 3.0 - from io import StringIO - windows = os.name == 'nt' # There is a known a bug in libpq under Windows which can cause @@ -523,13 +509,13 @@ def testEscapeLiteral(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b"'plain'") r = f(u"plain") - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u"'plain'") r = f(u"that's käse".encode('utf-8')) self.assertIsInstance(r, bytes) self.assertEqual(r, u"'that''s käse'".encode('utf-8')) r = f(u"that's käse") - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u"'that''s käse'") self.assertEqual(f(r"It's fine to have a \ inside."), r" E'It''s fine to have a \\ inside.'") @@ -542,13 +528,13 @@ def testEscapeIdentifier(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b'"plain"') r = f(u"plain") - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u'"plain"') r = f(u"that's käse".encode('utf-8')) self.assertIsInstance(r, bytes) self.assertEqual(r, u'"that\'s käse"'.encode('utf-8')) r = f(u"that's käse") - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u'"that\'s käse"') self.assertEqual(f(r"It's fine to have a \ inside."), '"It\'s fine to have a \\ inside."') @@ -561,13 +547,13 @@ def testEscapeString(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b"plain") r = f(u"plain") - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u"plain") r = f(u"that's käse".encode('utf-8')) self.assertIsInstance(r, bytes) self.assertEqual(r, u"that''s käse".encode('utf-8')) r = f(u"that's käse") - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u"that''s käse") self.assertEqual(f(r"It's fine to have a \ inside."), r"It''s fine to have a \ inside.") @@ -580,13 +566,13 @@ def testEscapeBytea(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b'\\x706c61696e') r = f(u'plain') - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u'\\x706c61696e') r = f(u"das is' käse".encode('utf-8')) self.assertIsInstance(r, bytes) self.assertEqual(r, b'\\x64617320697327206bc3a47365') r = f(u"das is' käse") - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u'\\x64617320697327206bc3a47365') self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21') @@ -623,7 +609,7 @@ def testDecodeJson(self): self.assertIsInstance(r, dict) self.assertEqual(r, data) self.assertIsInstance(r['id'], int) - self.assertIsInstance(r['name'], unicode) + self.assertIsInstance(r['name'], str) self.assertIsInstance(r['price'], float) self.assertIsInstance(r['new'], bool) self.assertIsInstance(r['tags'], list) @@ -1783,8 +1769,7 @@ def testInsert(self): (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)), (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)), dict(i2=42, i4=123456, i8=9876543210), - dict(i2=2 ** 15 - 1, - i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)), + dict(i2=2 ** 15 - 1, i4=2 ** 31 - 1, i8=2 ** 63 - 1), dict(d=None), (dict(d=''), dict(d=None)), dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))), dict(f4=None, f8=None), dict(f4=0, f8=0), @@ -2519,9 +2504,9 @@ def testClear(self): r['a'] = r['f'] = r['n'] = 1 r['d'] = r['t'] = 'x' r['b'] = 't' - r['oid'] = long(1) + r['oid'] = 1 r = clear(table, r) - result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=long(1)) + result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=1) self.assertEqual(r, result) def testClearWithQuotedNames(self): @@ -3455,7 +3440,7 @@ def testInsertGetJson(self): self.assertIsInstance(r, dict) self.assertEqual(r, data) self.assertIsInstance(r['id'], int) - self.assertIsInstance(r['name'], unicode) + self.assertIsInstance(r['name'], str) self.assertIsInstance(r['price'], float) self.assertIsInstance(r['new'], bool) self.assertIsInstance(r['tags'], list) @@ -3472,7 +3457,7 @@ def testInsertGetJson(self): self.assertIsInstance(r, dict) self.assertEqual(r, data) self.assertIsInstance(r['id'], int) - self.assertIsInstance(r['name'], unicode) + self.assertIsInstance(r['name'], str) self.assertIsInstance(r['price'], float) self.assertIsInstance(r['new'], bool) self.assertIsInstance(r['tags'], list) @@ -3525,7 +3510,7 @@ def testInsertGetJsonb(self): self.assertIsInstance(r, dict) self.assertEqual(r, data) self.assertIsInstance(r['id'], int) - self.assertIsInstance(r['name'], unicode) + self.assertIsInstance(r['name'], str) self.assertIsInstance(r['price'], float) self.assertIsInstance(r['new'], bool) self.assertIsInstance(r['tags'], list) @@ -3542,7 +3527,7 @@ def testInsertGetJsonb(self): self.assertIsInstance(r, dict) self.assertEqual(r, data) self.assertIsInstance(r['id'], int) - self.assertIsInstance(r['name'], unicode) + self.assertIsInstance(r['name'], str) self.assertIsInstance(r['price'], float) self.assertIsInstance(r['new'], bool) self.assertIsInstance(r['tags'], list) @@ -3578,8 +3563,7 @@ def testArray(self): data = dict( id=42, i2=[42, 1234, None, 0, -1], i4=[42, 123456789, None, 0, 1, -1], - i8=[long(42), long(123456789123456789), None, - long(0), long(1), long(-1)], + i8=[42, 123456789123456789, None, 0, 1, -1], d=[decimal(42), long_decimal, None, decimal(0), decimal(1), decimal(-1), -long_decimal], f4=[42.0, 1234.5, None, 0.0, 1.0, -1.0, @@ -4053,10 +4037,7 @@ def testTimetz(self): timezones = dict(CET=1, EET=2, EST=-5, UTC=0) for timezone in sorted(timezones): tz = '%+03d00' % timezones[timezone] - try: - tzinfo = datetime.strptime(tz, '%z').tzinfo - except ValueError: # Python < 3.2 - tzinfo = pg._get_timezone(tz) + tzinfo = datetime.strptime(tz, '%z').tzinfo self.db.set_parameter('timezone', timezone) d = time(15, 9, 26, tzinfo=tzinfo) q = "select $1::timetz" @@ -4108,10 +4089,7 @@ def testTimestamptz(self): timezones = dict(CET=1, EET=2, EST=-5, UTC=0) for timezone in sorted(timezones): tz = '%+03d00' % timezones[timezone] - try: - tzinfo = datetime.strptime(tz, '%z').tzinfo - except ValueError: # Python < 3.2 - tzinfo = pg._get_timezone(tz) + tzinfo = datetime.strptime(tz, '%z').tzinfo self.db.set_parameter('timezone', timezone) for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY', 'SQL, MDY', 'SQL, DMY', 'German'): @@ -4546,20 +4524,14 @@ def testAdaptQueryTypedWithHstore(self): value = {'one': "it's fine", 'two': 2} sql, params = format_query("select %s", (value,), 'hstore') self.assertEqual(sql, "select $1") - if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict - params[0] = ','.join(sorted(params[0].split(','))) self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) value = pg.Hstore({'one': "it's fine", 'two': 2}) sql, params = format_query("select %s", (value,), 'hstore') self.assertEqual(sql, "select $1") - if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict - params[0] = ','.join(sorted(params[0].split(','))) self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) value = pg.Hstore({'one': "it's fine", 'two': 2}) sql, params = format_query("select %s", [value], [pg.Hstore]) self.assertEqual(sql, "select $1") - if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict - params[0] = ','.join(sorted(params[0].split(','))) self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) def testAdaptQueryTypedWithUuid(self): @@ -4658,8 +4630,6 @@ def testAdaptQueryUntypedWithHstore(self): value = pg.Hstore({'one': "it's fine", 'two': 2}) sql, params = format_query("select %s", (value,)) self.assertEqual(sql, "select $1") - if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict - params[0] = ','.join(sorted(params[0].split(','))) self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) def testAdaptQueryUntypedDict(self): @@ -4729,8 +4699,6 @@ def testAdaptQueryInlineListWithHstore(self): format_query = self.adapter.format_query value = pg.Hstore({'one': "it's fine", 'two': 2}) sql, params = format_query("select %s", (value,), inline=True) - if sys.version_info[:2] < (3, 6): # Python < 3.6 has unsorted dict - sql = sql[:8] + ','.join(sorted(sql[8:-9].split(','))) + sql[-9:] self.assertEqual( sql, "select 'one=>\"it''s fine\",two=>2'::hstore") self.assertEqual(params, []) diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index db450ec8..282ec6df 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -18,16 +18,6 @@ from datetime import timedelta -try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences - long -except NameError: # Python >= 3.0 - long = int - -try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences - unicode -except NameError: # Python >= 3.0 - unicode = str - class TestHasConnect(unittest.TestCase): """Test existence of basic pg module functions.""" @@ -123,8 +113,8 @@ def testDefBase(self): def testPqlibVersion(self): # noinspection PyUnresolvedReferences v = pg.get_pqlib_version() - self.assertIsInstance(v, long) - self.assertGreater(v, 90000) + self.assertIsInstance(v, int) + self.assertGreater(v, 100000) self.assertLess(v, 160000) @@ -881,7 +871,7 @@ def testEscapeString(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') r = f(u'plain') - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u'plain') r = f("that's cheese") self.assertIsInstance(r, str) @@ -893,7 +883,7 @@ def testEscapeBytea(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') r = f(u'plain') - self.assertIsInstance(r, unicode) + self.assertIsInstance(r, str) self.assertEqual(r, u'plain') r = f("that's cheese") self.assertIsInstance(r, str) diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 2d853f73..8505e518 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -4,7 +4,7 @@ import sys import unittest -from datetime import date, time, datetime, timedelta +from datetime import date, time, datetime, timedelta, timezone from uuid import UUID as Uuid import pgdb @@ -17,11 +17,6 @@ from .config import dbname, dbhost, dbport, dbuser, dbpasswd -try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences - long -except NameError: # Python >= 3.0 - long = int - class PgBitString: """Test object with a PostgreSQL representation as Bit String.""" @@ -492,7 +487,7 @@ def test_fetch_2_rows(self): self.assertIsInstance(row0[1], bytes) self.assertIsInstance(row0[2], bool) self.assertIsInstance(row0[3], int) - self.assertIsInstance(row0[4], long) + self.assertIsInstance(row0[4], int) self.assertIsInstance(row0[5], float) self.assertIsInstance(row0[6], Decimal) self.assertIsInstance(row0[7], Decimal) @@ -600,8 +595,8 @@ def test_datetime(self): "tz timetz, tsz timestamptz)" % table) for n in range(3): values = [dt.date(), dt.time(), dt, dt.time(), dt] - values[3] = values[3].replace(tzinfo=pgdb.timezone.utc) - values[4] = values[4].replace(tzinfo=pgdb.timezone.utc) + values[3] = values[3].replace(tzinfo=timezone.utc) + values[4] = values[4].replace(tzinfo=timezone.utc) if n == 0: # input as objects params = values if n == 1: # input as text @@ -609,7 +604,7 @@ def test_datetime(self): elif n == 2: # input using type helpers d = (dt.year, dt.month, dt.day) t = (dt.hour, dt.minute, dt.second, dt.microsecond) - z = (pgdb.timezone.utc,) + z = (timezone.utc,) params = [pgdb.Date(*d), pgdb.Time(*t), pgdb.Timestamp(*(d + t)), pgdb.Time(*(t + z)), pgdb.Timestamp(*(d + t + z))] @@ -1000,8 +995,6 @@ def test_unicode_with_utf8(self): output4 = cur.fetchone()[0] finally: con.close() - if str is bytes: # Python < 3.0 - s = s.encode('utf8') self.assertIsInstance(output1, str) self.assertEqual(output1, s) self.assertIsInstance(output2, str) @@ -1033,8 +1026,6 @@ def test_unicode_with_latin1(self): output4 = cur.fetchone()[0] finally: con.close() - if str is bytes: # Python < 3.0 - s = s.encode('latin1') self.assertIsInstance(output1, str) self.assertEqual(output1, s) self.assertIsInstance(output2, str) @@ -1347,10 +1338,7 @@ def test_no_close(self): self.assertEqual(row, data) def test_set_row_factory_size(self): - try: - from functools import lru_cache - except ImportError: # Python < 3.2 - lru_cache = None + from functools import lru_cache queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc'] con = self._connect() cur = con.cursor() @@ -1366,12 +1354,11 @@ def test_set_row_factory_size(self): else: self.assertEqual(r, (1, 2, 3)) self.assertEqual(r._fields, ('a', 'b', 'c')) - if lru_cache: - info = pgdb._row_factory.cache_info() - self.assertEqual(info.maxsize, maxsize) - self.assertEqual(info.hits + info.misses, 6) - self.assertEqual( - info.hits, 0 if maxsize is not None and maxsize < 2 else 4) + info = pgdb._row_factory.cache_info() + self.assertEqual(info.maxsize, maxsize) + self.assertEqual(info.hits + info.misses, 6) + self.assertEqual( + info.hits, 0 if maxsize is not None and maxsize < 2 else 4) def test_memory_leaks(self): ids = set() @@ -1384,9 +1371,6 @@ def test_memory_leaks(self): gc.collect() objs[:] = gc.get_objects() objs[:] = [obj for obj in objs if id(obj) not in ids] - if objs and sys.version_info[:3] in ((3, 5, 0), (3, 5, 1)): - # workaround for Python issue 26811 - objs[:] = [obj for obj in objs if repr(obj) != '(,)'] self.assertEqual(len(objs), 0) def test_cve_2018_1058(self): diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index d6fd1cfc..d8661251 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -11,10 +11,7 @@ import unittest -try: # noinspection PyCompatibility - from collections.abc import Iterable -except ImportError: # Python < 3.3 - from collections import Iterable +from collections.abc import Iterable import pgdb # the module under test @@ -29,15 +26,13 @@ class InputStream: def __init__(self, data): - if isinstance(data, unicode): + if isinstance(data, str): data = data.encode('utf-8') self.data = data or b'' self.sizes = [] def __str__(self): - data = self.data - if str is unicode: # Python >= 3.0 - data = data.decode('utf-8') + data = self.data.decode('utf-8') return data def __len__(self): @@ -60,16 +55,14 @@ def __init__(self): self.sizes = [] def __str__(self): - data = self.data - if str is unicode: # Python >= 3.0 - data = data.decode('utf-8') + data = self.data.decode('utf-8') return data def __len__(self): return len(self.data) def write(self, data): - if isinstance(data, unicode): + if isinstance(data, str): data = data.encode('utf-8') self.data += data self.sizes.append(len(data)) From 1f66d19e49acc36e63a0170ff0161e3f7ab3b4ea Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 20:47:39 +0200 Subject: [PATCH 025/118] Mention new version in README file --- .bumpversion.cfg | 2 +- README.rst | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 3a654eda..89aec55e 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 5.2.5 +current_version = 6.0 commit = False tag = False diff --git a/README.rst b/README.rst index a6054363..150effb5 100644 --- a/README.rst +++ b/README.rst @@ -9,7 +9,13 @@ PyGreSQL should run on most platforms where PostgreSQL and Python is running. It is based on the PyGres95 code written by Pascal Andre. D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -Starting with version 5.0, PyGreSQL also supports Python 3. + +The following Python versions are supported: + +* PyGreSQL 4.x and earlier: Python 2 only +* PyGreSQL 5.x: Python 2 and Python 3 +* PyGreSQL 6.x and newer: Python 3 only + Installation ------------ From fe83a9eb73dfe6288963615fadc79d0ab6d5e497 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 21:02:30 +0200 Subject: [PATCH 026/118] Simplify one more test module --- tests/test_dbapi20_copy.py | 77 ++++++++++---------------------------- 1 file changed, 20 insertions(+), 57 deletions(-) diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index d8661251..540ccf1e 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -17,11 +17,6 @@ from .config import dbname, dbhost, dbport, dbuser, dbpasswd -try: # noinspection PyUnboundLocalVariable,PyUnresolvedReferences - unicode -except NameError: # Python >= 3.0 - unicode = str - class InputStream: @@ -248,26 +243,12 @@ def test_input_string_multiple_rows(self): self.check_table() self.check_rowcount() - if str is unicode: # Python >= 3.0 - - def test_input_bytes(self): - self.copy_from(b'42\tHello, world!') - self.assertEqual(self.table_data, [(42, 'Hello, world!')]) - self.truncate_table() - self.copy_from(self.data_text.encode('utf-8')) - self.check_table() - - else: # Python < 3.0 - - def test_input_unicode(self): - if not self.can_encode: - self.skipTest('database does not support utf8') - self.copy_from(u'43\tWürstel, Käse!') - self.assertEqual(self.table_data, [(43, 'Würstel, Käse!')]) - self.truncate_table() - # noinspection PyUnresolvedReferences - self.copy_from(self.data_text.decode('utf-8')) - self.check_table() + def test_input_bytes(self): + self.copy_from(b'42\tHello, world!') + self.assertEqual(self.table_data, [(42, 'Hello, world!')]) + self.truncate_table() + self.copy_from(self.data_text.encode('utf-8')) + self.check_table() def test_input_iterable(self): self.copy_from(self.data_text.splitlines()) @@ -281,12 +262,10 @@ def test_input_iterable_with_newlines(self): self.copy_from('%s\n' % row for row in self.data_text.splitlines()) self.check_table() - if str is unicode: # Python >= 3.0 - - def test_input_iterable_bytes(self): - self.copy_from(row.encode('utf-8') - for row in self.data_text.splitlines()) - self.check_table() + def test_input_iterable_bytes(self): + self.copy_from(row.encode('utf-8') + for row in self.data_text.splitlines()) + self.check_table() def test_sep(self): stream = ('%d-%s' % row for row in self.data) @@ -437,28 +416,14 @@ def test_generator_with_schema_name(self): ret = self.cursor.copy_to(None, 'public.copytest') self.assertEqual(''.join(ret), self.data_text) - if str is unicode: # Python >= 3.0 - - def test_generator_bytes(self): - ret = self.copy_to(decode=False) - self.assertIsInstance(ret, Iterable) - rows = list(ret) - self.assertEqual(len(rows), 3) - rows = b''.join(rows) - self.assertIsInstance(rows, bytes) - self.assertEqual(rows, self.data_text.encode('utf-8')) - - else: # Python < 3.0 - - def test_generator_unicode(self): - ret = self.copy_to(decode=True) - self.assertIsInstance(ret, Iterable) - rows = list(ret) - self.assertEqual(len(rows), 3) - rows = ''.join(rows) - self.assertIsInstance(rows, unicode) - # noinspection PyUnresolvedReferences - self.assertEqual(rows, self.data_text.decode('utf-8')) + def test_generator_bytes(self): + ret = self.copy_to(decode=False) + self.assertIsInstance(ret, Iterable) + rows = list(ret) + self.assertEqual(len(rows), 3) + rows = b''.join(rows) + self.assertIsInstance(rows, bytes) + self.assertEqual(rows, self.data_text.encode('utf-8')) def test_rowcount_increment(self): ret = self.copy_to() @@ -470,7 +435,7 @@ def test_decode(self): ret_raw = b''.join(self.copy_to(decode=False)) ret_decoded = ''.join(self.copy_to(decode=True)) self.assertIsInstance(ret_raw, bytes) - self.assertIsInstance(ret_decoded, unicode) + self.assertIsInstance(ret_decoded, str) self.assertEqual(ret_decoded, ret_raw.decode('utf-8')) self.check_rowcount() @@ -556,9 +521,7 @@ def test_file(self): ret = self.copy_to(stream) self.assertIs(ret, self.cursor) self.assertEqual(str(stream), self.data_text) - data = self.data_text - if str is unicode: # Python >= 3.0 - data = data.encode('utf-8') + data = self.data_text.encode('utf-8') sizes = [len(row) + 1 for row in data.splitlines()] self.assertEqual(stream.sizes, sizes) self.check_rowcount() From 07da5d802bff07601d2c62b537b2a6214484285e Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 21:53:58 +0200 Subject: [PATCH 027/118] Do not use wildcard imports --- pg.py | 35 +++++++++++++++++++++++++++++--- pgdb.py | 18 ++++++++++++++-- tests/test_classic.py | 5 ++++- tests/test_classic_connection.py | 1 - tests/test_dbapi20.py | 2 -- tests/test_dbapi20_copy.py | 2 +- 6 files changed, 53 insertions(+), 10 deletions(-) diff --git a/pg.py b/pg.py index 2edeb6e2..99e4aa62 100644 --- a/pg.py +++ b/pg.py @@ -21,7 +21,7 @@ """ try: - from _pg import * + from _pg import version except ImportError as e: import os libpq = 'libpq.' @@ -35,10 +35,11 @@ for path in paths: with os.add_dll_directory(os.path.abspath(path)): try: - from _pg import * + from _pg import version except ImportError: pass else: + del version e = None break if paths: @@ -49,6 +50,34 @@ raise ImportError( "Cannot import shared library for PyGreSQL,\n" "probably because no %s is installed.\n%s" % (libpq, e)) from e +else: + del version + +# import objects from extension module +from _pg import ( + Error, Warning, + DataError, DatabaseError, + IntegrityError, InterfaceError, InternalError, + InvalidResultError, MultipleResultsError, + NoResultError, NotSupportedError, + OperationalError, ProgrammingError, + INV_READ, INV_WRITE, + POLLING_OK, POLLING_FAILED, POLLING_READING, POLLING_WRITING, + SEEK_CUR, SEEK_END, SEEK_SET, + TRANS_ACTIVE, TRANS_IDLE, TRANS_INERROR, + TRANS_INTRANS, TRANS_UNKNOWN, + cast_array, cast_hstore, cast_record, + connect, escape_bytea, escape_string, unescape_bytea, + get_array, get_bool, get_bytea_escaped, + get_datestyle, get_decimal, get_decimal_point, + get_defbase, get_defhost, get_defopt, get_defport, get_defuser, + get_jsondecode, get_pqlib_version, + set_array, set_bool, set_bytea_escaped, + set_datestyle, set_decimal, set_decimal_point, + set_defbase, set_defhost, set_defopt, + set_defpasswd, set_defport, set_defuser, + set_jsondecode, set_query_helpers, + version) __version__ = version @@ -72,7 +101,7 @@ 'get_array', 'get_bool', 'get_bytea_escaped', 'get_datestyle', 'get_decimal', 'get_decimal_point', 'get_defbase', 'get_defhost', 'get_defopt', 'get_defport', 'get_defuser', - 'get_jsondecode', 'get_typecast', + 'get_jsondecode', 'get_pqlib_version', 'get_typecast', 'set_array', 'set_bool', 'set_bytea_escaped', 'set_datestyle', 'set_decimal', 'set_decimal_point', 'set_defbase', 'set_defhost', 'set_defopt', diff --git a/pgdb.py b/pgdb.py index 85767e3a..4de78e15 100644 --- a/pgdb.py +++ b/pgdb.py @@ -65,7 +65,7 @@ """ try: - from _pg import * + from _pg import version except ImportError as e: import os libpq = 'libpq.' @@ -79,10 +79,11 @@ for path in paths: with os.add_dll_directory(os.path.abspath(path)): try: - from _pg import * + from _pg import version except ImportError: pass else: + del version e = None break if paths: @@ -93,6 +94,19 @@ raise ImportError( "Cannot import shared library for PyGreSQL,\n" "probably because no %s is installed.\n%s" % (libpq, e)) from e +else: + del version + +# import objects from extension module +from _pg import ( + Error, Warning, + DataError, DatabaseError, + IntegrityError, InterfaceError, InternalError, + NotSupportedError, OperationalError, ProgrammingError, + cast_array, cast_hstore, cast_record, + RESULT_DQL, + connect, unescape_bytea, + version) __version__ = version diff --git a/tests/test_classic.py b/tests/test_classic.py index 375bad3f..799cb6c7 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -6,7 +6,10 @@ from time import sleep from threading import Thread -from pg import * +from pg import ( + DB, NotificationHandler, + Error, DatabaseError, IntegrityError, + NotSupportedError, ProgrammingError) from .config import dbname, dbhost, dbport, dbuser, dbpasswd diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 0dedc5c4..068dd792 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -2666,7 +2666,6 @@ def testSetByteaEscaped(self): self.assertEqual(r, b'data') def testSetRowFactorySize(self): - from functools import lru_cache queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc'] query = self.c.query for maxsize in (None, 0, 1, 2, 3, 10, 1024): diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 8505e518..01a89247 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -1,7 +1,6 @@ #!/usr/bin/python import gc -import sys import unittest from datetime import date, time, datetime, timedelta, timezone @@ -1338,7 +1337,6 @@ def test_no_close(self): self.assertEqual(row, data) def test_set_row_factory_size(self): - from functools import lru_cache queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc'] con = self._connect() cur = con.cursor() diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index 540ccf1e..769065ab 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -264,7 +264,7 @@ def test_input_iterable_with_newlines(self): def test_input_iterable_bytes(self): self.copy_from(row.encode('utf-8') - for row in self.data_text.splitlines()) + for row in self.data_text.splitlines()) self.check_table() def test_sep(self): From 3123f331e42c6408f11c1cdb49936b168ebaf847 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 22:31:36 +0200 Subject: [PATCH 028/118] Use some default parameters for testing --- tests/config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/config.py b/tests/config.py index e6bf326c..6e2ebd3c 100644 --- a/tests/config.py +++ b/tests/config.py @@ -16,9 +16,9 @@ get = environ.get -dbname = get('PYGRESQL_DB', get('PGDATABASE')) -dbhost = get('PYGRESQL_HOST', get('PGHOST')) -dbport = get('PYGRESQL_PORT', get('PGPORT')) +dbname = get('PYGRESQL_DB', get('PGDATABASE', 'test')) +dbhost = get('PYGRESQL_HOST', get('PGHOST', 'localhost')) +dbport = get('PYGRESQL_PORT', get('PGPORT', 5432)) dbuser = get('PYGRESQL_USER', get('PGUSER')) dbpasswd = get('PYGRESQL_PASSWD', get('PGPASSWORD')) From da5bf3eafbfa72c70351b3b8319d35cf97dd05fc Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 23:31:02 +0200 Subject: [PATCH 029/118] Remove most compiler options These only complicated things and nobody used them anyway. Kept only the memory size option because it is not available in PostgreSQL 10 which is still supported. --- docs/contents/install.rst | 26 ++++++-------------- docs/contents/pg/connection.rst | 37 +++++++++++------------------ docs/contents/pg/module.rst | 28 ++++++++++------------ pgconn.c | 26 -------------------- pginternal.c | 4 ---- pgmodule.c | 27 --------------------- pgquery.c | 11 ++++----- setup.py | 42 +++------------------------------ tox.ini | 2 +- 9 files changed, 43 insertions(+), 160 deletions(-) diff --git a/docs/contents/install.rst b/docs/contents/install.rst index d1926881..fd4f99b5 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -91,24 +91,24 @@ Now you should be ready to use PyGreSQL. You can also run the build step separately if you want to create a distribution to be installed on a different system or explicitly enable or disable certain -features. For instance, in order to build PyGreSQL without support for the SSL -info functions, run:: +features. For instance, in order to build PyGreSQL without support for the +memory size functions, run:: - python setup.py build_ext --no-ssl-info + python setup.py build_ext --no-memory-size By default, PyGreSQL is compiled with support for all features available in the installed PostgreSQL version, and you will get warnings for the features that are not supported in this version. You can also explicitly require a feature in order to get an error if it is not available, for instance: - python setup.py build_ext --ssl-info + python setup.py build_ext --memory-size You can find out all possible build options with:: python setup.py build_ext --help Alternatively, you can also use the corresponding C preprocessor macros like -``SSL_INFO`` directly (see the next section). +``MEMORY_SIZE`` directly (see the next section). Note that if you build PyGreSQL with support for newer features that are not available in the PQLib installed on the runtime system, you may get an error @@ -154,13 +154,7 @@ Stand-Alone Some options may be added to this line:: - -DDEFAULT_VARS default variables support - -DDIRECT_ACCESS direct access methods - -DLARGE_OBJECTS large object support - -DESCAPING_FUNCS support for newer escaping functions - -DPQLIB_INFO support PQLib information - -DSSL_INFO support SSL information - -DMEMORY_SIZE support memory size function + -DMEMORY_SIZE = support memory size function (PostgreSQL 12 or newer) On some systems you may need to include ``-lcrypt`` in the list of libraries to make it compile. @@ -202,13 +196,7 @@ Built-in to Python interpreter Some options may be added to this line:: - -DDEFAULT_VARS default variables support - -DDIRECT_ACCESS direct access methods - -DLARGE_OBJECTS large object support - -DESCAPING_FUNCS support for newer escaping functions - -DPQLIB_INFO support PQLib information - -DSSL_INFO support SSL information - -DMEMORY_SIZE support memory size function + -DMEMORY_SIZE = support memory size function (PostgreSQL 12 or newer) On some systems you may need to include ``-lcrypt`` in the list of libraries to make it compile. diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index d1c95213..7fd44cca 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -13,17 +13,8 @@ significant parameters in function calls. Some methods give direct access to the connection socket. *Do not use them unless you really know what you are doing.* - If you prefer disabling them, - do not set the ``direct_access`` option in the Python setup file. - These methods are specified by the tag [DA]. - -.. note:: - - Some other methods give access to large objects - (refer to PostgreSQL user manual for more information about these). - If you want to forbid access to these from the module, - set the ``large_objects`` option in the Python setup file. - These methods are specified by the tag [LO]. + Some other methods give access to large objects. + Refer to the PostgreSQL user manual for more information about these. query -- execute a SQL command string ------------------------------------- @@ -605,8 +596,8 @@ attributes: .. versionadded:: 4.1 -putline -- write a line to the server socket [DA] -------------------------------------------------- +putline -- write a line to the server socket +-------------------------------------------- .. method:: Connection.putline(line) @@ -618,8 +609,8 @@ putline -- write a line to the server socket [DA] This method allows to directly write a string to the server socket. -getline -- get a line from server socket [DA] ---------------------------------------------- +getline -- get a line from server socket +---------------------------------------- .. method:: Connection.getline() @@ -633,8 +624,8 @@ getline -- get a line from server socket [DA] This method allows to directly read a string from the server socket. -endcopy -- synchronize client and server [DA] ---------------------------------------------- +endcopy -- synchronize client and server +---------------------------------------- .. method:: Connection.endcopy() @@ -647,8 +638,8 @@ endcopy -- synchronize client and server [DA] The use of direct access methods may desynchronize client and server. This method ensure that client and server will be synchronized. -locreate -- create a large object in the database [LO] ------------------------------------------------------- +locreate -- create a large object in the database +------------------------------------------------- .. method:: Connection.locreate(mode) @@ -665,8 +656,8 @@ by OR-ing the constants defined in the :mod:`pg` module (:const:`INV_READ`, and :const:`INV_WRITE`). Please refer to PostgreSQL user manual for a description of the mode values. -getlo -- build a large object from given oid [LO] -------------------------------------------------- +getlo -- build a large object from given oid +-------------------------------------------- .. method:: Connection.getlo(oid) @@ -681,8 +672,8 @@ getlo -- build a large object from given oid [LO] This method allows reusing a previously created large object through the :class:`LargeObject` interface, provided the user has its OID. -loimport -- import a file to a large object [LO] ------------------------------------------------- +loimport -- import a file to a large object +------------------------------------------- .. method:: Connection.loimport(name) diff --git a/docs/contents/pg/module.rst b/docs/contents/pg/module.rst index b122808b..9faa3754 100644 --- a/docs/contents/pg/module.rst +++ b/docs/contents/pg/module.rst @@ -10,9 +10,7 @@ the environment variables used by PostgreSQL. These "default variables" were designed to allow you to handle general connection parameters without heavy code in your programs. You can prompt the user for a value, put it in the default variable, and forget it, without -having to modify your environment. The support for default variables can be -disabled by not setting the ``default_vars`` option in the Python setup file. -Methods relative to this are specified by the tag [DV]. +having to modify your environment. All variables are set to ``None`` at module initialization, specifying that standard environment variables should be used. @@ -87,8 +85,8 @@ For example, version 9.1.2 will be returned as 90102. .. versionadded:: 5.2 (needs PostgreSQL >= 9.1) -get/set_defhost -- default server host [DV] -------------------------------------------- +get/set_defhost -- default server host +-------------------------------------- .. function:: get_defhost(host) @@ -117,8 +115,8 @@ If ``None`` is supplied as parameter, environment variables will be used in future connections. It returns the previous setting for default host. -get/set_defport -- default server port [DV] -------------------------------------------- +get/set_defport -- default server port +-------------------------------------- .. function:: get_defport() @@ -145,8 +143,8 @@ This methods sets the default port value for new connections. If -1 is supplied as parameter, environment variables will be used in future connections. It returns the previous setting for default port. -get/set_defopt -- default connection options [DV] --------------------------------------------------- +get/set_defopt -- default connection options +--------------------------------------------- .. function:: get_defopt() @@ -174,8 +172,8 @@ This methods sets the default connection options value for new connections. If ``None`` is supplied as parameter, environment variables will be used in future connections. It returns the previous setting for default options. -get/set_defbase -- default database name [DV] ---------------------------------------------- +get/set_defbase -- default database name +---------------------------------------- .. function:: get_defbase() @@ -203,8 +201,8 @@ This method sets the default database name value for new connections. If ``None`` is supplied as parameter, environment variables will be used in future connections. It returns the previous setting for default host. -get/set_defuser -- default database user [DV] ---------------------------------------------- +get/set_defuser -- default database user +---------------------------------------- .. function:: get_defuser() @@ -232,8 +230,8 @@ This method sets the default database user name for new connections. If ``None`` is supplied as parameter, environment variables will be used in future connections. It returns the previous setting for default host. -get/set_defpasswd -- default database password [DV] ---------------------------------------------------- +get/set_defpasswd -- default database password +---------------------------------------------- .. function:: get_defpasswd() diff --git a/pgconn.c b/pgconn.c index 910f2212..c67e74dc 100644 --- a/pgconn.c +++ b/pgconn.c @@ -94,27 +94,17 @@ conn_getattr(connObject *self, PyObject *nameobj) /* whether the connection uses SSL */ if (!strcmp(name, "ssl_in_use")) { -#ifdef SSL_INFO if (PQsslInUse(self->cnx)) { Py_INCREF(Py_True); return Py_True; } else { Py_INCREF(Py_False); return Py_False; } -#else - set_error_msg(NotSupportedError, "SSL info functions not supported"); - return NULL; -#endif } /* SSL attributes */ if (!strcmp(name, "ssl_attributes")) { -#ifdef SSL_INFO return get_ssl_attributes(self->cnx); -#else - set_error_msg(NotSupportedError, "SSL info functions not supported"); - return NULL; -#endif } return PyObject_GenericGetAttr((PyObject *) self, nameobj); @@ -540,7 +530,6 @@ conn_describe_prepared(connObject *self, PyObject *args) return NULL; /* error */ } -#ifdef DIRECT_ACCESS static char conn_putline__doc__[] = "putline(line) -- send a line directly to the backend"; @@ -697,7 +686,6 @@ conn_is_non_blocking(connObject *self, PyObject *noargs) return PyBool_FromLong((long)rc); } -#endif /* DIRECT_ACCESS */ /* Insert table */ @@ -1110,8 +1098,6 @@ conn_date_format(connObject *self, PyObject *noargs) return PyUnicode_FromString(fmt); } -#ifdef ESCAPING_FUNCS - /* Escape literal */ static char conn_escape_literal__doc__[] = "escape_literal(str) -- escape a literal constant for use within SQL"; @@ -1202,8 +1188,6 @@ conn_escape_identifier(connObject *self, PyObject *string) return to_obj; } -#endif /* ESCAPING_FUNCS */ - /* Escape string */ static char conn_escape_string__doc__[] = "escape_string(str) -- escape a string for use within SQL"; @@ -1299,8 +1283,6 @@ conn_escape_bytea(connObject *self, PyObject *data) return to_obj; } -#ifdef LARGE_OBJECTS - /* Constructor for large objects (internal use only) */ static largeObject * large_new(connObject *pgcnx, Oid oid) @@ -1415,8 +1397,6 @@ conn_loimport(connObject *self, PyObject *args) return (PyObject *) large_new(self, lo_oid); } -#endif /* LARGE_OBJECTS */ - /* Reset connection. */ static char conn_reset__doc__[] = "reset() -- reset connection with current parameters\n\n" @@ -1724,18 +1704,15 @@ static struct PyMethodDef conn_methods[] = { {"date_format", (PyCFunction) conn_date_format, METH_NOARGS, conn_date_format__doc__}, -#ifdef ESCAPING_FUNCS {"escape_literal", (PyCFunction) conn_escape_literal, METH_O, conn_escape_literal__doc__}, {"escape_identifier", (PyCFunction) conn_escape_identifier, METH_O, conn_escape_identifier__doc__}, -#endif /* ESCAPING_FUNCS */ {"escape_string", (PyCFunction) conn_escape_string, METH_O, conn_escape_string__doc__}, {"escape_bytea", (PyCFunction) conn_escape_bytea, METH_O, conn_escape_bytea__doc__}, -#ifdef DIRECT_ACCESS {"putline", (PyCFunction) conn_putline, METH_VARARGS, conn_putline__doc__}, {"getline", (PyCFunction) conn_getline, @@ -1746,16 +1723,13 @@ static struct PyMethodDef conn_methods[] = { METH_VARARGS, conn_set_non_blocking__doc__}, {"is_non_blocking", (PyCFunction) conn_is_non_blocking, METH_NOARGS, conn_is_non_blocking__doc__}, -#endif /* DIRECT_ACCESS */ -#ifdef LARGE_OBJECTS {"locreate", (PyCFunction) conn_locreate, METH_VARARGS, conn_locreate__doc__}, {"getlo", (PyCFunction) conn_getlo, METH_VARARGS, conn_getlo__doc__}, {"loimport", (PyCFunction) conn_loimport, METH_VARARGS, conn_loimport__doc__}, -#endif /* LARGE_OBJECTS */ {NULL, NULL} /* sentinel */ }; diff --git a/pginternal.c b/pginternal.c index 50181b0d..61446f41 100644 --- a/pginternal.c +++ b/pginternal.c @@ -1115,8 +1115,6 @@ set_error(PyObject *type, const char * msg, PGconn *cnx, PGresult *result) set_error_msg_and_state(type, msg, encoding, sqlstate); } -#ifdef SSL_INFO - /* Get SSL attributes and values as a dictionary. */ static PyObject * get_ssl_attributes(PGconn *cnx) { @@ -1144,8 +1142,6 @@ get_ssl_attributes(PGconn *cnx) { return attr_dict; } -#endif /* SSL_INFO */ - /* Format result (mostly useful for debugging). Note: This is similar to the Postgres function PQprint(). PQprint() is not used because handing over a stream from Python to diff --git a/pgmodule.c b/pgmodule.c index 6adc79c0..f1335263 100644 --- a/pgmodule.c +++ b/pgmodule.c @@ -59,14 +59,12 @@ static const char *PyPgVersion = TOSTRING(PYGRESQL_VERSION); /* MODULE GLOBAL VARIABLES */ -#ifdef DEFAULT_VARS static PyObject *pg_default_host; /* default database host */ static PyObject *pg_default_base; /* default database name */ static PyObject *pg_default_opt; /* default connection options */ static PyObject *pg_default_port; /* default connection port */ static PyObject *pg_default_user; /* default username */ static PyObject *pg_default_passwd; /* default password */ -#endif /* DEFAULT_VARS */ static PyObject *decimal = NULL, /* decimal type */ *dictiter = NULL, /* function for getting dict results */ @@ -160,7 +158,6 @@ typedef struct } queryObject; #define is_queryObject(v) (PyType(v) == &queryType) -#ifdef LARGE_OBJECTS typedef struct { PyObject_HEAD @@ -169,7 +166,6 @@ typedef struct int lo_fd; /* large object fd */ } largeObject; #define is_largeObject(v) (PyType(v) == &largeType) -#endif /* LARGE_OBJECTS */ /* Internal functions */ #include "pginternal.c" @@ -187,9 +183,7 @@ typedef struct #include "pgnotice.c" /* Large objects */ -#ifdef LARGE_OBJECTS #include "pglarge.c" -#endif /* MODULE FUNCTIONS */ @@ -228,7 +222,6 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) return NULL; } -#ifdef DEFAULT_VARS /* handles defaults variables (for uninitialised vars) */ if ((!pghost) && (pg_default_host != Py_None)) pghost = PyBytes_AsString(pg_default_host); @@ -247,7 +240,6 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) if ((!pgpasswd) && (pg_default_passwd != Py_None)) pgpasswd = PyBytes_AsString(pg_default_passwd); -#endif /* DEFAULT_VARS */ if (!(conn_obj = PyObject_New(connObject, &connType))) { set_error_msg(InternalError, "Can't create new connection object"); @@ -309,8 +301,6 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) return (PyObject *) conn_obj; } -#ifdef PQLIB_INFO - /* Get version of libpq that is being used */ static char pg_get_pqlib_version__doc__[] = "get_pqlib_version() -- get the version of libpq that is being used"; @@ -320,8 +310,6 @@ pg_get_pqlib_version(PyObject *self, PyObject *noargs) { return PyLong_FromLong(PQlibVersion()); } -#endif /* PQLIB_INFO */ - /* Escape string */ static char pg_escape_string__doc__[] = "escape_string(string) -- escape a string for use within SQL"; @@ -766,8 +754,6 @@ pg_set_jsondecode(PyObject *self, PyObject *func) return ret; } -#ifdef DEFAULT_VARS - /* Get default host. */ static char pg_get_defhost__doc__[] = "get_defhost() -- return default database host"; @@ -1012,7 +998,6 @@ pg_set_defport(PyObject *self, PyObject *args) return old; } -#endif /* DEFAULT_VARS */ /* Cast a string with a text representation of an array to a list. */ static char pg_cast_array__doc__[] = @@ -1216,7 +1201,6 @@ static struct PyMethodDef pg_methods[] = { METH_VARARGS|METH_KEYWORDS, pg_cast_record__doc__}, {"cast_hstore", (PyCFunction) pg_cast_hstore, METH_O, pg_cast_hstore__doc__}, -#ifdef DEFAULT_VARS {"get_defhost", pg_get_defhost, METH_NOARGS, pg_get_defhost__doc__}, {"set_defhost", pg_set_defhost, METH_VARARGS, pg_set_defhost__doc__}, {"get_defbase", pg_get_defbase, METH_NOARGS, pg_get_defbase__doc__}, @@ -1228,11 +1212,8 @@ static struct PyMethodDef pg_methods[] = { {"get_defuser", pg_get_defuser, METH_NOARGS, pg_get_defuser__doc__}, {"set_defuser", pg_set_defuser, METH_VARARGS, pg_set_defuser__doc__}, {"set_defpasswd", pg_set_defpasswd, METH_VARARGS, pg_set_defpasswd__doc__}, -#endif /* DEFAULT_VARS */ -#ifdef PQLIB_INFO {"get_pqlib_version", (PyCFunction) pg_get_pqlib_version, METH_NOARGS, pg_get_pqlib_version__doc__}, -#endif /* PQLIB_INFO */ {NULL, NULL} /* sentinel */ }; @@ -1260,17 +1241,13 @@ PyMODINIT_FUNC PyInit__pg(void) /* Initialize here because some Windows platforms get confused otherwise */ connType.tp_base = noticeType.tp_base = queryType.tp_base = sourceType.tp_base = &PyBaseObject_Type; -#ifdef LARGE_OBJECTS largeType.tp_base = &PyBaseObject_Type; -#endif if (PyType_Ready(&connType) || PyType_Ready(¬iceType) || PyType_Ready(&queryType) || PyType_Ready(&sourceType) -#ifdef LARGE_OBJECTS || PyType_Ready(&largeType) -#endif ) { return NULL; @@ -1354,7 +1331,6 @@ PyMODINIT_FUNC PyInit__pg(void) PyDict_SetItemString(dict, "POLLING_READING", PyLong_FromLong(PGRES_POLLING_READING)); PyDict_SetItemString(dict, "POLLING_WRITING", PyLong_FromLong(PGRES_POLLING_WRITING)); -#ifdef LARGE_OBJECTS /* Create mode for large objects */ PyDict_SetItemString(dict, "INV_READ", PyLong_FromLong(INV_READ)); PyDict_SetItemString(dict, "INV_WRITE", PyLong_FromLong(INV_WRITE)); @@ -1363,9 +1339,7 @@ PyMODINIT_FUNC PyInit__pg(void) PyDict_SetItemString(dict, "SEEK_SET", PyLong_FromLong(SEEK_SET)); PyDict_SetItemString(dict, "SEEK_CUR", PyLong_FromLong(SEEK_CUR)); PyDict_SetItemString(dict, "SEEK_END", PyLong_FromLong(SEEK_END)); -#endif /* LARGE_OBJECTS */ -#ifdef DEFAULT_VARS /* Prepare default values */ Py_INCREF(Py_None); pg_default_host = Py_None; @@ -1379,7 +1353,6 @@ PyMODINIT_FUNC PyInit__pg(void) pg_default_user = Py_None; Py_INCREF(Py_None); pg_default_passwd = Py_None; -#endif /* DEFAULT_VARS */ /* Store common pg encoding ids */ diff --git a/pgquery.c b/pgquery.c index 0923eb66..1196889a 100644 --- a/pgquery.c +++ b/pgquery.c @@ -246,18 +246,19 @@ query_next(queryObject *self, PyObject *noargs) return row_tuple; } -#ifdef MEMORY_SIZE - /* Get number of bytes allocated for PGresult object */ static char query_memsize__doc__[] = "memsize() -- return number of bytes allocated by query result"; static PyObject * query_memsize(queryObject *self, PyObject *noargs) { +#ifdef MEMORY_SIZE return PyLong_FromSize_t(PQresultMemorySize(self->result)); -} - +#else + set_error_msg(NotSupportedError, "Memory size functions not supported"); + return NULL; #endif /* MEMORY_SIZE */ +} /* Get number of rows. */ static char query_ntuples__doc__[] = @@ -949,10 +950,8 @@ static struct PyMethodDef query_methods[] = { METH_VARARGS, query_fieldinfo__doc__}, {"ntuples", (PyCFunction) query_ntuples, METH_NOARGS, query_ntuples__doc__}, -#ifdef MEMORY_SIZE {"memsize", (PyCFunction) query_memsize, METH_NOARGS, query_memsize__doc__}, -#endif /* MEMORY_SIZE */ {NULL, NULL} }; diff --git a/setup.py b/setup.py index fb5330e8..456e3b5e 100755 --- a/setup.py +++ b/setup.py @@ -104,31 +104,13 @@ class build_pg_ext(build_ext): user_options = build_ext.user_options + [ ('strict', None, "count all compiler warnings as errors"), - ('direct-access', None, "enable direct access functions"), - ('no-direct-access', None, "disable direct access functions"), - ('direct-access', None, "enable direct access functions"), - ('no-direct-access', None, "disable direct access functions"), - ('large-objects', None, "enable large object support"), - ('no-large-objects', None, "disable large object support"), - ('default-vars', None, "enable default variables use"), - ('no-default-vars', None, "disable default variables use"), - ('escaping-funcs', None, "enable string escaping functions"), - ('no-escaping-funcs', None, "disable string escaping functions"), - ('ssl-info', None, "use new ssl info functions"), - ('no-ssl-info', None, "do not use new ssl info functions"), - ('memory-size', None, "enable new memory size function"), - ('no-memory-size', None, "disable new memory size function")] + ('memory-size', None, "enable memory size function"), + ('no-memory-size', None, "disable memory size function")] boolean_options = build_ext.boolean_options + [ - 'strict', 'direct-access', 'large-objects', 'default-vars', - 'escaping-funcs', 'ssl-info', 'memory-size'] + 'strict', 'memory-size'] negative_opt = { - 'no-direct-access': 'direct-access', - 'no-large-objects': 'large-objects', - 'no-default-vars': 'default-vars', - 'no-escaping-funcs': 'escaping-funcs', - 'no-ssl-info': 'ssl-info', 'no-memory-size': 'memory-size'} def get_compiler(self): @@ -138,12 +120,6 @@ def get_compiler(self): def initialize_options(self): build_ext.initialize_options(self) self.strict = False - self.direct_access = None - self.large_objects = None - self.default_vars = None - self.escaping_funcs = None - self.pqlib_info = None - self.ssl_info = None self.memory_size = None supported = pg_version >= (10, 0) if not supported: @@ -155,18 +131,6 @@ def finalize_options(self): build_ext.finalize_options(self) if self.strict: extra_compile_args.append('-Werror') - if self.direct_access is None or self.direct_access: - define_macros.append(('DIRECT_ACCESS', None)) - if self.large_objects is None or self.large_objects: - define_macros.append(('LARGE_OBJECTS', None)) - if self.default_vars is None or self.default_vars: - define_macros.append(('DEFAULT_VARS', None)) - if self.escaping_funcs is None or self.escaping_funcs: - define_macros.append(('ESCAPING_FUNCS', None)) - if self.pqlib_info is None or self.pqlib_info: - define_macros.append(('PQLIB_INFO', None)) - if self.ssl_info is None or self.ssl_info: - define_macros.append(('SSL_INFO', None)) wanted = self.memory_size supported = pg_version >= (12, 0) if (wanted is None and supported) or wanted: diff --git a/tox.ini b/tox.ini index d48b44c7..23fb9379 100644 --- a/tox.ini +++ b/tox.ini @@ -21,5 +21,5 @@ passenv = PG* PYGRESQL_* commands = - python setup.py clean --all build_ext --force --inplace --strict --ssl-info --memory-size + python setup.py clean --all build_ext --force --inplace --strict --memory-size python -m unittest {posargs:discover} From 99dd2e07af24a299e39dbea25eee2ee3a09b3342 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 29 Aug 2023 23:52:50 +0200 Subject: [PATCH 030/118] Remove support for very old PostgreSQL versions --- docs/contents/pg/connection.rst | 4 +- docs/contents/pg/module.rst | 4 +- pg.py | 80 +++++++++------------------ pgdb.py | 17 ++---- tests/test_classic_connection.py | 8 +-- tests/test_classic_dbwrapper.py | 93 +++++--------------------------- 6 files changed, 50 insertions(+), 156 deletions(-) diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index 7fd44cca..237e25a8 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -749,10 +749,10 @@ the connection and its status. These attributes are: this is True if the connection uses SSL, False if not -.. versionadded:: 5.1 (needs PostgreSQL >= 9.5) +.. versionadded:: 5.1 .. attribute:: Connection.ssl_attributes SSL-related information about the connection (dict) -.. versionadded:: 5.1 (needs PostgreSQL >= 9.5) +.. versionadded:: 5.1 diff --git a/docs/contents/pg/module.rst b/docs/contents/pg/module.rst index 9faa3754..203ada03 100644 --- a/docs/contents/pg/module.rst +++ b/docs/contents/pg/module.rst @@ -81,9 +81,9 @@ get_pqlib_version -- get the version of libpq The number is formed by converting the major, minor, and revision numbers of the libpq version into two-decimal-digit numbers and appending them together. -For example, version 9.1.2 will be returned as 90102. +For example, version 15.4 will be returned as 150400. -.. versionadded:: 5.2 (needs PostgreSQL >= 9.1) +.. versionadded:: 5.2 get/set_defhost -- default server host -------------------------------------- diff --git a/pg.py b/pg.py index 99e4aa62..896af911 100644 --- a/pg.py +++ b/pg.py @@ -1126,19 +1126,11 @@ def __init__(self, db): self._typecasts = Typecasts() self._typecasts.get_attnames = self.get_attnames self._typecasts.connection = self._db - if db.server_version < 80400: - # very old remote databases (not officially supported) - self._query_pg_type = ( - "SELECT oid, typname, oid::pg_catalog.regtype," - " typlen, typtype, null as typcategory, typdelim, typrelid" - " FROM pg_catalog.pg_type" - " WHERE oid OPERATOR(pg_catalog.=) %s::pg_catalog.regtype") - else: - self._query_pg_type = ( - "SELECT oid, typname, oid::pg_catalog.regtype," - " typlen, typtype, typcategory, typdelim, typrelid" - " FROM pg_catalog.pg_type" - " WHERE oid OPERATOR(pg_catalog.=) %s::pg_catalog.regtype") + self._query_pg_type = ( + "SELECT oid, typname, oid::pg_catalog.regtype," + " typlen, typtype, typcategory, typdelim, typrelid" + " FROM pg_catalog.pg_type" + " WHERE oid OPERATOR(pg_catalog.=) {}::pg_catalog.regtype") def add(self, oid, pgtype, regtype, typlen, typtype, category, delim, relid): @@ -1162,7 +1154,7 @@ def add(self, oid, pgtype, regtype, def __missing__(self, key): """Get the type info from the database if it is not cached.""" try: - q = self._query_pg_type % (_quote_if_unqualified('$1', key),) + q = self._query_pg_type.format(_quote_if_unqualified('$1', key)) res = self._db.query(q, (key,)).getresult() except ProgrammingError: res = None @@ -1493,33 +1485,17 @@ def __init__(self, *args, **kw): self._privileges = {} self.adapter = Adapter(self) self.dbtypes = DbTypes(self) - if db.server_version < 80400: - # very old remote databases (not officially supported) - self._query_attnames = ( - "SELECT a.attname," - " t.oid, t.typname, t.oid::pg_catalog.regtype," - " t.typlen, t.typtype, null as typcategory," - " t.typdelim, t.typrelid" - " FROM pg_catalog.pg_attribute a" - " JOIN pg_catalog.pg_type t" - " ON t.oid OPERATOR(pg_catalog.=) a.atttypid" - " WHERE a.attrelid OPERATOR(pg_catalog.=)" - " %s::pg_catalog.regclass" - " AND %s AND NOT a.attisdropped ORDER BY a.attnum") - else: - self._query_attnames = ( - "SELECT a.attname," - " t.oid, t.typname, t.oid::pg_catalog.regtype," - " t.typlen, t.typtype, t.typcategory, t.typdelim, t.typrelid" - " FROM pg_catalog.pg_attribute a" - " JOIN pg_catalog.pg_type t" - " ON t.oid OPERATOR(pg_catalog.=) a.atttypid" - " WHERE a.attrelid OPERATOR(pg_catalog.=)" - " %s::pg_catalog.regclass" - " AND %s AND NOT a.attisdropped ORDER BY a.attnum") - if db.server_version < 100000: - self._query_generated = None - elif db.server_version < 120000: + self._query_attnames = ( + "SELECT a.attname," + " t.oid, t.typname, t.oid::pg_catalog.regtype," + " t.typlen, t.typtype, t.typcategory, t.typdelim, t.typrelid" + " FROM pg_catalog.pg_attribute a" + " JOIN pg_catalog.pg_type t" + " ON t.oid OPERATOR(pg_catalog.=) a.atttypid" + " WHERE a.attrelid OPERATOR(pg_catalog.=)" + " {}::pg_catalog.regclass" + " AND {} AND NOT a.attisdropped ORDER BY a.attnum") + if db.server_version < 120000: self._query_generated = ( "a.attidentity OPERATOR(pg_catalog.=) 'a'" ) @@ -2052,8 +2028,9 @@ def get_attnames(self, table, with_oid=True, flush=False): except KeyError: # cache miss, check the database q = "a.attnum OPERATOR(pg_catalog.>) 0" if with_oid: - q = "(%s OR a.attname OPERATOR(pg_catalog.=) 'oid')" % q - q = self._query_attnames % (_quote_if_unqualified('$1', table), q) + q = f"({q} OR a.attname OPERATOR(pg_catalog.=) 'oid')" + q = self._query_attnames.format( + _quote_if_unqualified('$1', table), q) names = self.db.query(q, (table,)).getresult() types = self.dbtypes names = ((name[0], types.add(*name[1:])) for name in names) @@ -2070,9 +2047,6 @@ def get_generated(self, table, flush=False): be flushed. This may be necessary after the database schema or the search path has been changed. """ - query_generated = self._query_generated - if not query_generated: - return frozenset() generated = self._generated if flush: generated.clear() @@ -2080,8 +2054,10 @@ def get_generated(self, table, flush=False): try: # cache lookup names = generated[table] except KeyError: # cache miss, check the database - q = "a.attnum OPERATOR(pg_catalog.>) 0 AND " + query_generated - q = self._query_attnames % (_quote_if_unqualified('$1', table), q) + q = "a.attnum OPERATOR(pg_catalog.>) 0" + q = f"{q} AND {self._query_generated}" + q = self._query_attnames.format( + _quote_if_unqualified('$1', table), q) names = self.db.query(q, (table,)).getresult() names = frozenset(name[0] for name in names) generated[table] = names # cache it @@ -2394,13 +2370,7 @@ def upsert(self, table, row=None, **kw): ' ON CONFLICT (%s) DO %s RETURNING %s') % ( self._escape_qualified_name(table), names, values, target, do, ret) self._do_debug(q, params) - try: - q = self.db.query(q, params) - except ProgrammingError: - if self.server_version < 90500: - raise _prg_error( - 'Upsert operation is not supported by PostgreSQL version') - raise # re-raise original error + q = self.db.query(q, params) res = q.dictresult() if res: # may be empty with "do nothing" for n, value in res[0].items(): diff --git a/pgdb.py b/pgdb.py index 4de78e15..f986242f 100644 --- a/pgdb.py +++ b/pgdb.py @@ -629,17 +629,10 @@ def __init__(self, cnx): self._typecasts = LocalTypecasts() self._typecasts.get_fields = self.get_fields self._typecasts.connection = cnx - if cnx.server_version < 80400: - # older remote databases (not officially supported) - self._query_pg_type = ( - "SELECT oid, typname," - " typlen, typtype, null as typcategory, typdelim, typrelid" - " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) %s") - else: - self._query_pg_type = ( - "SELECT oid, typname," - " typlen, typtype, typcategory, typdelim, typrelid" - " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) %s") + self._query_pg_type = ( + "SELECT oid, typname," + " typlen, typtype, typcategory, typdelim, typrelid" + " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) {}") def __missing__(self, key): """Get the type info from the database if it is not cached.""" @@ -650,7 +643,7 @@ def __missing__(self, key): key = '"%s"' % (key,) oid = "'%s'::pg_catalog.regtype" % (self._escape_string(key),) try: - self._src.execute(self._query_pg_type % (oid,)) + self._src.execute(self._query_pg_type.format(oid)) except ProgrammingError: res = None else: diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 068dd792..c456b4ec 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -176,7 +176,7 @@ def testAttributeProtocolVersion(self): def testAttributeServerVersion(self): server_version = self.connection.server_version self.assertIsInstance(server_version, int) - self.assertTrue(90000 <= server_version < 160000) + self.assertTrue(100000 <= server_version < 160000) def testAttributeSocket(self): socket = self.connection.socket @@ -2704,11 +2704,7 @@ def setUpClass(cls): query = db.query query('set client_encoding=sql_ascii') query('set standard_conforming_strings=off') - try: - query('set bytea_output=escape') - except pg.ProgrammingError: - if db.server_version >= 90000: - raise # ignore for older server versions + query('set bytea_output=escape') db.close() cls.cls_set_up = True diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 25c3c11d..3d372ad3 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -254,7 +254,7 @@ def testAttributeProtocolVersion(self): def testAttributeServerVersion(self): server_version = self.db.server_version self.assertIsInstance(server_version, int) - self.assertTrue(90000 <= server_version < 160000) + self.assertTrue(100000 <= server_version < 160000) self.assertEqual(server_version, self.db.db.server_version) def testAttributeSocket(self): @@ -456,11 +456,7 @@ def setUp(self): query("set lc_monetary='C'") query("set datestyle='ISO,YMD'") query('set standard_conforming_strings=on') - try: - query('set bytea_output=hex') - except pg.ProgrammingError: - if self.db.server_version >= 90000: - raise # ignore for older server versions + query('set bytea_output=hex') def tearDown(self): self.doCleanups() @@ -1951,13 +1947,7 @@ def testInsertIntoView(self): r = query(q).getresult() self.assertEqual(r, [(1234, 'abcd')]) r = dict(i4=5678, v4='efgh') - try: - insert('test_view', r) - except (pg.OperationalError, pg.NotSupportedError) as error: - if self.db.server_version < 90300: - # must setup rules in older PostgreSQL versions - self.skipTest('database cannot insert into view') - self.fail(str(error)) + insert('test_view', r) self.assertNotIn('i2', r) self.assertEqual(r['i4'], 5678) self.assertNotIn('i8', r) @@ -2203,12 +2193,7 @@ def testUpsert(self): table = 'upsert_test_table' self.createTable(table, 'n integer primary key, t text') s = dict(n=1, t='x') - try: - r = upsert(table, s) - except pg.ProgrammingError as error: - if self.db.server_version < 90500: - self.skipTest('database does not support upsert') - self.fail(str(error)) + r = upsert(table, s) self.assertIs(r, s) self.assertEqual(r['n'], 1) self.assertEqual(r['t'], 'x') @@ -2296,12 +2281,7 @@ def testUpsertWithOids(self): self.assertIn('m', self.db.get_attnames('test_table', flush=True)) self.assertEqual('n', self.db.pkey('test_table', flush=True)) s = dict(n=2) - try: - r = upsert('test_table', s) - except pg.ProgrammingError as error: - if self.db.server_version < 90500: - self.skipTest('database does not support upsert') - self.fail(str(error)) + r = upsert('test_table', s) self.assertIs(r, s) self.assertEqual(r['n'], 2) self.assertIsNone(r['m']) @@ -2366,12 +2346,7 @@ def testUpsertWithCompositeKey(self): self.createTable( table, 'n integer, m integer, t text, primary key (n, m)') s = dict(n=1, m=2, t='x') - try: - r = upsert(table, s) - except pg.ProgrammingError as error: - if self.db.server_version < 90500: - self.skipTest('database does not support upsert') - self.fail(str(error)) + r = upsert(table, s) self.assertIs(r, s) self.assertEqual(r['n'], 1) self.assertEqual(r['m'], 2) @@ -2433,12 +2408,7 @@ def testUpsertWithQuotedNames(self): self.createTable(table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text') s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'} - try: - r = upsert(table, s) - except pg.ProgrammingError as error: - if self.db.server_version < 90500: - self.skipTest('database does not support upsert') - self.fail(str(error)) + r = upsert(table, s) self.assertIs(r, s) self.assertEqual(r['Prime!'], 31) self.assertEqual(r['much space'], 9009) @@ -2456,8 +2426,6 @@ def testUpsertWithQuotedNames(self): self.assertEqual(r, [(31, 9009, 'No.')]) def testUpsertWithGeneratedColumns(self): - if self.db.server_version < 90500: - self.skipTest('database does not support upsert') upsert = self.db.upsert get = self.db.get server_version = self.db.server_version @@ -3378,12 +3346,7 @@ def testUpsertBytea(self): self.createTable('bytea_test', 'n smallint primary key, data bytea') s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n" r = dict(n=7, data=s) - try: - r = self.db.upsert('bytea_test', r) - except pg.ProgrammingError as error: - if self.db.server_version < 90500: - self.skipTest('database does not support upsert') - self.fail(str(error)) + r = self.db.upsert('bytea_test', r) self.assertIsInstance(r, dict) self.assertIn('n', r) self.assertEqual(r['n'], 7) @@ -3402,12 +3365,7 @@ def testUpsertBytea(self): self.assertIsNone(r['data']) def testInsertGetJson(self): - try: - self.createTable('json_test', 'n smallint primary key, data json') - except pg.ProgrammingError as error: - if self.db.server_version < 90200: - self.skipTest('database does not support json') - self.fail(str(error)) + self.createTable('json_test', 'n smallint primary key, data json') jsondecode = pg.get_jsondecode() # insert null value r = self.db.insert('json_test', n=0, data=None) @@ -3471,13 +3429,8 @@ def testInsertGetJson(self): self.assertEqual(r[0][0], r[1][0]) def testInsertGetJsonb(self): - try: - self.createTable('jsonb_test', - 'n smallint primary key, data jsonb') - except pg.ProgrammingError as error: - if self.db.server_version < 90400: - self.skipTest('database does not support jsonb') - self.fail(str(error)) + self.createTable('jsonb_test', + 'n smallint primary key, data jsonb') jsondecode = pg.get_jsondecode() # insert null value r = self.db.insert('jsonb_test', n=0, data=None) @@ -3703,13 +3656,7 @@ def testArrayOfBytea(self): self.assertNotEqual(r['data'], data) def testArrayOfJson(self): - try: - self.createTable( - 'arraytest', 'id serial primary key, data json[]') - except pg.ProgrammingError as error: - if self.db.server_version < 90200: - self.skipTest('database does not support json') - self.fail(str(error)) + self.createTable('arraytest', 'id serial primary key, data json[]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'json[]') data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')] @@ -3751,13 +3698,7 @@ def testArrayOfJson(self): self.assertEqual(r, '{NULL,NULL}') def testArrayOfJsonb(self): - try: - self.createTable( - 'arraytest', 'id serial primary key, data jsonb[]') - except pg.ProgrammingError as error: - if self.db.server_version < 90400: - self.skipTest('database does not support jsonb') - self.fail(str(error)) + self.createTable('arraytest', 'id serial primary key, data jsonb[]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'jsonb[]' if self.regtypes else 'json[]') data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')] @@ -3941,13 +3882,7 @@ def testRecordInsertBytea(self): def testRecordInsertJson(self): query = self.db.query - try: - query('create type test_person_type as' - ' (name text, data json)') - except pg.ProgrammingError as error: - if self.db.server_version < 90200: - self.skipTest('database does not support json') - self.fail(str(error)) + query('create type test_person_type as (name text, data json)') self.addCleanup(query, 'drop type test_person_type') self.createTable('test_person', 'person test_person_type', temporary=False) From 7e53673a2cfe541983740b01ba99bec210830acc Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 30 Aug 2023 01:07:00 +0200 Subject: [PATCH 031/118] Modernize string formatting in pg and pgdb --- pg.py | 192 +++++++++++++++++++++++++++++--------------------------- pgdb.py | 95 ++++++++++++++-------------- 2 files changed, 150 insertions(+), 137 deletions(-) diff --git a/pg.py b/pg.py index 896af911..50f22425 100644 --- a/pg.py +++ b/pg.py @@ -49,7 +49,7 @@ if e: raise ImportError( "Cannot import shared library for PyGreSQL,\n" - "probably because no %s is installed.\n%s" % (libpq, e)) from e + f"probably because no {libpq} is installed.\n{e}") from e else: del version @@ -148,7 +148,7 @@ def _timezone_as_offset(tz): def _oid_key(table): """Build oid key from a table name.""" - return 'oid(%s)' % table + return f'oid({table})' class Bytea(bytes): @@ -170,12 +170,12 @@ def _quote(cls, s): return '""' s = s.replace('"', '\\"') if cls._re_quote.search(s): - s = '"%s"' % s + s = f'"{s}"' return s def __str__(self): q = self._quote - return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items()) + return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) class Json: @@ -220,9 +220,9 @@ def __init__(self): for key in keys: self[key] = typ if isinstance(key, str): - self['_%s' % key] = '%s[]' % typ + self[f'_{key}'] = f'{typ}[]' elif not isinstance(key, tuple): - self[List[key]] = '%s[]' % typ + self[List[key]] = f'{typ}[]' @staticmethod def __missing__(key): @@ -248,7 +248,7 @@ def _quote_if_unqualified(param, name): and must be quoted manually by the caller. """ if isinstance(name, str) and '.' not in name: - return 'quote_ident(%s)' % (param,) + return f'quote_ident({param})' return param @@ -266,7 +266,7 @@ def add(self, value, typ=None): if isinstance(value, Literal): return value self.append(value) - return '$%d' % len(self) + return f'${len(self)}' class Literal(str): @@ -366,7 +366,7 @@ def _adapt_hstore(self, v): return str(v) if isinstance(v, dict): return str(Hstore(v)) - raise TypeError('Hstore parameter %s has wrong type' % v) + raise TypeError(f'Hstore parameter {v} has wrong type') def _adapt_uuid(self, v): """Adapt a UUID parameter.""" @@ -381,14 +381,15 @@ def _adapt_text_array(cls, v): """Adapt a text type array parameter.""" if isinstance(v, list): adapt = cls._adapt_text_array - return '{%s}' % ','.join(adapt(v) for v in v) + return '{' + ','.join(adapt(v) for v in v) + '}' if v is None: return 'null' if not v: return '""' v = str(v) if cls._re_array_quote.search(v): - v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v) + v = cls._re_array_escape.sub(r'\\\1', v) + v = f'"{v}"' return v _adapt_date_array = _adapt_text_array @@ -398,7 +399,7 @@ def _adapt_bool_array(cls, v): """Adapt a boolean array parameter.""" if isinstance(v, list): adapt = cls._adapt_bool_array - return '{%s}' % ','.join(adapt(v) for v in v) + return '{' + ','.join(adapt(v) for v in v) + '}' if v is None: return 'null' if isinstance(v, str): @@ -412,7 +413,7 @@ def _adapt_num_array(cls, v): """Adapt a numeric array parameter.""" if isinstance(v, list): adapt = cls._adapt_num_array - return '{%s}' % ','.join(adapt(v) for v in v) + v = '{' + ','.join(adapt(v) for v in v) + '}' if not v and v != 0: return 'null' return str(v) @@ -433,20 +434,21 @@ def _adapt_json_array(self, v): """Adapt a json array parameter.""" if isinstance(v, list): adapt = self._adapt_json_array - return '{%s}' % ','.join(adapt(v) for v in v) + return '{' + ','.join(adapt(v) for v in v) + '}' if not v: return 'null' if not isinstance(v, str): v = self.db.encode_json(v) if self._re_array_quote.search(v): - v = '"%s"' % self._re_array_escape.sub(r'\\\1', v) + v = self._re_array_escape.sub(r'\\\1', v) + v = f'"{v}"' return v def _adapt_record(self, v, typ): """Adapt a record parameter with given type.""" typ = self.get_attnames(typ).values() if len(typ) != len(v): - raise TypeError('Record parameter %s has wrong size' % v) + raise TypeError(f'Record parameter {v} has wrong size') adapt = self.adapt value = [] for v, t in zip(v, typ): @@ -462,9 +464,11 @@ def _adapt_record(self, v, typ): else: v = str(v) if self._re_record_quote.search(v): - v = '"%s"' % self._re_record_escape.sub(r'\\\1', v) + v = self._re_record_escape.sub(r'\\\1', v) + v = f'"{v}"' value.append(v) - return '(%s)' % ','.join(value) + v = ','.join(value) + return f'({v})' def adapt(self, value, typ=None): """Adapt a value with known database type.""" @@ -483,10 +487,10 @@ def adapt(self, value, typ=None): value = self._adapt_record(value, typ) elif simple.endswith('[]'): if isinstance(value, list): - adapt = getattr(self, '_adapt_%s_array' % simple[:-2]) + adapt = getattr(self, f'_adapt_{simple[:-2]}_array') value = adapt(value) else: - adapt = getattr(self, '_adapt_%s' % simple) + adapt = getattr(self, f'_adapt_{simple}') value = adapt(value) return value @@ -541,7 +545,7 @@ def guess_simple_type(cls, value): if isinstance(value, UUID): return 'uuid' if isinstance(value, list): - return '%s[]' % (cls.guess_simple_base_type(value) or 'text',) + return (cls.guess_simple_base_type(value) or 'text') + '[]' if isinstance(value, tuple): simple_type = cls.simple_type guess = cls.guess_simple_type @@ -578,7 +582,7 @@ def adapt_inline(self, value, nested=False): value = str(value) if isinstance(value, (bytes, str)): value = self.db.escape_string(value) - return "'%s'" % value + return f"'{value}'" if isinstance(value, bool): return 'true' if value else 'false' if isinstance(value, float): @@ -591,21 +595,21 @@ def adapt_inline(self, value, nested=False): return value if isinstance(value, list): q = self.adapt_inline - s = '[%s]' if nested else 'ARRAY[%s]' - return s % ','.join(str(q(v, nested=True)) for v in value) + s = '[{}]' if nested else 'ARRAY[{}]' + return s.format(','.join(str(q(v, nested=True)) for v in value)) if isinstance(value, tuple): q = self.adapt_inline - return '(%s)' % ','.join(str(q(v)) for v in value) + return '({})'.format(','.join(str(q(v)) for v in value)) if isinstance(value, Json): value = self.db.escape_string(str(value)) - return "'%s'::json" % value + return f"'{value}'::json" if isinstance(value, Hstore): value = self.db.escape_string(str(value)) - return "'%s'::hstore" % value + return f"'{value}'::hstore" pg_repr = getattr(value, '__pg_repr__', None) if not pg_repr: raise InterfaceError( - 'Do not know how to adapt type %s' % type(value)) + f'Do not know how to adapt type {type(value)}') value = pg_repr() if isinstance(value, (tuple, list)): value = self.adapt_inline(value) @@ -903,7 +907,7 @@ def cast_interval(value): secs = -secs usecs = -usecs else: - raise ValueError('Cannot parse interval: %s' % value) + raise ValueError(f'Cannot parse interval: {value}') days += 365 * years + 30 * mons return timedelta(days=days, hours=hours, minutes=mins, seconds=secs, microseconds=usecs) @@ -946,7 +950,7 @@ def __missing__(self, typ): but returns None when no special cast function exists. """ if not isinstance(typ, str): - raise TypeError('Invalid type: %s' % typ) + raise TypeError('Invalid type: {typ}') cast = self.defaults.get(typ) if cast: # store default for faster access @@ -992,13 +996,13 @@ def set(self, typ, cast): if cast is None: for t in typ: self.pop(t, None) - self.pop('_%s' % t, None) + self.pop(f'_{t}', None) else: if not callable(cast): raise TypeError("Cast parameter must be callable") for t in typ: self[t] = self._add_connection(cast) - self.pop('_%s' % t, None) + self.pop(f'_{t}', None) def reset(self, typ=None): """Reset the typecasts for the specified type(s) to their defaults. @@ -1027,13 +1031,13 @@ def set_default(cls, typ, cast): if cast is None: for t in typ: defaults.pop(t, None) - defaults.pop('_%s' % t, None) + defaults.pop(f'_{t}', None) else: if not callable(cast): raise TypeError("Cast parameter must be callable") for t in typ: defaults[t] = cast - defaults.pop('_%s' % t, None) + defaults.pop(f'_{t}', None) # noinspection PyMethodMayBeStatic,PyUnusedLocal def get_attnames(self, typ): @@ -1159,7 +1163,7 @@ def __missing__(self, key): except ProgrammingError: res = None if not res: - raise KeyError('Type %s could not be found' % (key,)) + raise KeyError(f'Type {key} could not be found') res = res[0] typ = self.add(*res) self[typ.oid] = self[typ.pgtype] = typ @@ -1224,7 +1228,7 @@ def _row_factory(names): try: return namedtuple('Row', names, rename=True)._make except ValueError: # there is still a problem with the field names - names = ['column_%d' % (n,) for n in range(len(names))] + names = [f'column_{n}' for n in range(len(names))] return namedtuple('Row', names)._make @@ -1335,7 +1339,7 @@ def __init__(self, db, event, callback=None, """ self.db = db self.event = event - self.stop_event = stop_event or 'stop_%s' % event + self.stop_event = stop_event or f'stop_{event}' self.listening = False self.callback = callback if arg_dict is None: @@ -1356,15 +1360,15 @@ def close(self): def listen(self): """Start listening for the event and the stop event.""" if not self.listening: - self.db.query('listen "%s"' % self.event) - self.db.query('listen "%s"' % self.stop_event) + self.db.query(f'listen "{self.event}"') + self.db.query(f'listen "{self.stop_event}"') self.listening = True def unlisten(self): """Stop listening for the event and the stop event.""" if self.listening: - self.db.query('unlisten "%s"' % self.event) - self.db.query('unlisten "%s"' % self.stop_event) + self.db.query(f'unlisten "{self.event}"') + self.db.query(f'unlisten "{self.stop_event}"') self.listening = False def notify(self, db=None, stop=False, payload=None): @@ -1382,9 +1386,10 @@ def notify(self, db=None, stop=False, payload=None): if self.listening: if not db: db = self.db - q = 'notify "%s"' % (self.stop_event if stop else self.event) + event = self.stop_event if stop else self.event + q = f'notify "{event}"' if payload: - q += ", '%s'" % payload + q += f", '{payload}'" return db.query(q) def __call__(self): @@ -1420,8 +1425,9 @@ def __call__(self): if event not in (self.event, self.stop_event): self.unlisten() raise _db_error( - 'Listening for "%s" and "%s", but notified of "%s"' - % (self.event, self.stop_event, event)) + f'Listening for "{self.event}"' + f' and "{self.stop_event}",' + f' but notified of "{event}"') if event == self.stop_event: self.unlisten() self.arg_dict.update(pid=pid, event=event, extra=extra) @@ -1592,7 +1598,7 @@ def _make_bool(d): @staticmethod def _list_params(params): """Create a human readable parameter list.""" - return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1)) + return ', '.join(f'${n}={v!r}' for n, v in enumerate(params, 1)) # Public methods @@ -1735,7 +1741,7 @@ def get_parameter(self, parameter): params.append(param) else: for param in params: - q = 'SHOW %s' % (param,) + q = f'SHOW {param}' value = self.db.query(q).singlescalar() if values is None: values = value @@ -1813,9 +1819,9 @@ def set_parameter(self, parameter, value=None, local=False): local = ' LOCAL' if local else '' for param, value in params.items(): if value is None: - q = 'RESET%s %s' % (local, param) + q = f'RESET{local} {param}' else: - q = 'SET%s %s TO %s' % (local, param, value) + q = f'SET{local} {param} TO {value}' self._do_debug(q) self.db.query(q) @@ -1919,7 +1925,9 @@ def delete_prepared(self, name=None): name. Note that prepared statements are also deallocated automatically when the current session ends. """ - q = "DEALLOCATE %s" % (name or 'ALL',) + if not name: + name = 'ALL' + q = f"DEALLOCATE {name}" self._do_debug(q) return self.db.query(q) @@ -1949,12 +1957,12 @@ def pkey(self, table, composite=False, flush=False): " AND a.attnum OPERATOR(pg_catalog.=) ANY(i.indkey)" " AND NOT a.attisdropped" " WHERE i.indrelid OPERATOR(pg_catalog.=)" - " %s::pg_catalog.regclass" - " AND i.indisprimary ORDER BY a.attnum") % ( - _quote_if_unqualified('$1', table),) + " {}::pg_catalog.regclass" + " AND i.indisprimary ORDER BY a.attnum").format( + _quote_if_unqualified('$1', table)) pkey = self.db.query(q, (table,)).getresult() if not pkey: - raise KeyError('Table %s has no primary key' % table) + raise KeyError(f'Table {table} has no primary key') # we want to use the order defined in the primary key index here, # not the order as defined by the columns in the table if len(pkey) > 1: @@ -1984,18 +1992,18 @@ def get_relations(self, kinds=None, system=False): """ where = [] if kinds: - where.append("r.relkind IN (%s)" % - ','.join("'%s'" % k for k in kinds)) + where.append( + "r.relkind IN ({})".format(','.join(f"'{k}'" for k in kinds))) if not system: where.append("s.nspname NOT SIMILAR" " TO 'pg/_%|information/_schema' ESCAPE '/'") - where = " WHERE %s" % ' AND '.join(where) if where else '' + where = " WHERE " + ' AND '.join(where) if where else '' q = ("SELECT pg_catalog.quote_ident(s.nspname) OPERATOR(pg_catalog.||)" " '.' OPERATOR(pg_catalog.||) pg_catalog.quote_ident(r.relname)" " FROM pg_catalog.pg_class r" " JOIN pg_catalog.pg_namespace s" - " ON s.oid OPERATOR(pg_catalog.=) r.relnamespace%s" - " ORDER BY s.nspname, r.relname") % where + f" ON s.oid OPERATOR(pg_catalog.=) r.relnamespace{where}" + " ORDER BY s.nspname, r.relname") return [r[0] for r in self.db.query(q).getresult()] def get_tables(self, system=False): @@ -2089,8 +2097,8 @@ def has_table_privilege(self, table, privilege='select', flush=False): try: # ask cache ret = privileges[table, privilege] except KeyError: # cache miss, ask the database - q = "SELECT pg_catalog.has_table_privilege(%s, $2)" % ( - _quote_if_unqualified('$1', table),) + q = "SELECT pg_catalog.has_table_privilege({}, $2)".format( + _quote_if_unqualified('$1', table)) q = self.db.query(q, (table, privilege)) ret = q.singlescalar() == self._make_bool(True) privileges[table, privilege] = ret # cache it @@ -2130,7 +2138,7 @@ def get(self, table, row, keyname=None): if qoid and isinstance(row, dict) and 'oid' in row: keyname = ('oid',) else: - raise _prg_error('Table %s has no primary key' % table) + raise _prg_error(f'Table {table} has no primary key') else: # the table has a primary key # check whether all key columns have values if isinstance(row, dict) and not set(keyname).issubset(row): @@ -2151,22 +2159,23 @@ def get(self, table, row, keyname=None): adapt = params.add col = self.escape_identifier what = 'oid, *' if qoid else '*' - where = ' AND '.join('%s OPERATOR(pg_catalog.=) %s' % ( + where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( col(k), adapt(row[k], attnames[k])) for k in keyname) if 'oid' in row: if qoid: row[qoid] = row['oid'] del row['oid'] - q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % ( - what, self._escape_qualified_name(table), where) + t = self._escape_qualified_name(table) + q = f'SELECT {what} FROM {t} WHERE {where} LIMIT 1' self._do_debug(q, params) q = self.db.query(q, params) res = q.dictresult() if not res: # make where clause in error message better readable where = where.replace('OPERATOR(pg_catalog.=)', '=') - raise _db_error('No such record in %s\nwhere %s\nwith %s' % ( - table, where, self._list_params(params))) + raise _db_error( + f'No such record in {table}\nwhere {where}\nwith ' + + self._list_params(params)) for n, value in res[0].items(): if qoid and n == 'oid': n = qoid @@ -2208,8 +2217,8 @@ def insert(self, table, row=None, **kw): raise _prg_error('No column found that can be inserted') names, values = ', '.join(names), ', '.join(values) ret = 'oid, *' if qoid else '*' - q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % ( - self._escape_qualified_name(table), names, values, ret) + t = self._escape_qualified_name(table) + q = f'INSERT INTO {t} ({names}) VALUES ({values}) RETURNING {ret}' self._do_debug(q, params) q = self.db.query(q, params) res = q.dictresult() @@ -2249,14 +2258,14 @@ def update(self, table, row=None, **kw): try: keyname = self.pkey(table, True) except KeyError: # the table has no primary key - raise _prg_error('Table %s has no primary key' % table) + raise _prg_error(f'Table {table} has no primary key') # check whether all key columns have values if not set(keyname).issubset(row): raise KeyError('Missing value for primary key in row') params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier - where = ' AND '.join('%s OPERATOR(pg_catalog.=) %s' % ( + where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( col(k), adapt(row[k], attnames[k])) for k in keyname) if 'oid' in row: if qoid: @@ -2266,13 +2275,14 @@ def update(self, table, row=None, **kw): keyname = set(keyname) for n in attnames: if n in row and n not in keyname and n not in generated: - values.append('%s = %s' % (col(n), adapt(row[n], attnames[n]))) + values.append('{} = {}'.format( + col(n), adapt(row[n], attnames[n]))) if not values: return row values = ', '.join(values) ret = 'oid, *' if qoid else '*' - q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % ( - self._escape_qualified_name(table), values, where, ret) + t = self._escape_qualified_name(table) + q = f'UPDATE {t} SET {values} WHERE {where} RETURNING {ret}' self._do_debug(q, params) q = self.db.query(q, params) res = q.dictresult() @@ -2350,7 +2360,7 @@ def upsert(self, table, row=None, **kw): try: keyname = self.pkey(table, True) except KeyError: - raise _prg_error('Table %s has no primary key' % table) + raise _prg_error(f'Table {table} has no primary key') target = ', '.join(col(k) for k in keyname) update = [] keyname = set(keyname) @@ -2360,15 +2370,15 @@ def upsert(self, table, row=None, **kw): value = kw.get(n, n in row) if value: if not isinstance(value, str): - value = 'excluded.%s' % col(n) - update.append('%s = %s' % (col(n), value)) + value = f'excluded.{col(n)}' + update.append(f'{col(n)} = {value}') if not values: return row - do = 'update set %s' % ', '.join(update) if update else 'nothing' + do = 'update set ' + ', '.join(update) if update else 'nothing' ret = 'oid, *' if qoid else '*' - q = ('INSERT INTO %s AS included (%s) VALUES (%s)' - ' ON CONFLICT (%s) DO %s RETURNING %s') % ( - self._escape_qualified_name(table), names, values, target, do, ret) + t = self._escape_qualified_name(table) + q = (f'INSERT INTO {t} AS included ({names}) VALUES ({values})' + f' ON CONFLICT ({target}) DO {do} RETURNING {ret}') self._do_debug(q, params) q = self.db.query(q, params) res = q.dictresult() @@ -2435,21 +2445,21 @@ def delete(self, table, row=None, **kw): try: keyname = self.pkey(table, True) except KeyError: # the table has no primary key - raise _prg_error('Table %s has no primary key' % table) + raise _prg_error(f'Table {table} has no primary key') # check whether all key columns have values if not set(keyname).issubset(row): raise KeyError('Missing value for primary key in row') params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier - where = ' AND '.join('%s OPERATOR(pg_catalog.=) %s' % ( + where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( col(k), adapt(row[k], attnames[k])) for k in keyname) if 'oid' in row: if qoid: row[qoid] = row['oid'] del row['oid'] - q = 'DELETE FROM %s WHERE %s' % ( - self._escape_qualified_name(table), where) + t = self._escape_qualified_name(table) + q = f'DELETE FROM {t} WHERE {where}' self._do_debug(q, params) res = self.db.query(q, params) return int(res) @@ -2499,7 +2509,7 @@ def truncate(self, table, restart=False, cascade=False, only=False): t = t[:-1].rstrip() t = self._escape_qualified_name(t) if u: - t = 'ONLY %s' % t + t = f'ONLY {t}' tables.append(t) q = ['TRUNCATE', ', '.join(tables)] if restart: @@ -2565,9 +2575,9 @@ def get_as_list(self, table, what=None, where=None, order = ', '.join(map(str, order)) q.extend(['ORDER BY', order]) if limit: - q.append('LIMIT %d' % limit) + q.append(f'LIMIT {limit}') if offset: - q.append('OFFSET %d' % offset) + q.append(f'OFFSET {offset}') q = ' '.join(q) self._do_debug(q) q = self.db.query(q) @@ -2603,7 +2613,7 @@ def get_as_dict(self, table, keyname=None, what=None, where=None, try: keyname = self.pkey(table, True) except (KeyError, ProgrammingError): - raise _prg_error('Table %s has no primary key' % table) + raise _prg_error(f'Table {table} has no primary key') if isinstance(keyname, str): keyname = [keyname] elif not isinstance(keyname, (list, tuple)): @@ -2627,9 +2637,9 @@ def get_as_dict(self, table, keyname=None, what=None, where=None, order = ', '.join(map(str, order)) q.extend(['ORDER BY', order]) if limit: - q.append('LIMIT %d' % limit) + q.append(f'LIMIT {limit}') if offset: - q.append('OFFSET %d' % offset) + q.append(f'OFFSET {offset}') q = ' '.join(q) self._do_debug(q) q = self.db.query(q) diff --git a/pgdb.py b/pgdb.py index f986242f..5e218b42 100644 --- a/pgdb.py +++ b/pgdb.py @@ -93,7 +93,7 @@ if e: raise ImportError( "Cannot import shared library for PyGreSQL,\n" - "probably because no %s is installed.\n%s" % (libpq, e)) from e + f"probably because no {libpq} is installed.\n{e}") from e else: del version @@ -389,7 +389,7 @@ def cast_interval(value): secs = -secs usecs = -usecs else: - raise ValueError('Cannot parse interval: %s' % value) + raise ValueError(f'Cannot parse interval: {value}') days += 365 * years + 30 * mons return timedelta(days=days, hours=hours, minutes=mins, seconds=secs, microseconds=usecs) @@ -429,7 +429,7 @@ def __missing__(self, typ): but returns None when no special cast function exists. """ if not isinstance(typ, str): - raise TypeError('Invalid type: %s' % typ) + raise TypeError(f'Invalid type: {typ}') cast = self.defaults.get(typ) if cast: # store default for faster access @@ -471,13 +471,13 @@ def set(self, typ, cast): if cast is None: for t in typ: self.pop(t, None) - self.pop('_%s' % t, None) + self.pop(f'_{t}', None) else: if not callable(cast): raise TypeError("Cast parameter must be callable") for t in typ: self[t] = self._add_connection(cast) - self.pop('_%s' % t, None) + self.pop(f'_{t}', None) def reset(self, typ=None): """Reset the typecasts for the specified type(s) to their defaults. @@ -495,7 +495,7 @@ def reset(self, typ=None): cast = defaults.get(t) if cast: self[t] = self._add_connection(cast) - t = '_%s' % t + t = f'_{t}' cast = defaults.get(t) if cast: self[t] = self._add_connection(cast) @@ -503,7 +503,7 @@ def reset(self, typ=None): self.pop(t, None) else: self.pop(t, None) - self.pop('_%s' % t, None) + self.pop(f'_{t}', None) def create_array_cast(self, basecast): """Create an array typecast for the given base cast.""" @@ -640,8 +640,8 @@ def __missing__(self, key): oid = key else: if '.' not in key and '"' not in key: - key = '"%s"' % (key,) - oid = "'%s'::pg_catalog.regtype" % (self._escape_string(key),) + key = f'"{key}"' + oid = f"'{self._escape_string(key)}'::pg_catalog.regtype" try: self._src.execute(self._query_pg_type.format(oid)) except ProgrammingError: @@ -649,7 +649,7 @@ def __missing__(self, key): else: res = self._src.fetch(1) if not res: - raise KeyError('Type %s could not be found' % (key,)) + raise KeyError(f'Type {key} could not be found') res = res[0] type_code = TypeCode.create( int(res[0]), res[1], int(res[2]), @@ -676,9 +676,9 @@ def get_fields(self, typ): self._src.execute( "SELECT attname, atttypid" " FROM pg_catalog.pg_attribute" - " WHERE attrelid OPERATOR(pg_catalog.=) %s" + f" WHERE attrelid OPERATOR(pg_catalog.=) {typ.relid}" " AND attnum OPERATOR(pg_catalog.>) 0" - " AND NOT attisdropped ORDER BY attnum" % (typ.relid,)) + " AND NOT attisdropped ORDER BY attnum") return [FieldInfo(name, self.get(int(oid))) for name, oid in self._src.fetch(-1)] @@ -761,7 +761,7 @@ def _row_factory(names): try: return namedtuple('Row', names, rename=True)._make except ValueError: # there is still a problem with the field names - names = ['column_%d' % (n,) for n in range(len(names))] + names = [f'column_{n}' for n in range(len(names))] return namedtuple('Row', names)._make @@ -820,7 +820,7 @@ def _quote(self, value): value = self._cnx.escape_bytea(value).decode('ascii') else: value = self._cnx.escape_string(value) - return "'%s'" % (value,) + return f"'{value}'" if isinstance(value, float): if isinf(value): return "'-Infinity'" if value < 0 else "'Infinity'" @@ -831,18 +831,18 @@ def _quote(self, value): return value if isinstance(value, datetime): if value.tzinfo: - return "'%s'::timestamptz" % (value,) - return "'%s'::timestamp" % (value,) + return f"'{value}'::timestamptz" + return f"'{value}'::timestamp" if isinstance(value, date): - return "'%s'::date" % (value,) + return f"'{value}'::date" if isinstance(value, time): if value.tzinfo: - return "'%s'::timetz" % (value,) - return "'%s'::time" % value + return f"'{value}'::timetz" + return f"'{value}'::time" if isinstance(value, timedelta): - return "'%s'::interval" % (value,) + return f"'{value}'::interval" if isinstance(value, Uuid): - return "'%s'::uuid" % (value,) + return f"'{value}'::uuid" if isinstance(value, list): # Quote value as an ARRAY constructor. This is better than using # an array literal because it carries the information that this is @@ -852,7 +852,8 @@ def _quote(self, value): if not value: # exception for empty array return "'{}'" q = self._quote - return 'ARRAY[%s]' % (','.join(str(q(v)) for v in value),) + v = ','.join(str(q(v)) for v in value) + return f'ARRAY[{v}]' if isinstance(value, tuple): # Quote as a ROW constructor. This is better than using a record # literal because it carries the information that this is a record @@ -860,12 +861,13 @@ def _quote(self, value): # this usable with the IN syntax as well. It is only necessary # when the records has a single column which is not really useful. q = self._quote - return '(%s)' % (','.join(str(q(v)) for v in value),) + v = ','.join(str(q(v)) for v in value) + return f'({v})' try: # noinspection PyUnresolvedReferences value = value.__pg_repr__() except AttributeError: raise InterfaceError( - 'Do not know how to adapt type %s' % (type(value),)) + f'Do not know how to adapt type {type(value)}') if isinstance(value, (tuple, list)): value = self._quote(value) return value @@ -979,10 +981,9 @@ def executemany(self, operation, seq_of_parameters): raise # database provides error message except Error as err: # noinspection PyTypeChecker - raise _db_error( - "Error in '%s': '%s' " % (sql, err), InterfaceError) + raise _db_error(f"Error in '{sql}': '{err}'", InterfaceError) except Exception as err: - raise _op_error("Internal error in '%s': %s" % (sql, err)) + raise _op_error(f"Internal error in '{sql}': {err}") # then initialize result raw count and description if self._src.resulttype == RESULT_DQL: self._description = True # fetch on demand @@ -1049,8 +1050,9 @@ def callproc(self, procname, parameters=None): The procedure may also provide a result set as output. These can be requested through the standard fetch methods of the cursor. """ - n = parameters and len(parameters) or 0 - query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s'])) + n = len(parameters) if parameters else 0 + s = ','.join(n * ['%s']) + query = f'select * from "{procname}"({s})' self.execute(query, parameters) return parameters @@ -1088,7 +1090,7 @@ def copy_from(self, stream, table, if isinstance(stream, (bytes, str)): if not isinstance(stream, input_type): - raise ValueError("The input must be %s" % (type_name,)) + raise ValueError(f"The input must be {type_name}") if not binary_format: if isinstance(stream, str): if not stream.endswith('\n'): @@ -1106,8 +1108,7 @@ def chunks(): for chunk in stream: if not isinstance(chunk, input_type): raise ValueError( - "Input stream must consist of %s" - % (type_name,)) + f"Input stream must consist of {type_name}") if isinstance(chunk, str): if not chunk.endswith('\n'): chunk += '\n' @@ -1144,7 +1145,7 @@ def chunks(): else: table = '.'.join(map( self.connection._cnx.escape_identifier, table.split('.', 1))) - operation = ['copy %s' % (table,)] + operation = [f'copy {table}'] options = [] params = [] if format is not None: @@ -1152,7 +1153,7 @@ def chunks(): raise TypeError("The format option must be be a string") if format not in ('text', 'csv', 'binary'): raise ValueError("Invalid format") - options.append('format %s' % (format,)) + options.append(f'format {format}') if sep is not None: if not isinstance(sep, str): raise TypeError("The sep option must be a string") @@ -1173,10 +1174,11 @@ def chunks(): if not isinstance(columns, str): columns = ','.join(map( self.connection._cnx.escape_identifier, columns)) - operation.append('(%s)' % (columns,)) + operation.append(f'({columns})') operation.append("from stdin") if options: - operation.append('(%s)' % (','.join(options),)) + options = ','.join(options) + operation.append(f'({options})') operation = ' '.join(operation) putdata = self._src.putdata @@ -1226,11 +1228,11 @@ def copy_to(self, stream, table, if table.lower().startswith('select '): if columns: raise ValueError("Columns must be specified in the query") - table = '(%s)' % (table,) + table = f'({table})' else: table = '.'.join(map( self.connection._cnx.escape_identifier, table.split('.', 1))) - operation = ['copy %s' % (table,)] + operation = [f'copy {table}'] options = [] params = [] if format is not None: @@ -1238,7 +1240,7 @@ def copy_to(self, stream, table, raise TypeError("The format option must be a string") if format not in ('text', 'csv', 'binary'): raise ValueError("Invalid format") - options.append('format %s' % (format,)) + options.append(f'format {format}') if sep is not None: if not isinstance(sep, str): raise TypeError("The sep option must be a string") @@ -1267,11 +1269,12 @@ def copy_to(self, stream, table, if not isinstance(columns, str): columns = ','.join(map( self.connection._cnx.escape_identifier, columns)) - operation.append('(%s)' % (columns,)) + operation.append(f'({columns})') operation.append("to stdout") if options: - operation.append('(%s)' % (','.join(options),)) + options = ','.join(options) + operation.append(f'({options})') operation = ' '.join(operation) getdata = self._src.getdata @@ -1553,9 +1556,9 @@ def connect(dsn=None, for kw, value in kwargs: value = str(value) if not value or ' ' in value: - value = "'%s'" % (value.replace( - '\\', '\\\\').replace("'", "\\'")) - dbname.append('%s=%s' % (kw, value)) + value = value.replace('\\', '\\\\').replace("'", "\\'") + value = f"'{value}'" + dbname.append(f'{kw}={value}') dbname = ' '.join(dbname) # open the connection # noinspection PyArgumentList @@ -1734,12 +1737,12 @@ def _quote(cls, s): quote = cls._re_quote.search(s) s = cls._re_escape.sub(r'\\\1', s) if quote: - s = '"%s"' % (s,) + s = f'"{s}"' return s def __str__(self): q = self._quote - return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items()) + return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) class Json: From a7fe116b61fe8d30c7e9595d7ca8820987bbb9ec Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 Sep 2023 13:23:17 +0200 Subject: [PATCH 032/118] Modernize string formatting in tests --- tests/dbapi20.py | 212 ++++++++++------------ tests/test_classic.py | 8 +- tests/test_classic_connection.py | 154 ++++++++-------- tests/test_classic_dbwrapper.py | 270 ++++++++++++++--------------- tests/test_classic_functions.py | 31 ++-- tests/test_classic_largeobj.py | 20 +-- tests/test_classic_notification.py | 14 +- tests/test_dbapi20.py | 196 +++++++++++---------- tests/test_dbapi20_copy.py | 24 +-- tests/test_tutorial.py | 2 +- 10 files changed, 452 insertions(+), 479 deletions(-) diff --git a/tests/dbapi20.py b/tests/dbapi20.py index b793fbf2..798bbc49 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -3,6 +3,8 @@ """Python DB API 2.0 driver compliance unit test suite. This software is Public Domain and may be used without restrictions. + +Some modernization of the code has been done by the PyGreSQL team. """ __version__ = '1.15.0' @@ -10,6 +12,8 @@ import unittest import time +from typing import Any, Dict, Tuple + class DatabaseAPI20Test(unittest.TestCase): """Test a database self.driver for DB API 2.0 compatibility. @@ -36,16 +40,16 @@ class mytest(dbapi20.DatabaseAPI20Test): # The self.driver module. This should be the module where the 'connect' # method is to be found - driver = None - connect_args = () # List of arguments to pass to connect - connect_kw_args = {} # Keyword arguments for connect + driver: Any = None + connect_args: Tuple = () # List of arguments to pass to connect + connect_kw_args: Dict[str, Any] = {} # Keyword arguments for connect table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables - ddl1 = 'create table %sbooze (name varchar(20))' % (table_prefix,) - ddl2 = 'create table %sbarflys (name varchar(20), drink varchar(30))' % ( - table_prefix,) - xddl1 = 'drop table %sbooze' % (table_prefix,) - xddl2 = 'drop table %sbarflys' % (table_prefix,) + ddl1 = f'create table {table_prefix}booze (name varchar(20))' + ddl2 = (f'create table {table_prefix}barflys (name varchar(20),' + ' drink varchar(30))') + xddl1 = f'drop table {table_prefix}booze' + xddl2 = f'drop table {table_prefix}barflys' insert = 'insert' lowerfunc = 'lower' # Name of stored procedure to convert str to lowercase @@ -155,15 +159,15 @@ def test_ExceptionsAsConnectionAttributes(self): # by default. con = self._connect() drv = self.driver - self.assertTrue(con.Warning is drv.Warning) - self.assertTrue(con.Error is drv.Error) - self.assertTrue(con.InterfaceError is drv.InterfaceError) - self.assertTrue(con.DatabaseError is drv.DatabaseError) - self.assertTrue(con.OperationalError is drv.OperationalError) - self.assertTrue(con.IntegrityError is drv.IntegrityError) - self.assertTrue(con.InternalError is drv.InternalError) - self.assertTrue(con.ProgrammingError is drv.ProgrammingError) - self.assertTrue(con.NotSupportedError is drv.NotSupportedError) + self.assertIs(con.Warning, drv.Warning) + self.assertIs(con.Error, drv.Error) + self.assertIs(con.InterfaceError, drv.InterfaceError) + self.assertIs(con.DatabaseError, drv.DatabaseError) + self.assertIs(con.OperationalError, drv.OperationalError) + self.assertIs(con.IntegrityError, drv.IntegrityError) + self.assertIs(con.InternalError, drv.InternalError) + self.assertIs(con.ProgrammingError, drv.ProgrammingError) + self.assertIs(con.NotSupportedError, drv.NotSupportedError) def test_commit(self): con = self._connect() @@ -200,10 +204,9 @@ def test_cursor_isolation(self): cur1 = con.cursor() cur2 = con.cursor() self.executeDDL1(cur1) - cur1.execute("%s into %sbooze values ('Victoria Bitter')" % ( - self.insert, self.table_prefix - )) - cur2.execute("select name from %sbooze" % self.table_prefix) + cur1.execute(f"{self.insert} into {self.table_prefix}booze" + " values ('Victoria Bitter')") + cur2.execute(f"select name from {self.table_prefix}booze") booze = cur2.fetchall() self.assertEqual(len(booze), 1) self.assertEqual(len(booze[0]), 1) @@ -220,7 +223,7 @@ def test_description(self): cur.description, 'cursor.description should be none after executing a' ' statement that can return no rows (such as DDL)') - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') self.assertEqual( len(cur.description), 1, 'cursor.description describes too many columns') @@ -232,8 +235,8 @@ def test_description(self): 'cursor.description[x][0] must return column name') self.assertEqual( cur.description[0][1], self.driver.STRING, - 'cursor.description[x][1] must return column type. Got %r' - % cur.description[0][1]) + 'cursor.description[x][1] must return column type.' + f' Got: {cur.description[0][1]!r}') # Make sure self.description gets reset self.executeDDL2(cur) @@ -253,14 +256,13 @@ def test_rowcount(self): cur.rowcount, (-1, 0), # Bug #543885 'cursor.rowcount should be -1 or 0 after executing no-result' ' statements') - cur.execute("%s into %sbooze values ('Victoria Bitter')" % ( - self.insert, self.table_prefix - )) + cur.execute(f"{self.insert} into {self.table_prefix}booze" + " values ('Victoria Bitter')") self.assertIn( cur.rowcount, (-1, 1), 'cursor.rowcount should == number or rows inserted, or' ' set to -1 after executing an insert statement') - cur.execute("select name from %sbooze" % self.table_prefix) + cur.execute(f"select name from {self.table_prefix}booze") self.assertIn( cur.rowcount, (-1, 1), 'cursor.rowcount should == number of rows returned, or' @@ -325,47 +327,38 @@ def test_execute(self): def _paraminsert(self, cur): self.executeDDL2(cur) + table_prefix = self.table_prefix + insert = f"{self.insert} into {table_prefix}barflys values" cur.execute( - "%s into %sbarflys values ('Victoria Bitter'," - " 'thi%%s :may ca%%(u)se? troub:1e')" % ( - self.insert, self.table_prefix)) + f"{insert} ('Victoria Bitter'," + " 'thi%s :may ca%(u)se? troub:1e')") self.assertIn(cur.rowcount, (-1, 1)) if self.driver.paramstyle == 'qmark': cur.execute( - "%s into %sbarflys values (?," - " 'thi%%s :may ca%%(u)se? troub:1e')" % ( - self.insert, self.table_prefix), + f"{insert} (?, 'thi%s :may ca%(u)se? troub:1e')", ("Cooper's",)) elif self.driver.paramstyle == 'numeric': cur.execute( - "%s into %sbarflys values (:1," - " 'thi%%s :may ca%%(u)se? troub:1e')" % ( - self.insert, self.table_prefix), + f"{insert} (:1, 'thi%s :may ca%(u)se? troub:1e')", ("Cooper's",)) elif self.driver.paramstyle == 'named': cur.execute( - "%s into %sbarflys values (:beer," - " 'thi%%s :may ca%%(u)se? troub:1e')" % ( - self.insert, self.table_prefix), + f"{insert} (:beer, 'thi%s :may ca%(u)se? troub:1e')", {'beer': "Cooper's"}) elif self.driver.paramstyle == 'format': cur.execute( - "%s into %sbarflys values (%%s," - " 'thi%%%%s :may ca%%%%(u)se? troub:1e')" % ( - self.insert, self.table_prefix), + f"{insert} (%s, 'thi%%s :may ca%%(u)se? troub:1e')", ("Cooper's",)) elif self.driver.paramstyle == 'pyformat': cur.execute( - "%s into %sbarflys values (%%(beer)s," - " 'thi%%%%s :may ca%%%%(u)se? troub:1e')" % ( - self.insert, self.table_prefix), + f"{insert} (%(beer)s, 'thi%%s :may ca%%(u)se? troub:1e')", {'beer': "Cooper's"}) else: self.fail('Invalid paramstyle') self.assertIn(cur.rowcount, (-1, 1)) - cur.execute('select name, drink from %sbarflys' % self.table_prefix) + cur.execute(f'select name, drink from {table_prefix}barflys') res = cur.fetchall() self.assertEqual(len(res), 2, 'cursor.fetchall returned too few rows') beers = [res[0][0], res[1][0]] @@ -382,48 +375,38 @@ def _paraminsert(self, cur): self.assertEqual( res[0][1], trouble, 'cursor.fetchall retrieved incorrect data, or data inserted' - ' incorrectly. Got=%s, Expected=%s' % ( - repr(res[0][1]), repr(trouble))) + f' incorrectly. Got: {res[0][1]!r}, Expected: {trouble!r}') self.assertEqual( res[1][1], trouble, 'cursor.fetchall retrieved incorrect data, or data inserted' - ' incorrectly. Got=%s, Expected=%s' % ( - repr(res[1][1]), repr(trouble))) + f' incorrectly. Got: {res[1][1]!r}, Expected: {trouble!r}') def test_executemany(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) + table_prefix = self.table_prefix + insert = f'{self.insert} into {table_prefix}booze values' largs = [("Cooper's",), ("Boag's",)] margs = [{'beer': "Cooper's"}, {'beer': "Boag's"}] if self.driver.paramstyle == 'qmark': - cur.executemany( - '%s into %sbooze values (?)' % ( - self.insert, self.table_prefix), largs) + cur.executemany(f'{insert} (?)', largs) elif self.driver.paramstyle == 'numeric': - cur.executemany( - '%s into %sbooze values (:1)' % ( - self.insert, self.table_prefix), largs) + cur.executemany(f'{insert} (:1)', largs) elif self.driver.paramstyle == 'named': - cur.executemany( - '%s into %sbooze values (:beer)' % ( - self.insert, self.table_prefix), margs) + cur.executemany(f'{insert} (:beer)', margs) elif self.driver.paramstyle == 'format': - cur.executemany( - '%s into %sbooze values (%%s)' % ( - self.insert, self.table_prefix), largs) + cur.executemany(f'{insert} (%s)', largs) elif self.driver.paramstyle == 'pyformat': - cur.executemany( - '%s into %sbooze values (%%(beer)s)' % ( - self.insert, self.table_prefix), margs) + cur.executemany(f'{insert} (%(beer)s)', margs) else: self.fail('Unknown paramstyle') self.assertIn( cur.rowcount, (-1, 2), 'insert using cursor.executemany set cursor.rowcount to' - ' incorrect value %r' % cur.rowcount) - cur.execute('select name from %sbooze' % self.table_prefix) + f' incorrect value {cur.rowcount!r}') + cur.execute(f'select name from {table_prefix}booze') res = cur.fetchall() self.assertEqual( len(res), 2, @@ -449,7 +432,7 @@ def test_fetchone(self): self.executeDDL1(cur) self.assertRaises(self.driver.Error, cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') self.assertIsNone( cur.fetchone(), 'cursor.fetchone should return None if a query retrieves' @@ -458,12 +441,12 @@ def test_fetchone(self): # cursor.fetchone should raise an Error if called after # executing a query that cannot return rows - cur.execute("%s into %sbooze values ('Victoria Bitter')" % ( - self.insert, self.table_prefix - )) + cur.execute( + f"{self.insert} into {self.table_prefix}booze" + " values ('Victoria Bitter')") self.assertRaises(self.driver.Error, cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') r = cur.fetchone() self.assertEqual( len(r), 1, @@ -490,8 +473,7 @@ def test_fetchone(self): def _populate(self): """Return a list of SQL commands to setup the DB for fetching tests.""" populate = [ - "%s into %sbooze values ('%s')" % ( - self.insert, self.table_prefix, s) + f"{self.insert} into {self.table_prefix}booze values ('{s}')" for s in self.samples] return populate @@ -508,7 +490,7 @@ def test_fetchmany(self): for sql in self._populate(): cur.execute(sql) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') r = cur.fetchmany() self.assertEqual( len(r), 1, @@ -532,7 +514,7 @@ def test_fetchmany(self): # Same as above, using cursor.arraysize cur.arraysize = 4 - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') r = cur.fetchmany() # Should get 4 rows self.assertEqual( len(r), 4, @@ -544,7 +526,7 @@ def test_fetchmany(self): self.assertIn(cur.rowcount, (-1, 6)) cur.arraysize = 6 - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') rows = cur.fetchmany() # Should get all rows self.assertIn(cur.rowcount, (-1, 6)) self.assertEqual(len(rows), 6) @@ -566,7 +548,7 @@ def test_fetchmany(self): self.assertIn(cur.rowcount, (-1, 6)) self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}barflys') r = cur.fetchmany() # Should get empty sequence self.assertEqual( len(r), 0, @@ -594,7 +576,7 @@ def test_fetchall(self): # after executing a a statement that cannot return rows self.assertRaises(self.driver.Error, cur.fetchall) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') rows = cur.fetchall() self.assertIn(cur.rowcount, (-1, len(self.samples))) self.assertEqual( @@ -613,7 +595,7 @@ def test_fetchall(self): self.assertIn(cur.rowcount, (-1, len(self.samples))) self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}barflys') rows = cur.fetchall() self.assertIn(cur.rowcount, (-1, 0)) self.assertEqual( @@ -632,7 +614,7 @@ def test_mixedfetch(self): for sql in self._populate(): cur.execute(sql) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') rows1 = cur.fetchone() rows23 = cur.fetchmany(2) rows4 = cur.fetchone() @@ -676,51 +658,45 @@ def help_nextset_setUp(self, cur): def help_nextset_tearDown(self, cur): """Clean up after nextset test. - If cleaning up is needed after nextSetTest. + If cleaning up is needed after test_nextset. """ raise NotImplementedError('Helper not implemented') # cur.execute("drop procedure deleteme") - # example test implementation only def test_nextset(self): - con = self._connect() - try: - cur = con.cursor() - if not hasattr(cur, 'nextset'): - return - - try: - self.executeDDL1(cur) - for sql in self._populate(): - cur.execute(sql) - - self.help_nextset_setUp(cur) - - cur.callproc('deleteme') - number_of_rows = cur.fetchone() - self.assertEqual(number_of_rows[0], len(self.samples)) - self.assertTrue(cur.nextset()) - names = cur.fetchall() - self.assertEqual(len(names), len(self.samples)) - s = cur.nextset() - self.assertIsNone(s, 'No more return sets, should return None') - finally: - self.help_nextset_tearDown(cur) - - finally: - con.close() - - # noinspection PyRedeclaration - def test_nextset(self): # noqa: F811 + """Test the nextset functionality.""" raise NotImplementedError('Drivers need to override this test') + # example test implementation only: + # con = self._connect() + # try: + # cur = con.cursor() + # if not hasattr(cur, 'nextset'): + # return + # try: + # self.executeDDL1(cur) + # for sql in self._populate(): + # cur.execute(sql) + # self.help_nextset_setUp(cur) + # cur.callproc('deleteme') + # number_of_rows = cur.fetchone() + # self.assertEqual(number_of_rows[0], len(self.samples)) + # self.assertTrue(cur.nextset()) + # names = cur.fetchall() + # self.assertEqual(len(names), len(self.samples)) + # self.assertIsNone( + # cur.nextset(), 'No more return sets, should return None') + # finally: + # self.help_nextset_tearDown(cur) + # finally: + # con.close() def test_arraysize(self): # Not much here - rest of the tests for this are in test_fetchmany con = self._connect() try: cur = con.cursor() - self.assertTrue( - hasattr(cur, 'arraysize'), 'cursor.arraysize must be defined') + self.assertTrue(hasattr(cur, 'arraysize'), + 'cursor.arraysize must be defined') finally: con.close() @@ -756,9 +732,9 @@ def test_None(self): # inserting NULL to the second column, because some drivers might # need the first one to be primary key, which means it needs # to have a non-NULL value - cur.execute("%s into %sbarflys values ('a', NULL)" % ( - self.insert, self.table_prefix)) - cur.execute('select drink from %sbarflys' % self.table_prefix) + cur.execute(f"{self.insert} into {self.table_prefix}barflys" + " values ('a', NULL)") + cur.execute(f'select drink from {self.table_prefix}barflys') r = cur.fetchall() self.assertEqual(len(r), 1) self.assertEqual(len(r[0]), 1) diff --git a/tests/test_classic.py b/tests/test_classic.py index 799cb6c7..6319d5d5 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -48,11 +48,11 @@ def setUpClass(cls): except Exception: pass try: - db.query("DROP TABLE %s._test_schema" % (t,)) + db.query(f"DROP TABLE {t}._test_schema") except Exception: pass - db.query("CREATE TABLE %s._test_schema" - " (%s int PRIMARY KEY)" % (t, t)) + db.query(f"CREATE TABLE {t}._test_schema" + f" ({t} int PRIMARY KEY)") db.close() def setUp(self): @@ -60,7 +60,7 @@ def setUp(self): db = open_db() db.query("TRUNCATE TABLE _test_schema") for t in ('_test1', '_test2'): - db.query("TRUNCATE TABLE %s._test_schema" % t) + db.query(f"TRUNCATE TABLE {t}._test_schema") db.close() def test_invalid_name(self): diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index c456b4ec..f7ca2a46 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -54,7 +54,7 @@ def testCanConnect(self): connection = connect() rc = connection.poll() except pg.Error as error: - self.fail('Cannot connect to database %s:\n%s' % (dbname, error)) + self.fail(f'Cannot connect to database {dbname}:\n{error}') self.assertEqual(rc, pg.POLLING_OK) self.assertIs(connection.is_non_blocking(), False) connection.set_non_blocking(True) @@ -74,7 +74,7 @@ def testCanConnectNoWait(self): while rc not in (pg.POLLING_OK, pg.POLLING_FAILED): rc = connection.poll() except pg.Error as error: - self.fail('Cannot connect to database %s:\n%s' % (dbname, error)) + self.fail(f'Cannot connect to database {dbname}:\n{error}') self.assertEqual(rc, pg.POLLING_OK) self.assertIs(connection.is_non_blocking(), False) connection.set_non_blocking(True) @@ -310,7 +310,7 @@ def testMethodReset(self): encoding = query('show client_encoding').getresult()[0][0].upper() changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8' self.assertNotEqual(encoding, changed_encoding) - self.connection.query("set client_encoding=%s" % changed_encoding) + self.connection.query(f"set client_encoding={changed_encoding}") new_encoding = query('show client_encoding').getresult()[0][0].upper() self.assertEqual(new_encoding, changed_encoding) self.connection.reset() @@ -459,7 +459,7 @@ def testGetresultDecimal(self): def testGetresultString(self): result = 'Hello, world!' - q = "select '%s'" % result + q = f"select '{result}'" v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -503,7 +503,7 @@ def testDictresultDecimal(self): def testDictresultString(self): result = 'Hello, world!' - q = "select '%s' as greeting" % result + q = f"select '{result}' as greeting" v = self.c.query(q).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -699,7 +699,7 @@ def testFieldInfoName(self): for field_num, info in enumerate(result): field_name = info[0] if field_num > 0: - field_name = '"%s"' % field_name + field_name = f'"{field_name}"' r = f(field_name) self.assertIsInstance(r, tuple) self.assertEqual(len(r), 4) @@ -841,28 +841,27 @@ def tearDown(self): self.c.close() def testGetresulAscii(self): - result = u'Hello, world!' - q = u"select '%s'" % result + result = 'Hello, world!' + q = f"select '{result}'" v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) def testDictresulAscii(self): - result = u'Hello, world!' - q = u"select '%s' as greeting" % result + result = 'Hello, world!' + q = f"select '{result}' as greeting" v = self.c.query(q).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) def testGetresultUtf8(self): - result = u'Hello, wörld & мир!' - q = u"select '%s'" % result + result = 'Hello, wörld & мир!' + q = f"select '{result}'" # pass the query as unicode try: v = self.c.query(q).getresult()[0][0] except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") - v = None self.assertIsInstance(v, str) self.assertEqual(v, result) q = q.encode('utf8') @@ -872,13 +871,12 @@ def testGetresultUtf8(self): self.assertEqual(v, result) def testDictresultUtf8(self): - result = u'Hello, wörld & мир!' - q = u"select '%s' as greeting" % result + result = 'Hello, wörld & мир!' + q = f"select '{result}' as greeting" try: v = self.c.query(q).dictresult()[0]['greeting'] except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") - v = None self.assertIsInstance(v, str) self.assertEqual(v, result) q = q.encode('utf8') @@ -891,8 +889,8 @@ def testGetresultLatin1(self): self.c.query('set client_encoding=latin1') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") - result = u'Hello, wörld!' - q = u"select '%s'" % result + result = 'Hello, wörld!' + q = f"select '{result}'" v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -906,8 +904,8 @@ def testDictresultLatin1(self): self.c.query('set client_encoding=latin1') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") - result = u'Hello, wörld!' - q = u"select '%s' as greeting" % result + result = 'Hello, wörld!' + q = f"select '{result}' as greeting" v = self.c.query(q).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -921,8 +919,8 @@ def testGetresultCyrillic(self): self.c.query('set client_encoding=iso_8859_5') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support cyrillic") - result = u'Hello, мир!' - q = u"select '%s'" % result + result = 'Hello, мир!' + q = f"select '{result}'" v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -936,8 +934,8 @@ def testDictresultCyrillic(self): self.c.query('set client_encoding=iso_8859_5') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support cyrillic") - result = u'Hello, мир!' - q = u"select '%s' as greeting" % result + result = 'Hello, мир!' + q = f"select '{result}' as greeting" v = self.c.query(q).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -951,8 +949,8 @@ def testGetresultLatin9(self): self.c.query('set client_encoding=latin9') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin9") - result = u'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' - q = u"select '%s'" % result + result = 'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' + q = f"select '{result}'" v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -966,8 +964,8 @@ def testDictresultLatin9(self): self.c.query('set client_encoding=latin9') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin9") - result = u'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' - q = u"select '%s' as menu" % result + result = 'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' + q = f"select '{result}' as menu" v = self.c.query(q).dictresult()[0]['menu'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -1092,7 +1090,7 @@ def testQueryWithUnicodeParams(self): except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") self.assertEqual( - query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult(), + query("select $1||', '||$2||'!'", ('Hello', 'wörld')).getresult(), [('Hello, wörld!',)]) def testQueryWithUnicodeParamsLatin1(self): @@ -1103,22 +1101,22 @@ def testQueryWithUnicodeParamsLatin1(self): query("select 'wörld'").getresult()[0][0], 'wörld') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") - r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult() + r = query("select $1||', '||$2||'!'", ('Hello', 'wörld')).getresult() self.assertEqual(r, [('Hello, wörld!',)]) self.assertRaises( UnicodeError, query, "select $1||', '||$2||'!'", - ('Hello', u'мир')) + ('Hello', 'мир')) query('set client_encoding=iso_8859_1') r = query( - "select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult() + "select $1||', '||$2||'!'", ('Hello', 'wörld')).getresult() self.assertEqual(r, [('Hello, wörld!',)]) self.assertRaises( UnicodeError, query, "select $1||', '||$2||'!'", - ('Hello', u'мир')) + ('Hello', 'мир')) query('set client_encoding=sql_ascii') self.assertRaises( UnicodeError, query, "select $1||', '||$2||'!'", - ('Hello', u'wörld')) + ('Hello', 'wörld')) def testQueryWithUnicodeParamsCyrillic(self): query = self.c.query @@ -1130,14 +1128,14 @@ def testQueryWithUnicodeParamsCyrillic(self): self.skipTest("database does not support cyrillic") self.assertRaises( UnicodeError, query, "select $1||', '||$2||'!'", - ('Hello', u'wörld')) + ('Hello', 'wörld')) r = query( - "select $1||', '||$2||'!'", ('Hello', u'мир')).getresult() + "select $1||', '||$2||'!'", ('Hello', 'мир')).getresult() self.assertEqual(r, [('Hello, мир!',)]) query('set client_encoding=sql_ascii') self.assertRaises( UnicodeError, query, "select $1||', '||$2||'!'", - ('Hello', u'мир!')) + ('Hello', 'мир!')) def testQueryWithMixedParams(self): self.assertEqual( @@ -1264,7 +1262,7 @@ def tearDown(self): self.c.close() def assert_proper_cast(self, value, pgtype, pytype): - q = 'select $1::%s' % (pgtype,) + q = f'select $1::{pgtype}' try: r = self.c.query(q, (value,)).getresult()[0][0] except pg.ProgrammingError as e: @@ -1275,8 +1273,8 @@ def assert_proper_cast(self, value, pgtype, pytype): self.assertIsInstance(r, pytype) if isinstance(value, str): if not value or ' ' in value or '{' in value: - value = '"%s"' % value - value = '{%s}' % value + value = f'"{value}"' + value = f'{{{value}}}' r = self.c.query(q + '[]', (value,)).getresult()[0][0] if pgtype.startswith(('date', 'time', 'interval')): # arrays of these are casted by the DB wrapper only @@ -2009,11 +2007,11 @@ def testInserttableByteValues(self): except pg.DataError: self.skipTest("database does not support utf8") # non-ascii chars do not fit in char(1) when there is no encoding - c = u'€' if self.has_encoding else u'$' + c = '€' if self.has_encoding else '$' row_unicode = ( - 0, 0, 0, False, u'1970-01-01', u'00:00:00', - 0.0, 0.0, 0.0, u'0.0', - c, u'bäd', u'bäd', u"käse сыр pont-l'évêque") + 0, 0, 0, False, '1970-01-01', '00:00:00', + 0.0, 0.0, 0.0, '0.0', + c, 'bäd', 'bäd', "käse сыр pont-l'évêque") row_bytes = tuple( s.encode('utf-8') if isinstance(s, str) else s for s in row_unicode) @@ -2028,11 +2026,11 @@ def testInserttableUnicodeUtf8(self): except pg.DataError: self.skipTest("database does not support utf8") # non-ascii chars do not fit in char(1) when there is no encoding - c = u'€' if self.has_encoding else u'$' + c = '€' if self.has_encoding else '$' row_unicode = ( - 0, 0, 0, False, u'1970-01-01', u'00:00:00', - 0.0, 0.0, 0.0, u'0.0', - c, u'bäd', u'bäd', u"käse сыр pont-l'évêque") + 0, 0, 0, False, '1970-01-01', '00:00:00', + 0.0, 0.0, 0.0, '0.0', + c, 'bäd', 'bäd', "käse сыр pont-l'évêque") data = [row_unicode] * 2 self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) @@ -2044,16 +2042,16 @@ def testInserttableUnicodeLatin1(self): except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") # non-ascii chars do not fit in char(1) when there is no encoding - c = u'€' if self.has_encoding else u'$' + c = '€' if self.has_encoding else '$' row_unicode = ( - 0, 0, 0, False, u'1970-01-01', u'00:00:00', - 0.0, 0.0, 0.0, u'0.0', - c, u'bäd', u'bäd', u"for käse and pont-l'évêque pay in €") + 0, 0, 0, False, '1970-01-01', '00:00:00', + 0.0, 0.0, 0.0, '0.0', + c, 'bäd', 'bäd', "for käse and pont-l'évêque pay in €") data = [row_unicode] # cannot encode € sign with latin1 encoding self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data) row_unicode = tuple( - s.replace(u'€', u'¥') if isinstance(s, str) else s + s.replace('€', '¥') if isinstance(s, str) else s for s in row_unicode) data = [row_unicode] * 2 self.c.inserttable('test', data) @@ -2067,11 +2065,11 @@ def testInserttableUnicodeLatin9(self): self.skipTest("database does not support latin9") return # non-ascii chars do not fit in char(1) when there is no encoding - c = u'€' if self.has_encoding else u'$' + c = '€' if self.has_encoding else '$' row_unicode = ( - 0, 0, 0, False, u'1970-01-01', u'00:00:00', - 0.0, 0.0, 0.0, u'0.0', - c, u'bäd', u'bäd', u"for käse and pont-l'évêque pay in €") + 0, 0, 0, False, '1970-01-01', '00:00:00', + 0.0, 0.0, 0.0, '0.0', + c, 'bäd', 'bäd', "for käse and pont-l'évêque pay in €") data = [row_unicode] * 2 self.c.inserttable('test', data) self.assertEqual(self.get_back('latin9'), data) @@ -2079,11 +2077,11 @@ def testInserttableUnicodeLatin9(self): def testInserttableNoEncoding(self): self.c.query("set client_encoding=sql_ascii") # non-ascii chars do not fit in char(1) when there is no encoding - c = u'€' if self.has_encoding else u'$' + c = '€' if self.has_encoding else '$' row_unicode = ( - 0, 0, 0, False, u'1970-01-01', u'00:00:00', - 0.0, 0.0, 0.0, u'0.0', - c, u'bäd', u'bäd', u"for käse and pont-l'évêque pay in €") + 0, 0, 0, False, '1970-01-01', '00:00:00', + 0.0, 0.0, 0.0, '0.0', + c, 'bäd', 'bäd', "for käse and pont-l'évêque pay in €") data = [row_unicode] # cannot encode non-ascii unicode without a specific encoding self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data) @@ -2174,7 +2172,7 @@ def testPutline(self): query("copy test from stdin") try: for i, v in data: - putline("%d\t%s\n" % (i, v)) + putline(f"{i}\t{v}\n") finally: self.c.endcopy() r = query("select * from test").getresult() @@ -2189,7 +2187,7 @@ def testPutlineBytesAndUnicode(self): self.skipTest('database does not support utf8') query("copy test from stdin") try: - putline(u"47\tkäse\n".encode('utf8')) + putline("47\tkäse\n".encode('utf8')) putline("35\twürstel\n") finally: self.c.endcopy() @@ -2208,7 +2206,7 @@ def testGetline(self): v = getline() if i < n: # noinspection PyStringFormat - self.assertEqual(v, '%d\t%s' % data[i]) + self.assertEqual(v, '{}\t{}'.format(*data[i])) elif i == n: self.assertIsNone(v) finally: @@ -2224,7 +2222,7 @@ def testGetlineBytesAndUnicode(self): query("select 'käse+würstel'") except (pg.DataError, pg.NotSupportedError): self.skipTest('database does not support utf8') - data = [(54, u'käse'.encode('utf8')), (73, u'würstel')] + data = [(54, 'käse'.encode('utf8')), (73, 'würstel')] self.c.inserttable('test', data) query("copy test to stdout") try: @@ -2405,7 +2403,7 @@ def testSetDecimalPoint(self): # first try with English localization (using the point) for lc in en_locales: try: - query("set lc_monetary='%s'" % lc) + query(f"set lc_monetary='{lc}'") except pg.DataError: pass else: @@ -2456,7 +2454,7 @@ def testSetDecimalPoint(self): # then try with German localization (using the comma) for lc in de_locales: try: - query("set lc_monetary='%s'" % lc) + query(f"set lc_monetary='{lc}'") except pg.DataError: pass else: @@ -2714,15 +2712,15 @@ def testEscapeString(self): r = f(b'plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') - r = f(u'plain') + r = f('plain') self.assertIsInstance(r, str) - self.assertEqual(r, u'plain') - r = f(u"das is' käse".encode('utf-8')) + self.assertEqual(r, 'plain') + r = f("das is' käse".encode('utf-8')) self.assertIsInstance(r, bytes) - self.assertEqual(r, u"das is'' käse".encode('utf-8')) - r = f(u"that's cheesy") + self.assertEqual(r, "das is'' käse".encode('utf-8')) + r = f("that's cheesy") self.assertIsInstance(r, str) - self.assertEqual(r, u"that''s cheesy") + self.assertEqual(r, "that''s cheesy") r = f(r"It's bad to have a \ inside.") self.assertEqual(r, r"It''s bad to have a \\ inside.") @@ -2732,15 +2730,15 @@ def testEscapeBytea(self): r = f(b'plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') - r = f(u'plain') + r = f('plain') self.assertIsInstance(r, str) - self.assertEqual(r, u'plain') - r = f(u"das is' käse".encode('utf-8')) + self.assertEqual(r, 'plain') + r = f("das is' käse".encode('utf-8')) self.assertIsInstance(r, bytes) self.assertEqual(r, b"das is'' k\\\\303\\\\244se") - r = f(u"that's cheesy") + r = f("that's cheesy") self.assertIsInstance(r, str) - self.assertEqual(r, u"that''s cheesy") + self.assertEqual(r, "that''s cheesy") r = f(b'O\x00ps\xff!') self.assertEqual(r, b'O\\\\000ps\\\\377!') diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 3d372ad3..79c962a4 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -466,15 +466,15 @@ def createTable(self, table, definition, temporary=True, oids=None, values=None): query = self.db.query if '"' not in table or '.' in table: - table = '"%s"' % table + table = f'"{table}"' if not temporary: - q = 'drop table if exists %s cascade' % table + q = f'drop table if exists {table} cascade' query(q) self.addCleanup(query, q) temporary = 'temporary table' if temporary else 'table' as_query = definition.startswith(('as ', 'AS ')) if not as_query and not definition.startswith('('): - definition = '(%s)' % definition + definition = f'({definition})' with_oids = 'with oids' if oids else ( 'without oids' if self.oids else '') q = ['create', temporary, table] @@ -488,8 +488,8 @@ def createTable(self, table, definition, for params in values: if not isinstance(params, (list, tuple)): params = [params] - values = ', '.join('$%d' % (n + 1) for n in range(len(params))) - q = "insert into %s values (%s)" % (table, values) + values = ', '.join(f'${n + 1}' for n in range(len(params))) + q = f"insert into {table} values ({values})" query(q, params) def testClassName(self): @@ -504,15 +504,15 @@ def testEscapeLiteral(self): r = f(b"plain") self.assertIsInstance(r, bytes) self.assertEqual(r, b"'plain'") - r = f(u"plain") + r = f("plain") self.assertIsInstance(r, str) - self.assertEqual(r, u"'plain'") - r = f(u"that's käse".encode('utf-8')) + self.assertEqual(r, "'plain'") + r = f("that's käse".encode('utf-8')) self.assertIsInstance(r, bytes) - self.assertEqual(r, u"'that''s käse'".encode('utf-8')) - r = f(u"that's käse") + self.assertEqual(r, "'that''s käse'".encode('utf-8')) + r = f("that's käse") self.assertIsInstance(r, str) - self.assertEqual(r, u"'that''s käse'") + self.assertEqual(r, "'that''s käse'") self.assertEqual(f(r"It's fine to have a \ inside."), r" E'It''s fine to have a \\ inside.'") self.assertEqual(f('No "quotes" must be escaped.'), @@ -523,15 +523,15 @@ def testEscapeIdentifier(self): r = f(b"plain") self.assertIsInstance(r, bytes) self.assertEqual(r, b'"plain"') - r = f(u"plain") + r = f("plain") self.assertIsInstance(r, str) - self.assertEqual(r, u'"plain"') - r = f(u"that's käse".encode('utf-8')) + self.assertEqual(r, '"plain"') + r = f("that's käse".encode('utf-8')) self.assertIsInstance(r, bytes) - self.assertEqual(r, u'"that\'s käse"'.encode('utf-8')) - r = f(u"that's käse") + self.assertEqual(r, '"that\'s käse"'.encode('utf-8')) + r = f("that's käse") self.assertIsInstance(r, str) - self.assertEqual(r, u'"that\'s käse"') + self.assertEqual(r, '"that\'s käse"') self.assertEqual(f(r"It's fine to have a \ inside."), '"It\'s fine to have a \\ inside."') self.assertEqual(f('All "quotes" must be escaped.'), @@ -542,15 +542,15 @@ def testEscapeString(self): r = f(b"plain") self.assertIsInstance(r, bytes) self.assertEqual(r, b"plain") - r = f(u"plain") + r = f("plain") self.assertIsInstance(r, str) - self.assertEqual(r, u"plain") - r = f(u"that's käse".encode('utf-8')) + self.assertEqual(r, "plain") + r = f("that's käse".encode('utf-8')) self.assertIsInstance(r, bytes) - self.assertEqual(r, u"that''s käse".encode('utf-8')) - r = f(u"that's käse") + self.assertEqual(r, "that''s käse".encode('utf-8')) + r = f("that's käse") self.assertIsInstance(r, str) - self.assertEqual(r, u"that''s käse") + self.assertEqual(r, "that''s käse") self.assertEqual(f(r"It's fine to have a \ inside."), r"It''s fine to have a \ inside.") @@ -561,15 +561,15 @@ def testEscapeBytea(self): r = f(b'plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'\\x706c61696e') - r = f(u'plain') + r = f('plain') self.assertIsInstance(r, str) - self.assertEqual(r, u'\\x706c61696e') - r = f(u"das is' käse".encode('utf-8')) + self.assertEqual(r, '\\x706c61696e') + r = f("das is' käse".encode('utf-8')) self.assertIsInstance(r, bytes) self.assertEqual(r, b'\\x64617320697327206bc3a47365') - r = f(u"das is' käse") + r = f("das is' käse") self.assertIsInstance(r, str) - self.assertEqual(r, u'\\x64617320697327206bc3a47365') + self.assertEqual(r, '\\x64617320697327206bc3a47365') self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21') def testUnescapeBytea(self): @@ -577,15 +577,15 @@ def testUnescapeBytea(self): r = f(b'plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') - r = f(u'plain') + r = f('plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') r = f(b"das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, u"das is' käse".encode('utf8')) - r = f(u"das is' k\\303\\244se") + self.assertEqual(r, "das is' käse".encode('utf8')) + r = f("das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, u"das is' käse".encode('utf8')) + self.assertEqual(r, "das is' käse".encode('utf8')) self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!') self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e') self.assertEqual(f(r'\\x746861742773206be47365'), @@ -848,7 +848,7 @@ def testCreateTable(self): values = [(2, "World!"), (1, "Hello")] self.createTable(table, "n smallint, t varchar", temporary=True, oids=False, values=values) - r = self.db.query('select t from "%s" order by n' % table).getresult() + r = self.db.query(f'select t from "{table}" order by n').getresult() r = ', '.join(row[0] for row in r) self.assertEqual(r, "Hello, World!") @@ -859,10 +859,10 @@ def testCreateTableWithOids(self): values = [(2, "World!"), (1, "Hello")] self.createTable(table, "n smallint, t varchar", temporary=True, oids=True, values=values) - r = self.db.query('select t from "%s" order by n' % table).getresult() + r = self.db.query(f'select t from "{table}" order by n').getresult() r = ', '.join(row[0] for row in r) self.assertEqual(r, "Hello, World!") - r = self.db.query('select oid from "%s" limit 1' % table).getresult() + r = self.db.query(f'select oid from "{table}" limit 1').getresult() self.assertIsInstance(r[0][0], int) def testQuery(self): @@ -1131,56 +1131,56 @@ def testPkey(self): pkey = self.db.pkey self.assertRaises(KeyError, pkey, 'test') for t in ('pkeytest', 'primary key test'): - self.createTable('%s0' % t, 'a smallint') - self.createTable('%s1' % t, 'b smallint primary key') - self.createTable('%s2' % t, 'c smallint, d smallint primary key') + self.createTable(f'{t}0', 'a smallint') + self.createTable(f'{t}1', 'b smallint primary key') + self.createTable(f'{t}2', 'c smallint, d smallint primary key') self.createTable( - '%s3' % t, + f'{t}3', 'e smallint, f smallint, g smallint, h smallint, i smallint,' ' primary key (f, h)') self.createTable( - '%s4' % t, + f'{t}4', 'e smallint, f smallint, g smallint, h smallint, i smallint,' ' primary key (h, f)') self.createTable( - '%s5' % t, 'more_than_one_letter varchar primary key') + f'{t}5', 'more_than_one_letter varchar primary key') self.createTable( - '%s6' % t, '"with space" date primary key') + f'{t}6', '"with space" date primary key') self.createTable( - '%s7' % t, + f'{t}7', 'a_very_long_column_name varchar, "with space" date, "42" int,' ' primary key (a_very_long_column_name, "with space", "42")') - self.assertRaises(KeyError, pkey, '%s0' % t) - self.assertEqual(pkey('%s1' % t), 'b') - self.assertEqual(pkey('%s1' % t, True), ('b',)) - self.assertEqual(pkey('%s1' % t, composite=False), 'b') - self.assertEqual(pkey('%s1' % t, composite=True), ('b',)) - self.assertEqual(pkey('%s2' % t), 'd') - self.assertEqual(pkey('%s2' % t, composite=True), ('d',)) - r = pkey('%s3' % t) + self.assertRaises(KeyError, pkey, f'{t}0') + self.assertEqual(pkey(f'{t}1'), 'b') + self.assertEqual(pkey(f'{t}1', True), ('b',)) + self.assertEqual(pkey(f'{t}1', composite=False), 'b') + self.assertEqual(pkey(f'{t}1', composite=True), ('b',)) + self.assertEqual(pkey(f'{t}2'), 'd') + self.assertEqual(pkey(f'{t}2', composite=True), ('d',)) + r = pkey(f'{t}3') self.assertIsInstance(r, tuple) self.assertEqual(r, ('f', 'h')) - r = pkey('%s3' % t, composite=False) + r = pkey(f'{t}3', composite=False) self.assertIsInstance(r, tuple) self.assertEqual(r, ('f', 'h')) - r = pkey('%s4' % t) + r = pkey(f'{t}4') self.assertIsInstance(r, tuple) self.assertEqual(r, ('h', 'f')) - self.assertEqual(pkey('%s5' % t), 'more_than_one_letter') - self.assertEqual(pkey('%s6' % t), 'with space') - r = pkey('%s7' % t) + self.assertEqual(pkey(f'{t}5'), 'more_than_one_letter') + self.assertEqual(pkey(f'{t}6'), 'with space') + r = pkey(f'{t}7') self.assertIsInstance(r, tuple) self.assertEqual(r, ( 'a_very_long_column_name', 'with space', '42')) # a newly added primary key will be detected - query('alter table "%s0" add primary key (a)' % t) - self.assertEqual(pkey('%s0' % t), 'a') + query(f'alter table "{t}0" add primary key (a)') + self.assertEqual(pkey(f'{t}0'), 'a') # a changed primary key will not be detected, # indicating that the internal cache is operating - query('alter table "%s1" rename column b to x' % t) - self.assertEqual(pkey('%s1' % t), 'b') + query(f'alter table "{t}1" rename column b to x') + self.assertEqual(pkey(f'{t}1'), 'b') # we get the changed primary key when the cache is flushed - self.assertEqual(pkey('%s1' % t, flush=True), 'x') + self.assertEqual(pkey(f'{t}1', flush=True), 'x') def testGetDatabases(self): databases = self.db.get_databases() @@ -1197,7 +1197,7 @@ def testGetTables(self): 'averyveryveryveryveryveryveryreallyreallylongtablename', 'b0', 'b3', 'x', 'xXx', 'xx', 'y', 'z') for t in tables: - self.db.query('drop table if exists "%s" cascade' % t) + self.db.query(f'drop table if exists "{t}" cascade') before_tables = get_tables() self.assertIsInstance(before_tables, list) for t in before_tables: @@ -1212,8 +1212,8 @@ def testGetTables(self): self.createTable(t, 'as select 0', temporary=False) current_tables = get_tables() new_tables = [t for t in current_tables if t not in before_tables] - expected_new_tables = ['public.%s' % ( - '"%s"' % t if ' ' in t or t != t.lower() else t) for t in tables] + expected_new_tables = ['public.' + ( + f'"{t}"' if ' ' in t or t != t.lower() else t) for t in tables] self.assertEqual(new_tables, expected_new_tables) self.doCleanups() after_tables = get_tables() @@ -1513,8 +1513,8 @@ def testGetGeneratedIsCached(self): table = 'test_get_generated_2' self.createTable(table, 'i int primary key') self.assertFalse(get_generated(table)) - query('alter table %s alter column i' - ' add generated always as identity' % table) + query(f'alter table {table} alter column i' + ' add generated always as identity') self.assertFalse(get_generated(table)) self.assertEqual(get_generated(table, flush=True), {'i'}) @@ -1573,8 +1573,8 @@ def testGet(self): r = get(table, s, ('n', 't')) self.assertIs(r, s) self.assertEqual(r, dict(n=1, t='x')) - query('alter table "%s" alter n set not null' % table) - query('alter table "%s" add primary key (n)' % table) + query(f'alter table "{table}" alter n set not null') + query(f'alter table "{table}" add primary key (n)') r = get(table, 2) self.assertIsInstance(r, dict) self.assertEqual(r, dict(n=2, t='y')) @@ -1605,7 +1605,7 @@ def testGetWithOids(self): self.assertRaises(pg.ProgrammingError, get, table, 2) self.assertRaises(KeyError, get, table, {}, 'oid') r = get(table, 2, 'n') - qoid = 'oid(%s)' % table + qoid = f'oid({table})' self.assertIn(qoid, r) oid = r[qoid] self.assertIsInstance(oid, int) @@ -1632,8 +1632,8 @@ def testGetWithOids(self): self.assertEqual(get(table, r, 'n')['t'], 'z') self.assertEqual(get(table, 1, 'n')['t'], 'x') self.assertEqual(get(table, r, 'oid')['t'], 'z') - query('alter table "%s" alter n set not null' % table) - query('alter table "%s" add primary key (n)' % table) + query(f'alter table "{table}" alter n set not null') + query(f'alter table "{table}" add primary key (n)') self.assertEqual(get(table, 3)['t'], 'z') self.assertEqual(get(table, 1)['t'], 'x') self.assertEqual(get(table, 2)['t'], 'y') @@ -1836,10 +1836,10 @@ def testInsert(self): ts = datetime.strptime(ts, '%Y-%m-%d %H:%M:%S') expect['ts'] = ts self.assertEqual(data, expect) - data = query('select * from "%s"' % table).dictresult()[0] + data = query(f'select * from "{table}"').dictresult()[0] data = dict(item for item in data.items() if item[0] in expect) self.assertEqual(data, expect) - query('truncate table "%s"' % table) + query(f'truncate table "{table}"') def testInsertWithOids(self): if not self.oids: @@ -1923,7 +1923,7 @@ def testInsertWithQuotedNames(self): self.assertEqual(r['Prime!'], 11) self.assertEqual(r['much space'], 2002) self.assertEqual(r['Questions?'], 'What?') - r = query('select * from "%s" limit 2' % table).dictresult() + r = query(f'select * from "{table}" limit 2').dictresult() self.assertEqual(len(r), 1) r = r[0] self.assertEqual(r['Prime!'], 11) @@ -1995,7 +1995,7 @@ def testUpdate(self): r['t'] = 'u' s = update(table, r) self.assertEqual(s, r) - q = 'select t from "%s" where n=2' % table + q = f'select t from "{table}" where n=2' r = query(q).getresult()[0][0] self.assertEqual(r, 'u') @@ -2091,7 +2091,7 @@ def testUpdateWithoutOid(self): r['t'] = 'u' s = update(table, r) self.assertEqual(s, r) - q = 'select t from "%s" where n=2' % table + q = f'select t from "{table}" where n=2' r = query(q).getresult()[0][0] self.assertEqual(r, 'u') @@ -2107,20 +2107,20 @@ def testUpdateWithCompositeKey(self): self.assertIs(r, s) self.assertEqual(r['n'], 2) self.assertEqual(r['t'], 'd') - q = 'select t from "%s" where n=2' % table + q = f'select t from "{table}" where n=2' r = query(q).getresult()[0][0] self.assertEqual(r, 'd') s.update(dict(n=4, t='e')) r = update(table, s) self.assertEqual(r['n'], 4) self.assertEqual(r['t'], 'e') - q = 'select t from "%s" where n=2' % table + q = f'select t from "{table}" where n=2' r = query(q).getresult()[0][0] self.assertEqual(r, 'd') - q = 'select t from "%s" where n=4' % table + q = f'select t from "{table}" where n=4' r = query(q).getresult() self.assertEqual(len(r), 0) - query('drop table "%s"' % table) + query(f'drop table "{table}"') table = 'update_test_table_2' self.createTable(table, 'n integer, m integer, t text, primary key (n, m)', @@ -2129,7 +2129,7 @@ def testUpdateWithCompositeKey(self): self.assertRaises(KeyError, update, table, dict(n=2, t='b')) self.assertEqual(update(table, dict(n=2, m=2, t='x'))['t'], 'x') - q = 'select t from "%s" where n=2 order by m' % table + q = f'select t from "{table}" where n=2 order by m' r = [r[0] for r in query(q).getresult()] self.assertEqual(r, ['c', 'x']) @@ -2146,7 +2146,7 @@ def testUpdateWithQuotedNames(self): self.assertEqual(r['Prime!'], 13) self.assertEqual(r['much space'], 7007) self.assertEqual(r['Questions?'], 'When?') - r = query('select * from "%s" limit 2' % table).dictresult() + r = query(f'select * from "{table}" limit 2').dictresult() self.assertEqual(len(r), 1) r = r[0] self.assertEqual(r['Prime!'], 13) @@ -2173,7 +2173,7 @@ def testUpdateWithGeneratedColumns(self): self.createTable(table, table_def) i, d = 35, 1001 j = i + 7 - r = query('insert into %s (i, d) values (%d, %d)' % (table, i, d)) + r = query(f'insert into {table} (i, d) values ({i}, {d})') self.assertEqual(r, '1') r = get(table, d) self.assertIsInstance(r, dict) @@ -2202,7 +2202,7 @@ def testUpsert(self): self.assertIs(r, s) self.assertEqual(r['n'], 2) self.assertEqual(r['t'], 'y') - q = 'select n, t from "%s" order by n limit 3' % table + q = f'select n, t from "{table}" order by n limit 3' r = query(q).getresult() self.assertEqual(r, [(1, 'x'), (2, 'y')]) s.update(t='z') @@ -2357,7 +2357,7 @@ def testUpsertWithCompositeKey(self): self.assertEqual(r['n'], 1) self.assertEqual(r['m'], 3) self.assertEqual(r['t'], 'y') - q = 'select n, m, t from "%s" order by n, m limit 3' % table + q = f'select n, m, t from "{table}" order by n, m limit 3' r = query(q).getresult() self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')]) s.update(t='z') @@ -2413,7 +2413,7 @@ def testUpsertWithQuotedNames(self): self.assertEqual(r['Prime!'], 31) self.assertEqual(r['much space'], 9009) self.assertEqual(r['Questions?'], 'Yes.') - q = 'select * from "%s" limit 2' % table + q = f'select * from "{table}" limit 2' r = query(q).getresult() self.assertEqual(r, [(31, 9009, 'Yes.')]) s.update({'Questions?': 'No.'}) @@ -2506,7 +2506,7 @@ def testDelete(self): self.assertEqual(s, 1) s = delete(table, r) self.assertEqual(s, 0) - r = query('select * from "%s"' % table).dictresult() + r = query(f'select * from "{table}"').dictresult() self.assertEqual(len(r), 1) r = r[0] result = {'n': 2, 't': 'y'} @@ -2574,7 +2574,7 @@ def testDeleteWithOids(self): self.assertIn('m', self.db.get_attnames('test_table', flush=True)) self.assertEqual('n', self.db.pkey('test_table', flush=True)) for i in range(5): - query("insert into test_table values (%d, %d)" % (i + 1, i + 2)) + query(f"insert into test_table values ({i + 1}, {i + 2})") s = dict(m=2) self.assertRaises(KeyError, delete, 'test_table', s) s = dict(m=2, oid=oid) @@ -2625,10 +2625,10 @@ def testDeleteWithCompositeKey(self): values=enumerate('abc', start=1)) self.assertRaises(KeyError, self.db.delete, table, dict(t='b')) self.assertEqual(self.db.delete(table, dict(n=2)), 1) - r = query('select t from "%s" where n=2' % table).getresult() + r = query(f'select t from "{table}" where n=2').getresult() self.assertEqual(r, []) self.assertEqual(self.db.delete(table, dict(n=2)), 0) - r = query('select t from "%s" where n=3' % table).getresult()[0][0] + r = query(f'select t from "{table}" where n=3').getresult()[0][0] self.assertEqual(r, 'c') table = 'delete_test_table_2' self.createTable( @@ -2637,16 +2637,16 @@ def testDeleteWithCompositeKey(self): for n in range(3) for m in range(2)]) self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b')) self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1) - r = [r[0] for r in query('select t from "%s" where n=2' - ' order by m' % table).getresult()] + r = [r[0] for r in query(f'select t from "{table}" where n=2' + ' order by m').getresult()] self.assertEqual(r, ['c']) self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0) - r = [r[0] for r in query('select t from "%s" where n=3' - ' order by m' % table).getresult()] + r = [r[0] for r in query(f'select t from "{table}" where n=3' + ' order by m').getresult()] self.assertEqual(r, ['e', 'f']) self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1) - r = [r[0] for r in query('select t from "%s" where n=3' - ' order by m' % table).getresult()] + r = [r[0] for r in query(f'select t from "{table}" where n=3' + f' order by m').getresult()] self.assertEqual(r, ['f']) def testDeleteWithQuotedNames(self): @@ -2660,12 +2660,12 @@ def testDeleteWithQuotedNames(self): r = {'Prime!': 17} r = delete(table, r) self.assertEqual(r, 0) - r = query('select count(*) from "%s"' % table).getresult() + r = query(f'select count(*) from "{table}"').getresult() self.assertEqual(r[0][0], 1) r = {'Prime!': 19} r = delete(table, r) self.assertEqual(r, 1) - r = query('select count(*) from "%s"' % table).getresult() + r = query(f'select count(*) from "{table}"').getresult() self.assertEqual(r[0][0], 0) def testDeleteReferenced(self): @@ -2718,7 +2718,7 @@ def testTempCrud(self): r = self.db.get(table, 2) self.assertEqual(r['t'], 'two') self.db.delete(table, r) - r = self.db.query('select n, t from %s order by 1' % table).getresult() + r = self.db.query(f'select n, t from {table} order by 1').getresult() self.assertEqual(r, [(1, 'one'), (3, 'three')]) def testTruncate(self): @@ -2798,16 +2798,16 @@ def testTruncateCascade(self): r = query(q).getresult()[0] self.assertEqual(r, (0, 0)) for n in range(3): - query("insert into test_parent (n) values (%d)" % n) - query("insert into test_child (n) values (%d)" % n) + query(f"insert into test_parent (n) values ({n})") + query(f"insert into test_child (n) values ({n})") r = query(q).getresult()[0] self.assertEqual(r, (3, 3)) truncate('test_parent', cascade=True) r = query(q).getresult()[0] self.assertEqual(r, (0, 0)) for n in range(3): - query("insert into test_parent (n) values (%d)" % n) - query("insert into test_child (n) values (%d)" % n) + query(f"insert into test_parent (n) values ({n})") + query(f"insert into test_child (n) values ({n})") r = query(q).getresult()[0] self.assertEqual(r, (3, 3)) truncate('test_child') @@ -2859,8 +2859,8 @@ def testTruncateOnly(self): self.createTable('test_child_2', 'm smallint) inherits (test_parent_2') for t in '', '_2': for n in range(3): - query("insert into test_parent%s (n) values (1)" % t) - query("insert into test_child%s (n, m) values (2, 3)" % t) + query(f"insert into test_parent{t} (n) values (1)") + query(f"insert into test_child{t} (n, m) values (2, 3)") q = ("select (select count(*) from test_parent)," " (select count(*) from test_child)," " (select count(*) from test_parent_2)," @@ -2883,17 +2883,17 @@ def testTruncateQuoted(self): query = self.db.query table = "test table for truncate()" self.createTable(table, 'n smallint', temporary=False, values=[1] * 3) - q = 'select count(*) from "%s"' % table + q = f'select count(*) from "{table}"' r = query(q).getresult()[0][0] self.assertEqual(r, 3) truncate(table) r = query(q).getresult()[0][0] self.assertEqual(r, 0) for i in range(3): - query('insert into "%s" values (1)' % table) + query(f'insert into "{table}" values (1)') r = query(q).getresult()[0][0] self.assertEqual(r, 3) - truncate('public."%s"' % table) + truncate(f'public."{table}"') r = query(q).getresult()[0][0] self.assertEqual(r, 0) @@ -2975,10 +2975,10 @@ def testGetAsList(self): r = get_as_list(table, what='name', limit=1, scalar=True) self.assertIsInstance(r, list) self.assertEqual(r, expected[:1]) - query('alter table "%s" drop constraint "%s_pkey"' % (table, table)) + query(f'alter table "{table}" drop constraint "{table}_pkey"') self.assertRaises(KeyError, self.db.pkey, table, flush=True) names.insert(1, (1, 'Snowball')) - query('insert into "%s" values ($1, $2)' % table, (1, 'Snowball')) + query(f'insert into "{table}" values ($1, $2)', (1, 'Snowball')) r = get_as_list(table) self.assertIsInstance(r, list) self.assertEqual(r, names) @@ -2990,7 +2990,7 @@ def testGetAsList(self): self.assertIsInstance(r, list) self.assertEqual(set(r), set(names)) # test with arbitrary from clause - from_table = '(select lower(name) as n2 from "%s") as t2' % table + from_table = f'(select lower(name) as n2 from "{table}") as t2' r = get_as_list(from_table) self.assertIsInstance(r, list) r = {row[0] for row in r} @@ -3157,7 +3157,7 @@ def testGetAsDict(self): self.assertEqual(r, expected) self.assertNotIsInstance(self, OrderedDict) # test with arbitrary from clause - from_table = '(select id, lower(name) as n2 from "%s") as t2' % table + from_table = f'(select id, lower(name) as n2 from "{table}") as t2' # primary key must be passed explicitly in this case self.assertRaises(pg.ProgrammingError, get_as_dict, from_table) r = get_as_dict(from_table, 'id') @@ -3165,7 +3165,7 @@ def testGetAsDict(self): expected = OrderedDict((row[0], (row[2].lower(),)) for row in colors) self.assertEqual(r, expected) # test without a primary key - query('alter table "%s" drop constraint "%s_pkey"' % (table, table)) + query(f'alter table "{table}" drop constraint "{table}_pkey"') self.assertRaises(KeyError, self.db.pkey, table, flush=True) self.assertRaises(pg.ProgrammingError, get_as_dict, table) r = get_as_dict(table, keyname='id') @@ -3173,7 +3173,7 @@ def testGetAsDict(self): self.assertIsInstance(r, dict) self.assertEqual(r, expected) r = (1, '#007fff', 'Azure') - query('insert into "%s" values ($1, $2, $3)' % table, r) + query(f'insert into "{table}" values ($1, $2, $3)', r) # the last entry will win expected[1] = r[1:] r = get_as_dict(table, keyname='id') @@ -3971,7 +3971,7 @@ def testTimetz(self): query = self.db.query timezones = dict(CET=1, EET=2, EST=-5, UTC=0) for timezone in sorted(timezones): - tz = '%+03d00' % timezones[timezone] + tz = f'{timezones[timezone]:+03d}00' tzinfo = datetime.strptime(tz, '%z').tzinfo self.db.set_parameter('timezone', timezone) d = time(15, 9, 26, tzinfo=tzinfo) @@ -4023,7 +4023,7 @@ def testTimestamptz(self): query = self.db.query timezones = dict(CET=1, EET=2, EST=-5, UTC=0) for timezone in sorted(timezones): - tz = '%+03d00' % timezones[timezone] + tz = f'{timezones[timezone]:+03d}00' tzinfo = datetime.strptime(tz, '%z').tzinfo self.db.set_parameter('timezone', timezone) for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY', @@ -4174,7 +4174,7 @@ def testDbTypesTypecast(self): self.assertIs(dbtypes.get_typecast('int4'), int) self.assertNotIn('circle', dbtypes) self.assertIsNone(dbtypes.get_typecast('circle')) - squared_circle = lambda v: 'Squared Circle: %s' % v # noqa: E731 + squared_circle = lambda v: f'Squared Circle: {v}' # noqa: E731 dbtypes.set_typecast('circle', squared_circle) self.assertIs(dbtypes.get_typecast('circle'), squared_circle) r = self.db.query("select '0,0,1'::circle").getresult()[0][0] @@ -4199,7 +4199,7 @@ def testGetSetTypeCast(self): self.assertIs(get_typecast('bool'), pg.cast_bool) cast_circle = get_typecast('circle') self.addCleanup(set_typecast, 'circle', cast_circle) - squared_circle = lambda v: 'Squared Circle: %s' % v # noqa: E731 + squared_circle = lambda v: f'Squared Circle: {v}' # noqa: E731 self.assertNotIn('circle', dbtypes) set_typecast('circle', squared_circle) self.assertNotIn('circle', dbtypes) @@ -4698,23 +4698,23 @@ def setUpClass(cls): query = db.query for num_schema in range(5): if num_schema: - schema = "s%d" % num_schema - query("drop schema if exists %s cascade" % (schema,)) + schema = f"s{num_schema}" + query(f"drop schema if exists {schema} cascade") try: - query("create schema %s" % (schema,)) + query(f"create schema {schema}") except pg.ProgrammingError: raise RuntimeError( "The test user cannot create schemas.\n" - "Grant create on database %s to the user" - " for running these tests." % dbname) + f"Grant create on database {dbname} to the user" + " for running these tests.") else: schema = "public" - query("drop table if exists %s.t" % (schema,)) - query("drop table if exists %s.t%d" % (schema, num_schema)) - query("create table %s.t %s as select 1 as n, %d as d" - % (schema, cls.with_oids, num_schema)) - query("create table %s.t%d %s as select 1 as n, %d as d" - % (schema, num_schema, cls.with_oids, num_schema)) + query(f"drop table if exists {schema}.t") + query(f"drop table if exists {schema}.t{num_schema}") + query(f"create table {schema}.t {cls.with_oids}" + f" as select 1 as n, {num_schema} as d") + query(f"create table {schema}.t{num_schema} {cls.with_oids}" + f" as select 1 as n, {num_schema} as d") db.close() cls.cls_set_up = True @@ -4724,12 +4724,12 @@ def tearDownClass(cls): query = db.query for num_schema in range(5): if num_schema: - schema = "s%d" % num_schema - query("drop schema %s cascade" % (schema,)) + schema = f"s{num_schema}" + query(f"drop schema {schema} cascade") else: schema = "public" - query("drop table %s.t" % (schema,)) - query("drop table %s.t%d" % (schema, num_schema)) + query(f"drop table {schema}.t") + query(f"drop table {schema}.t{num_schema}") db.close() def setUp(self): @@ -4763,7 +4763,7 @@ def testGetAttnames(self): self.assertEqual(r, result) query("drop table if exists s3.t3m") self.addCleanup(query, "drop table s3.t3m") - query("create table s3.t3m %s as select 1 as m" % (self.with_oids,)) + query(f"create table s3.t3m {self.with_oids} as select 1 as m") result_m = {'m': 'int'} if self.with_oids: result_m['oid'] = 'int' @@ -4824,7 +4824,7 @@ def testQueryInformationSchema(self): q = "column_name" if self.db.server_version < 110000: q += "::text" # old version does not have sql_identifier array - q = "select array_agg(%s) from information_schema.columns" % q + q = f"select array_agg({q}) from information_schema.columns" q += " where table_schema in ('s1', 's2', 's3', 's4')" r = self.db.query(q).onescalar() self.assertIsInstance(r, list) diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 282ec6df..adddc8ce 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -278,7 +278,7 @@ def testParserNested(self): def testParserTooDeeplyNested(self): f = pg.cast_array for n in 3, 5, 9, 12, 16, 32, 64, 256: - r = '%sa,b,c%s' % ('{' * n, '}' * n) + r = '{' * n + 'a,b,c' + '}' * n if n > 16: # hard coded maximum depth self.assertRaises(ValueError, f, r) else: @@ -302,7 +302,7 @@ def testParserCast(self): self.assertEqual(f('{a}', str), ['a']) def cast(s): - return '%s is ok' % s + return f'{s} is ok' self.assertEqual(f('{a}', cast), ['a is ok']) def testParserDelim(self): @@ -528,7 +528,8 @@ def testParserNested(self): def testParserManyElements(self): f = pg.cast_record for n in 3, 5, 9, 12, 16, 32, 64, 256: - r = '(%s)' % ','.join(map(str, range(n))) + r = ','.join(map(str, range(n))) + r = f'({r})' r = f(r, int) self.assertEqual(r, tuple(range(n))) @@ -544,7 +545,7 @@ def testParserCastUniform(self): self.assertEqual(f('(a)', str), ('a',)) def cast(s): - return '%s is ok' % s + return f'{s} is ok' self.assertEqual(f('(a)', cast), ('a is ok',)) def testParserCastNonUniform(self): @@ -571,11 +572,11 @@ def testParserCastNonUniform(self): (1, 'a', 2, 'b', 3, 'c')) def cast1(s): - return '%s is ok' % s + return f'{s} is ok' self.assertEqual(f('(a)', [cast1]), ('a is ok',)) def cast2(s): - return 'and %s is ok, too' % s + return f'and {s} is ok, too' self.assertEqual( f('(a,b)', [cast1, cast2]), ('a is ok', 'and b is ok, too')) self.assertRaises(ValueError, f, '(a)', [cast1, cast2]) @@ -870,9 +871,9 @@ def testEscapeString(self): r = f(b'plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') - r = f(u'plain') + r = f('plain') self.assertIsInstance(r, str) - self.assertEqual(r, u'plain') + self.assertEqual(r, 'plain') r = f("that's cheese") self.assertIsInstance(r, str) self.assertEqual(r, "that''s cheese") @@ -882,9 +883,9 @@ def testEscapeBytea(self): r = f(b'plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') - r = f(u'plain') + r = f('plain') self.assertIsInstance(r, str) - self.assertEqual(r, u'plain') + self.assertEqual(r, 'plain') r = f("that's cheese") self.assertIsInstance(r, str) self.assertEqual(r, "that''s cheese") @@ -894,18 +895,18 @@ def testUnescapeBytea(self): r = f(b'plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') - r = f(u'plain') + r = f('plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') r = f(b"das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, u"das is' käse".encode('utf-8')) - r = f(u"das is' k\\303\\244se") + self.assertEqual(r, "das is' käse".encode('utf-8')) + r = f("das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, u"das is' käse".encode('utf-8')) + self.assertEqual(r, "das is' käse".encode('utf-8')) r = f(b'O\\000ps\\377!') self.assertEqual(r, b'O\x00ps\xff!') - r = f(u'O\\000ps\\377!') + r = f('O\\000ps\\377!') self.assertEqual(r, b'O\x00ps\xff!') diff --git a/tests/test_classic_largeobj.py b/tests/test_classic_largeobj.py index 3271686c..bdf3a613 100755 --- a/tests/test_classic_largeobj.py +++ b/tests/test_classic_largeobj.py @@ -38,7 +38,7 @@ def testLargeObjectIntConstants(self): try: value = getattr(pg, name) except AttributeError: - self.fail('Module constant %s is missing' % name) + self.fail(f'Module constant {name} is missing') self.assertIsInstance(value, int) @@ -187,10 +187,10 @@ def testStr(self): self.obj.write(data) oid = self.obj.oid r = str(self.obj) - self.assertEqual(r, 'Opened large object, oid %d' % oid) + self.assertEqual(r, f'Opened large object, oid {oid}') self.obj.close() r = str(self.obj) - self.assertEqual(r, 'Closed large object, oid %d' % oid) + self.assertEqual(r, f'Closed large object, oid {oid}') def testRepr(self): r = repr(self.obj) @@ -260,22 +260,22 @@ def testWrite(self): def testWriteLatin1Bytes(self): read = self.obj.read self.obj.open(pg.INV_WRITE) - self.obj.write(u'käse'.encode('latin1')) + self.obj.write('käse'.encode('latin1')) self.obj.close() self.obj.open(pg.INV_READ) r = read(80) self.assertIsInstance(r, bytes) - self.assertEqual(r.decode('latin1'), u'käse') + self.assertEqual(r.decode('latin1'), 'käse') def testWriteUtf8Bytes(self): read = self.obj.read self.obj.open(pg.INV_WRITE) - self.obj.write(u'käse'.encode('utf8')) + self.obj.write('käse'.encode('utf8')) self.obj.close() self.obj.open(pg.INV_READ) r = read(80) self.assertIsInstance(r, bytes) - self.assertEqual(r.decode('utf8'), u'käse') + self.assertEqual(r.decode('utf8'), 'käse') def testWriteUtf8String(self): read = self.obj.read @@ -285,7 +285,7 @@ def testWriteUtf8String(self): self.obj.open(pg.INV_READ) r = read(80) self.assertIsInstance(r, bytes) - self.assertEqual(r.decode('utf8'), u'käse') + self.assertEqual(r.decode('utf8'), 'käse') def testSeek(self): seek = self.obj.seek @@ -367,7 +367,7 @@ def testUnlinkInexistent(self): unlink = self.obj.unlink self.obj.open(pg.INV_WRITE) self.obj.close() - self.pgcnx.query('select lo_unlink(%d)' % self.obj.oid) + self.pgcnx.query(f'select lo_unlink({self.obj.oid})') self.assertRaises(IOError, unlink) def testSize(self): @@ -446,7 +446,7 @@ def testExportInExistent(self): f = tempfile.NamedTemporaryFile() self.obj.open(pg.INV_WRITE) self.obj.close() - self.pgcnx.query('select lo_unlink(%d)' % self.obj.oid) + self.pgcnx.query(f'select lo_unlink({self.obj.oid})') self.assertRaises(IOError, export, f.name) f.close() diff --git a/tests/test_classic_notification.py b/tests/test_classic_notification.py index 6f94cebd..dcc06382 100755 --- a/tests/test_classic_notification.py +++ b/tests/test_classic_notification.py @@ -88,7 +88,7 @@ def get_handler(self, event=None, arg_dict=None, stop_event=None): handler = self.db.notification_handler( event, callback, arg_dict, 0, stop_event) self.assertEqual(handler.event, event) - self.assertEqual(handler.stop_event, stop_event or 'stop_%s' % event) + self.assertEqual(handler.stop_event, stop_event or f'stop_{event}') self.assertIs(handler.callback, callback) if arg_dict is None: self.assertEqual(handler.arg_dict, {}) @@ -224,7 +224,7 @@ def start_handler(self, event=None, arg_dict=None, self.handler = handler self.assertIsInstance(handler, pg.NotificationHandler) self.assertEqual(handler.event, event) - self.assertEqual(handler.stop_event, stop_event or 'stop_%s' % event) + self.assertEqual(handler.stop_event, stop_event or f'stop_{event}') self.event = handler.event self.assertIs(handler.callback, callback) if arg_dict is None: @@ -277,9 +277,9 @@ def notify_query(self, stop=False, payload=None): if stop: event = self.handler.stop_event self.stopped = True - q = 'notify "%s"' % event + q = f'notify "{event}"' if payload: - q += ", '%s'" % payload + q += f", '{payload}'" arg_dict = self.arg_dict.copy() arg_dict.update(event=event, pid=1, extra=payload or '') self.db.query(q) @@ -370,14 +370,14 @@ def testNotifyQuotedNames(self): def testNotifyWithFivePayloads(self): self.start_handler('gimme_5', {'test': 'Gimme 5'}) for count in range(5): - self.notify_query(payload="Round %d" % count) + self.notify_query(payload=f"Round {count}") self.assertEqual(len(self.sent), 5) self.receive(stop=True) def testReceiveImmediately(self): self.start_handler('immediate', {'test': 'immediate'}) for count in range(3): - self.notify_query(payload="Round %d" % count) + self.notify_query(payload=f"Round {count}") self.receive() self.receive(stop=True) @@ -385,7 +385,7 @@ def testNotifyDistinctInTransaction(self): self.start_handler('test_transaction', {'transaction': True}) self.db.begin() for count in range(3): - self.notify_query(payload='Round %d' % count) + self.notify_query(payload=f'Round {count}') self.db.commit() self.receive(stop=True) diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 01a89247..6062a4fa 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -32,7 +32,7 @@ class test_PyGreSQL(dbapi20.DatabaseAPI20Test): driver = pgdb connect_args = () connect_kw_args = { - 'database': dbname, 'host': '%s:%d' % (dbhost or '', dbport or -1), + 'database': dbname, 'host': f"{dbhost or ''}:{dbport or -1}", 'user': dbuser, 'password': dbpasswd} lower_func = 'lower' # For stored procedure test @@ -164,7 +164,7 @@ def test_row_factory(self): class TestCursor(pgdb.Cursor): def row_factory(self, row): - return {'column %s' % desc[0]: value + return {f'column {desc[0]}': value for desc, value in zip(self.description, row)} con = self._connect() @@ -306,7 +306,7 @@ def test_description_fields(self): self.assertIsInstance(d, tuple) self.assertEqual(len(d), 7) self.assertIsInstance(d.name, str) - self.assertEqual(d.name, 'col%d' % i) + self.assertEqual(d.name, f'col{i}') self.assertIsInstance(d.type_code, str) self.assertEqual(d.type_code, c[0]) self.assertIsNone(d.display_size) @@ -382,7 +382,7 @@ def test_type_cache_typecast(self): cur = con.cursor() type_cache = con.type_cache self.assertIs(type_cache.get_typecast('int4'), int) - cast_int = lambda v: 'int(%s)' % v # noqa: E731 + cast_int = lambda v: f'int({v})' # noqa: E731 type_cache.set_typecast('int4', cast_int) query = 'select 2::int2, 4::int4, 8::int8' cur.execute(query) @@ -454,7 +454,7 @@ def test_fetch_2_rows(self): cur = con.cursor() cur.execute("set datestyle to iso") cur.execute( - "create table %s (" + f"create table {table} (" "stringtest varchar," "binarytest bytea," "booltest bool," @@ -467,16 +467,16 @@ def test_fetch_2_rows(self): "timetest time," "datetimetest timestamp," "intervaltest interval," - "rowidtest oid)" % table) + "rowidtest oid)") cur.execute("set standard_conforming_strings to on") for s in ('numeric', 'monetary', 'time'): - cur.execute("set lc_%s to 'C'" % s) + cur.execute(f"set lc_{s} to 'C'") for _i in range(2): cur.execute( - "insert into %s values (" - "%%s,%%s,%%s,%%s,%%s,%%s,%%s," - "'%%s'::money,%%s,%%s,%%s,%%s,%%s)" % table, values) - cur.execute("select * from %s" % table) + f"insert into {table} values (" + "%s,%s,%s,%s,%s,%s,%s," + "'%s'::money,%s,%s,%s,%s,%s)", values) + cur.execute(f"select * from {table}") rows = cur.fetchall() self.assertEqual(len(rows), 2) row0 = rows[0] @@ -503,12 +503,12 @@ def test_integrity_error(self): try: cur = con.cursor() cur.execute("set client_min_messages = warning") - cur.execute("create table %s (i int primary key)" % table) - cur.execute("insert into %s values (1)" % table) - cur.execute("insert into %s values (2)" % table) + cur.execute(f"create table {table} (i int primary key)") + cur.execute(f"insert into {table} values (1)") + cur.execute(f"insert into {table} values (2)") self.assertRaises( pgdb.IntegrityError, cur.execute, - "insert into %s values (1)" % table) + f"insert into {table} values (1)") finally: con.close() @@ -517,11 +517,11 @@ def test_update_rowcount(self): con = self._connect() try: cur = con.cursor() - cur.execute("create table %s (i int)" % table) - cur.execute("insert into %s values (1)" % table) - cur.execute("update %s set i=2 where i=2 returning i" % table) + cur.execute(f"create table {table} (i int)") + cur.execute(f"insert into {table} values (1)") + cur.execute(f"update {table} set i=2 where i=2 returning i") self.assertEqual(cur.rowcount, 0) - cur.execute("update %s set i=2 where i=1 returning i" % table) + cur.execute(f"update {table} set i=2 where i=1 returning i") self.assertEqual(cur.rowcount, 1) cur.close() # keep rowcount even if cursor is closed (needed by SQLAlchemy) @@ -552,10 +552,10 @@ def test_float(self): try: cur = con.cursor() cur.execute( - "create table %s (n smallint, floattest float)" % table) + f"create table {table} (n smallint, floattest float)") params = enumerate(values) - cur.executemany("insert into %s values (%%d,%%s)" % table, params) - cur.execute("select floattest from %s order by n" % table) + cur.executemany(f"insert into {table} values (%d,%s)", params) + cur.execute(f"select floattest from {table} order by n") rows = cur.fetchall() self.assertEqual(cur.description[0].type_code, pgdb.FLOAT) self.assertNotEqual(cur.description[0].type_code, pgdb.ARRAY) @@ -589,9 +589,9 @@ def test_datetime(self): try: cur = con.cursor() cur.execute("set timezone = UTC") - cur.execute("create table %s (" + cur.execute(f"create table {table} (" "d date, t time, ts timestamp," - "tz timetz, tsz timestamptz)" % table) + "tz timetz, tsz timestamptz)") for n in range(3): values = [dt.date(), dt.time(), dt, dt.time(), dt] values[3] = values[3].replace(tzinfo=timezone.utc) @@ -609,16 +609,16 @@ def test_datetime(self): pgdb.Timestamp(*(d + t + z))] for datestyle in ('iso', 'postgres, mdy', 'postgres, dmy', 'sql, mdy', 'sql, dmy', 'german'): - cur.execute("set datestyle to %s" % datestyle) + cur.execute(f"set datestyle to {datestyle}") if n != 1: # noinspection PyUnboundLocalVariable cur.execute("select %s,%s,%s,%s,%s", params) row = cur.fetchone() self.assertEqual(row, tuple(values)) cur.execute( - "insert into %s" - " values (%%s,%%s,%%s,%%s,%%s)" % table, params) - cur.execute("select * from %s" % table) + f"insert into {table}" + " values (%s,%s,%s,%s,%s)", params) + cur.execute(f"select * from {table}") d = cur.description for i in range(5): self.assertEqual(d[i].type_code, pgdb.DATETIME) @@ -632,7 +632,7 @@ def test_datetime(self): self.assertEqual(d[4].type_code, pgdb.TIMESTAMP) row = cur.fetchone() self.assertEqual(row, tuple(values)) - cur.execute("truncate table %s" % table) + cur.execute(f"truncate table {table}") finally: con.close() @@ -642,23 +642,22 @@ def test_interval(self): con = self._connect() try: cur = con.cursor() - cur.execute("create table %s (i interval)" % table) + cur.execute(f"create table {table} (i interval)") for n in range(3): if n == 0: # input as objects param = td if n == 1: # input as text - param = '%d days %d seconds %d microseconds ' % ( - td.days, td.seconds, td.microseconds) + param = (f'{td.days} days {td.seconds} seconds' + f' {td.microseconds} microseconds') elif n == 2: # input using type helpers param = pgdb.Interval( td.days, 0, 0, td.seconds, td.microseconds) for intervalstyle in ('sql_standard ', 'postgres', 'postgres_verbose', 'iso_8601'): - cur.execute("set intervalstyle to %s" % intervalstyle) + cur.execute(f"set intervalstyle to {intervalstyle}") # noinspection PyUnboundLocalVariable - cur.execute("insert into %s" - " values (%%s)" % table, [param]) - cur.execute("select * from %s" % table) + cur.execute(f"insert into {table} values (%s)", [param]) + cur.execute(f"select * from {table}") tc = cur.description[0].type_code self.assertEqual(tc, pgdb.DATETIME) self.assertNotEqual(tc, pgdb.STRING) @@ -667,7 +666,7 @@ def test_interval(self): self.assertEqual(tc, pgdb.INTERVAL) row = cur.fetchone() self.assertEqual(row, (td,)) - cur.execute("truncate table %s" % table) + cur.execute(f"truncate table {table}") finally: con.close() @@ -721,15 +720,15 @@ def test_insert_array(self): con = self._connect() try: cur = con.cursor() - cur.execute("create table %s" - " (n smallint, i int[], t text[][])" % table) + cur.execute( + f"create table {table} (n smallint, i int[], t text[][])") params = [(n, v[0], v[1]) for n, v in enumerate(values)] # Note that we must explicit casts because we are inserting # empty arrays. Otherwise this is not necessary. cur.executemany( - "insert into %s values" - " (%%d,%%s::int[],%%s::text[][])" % table, params) - cur.execute("select i, t from %s order by n" % table) + f"insert into {table} values" + " (%d,%s::int[],%s::text[][])", params) + cur.execute(f"select i, t from {table} order by n") d = cur.description self.assertEqual(d[0].type_code, pgdb.ARRAY) self.assertNotEqual(d[0].type_code, pgdb.RECORD) @@ -755,7 +754,7 @@ def test_select_array(self): self.assertEqual(row, values) def test_unicode_list_and_tuple(self): - value = (u'Käse', u'Würstchen') + value = ('Käse', 'Würstchen') con = self._connect() try: cur = con.cursor() @@ -780,11 +779,11 @@ def test_insert_record(self): con = self._connect() cur = con.cursor() try: - cur.execute("create type %s as (name varchar, age int)" % record) - cur.execute("create table %s (n smallint, r %s)" % (table, record)) + cur.execute(f"create type {record} as (name varchar, age int)") + cur.execute(f"create table {table} (n smallint, r {record})") params = enumerate(values) - cur.executemany("insert into %s values (%%d,%%s)" % table, params) - cur.execute("select r from %s order by n" % table) + cur.executemany(f"insert into {table} values (%d,%s)", params) + cur.execute(f"select r from {table} order by n") type_code = cur.description[0].type_code self.assertEqual(type_code, record) self.assertEqual(type_code, pgdb.RECORD) @@ -796,8 +795,8 @@ def test_insert_record(self): self.assertEqual(con.type_cache[columns[1].type], 'int4') rows = cur.fetchall() finally: - cur.execute('drop table %s' % table) - cur.execute('drop type %s' % record) + cur.execute(f'drop table {table}') + cur.execute(f'drop type {record}') con.close() self.assertEqual(len(rows), len(values)) rows = [row[0] for row in rows] @@ -832,9 +831,9 @@ def test_custom_type(self): cur = con.cursor() params = enumerate(values) # params have __pg_repr__ method cur.execute( - 'create table "%s" (n smallint, b bit varying(7))' % table) - cur.executemany("insert into %s values (%%s,%%s)" % table, params) - cur.execute("select * from %s" % table) + f'create table "{table}" (n smallint, b bit varying(7))') + cur.executemany(f"insert into {table} values (%s,%s)", params) + cur.execute(f"select * from {table}") rows = cur.fetchall() finally: con.close() @@ -845,7 +844,7 @@ def test_custom_type(self): params = (1, object()) # an object that cannot be handled self.assertRaises( pgdb.InterfaceError, cur.execute, - "insert into %s values (%%s,%%s)" % table, params) + f"insert into {table} values (%s,%s)", params) finally: con.close() @@ -887,7 +886,7 @@ def test_global_typecast(self): try: query = 'select 2::int2, 4::int4, 8::int8' self.assertIs(pgdb.get_typecast('int4'), int) - cast_int = lambda v: 'int(%s)' % v # noqa: E731 + cast_int = lambda v: f'int({v})' # noqa: E731 pgdb.set_typecast('int4', cast_int) con = self._connect() try: @@ -974,23 +973,23 @@ def test_set_typecast_for_arrays(self): def test_unicode_with_utf8(self): table = self.table_prefix + 'booze' - s = u"He wes Leovenaðes sone — liðe him be Drihten" + s = "He wes Leovenaðes sone — liðe him be Drihten" con = self._connect() cur = con.cursor() try: - cur.execute("create table %s (t text)" % table) + cur.execute(f"create table {table} (t text)") try: cur.execute("set client_encoding=utf8") - cur.execute(u"select '%s'" % s) + cur.execute(f"select '{s}'") except Exception: self.skipTest("database does not support utf8") output1 = cur.fetchone()[0] - cur.execute("insert into %s values (%%s)" % table, (s,)) - cur.execute("select * from %s" % table) + cur.execute(f"insert into {table} values (%s)", (s,)) + cur.execute(f"select * from {table}") output2 = cur.fetchone()[0] - cur.execute("select t = '%s' from %s" % (s, table)) + cur.execute(f"select t = '{s}' from {table}") output3 = cur.fetchone()[0] - cur.execute("select t = %%s from %s" % table, (s,)) + cur.execute(f"select t = %s from {table}", (s,)) output4 = cur.fetchone()[0] finally: con.close() @@ -1005,23 +1004,23 @@ def test_unicode_with_utf8(self): def test_unicode_with_latin1(self): table = self.table_prefix + 'booze' - s = u"Ehrt den König seine Würde, ehret uns der Hände Fleiß." + s = "Ehrt den König seine Würde, ehret uns der Hände Fleiß." con = self._connect() try: cur = con.cursor() - cur.execute("create table %s (t text)" % table) + cur.execute(f"create table {table} (t text)") try: cur.execute("set client_encoding=latin1") - cur.execute(u"select '%s'" % s) + cur.execute(f"select '{s}'") except Exception: self.skipTest("database does not support latin1") output1 = cur.fetchone()[0] - cur.execute("insert into %s values (%%s)" % table, (s,)) - cur.execute("select * from %s" % table) + cur.execute(f"insert into {table} values (%s)", (s,)) + cur.execute(f"select * from {table}") output2 = cur.fetchone()[0] - cur.execute("select t = '%s' from %s" % (s, table)) + cur.execute(f"select t = '{s}' from {table}") output3 = cur.fetchone()[0] - cur.execute("select t = %%s from %s" % table, (s,)) + cur.execute(f"select t = %s from {table}", (s,)) output4 = cur.fetchone()[0] finally: con.close() @@ -1040,11 +1039,10 @@ def test_bool(self): con = self._connect() try: cur = con.cursor() - cur.execute( - "create table %s (n smallint, booltest bool)" % table) + cur.execute(f"create table {table} (n smallint, booltest bool)") params = enumerate(values) - cur.executemany("insert into %s values (%%s,%%s)" % table, params) - cur.execute("select booltest from %s order by n" % table) + cur.executemany(f"insert into {table} values (%s,%s)", params) + cur.execute(f"select booltest from {table} order by n") rows = cur.fetchall() self.assertEqual(cur.description[0].type_code, pgdb.BOOL) finally: @@ -1073,12 +1071,12 @@ def test_json(self): try: cur = con.cursor() try: - cur.execute("create table %s (jsontest json)" % table) + cur.execute(f"create table {table} (jsontest json)") except pgdb.ProgrammingError: self.skipTest('database does not support json') params = (pgdb.Json(inval),) - cur.execute("insert into %s values (%%s)" % table, params) - cur.execute("select jsontest from %s" % table) + cur.execute(f"insert into {table} values (%s)", params) + cur.execute(f"select jsontest from {table}") outval = cur.fetchone()[0] self.assertEqual(cur.description[0].type_code, pgdb.JSON) finally: @@ -1093,12 +1091,12 @@ def test_jsonb(self): try: cur = con.cursor() try: - cur.execute("create table %s (jsonbtest jsonb)" % table) + cur.execute(f"create table {table} (jsonbtest jsonb)") except pgdb.ProgrammingError: self.skipTest('database does not support jsonb') params = (pgdb.Json(inval),) - cur.execute("insert into %s values (%%s)" % table, params) - cur.execute("select jsonbtest from %s" % table) + cur.execute(f"insert into {table} values (%s)", params) + cur.execute(f"select jsonbtest from {table}") outval = cur.fetchone()[0] self.assertEqual(cur.description[0].type_code, pgdb.JSON) finally: @@ -1135,8 +1133,8 @@ def test_fetchall_with_various_sizes(self): for n in (1, 3, 5, 7, 10, 100, 1000): cur = con.cursor() try: - cur.execute('select n, n::text as s, n %% 2 = 1 as b' - ' from generate_series(1, %d) as s(n)' % n) + cur.execute('select n, n::text as s, n % 2 = 1 as b' + f' from generate_series(1, {n}) as s(n)') res = cur.fetchall() self.assertEqual(len(res), n, res) self.assertEqual(len(res[0]), 3) @@ -1212,13 +1210,13 @@ def test_transaction(self): con1.commit() con2 = self._connect() cur2 = con2.cursor() - cur2.execute("select name from %s" % table) + cur2.execute(f"select name from {table}") self.assertIsNone(cur2.fetchone()) - cur1.execute("insert into %s values('Schlafly')" % table) - cur2.execute("select name from %s" % table) + cur1.execute(f"insert into {table} values('Schlafly')") + cur2.execute(f"select name from {table}") self.assertIsNone(cur2.fetchone()) con1.commit() - cur2.execute("select name from %s" % table) + cur2.execute(f"select name from {table}") self.assertEqual(cur2.fetchone(), ('Schlafly',)) con2.close() con1.close() @@ -1231,10 +1229,10 @@ def test_autocommit(self): self.executeDDL1(cur1) con2 = self._connect() cur2 = con2.cursor() - cur2.execute("select name from %s" % table) + cur2.execute(f"select name from {table}") self.assertIsNone(cur2.fetchone()) - cur1.execute("insert into %s values('Shmaltz Pastrami')" % table) - cur2.execute("select name from %s" % table) + cur1.execute(f"insert into {table} values('Shmaltz Pastrami')") + cur2.execute(f"select name from {table}") self.assertEqual(cur2.fetchone(), ('Shmaltz Pastrami',)) con2.close() con1.close() @@ -1247,32 +1245,32 @@ def test_connection_as_contextmanager(self): try: cur = con.cursor() if autocommit: - cur.execute("truncate table %s" % table) + cur.execute(f"truncate table {table}") else: cur.execute( - "create table %s (n smallint check(n!=4))" % table) + f"create table {table} (n smallint check(n!=4))") with con: - cur.execute("insert into %s values (1)" % table) - cur.execute("insert into %s values (2)" % table) + cur.execute(f"insert into {table} values (1)") + cur.execute(f"insert into {table} values (2)") try: with con: - cur.execute("insert into %s values (3)" % table) - cur.execute("insert into %s values (4)" % table) + cur.execute(f"insert into {table} values (3)") + cur.execute(f"insert into {table} values (4)") except con.IntegrityError as error: self.assertTrue('check' in str(error).lower()) with con: - cur.execute("insert into %s values (5)" % table) - cur.execute("insert into %s values (6)" % table) + cur.execute(f"insert into {table} values (5)") + cur.execute(f"insert into {table} values (6)") try: with con: - cur.execute("insert into %s values (7)" % table) - cur.execute("insert into %s values (8)" % table) + cur.execute(f"insert into {table} values (7)") + cur.execute(f"insert into {table} values (8)") raise ValueError('transaction should rollback') except ValueError as error: self.assertEqual(str(error), 'transaction should rollback') with con: - cur.execute("insert into %s values (9)" % table) - cur.execute("select * from %s order by 1" % table) + cur.execute(f"insert into {table} values (9)") + cur.execute(f"select * from {table} order by 1") rows = cur.fetchall() rows = [row[0] for row in rows] finally: diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index 769065ab..d461825c 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -100,7 +100,7 @@ class TestCopy(unittest.TestCase): @staticmethod def connect(): - host = '%s:%d' % (dbhost or '', dbport or -1) + host = f"{dbhost or ''}:{dbport or -1}" return pgdb.connect(database=dbname, host=host, user=dbuser, password=dbpasswd) @@ -163,11 +163,11 @@ def tearDown(self): @property def data_text(self): - return ''.join('%d\t%s\n' % row for row in self.data) + return ''.join('{}\t{}\n'.format(*row) for row in self.data) @property def data_csv(self): - return ''.join('%d,%s\n' % row for row in self.data) + return ''.join('{},{}\n'.format(*row) for row in self.data) def truncate_table(self): self.cursor.execute("truncate table copytest") @@ -259,7 +259,7 @@ def test_input_iterable_invalid(self): self.assertRaises(IOError, self.copy_from, [None]) def test_input_iterable_with_newlines(self): - self.copy_from('%s\n' % row for row in self.data_text.splitlines()) + self.copy_from(f'{row}\n' for row in self.data_text.splitlines()) self.check_table() def test_input_iterable_bytes(self): @@ -268,7 +268,7 @@ def test_input_iterable_bytes(self): self.check_table() def test_sep(self): - stream = ('%d-%s' % row for row in self.data) + stream = ('{}-{}'.format(*row) for row in self.data) self.copy_from(stream, sep='-') self.check_table() @@ -311,7 +311,7 @@ def test_csv(self): self.check_table() def test_csv_with_sep(self): - stream = ('%d;"%s"\n' % row for row in self.data) + stream = ('{};"{}"\n'.format(*row) for row in self.data) self.copy_from(stream, format='csv', sep=';') self.check_table() self.check_rowcount() @@ -326,7 +326,7 @@ def test_binary_with_sep(self): ValueError, self.copy_from, '', format='binary', sep='\t') def test_binary_with_unicode(self): - self.assertRaises(ValueError, self.copy_from, u'', format='binary') + self.assertRaises(ValueError, self.copy_from, '', format='binary') def test_query(self): self.assertRaises(ValueError, self.cursor.copy_from, '', "select null") @@ -441,10 +441,10 @@ def test_decode(self): def test_sep(self): ret = list(self.copy_to(sep='-')) - self.assertEqual(ret, ['%d-%s\n' % row for row in self.data]) + self.assertEqual(ret, ['{}-{}\n'.format(*row) for row in self.data]) def test_null(self): - data = ['%d\t%s\n' % row for row in self.data] + data = ['{}\t{}\n'.format(*row) for row in self.data] self.cursor.execute('insert into copytest values(4, null)') try: ret = list(self.copy_to()) @@ -457,8 +457,8 @@ def test_null(self): self.cursor.execute('delete from copytest where id=4') def test_columns(self): - data_id = ''.join('%d\n' % row[0] for row in self.data) - data_name = ''.join('%s\n' % row[1] for row in self.data) + data_id = ''.join(f'{row[0]}\n' for row in self.data) + data_name = ''.join(f'{row[1]}\n' for row in self.data) ret = ''.join(self.copy_to(columns='id')) self.assertEqual(ret, data_id) ret = ''.join(self.copy_to(columns=['id'])) @@ -513,7 +513,7 @@ def test_query(self): rows = list(ret) self.assertEqual(len(rows), 1) self.assertIsInstance(rows[0], str) - self.assertEqual(rows[0], '%s!\n' % self.data[1][1]) + self.assertEqual(rows[0], f'{self.data[1][1]}!\n') self.check_rowcount(1) def test_file(self): diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index 0193165a..a497914b 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -108,7 +108,7 @@ class TestDbApi20Tutorial(unittest.TestCase): def setUp(self): """Setup test tables or empty them if they already exist.""" - host = '%s:%d' % (dbhost or '', dbport or -1) + host = f"{dbhost or ''}:{dbport or -1}" con = connect(database=dbname, host=host, user=dbuser, password=dbpasswd) cur = con.cursor() From b87e87bbde0d3f8e959d86820db4e15c9bd31b57 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 Sep 2023 13:32:33 +0200 Subject: [PATCH 033/118] Add next row extension to DBAPI 20 conformance test --- tests/dbapi20.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/dbapi20.py b/tests/dbapi20.py index 798bbc49..32045fa4 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -461,6 +461,48 @@ def test_fetchone(self): finally: con.close() + def test_next(self): + """Extension for getting the next row""" + con = self._connect() + try: + cur = con.cursor() + if not hasattr(cur, 'next'): + return + + # cursor.next should raise an Error if called before + # executing a select-type query + self.assertRaises(self.driver.Error, cur.next) + + # cursor.next should raise an Error if called after + # executing a query that cannot return rows + self.executeDDL1(cur) + self.assertRaises(self.driver.Error, cur.next) + + # cursor.next should return None if a query retrieves no rows + cur.execute(f'select name from {self.table_prefix}booze') + self.assertRaises(StopIteration, cur.next) + self.assertIn(cur.rowcount, (-1, 0)) + + # cursor.next should raise an Error if called after + # executing a query that cannot return rows + cur.execute(f"{self.insert} into {self.table_prefix}booze" + " values ('Victoria Bitter')") + self.assertRaises(self.driver.Error, cur.next) + + cur.execute(f'select name from {self.table_prefix}booze') + r = cur.next() + self.assertEqual( + len(r), 1, + 'cursor.fetchone should have retrieved a single row') + self.assertEqual( + r[0], 'Victoria Bitter', + 'cursor.next retrieved incorrect data') + # cursor.next should raise StopIteration if no more rows available + self.assertRaises(StopIteration, cur.next) + self.assertIn(cur.rowcount, (-1, 1)) + finally: + con.close() + samples = [ 'Carlton Cold', 'Carlton Draft', From 324b8fc9cae976e2b8e42e69aace59ff36fb2881 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 Sep 2023 13:53:22 +0200 Subject: [PATCH 034/118] Some more string formatting modernization --- docs/conf.py | 2 +- docs/contents/pg/adaptation.rst | 4 ++-- docs/contents/pg/connection.rst | 2 +- docs/contents/pg/db_wrapper.rst | 2 +- docs/contents/pg/module.rst | 8 ++++---- docs/contents/pgdb/adaptation.rst | 4 ++-- docs/contents/postgres/advanced.rst | 6 +++--- setup.py | 6 +++--- 8 files changed, 17 insertions(+), 17 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 933c4e38..f5789d29 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -68,7 +68,7 @@ html_theme = 'alabaster' html_static_path = ['_static'] -html_title = 'PyGreSQL %s' % version +html_title = f'PyGreSQL {version}' html_logo = '_static/pygresql.png' html_favicon = '_static/favicon.ico' diff --git a/docs/contents/pg/adaptation.rst b/docs/contents/pg/adaptation.rst index c5d0a795..de82cbfa 100644 --- a/docs/contents/pg/adaptation.rst +++ b/docs/contents/pg/adaptation.rst @@ -231,7 +231,7 @@ our values:: ... self.price = price ... ... def __str__(self): - ... return '%s (from %s, at $%s)' % ( + ... return '{} (from {}, at ${})'.format( ... self.name, self.supplier_id, self.price) But when we try to insert an instance of this class in the same way, we @@ -246,7 +246,7 @@ PostgreSQL by adding a "magic" method with the name ``__pg_str__``, like so:: ... ... ... ... def __str__(self): - ... return '%s (from %s, at $%s)' % ( + ... return '{} (from {}, at ${})'.format( ... self.name, self.supplier_id, self.price) ... ... def __pg_str__(self, typ): diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index 237e25a8..1adf29d1 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -157,7 +157,7 @@ Examples:: s = con1.query("begin; set transaction isolation level repeatable read;" "select pg_export_snapshot();").single() con2.query("begin; set transaction isolation level repeatable read;" - "set transaction snapshot '%s'" % (s,)) + f"set transaction snapshot '{s}'") q1 = con1.send_query("select a,b,c from x where d=e") q2 = con2.send_query("select e,f from y where g") r1 = q1.getresult() diff --git a/docs/contents/pg/db_wrapper.rst b/docs/contents/pg/db_wrapper.rst index 5d587f97..68d33c65 100644 --- a/docs/contents/pg/db_wrapper.rst +++ b/docs/contents/pg/db_wrapper.rst @@ -16,7 +16,7 @@ The preferred way to use this module is as follows:: for r in db.query( # just for example "SELECT foo, bar FROM foo_bar_table WHERE foo !~ bar" ).dictresult(): - print('%(foo)s %(bar)s' % r) + print('{foo} {bar}'.format(**r)) This class can be subclassed as in this example:: diff --git a/docs/contents/pg/module.rst b/docs/contents/pg/module.rst index 203ada03..2dc26d5f 100644 --- a/docs/contents/pg/module.rst +++ b/docs/contents/pg/module.rst @@ -289,8 +289,8 @@ which takes connection properties into account. Example:: name = input("Name? ") - phone = con.query("select phone from employees where name='%s'" - % escape_string(name)).getresult() + phone = con.query("select phone from employees" + f" where name='{escape_string(name)}'").singlescalar() escape_bytea -- escape binary data for use within SQL ----------------------------------------------------- @@ -315,8 +315,8 @@ which takes connection properties into account. Example:: picture = open('garfield.gif', 'rb').read() - con.query("update pictures set img='%s' where name='Garfield'" - % escape_bytea(picture)) + con.query(f"update pictures set img='{escape_bytea(picture)}'" + " where name='Garfield'") unescape_bytea -- unescape data that has been retrieved as text --------------------------------------------------------------- diff --git a/docs/contents/pgdb/adaptation.rst b/docs/contents/pgdb/adaptation.rst index ebb36e5b..ac649a21 100644 --- a/docs/contents/pgdb/adaptation.rst +++ b/docs/contents/pgdb/adaptation.rst @@ -209,7 +209,7 @@ to hold our values, like this one:: ... self.price = price ... ... def __str__(self): - ... return '%s (from %s, at $%s)' % ( + ... return '{} (from {}, at ${})'.format( ... self.name, self.supplier_id, self.price) But when we try to insert an instance of this class in the same way, we @@ -231,7 +231,7 @@ with the name ``__pg_repr__``, like this:: ... ... ... ... def __str__(self): - ... return '%s (from %s, at $%s)' % ( + ... return '{} (from {}, at ${})'.format( ... self.name, self.supplier_id, self.price) ... ... def __pg_repr__(self): diff --git a/docs/contents/postgres/advanced.rst b/docs/contents/postgres/advanced.rst index e3e2ab10..d7627312 100644 --- a/docs/contents/postgres/advanced.rst +++ b/docs/contents/postgres/advanced.rst @@ -27,7 +27,7 @@ all data fields from cities):: ... "'Las Vegas', 2.583E+5, 2174", ... "'Mariposa', 1200, 1953"]), ... ('capitals', [ - ... "'Sacramento',3.694E+5,30,'CA'", + ... "'Sacramento', 3.694E+5,30, 'CA'", ... "'Madison', 1.913E+5, 845, 'WI'"])] Now, let's populate the tables:: @@ -37,11 +37,11 @@ Now, let's populate the tables:: ... "'Las Vegas', 2.583E+5, 2174" ... "'Mariposa', 1200, 1953"], ... 'capitals', [ - ... "'Sacramento',3.694E+5,30,'CA'", + ... "'Sacramento', 3.694E+5,30, 'CA'", ... "'Madison', 1.913E+5, 845, 'WI'"]] >>> for table, rows in data: ... for row in rows: - ... query("INSERT INTO %s VALUES (%s)" % (table, row)) + ... query(f"INSERT INTO {table} VALUES (row)") >>> print(query("SELECT * FROM cities")) name |population|altitude -------------+----------+-------- diff --git a/setup.py b/setup.py index 456e3b5e..08d43dae 100755 --- a/setup.py +++ b/setup.py @@ -56,7 +56,7 @@ if not (3, 7) <= sys.version_info[:2] < (4, 0): raise Exception( - "Sorry, PyGreSQL %s does not support this Python version" % version) + f"Sorry, PyGreSQL {version} does not support this Python version") # For historical reasons, PyGreSQL does not install itself as a single # "pygresql" package, but as two top-level modules "pg", providing the @@ -69,12 +69,12 @@ def pg_config(s): """Retrieve information about installed version of PostgreSQL.""" - f = os.popen('pg_config --%s' % s) + f = os.popen(f'pg_config --{s}') d = f.readline().strip() if f.close() is not None: raise Exception("pg_config tool is not available.") if not d: - raise Exception("Could not get %s information." % s) + raise Exception(f"Could not get {s} information.") return d From e3398d5ec3b919de4e81988615302421893124cc Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 Sep 2023 14:03:26 +0200 Subject: [PATCH 035/118] Add more context to errors --- docs/contents/postgres/func.rst | 2 +- pg.py | 18 +++++++++--------- pgdb.py | 12 ++++++------ 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/docs/contents/postgres/func.rst b/docs/contents/postgres/func.rst index 9d0f5967..3bfcfd98 100644 --- a/docs/contents/postgres/func.rst +++ b/docs/contents/postgres/func.rst @@ -62,7 +62,7 @@ Before we create more sophisticated functions, let's populate an EMP table:: ... "'Bill', 4200, 36, 'shoe'", ... "'Ginger', 4800, 30, 'candy'"] >>> for emp in emps: - ... query("INSERT INTO EMP VALUES (%s)" % emp) + ... query(f"INSERT INTO EMP VALUES ({emp})") Every INSERT statement will return a '1' indicating that it has inserted one row into the EMP table. diff --git a/pg.py b/pg.py index 50f22425..923d2743 100644 --- a/pg.py +++ b/pg.py @@ -950,7 +950,7 @@ def __missing__(self, typ): but returns None when no special cast function exists. """ if not isinstance(typ, str): - raise TypeError('Invalid type: {typ}') + raise TypeError(f'Invalid type: {typ}') cast = self.defaults.get(typ) if cast: # store default for faster access @@ -2257,8 +2257,8 @@ def update(self, table, row=None, **kw): else: # try using the primary key try: keyname = self.pkey(table, True) - except KeyError: # the table has no primary key - raise _prg_error(f'Table {table} has no primary key') + except KeyError as e: # the table has no primary key + raise _prg_error(f'Table {table} has no primary key') from e # check whether all key columns have values if not set(keyname).issubset(row): raise KeyError('Missing value for primary key in row') @@ -2359,8 +2359,8 @@ def upsert(self, table, row=None, **kw): names, values = ', '.join(names), ', '.join(values) try: keyname = self.pkey(table, True) - except KeyError: - raise _prg_error(f'Table {table} has no primary key') + except KeyError as e: + raise _prg_error(f'Table {table} has no primary key') from e target = ', '.join(col(k) for k in keyname) update = [] keyname = set(keyname) @@ -2444,8 +2444,8 @@ def delete(self, table, row=None, **kw): else: # try using the primary key try: keyname = self.pkey(table, True) - except KeyError: # the table has no primary key - raise _prg_error(f'Table {table} has no primary key') + except KeyError as e: # the table has no primary key + raise _prg_error(f'Table {table} has no primary key') from e # check whether all key columns have values if not set(keyname).issubset(row): raise KeyError('Missing value for primary key in row') @@ -2612,8 +2612,8 @@ def get_as_dict(self, table, keyname=None, what=None, where=None, if not keyname: try: keyname = self.pkey(table, True) - except (KeyError, ProgrammingError): - raise _prg_error(f'Table {table} has no primary key') + except (KeyError, ProgrammingError) as e: + raise _prg_error(f'Table {table} has no primary key') from e if isinstance(keyname, str): keyname = [keyname] elif not isinstance(keyname, (list, tuple)): diff --git a/pgdb.py b/pgdb.py index 5e218b42..44b6a83e 100644 --- a/pgdb.py +++ b/pgdb.py @@ -865,9 +865,9 @@ def _quote(self, value): return f'({v})' try: # noinspection PyUnresolvedReferences value = value.__pg_repr__() - except AttributeError: + except AttributeError as e: raise InterfaceError( - f'Do not know how to adapt type {type(value)}') + f'Do not know how to adapt type {type(value)}') from e if isinstance(value, (tuple, list)): value = self._quote(value) return value @@ -965,8 +965,8 @@ def executemany(self, operation, seq_of_parameters): self._src.execute(sql) except DatabaseError: raise # database provides error message - except Exception: - raise _op_error("Can't start transaction") + except Exception as e: + raise _op_error("Can't start transaction") from e else: self._dbcnx._tnx = True for parameters in seq_of_parameters: @@ -983,7 +983,7 @@ def executemany(self, operation, seq_of_parameters): # noinspection PyTypeChecker raise _db_error(f"Error in '{sql}': '{err}'", InterfaceError) except Exception as err: - raise _op_error(f"Internal error in '{sql}': {err}") + raise _op_error(f"Internal error in '{sql}': {err}") from err # then initialize result raw count and description if self._src.resulttype == RESULT_DQL: self._description = True # fetch on demand @@ -1027,7 +1027,7 @@ def fetchmany(self, size=None, keep=False): except DatabaseError: raise except Error as err: - raise _db_error(str(err)) + raise _db_error(str(err)) from err row_factory = self.row_factory coltypes = self.coltypes if len(result) > 5: From 816ec354723da1571a682e3215b06e4c6a9c7b1d Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 Sep 2023 20:01:48 +0200 Subject: [PATCH 036/118] Replace flake8 with ruff and use pyproject.toml Also fix some minor issues that were detected by ruff. Remove announcements from static docs and inline about page. --- .bumpversion.cfg | 16 ---- .devcontainer/devcontainer.json | 1 + .devcontainer/provision.sh | 10 ++- .flake8 | 4 - .github/workflows/lint.yml | 2 +- docs/about.rst | 42 +++++++++- docs/about.txt | 41 ---------- docs/announce.rst | 26 ------ docs/conf.py | 2 +- docs/contents/changelog.rst | 6 ++ docs/download/index.rst | 6 +- docs/index.rst | 1 - pg.py | 131 +++++++++++++++++++++---------- pgdb.py | 63 +++++++++------ pyproject.toml | 75 ++++++++++++++++++ setup.py | 89 +++------------------ tests/config.py | 4 +- tests/dbapi20.py | 5 +- tests/test_classic.py | 24 +++--- tests/test_classic_connection.py | 31 ++++---- tests/test_classic_dbwrapper.py | 43 +++++----- tests/test_classic_functions.py | 10 +-- tests/test_classic_largeobj.py | 16 ++-- tests/test_dbapi20.py | 2 +- tests/test_dbapi20_copy.py | 29 ++++--- tests/test_tutorial.py | 4 +- tox.ini | 17 +++- 27 files changed, 372 insertions(+), 328 deletions(-) delete mode 100644 .flake8 delete mode 100644 docs/about.txt delete mode 100644 docs/announce.rst create mode 100644 pyproject.toml diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 89aec55e..769d02cf 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -12,22 +12,6 @@ serialize = search = version = '{current_version}' replace = version = '{new_version}' -[bumpversion:file (head):setup.py] -search = PyGreSQL version {current_version} -replace = PyGreSQL version {new_version} - [bumpversion:file:docs/conf.py] search = version = release = '{current_version}' replace = version = release = '{new_version}' - -[bumpversion:file:docs/about.txt] -search = PyGreSQL {current_version} -replace = PyGreSQL {new_version} - -[bumpversion:file:docs/announce.rst] -search = PyGreSQL version {current_version} -replace = PyGreSQL version {new_version} - -[bumpversion:file (text):docs/announce.rst] -search = Release {current_version} of PyGreSQL -replace = Release {new_version} of PyGreSQL diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index c1374910..b9fbaaeb 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -45,6 +45,7 @@ "njpwerner.autodocstring", "redhat.vscode-yaml", "eamodio.gitlens", + "charliermarsh.ruff", "streetsidesoftware.code-spell-checker", "lextudio.restructuredtext" ] diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index b47abb8c..a42337b8 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -27,9 +27,15 @@ sudo apt-get install -y python3.9 python3.9-dev python3.9-distutils sudo apt-get install -y python3.10 python3.10-dev python3.10-distutils sudo apt-get install -y python3.11 python3.11-dev python3.11-distutils -# install testing tool +# install build and testing tool -sudo apt-get install -y tox +python3.7 -m pip install build +python3.8 -m pip install build +python3.9 -m pip install build +python3.10 -m pip install build +python3.11 -m pip install build + +sudo apt-get install -y tox python3-poetry # install PostgreSQL client tools diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 3f6e0a3c..00000000 --- a/.flake8 +++ /dev/null @@ -1,4 +0,0 @@ -[flake8] -ignore = F403,F405,W503 -exclude = .git,.tox,.venv,build,dist,docs -max-line-length = 79 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 54ae2fd3..40f5299e 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -22,5 +22,5 @@ jobs: with: python-version: 3.11 - name: Run quality checks - run: tox -e flake8,docs + run: tox -e ruff,docs timeout-minutes: 5 diff --git a/docs/about.rst b/docs/about.rst index 3e61d030..8235e5cc 100644 --- a/docs/about.rst +++ b/docs/about.rst @@ -1,4 +1,44 @@ About PyGreSQL ============== -.. include:: about.txt \ No newline at end of file +**PyGreSQL** is an *open-source* `Python `_ module +that interfaces to a `PostgreSQL `_ database. +It wraps the lower level C API library libpq to allow easy use of the +powerful PostgreSQL features from Python. + + | This software is copyright © 1995, Pascal Andre. + | Further modifications are copyright © 1997-2008 by D'Arcy J.M. Cain. + | Further modifications are copyright © 2009-2023 by the PyGreSQL team. + | For licensing details, see the full :doc:`copyright`. + +**PostgreSQL** is a highly scalable, SQL compliant, open source +object-relational database management system. With more than 20 years +of development history, it is quickly becoming the de facto database +for enterprise level open source solutions. +Best of all, PostgreSQL's source code is available under the most liberal +open source license: the BSD license. + +**Python** Python is an interpreted, interactive, object-oriented +programming language. It is often compared to Tcl, Perl, Scheme or Java. +Python combines remarkable power with very clear syntax. It has modules, +classes, exceptions, very high level dynamic data types, and dynamic typing. +There are interfaces to many system calls and libraries, as well as to +various windowing systems (X11, Motif, Tk, Mac, MFC). New built-in modules +are easily written in C or C++. Python is also usable as an extension +language for applications that need a programmable interface. +The Python implementation is copyrighted but freely usable and distributable, +even for commercial use. + +**PyGreSQL** is a Python module that interfaces to a PostgreSQL database. +It wraps the lower level C API library libpq to allow easy use of the +powerful PostgreSQL features from Python. + +PyGreSQL is developed and tested on a NetBSD system, but it also runs on +most other platforms where PostgreSQL and Python is running. It is based +on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). +D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with +version 2.0 and serves as the "BDFL" of PyGreSQL. + +The current version PyGreSQL |version| needs PostgreSQL 10 to 15, and Python +3.7 to 3.11. If you need to support older PostgreSQL or Python versions, +you can resort to the PyGreSQL 5.x versions that still support them. diff --git a/docs/about.txt b/docs/about.txt deleted file mode 100644 index 04f615e1..00000000 --- a/docs/about.txt +++ /dev/null @@ -1,41 +0,0 @@ -**PyGreSQL** is an *open-source* `Python `_ module -that interfaces to a `PostgreSQL `_ database. -It wraps the lower level C API library libpq to allow easy use of the -powerful PostgreSQL features from Python. - - | This software is copyright © 1995, Pascal Andre. - | Further modifications are copyright © 1997-2008 by D'Arcy J.M. Cain. - | Further modifications are copyright © 2009-2023 by the PyGreSQL team. - | For licensing details, see the full :doc:`copyright`. - -**PostgreSQL** is a highly scalable, SQL compliant, open source -object-relational database management system. With more than 20 years -of development history, it is quickly becoming the de facto database -for enterprise level open source solutions. -Best of all, PostgreSQL's source code is available under the most liberal -open source license: the BSD license. - -**Python** Python is an interpreted, interactive, object-oriented -programming language. It is often compared to Tcl, Perl, Scheme or Java. -Python combines remarkable power with very clear syntax. It has modules, -classes, exceptions, very high level dynamic data types, and dynamic typing. -There are interfaces to many system calls and libraries, as well as to -various windowing systems (X11, Motif, Tk, Mac, MFC). New built-in modules -are easily written in C or C++. Python is also usable as an extension -language for applications that need a programmable interface. -The Python implementation is copyrighted but freely usable and distributable, -even for commercial use. - -**PyGreSQL** is a Python module that interfaces to a PostgreSQL database. -It wraps the lower level C API library libpq to allow easy use of the -powerful PostgreSQL features from Python. - -PyGreSQL is developed and tested on a NetBSD system, but it also runs on -most other platforms where PostgreSQL and Python is running. It is based -on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). -D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with -version 2.0 and serves as the "BDFL" of PyGreSQL. - -The current version PyGreSQL 6.0 needs PostgreSQL 10 to 15, and Python -3.7 to 3.11. If you need to support older PostgreSQL or Python versions, -you can resort to the PyGreSQL 5.x versions that still support them. diff --git a/docs/announce.rst b/docs/announce.rst deleted file mode 100644 index d0a5f19c..00000000 --- a/docs/announce.rst +++ /dev/null @@ -1,26 +0,0 @@ -====================== -PyGreSQL Announcements -====================== - -------------------------------- -Release of PyGreSQL version 6.0 -------------------------------- - -Release 6.0 of PyGreSQL. - -It is available at: https://pypi.org/project/PyGreSQL/. - -If you are running NetBSD, look in the packages directory under databases. -There is also a package in the FreeBSD ports collection. - -Please refer to `changelog.txt `_ -for things that have changed in this version. - -This version has been built and unit tested on: - - Ubuntu - - Windows 7 and 10 with Visual Studio - - PostgreSQL 10 to 15 (32 and 64bit) - - Python 3.7 to 3.11 (32 and 64bit) - -| D'Arcy J.M. Cain -| darcy@PyGreSQL.org diff --git a/docs/conf.py b/docs/conf.py index f5789d29..0f95ab1b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -10,7 +10,7 @@ author = 'The PyGreSQL team' copyright = '2023, ' + author -version = release = '5.2.5' +version = release = '6.0' language = 'en' diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index bc8322f4..e2b68425 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,12 @@ ChangeLog ========= +Version 6.0 (to be released) +---------------------------- +- Removed support for Python versions older than 3.7 (released June 2017) + and PostgreSQL older than version 10 (released October 2017). +- Modernized code and tools for development, testing, linting and building. + Version 5.2.5 (2023-08-28) -------------------------- - This version officially supports the new Python 3.11 and PostgreSQL 15. diff --git a/docs/download/index.rst b/docs/download/index.rst index c4735826..88bf77b0 100644 --- a/docs/download/index.rst +++ b/docs/download/index.rst @@ -3,10 +3,8 @@ Download information .. include:: download.rst -News, Changes and Future Development ------------------------------------- - -See the :doc:`../announce` for current news. +Changes and Future Development +------------------------------ For a list of all changes in the current version |version| and in past versions, have a look at the :doc:`../contents/changelog`. diff --git a/docs/index.rst b/docs/index.rst index c40103a8..88292059 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,7 +6,6 @@ Welcome to PyGreSQL about copyright - announce download/index contents/index community/index diff --git a/pg.py b/pg.py index 923d2743..70de429e 100644 --- a/pg.py +++ b/pg.py @@ -22,7 +22,7 @@ try: from _pg import version -except ImportError as e: +except ImportError as e: # noqa: F841 import os libpq = 'libpq.' if os.name == 'nt': @@ -55,29 +55,69 @@ # import objects from extension module from _pg import ( - Error, Warning, - DataError, DatabaseError, - IntegrityError, InterfaceError, InternalError, - InvalidResultError, MultipleResultsError, - NoResultError, NotSupportedError, - OperationalError, ProgrammingError, - INV_READ, INV_WRITE, - POLLING_OK, POLLING_FAILED, POLLING_READING, POLLING_WRITING, - SEEK_CUR, SEEK_END, SEEK_SET, - TRANS_ACTIVE, TRANS_IDLE, TRANS_INERROR, - TRANS_INTRANS, TRANS_UNKNOWN, - cast_array, cast_hstore, cast_record, - connect, escape_bytea, escape_string, unescape_bytea, - get_array, get_bool, get_bytea_escaped, - get_datestyle, get_decimal, get_decimal_point, - get_defbase, get_defhost, get_defopt, get_defport, get_defuser, - get_jsondecode, get_pqlib_version, - set_array, set_bool, set_bytea_escaped, - set_datestyle, set_decimal, set_decimal_point, - set_defbase, set_defhost, set_defopt, - set_defpasswd, set_defport, set_defuser, - set_jsondecode, set_query_helpers, - version) + INV_READ, + INV_WRITE, + POLLING_FAILED, + POLLING_OK, + POLLING_READING, + POLLING_WRITING, + SEEK_CUR, + SEEK_END, + SEEK_SET, + TRANS_ACTIVE, + TRANS_IDLE, + TRANS_INERROR, + TRANS_INTRANS, + TRANS_UNKNOWN, + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + InvalidResultError, + MultipleResultsError, + NoResultError, + NotSupportedError, + OperationalError, + ProgrammingError, + Warning, + cast_array, + cast_hstore, + cast_record, + connect, + escape_bytea, + escape_string, + get_array, + get_bool, + get_bytea_escaped, + get_datestyle, + get_decimal, + get_decimal_point, + get_defbase, + get_defhost, + get_defopt, + get_defport, + get_defuser, + get_jsondecode, + get_pqlib_version, + set_array, + set_bool, + set_bytea_escaped, + set_datestyle, + set_decimal, + set_decimal_point, + set_defbase, + set_defhost, + set_defopt, + set_defpasswd, + set_defport, + set_defuser, + set_jsondecode, + set_query_helpers, + unescape_bytea, + version, +) __version__ = version @@ -112,19 +152,18 @@ import select import warnings import weakref - -from datetime import date, time, datetime, timedelta +from collections import OrderedDict, namedtuple +from datetime import date, datetime, time, timedelta from decimal import Decimal -from math import isnan, isinf -from collections import namedtuple, OrderedDict +from functools import lru_cache, partial from inspect import signature +from json import dumps as jsonencode +from json import loads as jsondecode +from math import isinf, isnan from operator import itemgetter -from functools import lru_cache, partial from re import compile as regex -from json import loads as jsondecode, dumps as jsonencode -from uuid import UUID from typing import Dict, List, Union # noqa: F401 - +from uuid import UUID # Auxiliary classes and functions that are independent of a DB connection: @@ -174,6 +213,7 @@ def _quote(cls, s): return s def __str__(self): + """Create a printable representation of the hstore value.""" q = self._quote return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) @@ -182,10 +222,12 @@ class Json: """Wrapper class for marking Json values.""" def __init__(self, obj, encode=None): + """Initialize the JSON object.""" self.obj = obj self.encode = encode or jsonencode def __str__(self): + """Create a printable representation of the JSON object.""" obj = self.obj if isinstance(obj, str): return obj @@ -313,6 +355,7 @@ class Adapter: _re_array_escape = _re_record_escape = regex(r'(["\\])') def __init__(self, db): + """Initialize the adapter object with the given connection.""" self.db = weakref.proxy(db) @classmethod @@ -1124,7 +1167,7 @@ class DbTypes(dict): def __init__(self, db): """Initialize type cache for connection.""" - super(DbTypes, self).__init__() + super().__init__() self._db = weakref.proxy(db) self._regtypes = False self._typecasts = Typecasts() @@ -1315,7 +1358,7 @@ def _prg_error(msg): # The notification handler -class NotificationHandler(object): +class NotificationHandler: """A PostgreSQL client-side asynchronous notification handler.""" def __init__(self, db, event, callback=None, @@ -1348,6 +1391,7 @@ def __init__(self, db, event, callback=None, self.timeout = timeout def __del__(self): + """Delete the notification handler.""" self.unlisten() def close(self): @@ -1440,7 +1484,10 @@ def __call__(self): def pgnotify(*args, **kw): - """Same as NotificationHandler, under the traditional name.""" + """Create a notification handler. + + Same as NotificationHandler, under the traditional name. + """ warnings.warn("pgnotify is deprecated, use NotificationHandler instead", DeprecationWarning, stacklevel=2) return NotificationHandler(*args, **kw) @@ -1454,7 +1501,7 @@ class DB: db = None # invalid fallback for underlying connection def __init__(self, *args, **kw): - """Create a new connection + """Create a new connection. You can pass either the connection parameters or an existing _pg or pgdb connection. This allows you to use the methods @@ -1519,6 +1566,7 @@ def __init__(self, *args, **kw): self.debug = None def __getattr__(self, name): + """Get the specified attritbute of the connection.""" # All undefined members are same as in underlying connection: if self.db: return getattr(self.db, name) @@ -1526,6 +1574,7 @@ def __getattr__(self, name): raise _int_error('Connection is not valid') def __dir__(self): + """List all attributes of the connection.""" # Custom dir function including the attributes of the connection: attrs = set(self.__class__.__dict__) attrs.update(self.__dict__) @@ -1547,6 +1596,7 @@ def __exit__(self, et, ev, tb): self.rollback() def __del__(self): + """Delete the connection.""" try: db = self.db except AttributeError: @@ -1565,7 +1615,7 @@ def __del__(self): # Auxiliary methods def _do_debug(self, *args): - """Print a debug message""" + """Print a debug message.""" if self.debug: s = '\n'.join(str(arg) for arg in args) if isinstance(self.debug, str): @@ -1918,7 +1968,7 @@ def describe_prepared(self, name=None): return self.db.describe_prepared(name) def delete_prepared(self, name=None): - """Delete a prepared SQL statement + """Delete a prepared SQL statement. This deallocates a previously prepared SQL statement with the given name, or deallocates all prepared statements if you do not specify a @@ -2275,8 +2325,7 @@ def update(self, table, row=None, **kw): keyname = set(keyname) for n in attnames: if n in row and n not in keyname and n not in generated: - values.append('{} = {}'.format( - col(n), adapt(row[n], attnames[n]))) + values.append(f'{col(n)} = {adapt(row[n], attnames[n])}') if not values: return row values = ', '.join(values) @@ -2294,7 +2343,7 @@ def update(self, table, row=None, **kw): return row def upsert(self, table, row=None, **kw): - """Insert a row into a database table with conflict resolution + """Insert a row into a database table with conflict resolution. This method inserts a row into a table, but instead of raising a ProgrammingError exception in case a row with the same primary key diff --git a/pgdb.py b/pgdb.py index 44b6a83e..f61522bb 100644 --- a/pgdb.py +++ b/pgdb.py @@ -66,7 +66,7 @@ try: from _pg import version -except ImportError as e: +except ImportError as e: # noqa: F841 import os libpq = 'libpq.' if os.name == 'nt': @@ -99,14 +99,24 @@ # import objects from extension module from _pg import ( - Error, Warning, - DataError, DatabaseError, - IntegrityError, InterfaceError, InternalError, - NotSupportedError, OperationalError, ProgrammingError, - cast_array, cast_hstore, cast_record, RESULT_DQL, - connect, unescape_bytea, - version) + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + NotSupportedError, + OperationalError, + ProgrammingError, + Warning, + cast_array, + cast_hstore, + cast_record, + connect, + unescape_bytea, + version, +) __version__ = version @@ -127,17 +137,18 @@ 'get_typecast', 'set_typecast', 'reset_typecast', 'version', '__version__'] -from datetime import date, time, datetime, timedelta -from time import localtime -from decimal import Decimal as StdDecimal -from uuid import UUID as Uuid -from math import isnan, isinf from collections import namedtuple from collections.abc import Iterable -from inspect import signature +from datetime import date, datetime, time, timedelta +from decimal import Decimal as StdDecimal from functools import lru_cache, partial +from inspect import signature +from json import dumps as jsonencode +from json import loads as jsondecode +from math import isinf, isnan from re import compile as regex -from json import loads as jsondecode, dumps as jsonencode +from time import localtime +from uuid import UUID as Uuid Decimal = StdDecimal @@ -623,7 +634,7 @@ class TypeCache(dict): def __init__(self, cnx): """Initialize type cache for connection.""" - super(TypeCache, self).__init__() + super().__init__() self._escape_string = cnx.escape_string self._src = cnx.source() self._typecasts = LocalTypecasts() @@ -726,7 +737,7 @@ class _quotedict(dict): def __getitem__(self, key): # noinspection PyUnresolvedReferences - return self.quote(super(_quotedict, self).__getitem__(key)) + return self.quote(super().__getitem__(key)) # *** Error Messages *** @@ -777,7 +788,7 @@ def set_row_factory_size(maxsize): # *** Cursor Object *** -class Cursor(object): +class Cursor: """Cursor object.""" def __init__(self, dbcnx): @@ -1369,7 +1380,7 @@ def build_row_factory(self): # *** Connection Objects *** -class Connection(object): +class Connection: """Connection object.""" # expose the exceptions as attributes on the connection object @@ -1576,25 +1587,28 @@ class Type(frozenset): """ def __new__(cls, values): + """Create new type object.""" if isinstance(values, str): values = values.split() - return super(Type, cls).__new__(cls, values) + return super().__new__(cls, values) def __eq__(self, other): + """Check whether types are considered equal.""" if isinstance(other, str): if other.startswith('_'): other = other[1:] return other in self else: - return super(Type, self).__eq__(other) + return super().__eq__(other) def __ne__(self, other): + """Check whether types are not considered equal.""" if isinstance(other, str): if other.startswith('_'): other = other[1:] return other not in self else: - return super(Type, self).__ne__(other) + return super().__ne__(other) class ArrayType: @@ -1741,6 +1755,7 @@ def _quote(cls, s): return s def __str__(self): + """Create a printable representation of the hstore value.""" q = self._quote return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) @@ -1749,10 +1764,12 @@ class Json: """Construct a wrapper for holding an object serializable to JSON.""" def __init__(self, obj, encode=None): + """Initialize the JSON object.""" self.obj = obj self.encode = encode or jsonencode def __str__(self): + """Create a printable representation of the JSON object.""" obj = self.obj if isinstance(obj, str): return obj @@ -1763,9 +1780,11 @@ class Literal: """Construct a wrapper for holding a literal SQL string.""" def __init__(self, sql): + """Initialize literal SQL string.""" self.sql = sql def __str__(self): + """Return a printable representation of the SQL string.""" return self.sql __pg_repr__ = __str__ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..b1a184cc --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,75 @@ +[project] +name = "PyGreSQL" +version = "6.0" +requires-python = ">=3.7" +authors = [ + {name = "D'Arcy J. M. Cain", email = "darcy@pygresql.org"}, + {name = "Christoph Zwerschke", email = "cito@online.de"}, +] +description = "Python PostgreSQL interfaces" +readme = "README.rst" +keywords = ["pygresql", "postgresql", "database", "api", "dbapi"] +classifiers = [ + "Development Status :: 6 - Mature", + "Intended Audience :: Developers", + "License :: OSI Approved :: PostgreSQL License", + "Operating System :: OS Independent", + "Programming Language :: C", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: SQL", + "Topic :: Database", + "Topic :: Database :: Front-Ends", + "Topic :: Software Development :: Libraries :: Python Modules", +] + +[project.license] +file = "LICENSE.txt" + +[project.urls] +homepage = "https://pygresql.github.io/" +documentation = "https://pygresql.github.io/contents/" +source = "https://github.com/PyGreSQL/PyGreSQL" +issues = "https://github.com/PyGreSQL/PyGreSQL/issues/" +changelog = "https://pygresql.github.io/contents/changelog.html" +download = "https://pygresql.github.io/download/" +"mailing list" = "https://mail.vex.net/mailman/listinfo/pygresql" + +[tool.ruff] +line-length = 79 +select = [ + "E", # pycodestyle + "F", # pyflakes + "UP", # pyupgrade + "D", # pydocstyle +] +exclude = [ + "__pycache__", + "__pypackages__", + ".git", + ".tox", + ".venv", + ".devcontainer", + ".vscode", + "docs", + "build", + "dist", + "local", + "venv", +] + +[tool.ruff.per-file-ignores] +"tests/*.py" = ["D100", "D101", "D102", "D103", "D105", "D107"] + +[tool.setuptools] +py-modules = ["pg", "pgdb"] +license-files = ["LICENSE.txt"] + +[build-system] +requires = ["setuptools>=68", "wheel>=0.41"] +build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 08d43dae..29a84bf8 100755 --- a/setup.py +++ b/setup.py @@ -1,41 +1,11 @@ #!/usr/bin/python -# -# PyGreSQL - a Python interface for the PostgreSQL database. -# -# Copyright (c) 2023 by the PyGreSQL Development Team -# -# Please see the LICENSE.TXT file for specific restrictions. - -"""Setup script for PyGreSQL version 6.0 - -PyGreSQL is an open-source Python module that interfaces to a -PostgreSQL database. It wraps the lower level C API library libpq -to allow easy use of the powerful PostgreSQL features from Python. - -Authors and history: -* PyGreSQL written 1997 by D'Arcy J.M. Cain -* based on code written 1995 by Pascal Andre -* setup script created 2000 by Mark Alexander -* improved 2000 by Jeremy Hylton -* improved 2001 by Gerhard Haering -* improved 2006 to 2018 by Christoph Zwerschke - -Prerequisites to be installed: -* Python including devel package (header files and distutils) -* PostgreSQL libs and devel packages (header file of the libpq client) -* PostgreSQL pg_config tool (usually included in the devel package) - (the Windows installer has it as part of the database server feature) - -PyGreSQL currently supports Python versions 3.7 to 3.11, -and PostgreSQL versions 10 to 15. - -Use as follows: -python setup.py build_ext # to build the module -python setup.py install # to install it - -See docs.python.org/doc/install/ for more information on -using distutils to install Python programs. +"""Driver script for building PyGreSQL using setuptools. + +You can build the PyGreSQL distribution like this: + + pip install build + python -m build -C strict -C memory-size """ import os @@ -43,15 +13,12 @@ import re import sys import warnings -try: - from setuptools import setup -except ImportError: - from distutils.core import setup -from distutils.extension import Extension -from distutils.command.build_ext import build_ext from distutils.ccompiler import get_default_compiler from distutils.sysconfig import get_python_inc, get_python_lib +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext + version = '6.0' if not (3, 7) <= sys.version_info[:2] < (4, 0): @@ -63,10 +30,8 @@ # classic interface, and "pgdb" for the modern DB-API 2.0 interface. # These two top-level Python modules share the same C extension "_pg". -py_modules = ['pg', 'pgdb'] c_sources = ['pgmodule.c'] - def pg_config(s): """Retrieve information about installed version of PostgreSQL.""" f = os.popen(f'pg_config --{s}') @@ -118,6 +83,7 @@ def get_compiler(self): return self.compiler or get_default_compiler() def initialize_options(self): + """Initialize the supported options with default values.""" build_ext.initialize_options(self) self.strict = False self.memory_size = None @@ -157,45 +123,10 @@ def finalize_options(self): setup( name="PyGreSQL", version=version, - description="Python PostgreSQL Interfaces", - long_description=__doc__.split('\n\n', 2)[1], # first passage - long_description_content_type='text/plain', - keywords="pygresql postgresql database api dbapi", - author="D'Arcy J. M. Cain", - author_email="darcy@PyGreSQL.org", - url="http://www.pygresql.org", - download_url="http://www.pygresql.org/download/", - project_urls={ - "Documentation": "https://pygresql.org/contents/", - "Issue Tracker": "https://github.com/PyGreSQL/PyGreSQL/issues/", - "Mailing List": "https://mail.vex.net/mailman/listinfo/pygresql", - "Source Code": "https://github.com/PyGreSQL/PyGreSQL"}, - platforms=["any"], - license="PostgreSQL", - py_modules=py_modules, ext_modules=[Extension( '_pg', c_sources, include_dirs=include_dirs, library_dirs=library_dirs, define_macros=define_macros, undef_macros=undef_macros, libraries=libraries, extra_compile_args=extra_compile_args)], - zip_safe=False, cmdclass=dict(build_ext=build_pg_ext), - test_suite='tests.discover', - classifiers=[ - "Development Status :: 6 - Mature", - "Intended Audience :: Developers", - "License :: OSI Approved :: PostgreSQL License", - "Operating System :: OS Independent", - "Programming Language :: C", - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - "Programming Language :: SQL", - "Topic :: Database", - "Topic :: Database :: Front-Ends", - "Topic :: Software Development :: Libraries :: Python Modules"] ) diff --git a/tests/config.py b/tests/config.py index 6e2ebd3c..acd8559a 100644 --- a/tests/config.py +++ b/tests/config.py @@ -26,9 +26,9 @@ dbport = int(dbport) try: - from .LOCAL_PyGreSQL import * # noqa: F401 + from .LOCAL_PyGreSQL import * # noqa: F403 except (ImportError, ValueError): try: - from LOCAL_PyGreSQL import * # noqa: F401 + from LOCAL_PyGreSQL import * # noqa: F403 except ImportError: pass diff --git a/tests/dbapi20.py b/tests/dbapi20.py index 32045fa4..e76e5fb9 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -9,9 +9,8 @@ __version__ = '1.15.0' -import unittest import time - +import unittest from typing import Any, Dict, Tuple @@ -462,7 +461,7 @@ def test_fetchone(self): con.close() def test_next(self): - """Extension for getting the next row""" + """Test extension for getting the next row.""" con = self._connect() try: cur = con.cursor() diff --git a/tests/test_classic.py b/tests/test_classic.py index 6319d5d5..d6763074 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -1,17 +1,21 @@ #!/usr/bin/python import unittest - from functools import partial -from time import sleep from threading import Thread +from time import sleep from pg import ( - DB, NotificationHandler, - Error, DatabaseError, IntegrityError, - NotSupportedError, ProgrammingError) + DB, + DatabaseError, + Error, + IntegrityError, + NotificationHandler, + NotSupportedError, + ProgrammingError, +) -from .config import dbname, dbhost, dbport, dbuser, dbpasswd +from .config import dbhost, dbname, dbpasswd, dbport, dbuser def open_db(): @@ -28,7 +32,7 @@ class UtilityTest(unittest.TestCase): @classmethod def setUpClass(cls): - """Recreate test tables and schemas""" + """Recreate test tables and schemas.""" db = open_db() try: db.query("DROP VIEW _test_vschema") @@ -56,7 +60,7 @@ def setUpClass(cls): db.close() def setUp(self): - """Setup test tables or empty them if they already exist.""" + """Set up test tables or empty them if they already exist.""" db = open_db() db.query("TRUNCATE TABLE _test_schema") for t in ('_test1', '_test2'): @@ -64,12 +68,12 @@ def setUp(self): db.close() def test_invalid_name(self): - """Make sure that invalid table names are caught""" + """Make sure that invalid table names are caught.""" db = open_db() self.assertRaises(NotSupportedError, db.get_attnames, 'x.y.z') def test_schema(self): - """Does it differentiate the same table name in different schemas""" + """Check differentiation of same table name in different schemas.""" db = open_db() # see if they differentiate the table names properly self.assertEqual( diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index f7ca2a46..4436239d 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -9,18 +9,17 @@ These tests need a database to test against. """ -import unittest +import os import threading import time -import os - +import unittest from collections import namedtuple from collections.abc import Iterable from decimal import Decimal import pg # the module under test -from .config import dbname, dbhost, dbport, dbuser, dbpasswd +from .config import dbhost, dbname, dbpasswd, dbport, dbuser windows = os.name == 'nt' @@ -284,7 +283,7 @@ def testAllQueryMembers(self): def testMethodEndcopy(self): try: self.connection.endcopy() - except IOError: + except OSError: pass def testMethodClose(self): @@ -864,7 +863,7 @@ def testGetresultUtf8(self): self.skipTest("database does not support utf8") self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('utf8') + q = q.encode() # pass the query as bytes v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) @@ -879,7 +878,7 @@ def testDictresultUtf8(self): self.skipTest("database does not support utf8") self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('utf8') + q = q.encode() v = self.c.query(q).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -2013,7 +2012,7 @@ def testInserttableByteValues(self): 0.0, 0.0, 0.0, '0.0', c, 'bäd', 'bäd', "käse сыр pont-l'évêque") row_bytes = tuple( - s.encode('utf-8') if isinstance(s, str) else s + s.encode() if isinstance(s, str) else s for s in row_unicode) data = [row_bytes] * 2 self.c.inserttable('test', data) @@ -2098,7 +2097,7 @@ def testInserttableFromQuery(self): None, 'c', 'v4', None, 'text')]) def testInserttableSpecialChars(self): - class S(object): + class S: def __repr__(self): return s @@ -2187,7 +2186,7 @@ def testPutlineBytesAndUnicode(self): self.skipTest('database does not support utf8') query("copy test from stdin") try: - putline("47\tkäse\n".encode('utf8')) + putline("47\tkäse\n".encode()) putline("35\twürstel\n") finally: self.c.endcopy() @@ -2212,7 +2211,7 @@ def testGetline(self): finally: try: self.c.endcopy() - except IOError: + except OSError: pass def testGetlineBytesAndUnicode(self): @@ -2222,7 +2221,7 @@ def testGetlineBytesAndUnicode(self): query("select 'käse+würstel'") except (pg.DataError, pg.NotSupportedError): self.skipTest('database does not support utf8') - data = [(54, 'käse'.encode('utf8')), (73, 'würstel')] + data = [(54, 'käse'.encode()), (73, 'würstel')] self.c.inserttable('test', data) query("copy test to stdout") try: @@ -2236,7 +2235,7 @@ def testGetlineBytesAndUnicode(self): finally: try: self.c.endcopy() - except IOError: + except OSError: pass def testParameterChecks(self): @@ -2715,9 +2714,9 @@ def testEscapeString(self): r = f('plain') self.assertIsInstance(r, str) self.assertEqual(r, 'plain') - r = f("das is' käse".encode('utf-8')) + r = f("das is' käse".encode()) self.assertIsInstance(r, bytes) - self.assertEqual(r, "das is'' käse".encode('utf-8')) + self.assertEqual(r, "das is'' käse".encode()) r = f("that's cheesy") self.assertIsInstance(r, str) self.assertEqual(r, "that''s cheesy") @@ -2733,7 +2732,7 @@ def testEscapeBytea(self): r = f('plain') self.assertIsInstance(r, str) self.assertEqual(r, 'plain') - r = f("das is' käse".encode('utf-8')) + r = f("das is' käse".encode()) self.assertIsInstance(r, bytes) self.assertEqual(r, b"das is'' k\\\\303\\\\244se") r = f("that's cheesy") diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 79c962a4..8e64949d 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -9,24 +9,23 @@ These tests need a database to test against. """ -import unittest -import os -import sys import gc import json +import os +import sys import tempfile - -import pg # the module under test - +import unittest from collections import OrderedDict +from datetime import date, datetime, time, timedelta from decimal import Decimal -from datetime import date, time, datetime, timedelta from io import StringIO -from uuid import UUID -from time import strftime from operator import itemgetter +from time import strftime +from uuid import UUID + +import pg # the module under test -from .config import dbname, dbhost, dbport, dbuser, dbpasswd +from .config import dbhost, dbname, dbpasswd, dbport, dbuser debug = False # let DB wrapper print debugging output @@ -337,7 +336,7 @@ def testMethodQueryDataError(self): def testMethodEndcopy(self): try: self.db.endcopy() - except IOError: + except OSError: pass def testMethodClose(self): @@ -507,9 +506,9 @@ def testEscapeLiteral(self): r = f("plain") self.assertIsInstance(r, str) self.assertEqual(r, "'plain'") - r = f("that's käse".encode('utf-8')) + r = f("that's käse".encode()) self.assertIsInstance(r, bytes) - self.assertEqual(r, "'that''s käse'".encode('utf-8')) + self.assertEqual(r, "'that''s käse'".encode()) r = f("that's käse") self.assertIsInstance(r, str) self.assertEqual(r, "'that''s käse'") @@ -526,9 +525,9 @@ def testEscapeIdentifier(self): r = f("plain") self.assertIsInstance(r, str) self.assertEqual(r, '"plain"') - r = f("that's käse".encode('utf-8')) + r = f("that's käse".encode()) self.assertIsInstance(r, bytes) - self.assertEqual(r, '"that\'s käse"'.encode('utf-8')) + self.assertEqual(r, '"that\'s käse"'.encode()) r = f("that's käse") self.assertIsInstance(r, str) self.assertEqual(r, '"that\'s käse"') @@ -545,9 +544,9 @@ def testEscapeString(self): r = f("plain") self.assertIsInstance(r, str) self.assertEqual(r, "plain") - r = f("that's käse".encode('utf-8')) + r = f("that's käse".encode()) self.assertIsInstance(r, bytes) - self.assertEqual(r, "that''s käse".encode('utf-8')) + self.assertEqual(r, "that''s käse".encode()) r = f("that's käse") self.assertIsInstance(r, str) self.assertEqual(r, "that''s käse") @@ -564,7 +563,7 @@ def testEscapeBytea(self): r = f('plain') self.assertIsInstance(r, str) self.assertEqual(r, '\\x706c61696e') - r = f("das is' käse".encode('utf-8')) + r = f("das is' käse".encode()) self.assertIsInstance(r, bytes) self.assertEqual(r, b'\\x64617320697327206bc3a47365') r = f("das is' käse") @@ -582,10 +581,10 @@ def testUnescapeBytea(self): self.assertEqual(r, b'plain') r = f(b"das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, "das is' käse".encode('utf8')) + self.assertEqual(r, "das is' käse".encode()) r = f("das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, "das is' käse".encode('utf8')) + self.assertEqual(r, "das is' käse".encode()) self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!') self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e') self.assertEqual(f(r'\\x746861742773206be47365'), @@ -4320,11 +4319,11 @@ def setUpClass(cls): db = DB() cls.regtypes = not db.use_regtypes() db.close() - super(TestDBClassNonStdOpts, cls).setUpClass() + super().setUpClass() @classmethod def tearDownClass(cls): - super(TestDBClassNonStdOpts, cls).tearDownClass() + super().tearDownClass() cls.reset_option('jsondecode') cls.reset_option('bool') cls.reset_option('array') diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index adddc8ce..914450f5 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -9,15 +9,13 @@ These tests do not need a database to test against. """ -import unittest - import json import re +import unittest +from datetime import timedelta import pg # the module under test -from datetime import timedelta - class TestHasConnect(unittest.TestCase): """Test existence of basic pg module functions.""" @@ -900,10 +898,10 @@ def testUnescapeBytea(self): self.assertEqual(r, b'plain') r = f(b"das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, "das is' käse".encode('utf-8')) + self.assertEqual(r, "das is' käse".encode()) r = f("das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, "das is' käse".encode('utf-8')) + self.assertEqual(r, "das is' käse".encode()) r = f(b'O\\000ps\\377!') self.assertEqual(r, b'O\x00ps\xff!') r = f('O\\000ps\\377!') diff --git a/tests/test_classic_largeobj.py b/tests/test_classic_largeobj.py index bdf3a613..039ca51f 100755 --- a/tests/test_classic_largeobj.py +++ b/tests/test_classic_largeobj.py @@ -9,13 +9,13 @@ These tests need a database to test against. """ -import unittest -import tempfile import os +import tempfile +import unittest import pg # the module under test -from .config import dbname, dbhost, dbport, dbuser, dbpasswd +from .config import dbhost, dbname, dbpasswd, dbport, dbuser windows = os.name == 'nt' @@ -151,11 +151,11 @@ def tearDown(self): if self.obj.oid: try: self.obj.close() - except (SystemError, IOError): + except (SystemError, OSError): pass try: self.obj.unlink() - except (SystemError, IOError): + except (SystemError, OSError): pass del self.obj try: @@ -270,12 +270,12 @@ def testWriteLatin1Bytes(self): def testWriteUtf8Bytes(self): read = self.obj.read self.obj.open(pg.INV_WRITE) - self.obj.write('käse'.encode('utf8')) + self.obj.write('käse'.encode()) self.obj.close() self.obj.open(pg.INV_READ) r = read(80) self.assertIsInstance(r, bytes) - self.assertEqual(r.decode('utf8'), 'käse') + self.assertEqual(r.decode(), 'käse') def testWriteUtf8String(self): read = self.obj.read @@ -285,7 +285,7 @@ def testWriteUtf8String(self): self.obj.open(pg.INV_READ) r = read(80) self.assertIsInstance(r, bytes) - self.assertEqual(r.decode('utf8'), 'käse') + self.assertEqual(r.decode(), 'käse') def testSeek(self): seek = self.obj.seek diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 6062a4fa..8522fbc3 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -24,7 +24,7 @@ def __init__(self, value): self.value = value def __pg_repr__(self): - return "B'{0:b}'".format(self.value) + return f"B'{self.value:b}'" class test_PyGreSQL(dbapi20.DatabaseAPI20Test): diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index d461825c..c4e8dd74 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -10,24 +10,23 @@ """ import unittest - from collections.abc import Iterable import pgdb # the module under test -from .config import dbname, dbhost, dbport, dbuser, dbpasswd +from .config import dbhost, dbname, dbpasswd, dbport, dbuser class InputStream: def __init__(self, data): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode() self.data = data or b'' self.sizes = [] def __str__(self): - data = self.data.decode('utf-8') + data = self.data.decode() return data def __len__(self): @@ -50,7 +49,7 @@ def __init__(self): self.sizes = [] def __str__(self): - data = self.data.decode('utf-8') + data = self.data.decode() return data def __len__(self): @@ -58,7 +57,7 @@ def __len__(self): def write(self, data): if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode() self.data += data self.sizes.append(len(data)) @@ -188,10 +187,10 @@ class TestCopyFrom(TestCopy): """Test the copy_from method.""" def tearDown(self): - super(TestCopyFrom, self).tearDown() + super().tearDown() self.setUp() self.truncate_table() - super(TestCopyFrom, self).tearDown() + super().tearDown() def copy_from(self, stream, **options): return self.cursor.copy_from(stream, 'copytest', **options) @@ -202,7 +201,7 @@ def data_file(self): def test_bad_params(self): call = self.cursor.copy_from - call('0\t', 'copytest'), self.cursor + call('0\t', 'copytest') call('1\t', 'copytest', format='text', sep='\t', null='', columns=['id', 'name']) self.assertRaises(TypeError, call) @@ -247,7 +246,7 @@ def test_input_bytes(self): self.copy_from(b'42\tHello, world!') self.assertEqual(self.table_data, [(42, 'Hello, world!')]) self.truncate_table() - self.copy_from(self.data_text.encode('utf-8')) + self.copy_from(self.data_text.encode()) self.check_table() def test_input_iterable(self): @@ -263,7 +262,7 @@ def test_input_iterable_with_newlines(self): self.check_table() def test_input_iterable_bytes(self): - self.copy_from(row.encode('utf-8') + self.copy_from(row.encode() for row in self.data_text.splitlines()) self.check_table() @@ -368,7 +367,7 @@ class TestCopyTo(TestCopy): @classmethod def setUpClass(cls): - super(TestCopyTo, cls).setUpClass() + super().setUpClass() con = cls.connect() cur = con.cursor() cur.execute("set client_encoding=utf8") @@ -423,7 +422,7 @@ def test_generator_bytes(self): self.assertEqual(len(rows), 3) rows = b''.join(rows) self.assertIsInstance(rows, bytes) - self.assertEqual(rows, self.data_text.encode('utf-8')) + self.assertEqual(rows, self.data_text.encode()) def test_rowcount_increment(self): ret = self.copy_to() @@ -436,7 +435,7 @@ def test_decode(self): ret_decoded = ''.join(self.copy_to(decode=True)) self.assertIsInstance(ret_raw, bytes) self.assertIsInstance(ret_decoded, str) - self.assertEqual(ret_decoded, ret_raw.decode('utf-8')) + self.assertEqual(ret_decoded, ret_raw.decode()) self.check_rowcount() def test_sep(self): @@ -521,7 +520,7 @@ def test_file(self): ret = self.copy_to(stream) self.assertIs(ret, self.cursor) self.assertEqual(str(stream), self.data_text) - data = self.data_text.encode('utf-8') + data = self.data_text.encode() sizes = [len(row) + 1 for row in data.splitlines()] self.assertEqual(stream.sizes, sizes) self.check_rowcount() diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index a497914b..1a43ab7d 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -12,7 +12,7 @@ class TestClassicTutorial(unittest.TestCase): """Test the First Steps Tutorial for the classic interface.""" def setUp(self): - """Setup test tables or empty them if they already exist.""" + """Set up test tables or empty them if they already exist.""" db = DB(dbname, dbhost, dbport, user=dbuser, passwd=dbpasswd) db.query("set datestyle to 'iso'") db.query("set default_with_oids=false") @@ -107,7 +107,7 @@ class TestDbApi20Tutorial(unittest.TestCase): """Test the First Steps Tutorial for the DB-API 2.0 interface.""" def setUp(self): - """Setup test tables or empty them if they already exist.""" + """Set up test tables or empty them if they already exist.""" host = f"{dbhost or ''}:{dbport or -1}" con = connect(database=dbname, host=host, user=dbuser, password=dbpasswd) diff --git a/tox.ini b/tox.ini index 23fb9379..9ddc3a75 100644 --- a/tox.ini +++ b/tox.ini @@ -1,13 +1,13 @@ # config file for tox [tox] -envlist = py3{7,8,9,10,11},flake8,docs +envlist = py3{7,8,9,10,11},ruff,docs -[testenv:flake8] +[testenv:ruff] basepython = python3.11 -deps = flake8>=6,<7 +deps = ruff>=0.0.287 commands = - flake8 setup.py pg.py pgdb.py tests + ruff setup.py pg.py pgdb.py tests [testenv:docs] basepython = python3.11 @@ -16,6 +16,15 @@ deps = commands = sphinx-build -b html -nEW docs docs/_build/html +[testenv:build] +basepython = python3.11 +deps = + setuptools>=68 + wheel>=0.41 + build>=0.10 +commands = + python -m build -n -C strict -C memory-size + [testenv] passenv = PG* From 8e0859aa8a6b963a3ee7c9687f9e4123cf0bce59 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 Sep 2023 22:02:12 +0200 Subject: [PATCH 037/118] Remove deprecated pgnotify function in pg module --- pg.py | 10 --------- tests/test_classic_notification.py | 34 ++---------------------------- 2 files changed, 2 insertions(+), 42 deletions(-) diff --git a/pg.py b/pg.py index 70de429e..6d4db899 100644 --- a/pg.py +++ b/pg.py @@ -1483,16 +1483,6 @@ def __call__(self): self.callback(None) -def pgnotify(*args, **kw): - """Create a notification handler. - - Same as NotificationHandler, under the traditional name. - """ - warnings.warn("pgnotify is deprecated, use NotificationHandler instead", - DeprecationWarning, stacklevel=2) - return NotificationHandler(*args, **kw) - - # The actual PostgreSQL database connection interface: class DB: diff --git a/tests/test_classic_notification.py b/tests/test_classic_notification.py index dcc06382..12d0dee8 100755 --- a/tests/test_classic_notification.py +++ b/tests/test_classic_notification.py @@ -11,12 +11,12 @@ import unittest import warnings -from time import sleep from threading import Thread +from time import sleep import pg # the module under test -from .config import dbname, dbhost, dbport, dbuser, dbpasswd +from .config import dbhost, dbname, dbpasswd, dbport, dbuser debug = False # let DB wrapper print debugging output @@ -29,36 +29,6 @@ def DB(): return db -class TestPyNotifyAlias(unittest.TestCase): - """Test alternative ways of creating a NotificationHandler.""" - - def callback(self): - self.fail('Callback should not be called in this test') - - def testPgNotify(self): - db = DB() - arg_dict = {} - args = ('test_event', self.callback, arg_dict) - kwargs = dict(timeout=2, stop_event='test_stop') - with warnings.catch_warnings(record=True) as warn_msgs: - warnings.simplefilter("always") - # noinspection PyDeprecation - handler1 = pg.pgnotify(db, *args, **kwargs) - self.assertEqual(len(warn_msgs), 1) - warn_msg = warn_msgs[0] - self.assertTrue(issubclass(warn_msg.category, DeprecationWarning)) - self.assertIn('deprecated', str(warn_msg.message)) - self.assertIsInstance(handler1, pg.NotificationHandler) - handler2 = db.notification_handler(*args, **kwargs) - self.assertIsInstance(handler2, pg.NotificationHandler) - self.assertIs(handler1.db, handler2.db) - self.assertEqual(handler1.event, handler2.event) - self.assertIs(handler1.callback, handler2.callback) - self.assertIs(handler1.arg_dict, handler2.arg_dict) - self.assertEqual(handler1.timeout, handler2.timeout) - self.assertEqual(handler1.stop_event, handler2.stop_event) - - class TestSyncNotification(unittest.TestCase): """Test notification handler running in the same thread.""" From 7d055b415e25f25137e484caf8b244025eac4196 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 Sep 2023 22:30:45 +0200 Subject: [PATCH 038/118] Remove deprecated pg.Query.ntuples method --- docs/contents/changelog.rst | 2 ++ docs/contents/pg/connection.rst | 8 ++++---- docs/contents/pg/query.rst | 16 ---------------- pgquery.c | 12 ------------ tests/test_classic_connection.py | 18 +----------------- tests/test_classic_notification.py | 1 - 6 files changed, 7 insertions(+), 50 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index e2b68425..67408993 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -5,6 +5,8 @@ Version 6.0 (to be released) ---------------------------- - Removed support for Python versions older than 3.7 (released June 2017) and PostgreSQL older than version 10 (released October 2017). +- Removed deprecated function `pg.pgnotify()`. +- Removed the deprecated method `ntuples()` of the `pg.Query` object. - Modernized code and tools for development, testing, linting and building. Version 5.2.5 (2023-08-28) diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index 1adf29d1..b175a2a0 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -114,10 +114,10 @@ result codes: if :meth:`Connection.query` returns `None`, the result-returning methods will return an empty string (`''`). It's still necessary to call a result-returning method until it returns `None`. -:meth:`Query.listfields`, :meth:`Query.fieldname`, :meth:`Query.fieldnum`, -and :meth:`Query.ntuples` only work after a call to a result-returning method -with a non-`None` return value. :meth:`Query.ntuples` returns only the number -of rows returned by the previous result-returning method. +:meth:`Query.listfields`, :meth:`Query.fieldname` and :meth:`Query.fieldnum` +only work after a call to a result-returning method with a non-``None`` return +value. Calling ``len()`` on a :class:`Query` object returns the number of rows +of the previous result-returning method. If multiple semi-colon-delimited statements are passed to :meth:`Connection.query`, only the results of the last statement are returned diff --git a/docs/contents/pg/query.rst b/docs/contents/pg/query.rst index 9e2998f8..3232c115 100644 --- a/docs/contents/pg/query.rst +++ b/docs/contents/pg/query.rst @@ -400,22 +400,6 @@ negative value if it is of variable size, and a type-specific modifier value. .. versionadded:: 5.2 -ntuples -- return number of tuples in query object --------------------------------------------------- - -.. method:: Query.ntuples() - - Return number of tuples in query object - - :returns: number of tuples in :class:`Query` - :rtype: int - :raises TypeError: Too many arguments. - -This method returns the number of tuples in the query result. - -.. deprecated:: 5.1 - You can use the normal :func:`len` function instead. - memsize -- return number of bytes allocated by query result ----------------------------------------------------------- diff --git a/pgquery.c b/pgquery.c index 1196889a..194bfaa1 100644 --- a/pgquery.c +++ b/pgquery.c @@ -260,16 +260,6 @@ query_memsize(queryObject *self, PyObject *noargs) #endif /* MEMORY_SIZE */ } -/* Get number of rows. */ -static char query_ntuples__doc__[] = -"ntuples() -- return number of tuples returned by query"; - -static PyObject * -query_ntuples(queryObject *self, PyObject *noargs) -{ - return PyLong_FromLong(self->max_row); -} - /* List field names from query result. */ static char query_listfields__doc__[] = "listfields() -- List field names from result"; @@ -948,8 +938,6 @@ static struct PyMethodDef query_methods[] = { METH_NOARGS, query_listfields__doc__}, {"fieldinfo", (PyCFunction) query_fieldinfo, METH_VARARGS, query_fieldinfo__doc__}, - {"ntuples", (PyCFunction) query_ntuples, - METH_NOARGS, query_ntuples__doc__}, {"memsize", (PyCFunction) query_memsize, METH_NOARGS, query_memsize__doc__}, {NULL, NULL} diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 4436239d..dc7311c4 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -268,7 +268,7 @@ def testAllQueryMembers(self): query = self.connection.query("select true where false") members = ''' dictiter dictresult fieldinfo fieldname fieldnum getresult - listfields memsize namediter namedresult ntuples + listfields memsize namediter namedresult one onedict onenamed onescalar scalariter scalarresult single singledict singlenamed singlescalar '''.split() @@ -712,22 +712,6 @@ def testFieldInfoName(self): self.assertRaises(IndexError, f, -1) self.assertRaises(IndexError, f, 4) - def testNtuples(self): # deprecated - q = "select 1 where false" - r = self.c.query(q).ntuples() - self.assertIsInstance(r, int) - self.assertEqual(r, 0) - q = ("select 1 as a, 2 as b, 3 as c, 4 as d" - " union select 5 as a, 6 as b, 7 as c, 8 as d") - r = self.c.query(q).ntuples() - self.assertIsInstance(r, int) - self.assertEqual(r, 2) - q = ("select 1 union select 2 union select 3" - " union select 4 union select 5 union select 6") - r = self.c.query(q).ntuples() - self.assertIsInstance(r, int) - self.assertEqual(r, 6) - def testLen(self): q = "select 1 where false" self.assertEqual(len(self.c.query(q)), 0) diff --git a/tests/test_classic_notification.py b/tests/test_classic_notification.py index 12d0dee8..13a341dd 100755 --- a/tests/test_classic_notification.py +++ b/tests/test_classic_notification.py @@ -10,7 +10,6 @@ """ import unittest -import warnings from threading import Thread from time import sleep From 536805c83e6c5e1ff047d609e71a41eceeadaba9 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 Sep 2023 22:59:03 +0200 Subject: [PATCH 039/118] Make sure import statements are sorted --- pg.py | 1 - pyproject.toml | 1 + tests/test_dbapi20.py | 20 +++++++------------- tests/test_tutorial.py | 2 +- 4 files changed, 9 insertions(+), 15 deletions(-) diff --git a/pg.py b/pg.py index 6d4db899..f8dfb1be 100644 --- a/pg.py +++ b/pg.py @@ -150,7 +150,6 @@ 'version', '__version__'] import select -import warnings import weakref from collections import OrderedDict, namedtuple from datetime import date, datetime, time, timedelta diff --git a/pyproject.toml b/pyproject.toml index b1a184cc..dfe59c2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ line-length = 79 select = [ "E", # pycodestyle "F", # pyflakes + "I", # isort "UP", # pyupgrade "D", # pydocstyle ] diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 8522fbc3..380caf52 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -2,19 +2,13 @@ import gc import unittest - -from datetime import date, time, datetime, timedelta, timezone +from datetime import date, datetime, time, timedelta, timezone from uuid import UUID as Uuid import pgdb -try: - from . import dbapi20 -except (ImportError, ValueError, SystemError): - # noinspection PyUnresolvedReferences - import dbapi20 - -from .config import dbname, dbhost, dbport, dbuser, dbpasswd +from . import dbapi20 +from .config import dbhost, dbname, dbpasswd, dbport, dbuser class PgBitString: @@ -27,7 +21,7 @@ def __pg_repr__(self): return f"B'{self.value:b}'" -class test_PyGreSQL(dbapi20.DatabaseAPI20Test): +class TestPgDb(dbapi20.DatabaseAPI20Test): driver = pgdb connect_args = () @@ -38,7 +32,7 @@ class test_PyGreSQL(dbapi20.DatabaseAPI20Test): lower_func = 'lower' # For stored procedure test def setUp(self): - dbapi20.DatabaseAPI20Test.setUp(self) + super().setUp() try: con = self._connect() con.close() @@ -52,7 +46,7 @@ def setUp(self): db.query('create database ' + dbname) def tearDown(self): - dbapi20.DatabaseAPI20Test.tearDown(self) + super().tearDown() def test_version(self): v = pgdb.version @@ -542,7 +536,7 @@ def test_sqlstate(self): def test_float(self): nan, inf = float('nan'), float('inf') - from math import isnan, isinf + from math import isinf, isnan self.assertTrue(isnan(nan) and not isinf(nan)) self.assertTrue(isinf(inf) and not isnan(inf)) values = [0, 1, 0.03125, -42.53125, nan, inf, -inf, diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index 1a43ab7d..3f76f39b 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -5,7 +5,7 @@ from pg import DB from pgdb import connect -from .config import dbname, dbhost, dbport, dbuser, dbpasswd +from .config import dbhost, dbname, dbpasswd, dbport, dbuser class TestClassicTutorial(unittest.TestCase): From 56a034bde9f32d2d4d325eb6b139fdba0526d59f Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 1 Sep 2023 23:54:56 +0200 Subject: [PATCH 040/118] Use PEP8 naming conventions for test methods --- pgdb.py | 22 +- pyproject.toml | 1 + setup.py | 2 +- tests/dbapi20.py | 68 ++-- tests/test_classic_connection.py | 444 ++++++++++----------- tests/test_classic_dbwrapper.py | 604 +++++++++++++++-------------- tests/test_classic_functions.py | 118 +++--- tests/test_classic_largeobj.py | 52 +-- tests/test_classic_notification.py | 46 +-- tests/test_dbapi20.py | 12 +- 10 files changed, 687 insertions(+), 682 deletions(-) diff --git a/pgdb.py b/pgdb.py index f61522bb..5752ac4d 100644 --- a/pgdb.py +++ b/pgdb.py @@ -148,7 +148,7 @@ from math import isinf, isnan from re import compile as regex from time import localtime -from uuid import UUID as Uuid +from uuid import UUID as Uuid # noqa: N811 Decimal = StdDecimal @@ -729,7 +729,7 @@ def row_caster(row): return row_caster -class _quotedict(dict): +class _QuoteDict(dict): """Dictionary with auto quoting of its items. The quote attribute must be set to the desired quote function. @@ -897,7 +897,7 @@ def _quoteparams(self, string, parameters): except (TypeError, ValueError): return string # silently accept unescaped quotes if isinstance(parameters, dict): - parameters = _quotedict(parameters) + parameters = _QuoteDict(parameters) parameters.quote = self._quote else: parameters = tuple(map(self._quote, parameters)) @@ -1687,34 +1687,35 @@ def __ne__(self, other): # Mandatory type helpers defined by DB-API 2 specs: -def Date(year, month, day): +def Date(year, month, day): # noqa: N802 """Construct an object holding a date value.""" return date(year, month, day) -def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None): +def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None): # noqa: N802 """Construct an object holding a time value.""" return time(hour, minute, second, microsecond, tzinfo) -def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0, +def Timestamp(year, month, day, # noqa: N802 + hour=0, minute=0, second=0, microsecond=0, tzinfo=None): """Construct an object holding a time stamp value.""" return datetime(year, month, day, hour, minute, second, microsecond, tzinfo) -def DateFromTicks(ticks): +def DateFromTicks(ticks): # noqa: N802 """Construct an object holding a date value from the given ticks value.""" return Date(*localtime(ticks)[:3]) -def TimeFromTicks(ticks): +def TimeFromTicks(ticks): # noqa: N802 """Construct an object holding a time value from the given ticks value.""" return Time(*localtime(ticks)[3:6]) -def TimestampFromTicks(ticks): +def TimestampFromTicks(ticks): # noqa: N802 """Construct an object holding a time stamp from the given ticks value.""" return Timestamp(*localtime(ticks)[:6]) @@ -1725,7 +1726,8 @@ class Binary(bytes): # Additional type helpers for PyGreSQL: -def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0): +def Interval(days, # noqa: N802 + hours=0, minutes=0, seconds=0, microseconds=0): """Construct an object holding a time interval value.""" return timedelta(days, hours=hours, minutes=minutes, seconds=seconds, microseconds=microseconds) diff --git a/pyproject.toml b/pyproject.toml index dfe59c2f..9603b825 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ select = [ "E", # pycodestyle "F", # pyflakes "I", # isort + "N", # pep8-naming "UP", # pyupgrade "D", # pydocstyle ] diff --git a/setup.py b/setup.py index 29a84bf8..09c6e2f8 100755 --- a/setup.py +++ b/setup.py @@ -62,7 +62,7 @@ def pg_version(): extra_compile_args = ['-O2', '-funsigned-char', '-Wall', '-Wconversion'] -class build_pg_ext(build_ext): +class build_pg_ext(build_ext): # noqa: N801 """Customized build_ext command for PyGreSQL.""" description = "build the PyGreSQL C extension" diff --git a/tests/dbapi20.py b/tests/dbapi20.py index e76e5fb9..bb913475 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -55,10 +55,10 @@ class mytest(dbapi20.DatabaseAPI20Test): # Some drivers may need to override these helpers, for example adding # a 'commit' after the execute. - def executeDDL1(self, cursor): + def execute_ddl1(self, cursor): cursor.execute(self.ddl1) - def executeDDL2(self, cursor): + def execute_ddl2(self, cursor): cursor.execute(self.ddl2) def setUp(self): @@ -134,7 +134,7 @@ def test_paramstyle(self): except AttributeError: self.fail("Driver doesn't define paramstyle") - def test_Exceptions(self): + def test_exceptions(self): # Make sure required exceptions exist, and are in the # defined hierarchy. sub = issubclass @@ -149,7 +149,7 @@ def test_Exceptions(self): self.assertTrue(sub(self.driver.ProgrammingError, self.driver.Error)) self.assertTrue(sub(self.driver.NotSupportedError, self.driver.Error)) - def test_ExceptionsAsConnectionAttributes(self): + def test_exceptions_as_connection_attributes(self): # OPTIONAL EXTENSION # Test for the optional DB API 2.0 extension, where the exceptions # are exposed as attributes on the Connection object @@ -202,7 +202,7 @@ def test_cursor_isolation(self): # the documented transaction isolation level cur1 = con.cursor() cur2 = con.cursor() - self.executeDDL1(cur1) + self.execute_ddl1(cur1) cur1.execute(f"{self.insert} into {self.table_prefix}booze" " values ('Victoria Bitter')") cur2.execute(f"select name from {self.table_prefix}booze") @@ -217,7 +217,7 @@ def test_description(self): con = self._connect() try: cur = con.cursor() - self.executeDDL1(cur) + self.execute_ddl1(cur) self.assertIsNone( cur.description, 'cursor.description should be none after executing a' @@ -238,7 +238,7 @@ def test_description(self): f' Got: {cur.description[0][1]!r}') # Make sure self.description gets reset - self.executeDDL2(cur) + self.execute_ddl2(cur) self.assertIsNone( cur.description, 'cursor.description not being set to None when executing' @@ -250,7 +250,7 @@ def test_rowcount(self): con = self._connect() try: cur = con.cursor() - self.executeDDL1(cur) + self.execute_ddl1(cur) self.assertIn( cur.rowcount, (-1, 0), # Bug #543885 'cursor.rowcount should be -1 or 0 after executing no-result' @@ -266,7 +266,7 @@ def test_rowcount(self): cur.rowcount, (-1, 1), 'cursor.rowcount should == number of rows returned, or' ' set to -1 after executing a select statement') - self.executeDDL2(cur) + self.execute_ddl2(cur) self.assertIn( cur.rowcount, (-1, 0), # Bug #543885 'cursor.rowcount should be -1 or 0 after executing no-result' @@ -303,7 +303,7 @@ def test_close(self): # cursor.execute should raise an Error if called after connection # closed - self.assertRaises(self.driver.Error, self.executeDDL1, cur) + self.assertRaises(self.driver.Error, self.execute_ddl1, cur) # connection.commit should raise an Error if called after connection' # closed.' @@ -325,7 +325,7 @@ def test_execute(self): con.close() def _paraminsert(self, cur): - self.executeDDL2(cur) + self.execute_ddl2(cur) table_prefix = self.table_prefix insert = f"{self.insert} into {table_prefix}barflys values" cur.execute( @@ -384,7 +384,7 @@ def test_executemany(self): con = self._connect() try: cur = con.cursor() - self.executeDDL1(cur) + self.execute_ddl1(cur) table_prefix = self.table_prefix insert = f'{self.insert} into {table_prefix}booze values' largs = [("Cooper's",), ("Boag's",)] @@ -428,7 +428,7 @@ def test_fetchone(self): # cursor.fetchone should raise an Error if called after # executing a query that cannot return rows - self.executeDDL1(cur) + self.execute_ddl1(cur) self.assertRaises(self.driver.Error, cur.fetchone) cur.execute(f'select name from {self.table_prefix}booze') @@ -474,7 +474,7 @@ def test_next(self): # cursor.next should raise an Error if called after # executing a query that cannot return rows - self.executeDDL1(cur) + self.execute_ddl1(cur) self.assertRaises(self.driver.Error, cur.next) # cursor.next should return None if a query retrieves no rows @@ -527,7 +527,7 @@ def test_fetchmany(self): # issuing a query self.assertRaises(self.driver.Error, cur.fetchmany, 4) - self.executeDDL1(cur) + self.execute_ddl1(cur) for sql in self._populate(): cur.execute(sql) @@ -588,7 +588,7 @@ def test_fetchmany(self): ' called after the whole result set has been fetched') self.assertIn(cur.rowcount, (-1, 6)) - self.executeDDL2(cur) + self.execute_ddl2(cur) cur.execute(f'select name from {self.table_prefix}barflys') r = cur.fetchmany() # Should get empty sequence self.assertEqual( @@ -609,7 +609,7 @@ def test_fetchall(self): # as a select) self.assertRaises(self.driver.Error, cur.fetchall) - self.executeDDL1(cur) + self.execute_ddl1(cur) for sql in self._populate(): cur.execute(sql) @@ -635,7 +635,7 @@ def test_fetchall(self): ' after the whole result set has been fetched') self.assertIn(cur.rowcount, (-1, len(self.samples))) - self.executeDDL2(cur) + self.execute_ddl2(cur) cur.execute(f'select name from {self.table_prefix}barflys') rows = cur.fetchall() self.assertIn(cur.rowcount, (-1, 0)) @@ -651,7 +651,7 @@ def test_mixedfetch(self): con = self._connect() try: cur = con.cursor() - self.executeDDL1(cur) + self.execute_ddl1(cur) for sql in self._populate(): cur.execute(sql) @@ -680,7 +680,7 @@ def test_mixedfetch(self): finally: con.close() - def help_nextset_setUp(self, cur): + def help_nextset_setup(self, cur): """Set up nextset test. Should create a procedure called deleteme that returns two result sets, @@ -696,7 +696,7 @@ def help_nextset_setUp(self, cur): # """ # cur.execute(sql) - def help_nextset_tearDown(self, cur): + def help_nextset_teardown(self, cur): """Clean up after nextset test. If cleaning up is needed after test_nextset. @@ -717,7 +717,7 @@ def test_nextset(self): # self.executeDDL1(cur) # for sql in self._populate(): # cur.execute(sql) - # self.help_nextset_setUp(cur) + # self.help_nextset_setup(cur) # cur.callproc('deleteme') # number_of_rows = cur.fetchone() # self.assertEqual(number_of_rows[0], len(self.samples)) @@ -727,7 +727,7 @@ def test_nextset(self): # self.assertIsNone( # cur.nextset(), 'No more return sets, should return None') # finally: - # self.help_nextset_tearDown(cur) + # self.help_nextset_teardown(cur) # finally: # con.close() @@ -765,11 +765,11 @@ def test_setoutputsize(self): # Real test for setoutputsize is driver dependant raise NotImplementedError('Driver needed to override this test') - def test_None(self): + def test_none(self): con = self._connect() try: cur = con.cursor() - self.executeDDL2(cur) + self.execute_ddl2(cur) # inserting NULL to the second column, because some drivers might # need the first one to be primary key, which means it needs # to have a non-NULL value @@ -783,21 +783,21 @@ def test_None(self): finally: con.close() - def test_Date(self): + def test_date(self): d1 = self.driver.Date(2002, 12, 25) d2 = self.driver.DateFromTicks( time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied self.assertEqual(str(d1), str(d2)) - def test_Time(self): + def test_time(self): t1 = self.driver.Time(13, 45, 30) t2 = self.driver.TimeFromTicks( time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied self.assertEqual(str(t1), str(t2)) - def test_Timestamp(self): + def test_timestamp(self): t1 = self.driver.Timestamp(2002, 12, 25, 13, 45, 30) t2 = self.driver.TimestampFromTicks( time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)) @@ -805,26 +805,26 @@ def test_Timestamp(self): # Can we assume this? API doesn't specify, but it seems implied self.assertEqual(str(t1), str(t2)) - def test_Binary(self): + def test_binary_string(self): self.driver.Binary(b'Something') self.driver.Binary(b'') - def test_STRING(self): + def test_string_type(self): self.assertTrue(hasattr(self.driver, 'STRING'), 'module.STRING must be defined') - def test_BINARY(self): + def test_binary_type(self): self.assertTrue(hasattr(self.driver, 'BINARY'), 'module.BINARY must be defined.') - def test_NUMBER(self): + def test_number_type(self): self.assertTrue(hasattr(self.driver, 'NUMBER'), 'module.NUMBER must be defined.') - def test_DATETIME(self): + def test_datetime_type(self): self.assertTrue(hasattr(self.driver, 'DATETIME'), 'module.DATETIME must be defined.') - def test_ROWID(self): + def test_rowid_type(self): self.assertTrue(hasattr(self.driver, 'ROWID'), 'module.ROWID must be defined.') diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index dc7311c4..ed31bed8 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -48,7 +48,7 @@ def connect_nowait(): class TestCanConnect(unittest.TestCase): """Test whether a basic connection to PostgreSQL is possible.""" - def testCanConnect(self): + def test_can_connect(self): try: connection = connect() rc = connection.poll() @@ -65,7 +65,7 @@ def testCanConnect(self): except pg.Error: self.fail('Cannot close the database connection') - def testCanConnectNoWait(self): + def test_can_connect_no_wait(self): try: connection = connect_nowait() rc = connection.poll() @@ -104,21 +104,21 @@ def is_method(self, attribute): return False return callable(getattr(self.connection, attribute)) - def testClassName(self): + def test_class_name(self): self.assertEqual(self.connection.__class__.__name__, 'Connection') - def testModuleName(self): + def test_module_name(self): self.assertEqual(self.connection.__class__.__module__, 'pg') - def testStr(self): + def test_str(self): r = str(self.connection) self.assertTrue(r.startswith('= 120000: self.skipTest("database does not support tables with oids") query = self.c.query @@ -797,7 +797,7 @@ def testQueryWithOids(self): self.assertIsInstance(r, str) self.assertEqual(r, '5') - def testMemSize(self): + def test_mem_size(self): # noinspection PyUnresolvedReferences if pg.get_pqlib_version() < 120000: self.skipTest("pqlib does not support memsize()") @@ -823,21 +823,21 @@ def setUp(self): def tearDown(self): self.c.close() - def testGetresulAscii(self): + def test_getresul_ascii(self): result = 'Hello, world!' q = f"select '{result}'" v = self.c.query(q).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) - def testDictresulAscii(self): + def test_dictresul_ascii(self): result = 'Hello, world!' q = f"select '{result}' as greeting" v = self.c.query(q).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) - def testGetresultUtf8(self): + def test_getresult_utf8(self): result = 'Hello, wörld & мир!' q = f"select '{result}'" # pass the query as unicode @@ -853,7 +853,7 @@ def testGetresultUtf8(self): self.assertIsInstance(v, str) self.assertEqual(v, result) - def testDictresultUtf8(self): + def test_dictresult_utf8(self): result = 'Hello, wörld & мир!' q = f"select '{result}' as greeting" try: @@ -867,7 +867,7 @@ def testDictresultUtf8(self): self.assertIsInstance(v, str) self.assertEqual(v, result) - def testGetresultLatin1(self): + def test_getresult_latin1(self): try: self.c.query('set client_encoding=latin1') except (pg.DataError, pg.NotSupportedError): @@ -882,7 +882,7 @@ def testGetresultLatin1(self): self.assertIsInstance(v, str) self.assertEqual(v, result) - def testDictresultLatin1(self): + def test_dictresult_latin1(self): try: self.c.query('set client_encoding=latin1') except (pg.DataError, pg.NotSupportedError): @@ -897,7 +897,7 @@ def testDictresultLatin1(self): self.assertIsInstance(v, str) self.assertEqual(v, result) - def testGetresultCyrillic(self): + def test_getresult_cyrillic(self): try: self.c.query('set client_encoding=iso_8859_5') except (pg.DataError, pg.NotSupportedError): @@ -912,7 +912,7 @@ def testGetresultCyrillic(self): self.assertIsInstance(v, str) self.assertEqual(v, result) - def testDictresultCyrillic(self): + def test_dictresult_cyrillic(self): try: self.c.query('set client_encoding=iso_8859_5') except (pg.DataError, pg.NotSupportedError): @@ -927,7 +927,7 @@ def testDictresultCyrillic(self): self.assertIsInstance(v, str) self.assertEqual(v, result) - def testGetresultLatin9(self): + def test_getresult_latin9(self): try: self.c.query('set client_encoding=latin9') except (pg.DataError, pg.NotSupportedError): @@ -942,7 +942,7 @@ def testGetresultLatin9(self): self.assertIsInstance(v, str) self.assertEqual(v, result) - def testDictresultLatin9(self): + def test_dictresult_latin9(self): try: self.c.query('set client_encoding=latin9') except (pg.DataError, pg.NotSupportedError): @@ -968,7 +968,7 @@ def setUp(self): def tearDown(self): self.c.close() - def testQueryWithNoneParam(self): + def test_query_with_none_param(self): self.assertRaises(TypeError, self.c.query, "select $1", None) self.assertRaises(TypeError, self.c.query, "select $1+$2", None, None) self.assertEqual( @@ -978,8 +978,9 @@ def testQueryWithNoneParam(self): self.assertEqual( self.c.query("select $1::text", [[None]]).getresult(), [(None,)]) - def testQueryWithBoolParams(self, bool_enabled=None): + def test_query_with_bool_params(self, bool_enabled=None): query = self.c.query + bool_enabled_default = None if bool_enabled is not None: bool_enabled_default = pg.get_bool() pg.set_bool(bool_enabled) @@ -1003,13 +1004,12 @@ def testQueryWithBoolParams(self, bool_enabled=None): self.assertEqual(query(q, (True,)).getresult(), r_true) finally: if bool_enabled is not None: - # noinspection PyUnboundLocalVariable pg.set_bool(bool_enabled_default) - def testQueryWithBoolParamsNotDefault(self): - self.testQueryWithBoolParams(bool_enabled=not pg.get_bool()) + def test_query_with_bool_params_not_default(self): + self.test_query_with_bool_params(bool_enabled=not pg.get_bool()) - def testQueryWithIntParams(self): + def test_query_with_int_params(self): query = self.c.query self.assertEqual(query("select 1+1").getresult(), [(2,)]) self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)]) @@ -1031,7 +1031,7 @@ def testQueryWithIntParams(self): query("select 0+$1+$2+$3+$4+$5+$6", list(range(6))).getresult(), [(15,)]) - def testQueryWithStrParams(self): + def test_query_with_str_params(self): query = self.c.query self.assertEqual( query("select $1||', world!'", ('Hello',)).getresult(), @@ -1064,7 +1064,7 @@ def testQueryWithStrParams(self): ('Hello', 'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)]) - def testQueryWithUnicodeParams(self): + def test_query_with_unicode_params(self): query = self.c.query try: query('set client_encoding=utf8') @@ -1076,7 +1076,7 @@ def testQueryWithUnicodeParams(self): query("select $1||', '||$2||'!'", ('Hello', 'wörld')).getresult(), [('Hello, wörld!',)]) - def testQueryWithUnicodeParamsLatin1(self): + def test_query_with_unicode_params_latin1(self): query = self.c.query try: query('set client_encoding=latin1') @@ -1101,7 +1101,7 @@ def testQueryWithUnicodeParamsLatin1(self): UnicodeError, query, "select $1||', '||$2||'!'", ('Hello', 'wörld')) - def testQueryWithUnicodeParamsCyrillic(self): + def test_query_with_unicode_params_cyrillic(self): query = self.c.query try: query('set client_encoding=iso_8859_5') @@ -1120,7 +1120,7 @@ def testQueryWithUnicodeParamsCyrillic(self): UnicodeError, query, "select $1||', '||$2||'!'", ('Hello', 'мир!')) - def testQueryWithMixedParams(self): + def test_query_with_mixed_params(self): self.assertEqual( self.c.query( "select $1+2,$2||', world!'", (1, 'Hello')).getresult(), @@ -1131,17 +1131,17 @@ def testQueryWithMixedParams(self): (4711, None, 'Hello!')).getresult(), [(4711, None, 'Hello!')]) - def testQueryWithDuplicateParams(self): + def test_query_with_duplicate_params(self): self.assertRaises( pg.ProgrammingError, self.c.query, "select $1+$1", (1,)) self.assertRaises( pg.ProgrammingError, self.c.query, "select $1+$1", (1, 2)) - def testQueryWithZeroParams(self): + def test_query_with_zero_params(self): self.assertEqual( self.c.query("select 1+1", []).getresult(), [(2,)]) - def testQueryWithGarbage(self): + def test_query_with_garbage(self): garbage = r"'\{}+()-#[]oo324" self.assertEqual( self.c.query("select $1::text AS garbage", @@ -1159,38 +1159,38 @@ def setUp(self): def tearDown(self): self.c.close() - def testEmptyPreparedStatement(self): + def test_empty_prepared_statement(self): self.c.prepare('', '') self.assertRaises(ValueError, self.c.query_prepared, '') - def testInvalidPreparedStatement(self): + def test_invalid_prepared_statement(self): self.assertRaises(pg.ProgrammingError, self.c.prepare, '', 'bad') - def testDuplicatePreparedStatement(self): + def test_duplicate_prepared_statement(self): self.assertIsNone(self.c.prepare('q', 'select 1')) self.assertRaises(pg.ProgrammingError, self.c.prepare, 'q', 'select 2') - def testNonExistentPreparedStatement(self): + def test_non_existent_prepared_statement(self): self.assertRaises( pg.OperationalError, self.c.query_prepared, 'does-not-exist') - def testUnnamedQueryWithoutParams(self): + def test_unnamed_query_without_params(self): self.assertIsNone(self.c.prepare('', "select 'anon'")) self.assertEqual(self.c.query_prepared('').getresult(), [('anon',)]) self.assertEqual(self.c.query_prepared('').getresult(), [('anon',)]) - def testNamedQueryWithoutParams(self): + def test_named_query_without_params(self): self.assertIsNone(self.c.prepare('hello', "select 'world'")) self.assertEqual( self.c.query_prepared('hello').getresult(), [('world',)]) - def testMultipleNamedQueriesWithoutParams(self): + def test_multiple_named_queries_without_params(self): self.assertIsNone(self.c.prepare('query17', "select 17")) self.assertIsNone(self.c.prepare('query42', "select 42")) self.assertEqual(self.c.query_prepared('query17').getresult(), [(17,)]) self.assertEqual(self.c.query_prepared('query42').getresult(), [(42,)]) - def testUnnamedQueryWithParams(self): + def test_unnamed_query_with_params(self): self.assertIsNone(self.c.prepare('', "select $1 || ', ' || $2")) self.assertEqual( self.c.query_prepared('', ['hello', 'world']).getresult(), @@ -1199,7 +1199,7 @@ def testUnnamedQueryWithParams(self): self.assertEqual( self.c.query_prepared('', [17, -5, 29]).getresult(), [(42,)]) - def testMultipleNamedQueriesWithParams(self): + def test_multiple_named_queries_with_params(self): self.assertIsNone(self.c.prepare('q1', "select $1 || '!'")) self.assertIsNone(self.c.prepare('q2', "select $1 || '-' || $2")) self.assertEqual( @@ -1209,21 +1209,21 @@ def testMultipleNamedQueriesWithParams(self): self.c.query_prepared('q2', ['he', 'lo']).getresult(), [('he-lo',)]) - def testDescribeNonExistentQuery(self): + def test_describe_non_existent_query(self): self.assertRaises( pg.OperationalError, self.c.describe_prepared, 'does-not-exist') - def testDescribeUnnamedQuery(self): + def test_describe_unnamed_query(self): self.c.prepare('', "select 1::int, 'a'::char") r = self.c.describe_prepared('') self.assertEqual(r.listfields(), ('int4', 'bpchar')) - def testDescribeNamedQuery(self): + def test_describe_named_query(self): self.c.prepare('myquery', "select 1 as first, 2 as second") r = self.c.describe_prepared('myquery') self.assertEqual(r.listfields(), ('first', 'second')) - def testDescribeMultipleNamedQueries(self): + def test_describe_multiple_named_queries(self): self.c.prepare('query1', "select 1::int") self.c.prepare('query2', "select 1::int, 2::int") r = self.c.describe_prepared('query1') @@ -1267,36 +1267,36 @@ def assert_proper_cast(self, value, pgtype, pytype): self.assertEqual(len(r), 1) self.assertIsInstance(r[0], pytype) - def testInt(self): + def test_int(self): self.assert_proper_cast(0, 'int', int) self.assert_proper_cast(0, 'smallint', int) self.assert_proper_cast(0, 'oid', int) self.assert_proper_cast(0, 'cid', int) self.assert_proper_cast(0, 'xid', int) - def testLong(self): + def test_long(self): self.assert_proper_cast(0, 'bigint', int) - def testFloat(self): + def test_float(self): self.assert_proper_cast(0, 'float', float) self.assert_proper_cast(0, 'real', float) self.assert_proper_cast(0, 'double precision', float) self.assert_proper_cast('infinity', 'float', float) - def testNumeric(self): + def test_numeric(self): decimal = pg.get_decimal() self.assert_proper_cast(decimal(0), 'numeric', decimal) self.assert_proper_cast(decimal(0), 'decimal', decimal) - def testMoney(self): + def test_money(self): decimal = pg.get_decimal() self.assert_proper_cast(decimal('0'), 'money', decimal) - def testBool(self): + def test_bool(self): bool_type = bool if pg.get_bool() else str self.assert_proper_cast('f', 'bool', bool_type) - def testDate(self): + def test_date(self): self.assert_proper_cast('1956-01-31', 'date', str) self.assert_proper_cast('10:20:30', 'interval', str) self.assert_proper_cast('08:42:15', 'time', str) @@ -1304,16 +1304,16 @@ def testDate(self): self.assert_proper_cast('1956-01-31 08:42:15', 'timestamp', str) self.assert_proper_cast('1956-01-31 08:42:15+00', 'timestamptz', str) - def testText(self): + def test_text(self): self.assert_proper_cast('', 'text', str) self.assert_proper_cast('', 'char', str) self.assert_proper_cast('', 'bpchar', str) self.assert_proper_cast('', 'varchar', str) - def testBytea(self): + def test_bytea(self): self.assert_proper_cast('', 'bytea', bytes) - def testJson(self): + def test_json(self): self.assert_proper_cast('{}', 'json', dict) @@ -1326,27 +1326,27 @@ def setUp(self): def tearDown(self): self.c.close() - def testLen(self): + def test_len(self): r = self.c.query("select generate_series(3,7)") self.assertEqual(len(r), 5) - def testGetItem(self): + def test_get_item(self): r = self.c.query("select generate_series(7,9)") self.assertEqual(r[0], (7,)) self.assertEqual(r[1], (8,)) self.assertEqual(r[2], (9,)) - def testGetItemWithNegativeIndex(self): + def test_get_item_with_negative_index(self): r = self.c.query("select generate_series(7,9)") self.assertEqual(r[-1], (9,)) self.assertEqual(r[-2], (8,)) self.assertEqual(r[-3], (7,)) - def testGetItemOutOfRange(self): + def test_get_item_out_of_range(self): r = self.c.query("select generate_series(7,9)") self.assertRaises(IndexError, r.__getitem__, 3) - def testIterate(self): + def test_iterate(self): r = self.c.query("select generate_series(3,5)") self.assertNotIsInstance(r, (list, tuple)) self.assertIsInstance(r, Iterable) @@ -1354,29 +1354,29 @@ def testIterate(self): # noinspection PyUnresolvedReferences self.assertIsInstance(r[1], tuple) - def testIterateTwice(self): + def test_iterate_twice(self): r = self.c.query("select generate_series(3,5)") for i in range(2): self.assertEqual(list(r), [(3,), (4,), (5,)]) - def testIterateTwoColumns(self): + def test_iterate_two_columns(self): r = self.c.query("select 1,2 union select 3,4") self.assertIsInstance(r, Iterable) self.assertEqual(list(r), [(1, 2), (3, 4)]) - def testNext(self): + def test_next(self): r = self.c.query("select generate_series(7,9)") self.assertEqual(next(r), (7,)) self.assertEqual(next(r), (8,)) self.assertEqual(next(r), (9,)) self.assertRaises(StopIteration, next, r) - def testContains(self): + def test_contains(self): r = self.c.query("select generate_series(7,9)") self.assertIn((8,), r) self.assertNotIn((5,), r) - def testDictIterate(self): + def test_dict_iterate(self): r = self.c.query("select generate_series(3,5) as n").dictiter() self.assertNotIsInstance(r, (list, tuple)) self.assertIsInstance(r, Iterable) @@ -1384,7 +1384,7 @@ def testDictIterate(self): self.assertEqual(r, [dict(n=3), dict(n=4), dict(n=5)]) self.assertIsInstance(r[1], dict) - def testDictIterateTwoColumns(self): + def test_dict_iterate_two_columns(self): r = self.c.query( "select 1 as one, 2 as two" " union select 3 as one, 4 as two").dictiter() @@ -1392,19 +1392,19 @@ def testDictIterateTwoColumns(self): r = list(r) self.assertEqual(r, [dict(one=1, two=2), dict(one=3, two=4)]) - def testDictNext(self): + def test_dict_next(self): r = self.c.query("select generate_series(7,9) as n").dictiter() self.assertEqual(next(r), dict(n=7)) self.assertEqual(next(r), dict(n=8)) self.assertEqual(next(r), dict(n=9)) self.assertRaises(StopIteration, next, r) - def testDictContains(self): + def test_dict_contains(self): r = self.c.query("select generate_series(7,9) as n").dictiter() self.assertIn(dict(n=8), r) self.assertNotIn(dict(n=5), r) - def testNamedIterate(self): + def test_named_iterate(self): r = self.c.query("select generate_series(3,5) as number").namediter() self.assertNotIsInstance(r, (list, tuple)) self.assertIsInstance(r, Iterable) @@ -1414,7 +1414,7 @@ def testNamedIterate(self): self.assertEqual(r[1]._fields, ('number',)) self.assertEqual(r[1].number, 4) - def testNamedIterateTwoColumns(self): + def test_named_iterate_two_columns(self): r = self.c.query( "select 1 as one, 2 as two" " union select 3 as one, 4 as two").namediter() @@ -1426,7 +1426,7 @@ def testNamedIterateTwoColumns(self): self.assertEqual(r[1]._fields, ('one', 'two')) self.assertEqual(r[1].two, 4) - def testNamedNext(self): + def test_named_next(self): r = self.c.query("select generate_series(7,9) as number").namediter() self.assertEqual(next(r), (7,)) self.assertEqual(next(r), (8,)) @@ -1435,12 +1435,12 @@ def testNamedNext(self): self.assertEqual(n.number, 9) self.assertRaises(StopIteration, next, r) - def testNamedContains(self): + def test_named_contains(self): r = self.c.query("select generate_series(7,9)").namediter() self.assertIn((8,), r) self.assertNotIn((5,), r) - def testScalarIterate(self): + def test_scalar_iterate(self): r = self.c.query("select generate_series(3,5)").scalariter() self.assertNotIsInstance(r, (list, tuple)) self.assertIsInstance(r, Iterable) @@ -1448,20 +1448,20 @@ def testScalarIterate(self): self.assertEqual(r, [3, 4, 5]) self.assertIsInstance(r[1], int) - def testScalarIterateTwoColumns(self): + def test_scalar_iterate_two_columns(self): r = self.c.query("select 1, 2 union select 3, 4").scalariter() self.assertIsInstance(r, Iterable) r = list(r) self.assertEqual(r, [1, 3]) - def testScalarNext(self): + def test_scalar_next(self): r = self.c.query("select generate_series(7,9)").scalariter() self.assertEqual(next(r), 7) self.assertEqual(next(r), 8) self.assertEqual(next(r), 9) self.assertRaises(StopIteration, next, r) - def testScalarContains(self): + def test_scalar_contains(self): r = self.c.query("select generate_series(7,9)").scalariter() self.assertIn(8, r) self.assertNotIn(5, r) @@ -1476,46 +1476,46 @@ def setUp(self): def tearDown(self): self.c.close() - def testOneWithEmptyQuery(self): + def test_one_with_empty_query(self): q = self.c.query("select 0 where false") self.assertIsNone(q.one()) - def testOneWithSingleRow(self): + def test_one_with_single_row(self): q = self.c.query("select 1, 2") r = q.one() self.assertIsInstance(r, tuple) self.assertEqual(r, (1, 2)) self.assertEqual(q.one(), None) - def testOneWithTwoRows(self): + def test_one_with_two_rows(self): q = self.c.query("select 1, 2 union select 3, 4") self.assertEqual(q.one(), (1, 2)) self.assertEqual(q.one(), (3, 4)) self.assertEqual(q.one(), None) - def testOneDictWithEmptyQuery(self): + def test_one_dict_with_empty_query(self): q = self.c.query("select 0 where false") self.assertIsNone(q.onedict()) - def testOneDictWithSingleRow(self): + def test_one_dict_with_single_row(self): q = self.c.query("select 1 as one, 2 as two") r = q.onedict() self.assertIsInstance(r, dict) self.assertEqual(r, dict(one=1, two=2)) self.assertEqual(q.onedict(), None) - def testOneDictWithTwoRows(self): + def test_one_dict_with_two_rows(self): q = self.c.query( "select 1 as one, 2 as two union select 3 as one, 4 as two") self.assertEqual(q.onedict(), dict(one=1, two=2)) self.assertEqual(q.onedict(), dict(one=3, two=4)) self.assertEqual(q.onedict(), None) - def testOneNamedWithEmptyQuery(self): + def test_one_named_with_empty_query(self): q = self.c.query("select 0 where false") self.assertIsNone(q.onenamed()) - def testOneNamedWithSingleRow(self): + def test_one_named_with_single_row(self): q = self.c.query("select 1 as one, 2 as two") r = q.onenamed() self.assertEqual(r._fields, ('one', 'two')) @@ -1524,7 +1524,7 @@ def testOneNamedWithSingleRow(self): self.assertEqual(r, (1, 2)) self.assertEqual(q.onenamed(), None) - def testOneNamedWithTwoRows(self): + def test_one_named_with_two_rows(self): q = self.c.query( "select 1 as one, 2 as two union select 3 as one, 4 as two") r = q.onenamed() @@ -1539,24 +1539,24 @@ def testOneNamedWithTwoRows(self): self.assertEqual(r, (3, 4)) self.assertEqual(q.onenamed(), None) - def testOneScalarWithEmptyQuery(self): + def test_one_scalar_with_empty_query(self): q = self.c.query("select 0 where false") self.assertIsNone(q.onescalar()) - def testOneScalarWithSingleRow(self): + def test_one_scalar_with_single_row(self): q = self.c.query("select 1, 2") r = q.onescalar() self.assertIsInstance(r, int) self.assertEqual(r, 1) self.assertEqual(q.onescalar(), None) - def testOneScalarWithTwoRows(self): + def test_one_scalar_with_two_rows(self): q = self.c.query("select 1, 2 union select 3, 4") self.assertEqual(q.onescalar(), 1) self.assertEqual(q.onescalar(), 3) self.assertEqual(q.onescalar(), None) - def testSingleWithEmptyQuery(self): + def test_single_with_empty_query(self): q = self.c.query("select 0 where false") try: q.single() @@ -1567,7 +1567,7 @@ def testSingleWithEmptyQuery(self): self.assertIsInstance(r, pg.NoResultError) self.assertEqual(str(r), 'No result found') - def testSingleWithSingleRow(self): + def test_single_with_single_row(self): q = self.c.query("select 1, 2") r = q.single() self.assertIsInstance(r, tuple) @@ -1576,7 +1576,7 @@ def testSingleWithSingleRow(self): self.assertIsInstance(r, tuple) self.assertEqual(r, (1, 2)) - def testSingleWithTwoRows(self): + def test_single_with_two_rows(self): q = self.c.query("select 1, 2 union select 3, 4") try: q.single() @@ -1587,7 +1587,7 @@ def testSingleWithTwoRows(self): self.assertIsInstance(r, pg.MultipleResultsError) self.assertEqual(str(r), 'Multiple results found') - def testSingleDictWithEmptyQuery(self): + def test_single_dict_with_empty_query(self): q = self.c.query("select 0 where false") try: q.singledict() @@ -1598,7 +1598,7 @@ def testSingleDictWithEmptyQuery(self): self.assertIsInstance(r, pg.NoResultError) self.assertEqual(str(r), 'No result found') - def testSingleDictWithSingleRow(self): + def test_single_dict_with_single_row(self): q = self.c.query("select 1 as one, 2 as two") r = q.singledict() self.assertIsInstance(r, dict) @@ -1607,7 +1607,7 @@ def testSingleDictWithSingleRow(self): self.assertIsInstance(r, dict) self.assertEqual(r, dict(one=1, two=2)) - def testSingleDictWithTwoRows(self): + def test_single_dict_with_two_rows(self): q = self.c.query("select 1, 2 union select 3, 4") try: q.singledict() @@ -1618,7 +1618,7 @@ def testSingleDictWithTwoRows(self): self.assertIsInstance(r, pg.MultipleResultsError) self.assertEqual(str(r), 'Multiple results found') - def testSingleNamedWithEmptyQuery(self): + def test_single_named_with_empty_query(self): q = self.c.query("select 0 where false") try: q.singlenamed() @@ -1629,7 +1629,7 @@ def testSingleNamedWithEmptyQuery(self): self.assertIsInstance(r, pg.NoResultError) self.assertEqual(str(r), 'No result found') - def testSingleNamedWithSingleRow(self): + def test_single_named_with_single_row(self): q = self.c.query("select 1 as one, 2 as two") r = q.singlenamed() self.assertEqual(r._fields, ('one', 'two')) @@ -1642,7 +1642,7 @@ def testSingleNamedWithSingleRow(self): self.assertEqual(r.two, 2) self.assertEqual(r, (1, 2)) - def testSingleNamedWithTwoRows(self): + def test_single_named_with_two_rows(self): q = self.c.query("select 1, 2 union select 3, 4") try: q.singlenamed() @@ -1653,7 +1653,7 @@ def testSingleNamedWithTwoRows(self): self.assertIsInstance(r, pg.MultipleResultsError) self.assertEqual(str(r), 'Multiple results found') - def testSingleScalarWithEmptyQuery(self): + def test_single_scalar_with_empty_query(self): q = self.c.query("select 0 where false") try: q.singlescalar() @@ -1664,7 +1664,7 @@ def testSingleScalarWithEmptyQuery(self): self.assertIsInstance(r, pg.NoResultError) self.assertEqual(str(r), 'No result found') - def testSingleScalarWithSingleRow(self): + def test_single_scalar_with_single_row(self): q = self.c.query("select 1, 2") r = q.singlescalar() self.assertIsInstance(r, int) @@ -1673,7 +1673,7 @@ def testSingleScalarWithSingleRow(self): self.assertIsInstance(r, int) self.assertEqual(r, 1) - def testSingleScalarWithTwoRows(self): + def test_single_scalar_with_two_rows(self): q = self.c.query("select 1, 2 union select 3, 4") try: q.singlescalar() @@ -1684,13 +1684,13 @@ def testSingleScalarWithTwoRows(self): self.assertIsInstance(r, pg.MultipleResultsError) self.assertEqual(str(r), 'Multiple results found') - def testScalarResult(self): + def test_scalar_result(self): q = self.c.query("select 1, 2 union select 3, 4") r = q.scalarresult() self.assertIsInstance(r, list) self.assertEqual(r, [1, 3]) - def testScalarIter(self): + def test_scalar_iter(self): q = self.c.query("select 1, 2 union select 3, 4") r = q.scalariter() self.assertNotIsInstance(r, (list, tuple)) @@ -1809,22 +1809,22 @@ def get_back(self, encoding='utf-8'): data.append(row) return data - def testInserttable1Row(self): + def test_inserttable1_row(self): data = self.data[2:3] self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) - def testInserttable4Rows(self): + def test_inserttable4_rows(self): data = self.data self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) - def testInserttableFromTupleOfLists(self): + def test_inserttable_from_tuple_of_lists(self): data = tuple(list(row) for row in self.data) self.c.inserttable('test', data) self.assertEqual(self.get_back(), self.data) - def testInserttableWithDifferentRowSizes(self): + def test_inserttable_with_different_row_sizes(self): data = self.data[:-1] + [self.data[-1][:-1]] try: self.c.inserttable('test', data) @@ -1834,34 +1834,34 @@ def testInserttableWithDifferentRowSizes(self): else: self.assertFalse('expected an error') - def testInserttableFromSetofTuples(self): + def test_inserttable_from_setof_tuples(self): data = {row for row in self.data} self.c.inserttable('test', data) self.assertEqual(self.get_back(), self.data) - def testInserttableFromDictAsInterable(self): + def test_inserttable_from_dict_as_interable(self): data = {row: None for row in self.data} self.c.inserttable('test', data) self.assertEqual(self.get_back(), self.data) - def testInserttableFromDictKeys(self): + def test_inserttable_from_dict_keys(self): data = {row: None for row in self.data} keys = data.keys() self.c.inserttable('test', keys) self.assertEqual(self.get_back(), self.data) - def testInserttableFromDictValues(self): + def test_inserttable_from_dict_values(self): data = {i: row for i, row in enumerate(self.data)} values = data.values() self.c.inserttable('test', values) self.assertEqual(self.get_back(), self.data) - def testInserttableFromGeneratorOfTuples(self): + def test_inserttable_from_generator_of_tuples(self): data = (row for row in self.data) self.c.inserttable('test', data) self.assertEqual(self.get_back(), self.data) - def testInserttableFromListOfSets(self): + def test_inserttable_from_list_of_sets(self): data = [set(row) for row in self.data] try: self.c.inserttable('test', data) @@ -1871,14 +1871,14 @@ def testInserttableFromListOfSets(self): else: self.assertFalse('expected an error') - def testInserttableMultipleRows(self): + def test_inserttable_multiple_rows(self): num_rows = 100 data = self.data[2:3] * num_rows self.c.inserttable('test', data) r = self.c.query("select count(*) from test").getresult()[0][0] self.assertEqual(r, num_rows) - def testInserttableMultipleCalls(self): + def test_inserttable_multiple_calls(self): num_rows = 10 data = self.data[2:3] for _i in range(num_rows): @@ -1886,23 +1886,23 @@ def testInserttableMultipleCalls(self): r = self.c.query("select count(*) from test").getresult()[0][0] self.assertEqual(r, num_rows) - def testInserttableNullValues(self): + def test_inserttable_null_values(self): data = [(None,) * 14] * 100 self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) - def testInserttableNoColumn(self): + def test_inserttable_no_column(self): data = [()] * 10 self.c.inserttable('test', data, []) self.assertEqual(self.get_back(), []) - def testInserttableOnlyOneColumn(self): + def test_inserttable_only_one_column(self): data = [(42,)] * 50 self.c.inserttable('test', data, ['i4']) data = [tuple([42 if i == 1 else None for i in range(14)])] * 50 self.assertEqual(self.get_back(), data) - def testInserttableOnlyTwoColumns(self): + def test_inserttable_only_two_columns(self): data = [(bool(i % 2), i * .5) for i in range(20)] self.c.inserttable('test', data, ('b', 'f4')) # noinspection PyTypeChecker @@ -1910,12 +1910,12 @@ def testInserttableOnlyTwoColumns(self): + (None,) * 6 for i in range(20)] self.assertEqual(self.get_back(), data) - def testInserttableWithDottedTableName(self): + def test_inserttable_with_dotted_table_name(self): data = self.data self.c.inserttable('public.test', data) self.assertEqual(self.get_back(), data) - def testInserttableWithInvalidTableName(self): + def test_inserttable_with_invalid_table_name(self): data = [(42,)] # check that the table name is not inserted unescaped # (this would pass otherwise since there is a column named i4) @@ -1928,7 +1928,7 @@ def testInserttableWithInvalidTableName(self): # make sure that it works if parameters are passed properly self.c.inserttable('test', data, ['i4']) - def testInserttableWithInvalidDataType(self): + def test_inserttable_with_invalid_data_type(self): try: self.c.inserttable('test', 42) except TypeError as e: @@ -1936,7 +1936,7 @@ def testInserttableWithInvalidDataType(self): else: self.assertFalse('expected an error') - def testInserttableWithInvalidColumnName(self): + def test_inserttable_with_invalid_column_name(self): data = [(2, 4)] # check that the column names are not inserted unescaped # (this would pass otherwise since there are columns i2 and i4) @@ -1950,7 +1950,7 @@ def testInserttableWithInvalidColumnName(self): # make sure that it works if parameters are passed properly self.c.inserttable('test', data, ['i2', 'i4']) - def testInserttableWithInvalidColumList(self): + def test_inserttable_with_invalid_colum_list(self): data = self.data try: self.c.inserttable('test', data, 'invalid') @@ -1960,7 +1960,7 @@ def testInserttableWithInvalidColumList(self): else: self.assertFalse('expected an error') - def testInserttableWithHugeListOfColumnNames(self): + def test_inserttable_with_huge_list_of_column_names(self): data = self.data # try inserting data with a huge list of column names cols = ['very_long_column_name'] * 2000 @@ -1970,13 +1970,13 @@ def testInserttableWithHugeListOfColumnNames(self): cols *= 2 self.assertRaises(MemoryError, self.c.inserttable, 'test', data, cols) - def testInserttableWithOutOfRangeData(self): + def test_inserttable_with_out_of_range_data(self): # try inserting data out of range for the column type # Should raise a value error because of smallint out of range self.assertRaises( ValueError, self.c.inserttable, 'test', [[33000]], ['i2']) - def testInserttableMaxValues(self): + def test_inserttable_max_values(self): data = [(2 ** 15 - 1, 2 ** 31 - 1, 2 ** 31 - 1, True, '2999-12-31', '11:59:59', 1e99, 1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None, @@ -1984,7 +1984,7 @@ def testInserttableMaxValues(self): self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) - def testInserttableByteValues(self): + def test_inserttable_byte_values(self): try: self.c.query("select '€', 'käse', 'сыр', 'pont-l''évêque'") except pg.DataError: @@ -2003,7 +2003,7 @@ def testInserttableByteValues(self): data = [row_unicode] * 2 self.assertEqual(self.get_back(), data) - def testInserttableUnicodeUtf8(self): + def test_inserttable_unicode_utf8(self): try: self.c.query("select '€', 'käse', 'сыр', 'pont-l''évêque'") except pg.DataError: @@ -2018,7 +2018,7 @@ def testInserttableUnicodeUtf8(self): self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) - def testInserttableUnicodeLatin1(self): + def test_inserttable_unicode_latin1(self): try: self.c.query("set client_encoding=latin1") self.c.query("select '¥'") @@ -2040,7 +2040,7 @@ def testInserttableUnicodeLatin1(self): self.c.inserttable('test', data) self.assertEqual(self.get_back('latin1'), data) - def testInserttableUnicodeLatin9(self): + def test_inserttable_unicode_latin9(self): try: self.c.query("set client_encoding=latin9") self.c.query("select '€'") @@ -2057,7 +2057,7 @@ def testInserttableUnicodeLatin9(self): self.c.inserttable('test', data) self.assertEqual(self.get_back('latin9'), data) - def testInserttableNoEncoding(self): + def test_inserttable_no_encoding(self): self.c.query("set client_encoding=sql_ascii") # non-ascii chars do not fit in char(1) when there is no encoding c = '€' if self.has_encoding else '$' @@ -2069,7 +2069,7 @@ def testInserttableNoEncoding(self): # cannot encode non-ascii unicode without a specific encoding self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data) - def testInserttableFromQuery(self): + def test_inserttable_from_query(self): data = self.c.query( "select 2::int2 as i2, 4::int4 as i4, 8::int8 as i8, true as b," "null as dt, null as ti, null as d," @@ -2080,7 +2080,7 @@ def testInserttableFromQuery(self): (2, 4, 8, True, None, None, None, 4.5, 8.5, None, 'c', 'v4', None, 'text')]) - def testInserttableSpecialChars(self): + def test_inserttable_special_chars(self): class S: def __repr__(self): return s @@ -2093,7 +2093,7 @@ def __repr__(self): self.assertEqual( self.c.query('select t from test').getresult(), [(s,)] * 3) - def testInsertTableBigRowSize(self): + def test_insert_table_big_row_size(self): # inserting rows with a size of up to 64k bytes should work t = '*' * 50000 data = [(t,)] @@ -2105,7 +2105,7 @@ def testInsertTableBigRowSize(self): data = [(t,)] self.assertRaises(MemoryError, self.c.inserttable, 'test', data, ['t']) - def testInsertTableSmallIntOverflow(self): + def test_insert_table_small_int_overflow(self): rest_row = self.data[2][1:] data = [(32000,) + rest_row] self.c.inserttable('test', data) @@ -2148,7 +2148,7 @@ def tearDown(self): self.c.query("truncate table test") self.c.close() - def testPutline(self): + def test_putline(self): putline = self.c.putline query = self.c.query data = list(enumerate("apple pear plum cherry banana".split())) @@ -2161,7 +2161,7 @@ def testPutline(self): r = query("select * from test").getresult() self.assertEqual(r, data) - def testPutlineBytesAndUnicode(self): + def test_putline_bytes_and_unicode(self): putline = self.c.putline query = self.c.query try: @@ -2177,7 +2177,7 @@ def testPutlineBytesAndUnicode(self): r = query("select * from test").getresult() self.assertEqual(r, [(47, 'käse'), (35, 'würstel')]) - def testGetline(self): + def test_getline(self): getline = self.c.getline query = self.c.query data = list(enumerate("apple banana pear plum strawberry".split())) @@ -2198,7 +2198,7 @@ def testGetline(self): except OSError: pass - def testGetlineBytesAndUnicode(self): + def test_getline_bytes_and_unicode(self): getline = self.c.getline query = self.c.query try: @@ -2222,7 +2222,7 @@ def testGetlineBytesAndUnicode(self): except OSError: pass - def testParameterChecks(self): + def test_parameter_checks(self): self.assertRaises(TypeError, self.c.putline) self.assertRaises(TypeError, self.c.getline, 'invalid') self.assertRaises(TypeError, self.c.endcopy, 'invalid') @@ -2238,7 +2238,7 @@ def tearDown(self): self.doCleanups() self.c.close() - def testGetNotify(self): + def test_get_notify(self): getnotify = self.c.getnotify query = self.c.query self.assertIsNone(getnotify()) @@ -2268,23 +2268,23 @@ def testGetNotify(self): finally: query('unlisten test_notify') - def testGetNoticeReceiver(self): + def test_get_notice_receiver(self): self.assertIsNone(self.c.get_notice_receiver()) - def testSetNoticeReceiver(self): + def test_set_notice_receiver(self): self.assertRaises(TypeError, self.c.set_notice_receiver, 42) self.assertRaises(TypeError, self.c.set_notice_receiver, 'invalid') self.assertIsNone(self.c.set_notice_receiver(lambda notice: None)) self.assertIsNone(self.c.set_notice_receiver(None)) - def testSetAndGetNoticeReceiver(self): + def test_set_and_get_notice_receiver(self): r = lambda notice: None # noqa: E731 self.assertIsNone(self.c.set_notice_receiver(r)) self.assertIs(self.c.get_notice_receiver(), r) self.assertIsNone(self.c.set_notice_receiver(None)) self.assertIsNone(self.c.get_notice_receiver()) - def testNoticeReceiver(self): + def test_notice_receiver(self): self.addCleanup(self.c.query, 'drop function bilbo_notice();') self.c.query('''create function bilbo_notice() returns void AS $$ begin @@ -2326,7 +2326,7 @@ def setUp(self): def tearDown(self): self.c.close() - def testGetDecimalPoint(self): + def test_get_decimal_point(self): point = pg.get_decimal_point() # error if a parameter is passed self.assertRaises(TypeError, pg.get_decimal_point, point) @@ -2359,7 +2359,7 @@ def testGetDecimalPoint(self): pg.set_decimal_point(point) self.assertIsNone(r) - def testSetDecimalPoint(self): + def test_set_decimal_point(self): d = pg.Decimal point = pg.get_decimal_point() self.assertRaises(TypeError, pg.set_decimal_point) @@ -2483,7 +2483,7 @@ def testSetDecimalPoint(self): pg.set_decimal_point(point) self.assertEqual(r, bad_money) - def testGetDecimal(self): + def test_get_decimal(self): decimal_class = pg.get_decimal() # error if a parameter is passed self.assertRaises(TypeError, pg.get_decimal, decimal_class) @@ -2497,7 +2497,7 @@ def testGetDecimal(self): r = pg.get_decimal() self.assertIs(r, decimal_class) - def testSetDecimal(self): + def test_set_decimal(self): decimal_class = pg.get_decimal() # error if no parameter is passed self.assertRaises(TypeError, pg.set_decimal) @@ -2520,7 +2520,7 @@ def testSetDecimal(self): self.assertIsInstance(r, int) self.assertEqual(r, 3425) - def testGetBool(self): + def test_get_bool(self): use_bool = pg.get_bool() # error if a parameter is passed self.assertRaises(TypeError, pg.get_bool, use_bool) @@ -2555,7 +2555,7 @@ def testGetBool(self): self.assertIsInstance(r, bool) self.assertIs(r, True) - def testSetBool(self): + def test_set_bool(self): use_bool = pg.get_bool() # error if no parameter is passed self.assertRaises(TypeError, pg.set_bool) @@ -2583,7 +2583,7 @@ def testSetBool(self): self.assertIsInstance(r, bool) self.assertIs(r, True) - def testGetByteEscaped(self): + def test_get_byte_escaped(self): bytea_escaped = pg.get_bytea_escaped() # error if a parameter is passed self.assertRaises(TypeError, pg.get_bytea_escaped, bytea_escaped) @@ -2618,7 +2618,7 @@ def testGetByteEscaped(self): self.assertIsInstance(r, bool) self.assertIs(r, False) - def testSetByteaEscaped(self): + def test_set_bytea_escaped(self): bytea_escaped = pg.get_bytea_escaped() # error if no parameter is passed self.assertRaises(TypeError, pg.set_bytea_escaped) @@ -2646,7 +2646,7 @@ def testSetByteaEscaped(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b'data') - def testSetRowFactorySize(self): + def test_set_row_factory_size(self): queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc'] query = self.c.query for maxsize in (None, 0, 1, 2, 3, 10, 1024): @@ -2689,7 +2689,7 @@ def setUpClass(cls): db.close() cls.cls_set_up = True - def testEscapeString(self): + def test_escape_string(self): self.assertTrue(self.cls_set_up) f = pg.escape_string r = f(b'plain') @@ -2707,7 +2707,7 @@ def testEscapeString(self): r = f(r"It's bad to have a \ inside.") self.assertEqual(r, r"It''s bad to have a \\ inside.") - def testEscapeBytea(self): + def test_escape_bytea(self): self.assertTrue(self.cls_set_up) f = pg.escape_bytea r = f(b'plain') diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 8e64949d..1f7b3aac 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -37,7 +37,7 @@ do_not_ask_for_host_reason = 'libpq issue on Windows' -def DB(): +def DB(): # noqa: N802 """Create a DB wrapper object connecting to the test database.""" db = pg.DB(dbname, dbhost, dbport, user=dbuser, passwd=dbpasswd) if debug: @@ -52,7 +52,7 @@ class TestAttrDict(unittest.TestCase): cls = pg.AttrDict base = OrderedDict - def testInit(self): + def test_init(self): a = self.cls() self.assertIsInstance(a, self.base) self.assertEqual(a, self.base()) @@ -65,7 +65,7 @@ def testInit(self): self.assertIsInstance(a, self.base) self.assertEqual(a, self.base(items)) - def testIter(self): + def test_iter(self): a = self.cls() self.assertEqual(list(a), []) keys = ['id', 'name', 'age'] @@ -73,7 +73,7 @@ def testIter(self): a = self.cls(items) self.assertEqual(list(a), keys) - def testKeys(self): + def test_keys(self): a = self.cls() self.assertEqual(list(a.keys()), []) keys = ['id', 'name', 'age'] @@ -81,7 +81,7 @@ def testKeys(self): a = self.cls(items) self.assertEqual(list(a.keys()), keys) - def testValues(self): + def test_values(self): a = self.cls() self.assertEqual(list(a.values()), []) items = [('id', 'int'), ('name', 'text')] @@ -89,21 +89,21 @@ def testValues(self): a = self.cls(items) self.assertEqual(list(a.values()), values) - def testItems(self): + def test_items(self): a = self.cls() self.assertEqual(list(a.items()), []) items = [('id', 'int'), ('name', 'text')] a = self.cls(items) self.assertEqual(list(a.items()), items) - def testGet(self): + def test_get(self): a = self.cls([('id', 1)]) try: self.assertEqual(a['id'], 1) except KeyError: self.fail('AttrDict should be readable') - def testSet(self): + def test_set(self): a = self.cls() try: a['id'] = 1 @@ -112,7 +112,7 @@ def testSet(self): else: self.fail('AttrDict should be read-only') - def testDel(self): + def test_del(self): a = self.cls([('id', 1)]) try: del a['id'] @@ -121,7 +121,7 @@ def testDel(self): else: self.fail('AttrDict should be read-only') - def testWriteMethods(self): + def test_write_methods(self): a = self.cls([('id', 1)]) self.assertEqual(a['id'], 1) for method in 'clear', 'update', 'pop', 'setdefault', 'popitem': @@ -132,17 +132,17 @@ def testWriteMethods(self): class TestDBClassInit(unittest.TestCase): """Test proper handling of errors when creating DB instances.""" - def testBadParams(self): + def test_bad_params(self): self.assertRaises(TypeError, pg.DB, invalid=True) # noinspection PyUnboundLocalVariable - def testDeleteDb(self): + def test_delete_db(self): db = DB() del db.db self.assertRaises(pg.InternalError, db.close) del db - def testAsyncQueryBeforeDeletion(self): + def test_async_query_before_deletion(self): db = DB() query = db.send_query('select 1') self.assertEqual(query.getresult(), [(1,)]) @@ -151,7 +151,7 @@ def testAsyncQueryBeforeDeletion(self): del db gc.collect() - def testAsyncQueryAfterDeletion(self): + def test_async_query_after_deletion(self): db = DB() query = db.send_query('select 1') del db @@ -172,7 +172,7 @@ def tearDown(self): except pg.InternalError: pass - def testAllDBAttributes(self): + def test_all_db_attributes(self): attributes = [ 'abort', 'adapter', 'backend_pid', 'begin', @@ -210,19 +210,19 @@ def testAllDBAttributes(self): db_attributes = [a for a in self.db.__dir__() if not a.startswith('_')] self.assertEqual(attributes, db_attributes) - def testAttributeDb(self): + def test_attribute_db(self): self.assertEqual(self.db.db.db, dbname) - def testAttributeDbname(self): + def test_attribute_dbname(self): self.assertEqual(self.db.dbname, dbname) - def testAttributeError(self): + def test_attribute_error(self): error = self.db.error self.assertTrue(not error or 'krb5_' in error) self.assertEqual(self.db.error, self.db.db.error) @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason) - def testAttributeHost(self): + def test_attribute_host(self): if dbhost and not dbhost.startswith('/'): host = dbhost else: @@ -231,61 +231,61 @@ def testAttributeHost(self): self.assertEqual(self.db.host, host) self.assertEqual(self.db.db.host, host) - def testAttributeOptions(self): + def test_attribute_options(self): no_options = '' options = self.db.options self.assertEqual(options, no_options) self.assertEqual(options, self.db.db.options) - def testAttributePort(self): + def test_attribute_port(self): def_port = 5432 port = self.db.port self.assertIsInstance(port, int) self.assertEqual(port, dbport or def_port) self.assertEqual(port, self.db.db.port) - def testAttributeProtocolVersion(self): + def test_attribute_protocol_version(self): protocol_version = self.db.protocol_version self.assertIsInstance(protocol_version, int) self.assertTrue(2 <= protocol_version < 4) self.assertEqual(protocol_version, self.db.db.protocol_version) - def testAttributeServerVersion(self): + def test_attribute_server_version(self): server_version = self.db.server_version self.assertIsInstance(server_version, int) self.assertTrue(100000 <= server_version < 160000) self.assertEqual(server_version, self.db.db.server_version) - def testAttributeSocket(self): + def test_attribute_socket(self): socket = self.db.socket self.assertIsInstance(socket, int) self.assertGreaterEqual(socket, 0) - def testAttributeBackendPid(self): + def test_attribute_backend_pid(self): backend_pid = self.db.backend_pid self.assertIsInstance(backend_pid, int) self.assertGreaterEqual(backend_pid, 1) - def testAttributeSslInUse(self): + def test_attribute_ssl_in_use(self): ssl_in_use = self.db.ssl_in_use self.assertIsInstance(ssl_in_use, bool) self.assertFalse(ssl_in_use) - def testAttributeSslAttributes(self): + def test_attribute_ssl_attributes(self): ssl_attributes = self.db.ssl_attributes self.assertIsInstance(ssl_attributes, dict) self.assertEqual(ssl_attributes, { 'cipher': None, 'compression': None, 'key_bits': None, 'library': None, 'protocol': None}) - def testAttributeStatus(self): + def test_attribute_status(self): status_ok = 1 status = self.db.status self.assertIsInstance(status, int) self.assertEqual(status, status_ok) self.assertEqual(status, self.db.db.status) - def testAttributeUser(self): + def test_attribute_user(self): no_user = 'Deprecated facility' user = self.db.user self.assertTrue(user) @@ -293,29 +293,29 @@ def testAttributeUser(self): self.assertNotEqual(user, no_user) self.assertEqual(user, self.db.db.user) - def testMethodEscapeLiteral(self): + def test_method_escape_literal(self): self.assertEqual(self.db.escape_literal(''), "''") - def testMethodEscapeIdentifier(self): + def test_method_escape_identifier(self): self.assertEqual(self.db.escape_identifier(''), '""') - def testMethodEscapeString(self): + def test_method_escape_string(self): self.assertEqual(self.db.escape_string(''), '') - def testMethodEscapeBytea(self): + def test_method_escape_bytea(self): self.assertEqual(self.db.escape_bytea('').replace( '\\x', '').replace('\\', ''), '') - def testMethodUnescapeBytea(self): + def test_method_unescape_bytea(self): self.assertEqual(self.db.unescape_bytea(''), b'') - def testMethodDecodeJson(self): + def test_method_decode_json(self): self.assertEqual(self.db.decode_json('{}'), {}) - def testMethodEncodeJson(self): + def test_method_encode_json(self): self.assertEqual(self.db.encode_json({}), '{}') - def testMethodQuery(self): + def test_method_query(self): query = self.db.query query("select 1+1") query("select 1+$1+$2", 2, 3) @@ -323,23 +323,23 @@ def testMethodQuery(self): query("select 1+$1+$2", [2, 3]) query("select 1+$1", 1) - def testMethodQueryEmpty(self): + def test_method_query_empty(self): self.assertRaises(ValueError, self.db.query, '') - def testMethodQueryDataError(self): + def test_method_query_data_error(self): try: self.db.query("select 1/0") except pg.DataError as error: # noinspection PyUnresolvedReferences self.assertEqual(error.sqlstate, '22012') - def testMethodEndcopy(self): + def test_method_endcopy(self): try: self.db.endcopy() except OSError: pass - def testMethodClose(self): + def test_method_close(self): self.db.close() try: self.db.reset() @@ -354,7 +354,7 @@ def testMethodClose(self): self.assertRaises(pg.InternalError, getattr, self.db, 'error') self.assertRaises(pg.InternalError, getattr, self.db, 'absent') - def testMethodReset(self): + def test_method_reset(self): con = self.db.db self.db.reset() self.assertIs(self.db.db, con) @@ -362,7 +362,7 @@ def testMethodReset(self): self.db.close() self.assertRaises(pg.InternalError, self.db.reset) - def testMethodReopen(self): + def test_method_reopen(self): con = self.db.db self.db.reopen() self.assertIsNot(self.db.db, con) @@ -374,7 +374,7 @@ def testMethodReopen(self): self.db.query("select 1+1") self.db.close() - def testExistingConnection(self): + def test_existing_connection(self): db = pg.DB(self.db.db) self.assertIsNotNone(db.db) self.assertEqual(self.db.db, db.db) @@ -391,7 +391,7 @@ def testExistingConnection(self): db = pg.DB(db=self.db.db) self.assertEqual(self.db.db, db.db) - def testExistingDbApi2Connection(self): + def test_existing_db_api2_connection(self): class DBApi2Con: @@ -461,7 +461,7 @@ def tearDown(self): self.doCleanups() self.db.close() - def createTable(self, table, definition, + def create_table(self, table, definition, temporary=True, oids=None, values=None): query = self.db.query if '"' not in table or '.' in table: @@ -491,14 +491,14 @@ def createTable(self, table, definition, q = f"insert into {table} values ({values})" query(q, params) - def testClassName(self): + def test_class_name(self): self.assertEqual(self.db.__class__.__name__, 'DB') - def testModuleName(self): + def test_module_name(self): self.assertEqual(self.db.__module__, 'pg') self.assertEqual(self.db.__class__.__module__, 'pg') - def testEscapeLiteral(self): + def test_escape_literal(self): f = self.db.escape_literal r = f(b"plain") self.assertIsInstance(r, bytes) @@ -517,7 +517,7 @@ def testEscapeLiteral(self): self.assertEqual(f('No "quotes" must be escaped.'), "'No \"quotes\" must be escaped.'") - def testEscapeIdentifier(self): + def test_escape_identifier(self): f = self.db.escape_identifier r = f(b"plain") self.assertIsInstance(r, bytes) @@ -536,7 +536,7 @@ def testEscapeIdentifier(self): self.assertEqual(f('All "quotes" must be escaped.'), '"All ""quotes"" must be escaped."') - def testEscapeString(self): + def test_escape_string(self): f = self.db.escape_string r = f(b"plain") self.assertIsInstance(r, bytes) @@ -553,7 +553,7 @@ def testEscapeString(self): self.assertEqual(f(r"It's fine to have a \ inside."), r"It''s fine to have a \ inside.") - def testEscapeBytea(self): + def test_escape_bytea(self): f = self.db.escape_bytea # note that escape_byte always returns hex output since Pg 9.0, # regardless of the bytea_output setting @@ -571,7 +571,7 @@ def testEscapeBytea(self): self.assertEqual(r, '\\x64617320697327206bc3a47365') self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21') - def testUnescapeBytea(self): + def test_unescape_bytea(self): f = self.db.unescape_bytea r = f(b'plain') self.assertIsInstance(r, bytes) @@ -591,7 +591,7 @@ def testUnescapeBytea(self): b'\\x746861742773206be47365') self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21') - def testDecodeJson(self): + def test_decode_json(self): f = self.db.decode_json self.assertIsNone(f('null')) data = { @@ -610,7 +610,7 @@ def testDecodeJson(self): self.assertIsInstance(r['tags'], list) self.assertIsInstance(r['stock'], dict) - def testEncodeJson(self): + def test_encode_json(self): f = self.db.encode_json self.assertEqual(f(None), 'null') data = { @@ -623,7 +623,7 @@ def testEncodeJson(self): self.assertIsInstance(r, str) self.assertEqual(r, text) - def testGetParameter(self): + def test_get_parameter(self): f = self.db.get_parameter self.assertRaises(TypeError, f) self.assertRaises(TypeError, f, None) @@ -660,14 +660,14 @@ def testGetParameter(self): self.assertIs(r, s) self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'}) - def testGetParameterServerVersion(self): + def test_get_parameter_server_version(self): r = self.db.get_parameter('server_version_num') self.assertIsInstance(r, str) s = self.db.server_version self.assertIsInstance(s, int) self.assertEqual(r, str(s)) - def testGetParameterAll(self): + def test_get_parameter_all(self): f = self.db.get_parameter r = f('all') self.assertIsInstance(r, dict) @@ -676,7 +676,7 @@ def testGetParameterAll(self): self.assertEqual(r['DateStyle'], 'ISO, YMD') self.assertEqual(r['bytea_output'], 'hex') - def testSetParameter(self): + def test_set_parameter(self): f = self.db.set_parameter g = self.db.get_parameter self.assertRaises(TypeError, f) @@ -720,7 +720,7 @@ def testSetParameter(self): self.assertEqual(g('standard_conforming_strings'), 'on') self.assertEqual(g('datestyle'), 'ISO, YMD') - def testResetParameter(self): + def test_reset_parameter(self): db = DB() f = db.set_parameter g = db.get_parameter @@ -761,7 +761,7 @@ def testResetParameter(self): self.assertEqual(g('standard_conforming_strings'), scs) db.close() - def testResetParameterAll(self): + def test_reset_parameter_all(self): db = DB() f = db.set_parameter self.assertRaises(ValueError, f, 'all', 0) @@ -782,7 +782,7 @@ def testResetParameterAll(self): self.assertEqual(g('standard_conforming_strings'), scs) db.close() - def testSetParameterLocal(self): + def test_set_parameter_local(self): f = self.db.set_parameter g = self.db.get_parameter self.assertEqual(g('standard_conforming_strings'), 'on') @@ -792,7 +792,7 @@ def testSetParameterLocal(self): self.db.end() self.assertEqual(g('standard_conforming_strings'), 'on') - def testSetParameterSession(self): + def test_set_parameter_session(self): f = self.db.set_parameter g = self.db.get_parameter self.assertEqual(g('standard_conforming_strings'), 'on') @@ -802,7 +802,7 @@ def testSetParameterSession(self): self.db.end() self.assertEqual(g('standard_conforming_strings'), 'off') - def testReset(self): + def test_reset(self): db = DB() default_datestyle = db.get_parameter('datestyle') changed_datestyle = 'ISO, DMY' @@ -823,7 +823,7 @@ def testReset(self): self.assertEqual(r, default_datestyle) db.close() - def testReopen(self): + def test_reopen(self): db = DB() default_datestyle = db.get_parameter('datestyle') changed_datestyle = 'ISO, DMY' @@ -842,21 +842,21 @@ def testReopen(self): self.assertEqual(r, default_datestyle) db.close() - def testCreateTable(self): + def test_create_table(self): table = 'test hello world' values = [(2, "World!"), (1, "Hello")] - self.createTable(table, "n smallint, t varchar", + self.create_table(table, "n smallint, t varchar", temporary=True, oids=False, values=values) r = self.db.query(f'select t from "{table}" order by n').getresult() r = ', '.join(row[0] for row in r) self.assertEqual(r, "Hello, World!") - def testCreateTableWithOids(self): + def test_create_table_with_oids(self): if not self.oids: self.skipTest("database does not support tables with oids") table = 'test hello world' values = [(2, "World!"), (1, "Hello")] - self.createTable(table, "n smallint, t varchar", + self.create_table(table, "n smallint, t varchar", temporary=True, oids=True, values=values) r = self.db.query(f'select t from "{table}" order by n').getresult() r = ', '.join(row[0] for row in r) @@ -864,10 +864,10 @@ def testCreateTableWithOids(self): r = self.db.query(f'select oid from "{table}" limit 1').getresult() self.assertIsInstance(r[0][0], int) - def testQuery(self): + def test_query(self): query = self.db.query table = 'test_table' - self.createTable(table, "n integer", oids=False) + self.create_table(table, "n integer", oids=False) q = "insert into test_table values (1)" r = query(q) self.assertIsInstance(r, str) @@ -898,12 +898,12 @@ def testQuery(self): self.assertIsInstance(r, str) self.assertEqual(r, '5') - def testQueryWithOids(self): + def test_query_with_oids(self): if not self.oids: self.skipTest("database does not support tables with oids") query = self.db.query table = 'test_table' - self.createTable(table, "n integer", oids=True) + self.create_table(table, "n integer", oids=True) q = "insert into test_table values (1)" r = query(q) self.assertIsInstance(r, int) @@ -932,15 +932,15 @@ def testQueryWithOids(self): self.assertIsInstance(r, str) self.assertEqual(r, '5') - def testMultipleQueries(self): + def test_multiple_queries(self): self.assertEqual(self.db.query( "create temporary table test_multi (n integer);" "insert into test_multi values (4711);" "select n from test_multi").getresult()[0][0], 4711) - def testQueryWithParams(self): + def test_query_with_params(self): query = self.db.query - self.createTable('test_table', 'n1 integer, n2 integer', oids=False) + self.create_table('test_table', 'n1 integer, n2 integer', oids=False) q = "insert into test_table values ($1, $2)" r = query(q, (1, 2)) self.assertEqual(r, '1') @@ -963,17 +963,17 @@ def testQueryWithParams(self): r = query(q, 4) self.assertEqual(r, '3') - def testEmptyQuery(self): + def test_empty_query(self): self.assertRaises(ValueError, self.db.query, '') - def testQueryDataError(self): + def test_query_data_error(self): try: self.db.query("select 1/0") except pg.DataError as error: # noinspection PyUnresolvedReferences self.assertEqual(error.sqlstate, '22012') - def testQueryFormatted(self): + def test_query_formatted(self): f = self.db.query_formatted t = True if pg.get_bool() else 't' # test with tuple @@ -1001,7 +1001,7 @@ def testQueryFormatted(self): r = q.getresult()[0][0] self.assertEqual(r, 'alphabetagammadeltaepsilon') - def testQueryFormattedWithAny(self): + def test_query_formatted_with_any(self): f = self.db.query_formatted q = "select 2 = any(%s)" r = f(q, [[1, 3]]).getresult()[0][0] @@ -1013,7 +1013,7 @@ def testQueryFormattedWithAny(self): r = f(q, [[None]]).getresult()[0][0] self.assertIsNone(r) - def testQueryFormattedWithoutParams(self): + def test_query_formatted_without_params(self): f = self.db.query_formatted q = "select 42" r = f(q).getresult()[0][0] @@ -1025,19 +1025,19 @@ def testQueryFormattedWithoutParams(self): r = f(q, {}).getresult()[0][0] self.assertEqual(r, 42) - def testPrepare(self): + def test_prepare(self): p = self.db.prepare self.assertIsNone(p('my query', "select 'hello'")) self.assertIsNone(p('my other query', "select 'world'")) self.assertRaises( pg.ProgrammingError, p, 'my query', "select 'hello, too'") - def testPrepareUnnamed(self): + def test_prepare_unnamed(self): p = self.db.prepare self.assertIsNone(p('', "select null")) self.assertIsNone(p(None, "select null")) - def testQueryPreparedWithoutParams(self): + def test_query_prepared_without_params(self): f = self.db.query_prepared self.assertRaises(pg.OperationalError, f, 'q') p = self.db.prepare @@ -1048,7 +1048,7 @@ def testQueryPreparedWithoutParams(self): r = f('q2').getresult()[0][0] self.assertEqual(r, 42) - def testQueryPreparedWithParams(self): + def test_query_prepared_with_params(self): p = self.db.prepare p('sum', "select 1 + $1 + $2 + $3") p('cat', "select initcap($1) || ', ' || $2 || '!'") @@ -1058,7 +1058,7 @@ def testQueryPreparedWithParams(self): r = f('cat', 'hello', 'world').getresult()[0][0] self.assertEqual(r, 'Hello, world!') - def testQueryPreparedUnnamedWithOutParams(self): + def test_query_prepared_unnamed_with_out_params(self): f = self.db.query_prepared self.assertRaises(pg.OperationalError, f, None) self.assertRaises(pg.OperationalError, f, '') @@ -1076,7 +1076,7 @@ def testQueryPreparedUnnamedWithOutParams(self): r = f('').getresult()[0][0] self.assertEqual(r, 'none') - def testQueryPreparedUnnamedWithParams(self): + def test_query_prepared_unnamed_with_params(self): p = self.db.prepare p('', "select 1 + $1 + $2") f = self.db.query_prepared @@ -1091,13 +1091,13 @@ def testQueryPreparedUnnamedWithParams(self): r = f(None, 3, 4).getresult()[0][0] self.assertEqual(r, 9) - def testDescribePrepared(self): + def test_describe_prepared(self): self.db.prepare('count', "select 1 as first, 2 as second") f = self.db.describe_prepared r = f('count').listfields() self.assertEqual(r, ('first', 'second')) - def testDescribePreparedUnnamed(self): + def test_describe_prepared_unnamed(self): self.db.prepare('', "select null as anon") f = self.db.describe_prepared r = f().listfields() @@ -1107,7 +1107,7 @@ def testDescribePreparedUnnamed(self): r = f('').listfields() self.assertEqual(r, ('anon',)) - def testDeletePrepared(self): + def test_delete_prepared(self): f = self.db.delete_prepared f() e = pg.OperationalError @@ -1125,27 +1125,27 @@ def testDeletePrepared(self): self.assertRaises(e, f, 'q1') self.assertRaises(e, f, 'q2') - def testPkey(self): + def test_pkey(self): query = self.db.query pkey = self.db.pkey self.assertRaises(KeyError, pkey, 'test') for t in ('pkeytest', 'primary key test'): - self.createTable(f'{t}0', 'a smallint') - self.createTable(f'{t}1', 'b smallint primary key') - self.createTable(f'{t}2', 'c smallint, d smallint primary key') - self.createTable( + self.create_table(f'{t}0', 'a smallint') + self.create_table(f'{t}1', 'b smallint primary key') + self.create_table(f'{t}2', 'c smallint, d smallint primary key') + self.create_table( f'{t}3', 'e smallint, f smallint, g smallint, h smallint, i smallint,' ' primary key (f, h)') - self.createTable( + self.create_table( f'{t}4', 'e smallint, f smallint, g smallint, h smallint, i smallint,' ' primary key (h, f)') - self.createTable( + self.create_table( f'{t}5', 'more_than_one_letter varchar primary key') - self.createTable( + self.create_table( f'{t}6', '"with space" date primary key') - self.createTable( + self.create_table( f'{t}7', 'a_very_long_column_name varchar, "with space" date, "42" int,' ' primary key (a_very_long_column_name, "with space", "42")') @@ -1181,7 +1181,7 @@ def testPkey(self): # we get the changed primary key when the cache is flushed self.assertEqual(pkey(f'{t}1', flush=True), 'x') - def testGetDatabases(self): + def test_get_databases(self): databases = self.db.get_databases() self.assertIn('template0', databases) self.assertIn('template1', databases) @@ -1189,7 +1189,7 @@ def testGetDatabases(self): self.assertIn('postgres', databases) self.assertIn(dbname, databases) - def testGetTables(self): + def test_get_tables(self): get_tables = self.db.get_tables tables = ('A very Special Name', 'A_MiXeD_quoted_NaMe', 'Hello, Test World!', 'Zoro', 'a1', 'a2', 'a321', @@ -1208,7 +1208,7 @@ def testGetTables(self): self.assertNotEqual(t, 'information_schema') self.assertFalse(t.startswith('pg_')) for t in tables: - self.createTable(t, 'as select 0', temporary=False) + self.create_table(t, 'as select 0', temporary=False) current_tables = get_tables() new_tables = [t for t in current_tables if t not in before_tables] expected_new_tables = ['public.' + ( @@ -1218,7 +1218,7 @@ def testGetTables(self): after_tables = get_tables() self.assertEqual(after_tables, before_tables) - def testGetSystemTables(self): + def test_get_system_tables(self): get_tables = self.db.get_tables result = get_tables() self.assertNotIn('pg_catalog.pg_class', result) @@ -1230,7 +1230,7 @@ def testGetSystemTables(self): self.assertIn('pg_catalog.pg_class', result) self.assertNotIn('information_schema.tables', result) - def testGetRelations(self): + def test_get_relations(self): get_relations = self.db.get_relations result = get_relations() self.assertIn('public.test', result) @@ -1248,7 +1248,7 @@ def testGetRelations(self): self.assertNotIn('public.test', result) self.assertNotIn('public.test_view', result) - def testGetSystemRelations(self): + def test_get_system_relations(self): get_relations = self.db.get_relations result = get_relations() self.assertNotIn('pg_catalog.pg_class', result) @@ -1260,7 +1260,7 @@ def testGetSystemRelations(self): self.assertIn('pg_catalog.pg_class', result) self.assertIn('information_schema.tables', result) - def testGetAttnames(self): + def test_get_attnames(self): get_attnames = self.db.get_attnames self.assertRaises(pg.ProgrammingError, self.db.get_attnames, 'does_not_exist') @@ -1278,7 +1278,7 @@ def testGetAttnames(self): i2='int', i4='int', i8='int', d='num', f4='float', f8='float', m='money', v4='text', c4='text', t='text')) - self.createTable('test_table', + self.create_table('test_table', 'n int, alpha smallint, beta bool,' ' gamma char(5), tau text, v varchar(3)') r = get_attnames('test_table') @@ -1292,10 +1292,10 @@ def testGetAttnames(self): n='int', alpha='int', beta='bool', gamma='text', tau='text', v='text')) - def testGetAttnamesWithQuotes(self): + def test_get_attnames_with_quotes(self): get_attnames = self.db.get_attnames table = 'test table for get_attnames()' - self.createTable( + self.create_table( table, '"Prime!" smallint, "much space" integer, "Questions?" text') r = get_attnames(table) @@ -1308,7 +1308,7 @@ def testGetAttnamesWithQuotes(self): self.assertEqual(r, { 'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'}) table = 'yet another test table for get_attnames()' - self.createTable(table, + self.create_table(table, 'a smallint, b integer, c bigint,' ' e numeric, f real, f2 double precision, m money,' ' x smallint, y smallint, z smallint,' @@ -1333,9 +1333,9 @@ def testGetAttnamesWithQuotes(self): 'u': 'text', 't': 'text', 'v': 'text', 'y': 'int', 'x': 'int', 'z': 'int'}) - def testGetAttnamesWithRegtypes(self): + def test_get_attnames_with_regtypes(self): get_attnames = self.db.get_attnames - self.createTable( + self.create_table( 'test_table', 'n int, alpha smallint, beta bool,' ' gamma char(5), tau text, v varchar(3)') use_regtypes = self.db.use_regtypes @@ -1351,9 +1351,9 @@ def testGetAttnamesWithRegtypes(self): n='integer', alpha='smallint', beta='boolean', gamma='character', tau='text', v='character varying')) - def testGetAttnamesWithoutRegtypes(self): + def test_get_attnames_without_regtypes(self): get_attnames = self.db.get_attnames - self.createTable( + self.create_table( 'test_table', 'n int, alpha smallint, beta bool,' ' gamma char(5), tau text, v varchar(3)') use_regtypes = self.db.use_regtypes @@ -1369,12 +1369,12 @@ def testGetAttnamesWithoutRegtypes(self): n='int', alpha='int', beta='bool', gamma='text', tau='text', v='text')) - def testGetAttnamesIsCached(self): + def test_get_attnames_is_cached(self): get_attnames = self.db.get_attnames int_type = 'integer' if self.regtypes else 'int' text_type = 'text' query = self.db.query - self.createTable('test_table', 'col int') + self.create_table('test_table', 'col int') r = get_attnames("test_table") self.assertIsInstance(r, dict) self.assertEqual(r, dict(col=int_type)) @@ -1395,7 +1395,7 @@ def testGetAttnamesIsCached(self): r = get_attnames("test_table", flush=True) self.assertEqual(r, dict()) - def testGetAttnamesIsOrdered(self): + def test_get_attnames_is_ordered(self): get_attnames = self.db.get_attnames r = get_attnames('test', flush=True) self.assertIsInstance(r, OrderedDict) @@ -1414,7 +1414,7 @@ def testGetAttnamesIsOrdered(self): r = ' '.join(list(r.keys())) self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t') table = 'test table for get_attnames' - self.createTable( + self.create_table( table, 'n int, alpha smallint, v varchar(3),' ' gamma char(5), tau text, beta bool') r = get_attnames(table) @@ -1434,8 +1434,8 @@ def testGetAttnamesIsOrdered(self): else: self.skipTest('OrderedDict is not supported') - def testGetAttnamesIsAttrDict(self): - AttrDict = pg.AttrDict + def test_get_attnames_is_attr_dict(self): + AttrDict = pg.AttrDict # noqa: N806 get_attnames = self.db.get_attnames r = get_attnames('test', flush=True) self.assertIsInstance(r, AttrDict) @@ -1453,7 +1453,7 @@ def testGetAttnamesIsAttrDict(self): r = ' '.join(list(r.keys())) self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t') table = 'test table for get_attnames' - self.createTable( + self.create_table( table, 'n int, alpha smallint, v varchar(3),' ' gamma char(5), tau text, beta bool') r = get_attnames(table) @@ -1470,7 +1470,7 @@ def testGetAttnamesIsAttrDict(self): r = ' '.join(list(r.keys())) self.assertEqual(r, 'n alpha v gamma tau beta') - def testGetGenerated(self): + def test_get_generated(self): get_generated = self.db.get_generated server_version = self.db.server_version if server_version >= 100000: @@ -1483,7 +1483,7 @@ def testGetGenerated(self): self.assertFalse(r) if server_version >= 100000: table = 'test_get_generated_1' - self.createTable( + self.create_table( table, 'i int generated always as identity primary key,' ' j int generated always as identity,' @@ -1494,7 +1494,7 @@ def testGetGenerated(self): self.assertEqual(r, {'i', 'j'}) if server_version >= 120000: table = 'test_get_generated_2' - self.createTable( + self.create_table( table, 'n int, m int generated always as (n + 3) stored,' ' i int generated always as identity,' @@ -1503,21 +1503,21 @@ def testGetGenerated(self): self.assertIsInstance(r, frozenset) self.assertEqual(r, {'m', 'i'}) - def testGetGeneratedIsCached(self): + def test_get_generated_is_cached(self): server_version = self.db.server_version if server_version < 100000: self.skipTest("database does not support generated columns") get_generated = self.db.get_generated query = self.db.query table = 'test_get_generated_2' - self.createTable(table, 'i int primary key') + self.create_table(table, 'i int primary key') self.assertFalse(get_generated(table)) query(f'alter table {table} alter column i' ' add generated always as identity') self.assertFalse(get_generated(table)) self.assertEqual(get_generated(table, flush=True), {'i'}) - def testHasTablePrivilege(self): + def test_has_table_privilege(self): can = self.db.has_table_privilege self.assertEqual(can('test'), True) self.assertEqual(can('test', 'select'), True) @@ -1538,13 +1538,13 @@ def testHasTablePrivilege(self): self.assertEqual(can('pg_views', 'select'), True) self.assertEqual(can('pg_views', 'delete'), False) - def testGet(self): + def test_get(self): get = self.db.get query = self.db.query table = 'get_test_table' self.assertRaises(TypeError, get) self.assertRaises(TypeError, get, table) - self.createTable(table, 'n integer, t text', + self.create_table(table, 'n integer, t text', values=enumerate('xyz', start=1)) self.assertRaises(pg.ProgrammingError, get, table, 2) r = get(table, 2, 'n') @@ -1593,13 +1593,13 @@ def testGet(self): s.pop('n') self.assertRaises(KeyError, get, table, s) - def testGetWithOids(self): + def test_get_with_oids(self): if not self.oids: self.skipTest("database does not support tables with oids") get = self.db.get query = self.db.query table = 'get_with_oid_test_table' - self.createTable(table, 'n integer, t text', oids=True, + self.create_table(table, 'n integer, t text', oids=True, values=enumerate('xyz', start=1)) self.assertRaises(pg.ProgrammingError, get, table, 2) self.assertRaises(KeyError, get, table, {}, 'oid') @@ -1659,10 +1659,10 @@ def testGetWithOids(self): self.assertEqual(r['n'], 3) self.assertNotEqual(r[qoid], oid) - def testGetWithCompositeKey(self): + def test_get_with_composite_key(self): get = self.db.get table = 'get_test_table_1' - self.createTable( + self.create_table( table, 'n integer primary key, t text', values=enumerate('abc', start=1)) self.assertEqual(get(table, 2)['t'], 'b') @@ -1674,7 +1674,7 @@ def testGetWithCompositeKey(self): self.assertEqual(get(table, ('a',), ('t',))['n'], 1) self.assertEqual(get(table, ['c'], ['t'])['n'], 3) table = 'get_test_table_2' - self.createTable( + self.create_table( table, 'n integer, m integer, t text, primary key (n, m)', values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m)) for n in range(3) for m in range(2)]) @@ -1691,10 +1691,10 @@ def testGetWithCompositeKey(self): self.assertEqual(get(table, dict(n=2, m=1), ['n', 'm'])['t'], 'c') self.assertEqual(get(table, dict(n=3, m=2), ('m', 'n'))['t'], 'f') - def testGetWithQuotedNames(self): + def test_get_with_quoted_names(self): get = self.db.get table = 'test table for get()' - self.createTable( + self.create_table( table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text', values=[(17, 1001, 'No!')]) @@ -1704,7 +1704,7 @@ def testGetWithQuotedNames(self): self.assertEqual(r['much space'], 1001) self.assertEqual(r['Questions?'], 'No!') - def testGetFromView(self): + def test_get_from_view(self): self.db.query('delete from test where i4=14') self.db.query('insert into test (i4, v4) values(' "14, 'abc4')") @@ -1712,10 +1712,10 @@ def testGetFromView(self): self.assertIn('v4', r) self.assertEqual(r['v4'], 'abc4') - def testGetLittleBobbyTables(self): + def test_get_little_bobby_tables(self): get = self.db.get query = self.db.query - self.createTable( + self.create_table( 'test_students', 'firstname varchar primary key, nickname varchar, grade char(2)', values=[("D'Arcy", 'Darcey', 'A+'), ('Sheldon', 'Moonpie', 'A+'), @@ -1748,13 +1748,13 @@ def testGetLittleBobbyTables(self): self.assertEqual(len(r), 3) self.assertEqual(r[1][2], 'D-') - def testInsert(self): + def test_insert(self): insert = self.db.insert query = self.db.query bool_on = pg.get_bool() decimal = pg.get_decimal() table = 'insert_test_table' - self.createTable( + self.create_table( table, 'i2 smallint, i4 integer, i8 bigint,' ' d numeric, f4 real, f8 double precision, m money,' ' v4 varchar(4), c4 char(4), t text,' @@ -1840,12 +1840,12 @@ def testInsert(self): self.assertEqual(data, expect) query(f'truncate table "{table}"') - def testInsertWithOids(self): + def test_insert_with_oids(self): if not self.oids: self.skipTest("database does not support tables with oids") insert = self.db.insert query = self.db.query - self.createTable('test_table', 'n int', oids=True) + self.create_table('test_table', 'n int', oids=True) self.assertRaises(pg.ProgrammingError, insert, 'test_table', m=1) r = insert('test_table', n=1) self.assertIsInstance(r, dict) @@ -1910,11 +1910,11 @@ def testInsertWithOids(self): r = ' '.join(str(row[0]) for row in query(q).getresult()) self.assertEqual(r, '6 7') - def testInsertWithQuotedNames(self): + def test_insert_with_quoted_names(self): insert = self.db.insert query = self.db.query table = 'test table for insert()' - self.createTable(table, '"Prime!" smallint primary key,' + self.create_table(table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text') r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'} r = insert(table, r) @@ -1929,7 +1929,7 @@ def testInsertWithQuotedNames(self): self.assertEqual(r['much space'], 2002) self.assertEqual(r['Questions?'], 'What?') - def testInsertIntoView(self): + def test_insert_into_view(self): insert = self.db.insert query = self.db.query query("truncate table test") @@ -1955,7 +1955,7 @@ def testInsertIntoView(self): r = query(q).getresult() self.assertEqual(r, [(1234, 'abcd'), (5678, 'efgh')]) - def testInsertWithGeneratedColumns(self): + def test_insert_with_generated_columns(self): insert = self.db.insert get = self.db.get server_version = self.db.server_version @@ -1971,7 +1971,7 @@ def testInsertWithGeneratedColumns(self): table_def += ', j int generated always as (i + 7) stored' else: table_def += ', j int not null default 42' - self.createTable(table, table_def) + self.create_table(table, table_def) i, d = 35, 1001 j = i + 7 r = insert(table, {'i': i, 'd': d, 'a': 1, 'j': j}) @@ -1981,13 +1981,13 @@ def testInsertWithGeneratedColumns(self): self.assertIsInstance(r, dict) self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) - def testUpdate(self): + def test_update(self): update = self.db.update query = self.db.query self.assertRaises(pg.ProgrammingError, update, 'test', i2=2, i4=4, i8=8) table = 'update_test_table' - self.createTable(table, 'n integer primary key, t text', + self.create_table(table, 'n integer primary key, t text', values=enumerate('xyz', start=1)) self.assertRaises(pg.DatabaseError, self.db.get, table, 4) r = self.db.get(table, 2) @@ -1998,13 +1998,13 @@ def testUpdate(self): r = query(q).getresult()[0][0] self.assertEqual(r, 'u') - def testUpdateWithOids(self): + def test_update_with_oids(self): if not self.oids: self.skipTest("database does not support tables with oids") update = self.db.update get = self.db.get query = self.db.query - self.createTable('test_table', 'n int', oids=True, values=[1]) + self.create_table('test_table', 'n int', oids=True, values=[1]) s = get('test_table', 1, 'n') self.assertIsInstance(s, dict) self.assertEqual(s['n'], 1) @@ -2078,13 +2078,13 @@ def testUpdateWithOids(self): r = query(q).getresult() self.assertEqual(r, [(1, 3), (4, 7)]) - def testUpdateWithoutOid(self): + def test_update_without_oid(self): update = self.db.update query = self.db.query self.assertRaises(pg.ProgrammingError, update, 'test', i2=2, i4=4, i8=8) table = 'update_test_table' - self.createTable(table, 'n integer primary key, t text', oids=False, + self.create_table(table, 'n integer primary key, t text', oids=False, values=enumerate('xyz', start=1)) r = self.db.get(table, 2) r['t'] = 'u' @@ -2094,11 +2094,11 @@ def testUpdateWithoutOid(self): r = query(q).getresult()[0][0] self.assertEqual(r, 'u') - def testUpdateWithCompositeKey(self): + def test_update_with_composite_key(self): update = self.db.update query = self.db.query table = 'update_test_table_1' - self.createTable(table, 'n integer primary key, t text', + self.create_table(table, 'n integer primary key, t text', values=enumerate('abc', start=1)) self.assertRaises(KeyError, update, table, dict(t='b')) s = dict(n=2, t='d') @@ -2121,7 +2121,7 @@ def testUpdateWithCompositeKey(self): self.assertEqual(len(r), 0) query(f'drop table "{table}"') table = 'update_test_table_2' - self.createTable(table, + self.create_table(table, 'n integer, m integer, t text, primary key (n, m)', values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m)) for n in range(3) for m in range(2)]) @@ -2132,11 +2132,11 @@ def testUpdateWithCompositeKey(self): r = [r[0] for r in query(q).getresult()] self.assertEqual(r, ['c', 'x']) - def testUpdateWithQuotedNames(self): + def test_update_with_quoted_names(self): update = self.db.update query = self.db.query table = 'test table for update()' - self.createTable(table, '"Prime!" smallint primary key,' + self.create_table(table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text', values=[(13, 3003, 'Why!')]) r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'} @@ -2152,7 +2152,7 @@ def testUpdateWithQuotedNames(self): self.assertEqual(r['much space'], 7007) self.assertEqual(r['Questions?'], 'When?') - def testUpdateWithGeneratedColumns(self): + def test_update_with_generated_columns(self): update = self.db.update get = self.db.get query = self.db.query @@ -2169,7 +2169,7 @@ def testUpdateWithGeneratedColumns(self): table_def += ', j int generated always as (i + 7) stored' else: table_def += ', j int not null default 42' - self.createTable(table, table_def) + self.create_table(table, table_def) i, d = 35, 1001 j = i + 7 r = query(f'insert into {table} (i, d) values ({i}, {d})') @@ -2184,13 +2184,13 @@ def testUpdateWithGeneratedColumns(self): j += 1 self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) - def testUpsert(self): + def test_upsert(self): upsert = self.db.upsert query = self.db.query self.assertRaises(pg.ProgrammingError, upsert, 'test', i2=2, i4=4, i8=8) table = 'upsert_test_table' - self.createTable(table, 'n integer primary key, t text') + self.create_table(table, 'n integer primary key, t text') s = dict(n=1, t='x') r = upsert(table, s) self.assertIs(r, s) @@ -2257,13 +2257,13 @@ def testUpsert(self): r = query(q).getresult() self.assertEqual(r, [(1, 'x2'), (2, 'y3')]) - def testUpsertWithOids(self): + def test_upsert_with_oids(self): if not self.oids: self.skipTest("database does not support tables with oids") upsert = self.db.upsert get = self.db.get query = self.db.query - self.createTable('test_table', 'n int', oids=True, values=[1]) + self.create_table('test_table', 'n int', oids=True, values=[1]) self.assertRaises(pg.ProgrammingError, upsert, 'test_table', dict(n=2)) r = get('test_table', 1, 'n') @@ -2338,11 +2338,11 @@ def testUpsertWithOids(self): q = query("select n, m from test_table order by n limit 3") self.assertEqual(q.getresult(), [(1, 5), (2, 10)]) - def testUpsertWithCompositeKey(self): + def test_upsert_with_composite_key(self): upsert = self.db.upsert query = self.db.query table = 'upsert_test_table_2' - self.createTable( + self.create_table( table, 'n integer, m integer, t text, primary key (n, m)') s = dict(n=1, m=2, t='x') r = upsert(table, s) @@ -2400,11 +2400,11 @@ def testUpsertWithCompositeKey(self): r = query(q).getresult() self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')]) - def testUpsertWithQuotedNames(self): + def test_upsert_with_quoted_names(self): upsert = self.db.upsert query = self.db.query table = 'test table for upsert()' - self.createTable(table, '"Prime!" smallint primary key,' + self.create_table(table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text') s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'} r = upsert(table, s) @@ -2424,7 +2424,7 @@ def testUpsertWithQuotedNames(self): r = query(q).getresult() self.assertEqual(r, [(31, 9009, 'No.')]) - def testUpsertWithGeneratedColumns(self): + def test_upsert_with_generated_columns(self): upsert = self.db.upsert get = self.db.get server_version = self.db.server_version @@ -2440,7 +2440,7 @@ def testUpsertWithGeneratedColumns(self): table_def += ', j int generated always as (i + 7) stored' else: table_def += ', j int not null default 42' - self.createTable(table, table_def) + self.create_table(table, table_def) i, d = 35, 1001 j = i + 7 r = upsert(table, {'i': i, 'd': d, 'a': 1, 'j': j}) @@ -2455,7 +2455,7 @@ def testUpsertWithGeneratedColumns(self): r = get(table, d) self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) - def testClear(self): + def test_clear(self): clear = self.db.clear f = False if pg.get_bool() else 'f' r = clear('test') @@ -2463,7 +2463,7 @@ def testClear(self): i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='') self.assertEqual(r, result) table = 'clear_test_table' - self.createTable( + self.create_table( table, 'n integer, f float, b boolean, d date, t text') r = clear(table) result = dict(n=0, f=0, b=f, d='', t='') @@ -2476,10 +2476,10 @@ def testClear(self): result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=1) self.assertEqual(r, result) - def testClearWithQuotedNames(self): + def test_clear_with_quoted_names(self): clear = self.db.clear table = 'test table for clear()' - self.createTable( + self.create_table( table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text') r = clear(table) @@ -2488,13 +2488,13 @@ def testClearWithQuotedNames(self): self.assertEqual(r['much space'], 0) self.assertEqual(r['Questions?'], '') - def testDelete(self): + def test_delete(self): delete = self.db.delete query = self.db.query self.assertRaises(pg.ProgrammingError, delete, 'test', dict(i2=2, i4=4, i8=8)) table = 'delete_test_table' - self.createTable(table, 'n integer primary key, t text', + self.create_table(table, 'n integer primary key, t text', oids=False, values=enumerate('xyz', start=1)) self.assertRaises(pg.DatabaseError, self.db.get, table, 4) r = self.db.get(table, 1) @@ -2521,13 +2521,13 @@ def testDelete(self): s = delete(table, r) self.assertEqual(s, 0) - def testDeleteWithOids(self): + def test_delete_with_oids(self): if not self.oids: self.skipTest("database does not support tables with oids") delete = self.db.delete get = self.db.get query = self.db.query - self.createTable('test_table', 'n int', oids=True, values=range(1, 7)) + self.create_table('test_table', 'n int', oids=True, values=range(1, 7)) r = dict(n=3) self.assertRaises(pg.ProgrammingError, delete, 'test_table', r) s = get('test_table', 1, 'n') @@ -2617,10 +2617,10 @@ def testDeleteWithOids(self): self.assertEqual(r, 1) self.assertEqual(query(q).getresult()[0], (None, 0)) - def testDeleteWithCompositeKey(self): + def test_delete_with_composite_key(self): query = self.db.query table = 'delete_test_table_1' - self.createTable(table, 'n integer primary key, t text', + self.create_table(table, 'n integer primary key, t text', values=enumerate('abc', start=1)) self.assertRaises(KeyError, self.db.delete, table, dict(t='b')) self.assertEqual(self.db.delete(table, dict(n=2)), 1) @@ -2630,7 +2630,7 @@ def testDeleteWithCompositeKey(self): r = query(f'select t from "{table}" where n=3').getresult()[0][0] self.assertEqual(r, 'c') table = 'delete_test_table_2' - self.createTable( + self.create_table( table, 'n integer, m integer, t text, primary key (n, m)', values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m)) for n in range(3) for m in range(2)]) @@ -2648,11 +2648,11 @@ def testDeleteWithCompositeKey(self): f' order by m').getresult()] self.assertEqual(r, ['f']) - def testDeleteWithQuotedNames(self): + def test_delete_with_quoted_names(self): delete = self.db.delete query = self.db.query table = 'test table for delete()' - self.createTable( + self.create_table( table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text', values=[(19, 5005, 'Yes!')]) @@ -2667,12 +2667,12 @@ def testDeleteWithQuotedNames(self): r = query(f'select count(*) from "{table}"').getresult() self.assertEqual(r[0][0], 0) - def testDeleteReferenced(self): + def test_delete_referenced(self): delete = self.db.delete query = self.db.query - self.createTable( + self.create_table( 'test_parent', 'n smallint primary key', values=range(3)) - self.createTable( + self.create_table( 'test_child', 'n smallint primary key references test_parent', values=range(3)) q = ("select (select count(*) from test_parent)," @@ -2705,9 +2705,10 @@ def testDeleteReferenced(self): q = "select n from test_parent natural join test_child limit 2" self.assertEqual(query(q).getresult(), [(1,)]) - def testTempCrud(self): + def test_temp_crud(self): table = 'test_temp_table' - self.createTable(table, "n int primary key, t varchar", temporary=True) + self.create_table(table, "n int primary key, t varchar", + temporary=True) self.db.insert(table, dict(n=1, t='one')) self.db.insert(table, dict(n=2, t='too')) self.db.insert(table, dict(n=3, t='three')) @@ -2720,14 +2721,14 @@ def testTempCrud(self): r = self.db.query(f'select n, t from {table} order by 1').getresult() self.assertEqual(r, [(1, 'one'), (3, 'three')]) - def testTruncate(self): + def test_truncate(self): truncate = self.db.truncate self.assertRaises(TypeError, truncate, None) self.assertRaises(TypeError, truncate, 42) self.assertRaises(TypeError, truncate, dict(test_table=None)) query = self.db.query - self.createTable('test_table', 'n smallint', - temporary=False, values=[1] * 3) + self.create_table('test_table', 'n smallint', + temporary=False, values=[1] * 3) q = "select count(*) from test_table" r = query(q).getresult()[0][0] self.assertEqual(r, 3) @@ -2741,7 +2742,7 @@ def testTruncate(self): truncate('public.test_table') r = query(q).getresult()[0][0] self.assertEqual(r, 0) - self.createTable('test_table_2', 'n smallint', temporary=True) + self.create_table('test_table_2', 'n smallint', temporary=True) for t in (list, tuple, set): for i in range(3): query("insert into test_table values (1)") @@ -2754,11 +2755,11 @@ def testTruncate(self): r = query(q).getresult()[0] self.assertEqual(r, (0, 0)) - def testTruncateRestart(self): + def test_truncate_restart(self): truncate = self.db.truncate self.assertRaises(TypeError, truncate, 'test_table', restart='invalid') query = self.db.query - self.createTable('test_table', 'n serial, t text') + self.create_table('test_table', 'n serial, t text') for n in range(3): query("insert into test_table (t) values ('test')") q = "select count(n), min(n), max(n) from test_table" @@ -2779,13 +2780,13 @@ def testTruncateRestart(self): r = query(q).getresult()[0] self.assertEqual(r, (3, 1, 3)) - def testTruncateCascade(self): + def test_truncate_cascade(self): truncate = self.db.truncate self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid') query = self.db.query - self.createTable('test_parent', 'n smallint primary key', + self.create_table('test_parent', 'n smallint primary key', values=range(3)) - self.createTable('test_child', + self.create_table('test_child', 'n smallint primary key references test_parent (n)', values=range(3)) q = ("select (select count(*) from test_parent)," @@ -2817,12 +2818,12 @@ def testTruncateCascade(self): r = query(q).getresult()[0] self.assertEqual(r, (0, 0)) - def testTruncateOnly(self): + def test_truncate_only(self): truncate = self.db.truncate self.assertRaises(TypeError, truncate, 'test_table', only='invalid') query = self.db.query - self.createTable('test_parent', 'n smallint') - self.createTable('test_child', 'm smallint) inherits (test_parent') + self.create_table('test_parent', 'n smallint') + self.create_table('test_child', 'm smallint) inherits (test_parent') for n in range(3): query("insert into test_parent (n) values (1)") query("insert into test_child (n, m) values (2, 3)") @@ -2854,8 +2855,9 @@ def testTruncateOnly(self): self.assertEqual(r, (0, 0)) self.assertRaises(ValueError, truncate, 'test_parent*', only=True) truncate('test_parent*', only=False) - self.createTable('test_parent_2', 'n smallint') - self.createTable('test_child_2', 'm smallint) inherits (test_parent_2') + self.create_table('test_parent_2', 'n smallint') + self.create_table('test_child_2', + 'm smallint) inherits (test_parent_2') for t in '', '_2': for n in range(3): query(f"insert into test_parent{t} (n) values (1)") @@ -2877,11 +2879,11 @@ def testTruncateOnly(self): ['test_parent*', 'test_child'], only=[True, False]) truncate(['test_parent*', 'test_child'], only=[False, True]) - def testTruncateQuoted(self): + def test_truncate_quoted(self): truncate = self.db.truncate query = self.db.query table = "test table for truncate()" - self.createTable(table, 'n smallint', temporary=False, values=[1] * 3) + self.create_table(table, 'n smallint', temporary=False, values=[1] * 3) q = f'select count(*) from "{table}"' r = query(q).getresult()[0][0] self.assertEqual(r, 3) @@ -2897,7 +2899,7 @@ def testTruncateQuoted(self): self.assertEqual(r, 0) # noinspection PyUnresolvedReferences - def testGetAsList(self): + def test_get_as_list(self): get_as_list = self.db.get_as_list self.assertRaises(TypeError, get_as_list) self.assertRaises(TypeError, get_as_list, None) @@ -2908,7 +2910,7 @@ def testGetAsList(self): named = hasattr(r, 'colname') names = [(1, 'Homer'), (2, 'Marge'), (3, 'Bart'), (4, 'Lisa'), (5, 'Maggie')] - self.createTable( + self.create_table( table, 'id smallint primary key, name varchar', values=names) r = get_as_list(table) self.assertIsInstance(r, list) @@ -3010,7 +3012,7 @@ def testGetAsList(self): self.assertEqual(t, ('bart',)) # noinspection PyUnresolvedReferences - def testGetAsDict(self): + def test_get_as_dict(self): get_as_dict = self.db.get_as_dict self.assertRaises(TypeError, get_as_dict) self.assertRaises(TypeError, get_as_dict, None) @@ -3023,7 +3025,7 @@ def testGetAsDict(self): named = hasattr(r, 'colname') colors = [(1, '#7cb9e8', 'Aero'), (2, '#b5a642', 'Brass'), (3, '#b2ffff', 'Celeste'), (4, '#c19a6b', 'Desert')] - self.createTable( + self.create_table( table, 'id smallint primary key, rgb char(7), name varchar', values=colors) # keyname must be string, list or tuple @@ -3178,9 +3180,9 @@ def testGetAsDict(self): r = get_as_dict(table, keyname='id') self.assertEqual(r, expected) - def testTransaction(self): + def test_transaction(self): query = self.db.query - self.createTable('test_table', 'n integer', temporary=False) + self.create_table('test_table', 'n integer', temporary=False) self.db.begin() query("insert into test_table values (1)") query("insert into test_table values (2)") @@ -3217,14 +3219,14 @@ def testTransaction(self): query, "insert into test_table values (0)") self.db.abort() - def testTransactionAliases(self): + def test_transaction_aliases(self): self.assertEqual(self.db.begin, self.db.start) self.assertEqual(self.db.commit, self.db.end) self.assertEqual(self.db.rollback, self.db.abort) - def testContextManager(self): + def test_context_manager(self): query = self.db.query - self.createTable('test_table', 'n integer check(n>0)') + self.create_table('test_table', 'n integer check(n>0)') with self.db: query("insert into test_table values (1)") query("insert into test_table values (2)") @@ -3249,9 +3251,9 @@ def testContextManager(self): "select * from test_table order by 1").getresult()] self.assertEqual(r, [1, 2, 5, 7]) - def testBytea(self): + def test_bytea(self): query = self.db.query - self.createTable('bytea_test', 'n smallint primary key, data bytea') + self.create_table('bytea_test', 'n smallint primary key, data bytea') s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n" r = self.db.escape_bytea(s) query('insert into bytea_test values(3, $1)', (r,)) @@ -3267,10 +3269,10 @@ def testBytea(self): self.assertIsInstance(r, bytes) self.assertEqual(r, s) - def testInsertUpdateGetBytea(self): + def test_insert_update_get_bytea(self): query = self.db.query unescape = pg.unescape_bytea if pg.get_bytea_escaped() else None - self.createTable('bytea_test', 'n smallint primary key, data bytea') + self.create_table('bytea_test', 'n smallint primary key, data bytea') # insert null value r = self.db.insert('bytea_test', n=0, data=None) self.assertIsInstance(r, dict) @@ -3341,8 +3343,8 @@ def testInsertUpdateGetBytea(self): self.assertIsInstance(r, bytes) self.assertEqual(r, s) - def testUpsertBytea(self): - self.createTable('bytea_test', 'n smallint primary key, data bytea') + def test_upsert_bytea(self): + self.create_table('bytea_test', 'n smallint primary key, data bytea') s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n" r = dict(n=7, data=s) r = self.db.upsert('bytea_test', r) @@ -3363,8 +3365,8 @@ def testUpsertBytea(self): self.assertIn('data', r) self.assertIsNone(r['data']) - def testInsertGetJson(self): - self.createTable('json_test', 'n smallint primary key, data json') + def test_insert_get_json(self): + self.create_table('json_test', 'n smallint primary key, data json') jsondecode = pg.get_jsondecode() # insert null value r = self.db.insert('json_test', n=0, data=None) @@ -3427,8 +3429,8 @@ def testInsertGetJson(self): self.assertIsInstance(r[0][0], str if jsondecode is None else dict) self.assertEqual(r[0][0], r[1][0]) - def testInsertGetJsonb(self): - self.createTable('jsonb_test', + def test_insert_get_jsonb(self): + self.create_table('jsonb_test', 'n smallint primary key, data jsonb') jsondecode = pg.get_jsondecode() # insert null value @@ -3485,9 +3487,9 @@ def testInsertGetJsonb(self): self.assertIsInstance(r['tags'], list) self.assertIsInstance(r['stock'], dict) - def testArray(self): + def test_array(self): returns_arrays = pg.get_array() - self.createTable( + self.create_table( 'arraytest', 'id smallint, i2 smallint[], i4 integer[], i8 bigint[],' ' d numeric[], f4 real[], f8 double precision[], m money[],' @@ -3545,10 +3547,10 @@ def testArray(self): else: self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}') - def testArrayLiteral(self): + def test_array_literal(self): insert = self.db.insert returns_arrays = pg.get_array() - self.createTable('arraytest', 'i int[], t text[]') + self.create_table('arraytest', 'i int[], t text[]') r = dict(i=[1, 2, 3], t=['a', 'b', 'c']) insert('arraytest', r) if returns_arrays: @@ -3565,8 +3567,8 @@ def testArrayLiteral(self): else: self.assertEqual(r['i'], '{1,2,3}') self.assertEqual(r['t'], '{a,b,c}') - L = pg.Literal - r = dict(i=L("ARRAY[1, 2, 3]"), t=L("ARRAY['a', 'b', 'c']")) + Lit = pg.Literal # noqa: N806 + r = dict(i=Lit("ARRAY[1, 2, 3]"), t=Lit("ARRAY['a', 'b', 'c']")) self.db.insert('arraytest', r) if returns_arrays: self.assertEqual(r['i'], [1, 2, 3]) @@ -3577,9 +3579,9 @@ def testArrayLiteral(self): r = dict(i="1, 2, 3", t="'a', 'b', 'c'") self.assertRaises(pg.DataError, self.db.insert, 'arraytest', r) - def testArrayOfIds(self): + def test_array_of_ids(self): array_on = pg.get_array() - self.createTable( + self.create_table( 'arraytest', 'i serial primary key, c cid[], o oid[], x xid[]') r = self.db.get_attnames('arraytest') if self.regtypes: @@ -3601,9 +3603,9 @@ def testArrayOfIds(self): else: self.assertEqual(r['o'], '{21,22,23}') - def testArrayOfText(self): + def test_array_of_text(self): array_on = pg.get_array() - self.createTable('arraytest', 'id serial primary key, data text[]') + self.create_table('arraytest', 'id serial primary key, data text[]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'text[]') data = ['Hello, World!', '', None, '{a,b,c}', '"Hi!"', @@ -3625,10 +3627,10 @@ def testArrayOfText(self): self.assertIsNone(r['data'][2]) # noinspection PyUnresolvedReferences - def testArrayOfBytea(self): + def test_array_of_bytea(self): array_on = pg.get_array() bytea_escaped = pg.get_bytea_escaped() - self.createTable('arraytest', 'id serial primary key, data bytea[]') + self.create_table('arraytest', 'id serial primary key, data bytea[]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'bytea[]') data = [b'Hello, World!', b'', None, b'{a,b,c}', b'"Hi!"', @@ -3654,8 +3656,8 @@ def testArrayOfBytea(self): else: self.assertNotEqual(r['data'], data) - def testArrayOfJson(self): - self.createTable('arraytest', 'id serial primary key, data json[]') + def test_array_of_json(self): + self.create_table('arraytest', 'id serial primary key, data json[]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'json[]') data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')] @@ -3696,8 +3698,8 @@ def testArrayOfJson(self): else: self.assertEqual(r, '{NULL,NULL}') - def testArrayOfJsonb(self): - self.createTable('arraytest', 'id serial primary key, data jsonb[]') + def test_array_of_jsonb(self): + self.create_table('arraytest', 'id serial primary key, data jsonb[]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'jsonb[]' if self.regtypes else 'json[]') data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')] @@ -3739,9 +3741,9 @@ def testArrayOfJsonb(self): self.assertEqual(r, '{NULL,NULL}') # noinspection PyUnresolvedReferences - def testDeepArray(self): + def test_deep_array(self): array_on = pg.get_array() - self.createTable( + self.create_table( 'arraytest', 'id serial primary key, data text[][][]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'text[]') @@ -3760,13 +3762,13 @@ def testDeepArray(self): self.assertTrue(r['data'].startswith('{{{"Hello,')) # noinspection PyUnresolvedReferences - def testInsertUpdateGetRecord(self): + def test_insert_update_get_record(self): query = self.db.query query('create type test_person_type as' ' (name varchar, age smallint, married bool,' ' weight real, salary money)') self.addCleanup(query, 'drop type test_person_type') - self.createTable('test_person', + self.create_table('test_person', 'id serial primary key, person test_person_type', oids=False, temporary=False) attnames = self.db.get_attnames('test_person') @@ -3859,12 +3861,12 @@ def testInsertUpdateGetRecord(self): self.assertIsNone(r['person']) # noinspection PyUnresolvedReferences - def testRecordInsertBytea(self): + def test_record_insert_bytea(self): query = self.db.query query('create type test_person_type as' ' (name text, picture bytea)') self.addCleanup(query, 'drop type test_person_type') - self.createTable('test_person', 'person test_person_type', + self.create_table('test_person', 'person test_person_type', temporary=False) person_typ = self.db.get_attnames('test_person')['person'] self.assertEqual(person_typ.attnames, @@ -3879,11 +3881,11 @@ def testRecordInsertBytea(self): self.assertEqual(p.picture, person[1]) self.assertIsInstance(p.picture, bytes) - def testRecordInsertJson(self): + def test_record_insert_json(self): query = self.db.query query('create type test_person_type as (name text, data json)') self.addCleanup(query, 'drop type test_person_type') - self.createTable('test_person', 'person test_person_type', + self.create_table('test_person', 'person test_person_type', temporary=False) person_typ = self.db.get_attnames('test_person')['person'] self.assertEqual(person_typ.attnames, @@ -3902,12 +3904,12 @@ def testRecordInsertJson(self): self.assertIsInstance(p.data, dict) # noinspection PyUnresolvedReferences - def testRecordLiteral(self): + def test_record_literal(self): query = self.db.query query('create type test_person_type as' ' (name varchar, age smallint)') self.addCleanup(query, 'drop type test_person_type') - self.createTable('test_person', 'person test_person_type', + self.create_table('test_person', 'person test_person_type', temporary=False) person_typ = self.db.get_attnames('test_person')['person'] if self.regtypes: @@ -3929,7 +3931,7 @@ def testRecordLiteral(self): self.assertEqual(p.age, 61) self.assertIsInstance(p.age, int) - def testDate(self): + def test_date(self): query = self.db.query for datestyle in ( 'ISO', 'Postgres, MDY', 'Postgres, DMY', @@ -3953,7 +3955,7 @@ def testDate(self): self.assertEqual(r[0], date.max) self.assertEqual(r[1], date.min) - def testTime(self): + def test_time(self): query = self.db.query d = time(15, 9, 26) q = "select $1::time" @@ -3966,7 +3968,7 @@ def testTime(self): self.assertIsInstance(r, time) self.assertEqual(r, d) - def testTimetz(self): + def test_timetz(self): query = self.db.query timezones = dict(CET=1, EET=2, EST=-5, UTC=0) for timezone in sorted(timezones): @@ -3984,7 +3986,7 @@ def testTimetz(self): self.assertIsInstance(r, time) self.assertEqual(r, d) - def testTimestamp(self): + def test_timestamp(self): query = self.db.query for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY', 'SQL, MDY', 'SQL, DMY', 'German'): @@ -4018,7 +4020,7 @@ def testTimestamp(self): self.assertEqual(r[0], datetime.max) self.assertEqual(r[1], datetime.min) - def testTimestamptz(self): + def test_timestamptz(self): query = self.db.query timezones = dict(CET=1, EET=2, EST=-5, UTC=0) for timezone in sorted(timezones): @@ -4057,7 +4059,7 @@ def testTimestamptz(self): self.assertEqual(r[0], datetime.max) self.assertEqual(r[1], datetime.min) - def testInterval(self): + def test_interval(self): query = self.db.query for intervalstyle in ( 'sql_standard', 'postgres', 'postgres_verbose', 'iso_8601'): @@ -4077,7 +4079,7 @@ def testInterval(self): self.assertIsInstance(r, timedelta) self.assertEqual(r, d) - def testDateAndTimeArrays(self): + def test_date_and_time_arrays(self): dt = (date(2016, 3, 14), time(15, 9, 26)) q = "select ARRAY[$1::date], ARRAY[$2::time]" r = self.db.query(q, dt).getresult()[0] @@ -4086,7 +4088,7 @@ def testDateAndTimeArrays(self): self.assertIsInstance(r[1], list) self.assertEqual(r[1][0], dt[1]) - def testHstore(self): + def test_hstore(self): try: self.db.query("select 'k=>v'::hstore") except pg.DatabaseError: @@ -4103,14 +4105,14 @@ def testHstore(self): self.assertIsInstance(r, dict) self.assertEqual(r, d) - def testUuid(self): + def test_uuid(self): d = UUID('{12345678-1234-5678-1234-567812345678}') q = 'select $1::uuid' r = self.db.query(q, (d,)).getresult()[0][0] self.assertIsInstance(r, UUID) self.assertEqual(r, d) - def testDbTypesInfo(self): + def test_db_types_info(self): dbtypes = self.db.dbtypes self.assertIsInstance(dbtypes, dict) self.assertNotIn('numeric', dbtypes) @@ -4158,7 +4160,7 @@ def testDbTypesInfo(self): self.assertEqual(typlen.category, 'N') # numeric # noinspection PyUnresolvedReferences - def testDbTypesTypecast(self): + def test_db_types_typecast(self): dbtypes = self.db.dbtypes self.assertIsInstance(dbtypes, dict) self.assertNotIn('int4', dbtypes) @@ -4185,7 +4187,7 @@ def testDbTypesTypecast(self): dbtypes.reset_typecast('circle') self.assertIsNone(dbtypes.get_typecast('circle')) - def testGetSetTypeCast(self): + def test_get_set_type_cast(self): get_typecast = pg.get_typecast set_typecast = pg.set_typecast dbtypes = self.db.dbtypes @@ -4209,7 +4211,7 @@ def testGetSetTypeCast(self): set_typecast('circle', cast_circle) self.assertIs(get_typecast('circle'), cast_circle) - def testNotificationHandler(self): + def test_notification_handler(self): # the notification handler itself is tested separately f = self.db.notification_handler callback = lambda arg_dict: None # noqa: E731 @@ -4286,11 +4288,11 @@ def testNotificationHandler(self): self.db.reopen() self.assertIsNone(handler.db) - def testInserttableFromQuery(self): + def test_inserttable_from_query(self): # use inserttable() to copy from one table to another query = self.db.query - self.createTable('test_table_from', 'n integer, t timestamp') - self.createTable('test_table_to', 'n integer, t timestamp') + self.create_table('test_table_from', 'n integer, t timestamp') + self.create_table('test_table_to', 'n integer, t timestamp') for i in range(1, 4): query("insert into test_table_from values ($1, now())", i) n = self.db.inserttable( @@ -4355,7 +4357,7 @@ def tearDown(self): except pg.InternalError: pass - def testGuessSimpleType(self): + def test_guess_simple_type(self): f = self.adapter.guess_simple_type self.assertEqual(f(pg.Bytea(b'test')), 'bytea') self.assertEqual(f('string'), 'text') @@ -4376,7 +4378,7 @@ def testGuessSimpleType(self): self.assertEqual(list(r.attnames.values()), [ 'text', 'bool', 'int', 'float', 'int[]', 'bool[]']) - def testAdaptQueryTypedList(self): + def test_adapt_query_typed_list(self): format_query = self.adapter.format_query self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), ('int2',)) self.assertRaises( @@ -4416,7 +4418,7 @@ def testAdaptQueryTypedList(self): self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) - def testAdaptQueryTypedListWithTypesAsString(self): + def test_adapt_query_typed_list_with_types_as_string(self): format_query = self.adapter.format_query self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), 'int2') self.assertRaises( @@ -4427,7 +4429,7 @@ def testAdaptQueryTypedListWithTypesAsString(self): self.assertEqual(sql, 'select $1,$2,$3,$4') self.assertEqual(params, [3, 7.5, 'hello', 't']) - def testAdaptQueryTypedListWithTypesAsClasses(self): + def test_adapt_query_typed_list_with_types_as_classes(self): format_query = self.adapter.format_query self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), (int,)) self.assertRaises( @@ -4438,7 +4440,7 @@ def testAdaptQueryTypedListWithTypesAsClasses(self): self.assertEqual(sql, 'select $1,$2,$3,$4') self.assertEqual(params, [3, 7.5, 'hello', 't']) - def testAdaptQueryTypedListWithJson(self): + def test_adapt_query_typed_list_with_json(self): format_query = self.adapter.format_query value = {'test': [1, "it's fine", 3]} sql, params = format_query("select %s", (value,), 'json') @@ -4453,7 +4455,7 @@ def testAdaptQueryTypedListWithJson(self): self.assertEqual(sql, 'select $1') self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}']) - def testAdaptQueryTypedWithHstore(self): + def test_adapt_query_typed_with_hstore(self): format_query = self.adapter.format_query value = {'one': "it's fine", 'two': 2} sql, params = format_query("select %s", (value,), 'hstore') @@ -4468,7 +4470,7 @@ def testAdaptQueryTypedWithHstore(self): self.assertEqual(sql, "select $1") self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) - def testAdaptQueryTypedWithUuid(self): + def test_adapt_query_typed_with_uuid(self): format_query = self.adapter.format_query value = '12345678-1234-5678-1234-567812345678' sql, params = format_query("select %s", (value,), 'uuid') @@ -4483,7 +4485,7 @@ def testAdaptQueryTypedWithUuid(self): self.assertEqual(sql, "select $1") self.assertEqual(params, ['12345678-1234-5678-1234-567812345678']) - def testAdaptQueryTypedDict(self): + def test_adapt_query_typed_dict(self): format_query = self.adapter.format_query self.assertRaises( TypeError, format_query, @@ -4527,7 +4529,7 @@ def testAdaptQueryTypedDict(self): self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) - def testAdaptQueryUntypedList(self): + def test_adapt_query_untyped_list(self): format_query = self.adapter.format_query values = (3, 7.5, 'hello', True) sql, params = format_query("select %s,%s,%s,%s", values) @@ -4552,21 +4554,21 @@ def testAdaptQueryUntypedList(self): self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) - def testAdaptQueryUntypedListWithJson(self): + def test_adapt_query_untyped_list_with_json(self): format_query = self.adapter.format_query value = pg.Json({'test': [1, "it's fine", 3]}) sql, params = format_query("select %s", (value,)) self.assertEqual(sql, 'select $1') self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}']) - def testAdaptQueryUntypedWithHstore(self): + def test_adapt_query_untyped_with_hstore(self): format_query = self.adapter.format_query value = pg.Hstore({'one': "it's fine", 'two': 2}) sql, params = format_query("select %s", (value,)) self.assertEqual(sql, "select $1") self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) - def testAdaptQueryUntypedDict(self): + def test_adapt_query_untyped_dict(self): format_query = self.adapter.format_query values = dict(i=3, f=7.5, t='hello', b=True) sql, params = format_query( @@ -4593,7 +4595,7 @@ def testAdaptQueryUntypedDict(self): self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) - def testAdaptQueryInlineList(self): + def test_adapt_query_inline_list(self): format_query = self.adapter.format_query values = (3, 7.5, 'hello', True) sql, params = format_query("select %s,%s,%s,%s", values, inline=True) @@ -4621,7 +4623,7 @@ def testAdaptQueryInlineList(self): sql, "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])") self.assertEqual(params, []) - def testAdaptQueryInlineListWithJson(self): + def test_adapt_query_inline_list_with_json(self): format_query = self.adapter.format_query value = pg.Json({'test': [1, "it's fine", 3]}) sql, params = format_query("select %s", (value,), inline=True) @@ -4629,7 +4631,7 @@ def testAdaptQueryInlineListWithJson(self): sql, "select '{\"test\": [1, \"it''s fine\", 3]}'::json") self.assertEqual(params, []) - def testAdaptQueryInlineListWithHstore(self): + def test_adapt_query_inline_list_with_hstore(self): format_query = self.adapter.format_query value = pg.Hstore({'one': "it's fine", 'two': 2}) sql, params = format_query("select %s", (value,), inline=True) @@ -4637,7 +4639,7 @@ def testAdaptQueryInlineListWithHstore(self): sql, "select 'one=>\"it''s fine\",two=>2'::hstore") self.assertEqual(params, []) - def testAdaptQueryInlineDict(self): + def test_adapt_query_inline_dict(self): format_query = self.adapter.format_query values = dict(i=3, f=7.5, t='hello', b=True) sql, params = format_query( @@ -4668,7 +4670,7 @@ def testAdaptQueryInlineDict(self): sql, "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])") self.assertEqual(params, []) - def testAdaptQueryWithPgRepr(self): + def test_adapt_query_with_pg_repr(self): format_query = self.adapter.format_query self.assertRaises(TypeError, format_query, '%s', object(), inline=True) @@ -4739,7 +4741,7 @@ def tearDown(self): self.doCleanups() self.db.close() - def testGetTables(self): + def test_get_tables(self): tables = self.db.get_tables() for num_schema in range(5): if num_schema: @@ -4750,7 +4752,7 @@ def testGetTables(self): schema + ".t" + str(num_schema)): self.assertIn(t, tables) - def testGetAttnames(self): + def test_get_attnames(self): get_attnames = self.db.get_attnames query = self.db.query result = {'d': 'int', 'n': 'int'} @@ -4774,10 +4776,10 @@ def testGetAttnames(self): r = get_attnames("t3m") self.assertEqual(r, result_m) - def testGet(self): + def test_get(self): get = self.db.get query = self.db.query - PrgError = pg.ProgrammingError + PrgError = pg.ProgrammingError # noqa: N806 self.assertEqual(get("t", 1, 'n')['d'], 0) self.assertEqual(get("t0", 1, 'n')['d'], 0) self.assertEqual(get("public.t", 1, 'n')['d'], 0) @@ -4798,7 +4800,7 @@ def testGet(self): self.assertEqual(get("t", 1, 'n')['d'], 1) self.assertEqual(get("s4.t4", 1, 'n')['d'], 4) - def testMunging(self): + def test_munging(self): get = self.db.get query = self.db.query r = get("t", 1, 'n') @@ -4819,7 +4821,7 @@ def testMunging(self): else: self.assertNotIn('oid(t)', r) - def testQueryInformationSchema(self): + def test_query_information_schema(self): q = "column_name" if self.db.server_version < 110000: q += "::text" # old version does not have sql_identifier array @@ -4853,30 +4855,30 @@ def send_queries(self): self.db.query("select 1") self.db.query("select 2") - def testDebugDefault(self): + def test_debug_default(self): if debug: self.assertEqual(self.db.debug, debug) else: self.assertIsNone(self.db.debug) - def testDebugIsFalse(self): + def test_debug_is_false(self): self.db.debug = False self.send_queries() self.assertEqual(self.get_output(), "") - def testDebugIsTrue(self): + def test_debug_is_true(self): self.db.debug = True self.send_queries() self.assertEqual(self.get_output(), "select 1\nselect 2\n") - def testDebugIsString(self): + def test_debug_is_string(self): self.db.debug = "Test with string: %s." self.send_queries() self.assertEqual( self.get_output(), "Test with string: select 1.\nTest with string: select 2.\n") - def testDebugIsFileLike(self): + def test_debug_is_file_like(self): with tempfile.TemporaryFile('w+') as debug_file: self.db.debug = debug_file self.send_queries() @@ -4885,7 +4887,7 @@ def testDebugIsFileLike(self): self.assertEqual(output, "select 1\nselect 2\n") self.assertEqual(self.get_output(), "") - def testDebugIsCallable(self): + def test_debug_is_callable(self): output = [] self.db.debug = output.append self.db.query("select 1") @@ -4893,7 +4895,7 @@ def testDebugIsCallable(self): self.assertEqual(output, ["select 1", "select 2"]) self.assertEqual(self.get_output(), "") - def testDebugMultipleArgs(self): + def test_debug_multiple_args(self): output = [] self.db.debug = output.append args = ['Error', 42, {1: 'a', 2: 'b'}, [3, 5, 7]] @@ -4905,7 +4907,7 @@ def testDebugMultipleArgs(self): class TestMemoryLeaks(unittest.TestCase): """Test that the DB class does not leak memory.""" - def getLeaks(self, fut): + def get_leaks(self, fut): ids = set() objs = [] add_ids = ids.update @@ -4918,20 +4920,20 @@ def getLeaks(self, fut): objs[:] = [obj for obj in objs if id(obj) not in ids] self.assertEqual(len(objs), 0) - def testLeaksWithClose(self): + def test_leaks_with_close(self): def fut(): db = DB() db.query("select $1::int as r", 42).dictresult() db.close() - self.getLeaks(fut) + self.get_leaks(fut) - def testLeaksWithoutClose(self): + def test_leaks_without_close(self): def fut(): db = DB() db.query("select $1::int as r", 42).dictresult() - self.getLeaks(fut) + self.get_leaks(fut) if __name__ == '__main__': diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 914450f5..5a49e9d2 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -20,59 +20,59 @@ class TestHasConnect(unittest.TestCase): """Test existence of basic pg module functions.""" - def testhasPgError(self): + def testhas_pg_error(self): self.assertTrue(issubclass(pg.Error, Exception)) - def testhasPgWarning(self): + def testhas_pg_warning(self): self.assertTrue(issubclass(pg.Warning, Exception)) - def testhasPgInterfaceError(self): + def testhas_pg_interface_error(self): self.assertTrue(issubclass(pg.InterfaceError, pg.Error)) - def testhasPgDatabaseError(self): + def testhas_pg_database_error(self): self.assertTrue(issubclass(pg.DatabaseError, pg.Error)) - def testhasPgInternalError(self): + def testhas_pg_internal_error(self): self.assertTrue(issubclass(pg.InternalError, pg.DatabaseError)) - def testhasPgOperationalError(self): + def testhas_pg_operational_error(self): self.assertTrue(issubclass(pg.OperationalError, pg.DatabaseError)) - def testhasPgProgrammingError(self): + def testhas_pg_programming_error(self): self.assertTrue(issubclass(pg.ProgrammingError, pg.DatabaseError)) - def testhasPgIntegrityError(self): + def testhas_pg_integrity_error(self): self.assertTrue(issubclass(pg.IntegrityError, pg.DatabaseError)) - def testhasPgDataError(self): + def testhas_pg_data_error(self): self.assertTrue(issubclass(pg.DataError, pg.DatabaseError)) - def testhasPgNotSupportedError(self): + def testhas_pg_not_supported_error(self): self.assertTrue(issubclass(pg.NotSupportedError, pg.DatabaseError)) - def testhasPgInvalidResultError(self): + def testhas_pg_invalid_result_error(self): self.assertTrue(issubclass(pg.InvalidResultError, pg.DataError)) - def testhasPgNoResultError(self): + def testhas_pg_no_result_error(self): self.assertTrue(issubclass(pg.NoResultError, pg.InvalidResultError)) - def testhasPgMultipleResultsError(self): + def testhas_pg_multiple_results_error(self): self.assertTrue( issubclass(pg.MultipleResultsError, pg.InvalidResultError)) - def testhasConnect(self): + def testhas_connect(self): self.assertTrue(callable(pg.connect)) - def testhasEscapeString(self): + def testhas_escape_string(self): self.assertTrue(callable(pg.escape_string)) - def testhasEscapeBytea(self): + def testhas_escape_bytea(self): self.assertTrue(callable(pg.escape_bytea)) - def testhasUnescapeBytea(self): + def testhas_unescape_bytea(self): self.assertTrue(callable(pg.unescape_bytea)) - def testDefHost(self): + def test_def_host(self): d0 = pg.get_defhost() d1 = 'pgtesthost' pg.set_defhost(d1) @@ -80,7 +80,7 @@ def testDefHost(self): pg.set_defhost(d0) self.assertEqual(pg.get_defhost(), d0) - def testDefPort(self): + def test_def_port(self): d0 = pg.get_defport() d1 = 1234 pg.set_defport(d1) @@ -92,7 +92,7 @@ def testDefPort(self): d0 = None self.assertEqual(pg.get_defport(), d0) - def testDefOpt(self): + def test_def_opt(self): d0 = pg.get_defopt() d1 = '-h pgtesthost -p 1234' pg.set_defopt(d1) @@ -100,7 +100,7 @@ def testDefOpt(self): pg.set_defopt(d0) self.assertEqual(pg.get_defopt(), d0) - def testDefBase(self): + def test_def_base(self): d0 = pg.get_defbase() d1 = 'pgtestdb' pg.set_defbase(d1) @@ -108,7 +108,7 @@ def testDefBase(self): pg.set_defbase(d0) self.assertEqual(pg.get_defbase(), d0) - def testPqlibVersion(self): + def test_pqlib_version(self): # noinspection PyUnresolvedReferences v = pg.get_pqlib_version() self.assertIsInstance(v, int) @@ -216,7 +216,7 @@ class TestParseArray(unittest.TestCase): ('[3:5]={{1,2,3},{4,5,6}}', int, ValueError), ('[1:1][-2:-1][3:5]={{1,2,3},{4,5,6}}', int, ValueError)] - def testParserParams(self): + def test_parser_params(self): f = pg.cast_array self.assertRaises(TypeError, f) self.assertRaises(TypeError, f, None) @@ -235,13 +235,13 @@ def testParserParams(self): self.assertEqual(f('{}', str), []) self.assertEqual(f('{}', str, b';'), []) - def testParserSimple(self): + def test_parser_simple(self): r = pg.cast_array('{a,b,c}') self.assertIsInstance(r, list) self.assertEqual(len(r), 3) self.assertEqual(r, ['a', 'b', 'c']) - def testParserNested(self): + def test_parser_nested(self): f = pg.cast_array r = f('{{a,b,c}}') self.assertIsInstance(r, list) @@ -273,7 +273,7 @@ def testParserNested(self): r = r[0] self.assertEqual(r, 'abc') - def testParserTooDeeplyNested(self): + def test_parser_too_deeply_nested(self): f = pg.cast_array for n in 3, 5, 9, 12, 16, 32, 64, 256: r = '{' * n + 'a,b,c' + '}' * n @@ -288,7 +288,7 @@ def testParserTooDeeplyNested(self): self.assertEqual(len(r), 3) self.assertEqual(r, ['a', 'b', 'c']) - def testParserCast(self): + def test_parser_cast(self): f = pg.cast_array self.assertEqual(f('{1}'), ['1']) self.assertEqual(f('{1}', None), ['1']) @@ -303,7 +303,7 @@ def cast(s): return f'{s} is ok' self.assertEqual(f('{a}', cast), ['a is ok']) - def testParserDelim(self): + def test_parser_delim(self): f = pg.cast_array self.assertEqual(f('{1,2}'), ['1', '2']) self.assertEqual(f('{1,2}', delim=b','), ['1', '2']) @@ -311,7 +311,7 @@ def testParserDelim(self): self.assertEqual(f('{1;2}', delim=b';'), ['1', '2']) self.assertEqual(f('{1,2}', delim=b';'), ['1,2']) - def testParserWithData(self): + def test_parser_with_data(self): f = pg.cast_array for string, cast, expected in self.test_strings: if expected is ValueError: @@ -319,7 +319,7 @@ def testParserWithData(self): else: self.assertEqual(f(string, cast), expected) - def testParserWithoutCast(self): + def test_parser_without_cast(self): f = pg.cast_array for string, cast, expected in self.test_strings: @@ -330,7 +330,7 @@ def testParserWithoutCast(self): else: self.assertEqual(f(string), expected) - def testParserWithDifferentDelimiter(self): + def test_parser_with_different_delimiter(self): f = pg.cast_array def replace_comma(value): @@ -491,7 +491,7 @@ class TestParseRecord(unittest.TestCase): ('(fuzzy dice,"42","1.9375")', (str, int, float), ('fuzzy dice', 42, 1.9375))] - def testParserParams(self): + def test_parser_params(self): f = pg.cast_record self.assertRaises(TypeError, f) self.assertRaises(TypeError, f, None) @@ -510,20 +510,20 @@ def testParserParams(self): self.assertEqual(f('()', str), (None,)) self.assertEqual(f('()', str, b';'), (None,)) - def testParserSimple(self): + def test_parser_simple(self): r = pg.cast_record('(a,b,c)') self.assertIsInstance(r, tuple) self.assertEqual(len(r), 3) self.assertEqual(r, ('a', 'b', 'c')) - def testParserNested(self): + def test_parser_nested(self): f = pg.cast_record self.assertRaises(ValueError, f, '((a,b,c))') self.assertRaises(ValueError, f, '((a,b),(c,d))') self.assertRaises(ValueError, f, '((a),(b),(c))') self.assertRaises(ValueError, f, '(((((((abc)))))))') - def testParserManyElements(self): + def test_parser_many_elements(self): f = pg.cast_record for n in 3, 5, 9, 12, 16, 32, 64, 256: r = ','.join(map(str, range(n))) @@ -531,7 +531,7 @@ def testParserManyElements(self): r = f(r, int) self.assertEqual(r, tuple(range(n))) - def testParserCastUniform(self): + def test_parser_cast_uniform(self): f = pg.cast_record self.assertEqual(f('(1)'), ('1',)) self.assertEqual(f('(1)', None), ('1',)) @@ -546,7 +546,7 @@ def cast(s): return f'{s} is ok' self.assertEqual(f('(a)', cast), ('a is ok',)) - def testParserCastNonUniform(self): + def test_parser_cast_non_uniform(self): f = pg.cast_record self.assertEqual(f('(1)', []), ('1',)) self.assertEqual(f('(1)', [None]), ('1',)) @@ -583,7 +583,7 @@ def cast2(s): f('(1,2,3,4,5,6)', [int, float, str, None, cast1, cast2]), (1, 2.0, '3', '4', '5 is ok', 'and 6 is ok, too')) - def testParserDelim(self): + def test_parser_delim(self): f = pg.cast_record self.assertEqual(f('(1,2)'), ('1', '2')) self.assertEqual(f('(1,2)', delim=b','), ('1', '2')) @@ -591,7 +591,7 @@ def testParserDelim(self): self.assertEqual(f('(1;2)', delim=b';'), ('1', '2')) self.assertEqual(f('(1,2)', delim=b';'), ('1,2',)) - def testParserWithData(self): + def test_parser_with_data(self): f = pg.cast_record for string, cast, expected in self.test_strings: if expected is ValueError: @@ -599,7 +599,7 @@ def testParserWithData(self): else: self.assertEqual(f(string, cast), expected) - def testParserWithoutCast(self): + def test_parser_without_cast(self): f = pg.cast_record for string, cast, expected in self.test_strings: @@ -610,7 +610,7 @@ def testParserWithoutCast(self): else: self.assertEqual(f(string), expected) - def testParserWithDifferentDelimiter(self): + def test_parser_with_different_delimiter(self): f = pg.cast_record def replace_comma(value): @@ -665,7 +665,7 @@ class TestParseHStore(unittest.TestCase): (r'k\=\>v=>"k=>v"', {'k=>v': 'k=>v'}), ('a\\,b=>a,b=>a', {'a,b': 'a', 'b': 'a'})] - def testParser(self): + def test_parser(self): f = pg.cast_hstore self.assertRaises(TypeError, f) @@ -842,7 +842,7 @@ class TestCastInterval(unittest.TestCase): '@ 10 mons 3 days -3 hours -55 mins -5.999993 secs ago', 'P-10M-3DT3H55M5.999993S'))] - def testCastInterval(self): + def test_cast_interval(self): for result, values in self.intervals: f = pg.cast_interval years, mons, days, hours, mins, secs, usecs = result @@ -864,7 +864,7 @@ class TestEscapeFunctions(unittest.TestCase): """ - def testEscapeString(self): + def test_escape_string(self): f = pg.escape_string r = f(b'plain') self.assertIsInstance(r, bytes) @@ -876,7 +876,7 @@ def testEscapeString(self): self.assertIsInstance(r, str) self.assertEqual(r, "that''s cheese") - def testEscapeBytea(self): + def test_escape_bytea(self): f = pg.escape_bytea r = f(b'plain') self.assertIsInstance(r, bytes) @@ -888,7 +888,7 @@ def testEscapeBytea(self): self.assertIsInstance(r, str) self.assertEqual(r, "that''s cheese") - def testUnescapeBytea(self): + def test_unescape_bytea(self): f = pg.unescape_bytea r = f(b'plain') self.assertIsInstance(r, bytes) @@ -916,10 +916,10 @@ class TestConfigFunctions(unittest.TestCase): """ - def testGetDatestyle(self): + def test_get_datestyle(self): self.assertIsNone(pg.get_datestyle()) - def testSetDatestyle(self): + def test_set_datestyle(self): datestyle = pg.get_datestyle() try: pg.set_datestyle('ISO, YMD') @@ -939,12 +939,12 @@ def testSetDatestyle(self): finally: pg.set_datestyle(datestyle) - def testGetDecimalPoint(self): + def test_get_decimal_point(self): r = pg.get_decimal_point() self.assertIsInstance(r, str) self.assertEqual(r, '.') - def testSetDecimalPoint(self): + def test_set_decimal_point(self): point = pg.get_decimal_point() try: pg.set_decimal_point('*') @@ -957,11 +957,11 @@ def testSetDecimalPoint(self): self.assertIsInstance(r, str) self.assertEqual(r, point) - def testGetDecimal(self): + def test_get_decimal(self): r = pg.get_decimal() self.assertIs(r, pg.Decimal) - def testSetDecimal(self): + def test_set_decimal(self): decimal_class = pg.Decimal try: pg.set_decimal(int) @@ -972,12 +972,12 @@ def testSetDecimal(self): r = pg.get_decimal() self.assertIs(r, decimal_class) - def testGetBool(self): + def test_get_bool(self): r = pg.get_bool() self.assertIsInstance(r, bool) self.assertIs(r, True) - def testSetBool(self): + def test_set_bool(self): use_bool = pg.get_bool() try: pg.set_bool(False) @@ -995,12 +995,12 @@ def testSetBool(self): self.assertIsInstance(r, bool) self.assertIs(r, use_bool) - def testGetByteaEscaped(self): + def test_get_bytea_escaped(self): r = pg.get_bytea_escaped() self.assertIsInstance(r, bool) self.assertIs(r, False) - def testSetByteaEscaped(self): + def test_set_bytea_escaped(self): bytea_escaped = pg.get_bytea_escaped() try: pg.set_bytea_escaped(True) @@ -1018,12 +1018,12 @@ def testSetByteaEscaped(self): self.assertIsInstance(r, bool) self.assertIs(r, bytea_escaped) - def testGetJsondecode(self): + def test_get_jsondecode(self): r = pg.get_jsondecode() self.assertTrue(callable(r)) self.assertIs(r, json.loads) - def testSetJsondecode(self): + def test_set_jsondecode(self): jsondecode = pg.get_jsondecode() try: pg.set_jsondecode(None) @@ -1042,7 +1042,7 @@ def testSetJsondecode(self): class TestModuleConstants(unittest.TestCase): """Test the existence of the documented module constants.""" - def testVersion(self): + def test_version(self): v = pg.version self.assertIsInstance(v, str) # make sure the version conforms to PEP440 diff --git a/tests/test_classic_largeobj.py b/tests/test_classic_largeobj.py index 039ca51f..afe48a21 100755 --- a/tests/test_classic_largeobj.py +++ b/tests/test_classic_largeobj.py @@ -32,7 +32,7 @@ def connect(): class TestModuleConstants(unittest.TestCase): """Test the existence of the documented module constants.""" - def testLargeObjectIntConstants(self): + def test_large_object_int_constants(self): names = 'INV_READ INV_WRITE SEEK_SET SEEK_CUR SEEK_END'.split() for name in names: try: @@ -53,7 +53,7 @@ def tearDown(self): self.c.query('rollback') self.c.close() - def assertIsLargeObject(self, obj): + def assertIsLargeObject(self, obj): # noqa: N802 self.assertIsNotNone(obj) self.assertTrue(hasattr(obj, 'open')) self.assertTrue(hasattr(obj, 'close')) @@ -66,14 +66,14 @@ def assertIsLargeObject(self, obj): self.assertIsInstance(obj.error, str) self.assertFalse(obj.error) - def testLoCreate(self): + def test_lo_create(self): large_object = self.c.locreate(pg.INV_READ | pg.INV_WRITE) try: self.assertIsLargeObject(large_object) finally: del large_object - def testGetLo(self): + def test_get_lo(self): large_object = self.c.locreate(pg.INV_READ | pg.INV_WRITE) try: self.assertIsLargeObject(large_object) @@ -103,7 +103,7 @@ def testGetLo(self): self.assertIsInstance(r, bytes) self.assertEqual(r, data) - def testLoImport(self): + def test_lo_import(self): if windows: # NamedTemporaryFiles don't work well here fname = 'temp_test_pg_largeobj_import.txt' @@ -164,24 +164,24 @@ def tearDown(self): pass self.pgcnx.close() - def testClassName(self): + def test_class_name(self): self.assertEqual(self.obj.__class__.__name__, 'LargeObject') - def testModuleName(self): + def test_module_name(self): self.assertEqual(self.obj.__class__.__module__, 'pg') - def testOid(self): + def test_oid(self): self.assertIsInstance(self.obj.oid, int) self.assertNotEqual(self.obj.oid, 0) - def testPgcn(self): + def test_pgcn(self): self.assertIs(self.obj.pgcnx, self.pgcnx) - def testError(self): + def test_error(self): self.assertIsInstance(self.obj.error, str) self.assertEqual(self.obj.error, '') - def testStr(self): + def test_str(self): self.obj.open(pg.INV_WRITE) data = b'some object to be printed' self.obj.write(data) @@ -192,11 +192,11 @@ def testStr(self): r = str(self.obj) self.assertEqual(r, f'Closed large object, oid {oid}') - def testRepr(self): + def test_repr(self): r = repr(self.obj) self.assertTrue(r.startswith(' Date: Sat, 2 Sep 2023 01:10:05 +0200 Subject: [PATCH 041/118] Minor improvements using more ruff specific linting --- docs/contents/changelog.rst | 1 + docs/contents/pgdb/types.rst | 6 +-- pg.py | 39 +++++++++--------- pgdb.py | 69 ++++++++++++++++---------------- pyproject.toml | 1 + setup.py | 6 +-- tests/dbapi20.py | 8 ++-- tests/test_classic_connection.py | 9 +++-- tests/test_classic_functions.py | 9 +++-- tests/test_dbapi20.py | 5 ++- tests/test_dbapi20_copy.py | 14 ++++--- 11 files changed, 88 insertions(+), 79 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 67408993..d240daa2 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -7,6 +7,7 @@ Version 6.0 (to be released) and PostgreSQL older than version 10 (released October 2017). - Removed deprecated function `pg.pgnotify()`. - Removed the deprecated method `ntuples()` of the `pg.Query` object. +- Renamed `pgdb.Type` to `pgdb.DbType` to avoid confusion with `typing.Type`. - Modernized code and tools for development, testing, linting and building. Version 5.2.5 (2023-08-28) diff --git a/docs/contents/pgdb/types.rst b/docs/contents/pgdb/types.rst index f28e23f7..d739df32 100644 --- a/docs/contents/pgdb/types.rst +++ b/docs/contents/pgdb/types.rst @@ -101,15 +101,15 @@ Example for using a type constructor:: Type objects ------------ -.. class:: Type +.. class:: DbType The :attr:`Cursor.description` attribute returns information about each of the result columns of a query. The *type_code* must compare equal to one -of the :class:`Type` objects defined below. Type objects can be equal to +of the :class:`DbType` objects defined below. Type objects can be equal to more than one type code (e.g. :class:`DATETIME` is equal to the type codes for ``date``, ``time`` and ``timestamp`` columns). -The pgdb module exports the following :class:`Type` objects as part of the +The pgdb module exports the following :class:`DbType` objects as part of the DB-API 2 standard: .. object:: STRING diff --git a/pg.py b/pg.py index f8dfb1be..25dc16e7 100644 --- a/pg.py +++ b/pg.py @@ -20,6 +20,22 @@ For a DB-API 2 compliant interface use the newer pgdb module. """ +import select +import weakref +from collections import OrderedDict, namedtuple +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from functools import lru_cache, partial +from inspect import signature +from json import dumps as jsonencode +from json import loads as jsondecode +from math import isinf, isnan +from operator import itemgetter +from re import compile as regex +from types import MappingProxyType +from typing import ClassVar, Dict, List, Mapping, Type, Union +from uuid import UUID + try: from _pg import version except ImportError as e: # noqa: F841 @@ -149,21 +165,6 @@ 'set_jsondecode', 'set_query_helpers', 'set_typecast', 'version', '__version__'] -import select -import weakref -from collections import OrderedDict, namedtuple -from datetime import date, datetime, time, timedelta -from decimal import Decimal -from functools import lru_cache, partial -from inspect import signature -from json import dumps as jsonencode -from json import loads as jsondecode -from math import isinf, isnan -from operator import itemgetter -from re import compile as regex -from typing import Dict, List, Union # noqa: F401 -from uuid import UUID - # Auxiliary classes and functions that are independent of a DB connection: def get_args(func): @@ -239,7 +240,7 @@ class _SimpleTypes(dict): The corresponding Python types and simple names are also mapped. """ - _type_aliases = { + _type_aliases: Mapping[str, List[Union[str, type]]] = MappingProxyType({ 'bool': [bool], 'bytea': [Bytea], 'date': ['interval', 'time', 'timetz', 'timestamp', 'timestamptz', @@ -251,13 +252,13 @@ class _SimpleTypes(dict): 'hstore': [Hstore], 'json': ['jsonb', Json], 'uuid': [UUID], 'num': ['numeric', Decimal], 'money': [], 'text': ['bpchar', 'char', 'name', 'varchar', bytes, str] - } # type: Dict[str, List[Union[str, type]]] + }) # noinspection PyMissingConstructor def __init__(self): """Initialize type mapping.""" for typ, keys in self._type_aliases.items(): - keys = [typ] + keys + keys = [typ, *keys] for key in keys: self[key] = typ if isinstance(key, str): @@ -969,7 +970,7 @@ class Typecasts(dict): # the default cast functions # (str functions are ignored but have been added for faster access) - defaults = { + defaults: ClassVar[Dict[str, Type]] = { 'char': str, 'bpchar': str, 'name': str, 'text': str, 'varchar': str, 'sql_identifier': str, 'bool': cast_bool, 'bytea': unescape_bytea, diff --git a/pgdb.py b/pgdb.py index 5752ac4d..00e57f02 100644 --- a/pgdb.py +++ b/pgdb.py @@ -64,6 +64,20 @@ connection.close() # close the connection """ +from collections import namedtuple +from collections.abc import Iterable +from datetime import date, datetime, time, timedelta +from decimal import Decimal as StdDecimal +from functools import lru_cache, partial +from inspect import signature +from json import dumps as jsonencode +from json import loads as jsondecode +from math import isinf, isnan +from re import compile as regex +from time import localtime +from typing import ClassVar, Dict, Type +from uuid import UUID as Uuid # noqa: N811 + try: from _pg import version except ImportError as e: # noqa: F841 @@ -137,19 +151,6 @@ 'get_typecast', 'set_typecast', 'reset_typecast', 'version', '__version__'] -from collections import namedtuple -from collections.abc import Iterable -from datetime import date, datetime, time, timedelta -from decimal import Decimal as StdDecimal -from functools import lru_cache, partial -from inspect import signature -from json import dumps as jsonencode -from json import loads as jsondecode -from math import isinf, isnan -from re import compile as regex -from time import localtime -from uuid import UUID as Uuid # noqa: N811 - Decimal = StdDecimal @@ -417,7 +418,7 @@ class Typecasts(dict): # the default cast functions # (str functions are ignored but have been added for faster access) - defaults = { + defaults: ClassVar[Dict[str, Type]] = { 'char': str, 'bpchar': str, 'name': str, 'text': str, 'varchar': str, 'sql_identifier': str, 'bool': cast_bool, 'bytea': unescape_bytea, @@ -1579,7 +1580,7 @@ def connect(dsn=None, # *** Types Handling *** -class Type(frozenset): +class DbType(frozenset): """Type class for a couple of PostgreSQL data types. PostgreSQL is object-oriented: types are dynamic. @@ -1651,30 +1652,30 @@ def __ne__(self, other): # Mandatory type objects defined by DB-API 2 specs: -STRING = Type('char bpchar name text varchar') -BINARY = Type('bytea') -NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money') -DATETIME = Type('date time timetz timestamp timestamptz interval' +STRING = DbType('char bpchar name text varchar') +BINARY = DbType('bytea') +NUMBER = DbType('int2 int4 serial int8 float4 float8 numeric money') +DATETIME = DbType('date time timetz timestamp timestamptz interval' ' abstime reltime') # these are very old -ROWID = Type('oid') +ROWID = DbType('oid') # Additional type objects (more specific): -BOOL = Type('bool') -SMALLINT = Type('int2') -INTEGER = Type('int2 int4 int8 serial') -LONG = Type('int8') -FLOAT = Type('float4 float8') -NUMERIC = Type('numeric') -MONEY = Type('money') -DATE = Type('date') -TIME = Type('time timetz') -TIMESTAMP = Type('timestamp timestamptz') -INTERVAL = Type('interval') -UUID = Type('uuid') -HSTORE = Type('hstore') -JSON = Type('json jsonb') +BOOL = DbType('bool') +SMALLINT = DbType('int2') +INTEGER = DbType('int2 int4 int8 serial') +LONG = DbType('int8') +FLOAT = DbType('float4 float8') +NUMERIC = DbType('numeric') +MONEY = DbType('money') +DATE = DbType('date') +TIME = DbType('time timetz') +TIMESTAMP = DbType('timestamp timestamptz') +INTERVAL = DbType('interval') +UUID = DbType('uuid') +HSTORE = DbType('hstore') +JSON = DbType('json jsonb') # Type object for arrays (also equate to their base types): diff --git a/pyproject.toml b/pyproject.toml index 9603b825..382b09ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ select = [ "N", # pep8-naming "UP", # pyupgrade "D", # pydocstyle + "RUF", # ruff ] exclude = [ "__pycache__", diff --git a/setup.py b/setup.py index 09c6e2f8..a52f315d 100755 --- a/setup.py +++ b/setup.py @@ -67,15 +67,15 @@ class build_pg_ext(build_ext): # noqa: N801 description = "build the PyGreSQL C extension" - user_options = build_ext.user_options + [ + user_options = [*build_ext.user_options, # noqa: RUF012 ('strict', None, "count all compiler warnings as errors"), ('memory-size', None, "enable memory size function"), ('no-memory-size', None, "disable memory size function")] - boolean_options = build_ext.boolean_options + [ + boolean_options = [*build_ext.boolean_options, # noqa: RUF012 'strict', 'memory-size'] - negative_opt = { + negative_opt = { # noqa: RUF012 'no-memory-size': 'memory-size'} def get_compiler(self): diff --git a/tests/dbapi20.py b/tests/dbapi20.py index bb913475..12a7647b 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -11,7 +11,7 @@ import time import unittest -from typing import Any, Dict, Tuple +from typing import Any, Mapping, Tuple class DatabaseAPI20Test(unittest.TestCase): @@ -41,7 +41,7 @@ class mytest(dbapi20.DatabaseAPI20Test): # method is to be found driver: Any = None connect_args: Tuple = () # List of arguments to pass to connect - connect_kw_args: Dict[str, Any] = {} # Keyword arguments for connect + connect_kw_args: Mapping[str, Any] = {} # Keyword arguments for connect table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables ddl1 = f'create table {table_prefix}booze (name varchar(20))' @@ -502,14 +502,14 @@ def test_next(self): finally: con.close() - samples = [ + samples = ( 'Carlton Cold', 'Carlton Draft', 'Mountain Goat', 'Redback', 'Victoria Bitter', 'XXXX' - ] + ) def _populate(self): """Return a list of SQL commands to setup the DB for fetching tests.""" diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index ed31bed8..440142c7 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -16,6 +16,7 @@ from collections import namedtuple from collections.abc import Iterable from decimal import Decimal +from typing import Sequence, Tuple import pg # the module under test @@ -1743,7 +1744,7 @@ def tearDown(self): self.c.query("truncate table test") self.c.close() - data = [ + data: Sequence[Tuple] = [ (-1, -1, -1, True, '1492-10-12', '08:30:00', -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'), (0, 0, 0, False, '1607-04-14', '09:00:00', @@ -1825,7 +1826,7 @@ def test_inserttable_from_tuple_of_lists(self): self.assertEqual(self.get_back(), self.data) def test_inserttable_with_different_row_sizes(self): - data = self.data[:-1] + [self.data[-1][:-1]] + data = [*self.data[:-1], (self.data[-1][:-1],)] try: self.c.inserttable('test', data) except TypeError as e: @@ -2107,10 +2108,10 @@ def test_insert_table_big_row_size(self): def test_insert_table_small_int_overflow(self): rest_row = self.data[2][1:] - data = [(32000,) + rest_row] + data = [(32000, *rest_row)] self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) - data = [(33000,) + rest_row] + data = [(33000, *rest_row)] try: self.c.inserttable('test', data) except ValueError as e: diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 5a49e9d2..5babc816 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -13,6 +13,7 @@ import re import unittest from datetime import timedelta +from typing import Any, Sequence, Tuple, Type import pg # the module under test @@ -119,7 +120,7 @@ def test_pqlib_version(self): class TestParseArray(unittest.TestCase): """Test the array parser.""" - test_strings = [ + test_strings: Sequence[Tuple[str, Type, Any]] = [ ('', str, ValueError), ('{}', None, []), ('{}', str, []), @@ -353,7 +354,7 @@ def replace_comma(value): class TestParseRecord(unittest.TestCase): """Test the record parser.""" - test_strings = [ + test_strings: Sequence[Tuple[str, Type, Any]] = [ ('', None, ValueError), ('', str, ValueError), ('(', None, ValueError), @@ -634,7 +635,7 @@ def replace_comma(value): class TestParseHStore(unittest.TestCase): """Test the hstore parser.""" - test_strings = [ + test_strings: Sequence[Tuple[str, Any]] = [ ('', {}), ('=>', ValueError), ('""=>', ValueError), @@ -683,7 +684,7 @@ def test_parser(self): class TestCastInterval(unittest.TestCase): """Test the interval typecast function.""" - intervals = [ + intervals: Sequence[Tuple[Tuple[int, ...], Tuple[str, ...]]] = [ ((0, 0, 0, 1, 0, 0, 0), ('1:00:00', '01:00:00', '@ 1 hour', 'PT1H')), ((0, 0, 0, -1, 0, 0, 0), diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 8ea52a7b..9fd00165 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -3,6 +3,7 @@ import gc import unittest from datetime import date, datetime, time, timedelta, timezone +from typing import Any, Mapping from uuid import UUID as Uuid # noqa: N811 import pgdb @@ -25,7 +26,7 @@ class TestPgDb(dbapi20.DatabaseAPI20Test): driver = pgdb connect_args = () - connect_kw_args = { + connect_kw_args: Mapping[str, Any] = { 'database': dbname, 'host': f"{dbhost or ''}:{dbport or -1}", 'user': dbuser, 'password': dbpasswd} @@ -1323,7 +1324,7 @@ def test_no_close(self): data = ('hello', 'world') con = self._connect() cur = con.cursor() - cur.build_row_factory = lambda: tuple # noqa: E731 + cur.build_row_factory = lambda: tuple cur.execute("select %s, %s", data) row = cur.fetchone() self.assertEqual(row, data) diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index c4e8dd74..170c33c1 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -11,6 +11,7 @@ import unittest from collections.abc import Iterable +from typing import Sequence, Tuple import pgdb # the module under test @@ -154,9 +155,10 @@ def tearDown(self): except Exception: pass - data = [(1935, 'Luciano Pavarotti'), - (1941, 'Plácido Domingo'), - (1946, 'José Carreras')] + data: Sequence[Tuple[int, str]] = [ + (1935, 'Luciano Pavarotti'), + (1941, 'Plácido Domingo'), + (1946, 'José Carreras')] can_encode = True @@ -447,11 +449,11 @@ def test_null(self): self.cursor.execute('insert into copytest values(4, null)') try: ret = list(self.copy_to()) - self.assertEqual(ret, data + ['4\t\\N\n']) + self.assertEqual(ret, [*data, '4\t\\N\n']) ret = list(self.copy_to(null='Nix')) - self.assertEqual(ret, data + ['4\tNix\n']) + self.assertEqual(ret, [*data, '4\tNix\n']) ret = list(self.copy_to(null='')) - self.assertEqual(ret, data + ['4\t\n']) + self.assertEqual(ret, [*data, '4\t\n']) finally: self.cursor.execute('delete from copytest where id=4') From 33859e51483e3b0fece3c63213ec4b6e1c6bca11 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 2 Sep 2023 09:42:38 +0200 Subject: [PATCH 042/118] Add ruff for local testing when provisioning --- .devcontainer/provision.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index a42337b8..2f3651d6 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -35,7 +35,9 @@ python3.9 -m pip install build python3.10 -m pip install build python3.11 -m pip install build -sudo apt-get install -y tox python3-poetry +pip install ruff + +sudo apt-get install -y tox # install PostgreSQL client tools From 950d7d8e22034e711aa79d2f49f5a41904ed74e8 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 2 Sep 2023 10:21:42 +0200 Subject: [PATCH 043/118] Add testing with flake8-bugbear --- pg.py | 11 ++++----- pgdb.py | 36 ++++++++++++++++-------------- pyproject.toml | 2 ++ setup.py | 6 +++-- tests/test_classic.py | 6 ++--- tests/test_classic_connection.py | 4 ++-- tests/test_classic_dbwrapper.py | 24 ++++++++++---------- tests/test_classic_functions.py | 4 ++-- tests/test_classic_notification.py | 24 ++++++++++---------- tests/test_dbapi20.py | 2 +- tests/test_dbapi20_copy.py | 4 ++-- 11 files changed, 65 insertions(+), 58 deletions(-) diff --git a/pg.py b/pg.py index 25dc16e7..2df124b0 100644 --- a/pg.py +++ b/pg.py @@ -494,7 +494,7 @@ def _adapt_record(self, v, typ): raise TypeError(f'Record parameter {v} has wrong size') adapt = self.adapt value = [] - for v, t in zip(v, typ): + for v, t in zip(v, typ): # noqa: B020 v = adapt(v, t) if v is None: v = '' @@ -1989,7 +1989,7 @@ def pkey(self, table, composite=False, flush=False): self._do_debug('The pkey cache has been flushed') try: # cache lookup pkey = pkeys[table] - except KeyError: # cache miss, check the database + except KeyError as e: # cache miss, check the database q = ("SELECT a.attname, a.attnum, i.indkey" " FROM pg_catalog.pg_index i" " JOIN pg_catalog.pg_attribute a" @@ -2002,7 +2002,7 @@ def pkey(self, table, composite=False, flush=False): _quote_if_unqualified('$1', table)) pkey = self.db.query(q, (table,)).getresult() if not pkey: - raise KeyError(f'Table {table} has no primary key') + raise KeyError(f'Table {table} has no primary key') from e # we want to use the order defined in the primary key index here, # not the order as defined by the columns in the table if len(pkey) > 1: @@ -2173,12 +2173,13 @@ def get(self, table, row, keyname=None): if not keyname: try: # if keyname is not specified, try using the primary key keyname = self.pkey(table, True) - except KeyError: # the table has no primary key + except KeyError as e: # the table has no primary key # try using the oid instead if qoid and isinstance(row, dict) and 'oid' in row: keyname = ('oid',) else: - raise _prg_error(f'Table {table} has no primary key') + raise _prg_error( + f'Table {table} has no primary key') from e else: # the table has a primary key # check whether all key columns have values if isinstance(row, dict) and not set(keyname).issubset(row): diff --git a/pgdb.py b/pgdb.py index 00e57f02..74db29e9 100644 --- a/pgdb.py +++ b/pgdb.py @@ -993,7 +993,8 @@ def executemany(self, operation, seq_of_parameters): raise # database provides error message except Error as err: # noinspection PyTypeChecker - raise _db_error(f"Error in '{sql}': '{err}'", InterfaceError) + raise _db_error( + f"Error in '{sql}': '{err}'", InterfaceError) from err except Exception as err: raise _op_error(f"Internal error in '{sql}': {err}") from err # then initialize result raw count and description @@ -1090,9 +1091,10 @@ def copy_from(self, stream, table, binary_format = format == 'binary' try: read = stream.read - except AttributeError: + except AttributeError as e: if size: - raise ValueError("Size must only be set for file-like objects") + raise ValueError( + "Size must only be set for file-like objects") from e if binary_format: input_type = bytes type_name = 'byte strings' @@ -1102,7 +1104,7 @@ def copy_from(self, stream, table, if isinstance(stream, (bytes, str)): if not isinstance(stream, input_type): - raise ValueError(f"The input must be {type_name}") + raise ValueError(f"The input must be {type_name}") from e if not binary_format: if isinstance(stream, str): if not stream.endswith('\n'): @@ -1130,7 +1132,7 @@ def chunks(): yield chunk else: - raise TypeError("Need an input stream to copy from") + raise TypeError("Need an input stream to copy from") from e else: if size is None: size = 8192 @@ -1233,8 +1235,8 @@ def copy_to(self, stream, table, if stream is not None: try: write = stream.write - except AttributeError: - raise TypeError("Need an output stream to copy to") + except AttributeError as e: + raise TypeError("Need an output stream to copy to") from e if not table or not isinstance(table, str): raise TypeError("Need a table to copy to") if table.lower().startswith('select '): @@ -1405,8 +1407,8 @@ def __init__(self, cnx): self.autocommit = False try: self._cnx.source() - except Exception: - raise _op_error("Invalid connection") + except Exception as e: + raise _op_error("Invalid connection") from e def __enter__(self): """Enter the runtime context for the connection object. @@ -1420,8 +1422,8 @@ def __enter__(self): self._cnx.source().execute("BEGIN") except DatabaseError: raise # database provides error message - except Exception: - raise _op_error("Can't start transaction") + except Exception as e: + raise _op_error("Can't start transaction") from e else: self._tnx = True return self @@ -1466,8 +1468,8 @@ def commit(self): self._cnx.source().execute("COMMIT") except DatabaseError: raise # database provides error message - except Exception: - raise _op_error("Can't commit transaction") + except Exception as e: + raise _op_error("Can't commit transaction") from e else: raise _op_error("Connection has been closed") @@ -1480,8 +1482,8 @@ def rollback(self): self._cnx.source().execute("ROLLBACK") except DatabaseError: raise # database provides error message - except Exception: - raise _op_error("Can't rollback transaction") + except Exception as e: + raise _op_error("Can't rollback transaction") from e else: raise _op_error("Connection has been closed") @@ -1490,8 +1492,8 @@ def cursor(self): if self._cnx: try: return self.cursor_type(self) - except Exception: - raise _op_error("Invalid connection") + except Exception as e: + raise _op_error("Invalid connection") from e else: raise _op_error("Connection has been closed") diff --git a/pyproject.toml b/pyproject.toml index 382b09ca..2abfeb63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ download = "https://pygresql.github.io/download/" "mailing list" = "https://mail.vex.net/mailman/listinfo/pygresql" [tool.ruff] +target-version = "py37" line-length = 79 select = [ "E", # pycodestyle @@ -50,6 +51,7 @@ select = [ "UP", # pyupgrade "D", # pydocstyle "RUF", # ruff + "B", # bugbear ] exclude = [ "__pycache__", diff --git a/setup.py b/setup.py index a52f315d..6e4c5fd4 100755 --- a/setup.py +++ b/setup.py @@ -90,7 +90,8 @@ def initialize_options(self): supported = pg_version >= (10, 0) if not supported: warnings.warn( - "PyGreSQL does not support the installed PostgreSQL version.") + "PyGreSQL does not support the installed PostgreSQL version.", + stacklevel=2) def finalize_options(self): """Set final values for all build_pg options.""" @@ -104,7 +105,8 @@ def finalize_options(self): if not supported: warnings.warn( "The installed PostgreSQL version" - " does not support the memory size function.") + " does not support the memory size function.", + stacklevel=2) if sys.platform == 'win32': libraries[0] = 'lib' + libraries[0] if os.path.exists(os.path.join( diff --git a/tests/test_classic.py b/tests/test_classic.py index d6763074..18af07b2 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -215,7 +215,7 @@ def test_notify(self, options=None): thread.start() try: # Wait until the thread has started. - for n in range(500): + for _n in range(500): if target.listening: break sleep(0.01) @@ -237,7 +237,7 @@ def test_notify(self, options=None): if two_payloads: db2.commit() # Wait until the notification has been caught. - for n in range(500): + for _n in range(500): if arg_dict['called'] or self.notify_timeout: break sleep(0.01) @@ -256,7 +256,7 @@ def test_notify(self, options=None): db2.query("notify stop_event_1, 'payload 2'") db2.close() # Wait until the notification has been caught. - for n in range(500): + for _n in range(500): if arg_dict['called'] or self.notify_timeout: break sleep(0.01) diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 440142c7..1fa9edb6 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -1357,7 +1357,7 @@ def test_iterate(self): def test_iterate_twice(self): r = self.c.query("select generate_series(3,5)") - for i in range(2): + for _i in range(2): self.assertEqual(list(r), [(3,), (4,), (5,)]) def test_iterate_two_columns(self): @@ -2652,7 +2652,7 @@ def test_set_row_factory_size(self): query = self.c.query for maxsize in (None, 0, 1, 2, 3, 10, 1024): pg.set_row_factory_size(maxsize) - for i in range(3): + for _i in range(3): for q in queries: r = query(q).namedresult()[0] if q.endswith('abc'): diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 1f7b3aac..c563d932 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -2735,7 +2735,7 @@ def test_truncate(self): truncate('test_table') r = query(q).getresult()[0][0] self.assertEqual(r, 0) - for i in range(3): + for _i in range(3): query("insert into test_table values (1)") r = query(q).getresult()[0][0] self.assertEqual(r, 3) @@ -2744,7 +2744,7 @@ def test_truncate(self): self.assertEqual(r, 0) self.create_table('test_table_2', 'n smallint', temporary=True) for t in (list, tuple, set): - for i in range(3): + for _i in range(3): query("insert into test_table values (1)") query("insert into test_table_2 values (2)") q = ("select (select count(*) from test_table)," @@ -2760,7 +2760,7 @@ def test_truncate_restart(self): self.assertRaises(TypeError, truncate, 'test_table', restart='invalid') query = self.db.query self.create_table('test_table', 'n serial, t text') - for n in range(3): + for _n in range(3): query("insert into test_table (t) values ('test')") q = "select count(n), min(n), max(n) from test_table" r = query(q).getresult()[0] @@ -2768,14 +2768,14 @@ def test_truncate_restart(self): truncate('test_table') r = query(q).getresult()[0] self.assertEqual(r, (0, None, None)) - for n in range(3): + for _n in range(3): query("insert into test_table (t) values ('test')") r = query(q).getresult()[0] self.assertEqual(r, (3, 4, 6)) truncate('test_table', restart=True) r = query(q).getresult()[0] self.assertEqual(r, (0, None, None)) - for n in range(3): + for _n in range(3): query("insert into test_table (t) values ('test')") r = query(q).getresult()[0] self.assertEqual(r, (3, 1, 3)) @@ -2824,7 +2824,7 @@ def test_truncate_only(self): query = self.db.query self.create_table('test_parent', 'n smallint') self.create_table('test_child', 'm smallint) inherits (test_parent') - for n in range(3): + for _n in range(3): query("insert into test_parent (n) values (1)") query("insert into test_child (n, m) values (2, 3)") q = ("select (select count(*) from test_parent)," @@ -2834,7 +2834,7 @@ def test_truncate_only(self): truncate('test_parent') r = query(q).getresult()[0] self.assertEqual(r, (0, 0)) - for n in range(3): + for _n in range(3): query("insert into test_parent (n) values (1)") query("insert into test_child (n, m) values (2, 3)") r = query(q).getresult()[0] @@ -2842,7 +2842,7 @@ def test_truncate_only(self): truncate('test_parent*') r = query(q).getresult()[0] self.assertEqual(r, (0, 0)) - for n in range(3): + for _n in range(3): query("insert into test_parent (n) values (1)") query("insert into test_child (n, m) values (2, 3)") r = query(q).getresult()[0] @@ -2859,7 +2859,7 @@ def test_truncate_only(self): self.create_table('test_child_2', 'm smallint) inherits (test_parent_2') for t in '', '_2': - for n in range(3): + for _n in range(3): query(f"insert into test_parent{t} (n) values (1)") query(f"insert into test_child{t} (n, m) values (2, 3)") q = ("select (select count(*) from test_parent)," @@ -2890,7 +2890,7 @@ def test_truncate_quoted(self): truncate(table) r = query(q).getresult()[0][0] self.assertEqual(r, 0) - for i in range(3): + for _i in range(3): query(f'insert into "{table}" values (1)') r = query(q).getresult()[0][0] self.assertEqual(r, 3) @@ -4703,11 +4703,11 @@ def setUpClass(cls): query(f"drop schema if exists {schema} cascade") try: query(f"create schema {schema}") - except pg.ProgrammingError: + except pg.ProgrammingError as e: raise RuntimeError( "The test user cannot create schemas.\n" f"Grant create on database {dbname} to the user" - " for running these tests.") + " for running these tests.") from e else: schema = "public" query(f"drop table if exists {schema}.t") diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 5babc816..37606b13 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -267,7 +267,7 @@ def test_parser_nested(self): self.assertEqual(len(r), 1) self.assertEqual(r[0], 'b') r = f('{{{{{{{abc}}}}}}}') - for i in range(7): + for _i in range(7): self.assertIsInstance(r, list) self.assertEqual(len(r), 1) # noinspection PyUnresolvedReferences @@ -282,7 +282,7 @@ def test_parser_too_deeply_nested(self): self.assertRaises(ValueError, f, r) else: r = f(r) - for i in range(n - 1): + for _i in range(n - 1): self.assertIsInstance(r, list) self.assertEqual(len(r), 1) r = r[0] diff --git a/tests/test_classic_notification.py b/tests/test_classic_notification.py index 9e56bd6d..552f1ea5 100755 --- a/tests/test_classic_notification.py +++ b/tests/test_classic_notification.py @@ -208,7 +208,7 @@ def start_handler(self, event=None, arg_dict=None, thread.start() self.stopped = timeout == 0 self.addCleanup(self.stop_handler) - for n in range(500): + for _n in range(500): if handler.listening: break sleep(0.01) @@ -255,7 +255,7 @@ def notify_query(self, stop=False, payload=None): self.sent.append(arg_dict) def wait(self): - for n in range(500): + for _n in range(500): if self.timeout: return False if len(self.received) >= len(self.sent): @@ -309,15 +309,15 @@ def test_notify_with_args(self): def test_notify_several_times(self): arg_dict = {'test': 1} self.start_handler(arg_dict=arg_dict) - for count in range(3): + for _n in range(3): self.notify_query() self.receive() arg_dict['test'] += 1 - for count in range(2): + for _n in range(2): self.notify_handler() self.receive() arg_dict['test'] += 1 - for count in range(3): + for _n in range(3): self.notify_query() self.receive(stop=True) @@ -338,30 +338,30 @@ def test_notify_quoted_names(self): def test_notify_with_five_payloads(self): self.start_handler('gimme_5', {'test': 'Gimme 5'}) - for count in range(5): - self.notify_query(payload=f"Round {count}") + for n in range(5): + self.notify_query(payload=f"Round {n}") self.assertEqual(len(self.sent), 5) self.receive(stop=True) def test_receive_immediately(self): self.start_handler('immediate', {'test': 'immediate'}) - for count in range(3): - self.notify_query(payload=f"Round {count}") + for n in range(3): + self.notify_query(payload=f"Round {n}") self.receive() self.receive(stop=True) def test_notify_distinct_in_transaction(self): self.start_handler('test_transaction', {'transaction': True}) self.db.begin() - for count in range(3): - self.notify_query(payload=f'Round {count}') + for n in range(3): + self.notify_query(payload=f'Round {n}') self.db.commit() self.receive(stop=True) def test_notify_same_in_transaction(self): self.start_handler('test_transaction', {'transaction': True}) self.db.begin() - for count in range(3): + for _n in range(3): self.notify_query() self.db.commit() # these same notifications may be delivered as one, diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 9fd00165..657e820c 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -1335,7 +1335,7 @@ def test_set_row_factory_size(self): cur = con.cursor() for maxsize in (None, 0, 1, 2, 3, 10, 1024): pgdb.set_row_factory_size(maxsize) - for i in range(3): + for _i in range(3): for q in queries: cur.execute(q) r = cur.fetchone() diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index 170c33c1..ca775001 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -181,7 +181,7 @@ def table_data(self): def check_table(self): self.assertEqual(self.table_data, self.data) - def check_rowcount(self, number=len(data)): + def check_rowcount(self, number=len(data)): # noqa: B008 self.assertEqual(self.cursor.rowcount, number) @@ -429,7 +429,7 @@ def test_generator_bytes(self): def test_rowcount_increment(self): ret = self.copy_to() self.assertIsInstance(ret, Iterable) - for n, row in enumerate(ret): + for n, _row in enumerate(ret): self.check_rowcount(n + 1) def test_decode(self): From d8033beee03c56a50bd6982a83129a41402d38bf Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 2 Sep 2023 10:54:07 +0200 Subject: [PATCH 044/118] Add testing with flake8-bandit --- pg.py | 19 ++++++++++++------- pgdb.py | 4 ++-- pyproject.toml | 5 +++-- setup.py | 2 +- 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/pg.py b/pg.py index 2df124b0..ac86d480 100644 --- a/pg.py +++ b/pg.py @@ -1990,7 +1990,8 @@ def pkey(self, table, composite=False, flush=False): try: # cache lookup pkey = pkeys[table] except KeyError as e: # cache miss, check the database - q = ("SELECT a.attname, a.attnum, i.indkey" + q = ("SELECT" # noqa: S608 + " a.attname, a.attnum, i.indkey" " FROM pg_catalog.pg_index i" " JOIN pg_catalog.pg_attribute a" " ON a.attrelid OPERATOR(pg_catalog.=) i.indrelid" @@ -2038,7 +2039,8 @@ def get_relations(self, kinds=None, system=False): where.append("s.nspname NOT SIMILAR" " TO 'pg/_%|information/_schema' ESCAPE '/'") where = " WHERE " + ' AND '.join(where) if where else '' - q = ("SELECT pg_catalog.quote_ident(s.nspname) OPERATOR(pg_catalog.||)" + q = ("SELECT" # noqa: S608 + " pg_catalog.quote_ident(s.nspname) OPERATOR(pg_catalog.||)" " '.' OPERATOR(pg_catalog.||) pg_catalog.quote_ident(r.relname)" " FROM pg_catalog.pg_class r" " JOIN pg_catalog.pg_namespace s" @@ -2207,7 +2209,7 @@ def get(self, table, row, keyname=None): row[qoid] = row['oid'] del row['oid'] t = self._escape_qualified_name(table) - q = f'SELECT {what} FROM {t} WHERE {where} LIMIT 1' + q = f'SELECT {what} FROM {t} WHERE {where} LIMIT 1' # noqa: S608s self._do_debug(q, params) q = self.db.query(q, params) res = q.dictresult() @@ -2259,7 +2261,8 @@ def insert(self, table, row=None, **kw): names, values = ', '.join(names), ', '.join(values) ret = 'oid, *' if qoid else '*' t = self._escape_qualified_name(table) - q = f'INSERT INTO {t} ({names}) VALUES ({values}) RETURNING {ret}' + q = (f'INSERT INTO {t} ({names})' # noqa: S608 + f' VALUES ({values}) RETURNING {ret}') self._do_debug(q, params) q = self.db.query(q, params) res = q.dictresult() @@ -2322,7 +2325,8 @@ def update(self, table, row=None, **kw): values = ', '.join(values) ret = 'oid, *' if qoid else '*' t = self._escape_qualified_name(table) - q = f'UPDATE {t} SET {values} WHERE {where} RETURNING {ret}' + q = (f'UPDATE {t} SET {values}' # noqa: S608 + f' WHERE {where} RETURNING {ret}') self._do_debug(q, params) q = self.db.query(q, params) res = q.dictresult() @@ -2417,7 +2421,8 @@ def upsert(self, table, row=None, **kw): do = 'update set ' + ', '.join(update) if update else 'nothing' ret = 'oid, *' if qoid else '*' t = self._escape_qualified_name(table) - q = (f'INSERT INTO {t} AS included ({names}) VALUES ({values})' + q = (f'INSERT INTO {t} AS included ({names})' # noqa: S608 + f' VALUES ({values})' f' ON CONFLICT ({target}) DO {do} RETURNING {ret}') self._do_debug(q, params) q = self.db.query(q, params) @@ -2499,7 +2504,7 @@ def delete(self, table, row=None, **kw): row[qoid] = row['oid'] del row['oid'] t = self._escape_qualified_name(table) - q = f'DELETE FROM {t} WHERE {where}' + q = f'DELETE FROM {t} WHERE {where}' # noqa: S608 self._do_debug(q, params) res = self.db.query(q, params) return int(res) diff --git a/pgdb.py b/pgdb.py index 74db29e9..22f99498 100644 --- a/pgdb.py +++ b/pgdb.py @@ -686,7 +686,7 @@ def get_fields(self, typ): if not typ.relid: return None # this type is not composite self._src.execute( - "SELECT attname, atttypid" + "SELECT attname, atttypid" # noqa: S608 " FROM pg_catalog.pg_attribute" f" WHERE attrelid OPERATOR(pg_catalog.=) {typ.relid}" " AND attnum OPERATOR(pg_catalog.>) 0" @@ -1065,7 +1065,7 @@ def callproc(self, procname, parameters=None): """ n = len(parameters) if parameters else 0 s = ','.join(n * ['%s']) - query = f'select * from "{procname}"({s})' + query = f'select * from "{procname}"({s})' # noqa: S608 self.execute(query, parameters) return parameters diff --git a/pyproject.toml b/pyproject.toml index 2abfeb63..f5927d43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,8 +50,9 @@ select = [ "N", # pep8-naming "UP", # pyupgrade "D", # pydocstyle - "RUF", # ruff "B", # bugbear + "S", # bandit + "RUF", # ruff ] exclude = [ "__pycache__", @@ -69,7 +70,7 @@ exclude = [ ] [tool.ruff.per-file-ignores] -"tests/*.py" = ["D100", "D101", "D102", "D103", "D105", "D107"] +"tests/*.py" = ["D100", "D101", "D102", "D103", "D105", "D107", "S"] [tool.setuptools] py-modules = ["pg", "pgdb"] diff --git a/setup.py b/setup.py index 6e4c5fd4..c20c9607 100755 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ def pg_config(s): """Retrieve information about installed version of PostgreSQL.""" - f = os.popen(f'pg_config --{s}') + f = os.popen(f'pg_config --{s}') # noqa: S605 d = f.readline().strip() if f.close() is not None: raise Exception("pg_config tool is not available.") From d844e8a6710d74671fafb976406014a3fbbd07c1 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 2 Sep 2023 11:16:20 +0200 Subject: [PATCH 045/118] Add testing with flake8-simplify --- pg.py | 21 ++++++--------------- pgdb.py | 5 ++--- pyproject.toml | 1 + tests/config.py | 2 +- tests/dbapi20.py | 5 ++--- tests/test_classic.py | 17 +++++------------ tests/test_classic_connection.py | 28 +++++++++------------------- tests/test_classic_dbwrapper.py | 27 ++++++++------------------- tests/test_classic_largeobj.py | 21 ++++++++------------- tests/test_dbapi20_copy.py | 13 ++++--------- 10 files changed, 46 insertions(+), 94 deletions(-) diff --git a/pg.py b/pg.py index ac86d480..434dc906 100644 --- a/pg.py +++ b/pg.py @@ -23,6 +23,7 @@ import select import weakref from collections import OrderedDict, namedtuple +from contextlib import suppress from datetime import date, datetime, time, timedelta from decimal import Decimal from functools import lru_cache, partial @@ -1507,11 +1508,9 @@ def __init__(self, *args, **kw): if isinstance(db, DB): db = db.db else: - try: + with suppress(AttributeError): # noinspection PyUnresolvedReferences db = db._cnx - except AttributeError: - pass if not db or not hasattr(db, 'db') or not hasattr(db, 'query'): db = connect(*args, **kw) self._db_args = args, kw @@ -1592,15 +1591,11 @@ def __del__(self): except AttributeError: db = None if db: - try: + with suppress(TypeError): # when already closed db.set_cast_hook(None) - except TypeError: - pass # probably already closed if self._closeable: - try: + with suppress(InternalError): # when already closed db.close() - except InternalError: - pass # probably already closed # Auxiliary methods @@ -1661,10 +1656,8 @@ def close(self): # Wraps shared library function so we can track state. db = self.db if db: - try: + with suppress(TypeError): # when already closed db.set_cast_hook(None) - except TypeError: - pass # probably already closed if self._closeable: db.close() self.db = None @@ -2611,10 +2604,8 @@ def get_as_list(self, table, what=None, where=None, try: order = self.pkey(table, True) except (KeyError, ProgrammingError): - try: + with suppress(KeyError, ProgrammingError): order = list(self.get_attnames(table)) - except (KeyError, ProgrammingError): - pass if order: if isinstance(order, (list, tuple)): order = ', '.join(map(str, order)) diff --git a/pgdb.py b/pgdb.py index 22f99498..2e48e39d 100644 --- a/pgdb.py +++ b/pgdb.py @@ -66,6 +66,7 @@ from collections import namedtuple from collections.abc import Iterable +from contextlib import suppress from datetime import date, datetime, time, timedelta from decimal import Decimal as StdDecimal from functools import lru_cache, partial @@ -1442,10 +1443,8 @@ def close(self): """Close the connection object.""" if self._cnx: if self._tnx: - try: + with suppress(DatabaseError): self.rollback() - except DatabaseError: - pass self._cnx.close() self._cnx = None else: diff --git a/pyproject.toml b/pyproject.toml index f5927d43..131308b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ select = [ "D", # pydocstyle "B", # bugbear "S", # bandit + "SIM", # simplify "RUF", # ruff ] exclude = [ diff --git a/tests/config.py b/tests/config.py index acd8559a..f6280548 100644 --- a/tests/config.py +++ b/tests/config.py @@ -28,7 +28,7 @@ try: from .LOCAL_PyGreSQL import * # noqa: F403 except (ImportError, ValueError): - try: + try: # noqa: SIM105 from LOCAL_PyGreSQL import * # noqa: F403 except ImportError: pass diff --git a/tests/dbapi20.py b/tests/dbapi20.py index 12a7647b..d5f2938f 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -11,6 +11,7 @@ import time import unittest +from contextlib import suppress from typing import Any, Mapping, Tuple @@ -181,11 +182,9 @@ def test_rollback(self): # If rollback is defined, it should either work or throw # the documented exception if hasattr(con, 'rollback'): - try: + with suppress(self.driver.NotSupportedError): # noinspection PyCallingNonCallable con.rollback() - except self.driver.NotSupportedError: - pass def test_cursor(self): con = self._connect() diff --git a/tests/test_classic.py b/tests/test_classic.py index 18af07b2..a6f78197 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -1,6 +1,7 @@ #!/usr/bin/python import unittest +from contextlib import suppress from functools import partial from threading import Thread from time import sleep @@ -34,27 +35,19 @@ class UtilityTest(unittest.TestCase): def setUpClass(cls): """Recreate test tables and schemas.""" db = open_db() - try: + with suppress(Exception): db.query("DROP VIEW _test_vschema") - except Exception: - pass - try: + with suppress(Exception): db.query("DROP TABLE _test_schema") - except Exception: - pass db.query("CREATE TABLE _test_schema" " (_test int PRIMARY KEY, _i interval, dvar int DEFAULT 999)") db.query("CREATE VIEW _test_vschema AS" " SELECT _test, 'abc'::text AS _test2 FROM _test_schema") for t in ('_test1', '_test2'): - try: + with suppress(Exception): db.query("CREATE SCHEMA " + t) - except Exception: - pass - try: + with suppress(Exception): db.query(f"DROP TABLE {t}._test_schema") - except Exception: - pass db.query(f"CREATE TABLE {t}._test_schema" f" ({t} int PRIMARY KEY)") db.close() diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 1fa9edb6..7d4409df 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -15,6 +15,7 @@ import unittest from collections import namedtuple from collections.abc import Iterable +from contextlib import suppress from decimal import Decimal from typing import Sequence, Tuple @@ -94,10 +95,8 @@ def setUp(self): self.connection = connect() def tearDown(self): - try: + with suppress(pg.InternalError): self.connection.close() - except pg.InternalError: - pass def is_method(self, attribute): """Check if given attribute on the connection is a method.""" @@ -152,10 +151,7 @@ def test_attribute_error(self): @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason) def test_attribute_host(self): - if dbhost and not dbhost.startswith('/'): - host = dbhost - else: - host = 'localhost' + host = dbhost if dbhost and not dbhost.startswith('/') else 'localhost' self.assertIsInstance(self.connection.host, str) self.assertEqual(self.connection.host, host) @@ -282,10 +278,8 @@ def test_all_query_members(self): self.assertEqual(members, query_members) def test_method_endcopy(self): - try: + with suppress(OSError): self.connection.endcopy() - except OSError: - pass def test_method_close(self): self.connection.close() @@ -1255,9 +1249,9 @@ def assert_proper_cast(self, value, pgtype, pytype): self.fail(str(e)) # noinspection PyUnboundLocalVariable self.assertIsInstance(r, pytype) - if isinstance(value, str): - if not value or ' ' in value or '{' in value: - value = f'"{value}"' + if isinstance(value, str) and ( + not value or ' ' in value or '{' in value): + value = f'"{value}"' value = f'{{{value}}}' r = self.c.query(q + '[]', (value,)).getresult()[0][0] if pgtype.startswith(('date', 'time', 'interval')): @@ -2194,10 +2188,8 @@ def test_getline(self): elif i == n: self.assertIsNone(v) finally: - try: + with suppress(OSError): self.c.endcopy() - except OSError: - pass def test_getline_bytes_and_unicode(self): getline = self.c.getline @@ -2218,10 +2210,8 @@ def test_getline_bytes_and_unicode(self): self.assertEqual(v, '73\twürstel') self.assertIsNone(getline()) finally: - try: + with suppress(OSError): self.c.endcopy() - except OSError: - pass def test_parameter_checks(self): self.assertRaises(TypeError, self.c.putline) diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index c563d932..3884436f 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -16,6 +16,7 @@ import tempfile import unittest from collections import OrderedDict +from contextlib import suppress from datetime import date, datetime, time, timedelta from decimal import Decimal from io import StringIO @@ -167,10 +168,8 @@ def setUp(self): self.db = DB() def tearDown(self): - try: + with suppress(pg.InternalError): self.db.close() - except pg.InternalError: - pass def test_all_db_attributes(self): attributes = [ @@ -223,10 +222,7 @@ def test_attribute_error(self): @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason) def test_attribute_host(self): - if dbhost and not dbhost.startswith('/'): - host = dbhost - else: - host = 'localhost' + host = dbhost if dbhost and not dbhost.startswith('/') else 'localhost' self.assertIsInstance(self.db.host, str) self.assertEqual(self.db.host, host) self.assertEqual(self.db.db.host, host) @@ -334,10 +330,8 @@ def test_method_query_data_error(self): self.assertEqual(error.sqlstate, '22012') def test_method_endcopy(self): - try: + with suppress(OSError): self.db.endcopy() - except OSError: - pass def test_method_close(self): self.db.close() @@ -4352,10 +4346,8 @@ def setUp(self): self.adapter = self.db.adapter def tearDown(self): - try: + with suppress(pg.InternalError): self.db.close() - except pg.InternalError: - pass def test_guess_simple_type(self): f = self.adapter.guess_simple_type @@ -4744,12 +4736,9 @@ def tearDown(self): def test_get_tables(self): tables = self.db.get_tables() for num_schema in range(5): - if num_schema: - schema = "s" + str(num_schema) - else: - schema = "public" - for t in (schema + ".t", - schema + ".t" + str(num_schema)): + schema = 's' + str(num_schema) if num_schema else 'public' + for t in (schema + '.t', + schema + '.t' + str(num_schema)): self.assertIn(t, tables) def test_get_attnames(self): diff --git a/tests/test_classic_largeobj.py b/tests/test_classic_largeobj.py index afe48a21..7e5ad4a2 100755 --- a/tests/test_classic_largeobj.py +++ b/tests/test_classic_largeobj.py @@ -12,6 +12,7 @@ import os import tempfile import unittest +from contextlib import suppress import pg # the module under test @@ -107,7 +108,7 @@ def test_lo_import(self): if windows: # NamedTemporaryFiles don't work well here fname = 'temp_test_pg_largeobj_import.txt' - f = open(fname, 'wb') + f = open(fname, 'wb') # noqa: SIM115 else: f = tempfile.NamedTemporaryFile() fname = f.name @@ -115,7 +116,7 @@ def test_lo_import(self): f.write(data) if windows: f.close() - f = open(fname, 'rb') + f = open(fname, 'rb') # noqa: SIM115 else: f.flush() f.seek(0) @@ -149,19 +150,13 @@ def setUp(self): def tearDown(self): if self.obj.oid: - try: + with suppress(SystemError, OSError): self.obj.close() - except (SystemError, OSError): - pass - try: + with suppress(SystemError, OSError): self.obj.unlink() - except (SystemError, OSError): - pass del self.obj - try: + with suppress(SystemError): self.pgcnx.query('rollback') - except SystemError: - pass self.pgcnx.close() def test_class_name(self): @@ -420,7 +415,7 @@ def test_export(self): if windows: # NamedTemporaryFiles don't work well here fname = 'temp_test_pg_largeobj_export.txt' - f = open(fname, 'wb') + f = open(fname, 'wb') # noqa: SIM115 else: f = tempfile.NamedTemporaryFile() fname = f.name @@ -433,7 +428,7 @@ def test_export(self): export(fname) if windows: f.close() - f = open(fname, 'rb') + f = open(fname, 'rb') # noqa: SIM115 r = f.read() f.close() if windows: diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index ca775001..bcacd476 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -11,6 +11,7 @@ import unittest from collections.abc import Iterable +from contextlib import suppress from typing import Sequence, Tuple import pgdb # the module under test @@ -142,18 +143,12 @@ def setUp(self): self.cursor.execute("set client_encoding=utf8") def tearDown(self): - try: + with suppress(Exception): self.cursor.close() - except Exception: - pass - try: + with suppress(Exception): self.con.rollback() - except Exception: - pass - try: + with suppress(Exception): self.con.close() - except Exception: - pass data: Sequence[Tuple[int, str]] = [ (1935, 'Luciano Pavarotti'), From fadd20762006f3c08f55abe916c4a19274fd2fb5 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 2 Sep 2023 11:31:43 +0200 Subject: [PATCH 046/118] Add some type hints --- pg.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pg.py b/pg.py index 434dc906..d29cb5c2 100644 --- a/pg.py +++ b/pg.py @@ -34,7 +34,7 @@ from operator import itemgetter from re import compile as regex from types import MappingProxyType -from typing import ClassVar, Dict, List, Mapping, Type, Union +from typing import Callable, ClassVar, Dict, List, Mapping, Type, Union from uuid import UUID try: @@ -298,6 +298,8 @@ def _quote_if_unqualified(param, name): class _ParameterList(list): """Helper class for building typed parameter lists.""" + adapt: Callable + def add(self, value, typ=None): """Typecast value with known database type and build parameter list. @@ -1149,6 +1151,18 @@ class DbType(str): attnames: attributes for composite types """ + oid: int + pgtype: str + regtype: str + simple: str + typlen: int + typtype: str + category: str + delim: str + relid: int + + _get_attnames: Callable + @property def attnames(self): """Get names and types of the fields of a composite type.""" From 47f19c189ca0deb98dd165489de8b3c02ee7b02b Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 2 Sep 2023 15:06:31 +0200 Subject: [PATCH 047/118] Use clang-format for C files --- .clang-format | 25 ++ .devcontainer/provision.sh | 2 +- pgconn.c | 778 ++++++++++++++++++++----------------- pginternal.c | 527 +++++++++++++++---------- pglarge.c | 151 ++++--- pgmodule.c | 601 ++++++++++++++-------------- pgnotice.c | 73 ++-- pgquery.c | 435 +++++++++++---------- pgsource.c | 290 +++++++------- tox.ini | 9 +- 10 files changed, 1557 insertions(+), 1334 deletions(-) create mode 100644 .clang-format diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..22f8603a --- /dev/null +++ b/.clang-format @@ -0,0 +1,25 @@ +# A clang-format style that approximates Python's PEP 7 +# Useful for IDE integration +# +# Based on Paul Ganssle's version at +# https://gist.github.com/pganssle/0e3a5f828b4d07d79447f6ced8e7e4db +BasedOnStyle: Google +AlwaysBreakAfterReturnType: All +AllowShortIfStatementsOnASingleLine: false +AlignAfterOpenBracket: Align +AlignTrailingComments: true +BreakBeforeBraces: Stroustrup +ColumnLimit: 79 +DerivePointerAlignment: false +IndentWidth: 4 +Language: Cpp +PointerAlignment: Right +ReflowComments: true +SpaceBeforeParens: ControlStatements +SpacesInParentheses: false +TabWidth: 4 +UseCRLF: false +UseTab: Never +StatementMacros: + - Py_BEGIN_ALLOW_THREADS + - Py_END_ALLOW_THREADS \ No newline at end of file diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index 2f3651d6..c780e7df 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -37,7 +37,7 @@ python3.11 -m pip install build pip install ruff -sudo apt-get install -y tox +sudo apt-get install -y tox clang-format # install PostgreSQL client tools diff --git a/pgconn.c b/pgconn.c index c67e74dc..10e5b780 100644 --- a/pgconn.c +++ b/pgconn.c @@ -95,10 +95,12 @@ conn_getattr(connObject *self, PyObject *nameobj) /* whether the connection uses SSL */ if (!strcmp(name, "ssl_in_use")) { if (PQsslInUse(self->cnx)) { - Py_INCREF(Py_True); return Py_True; + Py_INCREF(Py_True); + return Py_True; } else { - Py_INCREF(Py_False); return Py_False; + Py_INCREF(Py_False); + return Py_False; } } @@ -107,7 +109,7 @@ conn_getattr(connObject *self, PyObject *nameobj) return get_ssl_attributes(self->cnx); } - return PyObject_GenericGetAttr((PyObject *) self, nameobj); + return PyObject_GenericGetAttr((PyObject *)self, nameobj); } /* Check connection validity. */ @@ -123,7 +125,7 @@ _check_cnx_obj(connObject *self) /* Create source object. */ static char conn_source__doc__[] = -"source() -- create a new source object for this connection"; + "source() -- create a new source object for this connection"; static PyObject * conn_source(connObject *self, PyObject *noargs) @@ -147,13 +149,13 @@ conn_source(connObject *self, PyObject *noargs) source_obj->valid = 1; source_obj->arraysize = PG_ARRAYSIZE; - return (PyObject *) source_obj; + return (PyObject *)source_obj; } /* For a non-query result, set the appropriate error status, return the appropriate value, and free the result set. */ static PyObject * -_conn_non_query_result(int status, PGresult* result, PGconn *cnx) +_conn_non_query_result(int status, PGresult *result, PGconn *cnx) { switch (status) { case PGRES_EMPTY_QUERY: @@ -162,29 +164,27 @@ _conn_non_query_result(int status, PGresult* result, PGconn *cnx) case PGRES_BAD_RESPONSE: case PGRES_FATAL_ERROR: case PGRES_NONFATAL_ERROR: - set_error(ProgrammingError, "Cannot execute query", - cnx, result); + set_error(ProgrammingError, "Cannot execute query", cnx, result); break; - case PGRES_COMMAND_OK: - { /* INSERT, UPDATE, DELETE */ - Oid oid = PQoidValue(result); + case PGRES_COMMAND_OK: { /* INSERT, UPDATE, DELETE */ + Oid oid = PQoidValue(result); - if (oid == InvalidOid) { /* not a single insert */ - char *ret = PQcmdTuples(result); + if (oid == InvalidOid) { /* not a single insert */ + char *ret = PQcmdTuples(result); - if (ret[0]) { /* return number of rows affected */ - PyObject *obj = PyUnicode_FromString(ret); - PQclear(result); - return obj; - } + if (ret[0]) { /* return number of rows affected */ + PyObject *obj = PyUnicode_FromString(ret); PQclear(result); - Py_INCREF(Py_None); - return Py_None; + return obj; } - /* for a single insert, return the oid */ PQclear(result); - return PyLong_FromLong((long) oid); + Py_INCREF(Py_None); + return Py_None; } + /* for a single insert, return the oid */ + PQclear(result); + return PyLong_FromLong((long)oid); + } case PGRES_COPY_OUT: /* no data will be received */ case PGRES_COPY_IN: PQclear(result); @@ -196,15 +196,15 @@ _conn_non_query_result(int status, PGresult* result, PGconn *cnx) PQclear(result); return NULL; /* error detected on query */ - } +} /* Base method for execution of all different kinds of queries */ static PyObject * _conn_query(connObject *self, PyObject *args, int prepared, int async) { PyObject *query_str_obj, *param_obj = NULL; - PGresult* result; - queryObject* query_obj; + PGresult *result; + queryObject *query_obj; char *query; int encoding, status, nparms = 0; @@ -226,7 +226,8 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) } else if (PyUnicode_Check(query_str_obj)) { query_str_obj = get_encoded_string(query_str_obj, encoding); - if (!query_str_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!query_str_obj) + return NULL; /* pass the UnicodeEncodeError */ query = PyBytes_AsString(query_str_obj); } else { @@ -246,7 +247,7 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) Py_XDECREF(query_str_obj); return NULL; } - nparms = (int) PySequence_Fast_GET_SIZE(param_obj); + nparms = (int)PySequence_Fast_GET_SIZE(param_obj); /* if there's a single argument and it's a list or tuple, it * contains the positional arguments. */ @@ -255,7 +256,7 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) if (PyList_Check(first_obj) || PyTuple_Check(first_obj)) { Py_DECREF(param_obj); param_obj = PySequence_Fast(first_obj, NULL); - nparms = (int) PySequence_Fast_GET_SIZE(param_obj); + nparms = (int)PySequence_Fast_GET_SIZE(param_obj); } } } @@ -267,11 +268,13 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) const char **parms, **p; register int i; - str = (PyObject **) PyMem_Malloc((size_t) nparms * sizeof(*str)); - parms = (const char **) PyMem_Malloc((size_t) nparms * sizeof(*parms)); + str = (PyObject **)PyMem_Malloc((size_t)nparms * sizeof(*str)); + parms = (const char **)PyMem_Malloc((size_t)nparms * sizeof(*parms)); if (!str || !parms) { - PyMem_Free((void *) parms); PyMem_Free(str); - Py_XDECREF(query_str_obj); Py_XDECREF(param_obj); + PyMem_Free((void *)parms); + PyMem_Free(str); + Py_XDECREF(query_str_obj); + Py_XDECREF(param_obj); return PyErr_NoMemory(); } @@ -290,8 +293,11 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) else if (PyUnicode_Check(obj)) { PyObject *str_obj = get_encoded_string(obj, encoding); if (!str_obj) { - PyMem_Free((void *) parms); - while (s != str) { s--; Py_DECREF(*s); } + PyMem_Free((void *)parms); + while (s != str) { + s--; + Py_DECREF(*s); + } PyMem_Free(str); Py_XDECREF(query_str_obj); Py_XDECREF(param_obj); @@ -304,8 +310,11 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) else { PyObject *str_obj = PyObject_Str(obj); if (!str_obj) { - PyMem_Free((void *) parms); - while (s != str) { s--; Py_DECREF(*s); } + PyMem_Free((void *)parms); + while (s != str) { + s--; + Py_DECREF(*s); + } PyMem_Free(str); Py_XDECREF(query_str_obj); Py_XDECREF(param_obj); @@ -321,22 +330,25 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) Py_BEGIN_ALLOW_THREADS if (async) { - status = PQsendQueryParams(self->cnx, query, nparms, - NULL, (const char * const *)parms, NULL, NULL, 0); + status = + PQsendQueryParams(self->cnx, query, nparms, NULL, + (const char *const *)parms, NULL, NULL, 0); result = NULL; } else { - result = prepared ? - PQexecPrepared(self->cnx, query, nparms, - parms, NULL, NULL, 0) : - PQexecParams(self->cnx, query, nparms, - NULL, parms, NULL, NULL, 0); + result = prepared ? PQexecPrepared(self->cnx, query, nparms, parms, + NULL, NULL, 0) + : PQexecParams(self->cnx, query, nparms, NULL, + parms, NULL, NULL, 0); status = result != NULL; } Py_END_ALLOW_THREADS - PyMem_Free((void *) parms); - while (s != str) { s--; Py_DECREF(*s); } + PyMem_Free((void *)parms); + while (s != str) { + s--; + Py_DECREF(*s); + } PyMem_Free(str); } else { @@ -346,10 +358,9 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) result = NULL; } else { - result = prepared ? - PQexecPrepared(self->cnx, query, 0, - NULL, NULL, NULL, 0) : - PQexec(self->cnx, query); + result = prepared ? PQexecPrepared(self->cnx, query, 0, NULL, NULL, + NULL, 0) + : PQexec(self->cnx, query); status = result != NULL; } Py_END_ALLOW_THREADS @@ -399,14 +410,14 @@ _conn_query(connObject *self, PyObject *args, int prepared, int async) } } - return (PyObject *) query_obj; + return (PyObject *)query_obj; } /* Database query */ static char conn_query__doc__[] = -"query(sql, [arg]) -- create a new query object for this connection\n\n" -"You must pass the SQL (string) request and you can optionally pass\n" -"a tuple with positional parameters.\n"; + "query(sql, [arg]) -- create a new query object for this connection\n\n" + "You must pass the SQL (string) request and you can optionally pass\n" + "a tuple with positional parameters.\n"; static PyObject * conn_query(connObject *self, PyObject *args) @@ -416,9 +427,10 @@ conn_query(connObject *self, PyObject *args) /* Asynchronous database query */ static char conn_send_query__doc__[] = -"send_query(sql, [arg]) -- create a new asynchronous query for this connection\n\n" -"You must pass the SQL (string) request and you can optionally pass\n" -"a tuple with positional parameters.\n"; + "send_query(sql, [arg]) -- create a new asynchronous query for this " + "connection\n\n" + "You must pass the SQL (string) request and you can optionally pass\n" + "a tuple with positional parameters.\n"; static PyObject * conn_send_query(connObject *self, PyObject *args) @@ -428,9 +440,9 @@ conn_send_query(connObject *self, PyObject *args) /* Execute prepared statement. */ static char conn_query_prepared__doc__[] = -"query_prepared(name, [arg]) -- execute a prepared statement\n\n" -"You must pass the name (string) of the prepared statement and you can\n" -"optionally pass a tuple with positional parameters.\n"; + "query_prepared(name, [arg]) -- execute a prepared statement\n\n" + "You must pass the name (string) of the prepared statement and you can\n" + "optionally pass a tuple with positional parameters.\n"; static PyObject * conn_query_prepared(connObject *self, PyObject *args) @@ -440,9 +452,9 @@ conn_query_prepared(connObject *self, PyObject *args) /* Create prepared statement. */ static char conn_prepare__doc__[] = -"prepare(name, sql) -- create a prepared statement\n\n" -"You must pass the name (string) of the prepared statement and the\n" -"SQL (string) request for later execution.\n"; + "prepare(name, sql) -- create a prepared statement\n\n" + "You must pass the name (string) of the prepared statement and the\n" + "SQL (string) request for later execution.\n"; static PyObject * conn_prepare(connObject *self, PyObject *args) @@ -457,9 +469,8 @@ conn_prepare(connObject *self, PyObject *args) } /* reads args */ - if (!PyArg_ParseTuple(args, "s#s#", - &name, &name_length, &query, &query_length)) - { + if (!PyArg_ParseTuple(args, "s#s#", &name, &name_length, &query, + &query_length)) { PyErr_SetString(PyExc_TypeError, "Method prepare() takes two string arguments"); return NULL; @@ -474,8 +485,8 @@ conn_prepare(connObject *self, PyObject *args) Py_INCREF(Py_None); return Py_None; /* success */ } - set_error(ProgrammingError, "Cannot create prepared statement", - self->cnx, result); + set_error(ProgrammingError, "Cannot create prepared statement", self->cnx, + result); if (result) PQclear(result); return NULL; /* error */ @@ -483,8 +494,8 @@ conn_prepare(connObject *self, PyObject *args) /* Describe prepared statement. */ static char conn_describe_prepared__doc__[] = -"describe_prepared(name) -- describe a prepared statement\n\n" -"You must pass the name (string) of the prepared statement.\n"; + "describe_prepared(name) -- describe a prepared statement\n\n" + "You must pass the name (string) of the prepared statement.\n"; static PyObject * conn_describe_prepared(connObject *self, PyObject *args) @@ -521,17 +532,17 @@ conn_describe_prepared(connObject *self, PyObject *args) query_obj->max_row = PQntuples(result); query_obj->num_fields = PQnfields(result); query_obj->col_types = get_col_types(result, query_obj->num_fields); - return (PyObject *) query_obj; + return (PyObject *)query_obj; } set_error(ProgrammingError, "Cannot describe prepared statement", - self->cnx, result); + self->cnx, result); if (result) PQclear(result); return NULL; /* error */ } static char conn_putline__doc__[] = -"putline(line) -- send a line directly to the backend"; + "putline(line) -- send a line directly to the backend"; /* Direct access function: putline. */ static PyObject * @@ -554,12 +565,14 @@ conn_putline(connObject *self, PyObject *args) } /* send line to backend */ - ret = PQputCopyData(self->cnx, line, (int) line_length); + ret = PQputCopyData(self->cnx, line, (int)line_length); if (ret != 1) { - PyErr_SetString(PyExc_IOError, ret == -1 ? PQerrorMessage(self->cnx) : - "Line cannot be queued, wait for write-ready and try again"); + PyErr_SetString( + PyExc_IOError, + ret == -1 + ? PQerrorMessage(self->cnx) + : "Line cannot be queued, wait for write-ready and try again"); return NULL; - } Py_INCREF(Py_None); return Py_None; @@ -567,7 +580,7 @@ conn_putline(connObject *self, PyObject *args) /* Direct access function: getline. */ static char conn_getline__doc__[] = -"getline() -- get a line directly from the backend"; + "getline() -- get a line directly from the backend"; static PyObject * conn_getline(connObject *self, PyObject *noargs) @@ -586,15 +599,18 @@ conn_getline(connObject *self, PyObject *noargs) /* check result */ if (ret <= 0) { - if (line != NULL) PQfreemem(line); + if (line != NULL) + PQfreemem(line); if (ret == -1) { PQgetResult(self->cnx); Py_INCREF(Py_None); return Py_None; } - PyErr_SetString(PyExc_MemoryError, - ret == -2 ? PQerrorMessage(self->cnx) : - "No line available, wait for read-ready and try again"); + PyErr_SetString( + PyExc_MemoryError, + ret == -2 + ? PQerrorMessage(self->cnx) + : "No line available, wait for read-ready and try again"); return NULL; } if (line == NULL) { @@ -602,7 +618,8 @@ conn_getline(connObject *self, PyObject *noargs) return Py_None; } /* for backward compatibility, convert terminating newline to zero byte */ - if (*line) line[strlen(line) - 1] = '\0'; + if (*line) + line[strlen(line) - 1] = '\0'; str = PyUnicode_FromString(line); PQfreemem(line); return str; @@ -610,7 +627,7 @@ conn_getline(connObject *self, PyObject *noargs) /* Direct access function: end copy. */ static char conn_endcopy__doc__[] = -"endcopy() -- synchronize client and server"; + "endcopy() -- synchronize client and server"; static PyObject * conn_endcopy(connObject *self, PyObject *noargs) @@ -624,11 +641,11 @@ conn_endcopy(connObject *self, PyObject *noargs) /* end direct copy */ ret = PQputCopyEnd(self->cnx, NULL); - if (ret != 1) - { - PyErr_SetString(PyExc_IOError, ret == -1 ? PQerrorMessage(self->cnx) : - "Termination message cannot be queued," - " wait for write-ready and try again"); + if (ret != 1) { + PyErr_SetString(PyExc_IOError, + ret == -1 ? PQerrorMessage(self->cnx) + : "Termination message cannot be queued," + " wait for write-ready and try again"); return NULL; } Py_INCREF(Py_None); @@ -637,7 +654,7 @@ conn_endcopy(connObject *self, PyObject *noargs) /* Direct access function: set blocking status. */ static char conn_set_non_blocking__doc__[] = -"set_non_blocking() -- set the non-blocking status of the connection"; + "set_non_blocking() -- set the non-blocking status of the connection"; static PyObject * conn_set_non_blocking(connObject *self, PyObject *args) @@ -666,7 +683,7 @@ conn_set_non_blocking(connObject *self, PyObject *args) /* Direct access function: get blocking status. */ static char conn_is_non_blocking__doc__[] = -"is_non_blocking() -- report the blocking status of the connection"; + "is_non_blocking() -- report the blocking status of the connection"; static PyObject * conn_is_non_blocking(connObject *self, PyObject *noargs) @@ -687,12 +704,11 @@ conn_is_non_blocking(connObject *self, PyObject *noargs) return PyBool_FromLong((long)rc); } - /* Insert table */ static char conn_inserttable__doc__[] = -"inserttable(table, data, [columns]) -- insert iterable into table\n\n" -"The fields in the iterable must be in the same order as in the table\n" -"or in the list or tuple of columns if one is specified.\n"; + "inserttable(table, data, [columns]) -- insert iterable into table\n\n" + "The fields in the iterable must be in the same order as in the table\n" + "or in the list or tuple of columns if one is specified.\n"; static PyObject * conn_inserttable(connObject *self, PyObject *args) @@ -718,8 +734,7 @@ conn_inserttable(connObject *self, PyObject *args) } /* checks list type */ - if (!(iter_row = PyObject_GetIter(rows))) - { + if (!(iter_row = PyObject_GetIter(rows))) { PyErr_SetString( PyExc_TypeError, "Method inserttable() expects an iterable as second argument"); @@ -728,31 +743,36 @@ conn_inserttable(connObject *self, PyObject *args) m = PySequence_Check(rows) ? PySequence_Size(rows) : -1; if (!m) { /* no rows specified, nothing to do */ - Py_DECREF(iter_row); Py_INCREF(Py_None); return Py_None; + Py_DECREF(iter_row); + Py_INCREF(Py_None); + return Py_None; } /* checks columns type */ if (columns) { if (!(PyTuple_Check(columns) || PyList_Check(columns))) { - PyErr_SetString( - PyExc_TypeError, - "Method inserttable() expects a tuple or a list" - " as third argument"); + PyErr_SetString(PyExc_TypeError, + "Method inserttable() expects a tuple or a list" + " as third argument"); return NULL; } n = PySequence_Fast_GET_SIZE(columns); if (!n) { /* no columns specified, nothing to do */ - Py_DECREF(iter_row); Py_INCREF(Py_None); return Py_None; + Py_DECREF(iter_row); + Py_INCREF(Py_None); + return Py_None; } - } else { + } + else { n = -1; /* number of columns not yet known */ } /* allocate buffer */ if (!(buffer = PyMem_Malloc(MAX_BUFFER_SIZE))) { - Py_DECREF(iter_row); return PyErr_NoMemory(); + Py_DECREF(iter_row); + return PyErr_NoMemory(); } encoding = PQclientEncoding(self->cnx); @@ -760,22 +780,26 @@ conn_inserttable(connObject *self, PyObject *args) /* starts query */ bufpt = buffer; bufmax = bufpt + MAX_BUFFER_SIZE; - bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), "copy "); + bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), "copy "); s = table; do { - t = strchr(s, '.'); if (!t) t = s + strlen(s); - table = PQescapeIdentifier(self->cnx, s, (size_t) (t - s)); + t = strchr(s, '.'); + if (!t) + t = s + strlen(s); + table = PQescapeIdentifier(self->cnx, s, (size_t)(t - s)); if (bufpt < bufmax) - bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), "%s", table); + bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), "%s", table); PQfreemem(table); - s = t; if (*s && bufpt < bufmax) *bufpt++ = *s++; + s = t; + if (*s && bufpt < bufmax) + *bufpt++ = *s++; } while (*s); if (columns) { /* adds a string like f" ({','.join(columns)})" */ if (bufpt < bufmax) - bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), " ("); + bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), " ("); for (j = 0; j < n; ++j) { PyObject *obj = PySequence_Fast_GET_ITEM(columns, j); Py_ssize_t slen; @@ -787,29 +811,33 @@ conn_inserttable(connObject *self, PyObject *args) else if (PyUnicode_Check(obj)) { obj = get_encoded_string(obj, encoding); if (!obj) { - PyMem_Free(buffer); Py_DECREF(iter_row); + PyMem_Free(buffer); + Py_DECREF(iter_row); return NULL; /* pass the UnicodeEncodeError */ } - } else { + } + else { PyErr_SetString( PyExc_TypeError, "The third argument must contain only strings"); - PyMem_Free(buffer); Py_DECREF(iter_row); + PyMem_Free(buffer); + Py_DECREF(iter_row); return NULL; } PyBytes_AsStringAndSize(obj, &col, &slen); - col = PQescapeIdentifier(self->cnx, col, (size_t) slen); + col = PQescapeIdentifier(self->cnx, col, (size_t)slen); Py_DECREF(obj); if (bufpt < bufmax) - bufpt += snprintf(bufpt, (size_t) (bufmax - bufpt), - "%s%s", col, j == n - 1 ? ")" : ","); + bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), "%s%s", col, + j == n - 1 ? ")" : ","); PQfreemem(col); } } if (bufpt < bufmax) - snprintf(bufpt, (size_t) (bufmax - bufpt), " from stdin"); - if (bufpt >= bufmax) { - PyMem_Free(buffer); Py_DECREF(iter_row); + snprintf(bufpt, (size_t)(bufmax - bufpt), " from stdin"); + if (bufpt >= bufmax) { + PyMem_Free(buffer); + Py_DECREF(iter_row); return PyErr_NoMemory(); } @@ -818,7 +846,8 @@ conn_inserttable(connObject *self, PyObject *args) Py_END_ALLOW_THREADS if (!result || PQresultStatus(result) != PGRES_COPY_IN) { - PyMem_Free(buffer); Py_DECREF(iter_row); + PyMem_Free(buffer); + Py_DECREF(iter_row); PyErr_SetString(PyExc_ValueError, PQerrorMessage(self->cnx)); return NULL; } @@ -827,12 +856,15 @@ conn_inserttable(connObject *self, PyObject *args) /* feed table */ for (i = 0; m < 0 || i < m; ++i) { - - if (!(columns = PyIter_Next(iter_row))) break; + if (!(columns = PyIter_Next(iter_row))) + break; if (!(PyTuple_Check(columns) || PyList_Check(columns))) { - PQputCopyEnd(self->cnx, "Invalid arguments"); PyMem_Free(buffer); - Py_DECREF(columns); Py_DECREF(columns); Py_DECREF(iter_row); + PQputCopyEnd(self->cnx, "Invalid arguments"); + PyMem_Free(buffer); + Py_DECREF(columns); + Py_DECREF(columns); + Py_DECREF(iter_row); PyErr_SetString( PyExc_TypeError, "The second argument must contain tuples or lists"); @@ -842,9 +874,12 @@ conn_inserttable(connObject *self, PyObject *args) j = PySequence_Fast_GET_SIZE(columns); if (n < 0) { n = j; - } else if (j != n) { - PQputCopyEnd(self->cnx, "Invalid arguments"); PyMem_Free(buffer); - Py_DECREF(columns); Py_DECREF(iter_row); + } + else if (j != n) { + PQputCopyEnd(self->cnx, "Invalid arguments"); + PyMem_Free(buffer); + Py_DECREF(columns); + Py_DECREF(iter_row); PyErr_SetString( PyExc_TypeError, "The second arg must contain sequences of the same size"); @@ -857,7 +892,8 @@ conn_inserttable(connObject *self, PyObject *args) for (j = 0; j < n; ++j) { if (j) { - *bufpt++ = '\t'; --bufsiz; + *bufpt++ = '\t'; + --bufsiz; } item = PySequence_Fast_GET_ITEM(columns, j); @@ -865,37 +901,43 @@ conn_inserttable(connObject *self, PyObject *args) /* convert item to string and append to buffer */ if (item == Py_None) { if (bufsiz > 2) { - *bufpt++ = '\\'; *bufpt++ = 'N'; + *bufpt++ = '\\'; + *bufpt++ = 'N'; bufsiz -= 2; } else bufsiz = 0; } else if (PyBytes_Check(item)) { - const char* t = PyBytes_AsString(item); + const char *t = PyBytes_AsString(item); while (*t && bufsiz) { switch (*t) { case '\\': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= '\\'; + if (--bufsiz) + *bufpt++ = '\\'; break; case '\t': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= 't'; + if (--bufsiz) + *bufpt++ = 't'; break; case '\r': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= 'r'; + if (--bufsiz) + *bufpt++ = 'r'; break; case '\n': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= 'n'; + if (--bufsiz) + *bufpt++ = 'n'; break; default: - *bufpt ++= *t; + *bufpt++ = *t; } - ++t; --bufsiz; + ++t; + --bufsiz; } } else if (PyUnicode_Check(item)) { @@ -903,83 +945,97 @@ conn_inserttable(connObject *self, PyObject *args) if (!s) { PQputCopyEnd(self->cnx, "Encoding error"); PyMem_Free(buffer); - Py_DECREF(item); Py_DECREF(columns); Py_DECREF(iter_row); + Py_DECREF(item); + Py_DECREF(columns); + Py_DECREF(iter_row); return NULL; /* pass the UnicodeEncodeError */ } else { - const char* t = PyBytes_AsString(s); + const char *t = PyBytes_AsString(s); while (*t && bufsiz) { switch (*t) { case '\\': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= '\\'; + if (--bufsiz) + *bufpt++ = '\\'; break; case '\t': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= 't'; + if (--bufsiz) + *bufpt++ = 't'; break; case '\r': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= 'r'; + if (--bufsiz) + *bufpt++ = 'r'; break; case '\n': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= 'n'; + if (--bufsiz) + *bufpt++ = 'n'; break; default: - *bufpt ++= *t; + *bufpt++ = *t; } - ++t; --bufsiz; + ++t; + --bufsiz; } Py_DECREF(s); } } else if (PyLong_Check(item)) { - PyObject* s = PyObject_Str(item); - const char* t = PyUnicode_AsUTF8(s); + PyObject *s = PyObject_Str(item); + const char *t = PyUnicode_AsUTF8(s); while (*t && bufsiz) { - *bufpt++ = *t++; --bufsiz; + *bufpt++ = *t++; + --bufsiz; } Py_DECREF(s); } else { - PyObject* s = PyObject_Repr(item); - const char* t = PyUnicode_AsUTF8(s); + PyObject *s = PyObject_Repr(item); + const char *t = PyUnicode_AsUTF8(s); while (*t && bufsiz) { switch (*t) { case '\\': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= '\\'; + if (--bufsiz) + *bufpt++ = '\\'; break; case '\t': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= 't'; + if (--bufsiz) + *bufpt++ = 't'; break; case '\r': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= 'r'; + if (--bufsiz) + *bufpt++ = 'r'; break; case '\n': *bufpt++ = '\\'; - if (--bufsiz) *bufpt ++= 'n'; + if (--bufsiz) + *bufpt++ = 'n'; break; default: - *bufpt ++= *t; + *bufpt++ = *t; } - ++t; --bufsiz; + ++t; + --bufsiz; } Py_DECREF(s); } if (bufsiz <= 0) { - PQputCopyEnd(self->cnx, "Memory error"); PyMem_Free(buffer); - Py_DECREF(columns); Py_DECREF(iter_row); + PQputCopyEnd(self->cnx, "Memory error"); + PyMem_Free(buffer); + Py_DECREF(columns); + Py_DECREF(iter_row); return PyErr_NoMemory(); } - } Py_DECREF(columns); @@ -987,13 +1043,14 @@ conn_inserttable(connObject *self, PyObject *args) *bufpt++ = '\n'; /* sends data */ - ret = PQputCopyData(self->cnx, buffer, (int) (bufpt - buffer)); + ret = PQputCopyData(self->cnx, buffer, (int)(bufpt - buffer)); if (ret != 1) { - char *errormsg = ret == - 1 ? - PQerrorMessage(self->cnx) : "Data cannot be queued"; + char *errormsg = ret == -1 ? PQerrorMessage(self->cnx) + : "Data cannot be queued"; PyErr_SetString(PyExc_IOError, errormsg); PQputCopyEnd(self->cnx, errormsg); - PyMem_Free(buffer); Py_DECREF(iter_row); + PyMem_Free(buffer); + Py_DECREF(iter_row); return NULL; } } @@ -1006,8 +1063,8 @@ conn_inserttable(connObject *self, PyObject *args) ret = PQputCopyEnd(self->cnx, NULL); if (ret != 1) { - PyErr_SetString(PyExc_IOError, ret == -1 ? - PQerrorMessage(self->cnx) : "Data cannot be queued"); + PyErr_SetString(PyExc_IOError, ret == -1 ? PQerrorMessage(self->cnx) + : "Data cannot be queued"); PyMem_Free(buffer); return NULL; } @@ -1021,7 +1078,8 @@ conn_inserttable(connObject *self, PyObject *args) PyErr_SetString(PyExc_ValueError, PQerrorMessage(self->cnx)); PQclear(result); return NULL; - } else { + } + else { long ntuples = atol(PQcmdTuples(result)); PQclear(result); return PyLong_FromLong(ntuples); @@ -1030,7 +1088,7 @@ conn_inserttable(connObject *self, PyObject *args) /* Get transaction state. */ static char conn_transaction__doc__[] = -"transaction() -- return the current transaction status"; + "transaction() -- return the current transaction status"; static PyObject * conn_transaction(connObject *self, PyObject *noargs) @@ -1045,7 +1103,7 @@ conn_transaction(connObject *self, PyObject *noargs) /* Get parameter setting. */ static char conn_parameter__doc__[] = -"parameter(name) -- look up a current parameter setting"; + "parameter(name) -- look up a current parameter setting"; static PyObject * conn_parameter(connObject *self, PyObject *args) @@ -1076,7 +1134,7 @@ conn_parameter(connObject *self, PyObject *args) /* Get current date format. */ static char conn_date_format__doc__[] = -"date_format() -- return the current date format"; + "date_format() -- return the current date format"; static PyObject * conn_date_format(connObject *self, PyObject *noargs) @@ -1100,18 +1158,18 @@ conn_date_format(connObject *self, PyObject *noargs) /* Escape literal */ static char conn_escape_literal__doc__[] = -"escape_literal(str) -- escape a literal constant for use within SQL"; + "escape_literal(str) -- escape a literal constant for use within SQL"; static PyObject * conn_escape_literal(connObject *self, PyObject *string) { - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ - int encoding = -1; /* client encoding */ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ + int encoding = -1; /* client encoding */ if (PyBytes_Check(string)) { PyBytes_AsStringAndSize(string, &from, &from_length); @@ -1119,7 +1177,8 @@ conn_escape_literal(connObject *self, PyObject *string) else if (PyUnicode_Check(string)) { encoding = PQclientEncoding(self->cnx); tmp_obj = get_encoded_string(string, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); } else { @@ -1129,15 +1188,15 @@ conn_escape_literal(connObject *self, PyObject *string) return NULL; } - to = PQescapeLiteral(self->cnx, from, (size_t) from_length); + to = PQescapeLiteral(self->cnx, from, (size_t)from_length); to_length = strlen(to); Py_XDECREF(tmp_obj); if (encoding == -1) - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length); + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length); else - to_obj = get_decoded_string(to, (Py_ssize_t) to_length, encoding); + to_obj = get_decoded_string(to, (Py_ssize_t)to_length, encoding); if (to) PQfreemem(to); return to_obj; @@ -1145,18 +1204,18 @@ conn_escape_literal(connObject *self, PyObject *string) /* Escape identifier */ static char conn_escape_identifier__doc__[] = -"escape_identifier(str) -- escape an identifier for use within SQL"; + "escape_identifier(str) -- escape an identifier for use within SQL"; static PyObject * conn_escape_identifier(connObject *self, PyObject *string) { - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ - int encoding = -1; /* client encoding */ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ + int encoding = -1; /* client encoding */ if (PyBytes_Check(string)) { PyBytes_AsStringAndSize(string, &from, &from_length); @@ -1164,7 +1223,8 @@ conn_escape_identifier(connObject *self, PyObject *string) else if (PyUnicode_Check(string)) { encoding = PQclientEncoding(self->cnx); tmp_obj = get_encoded_string(string, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); } else { @@ -1174,15 +1234,15 @@ conn_escape_identifier(connObject *self, PyObject *string) return NULL; } - to = PQescapeIdentifier(self->cnx, from, (size_t) from_length); + to = PQescapeIdentifier(self->cnx, from, (size_t)from_length); to_length = strlen(to); Py_XDECREF(tmp_obj); if (encoding == -1) - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length); + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length); else - to_obj = get_decoded_string(to, (Py_ssize_t) to_length, encoding); + to_obj = get_decoded_string(to, (Py_ssize_t)to_length, encoding); if (to) PQfreemem(to); return to_obj; @@ -1190,18 +1250,18 @@ conn_escape_identifier(connObject *self, PyObject *string) /* Escape string */ static char conn_escape_string__doc__[] = -"escape_string(str) -- escape a string for use within SQL"; + "escape_string(str) -- escape a string for use within SQL"; static PyObject * conn_escape_string(connObject *self, PyObject *string) { - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ - int encoding = -1; /* client encoding */ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ + int encoding = -1; /* client encoding */ if (PyBytes_Check(string)) { PyBytes_AsStringAndSize(string, &from, &from_length); @@ -1209,49 +1269,50 @@ conn_escape_string(connObject *self, PyObject *string) else if (PyUnicode_Check(string)) { encoding = PQclientEncoding(self->cnx); tmp_obj = get_encoded_string(string, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); } else { - PyErr_SetString( - PyExc_TypeError, - "Method escape_string() expects a string as argument"); + PyErr_SetString(PyExc_TypeError, + "Method escape_string() expects a string as argument"); return NULL; } - to_length = 2 * (size_t) from_length + 1; - if ((Py_ssize_t) to_length < from_length) { /* overflow */ - to_length = (size_t) from_length; - from_length = (from_length - 1)/2; + to_length = 2 * (size_t)from_length + 1; + if ((Py_ssize_t)to_length < from_length) { /* overflow */ + to_length = (size_t)from_length; + from_length = (from_length - 1) / 2; } - to = (char *) PyMem_Malloc(to_length); - to_length = PQescapeStringConn(self->cnx, - to, from, (size_t) from_length, NULL); + to = (char *)PyMem_Malloc(to_length); + to_length = + PQescapeStringConn(self->cnx, to, from, (size_t)from_length, NULL); Py_XDECREF(tmp_obj); if (encoding == -1) - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length); + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length); else - to_obj = get_decoded_string(to, (Py_ssize_t) to_length, encoding); + to_obj = get_decoded_string(to, (Py_ssize_t)to_length, encoding); PyMem_Free(to); return to_obj; } /* Escape bytea */ static char conn_escape_bytea__doc__[] = -"escape_bytea(data) -- escape binary data for use within SQL as type bytea"; + "escape_bytea(data) -- escape binary data for use within SQL as type " + "bytea"; static PyObject * conn_escape_bytea(connObject *self, PyObject *data) { - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ - int encoding = -1; /* client encoding */ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ + int encoding = -1; /* client encoding */ if (PyBytes_Check(data)) { PyBytes_AsStringAndSize(data, &from, &from_length); @@ -1259,25 +1320,25 @@ conn_escape_bytea(connObject *self, PyObject *data) else if (PyUnicode_Check(data)) { encoding = PQclientEncoding(self->cnx); tmp_obj = get_encoded_string(data, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); } else { - PyErr_SetString( - PyExc_TypeError, - "Method escape_bytea() expects a string as argument"); + PyErr_SetString(PyExc_TypeError, + "Method escape_bytea() expects a string as argument"); return NULL; } - to = (char *) PQescapeByteaConn(self->cnx, - (unsigned char *) from, (size_t) from_length, &to_length); + to = (char *)PQescapeByteaConn(self->cnx, (unsigned char *)from, + (size_t)from_length, &to_length); Py_XDECREF(tmp_obj); if (encoding == -1) - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length - 1); + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length - 1); else - to_obj = get_decoded_string(to, (Py_ssize_t) to_length - 1, encoding); + to_obj = get_decoded_string(to, (Py_ssize_t)to_length - 1, encoding); if (to) PQfreemem(to); return to_obj; @@ -1303,7 +1364,7 @@ large_new(connObject *pgcnx, Oid oid) /* Create large object. */ static char conn_locreate__doc__[] = -"locreate(mode) -- create a new large object in the database"; + "locreate(mode) -- create a new large object in the database"; static PyObject * conn_locreate(connObject *self, PyObject *args) @@ -1330,12 +1391,12 @@ conn_locreate(connObject *self, PyObject *args) return NULL; } - return (PyObject *) large_new(self, lo_oid); + return (PyObject *)large_new(self, lo_oid); } /* Init from already known oid. */ static char conn_getlo__doc__[] = -"getlo(oid) -- create a large object instance for the specified oid"; + "getlo(oid) -- create a large object instance for the specified oid"; static PyObject * conn_getlo(connObject *self, PyObject *args) @@ -1355,19 +1416,19 @@ conn_getlo(connObject *self, PyObject *args) return NULL; } - lo_oid = (Oid) oid; + lo_oid = (Oid)oid; if (lo_oid == 0) { PyErr_SetString(PyExc_ValueError, "The object oid can't be null"); return NULL; } /* creates object */ - return (PyObject *) large_new(self, lo_oid); + return (PyObject *)large_new(self, lo_oid); } /* Import unix file. */ static char conn_loimport__doc__[] = -"loimport(name) -- create a new large object from specified file"; + "loimport(name) -- create a new large object from specified file"; static PyObject * conn_loimport(connObject *self, PyObject *args) @@ -1394,14 +1455,14 @@ conn_loimport(connObject *self, PyObject *args) return NULL; } - return (PyObject *) large_new(self, lo_oid); + return (PyObject *)large_new(self, lo_oid); } /* Reset connection. */ static char conn_reset__doc__[] = -"reset() -- reset connection with current parameters\n\n" -"All derived queries and large objects derived from this connection\n" -"will not be usable after this call.\n"; + "reset() -- reset connection with current parameters\n\n" + "All derived queries and large objects derived from this connection\n" + "will not be usable after this call.\n"; static PyObject * conn_reset(connObject *self, PyObject *noargs) @@ -1419,7 +1480,7 @@ conn_reset(connObject *self, PyObject *noargs) /* Cancel current command. */ static char conn_cancel__doc__[] = -"cancel() -- abandon processing of the current command"; + "cancel() -- abandon processing of the current command"; static PyObject * conn_cancel(connObject *self, PyObject *noargs) @@ -1430,12 +1491,12 @@ conn_cancel(connObject *self, PyObject *noargs) } /* request that the server abandon processing of the current command */ - return PyLong_FromLong((long) PQrequestCancel(self->cnx)); + return PyLong_FromLong((long)PQrequestCancel(self->cnx)); } /* Get connection socket. */ static char conn_fileno__doc__[] = -"fileno() -- return database connection socket file handle"; + "fileno() -- return database connection socket file handle"; static PyObject * conn_fileno(connObject *self, PyObject *noargs) @@ -1445,12 +1506,12 @@ conn_fileno(connObject *self, PyObject *noargs) return NULL; } - return PyLong_FromLong((long) PQsocket(self->cnx)); + return PyLong_FromLong((long)PQsocket(self->cnx)); } /* Set external typecast callback function. */ static char conn_set_cast_hook__doc__[] = -"set_cast_hook(func) -- set a fallback typecast function"; + "set_cast_hook(func) -- set a fallback typecast function"; static PyObject * conn_set_cast_hook(connObject *self, PyObject *func) @@ -1460,12 +1521,15 @@ conn_set_cast_hook(connObject *self, PyObject *func) if (func == Py_None) { Py_XDECREF(self->cast_hook); self->cast_hook = NULL; - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else if (PyCallable_Check(func)) { - Py_XINCREF(func); Py_XDECREF(self->cast_hook); + Py_XINCREF(func); + Py_XDECREF(self->cast_hook); self->cast_hook = func; - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString(PyExc_TypeError, @@ -1478,12 +1542,13 @@ conn_set_cast_hook(connObject *self, PyObject *func) /* Get notice receiver callback function. */ static char conn_get_cast_hook__doc__[] = -"get_cast_hook() -- get the fallback typecast function"; + "get_cast_hook() -- get the fallback typecast function"; static PyObject * conn_get_cast_hook(connObject *self, PyObject *noargs) { - PyObject *ret = self->cast_hook;; + PyObject *ret = self->cast_hook; + ; if (!ret) ret = Py_None; @@ -1494,7 +1559,7 @@ conn_get_cast_hook(connObject *self, PyObject *noargs) /* Get asynchronous connection state. */ static char conn_poll__doc__[] = -"poll() -- Completes an asynchronous connection"; + "poll() -- Completes an asynchronous connection"; static PyObject * conn_poll(connObject *self, PyObject *noargs) @@ -1521,7 +1586,7 @@ conn_poll(connObject *self, PyObject *noargs) /* Set notice receiver callback function. */ static char conn_set_notice_receiver__doc__[] = -"set_notice_receiver(func) -- set the current notice receiver"; + "set_notice_receiver(func) -- set the current notice receiver"; static PyObject * conn_set_notice_receiver(connObject *self, PyObject *func) @@ -1531,13 +1596,16 @@ conn_set_notice_receiver(connObject *self, PyObject *func) if (func == Py_None) { Py_XDECREF(self->notice_receiver); self->notice_receiver = NULL; - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else if (PyCallable_Check(func)) { - Py_XINCREF(func); Py_XDECREF(self->notice_receiver); + Py_XINCREF(func); + Py_XDECREF(self->notice_receiver); self->notice_receiver = func; PQsetNoticeReceiver(self->cnx, notice_receiver, self); - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString(PyExc_TypeError, @@ -1550,7 +1618,7 @@ conn_set_notice_receiver(connObject *self, PyObject *func) /* Get notice receiver callback function. */ static char conn_get_notice_receiver__doc__[] = -"get_notice_receiver() -- get the current notice receiver"; + "get_notice_receiver() -- get the current notice receiver"; static PyObject * conn_get_notice_receiver(connObject *self, PyObject *noargs) @@ -1566,9 +1634,9 @@ conn_get_notice_receiver(connObject *self, PyObject *noargs) /* Close without deleting. */ static char conn_close__doc__[] = -"close() -- close connection\n\n" -"All instances of the connection object and derived objects\n" -"(queries and large objects) can no longer be used after this call.\n"; + "close() -- close connection\n\n" + "All instances of the connection object and derived objects\n" + "(queries and large objects) can no longer be used after this call.\n"; static PyObject * conn_close(connObject *self, PyObject *noargs) @@ -1590,7 +1658,7 @@ conn_close(connObject *self, PyObject *noargs) /* Get asynchronous notify. */ static char conn_get_notify__doc__[] = -"getnotify() -- get database notify for this connection"; + "getnotify() -- get database notify for this connection"; static PyObject * conn_get_notify(connObject *self, PyObject *noargs) @@ -1649,87 +1717,74 @@ conn_dir(connObject *self, PyObject *noargs) { PyObject *attrs; - attrs = PyObject_Dir(PyObject_Type((PyObject *) self)); - PyObject_CallMethod( - attrs, "extend", "[sssssssssssss]", - "host", "port", "db", "options", "error", "status", "user", - "protocol_version", "server_version", "socket", "backend_pid", - "ssl_in_use", "ssl_attributes"); + attrs = PyObject_Dir(PyObject_Type((PyObject *)self)); + PyObject_CallMethod(attrs, "extend", "[sssssssssssss]", "host", "port", + "db", "options", "error", "status", "user", + "protocol_version", "server_version", "socket", + "backend_pid", "ssl_in_use", "ssl_attributes"); return attrs; } /* Connection object methods */ static struct PyMethodDef conn_methods[] = { - {"__dir__", (PyCFunction) conn_dir, METH_NOARGS, NULL}, - - {"source", (PyCFunction) conn_source, - METH_NOARGS, conn_source__doc__}, - {"query", (PyCFunction) conn_query, - METH_VARARGS, conn_query__doc__}, - {"send_query", (PyCFunction) conn_send_query, - METH_VARARGS, conn_send_query__doc__}, - {"query_prepared", (PyCFunction) conn_query_prepared, - METH_VARARGS, conn_query_prepared__doc__}, - {"prepare", (PyCFunction) conn_prepare, - METH_VARARGS, conn_prepare__doc__}, - {"describe_prepared", (PyCFunction) conn_describe_prepared, - METH_VARARGS, conn_describe_prepared__doc__}, - {"poll", (PyCFunction) conn_poll, - METH_NOARGS, conn_poll__doc__}, - {"reset", (PyCFunction) conn_reset, - METH_NOARGS, conn_reset__doc__}, - {"cancel", (PyCFunction) conn_cancel, - METH_NOARGS, conn_cancel__doc__}, - {"close", (PyCFunction) conn_close, - METH_NOARGS, conn_close__doc__}, - {"fileno", (PyCFunction) conn_fileno, - METH_NOARGS, conn_fileno__doc__}, - {"get_cast_hook", (PyCFunction) conn_get_cast_hook, - METH_NOARGS, conn_get_cast_hook__doc__}, - {"set_cast_hook", (PyCFunction) conn_set_cast_hook, - METH_O, conn_set_cast_hook__doc__}, - {"get_notice_receiver", (PyCFunction) conn_get_notice_receiver, - METH_NOARGS, conn_get_notice_receiver__doc__}, - {"set_notice_receiver", (PyCFunction) conn_set_notice_receiver, - METH_O, conn_set_notice_receiver__doc__}, - {"getnotify", (PyCFunction) conn_get_notify, - METH_NOARGS, conn_get_notify__doc__}, - {"inserttable", (PyCFunction) conn_inserttable, - METH_VARARGS, conn_inserttable__doc__}, - {"transaction", (PyCFunction) conn_transaction, - METH_NOARGS, conn_transaction__doc__}, - {"parameter", (PyCFunction) conn_parameter, - METH_VARARGS, conn_parameter__doc__}, - {"date_format", (PyCFunction) conn_date_format, - METH_NOARGS, conn_date_format__doc__}, - - {"escape_literal", (PyCFunction) conn_escape_literal, - METH_O, conn_escape_literal__doc__}, - {"escape_identifier", (PyCFunction) conn_escape_identifier, - METH_O, conn_escape_identifier__doc__}, - {"escape_string", (PyCFunction) conn_escape_string, - METH_O, conn_escape_string__doc__}, - {"escape_bytea", (PyCFunction) conn_escape_bytea, - METH_O, conn_escape_bytea__doc__}, - - {"putline", (PyCFunction) conn_putline, - METH_VARARGS, conn_putline__doc__}, - {"getline", (PyCFunction) conn_getline, - METH_NOARGS, conn_getline__doc__}, - {"endcopy", (PyCFunction) conn_endcopy, - METH_NOARGS, conn_endcopy__doc__}, - {"set_non_blocking", (PyCFunction) conn_set_non_blocking, - METH_VARARGS, conn_set_non_blocking__doc__}, - {"is_non_blocking", (PyCFunction) conn_is_non_blocking, - METH_NOARGS, conn_is_non_blocking__doc__}, - - {"locreate", (PyCFunction) conn_locreate, - METH_VARARGS, conn_locreate__doc__}, - {"getlo", (PyCFunction) conn_getlo, - METH_VARARGS, conn_getlo__doc__}, - {"loimport", (PyCFunction) conn_loimport, - METH_VARARGS, conn_loimport__doc__}, + {"__dir__", (PyCFunction)conn_dir, METH_NOARGS, NULL}, + + {"source", (PyCFunction)conn_source, METH_NOARGS, conn_source__doc__}, + {"query", (PyCFunction)conn_query, METH_VARARGS, conn_query__doc__}, + {"send_query", (PyCFunction)conn_send_query, METH_VARARGS, + conn_send_query__doc__}, + {"query_prepared", (PyCFunction)conn_query_prepared, METH_VARARGS, + conn_query_prepared__doc__}, + {"prepare", (PyCFunction)conn_prepare, METH_VARARGS, conn_prepare__doc__}, + {"describe_prepared", (PyCFunction)conn_describe_prepared, METH_VARARGS, + conn_describe_prepared__doc__}, + {"poll", (PyCFunction)conn_poll, METH_NOARGS, conn_poll__doc__}, + {"reset", (PyCFunction)conn_reset, METH_NOARGS, conn_reset__doc__}, + {"cancel", (PyCFunction)conn_cancel, METH_NOARGS, conn_cancel__doc__}, + {"close", (PyCFunction)conn_close, METH_NOARGS, conn_close__doc__}, + {"fileno", (PyCFunction)conn_fileno, METH_NOARGS, conn_fileno__doc__}, + {"get_cast_hook", (PyCFunction)conn_get_cast_hook, METH_NOARGS, + conn_get_cast_hook__doc__}, + {"set_cast_hook", (PyCFunction)conn_set_cast_hook, METH_O, + conn_set_cast_hook__doc__}, + {"get_notice_receiver", (PyCFunction)conn_get_notice_receiver, METH_NOARGS, + conn_get_notice_receiver__doc__}, + {"set_notice_receiver", (PyCFunction)conn_set_notice_receiver, METH_O, + conn_set_notice_receiver__doc__}, + {"getnotify", (PyCFunction)conn_get_notify, METH_NOARGS, + conn_get_notify__doc__}, + {"inserttable", (PyCFunction)conn_inserttable, METH_VARARGS, + conn_inserttable__doc__}, + {"transaction", (PyCFunction)conn_transaction, METH_NOARGS, + conn_transaction__doc__}, + {"parameter", (PyCFunction)conn_parameter, METH_VARARGS, + conn_parameter__doc__}, + {"date_format", (PyCFunction)conn_date_format, METH_NOARGS, + conn_date_format__doc__}, + + {"escape_literal", (PyCFunction)conn_escape_literal, METH_O, + conn_escape_literal__doc__}, + {"escape_identifier", (PyCFunction)conn_escape_identifier, METH_O, + conn_escape_identifier__doc__}, + {"escape_string", (PyCFunction)conn_escape_string, METH_O, + conn_escape_string__doc__}, + {"escape_bytea", (PyCFunction)conn_escape_bytea, METH_O, + conn_escape_bytea__doc__}, + + {"putline", (PyCFunction)conn_putline, METH_VARARGS, conn_putline__doc__}, + {"getline", (PyCFunction)conn_getline, METH_NOARGS, conn_getline__doc__}, + {"endcopy", (PyCFunction)conn_endcopy, METH_NOARGS, conn_endcopy__doc__}, + {"set_non_blocking", (PyCFunction)conn_set_non_blocking, METH_VARARGS, + conn_set_non_blocking__doc__}, + {"is_non_blocking", (PyCFunction)conn_is_non_blocking, METH_NOARGS, + conn_is_non_blocking__doc__}, + + {"locreate", (PyCFunction)conn_locreate, METH_VARARGS, + conn_locreate__doc__}, + {"getlo", (PyCFunction)conn_getlo, METH_VARARGS, conn_getlo__doc__}, + {"loimport", (PyCFunction)conn_loimport, METH_VARARGS, + conn_loimport__doc__}, {NULL, NULL} /* sentinel */ }; @@ -1738,32 +1793,31 @@ static char conn__doc__[] = "PostgreSQL connection object"; /* Connection type definition */ static PyTypeObject connType = { - PyVarObject_HEAD_INIT(NULL, 0) - "pg.Connection", /* tp_name */ - sizeof(connObject), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor) conn_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_reserved */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - (getattrofunc) conn_getattr, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - conn__doc__, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - conn_methods, /* tp_methods */ + PyVarObject_HEAD_INIT(NULL, 0) "pg.Connection", /* tp_name */ + sizeof(connObject), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)conn_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + (getattrofunc)conn_getattr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + conn__doc__, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + conn_methods, /* tp_methods */ }; diff --git a/pginternal.c b/pginternal.c index 61446f41..124661c1 100644 --- a/pginternal.c +++ b/pginternal.c @@ -37,8 +37,8 @@ get_decoded_string(const char *str, Py_ssize_t size, int encoding) if (encoding == pg_encoding_ascii) return PyUnicode_DecodeASCII(str, size, "strict"); /* encoding name should be properly translated to Python here */ - return PyUnicode_Decode(str, size, - pg_encoding_to_char(encoding), "strict"); + return PyUnicode_Decode(str, size, pg_encoding_to_char(encoding), + "strict"); } static PyObject * @@ -52,7 +52,7 @@ get_encoded_string(PyObject *unicode_obj, int encoding) return PyUnicode_AsASCIIString(unicode_obj); /* encoding name should be properly translated to Python here */ return PyUnicode_AsEncodedString(unicode_obj, - pg_encoding_to_char(encoding), "strict"); + pg_encoding_to_char(encoding), "strict"); } /* Helper functions */ @@ -64,7 +64,7 @@ get_type(Oid pgtype) int t; switch (pgtype) { - /* simple types */ + /* simple types */ case INT2OID: case INT4OID: @@ -113,7 +113,7 @@ get_type(Oid pgtype) t = PYGRES_TEXT; break; - /* array types */ + /* array types */ case INT2ARRAYOID: case INT4ARRAYOID: @@ -137,8 +137,9 @@ get_type(Oid pgtype) break; case MONEYARRAYOID: - t = array_as_text ? PYGRES_TEXT : ((decimal_point ? - PYGRES_MONEY : PYGRES_TEXT) | PYGRES_ARRAY); + t = array_as_text ? PYGRES_TEXT + : ((decimal_point ? PYGRES_MONEY : PYGRES_TEXT) | + PYGRES_ARRAY); break; case BOOLARRAYOID: @@ -146,14 +147,16 @@ get_type(Oid pgtype) break; case BYTEAARRAYOID: - t = array_as_text ? PYGRES_TEXT : ((bytea_escaped ? - PYGRES_TEXT : PYGRES_BYTEA) | PYGRES_ARRAY); + t = array_as_text ? PYGRES_TEXT + : ((bytea_escaped ? PYGRES_TEXT : PYGRES_BYTEA) | + PYGRES_ARRAY); break; case JSONARRAYOID: case JSONBARRAYOID: - t = array_as_text ? PYGRES_TEXT : ((jsondecode ? - PYGRES_JSON : PYGRES_TEXT) | PYGRES_ARRAY); + t = array_as_text ? PYGRES_TEXT + : ((jsondecode ? PYGRES_JSON : PYGRES_TEXT) | + PYGRES_ARRAY); break; case BPCHARARRAYOID: @@ -178,8 +181,8 @@ get_col_types(PGresult *result, int nfields) { int *types, *t, j; - if (!(types = PyMem_Malloc(sizeof(int) * (size_t) nfields))) { - return (int*) PyErr_NoMemory(); + if (!(types = PyMem_Malloc(sizeof(int) * (size_t)nfields))) { + return (int *)PyErr_NoMemory(); } for (j = 0, t = types; j < nfields; ++j) { @@ -199,8 +202,8 @@ cast_bytea_text(char *s) size_t str_len; /* this function should not be called when bytea_escaped is set */ - tmp_str = (char *) PQunescapeBytea((unsigned char*) s, &str_len); - obj = PyBytes_FromStringAndSize(tmp_str, (Py_ssize_t) str_len); + tmp_str = (char *)PQunescapeBytea((unsigned char *)s, &str_len); + obj = PyBytes_FromStringAndSize(tmp_str, (Py_ssize_t)str_len); if (tmp_str) { PQfreemem(tmp_str); } @@ -221,16 +224,18 @@ cast_sized_text(char *s, Py_ssize_t size, int encoding, int type) case PYGRES_BYTEA: /* this type should not be passed when bytea_escaped is set */ /* we need to add a null byte */ - tmp_str = (char *) PyMem_Malloc((size_t) size + 1); + tmp_str = (char *)PyMem_Malloc((size_t)size + 1); if (!tmp_str) { return PyErr_NoMemory(); } - memcpy(tmp_str, s, (size_t) size); - s = tmp_str; *(s + size) = '\0'; - tmp_str = (char *) PQunescapeBytea((unsigned char*) s, &str_len); + memcpy(tmp_str, s, (size_t)size); + s = tmp_str; + *(s + size) = '\0'; + tmp_str = (char *)PQunescapeBytea((unsigned char *)s, &str_len); PyMem_Free(s); - if (!tmp_str) return PyErr_NoMemory(); - obj = PyBytes_FromStringAndSize(tmp_str, (Py_ssize_t) str_len); + if (!tmp_str) + return PyErr_NoMemory(); + obj = PyBytes_FromStringAndSize(tmp_str, (Py_ssize_t)str_len); if (tmp_str) { PQfreemem(tmp_str); } @@ -246,7 +251,7 @@ cast_sized_text(char *s, Py_ssize_t size, int encoding, int type) } break; - default: /* PYGRES_TEXT */ + default: /* PYGRES_TEXT */ obj = get_decoded_string(s, size, encoding); if (!obj) { /* cannot decode */ obj = PyBytes_FromStringAndSize(s, size); @@ -288,8 +293,8 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) case PYGRES_INT: n = sizeof(buf) / sizeof(buf[0]) - 1; - if ((int) size < n) { - n = (int) size; + if ((int)size < n) { + n = (int)size; } for (i = 0, t = buf; i < n; ++i) { *t++ = *s++; @@ -300,8 +305,8 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) case PYGRES_LONG: n = sizeof(buf) / sizeof(buf[0]) - 1; - if ((int) size < n) { - n = (int) size; + if ((int)size < n) { + n = (int)size; } for (i = 0, t = buf; i < n; ++i) { *t++ = *s++; @@ -338,14 +343,14 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) tmp_obj = PyUnicode_FromString(buf); obj = PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); - } break; case PYGRES_DECIMAL: tmp_obj = PyUnicode_FromStringAndSize(s, size); - obj = decimal ? PyObject_CallFunctionObjArgs( - decimal, tmp_obj, NULL) : PyFloat_FromString(tmp_obj); + obj = decimal + ? PyObject_CallFunctionObjArgs(decimal, tmp_obj, NULL) + : PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); break; @@ -404,7 +409,8 @@ cast_unsized_simple(char *s, int type) buf[j++] = '-'; } } - buf[j] = '\0'; s = buf; + buf[j] = '\0'; + s = buf; /* FALLTHROUGH */ /* no break here */ case PYGRES_DECIMAL: @@ -438,11 +444,10 @@ cast_unsized_simple(char *s, int type) } /* Quick case insensitive check if given sized string is null. */ -#define STR_IS_NULL(s, n) (n == 4 && \ - (s[0] == 'n' || s[0] == 'N') && \ - (s[1] == 'u' || s[1] == 'U') && \ - (s[2] == 'l' || s[2] == 'L') && \ - (s[3] == 'l' || s[3] == 'L')) +#define STR_IS_NULL(s, n) \ + (n == 4 && (s[0] == 'n' || s[0] == 'N') && \ + (s[1] == 'u' || s[1] == 'U') && (s[2] == 'l' || s[2] == 'L') && \ + (s[3] == 'l' || s[3] == 'L')) /* Cast string s with size and encoding to a Python list, using the input and output syntax for arrays. @@ -450,8 +455,8 @@ cast_unsized_simple(char *s, int type) The parameter delim specifies the delimiter for the elements, since some types do not use the default delimiter of a comma. */ static PyObject * -cast_array(char *s, Py_ssize_t size, int encoding, - int type, PyObject *cast, char delim) +cast_array(char *s, Py_ssize_t size, int encoding, int type, PyObject *cast, + char delim) { PyObject *result, *stack[MAX_ARRAY_DEPTH]; char *end = s + size, *t; @@ -459,12 +464,13 @@ cast_array(char *s, Py_ssize_t size, int encoding, if (type) { type &= ~PYGRES_ARRAY; /* get the base type */ - if (!type) type = PYGRES_TEXT; + if (!type) + type = PYGRES_TEXT; } if (!delim) { delim = ','; } - else if (delim == '{' || delim =='}' || delim=='\\') { + else if (delim == '{' || delim == '}' || delim == '\\') { PyErr_SetString(PyExc_ValueError, "Invalid array delimiter"); return NULL; } @@ -475,20 +481,28 @@ cast_array(char *s, Py_ssize_t size, int encoding, int valid; for (valid = 0; !valid;) { - if (s == end || *s++ != '[') break; + if (s == end || *s++ != '[') + break; while (s != end && *s == ' ') ++s; - if (s != end && (*s == '+' || *s == '-')) ++s; - if (s == end || *s < '0' || *s > '9') break; + if (s != end && (*s == '+' || *s == '-')) + ++s; + if (s == end || *s < '0' || *s > '9') + break; while (s != end && *s >= '0' && *s <= '9') ++s; - if (s == end || *s++ != ':') break; - if (s != end && (*s == '+' || *s == '-')) ++s; - if (s == end || *s < '0' || *s > '9') break; + if (s == end || *s++ != ':') + break; + if (s != end && (*s == '+' || *s == '-')) + ++s; + if (s == end || *s < '0' || *s > '9') + break; while (s != end && *s >= '0' && *s <= '9') ++s; - if (s == end || *s++ != ']') break; + if (s == end || *s++ != ']') + break; while (s != end && *s == ' ') ++s; ++ranges; if (s != end && *s == '=') { - do ++s; while (s != end && *s == ' '); + do ++s; + while (s != end && *s == ' '); valid = 1; } } @@ -498,7 +512,8 @@ cast_array(char *s, Py_ssize_t size, int encoding, } } for (t = s, depth = 0; t != end && (*t == '{' || *t == ' '); ++t) { - if (*t == '{') ++depth; + if (*t == '{') + ++depth; } if (!depth) { PyErr_SetString(PyExc_ValueError, @@ -516,30 +531,40 @@ cast_array(char *s, Py_ssize_t size, int encoding, } depth--; /* next level of parsing */ result = PyList_New(0); - if (!result) return NULL; - do ++s; while (s != end && *s == ' '); + if (!result) + return NULL; + do ++s; + while (s != end && *s == ' '); /* everything is set up, start parsing the array */ while (s != end) { if (*s == '}') { PyObject *subresult; - if (!level) break; /* top level array ended */ - do ++s; while (s != end && *s == ' '); - if (s == end) break; /* error */ + if (!level) + break; /* top level array ended */ + do ++s; + while (s != end && *s == ' '); + if (s == end) + break; /* error */ if (*s == delim) { - do ++s; while (s != end && *s == ' '); - if (s == end) break; /* error */ + do ++s; + while (s != end && *s == ' '); + if (s == end) + break; /* error */ if (*s != '{') { PyErr_SetString(PyExc_ValueError, "Subarray expected but not found"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } - else if (*s != '}') break; /* error */ + else if (*s != '}') + break; /* error */ subresult = result; result = stack[--level]; if (PyList_Append(result, subresult)) { - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } else if (level == depth) { /* we expect elements at this level */ @@ -551,40 +576,48 @@ cast_array(char *s, Py_ssize_t size, int encoding, if (*s == '{') { PyErr_SetString(PyExc_ValueError, "Subarray found where not expected"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } if (*s == '"') { /* quoted element */ estr = ++s; while (s != end && *s != '"') { if (*s == '\\') { - ++s; if (s == end) break; + ++s; + if (s == end) + break; escaped = 1; } ++s; } esize = s - estr; - do ++s; while (s != end && *s == ' '); + do ++s; + while (s != end && *s == ' '); } else { /* unquoted element */ estr = s; /* can contain blanks inside */ - while (s != end && *s != '"' && - *s != '{' && *s != '}' && *s != delim) - { + while (s != end && *s != '"' && *s != '{' && *s != '}' && + *s != delim) { if (*s == '\\') { - ++s; if (s == end) break; + ++s; + if (s == end) + break; escaped = 1; } ++s; } - t = s; while (t > estr && *(t - 1) == ' ') --t; + t = s; + while (t > estr && *(t - 1) == ' ') --t; if (!(esize = t - estr)) { - s = end; break; /* error */ + s = end; + break; /* error */ } if (STR_IS_NULL(estr, esize)) /* NULL gives None */ estr = NULL; } - if (s == end) break; /* error */ + if (s == end) + break; /* error */ if (estr) { if (escaped) { char *r; @@ -592,12 +625,14 @@ cast_array(char *s, Py_ssize_t size, int encoding, /* create unescaped string */ t = estr; - estr = (char *) PyMem_Malloc((size_t) esize); + estr = (char *)PyMem_Malloc((size_t)esize); if (!estr) { - Py_DECREF(result); return PyErr_NoMemory(); + Py_DECREF(result); + return PyErr_NoMemory(); } for (i = 0, r = estr; i < esize; ++i) { - if (*t == '\\') ++t, ++i; + if (*t == '\\') + ++t, ++i; *r++ = *t++; } esize = r - estr; @@ -609,58 +644,73 @@ cast_array(char *s, Py_ssize_t size, int encoding, element = cast_sized_simple(estr, esize, type); } else { /* external casting of base type */ - element = encoding == pg_encoding_ascii ? NULL : - get_decoded_string(estr, esize, encoding); + element = encoding == pg_encoding_ascii + ? NULL + : get_decoded_string(estr, esize, encoding); if (!element) { /* no decoding necessary or possible */ element = PyBytes_FromStringAndSize(estr, esize); } if (element && cast) { PyObject *tmp = element; - element = PyObject_CallFunctionObjArgs( - cast, element, NULL); + element = + PyObject_CallFunctionObjArgs(cast, element, NULL); Py_DECREF(tmp); } } - if (escaped) PyMem_Free(estr); + if (escaped) + PyMem_Free(estr); if (!element) { - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } else { - Py_INCREF(Py_None); element = Py_None; + Py_INCREF(Py_None); + element = Py_None; } if (PyList_Append(result, element)) { - Py_DECREF(element); Py_DECREF(result); return NULL; + Py_DECREF(element); + Py_DECREF(result); + return NULL; } Py_DECREF(element); if (*s == delim) { - do ++s; while (s != end && *s == ' '); - if (s == end) break; /* error */ + do ++s; + while (s != end && *s == ' '); + if (s == end) + break; /* error */ } - else if (*s != '}') break; /* error */ + else if (*s != '}') + break; /* error */ } else { /* we expect arrays at this level */ if (*s != '{') { PyErr_SetString(PyExc_ValueError, "Subarray must start with a left brace"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } - do ++s; while (s != end && *s == ' '); - if (s == end) break; /* error */ + do ++s; + while (s != end && *s == ' '); + if (s == end) + break; /* error */ stack[level++] = result; - if (!(result = PyList_New(0))) return NULL; + if (!(result = PyList_New(0))) + return NULL; } } if (s == end || *s != '}') { - PyErr_SetString(PyExc_ValueError, - "Unexpected end of array"); - Py_DECREF(result); return NULL; + PyErr_SetString(PyExc_ValueError, "Unexpected end of array"); + Py_DECREF(result); + return NULL; } - do ++s; while (s != end && *s == ' '); + do ++s; + while (s != end && *s == ' '); if (s != end) { PyErr_SetString(PyExc_ValueError, "Unexpected characters after end of array"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } return result; } @@ -672,8 +722,8 @@ cast_array(char *s, Py_ssize_t size, int encoding, The parameter delim can specify a delimiter for the elements, although composite types always use a comma as delimiter. */ static PyObject * -cast_record(char *s, Py_ssize_t size, int encoding, - int *type, PyObject *cast, Py_ssize_t len, char delim) +cast_record(char *s, Py_ssize_t size, int encoding, int *type, PyObject *cast, + Py_ssize_t len, char delim) { PyObject *result, *ret; char *end = s + size, *t; @@ -682,7 +732,7 @@ cast_record(char *s, Py_ssize_t size, int encoding, if (!delim) { delim = ','; } - else if (delim == '(' || delim ==')' || delim=='\\') { + else if (delim == '(' || delim == ')' || delim == '\\') { PyErr_SetString(PyExc_ValueError, "Invalid record delimiter"); return NULL; } @@ -695,14 +745,16 @@ cast_record(char *s, Py_ssize_t size, int encoding, return NULL; } result = PyList_New(0); - if (!result) return NULL; + if (!result) + return NULL; i = 0; /* everything is set up, start parsing the record */ while (++s != end) { PyObject *element; if (*s == ')' || *s == delim) { - Py_INCREF(Py_None); element = Py_None; + Py_INCREF(Py_None); + element = Py_None; } else { char *estr; @@ -711,32 +763,40 @@ cast_record(char *s, Py_ssize_t size, int encoding, estr = s; quoted = *s == '"'; - if (quoted) ++s; + if (quoted) + ++s; esize = 0; while (s != end) { if (!quoted && (*s == ')' || *s == delim)) break; if (*s == '"') { - ++s; if (s == end) break; + ++s; + if (s == end) + break; if (!(quoted && *s == '"')) { - quoted = !quoted; continue; + quoted = !quoted; + continue; } } if (*s == '\\') { - ++s; if (s == end) break; + ++s; + if (s == end) + break; } ++s, ++esize; } - if (s == end) break; /* error */ + if (s == end) + break; /* error */ if (estr + esize != s) { char *r; escaped = 1; /* create unescaped string */ t = estr; - estr = (char *) PyMem_Malloc((size_t) esize); + estr = (char *)PyMem_Malloc((size_t)esize); if (!estr) { - Py_DECREF(result); return PyErr_NoMemory(); + Py_DECREF(result); + return PyErr_NoMemory(); } quoted = 0; r = estr; @@ -744,10 +804,12 @@ cast_record(char *s, Py_ssize_t size, int encoding, if (*t == '"') { ++t; if (!(quoted && *t == '"')) { - quoted = !quoted; continue; + quoted = !quoted; + continue; } } - if (*t == '\\') ++t; + if (*t == '\\') + ++t; *r++ = *t++; } } @@ -755,16 +817,17 @@ cast_record(char *s, Py_ssize_t size, int encoding, int etype = type[i]; if (etype & PYGRES_ARRAY) - element = cast_array( - estr, esize, encoding, etype, NULL, 0); + element = + cast_array(estr, esize, encoding, etype, NULL, 0); else if (etype & PYGRES_TEXT) element = cast_sized_text(estr, esize, encoding, etype); else element = cast_sized_simple(estr, esize, etype); } else { /* external casting of base type */ - element = encoding == pg_encoding_ascii ? NULL : - get_decoded_string(estr, esize, encoding); + element = encoding == pg_encoding_ascii + ? NULL + : get_decoded_string(estr, esize, encoding); if (!element) { /* no decoding necessary or possible */ element = PyBytes_FromStringAndSize(estr, esize); } @@ -781,46 +844,58 @@ cast_record(char *s, Py_ssize_t size, int encoding, } } else { - Py_DECREF(element); element = NULL; + Py_DECREF(element); + element = NULL; } } else { PyObject *tmp = element; - element = PyObject_CallFunctionObjArgs( - cast, element, NULL); + element = + PyObject_CallFunctionObjArgs(cast, element, NULL); Py_DECREF(tmp); } } } - if (escaped) PyMem_Free(estr); + if (escaped) + PyMem_Free(estr); if (!element) { - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } if (PyList_Append(result, element)) { - Py_DECREF(element); Py_DECREF(result); return NULL; + Py_DECREF(element); + Py_DECREF(result); + return NULL; } Py_DECREF(element); - if (len) ++i; - if (*s != delim) break; /* no next record */ + if (len) + ++i; + if (*s != delim) + break; /* no next record */ if (len && i >= len) { PyErr_SetString(PyExc_ValueError, "Too many columns"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } if (s == end || *s != ')') { PyErr_SetString(PyExc_ValueError, "Unexpected end of record"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } - do ++s; while (s != end && *s == ' '); + do ++s; + while (s != end && *s == ' '); if (s != end) { PyErr_SetString(PyExc_ValueError, "Unexpected characters after end of record"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } if (len && i < len) { PyErr_SetString(PyExc_ValueError, "Too few columns"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } ret = PyList_AsTuple(result); @@ -846,94 +921,116 @@ cast_hstore(char *s, Py_ssize_t size, int encoding) int quoted; while (s != end && *s == ' ') ++s; - if (s == end) break; + if (s == end) + break; quoted = *s == '"'; if (quoted) { key = ++s; while (s != end) { - if (*s == '"') break; + if (*s == '"') + break; if (*s == '\\') { - if (++s == end) break; + if (++s == end) + break; ++key_esc; } ++s; } if (s == end) { PyErr_SetString(PyExc_ValueError, "Unterminated quote"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } else { key = s; while (s != end) { - if (*s == '=' || *s == ' ') break; + if (*s == '=' || *s == ' ') + break; if (*s == '\\') { - if (++s == end) break; + if (++s == end) + break; ++key_esc; } ++s; } if (s == key) { PyErr_SetString(PyExc_ValueError, "Missing key"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } size = s - key - key_esc; if (key_esc) { char *r = key, *t; - key = (char *) PyMem_Malloc((size_t) size); + key = (char *)PyMem_Malloc((size_t)size); if (!key) { - Py_DECREF(result); return PyErr_NoMemory(); + Py_DECREF(result); + return PyErr_NoMemory(); } t = key; while (r != s) { if (*r == '\\') { - ++r; if (r == s) break; + ++r; + if (r == s) + break; } *t++ = *r++; } } key_obj = cast_sized_text(key, size, encoding, PYGRES_TEXT); - if (key_esc) PyMem_Free(key); + if (key_esc) + PyMem_Free(key); if (!key_obj) { - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } - if (quoted) ++s; + if (quoted) + ++s; while (s != end && *s == ' ') ++s; if (s == end || *s++ != '=' || s == end || *s++ != '>') { PyErr_SetString(PyExc_ValueError, "Invalid characters after key"); - Py_DECREF(key_obj); Py_DECREF(result); return NULL; + Py_DECREF(key_obj); + Py_DECREF(result); + return NULL; } while (s != end && *s == ' ') ++s; quoted = *s == '"'; if (quoted) { val = ++s; while (s != end) { - if (*s == '"') break; + if (*s == '"') + break; if (*s == '\\') { - if (++s == end) break; + if (++s == end) + break; ++val_esc; } ++s; } if (s == end) { PyErr_SetString(PyExc_ValueError, "Unterminated quote"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } else { val = s; while (s != end) { - if (*s == ',' || *s == ' ') break; + if (*s == ',' || *s == ' ') + break; if (*s == '\\') { - if (++s == end) break; + if (++s == end) + break; ++val_esc; } ++s; } if (s == val) { PyErr_SetString(PyExc_ValueError, "Missing value"); - Py_DECREF(key_obj); Py_DECREF(result); return NULL; + Py_DECREF(key_obj); + Py_DECREF(result); + return NULL; } if (STR_IS_NULL(val, s - val)) val = NULL; @@ -942,46 +1039,59 @@ cast_hstore(char *s, Py_ssize_t size, int encoding) size = s - val - val_esc; if (val_esc) { char *r = val, *t; - val = (char *) PyMem_Malloc((size_t) size); + val = (char *)PyMem_Malloc((size_t)size); if (!val) { - Py_DECREF(key_obj); Py_DECREF(result); + Py_DECREF(key_obj); + Py_DECREF(result); return PyErr_NoMemory(); } t = val; while (r != s) { if (*r == '\\') { - ++r; if (r == s) break; + ++r; + if (r == s) + break; } *t++ = *r++; } } val_obj = cast_sized_text(val, size, encoding, PYGRES_TEXT); - if (val_esc) PyMem_Free(val); + if (val_esc) + PyMem_Free(val); if (!val_obj) { - Py_DECREF(key_obj); Py_DECREF(result); return NULL; + Py_DECREF(key_obj); + Py_DECREF(result); + return NULL; } } else { - Py_INCREF(Py_None); val_obj = Py_None; + Py_INCREF(Py_None); + val_obj = Py_None; } - if (quoted) ++s; + if (quoted) + ++s; while (s != end && *s == ' ') ++s; if (s != end) { if (*s++ != ',') { PyErr_SetString(PyExc_ValueError, "Invalid characters after val"); - Py_DECREF(key_obj); Py_DECREF(val_obj); - Py_DECREF(result); return NULL; + Py_DECREF(key_obj); + Py_DECREF(val_obj); + Py_DECREF(result); + return NULL; } while (s != end && *s == ' ') ++s; if (s == end) { PyErr_SetString(PyExc_ValueError, "Missing entry"); - Py_DECREF(key_obj); Py_DECREF(val_obj); - Py_DECREF(result); return NULL; + Py_DECREF(key_obj); + Py_DECREF(val_obj); + Py_DECREF(result); + return NULL; } } PyDict_SetItem(result, key_obj, val_obj); - Py_DECREF(key_obj); Py_DECREF(val_obj); + Py_DECREF(key_obj); + Py_DECREF(val_obj); } return result; } @@ -1054,15 +1164,15 @@ get_error_type(const char *sqlstate) /* Set database error message and sqlstate attribute. */ static void -set_error_msg_and_state(PyObject *type, - const char *msg, int encoding, const char *sqlstate) +set_error_msg_and_state(PyObject *type, const char *msg, int encoding, + const char *sqlstate) { PyObject *err_obj, *msg_obj, *sql_obj = NULL; if (encoding == -1) /* unknown */ msg_obj = PyUnicode_DecodeLocale(msg, NULL); else - msg_obj = get_decoded_string(msg, (Py_ssize_t) strlen(msg), encoding); + msg_obj = get_decoded_string(msg, (Py_ssize_t)strlen(msg), encoding); if (!msg_obj) /* cannot decode */ msg_obj = PyBytes_FromString(msg); @@ -1070,7 +1180,8 @@ set_error_msg_and_state(PyObject *type, sql_obj = PyUnicode_FromStringAndSize(sqlstate, 5); } else { - Py_INCREF(Py_None); sql_obj = Py_None; + Py_INCREF(Py_None); + sql_obj = Py_None; } err_obj = PyObject_CallFunctionObjArgs(type, msg_obj, NULL); @@ -1095,7 +1206,7 @@ set_error_msg(PyObject *type, const char *msg) /* Set database error from connection and/or result. */ static void -set_error(PyObject *type, const char * msg, PGconn *cnx, PGresult *result) +set_error(PyObject *type, const char *msg, PGconn *cnx, PGresult *result) { char *sqlstate = NULL; int encoding = pg_encoding_ascii; @@ -1109,7 +1220,8 @@ set_error(PyObject *type, const char * msg, PGconn *cnx, PGresult *result) } if (result) { sqlstate = PQresultErrorField(result, PG_DIAG_SQLSTATE); - if (sqlstate) type = get_error_type(sqlstate); + if (sqlstate) + type = get_error_type(sqlstate); } set_error_msg_and_state(type, msg, encoding, sqlstate); @@ -1117,9 +1229,10 @@ set_error(PyObject *type, const char * msg, PGconn *cnx, PGresult *result) /* Get SSL attributes and values as a dictionary. */ static PyObject * -get_ssl_attributes(PGconn *cnx) { +get_ssl_attributes(PGconn *cnx) +{ PyObject *attr_dict = NULL; - const char * const *s; + const char *const *s; if (!(attr_dict = PyDict_New())) { return NULL; @@ -1129,7 +1242,7 @@ get_ssl_attributes(PGconn *cnx) { const char *val = PQsslAttribute(cnx, *s); if (val) { - PyObject * val_obj = PyUnicode_FromString(val); + PyObject *val_obj = PyUnicode_FromString(val); PyDict_SetItemString(attr_dict, *s, val_obj); Py_DECREF(val_obj); @@ -1153,10 +1266,10 @@ format_result(const PGresult *res) const int n = PQnfields(res); if (n > 0) { - char * const aligns = (char *) PyMem_Malloc( - (unsigned int) n * sizeof(char)); - size_t * const sizes = (size_t *) PyMem_Malloc( - (unsigned int) n * sizeof(size_t)); + char *const aligns = + (char *)PyMem_Malloc((unsigned int)n * sizeof(char)); + size_t *const sizes = + (size_t *)PyMem_Malloc((unsigned int)n * sizeof(size_t)); if (aligns && sizes) { const int m = PQntuples(res); @@ -1166,7 +1279,7 @@ format_result(const PGresult *res) /* calculate sizes and alignments */ for (j = 0; j < n; ++j) { - const char * const s = PQfname(res, j); + const char *const s = PQfname(res, j); const int format = PQfformat(res, j); sizes[j] = s ? strlen(s) : 0; @@ -1202,9 +1315,9 @@ format_result(const PGresult *res) if (aligns[j]) { const int k = PQgetlength(res, i, j); - if (sizes[j] < (size_t) k) + if (sizes[j] < (size_t)k) /* value must fit */ - sizes[j] = (size_t) k; + sizes[j] = (size_t)k; } } } @@ -1212,23 +1325,23 @@ format_result(const PGresult *res) /* size of one row */ for (j = 0; j < n; ++j) size += sizes[j] + 1; /* times number of rows incl. heading */ - size *= (size_t) m + 2; + size *= (size_t)m + 2; /* plus size of footer */ size += 40; /* is the buffer size that needs to be allocated */ - buffer = (char *) PyMem_Malloc(size); + buffer = (char *)PyMem_Malloc(size); if (buffer) { char *p = buffer; PyObject *result; /* create the header */ for (j = 0; j < n; ++j) { - const char * const s = PQfname(res, j); + const char *const s = PQfname(res, j); const size_t k = sizes[j]; - const size_t h = (k - (size_t) strlen(s)) / 2; + const size_t h = (k - (size_t)strlen(s)) / 2; - sprintf(p, "%*s", (int) h, ""); - sprintf(p + h, "%-*s", (int) (k - h), s); + sprintf(p, "%*s", (int)h, ""); + sprintf(p + h, "%-*s", (int)(k - h), s); p += k; if (j + 1 < n) *p++ = '|'; @@ -1237,8 +1350,7 @@ format_result(const PGresult *res) for (j = 0; j < n; ++j) { size_t k = sizes[j]; - while (k--) - *p++ = '-'; + while (k--) *p++ = '-'; if (j + 1 < n) *p++ = '+'; } @@ -1250,11 +1362,11 @@ format_result(const PGresult *res) const size_t k = sizes[j]; if (align) { - sprintf(p, align == 'r' ? "%*s" : "%-*s", (int) k, + sprintf(p, align == 'r' ? "%*s" : "%-*s", (int)k, PQgetvalue(res, i, j)); } else { - sprintf(p, "%-*s", (int) k, + sprintf(p, "%-*s", (int)k, PQgetisnull(res, i, j) ? "" : ""); } p += k; @@ -1264,7 +1376,8 @@ format_result(const PGresult *res) *p++ = '\n'; } /* free memory */ - PyMem_Free(aligns); PyMem_Free(sizes); + PyMem_Free(aligns); + PyMem_Free(sizes); /* create the footer */ sprintf(p, "(%d row%s)", m, m == 1 ? "" : "s"); /* return the result */ @@ -1273,11 +1386,15 @@ format_result(const PGresult *res) return result; } else { - PyMem_Free(aligns); PyMem_Free(sizes); return PyErr_NoMemory(); + PyMem_Free(aligns); + PyMem_Free(sizes); + return PyErr_NoMemory(); } } else { - PyMem_Free(aligns); PyMem_Free(sizes); return PyErr_NoMemory(); + PyMem_Free(aligns); + PyMem_Free(sizes); + return PyErr_NoMemory(); } } else @@ -1288,28 +1405,31 @@ format_result(const PGresult *res) static const char * date_style_to_format(const char *s) { - static const char *formats[] = - { - "%Y-%m-%d", /* 0 = ISO */ - "%m-%d-%Y", /* 1 = Postgres, MDY */ - "%d-%m-%Y", /* 2 = Postgres, DMY */ - "%m/%d/%Y", /* 3 = SQL, MDY */ - "%d/%m/%Y", /* 4 = SQL, DMY */ - "%d.%m.%Y" /* 5 = German */ + static const char *formats[] = { + "%Y-%m-%d", /* 0 = ISO */ + "%m-%d-%Y", /* 1 = Postgres, MDY */ + "%d-%m-%Y", /* 2 = Postgres, DMY */ + "%m/%d/%Y", /* 3 = SQL, MDY */ + "%d/%m/%Y", /* 4 = SQL, DMY */ + "%d.%m.%Y" /* 5 = German */ }; switch (s ? *s : 'I') { case 'P': /* Postgres */ s = strchr(s + 1, ','); - if (s) do ++s; while (*s && *s == ' '); + if (s) + do ++s; + while (*s && *s == ' '); return formats[s && *s == 'D' ? 2 : 1]; case 'S': /* SQL */ s = strchr(s + 1, ','); - if (s) do ++s; while (*s && *s == ' '); + if (s) + do ++s; + while (*s && *s == ' '); return formats[s && *s == 'D' ? 4 : 3]; case 'G': /* German */ return formats[5]; - default: /* ISO */ + default: /* ISO */ return formats[0]; /* ISO is the default */ } } @@ -1318,14 +1438,13 @@ date_style_to_format(const char *s) static const char * date_format_to_style(const char *s) { - static const char *datestyle[] = - { - "ISO, YMD", /* 0 = %Y-%m-%d */ - "Postgres, MDY", /* 1 = %m-%d-%Y */ - "Postgres, DMY", /* 2 = %d-%m-%Y */ - "SQL, MDY", /* 3 = %m/%d/%Y */ - "SQL, DMY", /* 4 = %d/%m/%Y */ - "German, DMY" /* 5 = %d.%m.%Y */ + static const char *datestyle[] = { + "ISO, YMD", /* 0 = %Y-%m-%d */ + "Postgres, MDY", /* 1 = %m-%d-%Y */ + "Postgres, DMY", /* 2 = %d-%m-%Y */ + "SQL, MDY", /* 3 = %m/%d/%Y */ + "SQL, DMY", /* 4 = %d/%m/%Y */ + "German, DMY" /* 5 = %d.%m.%Y */ }; switch (s ? s[1] : 'Y') { @@ -1355,7 +1474,7 @@ static void notice_receiver(void *arg, const PGresult *res) { PyGILState_STATE gstate = PyGILState_Ensure(); - connObject *self = (connObject*) arg; + connObject *self = (connObject *)arg; PyObject *func = self->notice_receiver; if (func) { @@ -1367,7 +1486,7 @@ notice_receiver(void *arg, const PGresult *res) } else { Py_INCREF(Py_None); - notice = (noticeObject *)(void *) Py_None; + notice = (noticeObject *)(void *)Py_None; } ret = PyObject_CallFunction(func, "(O)", notice); Py_XDECREF(ret); diff --git a/pglarge.c b/pglarge.c index 863e2ec9..77455361 100644 --- a/pglarge.c +++ b/pglarge.c @@ -28,9 +28,10 @@ static PyObject * large_str(largeObject *self) { char str[80]; - sprintf(str, self->lo_fd >= 0 ? - "Opened large object, oid %ld" : - "Closed large object, oid %ld", (long) self->lo_oid); + sprintf(str, + self->lo_fd >= 0 ? "Opened large object, oid %ld" + : "Closed large object, oid %ld", + (long)self->lo_oid); return PyUnicode_FromString(str); } @@ -75,7 +76,7 @@ large_getattr(largeObject *self, PyObject *nameobj) if (!strcmp(name, "pgcnx")) { if (_check_lo_obj(self, 0)) { Py_INCREF(self->pgcnx); - return (PyObject *) (self->pgcnx); + return (PyObject *)(self->pgcnx); } PyErr_Clear(); Py_INCREF(Py_None); @@ -85,7 +86,7 @@ large_getattr(largeObject *self, PyObject *nameobj) /* large object oid */ if (!strcmp(name, "oid")) { if (_check_lo_obj(self, 0)) - return PyLong_FromLong((long) self->lo_oid); + return PyLong_FromLong((long)self->lo_oid); PyErr_Clear(); Py_INCREF(Py_None); return Py_None; @@ -96,7 +97,7 @@ large_getattr(largeObject *self, PyObject *nameobj) return PyUnicode_FromString(PQerrorMessage(self->pgcnx->cnx)); /* seeks name in methods (fallback) */ - return PyObject_GenericGetAttr((PyObject *) self, nameobj); + return PyObject_GenericGetAttr((PyObject *)self, nameobj); } /* Get the list of large object attributes. */ @@ -105,17 +106,16 @@ large_dir(largeObject *self, PyObject *noargs) { PyObject *attrs; - attrs = PyObject_Dir(PyObject_Type((PyObject *) self)); - PyObject_CallMethod( - attrs, "extend", "[sss]", "oid", "pgcnx", "error"); + attrs = PyObject_Dir(PyObject_Type((PyObject *)self)); + PyObject_CallMethod(attrs, "extend", "[sss]", "oid", "pgcnx", "error"); return attrs; } /* Open large object. */ static char large_open__doc__[] = -"open(mode) -- open access to large object with specified mode\n\n" -"The mode must be one of INV_READ, INV_WRITE (module level constants).\n"; + "open(mode) -- open access to large object with specified mode\n\n" + "The mode must be one of INV_READ, INV_WRITE (module level constants).\n"; static PyObject * large_open(largeObject *self, PyObject *args) @@ -148,7 +148,7 @@ large_open(largeObject *self, PyObject *args) /* Close large object. */ static char large_close__doc__[] = -"close() -- close access to large object data"; + "close() -- close access to large object data"; static PyObject * large_close(largeObject *self, PyObject *noargs) @@ -172,8 +172,8 @@ large_close(largeObject *self, PyObject *noargs) /* Read from large object. */ static char large_read__doc__[] = -"read(size) -- read from large object to sized string\n\n" -"Object must be opened in read mode before calling this method.\n"; + "read(size) -- read from large object to sized string\n\n" + "Object must be opened in read mode before calling this method.\n"; static PyObject * large_read(largeObject *self, PyObject *args) @@ -200,11 +200,11 @@ large_read(largeObject *self, PyObject *args) } /* allocate buffer and runs read */ - buffer = PyBytes_FromStringAndSize((char *) NULL, size); + buffer = PyBytes_FromStringAndSize((char *)NULL, size); if ((size = lo_read(self->pgcnx->cnx, self->lo_fd, - PyBytes_AS_STRING((PyBytesObject *) (buffer)), (size_t) size)) == -1) - { + PyBytes_AS_STRING((PyBytesObject *)(buffer)), + (size_t)size)) == -1) { PyErr_SetString(PyExc_IOError, "Error while reading"); Py_XDECREF(buffer); return NULL; @@ -217,8 +217,8 @@ large_read(largeObject *self, PyObject *args) /* Write to large object. */ static char large_write__doc__[] = -"write(string) -- write sized string to large object\n\n" -"Object must be opened in read mode before calling this method.\n"; + "write(string) -- write sized string to large object\n\n" + "Object must be opened in read mode before calling this method.\n"; static PyObject * large_write(largeObject *self, PyObject *args) @@ -241,8 +241,7 @@ large_write(largeObject *self, PyObject *args) /* sends query */ if ((size = lo_write(self->pgcnx->cnx, self->lo_fd, buffer, - (size_t) bufsize)) != bufsize) - { + (size_t)bufsize)) != bufsize) { PyErr_SetString(PyExc_IOError, "Buffer truncated during write"); return NULL; } @@ -254,9 +253,9 @@ large_write(largeObject *self, PyObject *args) /* Go to position in large object. */ static char large_seek__doc__[] = -"seek(offset, whence) -- move to specified position\n\n" -"Object must be opened before calling this method. The whence option\n" -"can be SEEK_SET, SEEK_CUR or SEEK_END (module level constants).\n"; + "seek(offset, whence) -- move to specified position\n\n" + "Object must be opened before calling this method. The whence option\n" + "can be SEEK_SET, SEEK_CUR or SEEK_END (module level constants).\n"; static PyObject * large_seek(largeObject *self, PyObject *args) @@ -277,9 +276,8 @@ large_seek(largeObject *self, PyObject *args) } /* sends query */ - if ((ret = lo_lseek( - self->pgcnx->cnx, self->lo_fd, offset, whence)) == -1) - { + if ((ret = lo_lseek(self->pgcnx->cnx, self->lo_fd, offset, whence)) == + -1) { PyErr_SetString(PyExc_IOError, "Error while moving cursor"); return NULL; } @@ -290,8 +288,8 @@ large_seek(largeObject *self, PyObject *args) /* Get large object size. */ static char large_size__doc__[] = -"size() -- return large object size\n\n" -"The object must be opened before calling this method.\n"; + "size() -- return large object size\n\n" + "The object must be opened before calling this method.\n"; static PyObject * large_size(largeObject *self, PyObject *noargs) @@ -316,9 +314,8 @@ large_size(largeObject *self, PyObject *noargs) } /* move back to start position */ - if ((start = lo_lseek( - self->pgcnx->cnx, self->lo_fd, start, SEEK_SET)) == -1) - { + if ((start = lo_lseek(self->pgcnx->cnx, self->lo_fd, start, SEEK_SET)) == + -1) { PyErr_SetString(PyExc_IOError, "Error while moving back to first position"); return NULL; @@ -330,8 +327,8 @@ large_size(largeObject *self, PyObject *noargs) /* Get large object cursor position. */ static char large_tell__doc__[] = -"tell() -- give current position in large object\n\n" -"The object must be opened before calling this method.\n"; + "tell() -- give current position in large object\n\n" + "The object must be opened before calling this method.\n"; static PyObject * large_tell(largeObject *self, PyObject *noargs) @@ -355,8 +352,8 @@ large_tell(largeObject *self, PyObject *noargs) /* Export large object as unix file. */ static char large_export__doc__[] = -"export(filename) -- export large object data to specified file\n\n" -"The object must be closed when calling this method.\n"; + "export(filename) -- export large object data to specified file\n\n" + "The object must be closed when calling this method.\n"; static PyObject * large_export(largeObject *self, PyObject *args) @@ -387,8 +384,8 @@ large_export(largeObject *self, PyObject *args) /* Delete a large object. */ static char large_unlink__doc__[] = -"unlink() -- destroy large object\n\n" -"The object must be closed when calling this method.\n"; + "unlink() -- destroy large object\n\n" + "The object must be closed when calling this method.\n"; static PyObject * large_unlink(largeObject *self, PyObject *noargs) @@ -411,51 +408,49 @@ large_unlink(largeObject *self, PyObject *noargs) /* Large object methods */ static struct PyMethodDef large_methods[] = { - {"__dir__", (PyCFunction) large_dir, METH_NOARGS, NULL}, - {"open", (PyCFunction) large_open, METH_VARARGS, large_open__doc__}, - {"close", (PyCFunction) large_close, METH_NOARGS, large_close__doc__}, - {"read", (PyCFunction) large_read, METH_VARARGS, large_read__doc__}, - {"write", (PyCFunction) large_write, METH_VARARGS, large_write__doc__}, - {"seek", (PyCFunction) large_seek, METH_VARARGS, large_seek__doc__}, - {"size", (PyCFunction) large_size, METH_NOARGS, large_size__doc__}, - {"tell", (PyCFunction) large_tell, METH_NOARGS, large_tell__doc__}, - {"export",(PyCFunction) large_export, METH_VARARGS, large_export__doc__}, - {"unlink",(PyCFunction) large_unlink, METH_NOARGS, large_unlink__doc__}, - {NULL, NULL} -}; + {"__dir__", (PyCFunction)large_dir, METH_NOARGS, NULL}, + {"open", (PyCFunction)large_open, METH_VARARGS, large_open__doc__}, + {"close", (PyCFunction)large_close, METH_NOARGS, large_close__doc__}, + {"read", (PyCFunction)large_read, METH_VARARGS, large_read__doc__}, + {"write", (PyCFunction)large_write, METH_VARARGS, large_write__doc__}, + {"seek", (PyCFunction)large_seek, METH_VARARGS, large_seek__doc__}, + {"size", (PyCFunction)large_size, METH_NOARGS, large_size__doc__}, + {"tell", (PyCFunction)large_tell, METH_NOARGS, large_tell__doc__}, + {"export", (PyCFunction)large_export, METH_VARARGS, large_export__doc__}, + {"unlink", (PyCFunction)large_unlink, METH_NOARGS, large_unlink__doc__}, + {NULL, NULL}}; static char large__doc__[] = "PostgreSQL large object"; /* Large object type definition */ static PyTypeObject largeType = { - PyVarObject_HEAD_INIT(NULL, 0) - "pg.LargeObject", /* tp_name */ - sizeof(largeObject), /* tp_basicsize */ - 0, /* tp_itemsize */ + PyVarObject_HEAD_INIT(NULL, 0) "pg.LargeObject", /* tp_name */ + sizeof(largeObject), /* tp_basicsize */ + 0, /* tp_itemsize */ /* methods */ - (destructor) large_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_compare */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - (reprfunc) large_str, /* tp_str */ - (getattrofunc) large_getattr, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - large__doc__, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - large_methods, /* tp_methods */ + (destructor)large_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + (reprfunc)large_str, /* tp_str */ + (getattrofunc)large_getattr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + large__doc__, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + large_methods, /* tp_methods */ }; diff --git a/pgmodule.c b/pgmodule.c index f1335263..628de9ec 100644 --- a/pgmodule.c +++ b/pgmodule.c @@ -12,7 +12,6 @@ #define PY_SSIZE_T_CLEAN #include - #include #include @@ -20,9 +19,9 @@ #include "pgtypes.h" static PyObject *Error, *Warning, *InterfaceError, *DatabaseError, - *InternalError, *OperationalError, *ProgrammingError, - *IntegrityError, *DataError, *NotSupportedError, - *InvalidResultError, *NoResultError, *MultipleResultsError; + *InternalError, *OperationalError, *ProgrammingError, *IntegrityError, + *DataError, *NotSupportedError, *InvalidResultError, *NoResultError, + *MultipleResultsError; #define _TOSTRING(x) #x #define TOSTRING(x) _TOSTRING(x) @@ -36,23 +35,23 @@ static const char *PyPgVersion = TOSTRING(PYGRESQL_VERSION); #define PG_ARRAYSIZE 1 /* Flags for object validity checks */ -#define CHECK_OPEN 1 -#define CHECK_CLOSE 2 -#define CHECK_CNX 4 +#define CHECK_OPEN 1 +#define CHECK_CLOSE 2 +#define CHECK_CNX 4 #define CHECK_RESULT 8 -#define CHECK_DQL 16 +#define CHECK_DQL 16 /* Query result types */ #define RESULT_EMPTY 1 -#define RESULT_DML 2 -#define RESULT_DDL 3 -#define RESULT_DQL 4 +#define RESULT_DML 2 +#define RESULT_DDL 3 +#define RESULT_DQL 4 /* Flags for move methods */ #define QUERY_MOVEFIRST 1 -#define QUERY_MOVELAST 2 -#define QUERY_MOVENEXT 3 -#define QUERY_MOVEPREV 4 +#define QUERY_MOVELAST 2 +#define QUERY_MOVENEXT 3 +#define QUERY_MOVEPREV 4 #define MAX_BUFFER_SIZE 65536 /* maximum transaction size */ #define MAX_ARRAY_DEPTH 16 /* maximum allowed depth of an array */ @@ -67,16 +66,17 @@ static PyObject *pg_default_user; /* default username */ static PyObject *pg_default_passwd; /* default password */ static PyObject *decimal = NULL, /* decimal type */ - *dictiter = NULL, /* function for getting dict results */ - *namediter = NULL, /* function for getting named results */ - *namednext = NULL, /* function for getting one named result */ + *dictiter = NULL, /* function for getting dict results */ + *namediter = NULL, /* function for getting named results */ + *namednext = NULL, /* function for getting one named result */ *scalariter = NULL, /* function for getting scalar results */ - *jsondecode = NULL; /* function for decoding json strings */ + *jsondecode = + NULL; /* function for decoding json strings */ static const char *date_format = NULL; /* date format that is always assumed */ -static char decimal_point = '.'; /* decimal point used in money values */ -static int bool_as_text = 0; /* whether bool shall be returned as text */ -static int array_as_text = 0; /* whether arrays shall be returned as text */ -static int bytea_escaped = 0; /* whether bytea shall be returned escaped */ +static char decimal_point = '.'; /* decimal point used in money values */ +static int bool_as_text = 0; /* whether bool shall be returned as text */ +static int array_as_text = 0; /* whether arrays shall be returned as text */ +static int bytea_escaped = 0; /* whether bytea shall be returned escaped */ static int pg_encoding_utf8 = 0; static int pg_encoding_latin1 = 0; @@ -106,65 +106,56 @@ OBJECTS static PyTypeObject connType, sourceType, queryType, noticeType, largeType; /* Forward static declarations */ -static void notice_receiver(void *, const PGresult *); +static void +notice_receiver(void *, const PGresult *); /* Object declarations */ -typedef struct -{ - PyObject_HEAD - int valid; /* validity flag */ - PGconn *cnx; /* Postgres connection handle */ - const char *date_format; /* date format derived from datestyle */ - PyObject *cast_hook; /* external typecast method */ - PyObject *notice_receiver; /* current notice receiver */ -} connObject; +typedef struct { + PyObject_HEAD int valid; /* validity flag */ + PGconn *cnx; /* Postgres connection handle */ + const char *date_format; /* date format derived from datestyle */ + PyObject *cast_hook; /* external typecast method */ + PyObject *notice_receiver; /* current notice receiver */ +} connObject; #define is_connObject(v) (PyType(v) == &connType) -typedef struct -{ - PyObject_HEAD - int valid; /* validity flag */ +typedef struct { + PyObject_HEAD int valid; /* validity flag */ connObject *pgcnx; /* parent connection object */ - PGresult *result; /* result content */ - int encoding; /* client encoding */ - int result_type; /* result type (DDL/DML/DQL) */ - long arraysize; /* array size for fetch method */ - int current_row; /* currently selected row */ - int max_row; /* number of rows in the result */ - int num_fields; /* number of fields in each row */ -} sourceObject; + PGresult *result; /* result content */ + int encoding; /* client encoding */ + int result_type; /* result type (DDL/DML/DQL) */ + long arraysize; /* array size for fetch method */ + int current_row; /* currently selected row */ + int max_row; /* number of rows in the result */ + int num_fields; /* number of fields in each row */ +} sourceObject; #define is_sourceObject(v) (PyType(v) == &sourceType) -typedef struct -{ - PyObject_HEAD - connObject *pgcnx; /* parent connection object */ - PGresult const *res; /* an error or warning */ -} noticeObject; +typedef struct { + PyObject_HEAD connObject *pgcnx; /* parent connection object */ + PGresult const *res; /* an error or warning */ +} noticeObject; #define is_noticeObject(v) (PyType(v) == ¬iceType) -typedef struct -{ - PyObject_HEAD - connObject *pgcnx; /* parent connection object */ - PGresult *result; /* result content */ - int async; /* flag for asynchronous queries */ - int encoding; /* client encoding */ - int current_row; /* currently selected row */ - int max_row; /* number of rows in the result */ - int num_fields; /* number of fields in each row */ - int *col_types; /* PyGreSQL column types */ -} queryObject; +typedef struct { + PyObject_HEAD connObject *pgcnx; /* parent connection object */ + PGresult *result; /* result content */ + int async; /* flag for asynchronous queries */ + int encoding; /* client encoding */ + int current_row; /* currently selected row */ + int max_row; /* number of rows in the result */ + int num_fields; /* number of fields in each row */ + int *col_types; /* PyGreSQL column types */ +} queryObject; #define is_queryObject(v) (PyType(v) == &queryType) -typedef struct -{ - PyObject_HEAD - connObject *pgcnx; /* parent connection object */ - Oid lo_oid; /* large object oid */ - int lo_fd; /* large object fd */ -} largeObject; +typedef struct { + PyObject_HEAD connObject *pgcnx; /* parent connection object */ + Oid lo_oid; /* large object oid */ + int lo_fd; /* large object fd */ +} largeObject; #define is_largeObject(v) (PyType(v) == &largeType) /* Internal functions */ @@ -189,22 +180,22 @@ typedef struct /* Connect to a database. */ static char pg_connect__doc__[] = -"connect(dbname, host, port, opt, user, passwd, wait) -- connect to a PostgreSQL database\n\n" -"The connection uses the specified parameters (optional, keywords aware).\n"; + "connect(dbname, host, port, opt, user, passwd, wait) -- connect to a " + "PostgreSQL database\n\n" + "The connection uses the specified parameters (optional, keywords " + "aware).\n"; static PyObject * pg_connect(PyObject *self, PyObject *args, PyObject *dict) { - static const char *kwlist[] = - { - "dbname", "host", "port", "opt", "user", "passwd", "nowait", NULL - }; + static const char *kwlist[] = {"dbname", "host", "port", "opt", + "user", "passwd", "nowait", NULL}; char *pghost, *pgopt, *pgdbname, *pguser, *pgpasswd; int pgport = -1, nowait = 0, nkw = 0; char port_buffer[20]; const char *keywords[sizeof(kwlist) / sizeof(*kwlist) + 1], - *values[sizeof(kwlist) / sizeof(*kwlist) + 1]; + *values[sizeof(kwlist) / sizeof(*kwlist) + 1]; connObject *conn_obj; pghost = pgopt = pgdbname = pguser = pgpasswd = NULL; @@ -215,10 +206,9 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) * don't declare kwlist as const char *kwlist[] then it complains when * I try to assign all those constant strings to it. */ - if (!PyArg_ParseTupleAndKeywords( - args, dict, "|zzizzzi", (char**)kwlist, - &pgdbname, &pghost, &pgport, &pgopt, &pguser, &pgpasswd, &nowait)) - { + if (!PyArg_ParseTupleAndKeywords(args, dict, "|zzizzzi", (char **)kwlist, + &pgdbname, &pghost, &pgport, &pgopt, + &pguser, &pgpasswd, &nowait)) { return NULL; } @@ -227,7 +217,7 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) pghost = PyBytes_AsString(pg_default_host); if ((pgport == -1) && (pg_default_port != Py_None)) - pgport = (int) PyLong_AsLong(pg_default_port); + pgport = (int)PyLong_AsLong(pg_default_port); if ((!pgopt) && (pg_default_opt != Py_None)) pgopt = PyBytes_AsString(pg_default_opt); @@ -252,33 +242,27 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) conn_obj->cast_hook = NULL; conn_obj->notice_receiver = NULL; - if (pghost) - { + if (pghost) { keywords[nkw] = "host"; values[nkw++] = pghost; } - if (pgopt) - { + if (pgopt) { keywords[nkw] = "options"; values[nkw++] = pgopt; } - if (pgdbname) - { + if (pgdbname) { keywords[nkw] = "dbname"; values[nkw++] = pgdbname; } - if (pguser) - { + if (pguser) { keywords[nkw] = "user"; values[nkw++] = pguser; } - if (pgpasswd) - { + if (pgpasswd) { keywords[nkw] = "password"; values[nkw++] = pgpasswd; } - if (pgport != -1) - { + if (pgport != -1) { memset(port_buffer, 0, sizeof(port_buffer)); sprintf(port_buffer, "%d", pgport); @@ -288,8 +272,8 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) keywords[nkw] = values[nkw] = NULL; Py_BEGIN_ALLOW_THREADS - conn_obj->cnx = nowait ? PQconnectStartParams(keywords, values, 1) : - PQconnectdbParams(keywords, values, 1); + conn_obj->cnx = nowait ? PQconnectStartParams(keywords, values, 1) + : PQconnectdbParams(keywords, values, 1); Py_END_ALLOW_THREADS if (PQstatus(conn_obj->cnx) == CONNECTION_BAD) { @@ -298,32 +282,33 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) return NULL; } - return (PyObject *) conn_obj; + return (PyObject *)conn_obj; } /* Get version of libpq that is being used */ static char pg_get_pqlib_version__doc__[] = -"get_pqlib_version() -- get the version of libpq that is being used"; + "get_pqlib_version() -- get the version of libpq that is being used"; static PyObject * -pg_get_pqlib_version(PyObject *self, PyObject *noargs) { +pg_get_pqlib_version(PyObject *self, PyObject *noargs) +{ return PyLong_FromLong(PQlibVersion()); } /* Escape string */ static char pg_escape_string__doc__[] = -"escape_string(string) -- escape a string for use within SQL"; + "escape_string(string) -- escape a string for use within SQL"; static PyObject * pg_escape_string(PyObject *self, PyObject *string) { - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ - int encoding = -1; /* client encoding */ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ + int encoding = -1; /* client encoding */ if (PyBytes_Check(string)) { PyBytes_AsStringAndSize(string, &from, &from_length); @@ -331,7 +316,8 @@ pg_escape_string(PyObject *self, PyObject *string) else if (PyUnicode_Check(string)) { encoding = pg_encoding_ascii; tmp_obj = get_encoded_string(string, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); } else { @@ -340,38 +326,39 @@ pg_escape_string(PyObject *self, PyObject *string) return NULL; } - to_length = 2 * (size_t) from_length + 1; - if ((Py_ssize_t ) to_length < from_length) { /* overflow */ - to_length = (size_t) from_length; - from_length = (from_length - 1)/2; + to_length = 2 * (size_t)from_length + 1; + if ((Py_ssize_t)to_length < from_length) { /* overflow */ + to_length = (size_t)from_length; + from_length = (from_length - 1) / 2; } - to = (char *) PyMem_Malloc(to_length); - to_length = (size_t) PQescapeString(to, from, (size_t) from_length); + to = (char *)PyMem_Malloc(to_length); + to_length = (size_t)PQescapeString(to, from, (size_t)from_length); Py_XDECREF(tmp_obj); if (encoding == -1) - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length); + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length); else - to_obj = get_decoded_string(to, (Py_ssize_t) to_length, encoding); + to_obj = get_decoded_string(to, (Py_ssize_t)to_length, encoding); PyMem_Free(to); return to_obj; } /* Escape bytea */ static char pg_escape_bytea__doc__[] = -"escape_bytea(data) -- escape binary data for use within SQL as type bytea"; + "escape_bytea(data) -- escape binary data for use within SQL as type " + "bytea"; static PyObject * pg_escape_bytea(PyObject *self, PyObject *data) { - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ - int encoding = -1; /* client encoding */ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ + int encoding = -1; /* client encoding */ if (PyBytes_Check(data)) { PyBytes_AsStringAndSize(data, &from, &from_length); @@ -379,7 +366,8 @@ pg_escape_bytea(PyObject *self, PyObject *data) else if (PyUnicode_Check(data)) { encoding = pg_encoding_ascii; tmp_obj = get_encoded_string(data, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); } else { @@ -388,15 +376,15 @@ pg_escape_bytea(PyObject *self, PyObject *data) return NULL; } - to = (char *) PQescapeBytea( - (unsigned char*) from, (size_t) from_length, &to_length); + to = (char *)PQescapeBytea((unsigned char *)from, (size_t)from_length, + &to_length); Py_XDECREF(tmp_obj); if (encoding == -1) - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length - 1); + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length - 1); else - to_obj = get_decoded_string(to, (Py_ssize_t) to_length - 1, encoding); + to_obj = get_decoded_string(to, (Py_ssize_t)to_length - 1, encoding); if (to) PQfreemem(to); return to_obj; @@ -404,24 +392,25 @@ pg_escape_bytea(PyObject *self, PyObject *data) /* Unescape bytea */ static char pg_unescape_bytea__doc__[] = -"unescape_bytea(string) -- unescape bytea data retrieved as text"; + "unescape_bytea(string) -- unescape bytea data retrieved as text"; static PyObject * pg_unescape_bytea(PyObject *self, PyObject *data) { - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ if (PyBytes_Check(data)) { PyBytes_AsStringAndSize(data, &from, &from_length); } else if (PyUnicode_Check(data)) { tmp_obj = get_encoded_string(data, pg_encoding_ascii); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); } else { @@ -431,13 +420,14 @@ pg_unescape_bytea(PyObject *self, PyObject *data) return NULL; } - to = (char *) PQunescapeBytea((unsigned char*) from, &to_length); + to = (char *)PQunescapeBytea((unsigned char *)from, &to_length); Py_XDECREF(tmp_obj); - if (!to) return PyErr_NoMemory(); + if (!to) + return PyErr_NoMemory(); - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length); + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length); PQfreemem(to); return to_obj; @@ -445,7 +435,7 @@ pg_unescape_bytea(PyObject *self, PyObject *data) /* Set fixed datestyle. */ static char pg_set_datestyle__doc__[] = -"set_datestyle(style) -- set which style is assumed"; + "set_datestyle(style) -- set which style is assumed"; static PyObject * pg_set_datestyle(PyObject *self, PyObject *args) @@ -462,12 +452,13 @@ pg_set_datestyle(PyObject *self, PyObject *args) date_format = datestyle ? date_style_to_format(datestyle) : NULL; - Py_INCREF(Py_None); return Py_None; + Py_INCREF(Py_None); + return Py_None; } /* Get fixed datestyle. */ static char pg_get_datestyle__doc__[] = -"get_datestyle() -- get which date style is assumed"; + "get_datestyle() -- get which date style is assumed"; static PyObject * pg_get_datestyle(PyObject *self, PyObject *noargs) @@ -476,13 +467,14 @@ pg_get_datestyle(PyObject *self, PyObject *noargs) return PyUnicode_FromString(date_format_to_style(date_format)); } else { - Py_INCREF(Py_None); return Py_None; + Py_INCREF(Py_None); + return Py_None; } } /* Get decimal point. */ static char pg_get_decimal_point__doc__[] = -"get_decimal_point() -- get decimal point to be used for money values"; + "get_decimal_point() -- get decimal point to be used for money values"; static PyObject * pg_get_decimal_point(PyObject *self, PyObject *noargs) @@ -491,11 +483,13 @@ pg_get_decimal_point(PyObject *self, PyObject *noargs) char s[2]; if (decimal_point) { - s[0] = decimal_point; s[1] = '\0'; + s[0] = decimal_point; + s[1] = '\0'; ret = PyUnicode_FromString(s); } else { - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } return ret; @@ -503,7 +497,7 @@ pg_get_decimal_point(PyObject *self, PyObject *noargs) /* Set decimal point. */ static char pg_set_decimal_point__doc__[] = -"set_decimal_point(char) -- set decimal point to be used for money values"; + "set_decimal_point(char) -- set decimal point to be used for money values"; static PyObject * pg_set_decimal_point(PyObject *self, PyObject *args) @@ -515,13 +509,14 @@ pg_set_decimal_point(PyObject *self, PyObject *args) if (PyArg_ParseTuple(args, "z", &s)) { if (!s) s = "\0"; - else if (*s && (*(s+1) || !strchr(".,;: '*/_`|", *s))) + else if (*s && (*(s + 1) || !strchr(".,;: '*/_`|", *s))) s = NULL; } if (s) { decimal_point = *s; - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString(PyExc_TypeError, @@ -533,7 +528,7 @@ pg_set_decimal_point(PyObject *self, PyObject *args) /* Get decimal type. */ static char pg_get_decimal__doc__[] = -"get_decimal() -- get the decimal type to be used for numeric values"; + "get_decimal() -- get the decimal type to be used for numeric values"; static PyObject * pg_get_decimal(PyObject *self, PyObject *noargs) @@ -548,7 +543,7 @@ pg_get_decimal(PyObject *self, PyObject *noargs) /* Set decimal type. */ static char pg_set_decimal__doc__[] = -"set_decimal(cls) -- set a decimal type to be used for numeric values"; + "set_decimal(cls) -- set a decimal type to be used for numeric values"; static PyObject * pg_set_decimal(PyObject *self, PyObject *cls) @@ -556,12 +551,17 @@ pg_set_decimal(PyObject *self, PyObject *cls) PyObject *ret = NULL; if (cls == Py_None) { - Py_XDECREF(decimal); decimal = NULL; - Py_INCREF(Py_None); ret = Py_None; + Py_XDECREF(decimal); + decimal = NULL; + Py_INCREF(Py_None); + ret = Py_None; } else if (PyCallable_Check(cls)) { - Py_XINCREF(cls); Py_XDECREF(decimal); decimal = cls; - Py_INCREF(Py_None); ret = Py_None; + Py_XINCREF(cls); + Py_XDECREF(decimal); + decimal = cls; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString(PyExc_TypeError, @@ -574,7 +574,7 @@ pg_set_decimal(PyObject *self, PyObject *cls) /* Get usage of bool values. */ static char pg_get_bool__doc__[] = -"get_bool() -- check whether boolean values are converted to bool"; + "get_bool() -- check whether boolean values are converted to bool"; static PyObject * pg_get_bool(PyObject *self, PyObject *noargs) @@ -589,7 +589,7 @@ pg_get_bool(PyObject *self, PyObject *noargs) /* Set usage of bool values. */ static char pg_set_bool__doc__[] = -"set_bool(on) -- set whether boolean values should be converted to bool"; + "set_bool(on) -- set whether boolean values should be converted to bool"; static PyObject * pg_set_bool(PyObject *self, PyObject *args) @@ -600,7 +600,8 @@ pg_set_bool(PyObject *self, PyObject *args) /* gets arguments */ if (PyArg_ParseTuple(args, "i", &i)) { bool_as_text = i ? 0 : 1; - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString( @@ -613,7 +614,7 @@ pg_set_bool(PyObject *self, PyObject *args) /* Get conversion of arrays to lists. */ static char pg_get_array__doc__[] = -"get_array() -- check whether arrays are converted as lists"; + "get_array() -- check whether arrays are converted as lists"; static PyObject * pg_get_array(PyObject *self, PyObject *noargs) @@ -628,18 +629,19 @@ pg_get_array(PyObject *self, PyObject *noargs) /* Set conversion of arrays to lists. */ static char pg_set_array__doc__[] = -"set_array(on) -- set whether arrays should be converted to lists"; + "set_array(on) -- set whether arrays should be converted to lists"; static PyObject * -pg_set_array(PyObject* self, PyObject* args) +pg_set_array(PyObject *self, PyObject *args) { - PyObject* ret = NULL; + PyObject *ret = NULL; int i; /* gets arguments */ if (PyArg_ParseTuple(args, "i", &i)) { array_as_text = i ? 0 : 1; - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString( @@ -652,7 +654,7 @@ pg_set_array(PyObject* self, PyObject* args) /* Check whether bytea values are unescaped. */ static char pg_get_bytea_escaped__doc__[] = -"get_bytea_escaped() -- check whether bytea will be returned escaped"; + "get_bytea_escaped() -- check whether bytea will be returned escaped"; static PyObject * pg_get_bytea_escaped(PyObject *self, PyObject *noargs) @@ -667,7 +669,7 @@ pg_get_bytea_escaped(PyObject *self, PyObject *noargs) /* Set usage of bool values. */ static char pg_set_bytea_escaped__doc__[] = -"set_bytea_escaped(on) -- set whether bytea will be returned escaped"; + "set_bytea_escaped(on) -- set whether bytea will be returned escaped"; static PyObject * pg_set_bytea_escaped(PyObject *self, PyObject *args) @@ -678,7 +680,8 @@ pg_set_bytea_escaped(PyObject *self, PyObject *args) /* gets arguments */ if (PyArg_ParseTuple(args, "i", &i)) { bytea_escaped = i ? 1 : 0; - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString(PyExc_TypeError, @@ -692,18 +695,15 @@ pg_set_bytea_escaped(PyObject *self, PyObject *args) /* set query helper functions (not part of public API) */ static char pg_set_query_helpers__doc__[] = -"set_query_helpers(*helpers) -- set internal query helper functions"; + "set_query_helpers(*helpers) -- set internal query helper functions"; static PyObject * pg_set_query_helpers(PyObject *self, PyObject *args) { /* gets arguments */ - if (!PyArg_ParseTuple(args, "O!O!O!O!", - &PyFunction_Type, &dictiter, - &PyFunction_Type, &namediter, - &PyFunction_Type, &namednext, - &PyFunction_Type, &scalariter)) - { + if (!PyArg_ParseTuple(args, "O!O!O!O!", &PyFunction_Type, &dictiter, + &PyFunction_Type, &namediter, &PyFunction_Type, + &namednext, &PyFunction_Type, &scalariter)) { return NULL; } @@ -713,7 +713,7 @@ pg_set_query_helpers(PyObject *self, PyObject *args) /* Get json decode function. */ static char pg_get_jsondecode__doc__[] = -"get_jsondecode() -- get the function used for decoding json results"; + "get_jsondecode() -- get the function used for decoding json results"; static PyObject * pg_get_jsondecode(PyObject *self, PyObject *noargs) @@ -730,7 +730,8 @@ pg_get_jsondecode(PyObject *self, PyObject *noargs) /* Set json decode function. */ static char pg_set_jsondecode__doc__[] = -"set_jsondecode(func) -- set a function to be used for decoding json results"; + "set_jsondecode(func) -- set a function to be used for decoding json " + "results"; static PyObject * pg_set_jsondecode(PyObject *self, PyObject *func) @@ -738,12 +739,17 @@ pg_set_jsondecode(PyObject *self, PyObject *func) PyObject *ret = NULL; if (func == Py_None) { - Py_XDECREF(jsondecode); jsondecode = NULL; - Py_INCREF(Py_None); ret = Py_None; + Py_XDECREF(jsondecode); + jsondecode = NULL; + Py_INCREF(Py_None); + ret = Py_None; } else if (PyCallable_Check(func)) { - Py_XINCREF(func); Py_XDECREF(jsondecode); jsondecode = func; - Py_INCREF(Py_None); ret = Py_None; + Py_XINCREF(func); + Py_XDECREF(jsondecode); + jsondecode = func; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString(PyExc_TypeError, @@ -756,7 +762,7 @@ pg_set_jsondecode(PyObject *self, PyObject *func) /* Get default host. */ static char pg_get_defhost__doc__[] = -"get_defhost() -- return default database host"; + "get_defhost() -- return default database host"; static PyObject * pg_get_defhost(PyObject *self, PyObject *noargs) @@ -767,7 +773,8 @@ pg_get_defhost(PyObject *self, PyObject *noargs) /* Set default host. */ static char pg_set_defhost__doc__[] = -"set_defhost(string) -- set default database host and return previous value"; + "set_defhost(string) -- set default database host and return previous " + "value"; static PyObject * pg_set_defhost(PyObject *self, PyObject *args) @@ -799,7 +806,7 @@ pg_set_defhost(PyObject *self, PyObject *args) /* Get default database. */ static char pg_get_defbase__doc__[] = -"get_defbase() -- return default database name"; + "get_defbase() -- return default database name"; static PyObject * pg_get_defbase(PyObject *self, PyObject *noargs) @@ -810,7 +817,8 @@ pg_get_defbase(PyObject *self, PyObject *noargs) /* Set default database. */ static char pg_set_defbase__doc__[] = -"set_defbase(string) -- set default database name and return previous value"; + "set_defbase(string) -- set default database name and return previous " + "value"; static PyObject * pg_set_defbase(PyObject *self, PyObject *args) @@ -842,7 +850,7 @@ pg_set_defbase(PyObject *self, PyObject *args) /* Get default options. */ static char pg_get_defopt__doc__[] = -"get_defopt() -- return default database options"; + "get_defopt() -- return default database options"; static PyObject * pg_get_defopt(PyObject *self, PyObject *noargs) @@ -853,7 +861,7 @@ pg_get_defopt(PyObject *self, PyObject *noargs) /* Set default options. */ static char pg_set_defopt__doc__[] = -"set_defopt(string) -- set default options and return previous value"; + "set_defopt(string) -- set default options and return previous value"; static PyObject * pg_setdefopt(PyObject *self, PyObject *args) @@ -885,7 +893,7 @@ pg_setdefopt(PyObject *self, PyObject *args) /* Get default username. */ static char pg_get_defuser__doc__[] = -"get_defuser() -- return default database username"; + "get_defuser() -- return default database username"; static PyObject * pg_get_defuser(PyObject *self, PyObject *noargs) @@ -897,7 +905,7 @@ pg_get_defuser(PyObject *self, PyObject *noargs) /* Set default username. */ static char pg_set_defuser__doc__[] = -"set_defuser(name) -- set default username and return previous value"; + "set_defuser(name) -- set default username and return previous value"; static PyObject * pg_set_defuser(PyObject *self, PyObject *args) @@ -929,7 +937,7 @@ pg_set_defuser(PyObject *self, PyObject *args) /* Set default password. */ static char pg_set_defpasswd__doc__[] = -"set_defpasswd(password) -- set default database password"; + "set_defpasswd(password) -- set default database password"; static PyObject * pg_set_defpasswd(PyObject *self, PyObject *args) @@ -958,7 +966,7 @@ pg_set_defpasswd(PyObject *self, PyObject *args) /* Get default port. */ static char pg_get_defport__doc__[] = -"get_defport() -- return default database port"; + "get_defport() -- return default database port"; static PyObject * pg_get_defport(PyObject *self, PyObject *noargs) @@ -969,7 +977,7 @@ pg_get_defport(PyObject *self, PyObject *noargs) /* Set default port. */ static char pg_set_defport__doc__[] = -"set_defport(port) -- set default port and return previous value"; + "set_defport(port) -- set default port and return previous value"; static PyObject * pg_set_defport(PyObject *self, PyObject *args) @@ -1001,7 +1009,7 @@ pg_set_defport(PyObject *self, PyObject *args) /* Cast a string with a text representation of an array to a list. */ static char pg_cast_array__doc__[] = -"cast_array(string, cast=None, delim=',') -- cast a string as an array"; + "cast_array(string, cast=None, delim=',') -- cast a string as an array"; PyObject * pg_cast_array(PyObject *self, PyObject *args, PyObject *dict) @@ -1012,10 +1020,8 @@ pg_cast_array(PyObject *self, PyObject *args, PyObject *dict) Py_ssize_t size; int encoding; - if (!PyArg_ParseTupleAndKeywords( - args, dict, "O|Oc", - (char**) kwlist, &string_obj, &cast_obj, &delim)) - { + if (!PyArg_ParseTupleAndKeywords(args, dict, "O|Oc", (char **)kwlist, + &string_obj, &cast_obj, &delim)) { return NULL; } @@ -1026,7 +1032,8 @@ pg_cast_array(PyObject *self, PyObject *args, PyObject *dict) } else if (PyUnicode_Check(string_obj)) { string_obj = PyUnicode_AsUTF8String(string_obj); - if (!string_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!string_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(string_obj, &string, &size); encoding = pg_encoding_utf8; } @@ -1056,7 +1063,7 @@ pg_cast_array(PyObject *self, PyObject *args, PyObject *dict) /* Cast a string with a text representation of a record to a tuple. */ static char pg_cast_record__doc__[] = -"cast_record(string, cast=None, delim=',') -- cast a string as a record"; + "cast_record(string, cast=None, delim=',') -- cast a string as a record"; PyObject * pg_cast_record(PyObject *self, PyObject *args, PyObject *dict) @@ -1067,10 +1074,8 @@ pg_cast_record(PyObject *self, PyObject *args, PyObject *dict) Py_ssize_t size, len; int encoding; - if (!PyArg_ParseTupleAndKeywords( - args, dict, "O|Oc", - (char**) kwlist, &string_obj, &cast_obj, &delim)) - { + if (!PyArg_ParseTupleAndKeywords(args, dict, "O|Oc", (char **)kwlist, + &string_obj, &cast_obj, &delim)) { return NULL; } @@ -1081,7 +1086,8 @@ pg_cast_record(PyObject *self, PyObject *args, PyObject *dict) } else if (PyUnicode_Check(string_obj)) { string_obj = PyUnicode_AsUTF8String(string_obj); - if (!string_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!string_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(string_obj, &string, &size); encoding = pg_encoding_utf8; } @@ -1096,7 +1102,8 @@ pg_cast_record(PyObject *self, PyObject *args, PyObject *dict) len = 0; } else if (cast_obj == Py_None) { - cast_obj = NULL; len = 0; + cast_obj = NULL; + len = 0; } else if (PyTuple_Check(cast_obj) || PyList_Check(cast_obj)) { len = PySequence_Size(cast_obj); @@ -1120,7 +1127,7 @@ pg_cast_record(PyObject *self, PyObject *args, PyObject *dict) /* Cast a string with a text representation of an hstore to a dict. */ static char pg_cast_hstore__doc__[] = -"cast_hstore(string) -- cast a string as an hstore"; + "cast_hstore(string) -- cast a string as an hstore"; PyObject * pg_cast_hstore(PyObject *self, PyObject *string) @@ -1136,7 +1143,8 @@ pg_cast_hstore(PyObject *self, PyObject *string) } else if (PyUnicode_Check(string)) { tmp_obj = PyUnicode_AsUTF8String(string); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &s, &size); encoding = pg_encoding_utf8; } @@ -1157,50 +1165,47 @@ pg_cast_hstore(PyObject *self, PyObject *string) /* The list of functions defined in the module */ static struct PyMethodDef pg_methods[] = { - {"connect", (PyCFunction) pg_connect, - METH_VARARGS|METH_KEYWORDS, pg_connect__doc__}, - {"escape_string", (PyCFunction) pg_escape_string, - METH_O, pg_escape_string__doc__}, - {"escape_bytea", (PyCFunction) pg_escape_bytea, - METH_O, pg_escape_bytea__doc__}, - {"unescape_bytea", (PyCFunction) pg_unescape_bytea, - METH_O, pg_unescape_bytea__doc__}, - {"get_datestyle", (PyCFunction) pg_get_datestyle, - METH_NOARGS, pg_get_datestyle__doc__}, - {"set_datestyle", (PyCFunction) pg_set_datestyle, - METH_VARARGS, pg_set_datestyle__doc__}, - {"get_decimal_point", (PyCFunction) pg_get_decimal_point, - METH_NOARGS, pg_get_decimal_point__doc__}, - {"set_decimal_point", (PyCFunction) pg_set_decimal_point, - METH_VARARGS, pg_set_decimal_point__doc__}, - {"get_decimal", (PyCFunction) pg_get_decimal, - METH_NOARGS, pg_get_decimal__doc__}, - {"set_decimal", (PyCFunction) pg_set_decimal, - METH_O, pg_set_decimal__doc__}, - {"get_bool", (PyCFunction) pg_get_bool, - METH_NOARGS, pg_get_bool__doc__}, - {"set_bool", (PyCFunction) pg_set_bool, - METH_VARARGS, pg_set_bool__doc__}, - {"get_array", (PyCFunction) pg_get_array, - METH_NOARGS, pg_get_array__doc__}, - {"set_array", (PyCFunction) pg_set_array, - METH_VARARGS, pg_set_array__doc__}, - {"set_query_helpers", (PyCFunction) pg_set_query_helpers, - METH_VARARGS, pg_set_query_helpers__doc__}, - {"get_bytea_escaped", (PyCFunction) pg_get_bytea_escaped, - METH_NOARGS, pg_get_bytea_escaped__doc__}, - {"set_bytea_escaped", (PyCFunction) pg_set_bytea_escaped, - METH_VARARGS, pg_set_bytea_escaped__doc__}, - {"get_jsondecode", (PyCFunction) pg_get_jsondecode, - METH_NOARGS, pg_get_jsondecode__doc__}, - {"set_jsondecode", (PyCFunction) pg_set_jsondecode, - METH_O, pg_set_jsondecode__doc__}, - {"cast_array", (PyCFunction) pg_cast_array, - METH_VARARGS|METH_KEYWORDS, pg_cast_array__doc__}, - {"cast_record", (PyCFunction) pg_cast_record, - METH_VARARGS|METH_KEYWORDS, pg_cast_record__doc__}, - {"cast_hstore", (PyCFunction) pg_cast_hstore, - METH_O, pg_cast_hstore__doc__}, + {"connect", (PyCFunction)pg_connect, METH_VARARGS | METH_KEYWORDS, + pg_connect__doc__}, + {"escape_string", (PyCFunction)pg_escape_string, METH_O, + pg_escape_string__doc__}, + {"escape_bytea", (PyCFunction)pg_escape_bytea, METH_O, + pg_escape_bytea__doc__}, + {"unescape_bytea", (PyCFunction)pg_unescape_bytea, METH_O, + pg_unescape_bytea__doc__}, + {"get_datestyle", (PyCFunction)pg_get_datestyle, METH_NOARGS, + pg_get_datestyle__doc__}, + {"set_datestyle", (PyCFunction)pg_set_datestyle, METH_VARARGS, + pg_set_datestyle__doc__}, + {"get_decimal_point", (PyCFunction)pg_get_decimal_point, METH_NOARGS, + pg_get_decimal_point__doc__}, + {"set_decimal_point", (PyCFunction)pg_set_decimal_point, METH_VARARGS, + pg_set_decimal_point__doc__}, + {"get_decimal", (PyCFunction)pg_get_decimal, METH_NOARGS, + pg_get_decimal__doc__}, + {"set_decimal", (PyCFunction)pg_set_decimal, METH_O, + pg_set_decimal__doc__}, + {"get_bool", (PyCFunction)pg_get_bool, METH_NOARGS, pg_get_bool__doc__}, + {"set_bool", (PyCFunction)pg_set_bool, METH_VARARGS, pg_set_bool__doc__}, + {"get_array", (PyCFunction)pg_get_array, METH_NOARGS, pg_get_array__doc__}, + {"set_array", (PyCFunction)pg_set_array, METH_VARARGS, + pg_set_array__doc__}, + {"set_query_helpers", (PyCFunction)pg_set_query_helpers, METH_VARARGS, + pg_set_query_helpers__doc__}, + {"get_bytea_escaped", (PyCFunction)pg_get_bytea_escaped, METH_NOARGS, + pg_get_bytea_escaped__doc__}, + {"set_bytea_escaped", (PyCFunction)pg_set_bytea_escaped, METH_VARARGS, + pg_set_bytea_escaped__doc__}, + {"get_jsondecode", (PyCFunction)pg_get_jsondecode, METH_NOARGS, + pg_get_jsondecode__doc__}, + {"set_jsondecode", (PyCFunction)pg_set_jsondecode, METH_O, + pg_set_jsondecode__doc__}, + {"cast_array", (PyCFunction)pg_cast_array, METH_VARARGS | METH_KEYWORDS, + pg_cast_array__doc__}, + {"cast_record", (PyCFunction)pg_cast_record, METH_VARARGS | METH_KEYWORDS, + pg_cast_record__doc__}, + {"cast_hstore", (PyCFunction)pg_cast_hstore, METH_O, + pg_cast_hstore__doc__}, {"get_defhost", pg_get_defhost, METH_NOARGS, pg_get_defhost__doc__}, {"set_defhost", pg_set_defhost, METH_VARARGS, pg_set_defhost__doc__}, {"get_defbase", pg_get_defbase, METH_NOARGS, pg_get_defbase__doc__}, @@ -1212,25 +1217,26 @@ static struct PyMethodDef pg_methods[] = { {"get_defuser", pg_get_defuser, METH_NOARGS, pg_get_defuser__doc__}, {"set_defuser", pg_set_defuser, METH_VARARGS, pg_set_defuser__doc__}, {"set_defpasswd", pg_set_defpasswd, METH_VARARGS, pg_set_defpasswd__doc__}, - {"get_pqlib_version", (PyCFunction) pg_get_pqlib_version, - METH_NOARGS, pg_get_pqlib_version__doc__}, + {"get_pqlib_version", (PyCFunction)pg_get_pqlib_version, METH_NOARGS, + pg_get_pqlib_version__doc__}, {NULL, NULL} /* sentinel */ }; static char pg__doc__[] = "Python interface to PostgreSQL DB"; static struct PyModuleDef moduleDef = { - PyModuleDef_HEAD_INIT, - "_pg", /* m_name */ - pg__doc__, /* m_doc */ - -1, /* m_size */ - pg_methods /* m_methods */ + PyModuleDef_HEAD_INIT, "_pg", /* m_name */ + pg__doc__, /* m_doc */ + -1, /* m_size */ + pg_methods /* m_methods */ }; /* Initialization function for the module */ -PyMODINIT_FUNC PyInit__pg(void); +PyMODINIT_FUNC +PyInit__pg(void); -PyMODINIT_FUNC PyInit__pg(void) +PyMODINIT_FUNC +PyInit__pg(void) { PyObject *mod, *dict, *s; @@ -1239,17 +1245,13 @@ PyMODINIT_FUNC PyInit__pg(void) mod = PyModule_Create(&moduleDef); /* Initialize here because some Windows platforms get confused otherwise */ - connType.tp_base = noticeType.tp_base = - queryType.tp_base = sourceType.tp_base = &PyBaseObject_Type; + connType.tp_base = noticeType.tp_base = queryType.tp_base = + sourceType.tp_base = &PyBaseObject_Type; largeType.tp_base = &PyBaseObject_Type; - if (PyType_Ready(&connType) - || PyType_Ready(¬iceType) - || PyType_Ready(&queryType) - || PyType_Ready(&sourceType) - || PyType_Ready(&largeType) - ) - { + if (PyType_Ready(&connType) || PyType_Ready(¬iceType) || + PyType_Ready(&queryType) || PyType_Ready(&sourceType) || + PyType_Ready(&largeType)) { return NULL; } @@ -1262,48 +1264,45 @@ PyMODINIT_FUNC PyInit__pg(void) Warning = PyErr_NewException("pg.Warning", PyExc_Exception, NULL); PyDict_SetItemString(dict, "Warning", Warning); - InterfaceError = PyErr_NewException( - "pg.InterfaceError", Error, NULL); + InterfaceError = PyErr_NewException("pg.InterfaceError", Error, NULL); PyDict_SetItemString(dict, "InterfaceError", InterfaceError); - DatabaseError = PyErr_NewException( - "pg.DatabaseError", Error, NULL); + DatabaseError = PyErr_NewException("pg.DatabaseError", Error, NULL); PyDict_SetItemString(dict, "DatabaseError", DatabaseError); - InternalError = PyErr_NewException( - "pg.InternalError", DatabaseError, NULL); + InternalError = + PyErr_NewException("pg.InternalError", DatabaseError, NULL); PyDict_SetItemString(dict, "InternalError", InternalError); - OperationalError = PyErr_NewException( - "pg.OperationalError", DatabaseError, NULL); + OperationalError = + PyErr_NewException("pg.OperationalError", DatabaseError, NULL); PyDict_SetItemString(dict, "OperationalError", OperationalError); - ProgrammingError = PyErr_NewException( - "pg.ProgrammingError", DatabaseError, NULL); + ProgrammingError = + PyErr_NewException("pg.ProgrammingError", DatabaseError, NULL); PyDict_SetItemString(dict, "ProgrammingError", ProgrammingError); - IntegrityError = PyErr_NewException( - "pg.IntegrityError", DatabaseError, NULL); + IntegrityError = + PyErr_NewException("pg.IntegrityError", DatabaseError, NULL); PyDict_SetItemString(dict, "IntegrityError", IntegrityError); - DataError = PyErr_NewException( - "pg.DataError", DatabaseError, NULL); + DataError = PyErr_NewException("pg.DataError", DatabaseError, NULL); PyDict_SetItemString(dict, "DataError", DataError); - NotSupportedError = PyErr_NewException( - "pg.NotSupportedError", DatabaseError, NULL); + NotSupportedError = + PyErr_NewException("pg.NotSupportedError", DatabaseError, NULL); PyDict_SetItemString(dict, "NotSupportedError", NotSupportedError); - InvalidResultError = PyErr_NewException( - "pg.InvalidResultError", DataError, NULL); + InvalidResultError = + PyErr_NewException("pg.InvalidResultError", DataError, NULL); PyDict_SetItemString(dict, "InvalidResultError", InvalidResultError); - NoResultError = PyErr_NewException( - "pg.NoResultError", InvalidResultError, NULL); + NoResultError = + PyErr_NewException("pg.NoResultError", InvalidResultError, NULL); PyDict_SetItemString(dict, "NoResultError", NoResultError); - MultipleResultsError = PyErr_NewException( - "pg.MultipleResultsError", InvalidResultError, NULL); + MultipleResultsError = PyErr_NewException("pg.MultipleResultsError", + InvalidResultError, NULL); PyDict_SetItemString(dict, "MultipleResultsError", MultipleResultsError); /* Make the version available */ @@ -1320,16 +1319,24 @@ PyMODINIT_FUNC PyInit__pg(void) /* Transaction states */ PyDict_SetItemString(dict, "TRANS_IDLE", PyLong_FromLong(PQTRANS_IDLE)); - PyDict_SetItemString(dict, "TRANS_ACTIVE", PyLong_FromLong(PQTRANS_ACTIVE)); - PyDict_SetItemString(dict, "TRANS_INTRANS", PyLong_FromLong(PQTRANS_INTRANS)); - PyDict_SetItemString(dict, "TRANS_INERROR", PyLong_FromLong(PQTRANS_INERROR)); - PyDict_SetItemString(dict, "TRANS_UNKNOWN", PyLong_FromLong(PQTRANS_UNKNOWN)); + PyDict_SetItemString(dict, "TRANS_ACTIVE", + PyLong_FromLong(PQTRANS_ACTIVE)); + PyDict_SetItemString(dict, "TRANS_INTRANS", + PyLong_FromLong(PQTRANS_INTRANS)); + PyDict_SetItemString(dict, "TRANS_INERROR", + PyLong_FromLong(PQTRANS_INERROR)); + PyDict_SetItemString(dict, "TRANS_UNKNOWN", + PyLong_FromLong(PQTRANS_UNKNOWN)); /* Polling results */ - PyDict_SetItemString(dict, "POLLING_OK", PyLong_FromLong(PGRES_POLLING_OK)); - PyDict_SetItemString(dict, "POLLING_FAILED", PyLong_FromLong(PGRES_POLLING_FAILED)); - PyDict_SetItemString(dict, "POLLING_READING", PyLong_FromLong(PGRES_POLLING_READING)); - PyDict_SetItemString(dict, "POLLING_WRITING", PyLong_FromLong(PGRES_POLLING_WRITING)); + PyDict_SetItemString(dict, "POLLING_OK", + PyLong_FromLong(PGRES_POLLING_OK)); + PyDict_SetItemString(dict, "POLLING_FAILED", + PyLong_FromLong(PGRES_POLLING_FAILED)); + PyDict_SetItemString(dict, "POLLING_READING", + PyLong_FromLong(PGRES_POLLING_READING)); + PyDict_SetItemString(dict, "POLLING_WRITING", + PyLong_FromLong(PGRES_POLLING_WRITING)); /* Create mode for large objects */ PyDict_SetItemString(dict, "INV_READ", PyLong_FromLong(INV_READ)); diff --git a/pgnotice.c b/pgnotice.c index e079283c..0252a56f 100644 --- a/pgnotice.c +++ b/pgnotice.c @@ -25,7 +25,7 @@ notice_getattr(noticeObject *self, PyObject *nameobj) if (!strcmp(name, "pgcnx")) { if (self->pgcnx && _check_cnx_obj(self->pgcnx)) { Py_INCREF(self->pgcnx); - return (PyObject *) self->pgcnx; + return (PyObject *)self->pgcnx; } else { Py_INCREF(Py_None); @@ -54,11 +54,12 @@ notice_getattr(noticeObject *self, PyObject *nameobj) return PyUnicode_FromString(s); } else { - Py_INCREF(Py_None); return Py_None; + Py_INCREF(Py_None); + return Py_None; } } - return PyObject_GenericGetAttr((PyObject *) self, nameobj); + return PyObject_GenericGetAttr((PyObject *)self, nameobj); } /* Get the list of notice attributes. */ @@ -67,10 +68,9 @@ notice_dir(noticeObject *self, PyObject *noargs) { PyObject *attrs; - attrs = PyObject_Dir(PyObject_Type((PyObject *) self)); - PyObject_CallMethod( - attrs, "extend", "[ssssss]", - "pgcnx", "severity", "message", "primary", "detail", "hint"); + attrs = PyObject_Dir(PyObject_Type((PyObject *)self)); + PyObject_CallMethod(attrs, "extend", "[ssssss]", "pgcnx", "severity", + "message", "primary", "detail", "hint"); return attrs; } @@ -84,41 +84,38 @@ notice_str(noticeObject *self) /* Notice object methods */ static struct PyMethodDef notice_methods[] = { - {"__dir__", (PyCFunction) notice_dir, METH_NOARGS, NULL}, - {NULL, NULL} -}; + {"__dir__", (PyCFunction)notice_dir, METH_NOARGS, NULL}, {NULL, NULL}}; static char notice__doc__[] = "PostgreSQL notice object"; /* Notice type definition */ static PyTypeObject noticeType = { - PyVarObject_HEAD_INIT(NULL, 0) - "pg.Notice", /* tp_name */ - sizeof(noticeObject), /* tp_basicsize */ - 0, /* tp_itemsize */ + PyVarObject_HEAD_INIT(NULL, 0) "pg.Notice", /* tp_name */ + sizeof(noticeObject), /* tp_basicsize */ + 0, /* tp_itemsize */ /* methods */ - 0, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_compare */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - (reprfunc) notice_str, /* tp_str */ - (getattrofunc) notice_getattr, /* tp_getattro */ - PyObject_GenericSetAttr, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - notice__doc__, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - notice_methods, /* tp_methods */ + 0, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + (reprfunc)notice_str, /* tp_str */ + (getattrofunc)notice_getattr, /* tp_getattro */ + PyObject_GenericSetAttr, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + notice__doc__, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + notice_methods, /* tp_methods */ }; diff --git a/pgquery.c b/pgquery.c index 194bfaa1..6346497d 100644 --- a/pgquery.c +++ b/pgquery.c @@ -37,7 +37,7 @@ query_len(PyObject *self) PyObject *tmp; Py_ssize_t len; - tmp = PyLong_FromLong(((queryObject*) self)->max_row); + tmp = PyLong_FromLong(((queryObject *)self)->max_row); len = PyLong_AsSsize_t(tmp); Py_DECREF(tmp); return len; @@ -64,18 +64,18 @@ _query_value_in_column(queryObject *self, int column) /* cast the string representation into a Python object */ if (type & PYGRES_ARRAY) return cast_array(s, - PQgetlength(self->result, self->current_row, column), - self->encoding, type, NULL, 0); + PQgetlength(self->result, self->current_row, column), + self->encoding, type, NULL, 0); if (type == PYGRES_BYTEA) return cast_bytea_text(s); if (type == PYGRES_OTHER) return cast_other(s, - PQgetlength(self->result, self->current_row, column), - self->encoding, - PQftype(self->result, column), self->pgcnx->cast_hook); + PQgetlength(self->result, self->current_row, column), + self->encoding, PQftype(self->result, column), + self->pgcnx->cast_hook); if (type & PYGRES_TEXT) - return cast_sized_text(s, - PQgetlength(self->result, self->current_row, column), + return cast_sized_text( + s, PQgetlength(self->result, self->current_row, column), self->encoding, type); return cast_unsized_simple(s, type); } @@ -94,7 +94,8 @@ _query_row_as_tuple(queryObject *self) for (j = 0; j < self->num_fields; ++j) { PyObject *val = _query_value_in_column(self, j); if (!val) { - Py_DECREF(row_tuple); return NULL; + Py_DECREF(row_tuple); + return NULL; } PyTuple_SET_ITEM(row_tuple, j, val); } @@ -108,7 +109,8 @@ _query_row_as_tuple(queryObject *self) If this is a normal query result, the query itself will be returned, otherwise a result value will be returned that shall be passed on. */ static PyObject * -_get_async_result(queryObject *self, int keep) { +_get_async_result(queryObject *self, int keep) +{ int fetch = 0; if (self->async) { @@ -118,7 +120,8 @@ _get_async_result(queryObject *self, int keep) { /* mark query as fetched, do not fetch again */ self->async = 2; } - } else if (!keep) { + } + else if (!keep) { self->async = 1; } } @@ -147,8 +150,8 @@ _get_async_result(queryObject *self, int keep) { } if ((status = PQresultStatus(self->result)) != PGRES_TUPLES_OK) { - PyObject* result = _conn_non_query_result( - status, self->result, self->pgcnx->cnx); + PyObject *result = + _conn_non_query_result(status, self->result, self->pgcnx->cnx); self->result = NULL; /* since this has been already cleared */ if (!result) { /* Raise an error. We need to call PQgetResult() to clear the @@ -181,8 +184,9 @@ _get_async_result(queryObject *self, int keep) { Py_DECREF(self); return NULL; } - } else if (self->async == 2 && - !self->max_row && !self->num_fields && !self->col_types) { + } + else if (self->async == 2 && !self->max_row && !self->num_fields && + !self->col_types) { Py_INCREF(Py_None); return Py_None; } @@ -195,14 +199,14 @@ _get_async_result(queryObject *self, int keep) { static PyObject * query_getitem(PyObject *self, Py_ssize_t i) { - queryObject *q = (queryObject *) self; + queryObject *q = (queryObject *)self; PyObject *tmp; long row; if ((tmp = _get_async_result(q, 0)) != (PyObject *)self) return tmp; - tmp = PyLong_FromSize_t((size_t) i); + tmp = PyLong_FromSize_t((size_t)i); row = PyLong_AsLong(tmp); Py_DECREF(tmp); @@ -211,13 +215,14 @@ query_getitem(PyObject *self, Py_ssize_t i) return NULL; } - q->current_row = (int) row; + q->current_row = (int)row; return _query_row_as_tuple(q); } /* __iter__() method of the queryObject: Returns the default iterator yielding rows as tuples. */ -static PyObject* query_iter(queryObject *self) +static PyObject * +query_iter(queryObject *self) { PyObject *res; @@ -226,7 +231,7 @@ static PyObject* query_iter(queryObject *self) self->current_row = 0; Py_INCREF(self); - return (PyObject*) self; + return (PyObject *)self; } /* __next__() method of the queryObject: @@ -242,13 +247,14 @@ query_next(queryObject *self, PyObject *noargs) } row_tuple = _query_row_as_tuple(self); - if (row_tuple) ++self->current_row; + if (row_tuple) + ++self->current_row; return row_tuple; } /* Get number of bytes allocated for PGresult object */ static char query_memsize__doc__[] = -"memsize() -- return number of bytes allocated by query result"; + "memsize() -- return number of bytes allocated by query result"; static PyObject * query_memsize(queryObject *self, PyObject *noargs) { @@ -262,7 +268,7 @@ query_memsize(queryObject *self, PyObject *noargs) /* List field names from query result. */ static char query_listfields__doc__[] = -"listfields() -- List field names from result"; + "listfields() -- List field names from result"; static PyObject * query_listfields(queryObject *self, PyObject *noargs) @@ -285,7 +291,7 @@ query_listfields(queryObject *self, PyObject *noargs) /* Get field name from number in last result. */ static char query_fieldname__doc__[] = -"fieldname(num) -- return name of field from result from its position"; + "fieldname(num) -- return name of field from result from its position"; static PyObject * query_fieldname(queryObject *self, PyObject *args) @@ -313,7 +319,7 @@ query_fieldname(queryObject *self, PyObject *args) /* Get field number from name in last result. */ static char query_fieldnum__doc__[] = -"fieldnum(name) -- return position in query for field from its name"; + "fieldnum(name) -- return position in query for field from its name"; static PyObject * query_fieldnum(queryObject *self, PyObject *args) @@ -339,13 +345,15 @@ query_fieldnum(queryObject *self, PyObject *args) /* Build a tuple with info for query field with given number. */ static PyObject * -_query_build_field_info(PGresult *res, int col_num) { +_query_build_field_info(PGresult *res, int col_num) +{ PyObject *info; info = PyTuple_New(4); if (info) { PyTuple_SET_ITEM(info, 0, PyUnicode_FromString(PQfname(res, col_num))); - PyTuple_SET_ITEM(info, 1, PyLong_FromLong((long) PQftype(res, col_num))); + PyTuple_SET_ITEM(info, 1, + PyLong_FromLong((long)PQftype(res, col_num))); PyTuple_SET_ITEM(info, 2, PyLong_FromLong(PQfsize(res, col_num))); PyTuple_SET_ITEM(info, 3, PyLong_FromLong(PQfmod(res, col_num))); } @@ -354,7 +362,7 @@ _query_build_field_info(PGresult *res, int col_num) { /* Get information on one or all fields of the query result. */ static char query_fieldinfo__doc__[] = -"fieldinfo([name]) -- return information about field(s) in query result"; + "fieldinfo([name]) -- return information about field(s) in query result"; static PyObject * query_fieldinfo(queryObject *self, PyObject *args) @@ -374,14 +382,18 @@ query_fieldinfo(queryObject *self, PyObject *args) /* gets field number */ if (PyBytes_Check(field)) { num = PQfnumber(self->result, PyBytes_AsString(field)); - } else if (PyUnicode_Check(field)) { + } + else if (PyUnicode_Check(field)) { PyObject *tmp = get_encoded_string(field, self->encoding); - if (!tmp) return NULL; + if (!tmp) + return NULL; num = PQfnumber(self->result, PyBytes_AsString(tmp)); Py_DECREF(tmp); - } else if (PyLong_Check(field)) { - num = (int) PyLong_AsLong(field); - } else { + } + else if (PyLong_Check(field)) { + num = (int)PyLong_AsLong(field); + } + else { PyErr_SetString(PyExc_TypeError, "Field should be given as column number or name"); return NULL; @@ -407,13 +419,12 @@ query_fieldinfo(queryObject *self, PyObject *args) return result; } - /* Retrieve one row from the result as a tuple. */ static char query_one__doc__[] = -"one() -- Get one row from the result of a query\n\n" -"Only one row from the result is returned as a tuple of fields.\n" -"This method can be called multiple times to return more rows.\n" -"It returns None if the result does not contain one more row.\n"; + "one() -- Get one row from the result of a query\n\n" + "Only one row from the result is returned as a tuple of fields.\n" + "This method can be called multiple times to return more rows.\n" + "It returns None if the result does not contain one more row.\n"; static PyObject * query_one(queryObject *self, PyObject *noargs) @@ -421,13 +432,14 @@ query_one(queryObject *self, PyObject *noargs) PyObject *row_tuple; if ((row_tuple = _get_async_result(self, 0)) == (PyObject *)self) { - if (self->current_row >= self->max_row) { - Py_INCREF(Py_None); return Py_None; + Py_INCREF(Py_None); + return Py_None; } row_tuple = _query_row_as_tuple(self); - if (row_tuple) ++self->current_row; + if (row_tuple) + ++self->current_row; } return row_tuple; @@ -435,11 +447,13 @@ query_one(queryObject *self, PyObject *noargs) /* Retrieve the single row from the result as a tuple. */ static char query_single__doc__[] = -"single() -- Get the result of a query as single row\n\n" -"The single row from the query result is returned as a tuple of fields.\n" -"This method returns the same single row when called multiple times.\n" -"It raises an InvalidResultError if the result doesn't have exactly one row,\n" -"which will be of type NoResultError or MultipleResultsError specifically.\n"; + "single() -- Get the result of a query as single row\n\n" + "The single row from the query result is returned as a tuple of fields.\n" + "This method returns the same single row when called multiple times.\n" + "It raises an InvalidResultError if the result doesn't have exactly one " + "row,\n" + "which will be of type NoResultError or MultipleResultsError " + "specifically.\n"; static PyObject * query_single(queryObject *self, PyObject *noargs) @@ -447,7 +461,6 @@ query_single(queryObject *self, PyObject *noargs) PyObject *row_tuple; if ((row_tuple = _get_async_result(self, 0)) == (PyObject *)self) { - if (self->max_row != 1) { if (self->max_row) set_error_msg(MultipleResultsError, "Multiple results found"); @@ -458,7 +471,8 @@ query_single(queryObject *self, PyObject *noargs) self->current_row = 0; row_tuple = _query_row_as_tuple(self); - if (row_tuple) ++self->current_row; + if (row_tuple) + ++self->current_row; } return row_tuple; @@ -466,9 +480,9 @@ query_single(queryObject *self, PyObject *noargs) /* Retrieve the last query result as a list of tuples. */ static char query_getresult__doc__[] = -"getresult() -- Get the result of a query\n\n" -"The result is returned as a list of rows, each one a tuple of fields\n" -"in the order returned by the server.\n"; + "getresult() -- Get the result of a query\n\n" + "The result is returned as a list of rows, each one a tuple of fields\n" + "in the order returned by the server.\n"; static PyObject * query_getresult(queryObject *self, PyObject *noargs) @@ -477,7 +491,6 @@ query_getresult(queryObject *self, PyObject *noargs) int i; if ((result_list = _get_async_result(self, 0)) == (PyObject *)self) { - if (!(result_list = PyList_New(self->max_row))) { return NULL; } @@ -486,7 +499,8 @@ query_getresult(queryObject *self, PyObject *noargs) PyObject *row_tuple = query_next(self, noargs); if (!row_tuple) { - Py_DECREF(result_list); return NULL; + Py_DECREF(result_list); + return NULL; } PyList_SET_ITEM(result_list, i, row_tuple); } @@ -510,7 +524,8 @@ _query_row_as_dict(queryObject *self) PyObject *val = _query_value_in_column(self, j); if (!val) { - Py_DECREF(row_dict); return NULL; + Py_DECREF(row_dict); + return NULL; } PyDict_SetItemString(row_dict, PQfname(self->result, j), val); Py_DECREF(val); @@ -531,17 +546,18 @@ query_next_dict(queryObject *self, PyObject *noargs) } row_dict = _query_row_as_dict(self); - if (row_dict) ++self->current_row; + if (row_dict) + ++self->current_row; return row_dict; } /* Retrieve one row from the result as a dictionary. */ static char query_onedict__doc__[] = -"onedict() -- Get one row from the result of a query\n\n" -"Only one row from the result is returned as a dictionary with\n" -"the field names used as the keys.\n" -"This method can be called multiple times to return more rows.\n" -"It returns None if the result does not contain one more row.\n"; + "onedict() -- Get one row from the result of a query\n\n" + "Only one row from the result is returned as a dictionary with\n" + "the field names used as the keys.\n" + "This method can be called multiple times to return more rows.\n" + "It returns None if the result does not contain one more row.\n"; static PyObject * query_onedict(queryObject *self, PyObject *noargs) @@ -549,13 +565,14 @@ query_onedict(queryObject *self, PyObject *noargs) PyObject *row_dict; if ((row_dict = _get_async_result(self, 0)) == (PyObject *)self) { - if (self->current_row >= self->max_row) { - Py_INCREF(Py_None); return Py_None; + Py_INCREF(Py_None); + return Py_None; } row_dict = _query_row_as_dict(self); - if (row_dict) ++self->current_row; + if (row_dict) + ++self->current_row; } return row_dict; @@ -563,12 +580,14 @@ query_onedict(queryObject *self, PyObject *noargs) /* Retrieve the single row from the result as a dictionary. */ static char query_singledict__doc__[] = -"singledict() -- Get the result of a query as single row\n\n" -"The single row from the query result is returned as a dictionary with\n" -"the field names used as the keys.\n" -"This method returns the same single row when called multiple times.\n" -"It raises an InvalidResultError if the result doesn't have exactly one row,\n" -"which will be of type NoResultError or MultipleResultsError specifically.\n"; + "singledict() -- Get the result of a query as single row\n\n" + "The single row from the query result is returned as a dictionary with\n" + "the field names used as the keys.\n" + "This method returns the same single row when called multiple times.\n" + "It raises an InvalidResultError if the result doesn't have exactly one " + "row,\n" + "which will be of type NoResultError or MultipleResultsError " + "specifically.\n"; static PyObject * query_singledict(queryObject *self, PyObject *noargs) @@ -576,7 +595,6 @@ query_singledict(queryObject *self, PyObject *noargs) PyObject *row_dict; if ((row_dict = _get_async_result(self, 0)) == (PyObject *)self) { - if (self->max_row != 1) { if (self->max_row) set_error_msg(MultipleResultsError, "Multiple results found"); @@ -587,7 +605,8 @@ query_singledict(queryObject *self, PyObject *noargs) self->current_row = 0; row_dict = _query_row_as_dict(self); - if (row_dict) ++self->current_row; + if (row_dict) + ++self->current_row; } return row_dict; @@ -595,9 +614,9 @@ query_singledict(queryObject *self, PyObject *noargs) /* Retrieve the last query result as a list of dictionaries. */ static char query_dictresult__doc__[] = -"dictresult() -- Get the result of a query\n\n" -"The result is returned as a list of rows, each one a dictionary with\n" -"the field names used as the keys.\n"; + "dictresult() -- Get the result of a query\n\n" + "The result is returned as a list of rows, each one a dictionary with\n" + "the field names used as the keys.\n"; static PyObject * query_dictresult(queryObject *self, PyObject *noargs) @@ -606,7 +625,6 @@ query_dictresult(queryObject *self, PyObject *noargs) int i; if ((result_list = _get_async_result(self, 0)) == (PyObject *)self) { - if (!(result_list = PyList_New(self->max_row))) { return NULL; } @@ -615,7 +633,8 @@ query_dictresult(queryObject *self, PyObject *noargs) PyObject *row_dict = query_next_dict(self, noargs); if (!row_dict) { - Py_DECREF(result_list); return NULL; + Py_DECREF(result_list); + return NULL; } PyList_SET_ITEM(result_list, i, row_dict); } @@ -626,9 +645,9 @@ query_dictresult(queryObject *self, PyObject *noargs) /* Retrieve last result as iterator of dictionaries. */ static char query_dictiter__doc__[] = -"dictiter() -- Get the result of a query\n\n" -"The result is returned as an iterator of rows, each one a a dictionary\n" -"with the field names used as the keys.\n"; + "dictiter() -- Get the result of a query\n\n" + "The result is returned as an iterator of rows, each one a a dictionary\n" + "with the field names used as the keys.\n"; static PyObject * query_dictiter(queryObject *self, PyObject *noargs) @@ -647,10 +666,10 @@ query_dictiter(queryObject *self, PyObject *noargs) /* Retrieve one row from the result as a named tuple. */ static char query_onenamed__doc__[] = -"onenamed() -- Get one row from the result of a query\n\n" -"Only one row from the result is returned as a named tuple of fields.\n" -"This method can be called multiple times to return more rows.\n" -"It returns None if the result does not contain one more row.\n"; + "onenamed() -- Get one row from the result of a query\n\n" + "Only one row from the result is returned as a named tuple of fields.\n" + "This method can be called multiple times to return more rows.\n" + "It returns None if the result does not contain one more row.\n"; static PyObject * query_onenamed(queryObject *self, PyObject *noargs) @@ -665,7 +684,8 @@ query_onenamed(queryObject *self, PyObject *noargs) return res; if (self->current_row >= self->max_row) { - Py_INCREF(Py_None); return Py_None; + Py_INCREF(Py_None); + return Py_None; } return PyObject_CallFunction(namednext, "(O)", self); @@ -673,11 +693,14 @@ query_onenamed(queryObject *self, PyObject *noargs) /* Retrieve the single row from the result as a tuple. */ static char query_singlenamed__doc__[] = -"singlenamed() -- Get the result of a query as single row\n\n" -"The single row from the query result is returned as named tuple of fields.\n" -"This method returns the same single row when called multiple times.\n" -"It raises an InvalidResultError if the result doesn't have exactly one row,\n" -"which will be of type NoResultError or MultipleResultsError specifically.\n"; + "singlenamed() -- Get the result of a query as single row\n\n" + "The single row from the query result is returned as named tuple of " + "fields.\n" + "This method returns the same single row when called multiple times.\n" + "It raises an InvalidResultError if the result doesn't have exactly one " + "row,\n" + "which will be of type NoResultError or MultipleResultsError " + "specifically.\n"; static PyObject * query_singlenamed(queryObject *self, PyObject *noargs) @@ -705,9 +728,10 @@ query_singlenamed(queryObject *self, PyObject *noargs) /* Retrieve last result as list of named tuples. */ static char query_namedresult__doc__[] = -"namedresult() -- Get the result of a query\n\n" -"The result is returned as a list of rows, each one a named tuple of fields\n" -"in the order returned by the server.\n"; + "namedresult() -- Get the result of a query\n\n" + "The result is returned as a list of rows, each one a named tuple of " + "fields\n" + "in the order returned by the server.\n"; static PyObject * query_namedresult(queryObject *self, PyObject *noargs) @@ -720,8 +744,10 @@ query_namedresult(queryObject *self, PyObject *noargs) if ((res_list = _get_async_result(self, 1)) == (PyObject *)self) { res = PyObject_CallFunction(namediter, "(O)", self); - if (!res) return NULL; - if (PyList_Check(res)) return res; + if (!res) + return NULL; + if (PyList_Check(res)) + return res; res_list = PySequence_List(res); Py_DECREF(res); } @@ -731,9 +757,9 @@ query_namedresult(queryObject *self, PyObject *noargs) /* Retrieve last result as iterator of named tuples. */ static char query_namediter__doc__[] = -"namediter() -- Get the result of a query\n\n" -"The result is returned as an iterator of rows, each one a named tuple\n" -"of fields in the order returned by the server.\n"; + "namediter() -- Get the result of a query\n\n" + "The result is returned as an iterator of rows, each one a named tuple\n" + "of fields in the order returned by the server.\n"; static PyObject * query_namediter(queryObject *self, PyObject *noargs) @@ -745,11 +771,12 @@ query_namediter(queryObject *self, PyObject *noargs) } if ((res_iter = _get_async_result(self, 1)) == (PyObject *)self) { - res = PyObject_CallFunction(namediter, "(O)", self); - if (!res) return NULL; - if (!PyList_Check(res)) return res; - res_iter = (Py_TYPE(res)->tp_iter)((PyObject *) self); + if (!res) + return NULL; + if (!PyList_Check(res)) + return res; + res_iter = (Py_TYPE(res)->tp_iter)((PyObject *)self); Py_DECREF(res); } @@ -758,9 +785,9 @@ query_namediter(queryObject *self, PyObject *noargs) /* Retrieve the last query result as a list of scalar values. */ static char query_scalarresult__doc__[] = -"scalarresult() -- Get query result as scalars\n\n" -"The result is returned as a list of scalar values where the values\n" -"are the first fields of the rows in the order returned by the server.\n"; + "scalarresult() -- Get query result as scalars\n\n" + "The result is returned as a list of scalar values where the values\n" + "are the first fields of the rows in the order returned by the server.\n"; static PyObject * query_scalarresult(queryObject *self, PyObject *noargs) @@ -768,7 +795,6 @@ query_scalarresult(queryObject *self, PyObject *noargs) PyObject *result_list; if ((result_list = _get_async_result(self, 0)) == (PyObject *)self) { - if (!self->num_fields) { set_error_msg(ProgrammingError, "No fields in result"); return NULL; @@ -778,14 +804,13 @@ query_scalarresult(queryObject *self, PyObject *noargs) return NULL; } - for (self->current_row = 0; - self->current_row < self->max_row; - ++self->current_row) - { + for (self->current_row = 0; self->current_row < self->max_row; + ++self->current_row) { PyObject *value = _query_value_in_column(self, 0); if (!value) { - Py_DECREF(result_list); return NULL; + Py_DECREF(result_list); + return NULL; } PyList_SET_ITEM(result_list, self->current_row, value); } @@ -796,9 +821,9 @@ query_scalarresult(queryObject *self, PyObject *noargs) /* Retrieve the last query result as iterator of scalar values. */ static char query_scalariter__doc__[] = -"scalariter() -- Get query result as scalars\n\n" -"The result is returned as an iterator of scalar values where the values\n" -"are the first fields of the rows in the order returned by the server.\n"; + "scalariter() -- Get query result as scalars\n\n" + "The result is returned as an iterator of scalar values where the values\n" + "are the first fields of the rows in the order returned by the server.\n"; static PyObject * query_scalariter(queryObject *self, PyObject *noargs) @@ -822,10 +847,12 @@ query_scalariter(queryObject *self, PyObject *noargs) /* Retrieve one result as scalar value. */ static char query_onescalar__doc__[] = -"onescalar() -- Get one scalar value from the result of a query\n\n" -"Returns the first field of the next row from the result as a scalar value.\n" -"This method can be called multiple times to return more rows as scalars.\n" -"It returns None if the result does not contain one more row.\n"; + "onescalar() -- Get one scalar value from the result of a query\n\n" + "Returns the first field of the next row from the result as a scalar " + "value.\n" + "This method can be called multiple times to return more rows as " + "scalars.\n" + "It returns None if the result does not contain one more row.\n"; static PyObject * query_onescalar(queryObject *self, PyObject *noargs) @@ -833,18 +860,19 @@ query_onescalar(queryObject *self, PyObject *noargs) PyObject *value; if ((value = _get_async_result(self, 0)) == (PyObject *)self) { - if (!self->num_fields) { set_error_msg(ProgrammingError, "No fields in result"); return NULL; } if (self->current_row >= self->max_row) { - Py_INCREF(Py_None); return Py_None; + Py_INCREF(Py_None); + return Py_None; } value = _query_value_in_column(self, 0); - if (value) ++self->current_row; + if (value) + ++self->current_row; } return value; @@ -852,11 +880,14 @@ query_onescalar(queryObject *self, PyObject *noargs) /* Retrieves the single row from the result as a tuple. */ static char query_singlescalar__doc__[] = -"singlescalar() -- Get scalar value from single result of a query\n\n" -"Returns the first field of the next row from the result as a scalar value.\n" -"This method returns the same single row when called multiple times.\n" -"It raises an InvalidResultError if the result doesn't have exactly one row,\n" -"which will be of type NoResultError or MultipleResultsError specifically.\n"; + "singlescalar() -- Get scalar value from single result of a query\n\n" + "Returns the first field of the next row from the result as a scalar " + "value.\n" + "This method returns the same single row when called multiple times.\n" + "It raises an InvalidResultError if the result doesn't have exactly one " + "row,\n" + "which will be of type NoResultError or MultipleResultsError " + "specifically.\n"; static PyObject * query_singlescalar(queryObject *self, PyObject *noargs) @@ -864,7 +895,6 @@ query_singlescalar(queryObject *self, PyObject *noargs) PyObject *value; if ((value = _get_async_result(self, 0)) == (PyObject *)self) { - if (!self->num_fields) { set_error_msg(ProgrammingError, "No fields in result"); return NULL; @@ -880,7 +910,8 @@ query_singlescalar(queryObject *self, PyObject *noargs) self->current_row = 0; value = _query_value_in_column(self, 0); - if (value) ++self->current_row; + if (value) + ++self->current_row; } return value; @@ -888,92 +919,86 @@ query_singlescalar(queryObject *self, PyObject *noargs) /* Query sequence protocol methods */ static PySequenceMethods query_sequence_methods = { - (lenfunc) query_len, /* sq_length */ - 0, /* sq_concat */ - 0, /* sq_repeat */ - (ssizeargfunc) query_getitem, /* sq_item */ - 0, /* sq_ass_item */ - 0, /* sq_contains */ - 0, /* sq_inplace_concat */ - 0, /* sq_inplace_repeat */ + (lenfunc)query_len, /* sq_length */ + 0, /* sq_concat */ + 0, /* sq_repeat */ + (ssizeargfunc)query_getitem, /* sq_item */ + 0, /* sq_ass_item */ + 0, /* sq_contains */ + 0, /* sq_inplace_concat */ + 0, /* sq_inplace_repeat */ }; /* Query object methods */ static struct PyMethodDef query_methods[] = { - {"getresult", (PyCFunction) query_getresult, - METH_NOARGS, query_getresult__doc__}, - {"dictresult", (PyCFunction) query_dictresult, - METH_NOARGS, query_dictresult__doc__}, - {"dictiter", (PyCFunction) query_dictiter, - METH_NOARGS, query_dictiter__doc__}, - {"namedresult", (PyCFunction) query_namedresult, - METH_NOARGS, query_namedresult__doc__}, - {"namediter", (PyCFunction) query_namediter, - METH_NOARGS, query_namediter__doc__}, - {"one", (PyCFunction) query_one, - METH_NOARGS, query_one__doc__}, - {"single", (PyCFunction) query_single, - METH_NOARGS, query_single__doc__}, - {"onedict", (PyCFunction) query_onedict, - METH_NOARGS, query_onedict__doc__}, - {"singledict", (PyCFunction) query_singledict, - METH_NOARGS, query_singledict__doc__}, - {"onenamed", (PyCFunction) query_onenamed, - METH_NOARGS, query_onenamed__doc__}, - {"singlenamed", (PyCFunction) query_singlenamed, - METH_NOARGS, query_singlenamed__doc__}, - {"scalarresult", (PyCFunction) query_scalarresult, - METH_NOARGS, query_scalarresult__doc__}, - {"scalariter", (PyCFunction) query_scalariter, - METH_NOARGS, query_scalariter__doc__}, - {"onescalar", (PyCFunction) query_onescalar, - METH_NOARGS, query_onescalar__doc__}, - {"singlescalar", (PyCFunction) query_singlescalar, - METH_NOARGS, query_singlescalar__doc__}, - {"fieldname", (PyCFunction) query_fieldname, - METH_VARARGS, query_fieldname__doc__}, - {"fieldnum", (PyCFunction) query_fieldnum, - METH_VARARGS, query_fieldnum__doc__}, - {"listfields", (PyCFunction) query_listfields, - METH_NOARGS, query_listfields__doc__}, - {"fieldinfo", (PyCFunction) query_fieldinfo, - METH_VARARGS, query_fieldinfo__doc__}, - {"memsize", (PyCFunction) query_memsize, - METH_NOARGS, query_memsize__doc__}, - {NULL, NULL} -}; + {"getresult", (PyCFunction)query_getresult, METH_NOARGS, + query_getresult__doc__}, + {"dictresult", (PyCFunction)query_dictresult, METH_NOARGS, + query_dictresult__doc__}, + {"dictiter", (PyCFunction)query_dictiter, METH_NOARGS, + query_dictiter__doc__}, + {"namedresult", (PyCFunction)query_namedresult, METH_NOARGS, + query_namedresult__doc__}, + {"namediter", (PyCFunction)query_namediter, METH_NOARGS, + query_namediter__doc__}, + {"one", (PyCFunction)query_one, METH_NOARGS, query_one__doc__}, + {"single", (PyCFunction)query_single, METH_NOARGS, query_single__doc__}, + {"onedict", (PyCFunction)query_onedict, METH_NOARGS, query_onedict__doc__}, + {"singledict", (PyCFunction)query_singledict, METH_NOARGS, + query_singledict__doc__}, + {"onenamed", (PyCFunction)query_onenamed, METH_NOARGS, + query_onenamed__doc__}, + {"singlenamed", (PyCFunction)query_singlenamed, METH_NOARGS, + query_singlenamed__doc__}, + {"scalarresult", (PyCFunction)query_scalarresult, METH_NOARGS, + query_scalarresult__doc__}, + {"scalariter", (PyCFunction)query_scalariter, METH_NOARGS, + query_scalariter__doc__}, + {"onescalar", (PyCFunction)query_onescalar, METH_NOARGS, + query_onescalar__doc__}, + {"singlescalar", (PyCFunction)query_singlescalar, METH_NOARGS, + query_singlescalar__doc__}, + {"fieldname", (PyCFunction)query_fieldname, METH_VARARGS, + query_fieldname__doc__}, + {"fieldnum", (PyCFunction)query_fieldnum, METH_VARARGS, + query_fieldnum__doc__}, + {"listfields", (PyCFunction)query_listfields, METH_NOARGS, + query_listfields__doc__}, + {"fieldinfo", (PyCFunction)query_fieldinfo, METH_VARARGS, + query_fieldinfo__doc__}, + {"memsize", (PyCFunction)query_memsize, METH_NOARGS, query_memsize__doc__}, + {NULL, NULL}}; static char query__doc__[] = "PyGreSQL query object"; /* Query type definition */ static PyTypeObject queryType = { - PyVarObject_HEAD_INIT(NULL, 0) - "pg.Query", /* tp_name */ - sizeof(queryObject), /* tp_basicsize */ - 0, /* tp_itemsize */ + PyVarObject_HEAD_INIT(NULL, 0) "pg.Query", /* tp_name */ + sizeof(queryObject), /* tp_basicsize */ + 0, /* tp_itemsize */ /* methods */ - (destructor) query_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_compare */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - &query_sequence_methods, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - (reprfunc) query_str, /* tp_str */ - PyObject_GenericGetAttr, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - query__doc__, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - (getiterfunc) query_iter, /* tp_iter */ - (iternextfunc) query_next, /* tp_iternext */ - query_methods, /* tp_methods */ + (destructor)query_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + &query_sequence_methods, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + (reprfunc)query_str, /* tp_str */ + PyObject_GenericGetAttr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + query__doc__, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + (getiterfunc)query_iter, /* tp_iter */ + (iternextfunc)query_next, /* tp_iternext */ + query_methods, /* tp_methods */ }; diff --git a/pgsource.c b/pgsource.c index 7b081273..73c9a52b 100644 --- a/pgsource.c +++ b/pgsource.c @@ -71,7 +71,7 @@ source_getattr(sourceObject *self, PyObject *nameobj) if (!strcmp(name, "pgcnx")) { if (_check_source_obj(self, 0)) { Py_INCREF(self->pgcnx); - return (PyObject *) (self->pgcnx); + return (PyObject *)(self->pgcnx); } Py_INCREF(Py_None); return Py_None; @@ -94,7 +94,7 @@ source_getattr(sourceObject *self, PyObject *nameobj) return PyLong_FromLong(self->num_fields); /* seeks name in methods (fallback) */ - return PyObject_GenericGetAttr((PyObject *) self, nameobj); + return PyObject_GenericGetAttr((PyObject *)self, nameobj); } /* Set source object attributes. */ @@ -119,8 +119,9 @@ source_setattr(sourceObject *self, char *name, PyObject *v) /* Close object. */ static char source_close__doc__[] = -"close() -- close query object without deleting it\n\n" -"All instances of the query object can no longer be used after this call.\n"; + "close() -- close query object without deleting it\n\n" + "All instances of the query object can no longer be used after this " + "call.\n"; static PyObject * source_close(sourceObject *self, PyObject *noargs) @@ -141,15 +142,15 @@ source_close(sourceObject *self, PyObject *noargs) /* Database query. */ static char source_execute__doc__[] = -"execute(sql) -- execute a SQL statement (string)\n\n" -"On success, this call returns the number of affected rows, or None\n" -"for DQL (SELECT, ...) statements. The fetch (fetch(), fetchone()\n" -"and fetchall()) methods can be used to get result rows.\n"; + "execute(sql) -- execute a SQL statement (string)\n\n" + "On success, this call returns the number of affected rows, or None\n" + "for DQL (SELECT, ...) statements. The fetch (fetch(), fetchone()\n" + "and fetchall()) methods can be used to get result rows.\n"; static PyObject * source_execute(sourceObject *self, PyObject *sql) { - PyObject *tmp_obj = NULL; /* auxiliary string object */ + PyObject *tmp_obj = NULL; /* auxiliary string object */ char *query; int encoding; @@ -165,7 +166,8 @@ source_execute(sourceObject *self, PyObject *sql) } else if (PyUnicode_Check(sql)) { tmp_obj = get_encoded_string(sql, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ query = PyBytes_AsString(tmp_obj); } else { @@ -205,30 +207,29 @@ source_execute(sourceObject *self, PyObject *sql) /* checks result status */ switch (PQresultStatus(self->result)) { /* query succeeded */ - case PGRES_TUPLES_OK: /* DQL: returns None (DB-SIG compliant) */ + case PGRES_TUPLES_OK: /* DQL: returns None (DB-SIG compliant) */ self->result_type = RESULT_DQL; self->max_row = PQntuples(self->result); self->num_fields = PQnfields(self->result); Py_INCREF(Py_None); return Py_None; - case PGRES_COMMAND_OK: /* other requests */ + case PGRES_COMMAND_OK: /* other requests */ case PGRES_COPY_OUT: - case PGRES_COPY_IN: - { - long num_rows; - char *tmp; - - tmp = PQcmdTuples(self->result); - if (tmp[0]) { - self->result_type = RESULT_DML; - num_rows = atol(tmp); - } - else { - self->result_type = RESULT_DDL; - num_rows = -1; - } - return PyLong_FromLong(num_rows); + case PGRES_COPY_IN: { + long num_rows; + char *tmp; + + tmp = PQcmdTuples(self->result); + if (tmp[0]) { + self->result_type = RESULT_DML; + num_rows = atol(tmp); } + else { + self->result_type = RESULT_DDL; + num_rows = -1; + } + return PyLong_FromLong(num_rows); + } /* query failed */ case PGRES_EMPTY_QUERY: @@ -238,7 +239,7 @@ source_execute(sourceObject *self, PyObject *sql) case PGRES_FATAL_ERROR: case PGRES_NONFATAL_ERROR: set_error(ProgrammingError, "Cannot execute command", - self->pgcnx->cnx, self->result); + self->pgcnx->cnx, self->result); break; default: set_error_msg(InternalError, @@ -254,7 +255,7 @@ source_execute(sourceObject *self, PyObject *sql) /* Get oid status for last query (valid for INSERTs, 0 for other). */ static char source_oidstatus__doc__[] = -"oidstatus() -- return oid of last inserted row (if available)"; + "oidstatus() -- return oid of last inserted row (if available)"; static PyObject * source_oidstatus(sourceObject *self, PyObject *noargs) @@ -272,14 +273,14 @@ source_oidstatus(sourceObject *self, PyObject *noargs) return Py_None; } - return PyLong_FromLong((long) oid); + return PyLong_FromLong((long)oid); } /* Fetch rows from last result. */ static char source_fetch__doc__[] = -"fetch(num) -- return the next num rows from the last result in a list\n\n" -"If num parameter is omitted arraysize attribute value is used.\n" -"If size equals -1, all rows are fetched.\n"; + "fetch(num) -- return the next num rows from the last result in a list\n\n" + "If num parameter is omitted arraysize attribute value is used.\n" + "If size equals -1, all rows are fetched.\n"; static PyObject * source_fetch(sourceObject *self, PyObject *args) @@ -309,7 +310,8 @@ source_fetch(sourceObject *self, PyObject *args) } /* allocate list for result */ - if (!(res_list = PyList_New(0))) return NULL; + if (!(res_list = PyList_New(0))) + return NULL; encoding = self->encoding; @@ -319,7 +321,8 @@ source_fetch(sourceObject *self, PyObject *args) int j; if (!(rowtuple = PyTuple_New(self->num_fields))) { - Py_DECREF(res_list); return NULL; + Py_DECREF(res_list); + return NULL; } for (j = 0; j < self->num_fields; ++j) { @@ -345,7 +348,9 @@ source_fetch(sourceObject *self, PyObject *args) } if (PyList_Append(res_list, rowtuple)) { - Py_DECREF(rowtuple); Py_DECREF(res_list); return NULL; + Py_DECREF(rowtuple); + Py_DECREF(res_list); + return NULL; } Py_DECREF(rowtuple); } @@ -387,7 +392,7 @@ _source_move(sourceObject *self, int move) /* Move to first result row. */ static char source_movefirst__doc__[] = -"movefirst() -- move to first result row"; + "movefirst() -- move to first result row"; static PyObject * source_movefirst(sourceObject *self, PyObject *noargs) @@ -397,7 +402,7 @@ source_movefirst(sourceObject *self, PyObject *noargs) /* Move to last result row. */ static char source_movelast__doc__[] = -"movelast() -- move to last valid result row"; + "movelast() -- move to last valid result row"; static PyObject * source_movelast(sourceObject *self, PyObject *noargs) @@ -406,8 +411,7 @@ source_movelast(sourceObject *self, PyObject *noargs) } /* Move to next result row. */ -static char source_movenext__doc__[] = -"movenext() -- move to next result row"; +static char source_movenext__doc__[] = "movenext() -- move to next result row"; static PyObject * source_movenext(sourceObject *self, PyObject *noargs) @@ -417,7 +421,7 @@ source_movenext(sourceObject *self, PyObject *noargs) /* Move to previous result row. */ static char source_moveprev__doc__[] = -"moveprev() -- move to previous result row"; + "moveprev() -- move to previous result row"; static PyObject * source_moveprev(sourceObject *self, PyObject *noargs) @@ -427,17 +431,17 @@ source_moveprev(sourceObject *self, PyObject *noargs) /* Put copy data. */ static char source_putdata__doc__[] = -"putdata(buffer) -- send data to server during copy from stdin"; + "putdata(buffer) -- send data to server during copy from stdin"; static PyObject * source_putdata(sourceObject *self, PyObject *buffer) { - PyObject *tmp_obj = NULL; /* an auxiliary object */ - char *buf; /* the buffer as encoded string */ - Py_ssize_t nbytes; /* length of string */ - char *errormsg = NULL; /* error message */ - int res; /* direct result of the operation */ - PyObject *ret; /* return value */ + PyObject *tmp_obj = NULL; /* an auxiliary object */ + char *buf; /* the buffer as encoded string */ + Py_ssize_t nbytes; /* length of string */ + char *errormsg = NULL; /* error message */ + int res; /* direct result of the operation */ + PyObject *ret; /* return value */ /* checks validity */ if (!_check_source_obj(self, CHECK_CNX)) { @@ -459,9 +463,10 @@ source_putdata(sourceObject *self, PyObject *buffer) } else if (PyUnicode_Check(buffer)) { /* or pass a unicode string */ - tmp_obj = get_encoded_string( - buffer, PQclientEncoding(self->pgcnx->cnx)); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + tmp_obj = + get_encoded_string(buffer, PQclientEncoding(self->pgcnx->cnx)); + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &buf, &nbytes); } else if (PyErr_GivenExceptionMatches(buffer, PyExc_BaseException)) { @@ -470,10 +475,11 @@ source_putdata(sourceObject *self, PyObject *buffer) if (PyUnicode_Check(tmp_obj)) { PyObject *obj = tmp_obj; - tmp_obj = get_encoded_string( - obj, PQclientEncoding(self->pgcnx->cnx)); + tmp_obj = + get_encoded_string(obj, PQclientEncoding(self->pgcnx->cnx)); Py_DECREF(obj); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ } errormsg = PyBytes_AsString(tmp_obj); buf = NULL; @@ -487,8 +493,7 @@ source_putdata(sourceObject *self, PyObject *buffer) /* checks validity */ if (!_check_source_obj(self, CHECK_CNX | CHECK_RESULT) || - PQresultStatus(self->result) != PGRES_COPY_IN) - { + PQresultStatus(self->result) != PGRES_COPY_IN) { PyErr_SetString(PyExc_IOError, "Connection is invalid or not in copy_in state"); Py_XDECREF(tmp_obj); @@ -496,7 +501,7 @@ source_putdata(sourceObject *self, PyObject *buffer) } if (buf) { - res = nbytes ? PQputCopyData(self->pgcnx->cnx, buf, (int) nbytes) : 1; + res = nbytes ? PQputCopyData(self->pgcnx->cnx, buf, (int)nbytes) : 1; } else { res = PQputCopyEnd(self->pgcnx->cnx, errormsg); @@ -513,7 +518,7 @@ source_putdata(sourceObject *self, PyObject *buffer) ret = Py_None; Py_INCREF(ret); } - else { /* copy is done */ + else { /* copy is done */ PGresult *result; /* final result of the operation */ Py_BEGIN_ALLOW_THREADS; @@ -529,7 +534,8 @@ source_putdata(sourceObject *self, PyObject *buffer) ret = PyLong_FromLong(num_rows); } else { - if (!errormsg) errormsg = PQerrorMessage(self->pgcnx->cnx); + if (!errormsg) + errormsg = PQerrorMessage(self->pgcnx->cnx); PyErr_SetString(PyExc_IOError, errormsg); ret = NULL; } @@ -544,15 +550,15 @@ source_putdata(sourceObject *self, PyObject *buffer) /* Get copy data. */ static char source_getdata__doc__[] = -"getdata(decode) -- receive data to server during copy to stdout"; + "getdata(decode) -- receive data to server during copy to stdout"; static PyObject * source_getdata(sourceObject *self, PyObject *args) { - int *decode = 0; /* decode flag */ - char *buffer; /* the copied buffer as encoded byte string */ - Py_ssize_t nbytes; /* length of the byte string */ - PyObject *ret; /* return value */ + int *decode = 0; /* decode flag */ + char *buffer; /* the copied buffer as encoded byte string */ + Py_ssize_t nbytes; /* length of the byte string */ + PyObject *ret; /* return value */ /* checks validity */ if (!_check_source_obj(self, CHECK_CNX)) { @@ -570,8 +576,7 @@ source_getdata(sourceObject *self, PyObject *args) /* checks validity */ if (!_check_source_obj(self, CHECK_CNX | CHECK_RESULT) || - PQresultStatus(self->result) != PGRES_COPY_OUT) - { + PQresultStatus(self->result) != PGRES_COPY_OUT) { PyErr_SetString(PyExc_IOError, "Connection is invalid or not in copy_out state"); return NULL; @@ -584,7 +589,7 @@ source_getdata(sourceObject *self, PyObject *args) return NULL; } - if (nbytes == -1) { /* copy is done */ + if (nbytes == -1) { /* copy is done */ PGresult *result; /* final result of the operation */ Py_BEGIN_ALLOW_THREADS; @@ -609,9 +614,9 @@ source_getdata(sourceObject *self, PyObject *args) self->result_type = RESULT_EMPTY; } else { /* a row has been returned */ - ret = decode ? get_decoded_string( - buffer, nbytes, PQclientEncoding(self->pgcnx->cnx)) : - PyBytes_FromStringAndSize(buffer, nbytes); + ret = decode ? get_decoded_string(buffer, nbytes, + PQclientEncoding(self->pgcnx->cnx)) + : PyBytes_FromStringAndSize(buffer, nbytes); PQfreemem(buffer); } @@ -633,7 +638,7 @@ _source_fieldindex(sourceObject *self, PyObject *param, const char *usage) num = PQfnumber(self->result, PyBytes_AsString(param)); } else if (PyLong_Check(param)) { - num = (int) PyLong_AsLong(param); + num = (int)PyLong_AsLong(param); } else { PyErr_SetString(PyExc_TypeError, usage); @@ -664,20 +669,18 @@ _source_buildinfo(sourceObject *self, int num) /* affects field information */ PyTuple_SET_ITEM(result, 0, PyLong_FromLong(num)); PyTuple_SET_ITEM(result, 1, - PyUnicode_FromString(PQfname(self->result, num))); + PyUnicode_FromString(PQfname(self->result, num))); PyTuple_SET_ITEM(result, 2, - PyLong_FromLong((long) PQftype(self->result, num))); - PyTuple_SET_ITEM(result, 3, - PyLong_FromLong(PQfsize(self->result, num))); - PyTuple_SET_ITEM(result, 4, - PyLong_FromLong(PQfmod(self->result, num))); + PyLong_FromLong((long)PQftype(self->result, num))); + PyTuple_SET_ITEM(result, 3, PyLong_FromLong(PQfsize(self->result, num))); + PyTuple_SET_ITEM(result, 4, PyLong_FromLong(PQfmod(self->result, num))); return result; } /* Lists fields info. */ static char source_listinfo__doc__[] = -"listinfo() -- get information for all fields (position, name, type oid)"; + "listinfo() -- get information for all fields (position, name, type oid)"; static PyObject * source_listInfo(sourceObject *self, PyObject *noargs) @@ -710,7 +713,7 @@ source_listInfo(sourceObject *self, PyObject *noargs) /* List fields information for last result. */ static char source_fieldinfo__doc__[] = -"fieldinfo(desc) -- get specified field info (position, name, type oid)"; + "fieldinfo(desc) -- get specified field info (position, name, type oid)"; static PyObject * source_fieldinfo(sourceObject *self, PyObject *desc) @@ -719,9 +722,9 @@ source_fieldinfo(sourceObject *self, PyObject *desc) /* checks args and validity */ if ((num = _source_fieldindex( - self, desc, - "Method fieldinfo() needs a string or integer as argument")) == -1) - { + self, desc, + "Method fieldinfo() needs a string or integer as argument")) == + -1) { return NULL; } @@ -731,7 +734,7 @@ source_fieldinfo(sourceObject *self, PyObject *desc) /* Retrieve field value. */ static char source_field__doc__[] = -"field(desc) -- return specified field value"; + "field(desc) -- return specified field value"; static PyObject * source_field(sourceObject *self, PyObject *desc) @@ -740,9 +743,8 @@ source_field(sourceObject *self, PyObject *desc) /* checks args and validity */ if ((num = _source_fieldindex( - self, desc, - "Method field() needs a string or integer as argument")) == -1) - { + self, desc, + "Method field() needs a string or integer as argument")) == -1) { return NULL; } @@ -756,78 +758,70 @@ source_dir(connObject *self, PyObject *noargs) { PyObject *attrs; - attrs = PyObject_Dir(PyObject_Type((PyObject *) self)); - PyObject_CallMethod( - attrs, "extend", "[sssss]", - "pgcnx", "arraysize", "resulttype", "ntuples", "nfields"); + attrs = PyObject_Dir(PyObject_Type((PyObject *)self)); + PyObject_CallMethod(attrs, "extend", "[sssss]", "pgcnx", "arraysize", + "resulttype", "ntuples", "nfields"); return attrs; } /* Source object methods */ static PyMethodDef source_methods[] = { - {"__dir__", (PyCFunction) source_dir, METH_NOARGS, NULL}, - - {"close", (PyCFunction) source_close, - METH_NOARGS, source_close__doc__}, - {"execute", (PyCFunction) source_execute, - METH_O, source_execute__doc__}, - {"oidstatus", (PyCFunction) source_oidstatus, - METH_NOARGS, source_oidstatus__doc__}, - {"fetch", (PyCFunction) source_fetch, - METH_VARARGS, source_fetch__doc__}, - {"movefirst", (PyCFunction) source_movefirst, - METH_NOARGS, source_movefirst__doc__}, - {"movelast", (PyCFunction) source_movelast, - METH_NOARGS, source_movelast__doc__}, - {"movenext", (PyCFunction) source_movenext, - METH_NOARGS, source_movenext__doc__}, - {"moveprev", (PyCFunction) source_moveprev, - METH_NOARGS, source_moveprev__doc__}, - {"putdata", (PyCFunction) source_putdata, - METH_O, source_putdata__doc__}, - {"getdata", (PyCFunction) source_getdata, - METH_VARARGS, source_getdata__doc__}, - {"field", (PyCFunction) source_field, - METH_O, source_field__doc__}, - {"fieldinfo", (PyCFunction) source_fieldinfo, - METH_O, source_fieldinfo__doc__}, - {"listinfo", (PyCFunction) source_listInfo, - METH_NOARGS, source_listinfo__doc__}, - {NULL, NULL} -}; + {"__dir__", (PyCFunction)source_dir, METH_NOARGS, NULL}, + + {"close", (PyCFunction)source_close, METH_NOARGS, source_close__doc__}, + {"execute", (PyCFunction)source_execute, METH_O, source_execute__doc__}, + {"oidstatus", (PyCFunction)source_oidstatus, METH_NOARGS, + source_oidstatus__doc__}, + {"fetch", (PyCFunction)source_fetch, METH_VARARGS, source_fetch__doc__}, + {"movefirst", (PyCFunction)source_movefirst, METH_NOARGS, + source_movefirst__doc__}, + {"movelast", (PyCFunction)source_movelast, METH_NOARGS, + source_movelast__doc__}, + {"movenext", (PyCFunction)source_movenext, METH_NOARGS, + source_movenext__doc__}, + {"moveprev", (PyCFunction)source_moveprev, METH_NOARGS, + source_moveprev__doc__}, + {"putdata", (PyCFunction)source_putdata, METH_O, source_putdata__doc__}, + {"getdata", (PyCFunction)source_getdata, METH_VARARGS, + source_getdata__doc__}, + {"field", (PyCFunction)source_field, METH_O, source_field__doc__}, + {"fieldinfo", (PyCFunction)source_fieldinfo, METH_O, + source_fieldinfo__doc__}, + {"listinfo", (PyCFunction)source_listInfo, METH_NOARGS, + source_listinfo__doc__}, + {NULL, NULL}}; static char source__doc__[] = "PyGreSQL source object"; /* Source type definition */ static PyTypeObject sourceType = { - PyVarObject_HEAD_INIT(NULL, 0) - "pgdb.Source", /* tp_name */ - sizeof(sourceObject), /* tp_basicsize */ - 0, /* tp_itemsize */ + PyVarObject_HEAD_INIT(NULL, 0) "pgdb.Source", /* tp_name */ + sizeof(sourceObject), /* tp_basicsize */ + 0, /* tp_itemsize */ /* methods */ - (destructor) source_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - (setattrfunc) source_setattr, /* tp_setattr */ - 0, /* tp_compare */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - (reprfunc) source_str, /* tp_str */ - (getattrofunc) source_getattr, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - source__doc__, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - source_methods, /* tp_methods */ + (destructor)source_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + (setattrfunc)source_setattr, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + (reprfunc)source_str, /* tp_str */ + (getattrofunc)source_getattr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + source__doc__, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + source_methods, /* tp_methods */ }; diff --git a/tox.ini b/tox.ini index 9ddc3a75..37b3a39d 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ # config file for tox [tox] -envlist = py3{7,8,9,10,11},ruff,docs +envlist = py3{7,8,9,10,11},ruff,cformat,docs [testenv:ruff] basepython = python3.11 @@ -9,6 +9,13 @@ deps = ruff>=0.0.287 commands = ruff setup.py pg.py pgdb.py tests +[testenv:cformat] +basepython = python3.11 +allowlist_externals = + sh +commands = + sh -c "! (clang-format --style=file -n *.c 2>&1 | tee /dev/tty | grep format-violations)" + [testenv:docs] basepython = python3.11 deps = From a6c38643170a7037bd93c3290921e632051f1b67 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 3 Sep 2023 15:47:13 +0200 Subject: [PATCH 048/118] Add type hints to the pg module --- docs/contents/pg/db_wrapper.rst | 18 +- docs/contents/pg/query.rst | 4 +- pg.py | 1115 ++++++++++++++++-------------- pgdb.py | 172 ++--- pgmodule.c | 8 +- pyproject.toml | 15 + tests/config.py | 6 +- tests/dbapi20.py | 7 +- tests/test_classic_connection.py | 92 +-- tests/test_classic_dbwrapper.py | 272 ++++---- tests/test_classic_functions.py | 54 +- tests/test_classic_largeobj.py | 3 + tests/test_dbapi20_copy.py | 6 +- tests/test_tutorial.py | 3 +- tox.ini | 6 + 15 files changed, 969 insertions(+), 812 deletions(-) diff --git a/docs/contents/pg/db_wrapper.rst b/docs/contents/pg/db_wrapper.rst index 68d33c65..64710456 100644 --- a/docs/contents/pg/db_wrapper.rst +++ b/docs/contents/pg/db_wrapper.rst @@ -58,7 +58,7 @@ pkey -- return the primary key of a table Return the primary key of a table :param str table: name of table - :returns: Name of the field which is the primary key of the table + :returns: Name of the field that is the primary key of the table :rtype: str :raises KeyError: the table does not have a primary key @@ -67,6 +67,22 @@ returned as strings unless you set the composite flag. Composite primary keys are always represented as tuples. Note that this raises a KeyError if the table does not have a primary key. +pkeys -- return the primary keys of a table +------------------------------------------- + +.. method:: DB.pkeys(table) + + Return the primary keys of a table as a tuple + + :param str table: name of table + :returns: Names of the fields that are the primary keys of the table + :rtype: tuple + :raises KeyError: the table does not have a primary key + +This method returns the primary keys of a table as a tuple, i.e. +single primary keys are also returned as a tuple with one item. +Note that this raises a KeyError if the table does not have a primary key. + get_databases -- get list of databases in the system ---------------------------------------------------- diff --git a/docs/contents/pg/query.rst b/docs/contents/pg/query.rst index 3232c115..fcee193f 100644 --- a/docs/contents/pg/query.rst +++ b/docs/contents/pg/query.rst @@ -336,10 +336,10 @@ listfields -- list field names of query result List field names of query result :returns: field names - :rtype: list + :rtype: tuple :raises TypeError: too many parameters -This method returns the list of field names defined for the query result. +This method returns the tuple of field names defined for the query result. The fields are in the same order as the result values. fieldname, fieldnum -- field name/number conversion diff --git a/pg.py b/pg.py index d29cb5c2..11aaf90a 100644 --- a/pg.py +++ b/pg.py @@ -20,9 +20,11 @@ For a DB-API 2 compliant interface use the newer pgdb module. """ +from __future__ import annotations + import select import weakref -from collections import OrderedDict, namedtuple +from collections import namedtuple from contextlib import suppress from datetime import date, datetime, time, timedelta from decimal import Decimal @@ -34,7 +36,18 @@ from operator import itemgetter from re import compile as regex from types import MappingProxyType -from typing import Callable, ClassVar, Dict, List, Mapping, Type, Union +from typing import ( + Any, + Callable, + ClassVar, + Generator, + Iterator, + List, + Mapping, + NamedTuple, + Sequence, + TypeVar, +) from uuid import UUID try: @@ -49,15 +62,16 @@ if os.path.exists(os.path.join(path, libpq))] if sys.version_info >= (3, 8): # see https://docs.python.org/3/whatsnew/3.8.html#ctypes + add_dll_dir = os.add_dll_directory # type: ignore for path in paths: - with os.add_dll_directory(os.path.abspath(path)): + with add_dll_dir(os.path.abspath(path)): try: - from _pg import version + from _pg import version # type: ignore except ImportError: pass else: del version - e = None + e = None # type: ignore break if paths: libpq = 'compatible ' + libpq @@ -86,6 +100,7 @@ TRANS_INERROR, TRANS_INTRANS, TRANS_UNKNOWN, + Connection, DatabaseError, DataError, Error, @@ -98,6 +113,7 @@ NotSupportedError, OperationalError, ProgrammingError, + Query, Warning, cast_array, cast_hstore, @@ -148,6 +164,7 @@ 'InvalidResultError', 'MultipleResultsError', 'NoResultError', 'NotSupportedError', 'OperationalError', 'ProgrammingError', + 'Query', 'INV_READ', 'INV_WRITE', 'POLLING_OK', 'POLLING_FAILED', 'POLLING_READING', 'POLLING_WRITING', 'SEEK_CUR', 'SEEK_END', 'SEEK_SET', @@ -164,21 +181,24 @@ 'set_defbase', 'set_defhost', 'set_defopt', 'set_defpasswd', 'set_defport', 'set_defuser', 'set_jsondecode', 'set_query_helpers', 'set_typecast', - 'version', '__version__'] + 'version', '__version__', +] # Auxiliary classes and functions that are independent of a DB connection: -def get_args(func): +def get_args(func: Callable) -> list: return list(signature(func).parameters) # time zones used in Postgres timestamptz output -_timezones = dict(CET='+0100', EET='+0200', EST='-0500', - GMT='+0000', HST='-1000', MET='+0100', MST='-0700', - UCT='+0000', UTC='+0000', WET='+0000') +_timezones: dict[str, str] = { + 'CET': '+0100', 'EET': '+0200', 'EST': '-0500', + 'GMT': '+0000', 'HST': '-1000', 'MET': '+0100', 'MST': '-0700', + 'UCT': '+0000', 'UTC': '+0000', 'WET': '+0000' +} -def _timezone_as_offset(tz): +def _timezone_as_offset(tz: str) -> str: if tz.startswith(('+', '-')): if len(tz) < 5: return tz + '00' @@ -186,7 +206,7 @@ def _timezone_as_offset(tz): return _timezones.get(tz, '+0000') -def _oid_key(table): +def _oid_key(table: str) -> str: """Build oid key from a table name.""" return f'oid({table})' @@ -201,7 +221,7 @@ class Hstore(dict): _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]') @classmethod - def _quote(cls, s): + def _quote(cls, s: Any) -> str: if s is None: return 'NULL' if not isinstance(s, str): @@ -213,7 +233,7 @@ def _quote(cls, s): s = f'"{s}"' return s - def __str__(self): + def __str__(self) -> str: """Create a printable representation of the hstore value.""" q = self._quote return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) @@ -222,12 +242,12 @@ def __str__(self): class Json: """Wrapper class for marking Json values.""" - def __init__(self, obj, encode=None): + def __init__(self, obj: Any, encode: Callable | None = None) -> None: """Initialize the JSON object.""" self.obj = obj self.encode = encode or jsonencode - def __str__(self): + def __str__(self) -> str: """Create a printable representation of the JSON object.""" obj = self.obj if isinstance(obj, str): @@ -241,7 +261,7 @@ class _SimpleTypes(dict): The corresponding Python types and simple names are also mapped. """ - _type_aliases: Mapping[str, List[Union[str, type]]] = MappingProxyType({ + _type_aliases: Mapping[str, list[str | type]] = MappingProxyType({ 'bool': [bool], 'bytea': [Bytea], 'date': ['interval', 'time', 'timetz', 'timestamp', 'timestamptz', @@ -256,7 +276,7 @@ class _SimpleTypes(dict): }) # noinspection PyMissingConstructor - def __init__(self): + def __init__(self) -> None: """Initialize type mapping.""" for typ, keys in self._type_aliases.items(): keys = [typ, *keys] @@ -265,24 +285,24 @@ def __init__(self): if isinstance(key, str): self[f'_{key}'] = f'{typ}[]' elif not isinstance(key, tuple): - self[List[key]] = f'{typ}[]' + self[List[key]] = f'{typ}[]' # type: ignore @staticmethod - def __missing__(key): + def __missing__(key: str) -> str: """Unmapped types are interpreted as text.""" return 'text' - def get_type_dict(self): + def get_type_dict(self) -> dict[type, str]: """Get a plain dictionary of only the types.""" - return dict((key, typ) for key, typ in self.items() - if not isinstance(key, (str, tuple))) + return {key: typ for key, typ in self.items() + if not isinstance(key, (str, tuple))} _simpletypes = _SimpleTypes() _simple_type_dict = _simpletypes.get_type_dict() -def _quote_if_unqualified(param, name): +def _quote_if_unqualified(param: str, name: int | str) -> str: """Quote parameter representing a qualified name. Puts a quote_ident() call around the given parameter unless @@ -300,7 +320,7 @@ class _ParameterList(list): adapt: Callable - def add(self, value, typ=None): + def add(self, value: Any, typ:Any = None) -> str: """Typecast value with known database type and build parameter list. If this is a literal value, it will be returned as is. Otherwise, a @@ -318,29 +338,29 @@ class Literal(str): """Wrapper class for marking literal SQL values.""" -class AttrDict(OrderedDict): +class AttrDict(dict): """Simple read-only ordered dictionary for storing attribute names.""" - def __init__(self, *args, **kw): + def __init__(self, *args: Any, **kw: Any) -> None: self._read_only = False - OrderedDict.__init__(self, *args, **kw) + super().__init__(*args, **kw) self._read_only = True error = self._read_only_error - self.clear = self.update = error - self.pop = self.setdefault = self.popitem = error + self.clear = self.update = error # type: ignore + self.pop = self.setdefault = self.popitem = error # type: ignore - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: if self._read_only: self._read_only_error() - OrderedDict.__setitem__(self, key, value) + super().__setitem__(key, value) - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: if self._read_only: self._read_only_error() - OrderedDict.__delitem__(self, key) + super().__delitem__(key) @staticmethod - def _read_only_error(*args, **kw): + def _read_only_error(*_args: Any, **_kw: Any) -> Any: raise TypeError('This object is read-only') @@ -357,12 +377,12 @@ class Adapter: _re_record_quote = regex(r'[(,"\\]') _re_array_escape = _re_record_escape = regex(r'(["\\])') - def __init__(self, db): + def __init__(self, db: DB): """Initialize the adapter object with the given connection.""" self.db = weakref.proxy(db) @classmethod - def _adapt_bool(cls, v): + def _adapt_bool(cls, v: Any) -> str | None: """Adapt a boolean parameter.""" if isinstance(v, str): if not v: @@ -371,7 +391,7 @@ def _adapt_bool(cls, v): return 't' if v else 'f' @classmethod - def _adapt_date(cls, v): + def _adapt_date(cls, v: Any) -> Any: """Adapt a date parameter.""" if not v: return None @@ -380,7 +400,7 @@ def _adapt_date(cls, v): return v @staticmethod - def _adapt_num(v): + def _adapt_num(v: Any) -> Any: """Adapt a numeric parameter.""" if not v and v != 0: return None @@ -388,11 +408,11 @@ def _adapt_num(v): _adapt_int = _adapt_float = _adapt_money = _adapt_num - def _adapt_bytea(self, v): + def _adapt_bytea(self, v: Any) -> str: """Adapt a bytea parameter.""" return self.db.escape_bytea(v) - def _adapt_json(self, v): + def _adapt_json(self, v: Any) -> str | None: """Adapt a json parameter.""" if not v: return None @@ -402,7 +422,7 @@ def _adapt_json(self, v): return str(v) return self.db.encode_json(v) - def _adapt_hstore(self, v): + def _adapt_hstore(self, v: Any) -> str | None: """Adapt a hstore parameter.""" if not v: return None @@ -414,7 +434,7 @@ def _adapt_hstore(self, v): return str(Hstore(v)) raise TypeError(f'Hstore parameter {v} has wrong type') - def _adapt_uuid(self, v): + def _adapt_uuid(self, v: Any) -> str | None: """Adapt a UUID parameter.""" if not v: return None @@ -423,7 +443,7 @@ def _adapt_uuid(self, v): return str(v) @classmethod - def _adapt_text_array(cls, v): + def _adapt_text_array(cls, v: Any) -> str: """Adapt a text type array parameter.""" if isinstance(v, list): adapt = cls._adapt_text_array @@ -441,7 +461,7 @@ def _adapt_text_array(cls, v): _adapt_date_array = _adapt_text_array @classmethod - def _adapt_bool_array(cls, v): + def _adapt_bool_array(cls, v: Any) -> str: """Adapt a boolean array parameter.""" if isinstance(v, list): adapt = cls._adapt_bool_array @@ -455,7 +475,7 @@ def _adapt_bool_array(cls, v): return 't' if v else 'f' @classmethod - def _adapt_num_array(cls, v): + def _adapt_num_array(cls, v: Any) -> str: """Adapt a numeric array parameter.""" if isinstance(v, list): adapt = cls._adapt_num_array @@ -467,7 +487,7 @@ def _adapt_num_array(cls, v): _adapt_int_array = _adapt_float_array = _adapt_money_array = \ _adapt_num_array - def _adapt_bytea_array(self, v): + def _adapt_bytea_array(self, v: Any) -> bytes: """Adapt a bytea array parameter.""" if isinstance(v, list): return b'{' + b','.join( @@ -476,7 +496,7 @@ def _adapt_bytea_array(self, v): return b'null' return self.db.escape_bytea(v).replace(b'\\', b'\\\\') - def _adapt_json_array(self, v): + def _adapt_json_array(self, v: Any) -> str: """Adapt a json array parameter.""" if isinstance(v, list): adapt = self._adapt_json_array @@ -490,7 +510,7 @@ def _adapt_json_array(self, v): v = f'"{v}"' return v - def _adapt_record(self, v, typ): + def _adapt_record(self, v: Any, typ: Any) -> str: """Adapt a record parameter with given type.""" typ = self.get_attnames(typ).values() if len(typ) != len(v): @@ -516,7 +536,7 @@ def _adapt_record(self, v, typ): v = ','.join(value) return f'({v})' - def adapt(self, value, typ=None): + def adapt(self, value: Any, typ: Any = None) -> str: """Adapt a value with known database type.""" if value is not None and not isinstance(value, Literal): if typ: @@ -541,14 +561,14 @@ def adapt(self, value, typ=None): return value @staticmethod - def simple_type(name): + def simple_type(name: str) -> DbType: """Create a simple database type with given attribute names.""" typ = DbType(name) typ.simple = name return typ @staticmethod - def get_simple_name(typ): + def get_simple_name(typ: Any) -> str: """Get the simple name of a database type.""" if isinstance(typ, DbType): # noinspection PyUnresolvedReferences @@ -556,14 +576,14 @@ def get_simple_name(typ): return _simpletypes[typ] @staticmethod - def get_attnames(typ): + def get_attnames(typ: Any) -> dict[str, dict[str, str]]: """Get the attribute names of a composite database type.""" if isinstance(typ, DbType): return typ.attnames return {} @classmethod - def guess_simple_type(cls, value): + def guess_simple_type(cls, value: Any) -> str | None: """Try to guess which database type the given value has.""" # optimize for most frequent types try: @@ -597,16 +617,17 @@ def guess_simple_type(cls, value): guess = cls.guess_simple_type # noinspection PyUnusedLocal - def get_attnames(self): - return AttrDict((str(n + 1), simple_type(guess(v))) + def get_attnames(self: DbType) -> AttrDict: + return AttrDict((str(n + 1), simple_type(guess(v) or 'text')) for n, v in enumerate(value)) typ = simple_type('record') typ._get_attnames = get_attnames return typ + return None @classmethod - def guess_simple_base_type(cls, value): + def guess_simple_base_type(cls, value: Any) -> str | None: """Try to guess the base type of a given array.""" for v in value: if isinstance(v, list): @@ -615,8 +636,9 @@ def guess_simple_base_type(cls, value): typ = cls.guess_simple_type(v) if typ: return typ + return None - def adapt_inline(self, value, nested=False): + def adapt_inline(self, value: Any, nested: bool=False) -> Any: """Adapt a value that is put into the SQL and needs to be quoted.""" if value is None: return 'NULL' @@ -661,7 +683,7 @@ def adapt_inline(self, value, nested=False): value = self.adapt_inline(value) return value - def parameter_list(self): + def parameter_list(self) -> _ParameterList: """Return a parameter list for parameters with known database types. The list has an add(value, typ) method that will build up the @@ -671,7 +693,11 @@ def parameter_list(self): params.adapt = self.adapt return params - def format_query(self, command, values=None, types=None, inline=False): + def format_query(self, command: str, + values: list | tuple | dict | None = None, + types: list | tuple | dict | None = None, + inline: bool=False + ) -> tuple[str, _ParameterList]: """Format a database query using the given values and types. The optional types describe the values and must be passed as a list, @@ -681,15 +707,15 @@ def format_query(self, command, values=None, types=None, inline=False): If inline is set to True, then parameters will be passed inline together with the query string. """ + params = self.parameter_list() if not values: - return command, [] + return command, params if inline and types: raise ValueError('Typed parameters must be sent separately') - params = self.parameter_list() if isinstance(values, (list, tuple)): if inline: adapt = self.adapt_inline - literals = [adapt(value) for value in values] + seq_literals = [adapt(value) for value in values] else: add = params.add if types: @@ -698,52 +724,51 @@ def format_query(self, command, values=None, types=None, inline=False): if (not isinstance(types, (list, tuple)) or len(types) != len(values)): raise TypeError('The values and types do not match') - literals = [add(value, typ) - for value, typ in zip(values, types)] + seq_literals = [add(value, typ) + for value, typ in zip(values, types)] else: - literals = [add(value) for value in values] - command %= tuple(literals) + seq_literals = [add(value) for value in values] + command %= tuple(seq_literals) elif isinstance(values, dict): # we want to allow extra keys in the dictionary, # so we first must find the values actually used in the command used_values = {} - literals = dict.fromkeys(values, '') + map_literals = dict.fromkeys(values, '') for key in values: - del literals[key] + del map_literals[key] try: - command % literals + command % map_literals except KeyError: - used_values[key] = values[key] - literals[key] = '' - values = used_values + used_values[key] = values[key] # pyright: ignore + map_literals[key] = '' if inline: adapt = self.adapt_inline - literals = {key: adapt(value) - for key, value in values.items()} + map_literals = {key: adapt(value) + for key, value in used_values.items()} else: add = params.add if types: if not isinstance(types, dict): raise TypeError('The values and types do not match') - literals = {key: add(values[key], types.get(key)) - for key in sorted(values)} + map_literals = {key: add(used_values[key], types.get(key)) + for key in sorted(used_values)} else: - literals = {key: add(values[key]) - for key in sorted(values)} - command %= literals + map_literals = {key: add(used_values[key]) + for key in sorted(used_values)} + command %= map_literals else: raise TypeError('The values must be passed as tuple, list or dict') return command, params -def cast_bool(value): +def cast_bool(value: str) -> Any: """Cast a boolean value.""" if not get_bool(): return value return value[0] == 't' -def cast_json(value): +def cast_json(value: str) -> Any: """Cast a JSON value.""" cast = get_jsondecode() if not cast: @@ -751,12 +776,12 @@ def cast_json(value): return cast(value) -def cast_num(value): +def cast_num(value: str) -> Any: """Cast a numeric value.""" return (get_decimal() or float)(value) -def cast_money(value): +def cast_money(value: str) -> Any: """Cast a money value.""" point = get_decimal_point() if not point: @@ -768,12 +793,12 @@ def cast_money(value): return (get_decimal() or float)(value) -def cast_int2vector(value): +def cast_int2vector(value: str) -> list[int]: """Cast an int2vector value.""" return [int(v) for v in value.split()] -def cast_date(value, connection): +def cast_date(value: str, connection: DB) -> Any: """Cast a date value.""" # The output format depends on the server setting DateStyle. The default # setting ISO and the setting for German are actually unambiguous. The @@ -783,93 +808,93 @@ def cast_date(value, connection): return date.min if value == 'infinity': return date.max - value = value.split() - if value[-1] == 'BC': + values = value.split() + if values[-1] == 'BC': return date.min - value = value[0] + value = values[0] if len(value) > 10: return date.max - fmt = connection.date_format() - return datetime.strptime(value, fmt).date() + format = connection.date_format() + return datetime.strptime(value, format).date() -def cast_time(value): +def cast_time(value: str) -> Any: """Cast a time value.""" - fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' - return datetime.strptime(value, fmt).time() + format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' + return datetime.strptime(value, format).time() _re_timezone = regex('(.*)([+-].*)') -def cast_timetz(value): +def cast_timetz(value: str) -> Any: """Cast a timetz value.""" - tz = _re_timezone.match(value) - if tz: - value, tz = tz.groups() + m = _re_timezone.match(value) + if m: + value, tz = m.groups() else: tz = '+0000' - fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' + format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' value += _timezone_as_offset(tz) - fmt += '%z' - return datetime.strptime(value, fmt).timetz() + format += '%z' + return datetime.strptime(value, format).timetz() -def cast_timestamp(value, connection): +def cast_timestamp(value: str, connection: DB) -> Any: """Cast a timestamp value.""" if value == '-infinity': return datetime.min if value == 'infinity': return datetime.max - value = value.split() - if value[-1] == 'BC': + values = value.split() + if values[-1] == 'BC': return datetime.min - fmt = connection.date_format() - if fmt.endswith('-%Y') and len(value) > 2: - value = value[1:5] - if len(value[3]) > 4: + format = connection.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:5] + if len(values[3]) > 4: return datetime.max - fmt = ['%d %b' if fmt.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] else: - if len(value[0]) > 10: + if len(values[0]) > 10: return datetime.max - fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S'] - return datetime.strptime(' '.join(value), ' '.join(fmt)) + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + return datetime.strptime(' '.join(values), ' '.join(formats)) -def cast_timestamptz(value, connection): +def cast_timestamptz(value: str, connection: DB) -> Any: """Cast a timestamptz value.""" if value == '-infinity': return datetime.min if value == 'infinity': return datetime.max - value = value.split() - if value[-1] == 'BC': + values = value.split() + if values[-1] == 'BC': return datetime.min - fmt = connection.date_format() - if fmt.endswith('-%Y') and len(value) > 2: - value = value[1:] - if len(value[3]) > 4: + format = connection.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:] + if len(values[3]) > 4: return datetime.max - fmt = ['%d %b' if fmt.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] - value, tz = value[:-1], value[-1] + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] + values, tz = values[:-1], values[-1] else: - if fmt.startswith('%Y-'): - tz = _re_timezone.match(value[1]) - if tz: - value[1], tz = tz.groups() + if format.startswith('%Y-'): + m = _re_timezone.match(values[1]) + if m: + values[1], tz = m.groups() else: tz = '+0000' else: - value, tz = value[:-1], value[-1] - if len(value[0]) > 10: + values, tz = values[:-1], values[-1] + if len(values[0]) > 10: return datetime.max - fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S'] - value.append(_timezone_as_offset(tz)) - fmt.append('%z') - return datetime.strptime(' '.join(value), ' '.join(fmt)) + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + values.append(_timezone_as_offset(tz)) + formats.append('%z') + return datetime.strptime(' '.join(values), ' '.join(formats)) _re_interval_sql_standard = regex( @@ -900,37 +925,37 @@ def cast_timestamptz(value, connection): '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?') -def cast_interval(value): +def cast_interval(value: str) -> timedelta: """Cast an interval value.""" # The output format depends on the server setting IntervalStyle, but it's # not necessary to consult this setting to parse it. It's faster to just # check all possible formats, and there is no ambiguity here. m = _re_interval_iso_8601.match(value) if m: - m = [d or '0' for d in m.groups()] - secs_ago = m.pop(5) == '-' - m = [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m + s = [v or '0' for v in m.groups()] + secs_ago = s.pop(5) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d if secs_ago: secs = -secs usecs = -usecs else: m = _re_interval_postgres_verbose.match(value) if m: - m, ago = [d or '0' for d in m.groups()[:8]], m.group(9) - secs_ago = m.pop(5) == '-' - m = [-int(d) for d in m] if ago else [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m + s, ago = [v or '0' for v in m.groups()[:8]], m.group(9) + secs_ago = s.pop(5) == '-' + d = [-int(v) for v in s] if ago else [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d if secs_ago: secs = - secs usecs = -usecs else: m = _re_interval_postgres.match(value) if m and any(m.groups()): - m = [d or '0' for d in m.groups()] - hours_ago = m.pop(3) == '-' - m = [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m + s = [v or '0' for v in m.groups()] + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d if hours_ago: hours = -hours mins = -mins @@ -939,11 +964,11 @@ def cast_interval(value): else: m = _re_interval_sql_standard.match(value) if m and any(m.groups()): - m = [d or '0' for d in m.groups()] - years_ago = m.pop(0) == '-' - hours_ago = m.pop(3) == '-' - m = [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m + s = [v or '0' for v in m.groups()] + years_ago = s.pop(0) == '-' + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d if years_ago: years = -years mons = -mons @@ -973,7 +998,7 @@ class Typecasts(dict): # the default cast functions # (str functions are ignored but have been added for faster access) - defaults: ClassVar[Dict[str, Type]] = { + defaults: ClassVar[dict[str, Callable]] = { 'char': str, 'bpchar': str, 'name': str, 'text': str, 'varchar': str, 'sql_identifier': str, 'bool': cast_bool, 'bytea': unescape_bytea, @@ -985,11 +1010,11 @@ class Typecasts(dict): 'time': cast_time, 'timetz': cast_timetz, 'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz, 'int2vector': cast_int2vector, 'uuid': UUID, - 'anyarray': cast_array, 'record': cast_record} + 'anyarray': cast_array, 'record': cast_record} # pyright: ignore - connection = None # will be set in a connection specific instance + connection: DB | None = None # set in a connection specific instance - def __missing__(self, typ): + def __missing__(self, typ: Any) -> Callable | None: """Create a cast function if it is not cached. Note that this class never raises a KeyError, @@ -997,7 +1022,7 @@ def __missing__(self, typ): """ if not isinstance(typ, str): raise TypeError(f'Invalid type: {typ}') - cast = self.defaults.get(typ) + cast: Callable | None = self.defaults.get(typ) if cast: # store default for faster access cast = self._add_connection(cast) @@ -1016,7 +1041,7 @@ def __missing__(self, typ): return cast @staticmethod - def _needs_connection(func): + def _needs_connection(func: Callable) -> bool: """Check if a typecast function needs a connection argument.""" try: args = get_args(func) @@ -1025,17 +1050,17 @@ def _needs_connection(func): else: return 'connection' in args[1:] - def _add_connection(self, cast): + def _add_connection(self, cast: Callable) -> Callable: """Add a connection argument to the typecast function if necessary.""" if not self.connection or not self._needs_connection(cast): return cast return partial(cast, connection=self.connection) - def get(self, typ, default=None): + def get(self, typ: Any, default: Any = None) -> Any: """Get the typecast function for the given database type.""" return self[typ] or default - def set(self, typ, cast): + def set(self, typ: Any, cast: Callable) -> None: """Set a typecast function for the specified database type(s).""" if isinstance(typ, str): typ = [typ] @@ -1050,7 +1075,7 @@ def set(self, typ, cast): self[t] = self._add_connection(cast) self.pop(f'_{t}', None) - def reset(self, typ=None): + def reset(self, typ: Any = None) -> None: """Reset the typecasts for the specified type(s) to their defaults. When no type is specified, all typecasts will be reset. @@ -1064,12 +1089,12 @@ def reset(self, typ=None): self.pop(t, None) @classmethod - def get_default(cls, typ): + def get_default(cls, typ: Any) -> Any: """Get the default typecast function for the given database type.""" return cls.defaults.get(typ) @classmethod - def set_default(cls, typ, cast): + def set_default(cls, typ: Any, cast: Callable | None) -> None: """Set a default typecast function for the given database type(s).""" if isinstance(typ, str): typ = [typ] @@ -1086,46 +1111,47 @@ def set_default(cls, typ, cast): defaults.pop(f'_{t}', None) # noinspection PyMethodMayBeStatic,PyUnusedLocal - def get_attnames(self, typ): + def get_attnames(self, typ: Any) -> AttrDict: """Return the fields for the given record type. This method will be replaced with the get_attnames() method of DbTypes. """ - return {} + return AttrDict() # noinspection PyMethodMayBeStatic - def dateformat(self): + def dateformat(self) -> str: """Return the current date format. This method will be replaced with the dateformat() method of DbTypes. """ return '%Y-%m-%d' - def create_array_cast(self, basecast): + def create_array_cast(self, basecast: Callable) -> Callable: """Create an array typecast for the given base cast.""" cast_array = self['anyarray'] - def cast(v): + def cast(v: Any) -> Callable: return cast_array(v, basecast) return cast - def create_record_cast(self, name, fields, casts): + def create_record_cast(self, name: str, fields: AttrDict, + casts: list[Callable]) -> Callable: """Create a named record typecast for the given fields and casts.""" cast_record = self['record'] - record = namedtuple(name, fields) + record = namedtuple(name, fields) # type: ignore - def cast(v): + def cast(v: Any) -> record: # noinspection PyArgumentList return record(*cast_record(v, casts)) return cast -def get_typecast(typ): +def get_typecast(typ: Any) -> Callable | None: """Get the global typecast function for the given database type(s).""" return Typecasts.get_default(typ) -def set_typecast(typ, cast): +def set_typecast(typ: Any, cast: Callable | None) -> None: """Set a global typecast function for the given database type(s). Note that connections cache cast functions. To be sure a global change @@ -1161,10 +1187,10 @@ class DbType(str): delim: str relid: int - _get_attnames: Callable + _get_attnames: Callable[[DbType], AttrDict] @property - def attnames(self): + def attnames(self) -> AttrDict: """Get names and types of the fields of a composite type.""" # noinspection PyUnresolvedReferences return self._get_attnames(self) @@ -1180,13 +1206,13 @@ class DbTypes(dict): _num_types = frozenset('int float num money int2 int4 int8' ' float4 float8 numeric money'.split()) - def __init__(self, db): + def __init__(self, db: DB) -> None: """Initialize type cache for connection.""" super().__init__() self._db = weakref.proxy(db) self._regtypes = False self._typecasts = Typecasts() - self._typecasts.get_attnames = self.get_attnames + self._typecasts.get_attnames = self.get_attnames # type: ignore self._typecasts.connection = self._db self._query_pg_type = ( "SELECT oid, typname, oid::pg_catalog.regtype," @@ -1194,8 +1220,9 @@ def __init__(self, db): " FROM pg_catalog.pg_type" " WHERE oid OPERATOR(pg_catalog.=) {}::pg_catalog.regtype") - def add(self, oid, pgtype, regtype, - typlen, typtype, category, delim, relid): + def add(self, oid: int, pgtype: str, regtype: str, + typlen: int, typtype: str, category: str, delim: str, relid: int + ) -> DbType: """Create a PostgreSQL type name with additional info.""" if oid in self: return self[oid] @@ -1210,14 +1237,14 @@ def add(self, oid, pgtype, regtype, typ.category = category typ.delim = delim typ.relid = relid - typ._get_attnames = self.get_attnames + typ._get_attnames = self.get_attnames # type: ignore return typ - def __missing__(self, key): + def __missing__(self, key: int | str) -> DbType: """Get the type info from the database if it is not cached.""" try: - q = self._query_pg_type.format(_quote_if_unqualified('$1', key)) - res = self._db.query(q, (key,)).getresult() + cmd = self._query_pg_type.format(_quote_if_unqualified('$1', key)) + res = self._db.query(cmd, (key,)).getresult() except ProgrammingError: res = None if not res: @@ -1227,14 +1254,14 @@ def __missing__(self, key): self[typ.oid] = self[typ.pgtype] = typ return typ - def get(self, key, default=None): + def get(self, key: int | str, default: Any = None) -> Any: """Get the type even if it is not cached.""" try: return self[key] except KeyError: return default - def get_attnames(self, typ): + def get_attnames(self, typ: Any) -> AttrDict | None: """Get names and types of the fields of a composite type.""" if not isinstance(typ, DbType): typ = self.get(typ) @@ -1244,19 +1271,19 @@ def get_attnames(self, typ): return None return self._db.get_attnames(typ.relid, with_oid=False) - def get_typecast(self, typ): + def get_typecast(self, typ: Any) -> Callable: """Get the typecast function for the given database type.""" return self._typecasts.get(typ) - def set_typecast(self, typ, cast): + def set_typecast(self, typ: Any, cast: Callable) -> None: """Set a typecast function for the specified database type(s).""" self._typecasts.set(typ, cast) - def reset_typecast(self, typ=None): + def reset_typecast(self, typ: Any = None) -> None: """Reset the typecast function for the specified database type(s).""" self._typecasts.reset(typ) - def typecast(self, value, typ): + def typecast(self, value: Any, typ: Any) -> Callable | None: """Cast the given value according to the given database type.""" if value is None: # for NULL values, no typecast is necessary @@ -1272,25 +1299,22 @@ def typecast(self, value, typ): return cast(value) -_re_fieldname = regex('^[A-Za-z][_a-zA-Z0-9]*$') - - # The result rows for database operations are returned as named tuples # by default. Since creating namedtuple classes is a somewhat expensive # operation, we cache up to 1024 of these classes by default. # noinspection PyUnresolvedReferences @lru_cache(maxsize=1024) -def _row_factory(names): +def _row_factory(names: Sequence[str]) -> Callable[[Sequence], NamedTuple]: """Get a namedtuple factory for row results with the given names.""" try: - return namedtuple('Row', names, rename=True)._make + return namedtuple('Row', names, rename=True)._make # type: ignore except ValueError: # there is still a problem with the field names names = [f'column_{n}' for n in range(len(names))] - return namedtuple('Row', names)._make + return namedtuple('Row', names)._make # type: ignore -def set_row_factory_size(maxsize): +def set_row_factory_size(maxsize: int | None) -> None: """Change the size of the namedtuple factory cache. If maxsize is set to None, the cache can grow without bound. @@ -1302,26 +1326,26 @@ def set_row_factory_size(maxsize): # Helper functions used by the query object -def _dictiter(q): +def _dictiter(q: Query) -> Generator[dict[str, Any], None, None]: """Get query result as an iterator of dictionaries.""" - fields = q.listfields() + fields: tuple[str, ...] = q.listfields() for r in q: yield dict(zip(fields, r)) -def _namediter(q): +def _namediter(q: Query) -> Generator[NamedTuple, None, None]: """Get query result as an iterator of named tuples.""" row = _row_factory(q.listfields()) for r in q: yield row(r) -def _namednext(q): +def _namednext(q: Query) -> NamedTuple: """Get next row from query result as a named tuple.""" return _row_factory(q.listfields())(next(q)) -def _scalariter(q): +def _scalariter(q: Query) -> Generator[Any, None, None]: """Get query result as an iterator of scalar values.""" for r in q: yield r[0] @@ -1330,36 +1354,41 @@ def _scalariter(q): class _MemoryQuery: """Class that embodies a given query result.""" - def __init__(self, result, fields): + result: Any + fields: tuple[str, ...] + + def __init__(self, result: Any, fields: Sequence[str]) -> None: """Create query from given result rows and field names.""" self.result = result self.fields = tuple(fields) - def listfields(self): + def listfields(self) -> tuple[str, ...]: """Return the stored field names of this query.""" return self.fields - def getresult(self): + def getresult(self) -> Any: """Return the stored result of this query.""" return self.result - def __iter__(self): + def __iter__(self) -> Iterator[Any]: return iter(self.result) -def _db_error(msg, cls=DatabaseError): +E = TypeVar('E', bound=DatabaseError) + +def _db_error(msg: str, cls:type[E] = DatabaseError) -> type[E]: """Return DatabaseError with empty sqlstate attribute.""" error = cls(msg) error.sqlstate = None return error -def _int_error(msg): +def _int_error(msg: str) -> InternalError: """Return InternalError.""" return _db_error(msg, InternalError) -def _prg_error(msg): +def _prg_error(msg: str) -> ProgrammingError: """Return ProgrammingError.""" return _db_error(msg, ProgrammingError) @@ -1376,8 +1405,10 @@ def _prg_error(msg): class NotificationHandler: """A PostgreSQL client-side asynchronous notification handler.""" - def __init__(self, db, event, callback=None, - arg_dict=None, timeout=None, stop_event=None): + def __init__(self, db: DB, event: str, callback: Callable, + arg_dict: dict | None = None, + timeout: int | float | None = None, + stop_event: str | None = None): """Initialize the notification handler. You must pass a PyGreSQL database connection, the name of an @@ -1395,7 +1426,7 @@ def __init__(self, db, event, callback=None, the handler to stop listening as stop_event. By default, it will be the event name prefixed with 'stop_'. """ - self.db = db + self.db: DB | None = db self.event = event self.stop_event = stop_event or f'stop_{event}' self.listening = False @@ -1405,32 +1436,35 @@ def __init__(self, db, event, callback=None, self.arg_dict = arg_dict self.timeout = timeout - def __del__(self): + def __del__(self) -> None: """Delete the notification handler.""" self.unlisten() - def close(self): + def close(self) -> None: """Stop listening and close the connection.""" if self.db: self.unlisten() self.db.close() self.db = None - def listen(self): + def listen(self) -> None: """Start listening for the event and the stop event.""" - if not self.listening: - self.db.query(f'listen "{self.event}"') - self.db.query(f'listen "{self.stop_event}"') + db = self.db + if db and not self.listening: + db.query(f'listen "{self.event}"') + db.query(f'listen "{self.stop_event}"') self.listening = True - def unlisten(self): + def unlisten(self) -> None: """Stop listening for the event and the stop event.""" - if self.listening: - self.db.query(f'unlisten "{self.event}"') - self.db.query(f'unlisten "{self.stop_event}"') + db = self.db + if db and self.listening: + db.query(f'unlisten "{self.event}"') + db.query(f'unlisten "{self.stop_event}"') self.listening = False - def notify(self, db=None, stop=False, payload=None): + def notify(self, db: DB | None = None, stop: bool = False, + payload: str | None = None) -> None: """Generate a notification. Optionally, you can pass a payload with the notification. @@ -1445,13 +1479,15 @@ def notify(self, db=None, stop=False, payload=None): if self.listening: if not db: db = self.db + if not db: + return event = self.stop_event if stop else self.event - q = f'notify "{event}"' + cmd = f'notify "{event}"' if payload: - q += f", '{payload}'" - return db.query(q) + cmd += f", '{payload}'" + return db.query(cmd) - def __call__(self): + def __call__(self) -> None: """Invoke the notification handler. The handler is a loop that listens for notifications on the event @@ -1469,14 +1505,15 @@ def __call__(self): Note: If you run this loop in another thread, don't use the same database connection for database operations in the main thread. """ + if not self.db: + return self.listen() poll = self.timeout == 0 - if not poll: - rlist = [self.db.fileno()] - while self.listening: + rlist = [] if poll else [self.db.fileno()] + while self.db and self.listening: # noinspection PyUnboundLocalVariable if poll or select.select(rlist, [], [], self.timeout)[0]: - while self.listening: + while self.db and self.listening: notice = self.db.getnotify() if not notice: # no more messages break @@ -1503,9 +1540,9 @@ def __call__(self): class DB: """Wrapper class for the _pg connection type.""" - db = None # invalid fallback for underlying connection + db: Connection | None = None # invalid fallback for underlying connection - def __init__(self, *args, **kw): + def __init__(self, *args: Any, **kw: Any) -> None: """Create a new connection. You can pass either the connection parameters or an existing @@ -1535,10 +1572,10 @@ def __init__(self, *args, **kw): self.db = db self.dbname = db.db self._regtypes = False - self._attnames = {} - self._generated = {} - self._pkeys = {} - self._privileges = {} + self._attnames: dict[str, AttrDict] = {} + self._generated: dict[str, frozenset[str]] = {} + self._pkeys: dict[str, str | tuple[str, ...]] = {} + self._privileges: dict[tuple[str, str], bool] = {} self.adapter = Adapter(self) self.dbtypes = DbTypes(self) self._query_attnames = ( @@ -1566,9 +1603,9 @@ def __init__(self, *args, **kw): # * to a file object to write debug statements or # * to a callable object which takes a string argument # * to any other true value to just print debug statements - self.debug = None + self.debug: Any = None - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: """Get the specified attritbute of the connection.""" # All undefined members are same as in underlying connection: if self.db: @@ -1576,7 +1613,7 @@ def __getattr__(self, name): else: raise _int_error('Connection is not valid') - def __dir__(self): + def __dir__(self) -> list[str]: """List all attributes of the connection.""" # Custom dir function including the attributes of the connection: attrs = set(self.__class__.__dict__) @@ -1586,19 +1623,20 @@ def __dir__(self): # Context manager methods - def __enter__(self): + def __enter__(self) -> DB: """Enter the runtime context. This will start a transaction.""" self.begin() return self - def __exit__(self, et, ev, tb): + def __exit__(self, et: type[BaseException] | None, + ev: BaseException | None, tb: Any) -> None: """Exit the runtime context. This will end the transaction.""" if et is None and ev is None and tb is None: self.commit() else: self.rollback() - def __del__(self): + def __del__(self) -> None: """Delete the connection.""" try: db = self.db @@ -1613,7 +1651,7 @@ def __del__(self): # Auxiliary methods - def _do_debug(self, *args): + def _do_debug(self, *args: Any) -> None: """Print a debug message.""" if self.debug: s = '\n'.join(str(arg) for arg in args) @@ -1627,7 +1665,7 @@ def _do_debug(self, *args): else: print(s) - def _escape_qualified_name(self, s): + def _escape_qualified_name(self, s: str) -> str: """Escape a qualified name. Escapes the name for use as an SQL identifier, unless the @@ -1640,15 +1678,23 @@ def _escape_qualified_name(self, s): return s @staticmethod - def _make_bool(d): + def _make_bool(d: Any) -> bool | str: """Get boolean value corresponding to d.""" return bool(d) if get_bool() else ('t' if d else 'f') @staticmethod - def _list_params(params): + def _list_params(params: Sequence) -> str: """Create a human readable parameter list.""" return ', '.join(f'${n}={v!r}' for n, v in enumerate(params, 1)) + @property + def _valid_db(self) -> Connection: + """Get underlying connection and make sure it is not closed.""" + db = self.db + if not db: + raise _int_error('Connection already closed') + return db + # Public methods # escape_string and escape_bytea exist as methods, @@ -1656,46 +1702,38 @@ def _list_params(params): unescape_bytea = staticmethod(unescape_bytea) @staticmethod - def decode_json(s): + def decode_json(s: str) -> Any: """Decode a JSON string coming from the database.""" return (get_jsondecode() or jsondecode)(s) @staticmethod - def encode_json(d): + def encode_json(d: Any) -> str: """Encode a JSON string for use within SQL.""" return jsonencode(d) - def close(self): + def close(self) -> None: """Close the database connection.""" # Wraps shared library function so we can track state. - db = self.db - if db: - with suppress(TypeError): # when already closed - db.set_cast_hook(None) - if self._closeable: - db.close() - self.db = None - else: - raise _int_error('Connection already closed') + db = self._valid_db + with suppress(TypeError): # when already closed + db.set_cast_hook(None) + if self._closeable: + db.close() + self.db = None - def reset(self): + def reset(self) -> None: """Reset connection with current parameters. All derived queries and large objects derived from this connection will not be usable after this call. + """ + self._valid_db.reset() - """ - if self.db: - self.db.reset() - else: - raise _int_error('Connection already closed') - - def reopen(self): + def reopen(self) -> None: """Reopen connection to the database. Used in case we need another connection to the same database. Note that we can still reopen a database that we have closed. - """ # There is no such shared library function. if self._closeable: @@ -1708,7 +1746,7 @@ def reopen(self): else: self.db = self._db_args - def begin(self, mode=None): + def begin(self, mode: str | None = None) -> None: """Begin a transaction.""" qstr = 'BEGIN' if mode: @@ -1717,13 +1755,13 @@ def begin(self, mode=None): start = begin - def commit(self): + def commit(self) -> None: """Commit the current transaction.""" return self.query('COMMIT') end = commit - def rollback(self, name=None): + def rollback(self, name: str | None = None) -> None: """Roll back the current transaction.""" qstr = 'ROLLBACK' if name: @@ -1732,15 +1770,18 @@ def rollback(self, name=None): abort = rollback - def savepoint(self, name): + def savepoint(self, name: str) -> None: """Define a new savepoint within the current transaction.""" return self.query('SAVEPOINT ' + name) - def release(self, name): + def release(self, name: str) -> None: """Destroy a previously defined savepoint.""" return self.query('RELEASE ' + name) - def get_parameter(self, parameter): + def get_parameter(self, + parameter: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str] | dict[str, Any] + ) -> str | list[str] | dict[str, str]: """Get the value of a run-time parameter. If the parameter is a string, the return value will also be a string @@ -1757,6 +1798,7 @@ def get_parameter(self, parameter): By passing the special name 'all' as the parameter, you can get a dict of all existing configuration parameters. """ + values: Any if isinstance(parameter, str): parameter = [parameter] values = None @@ -1771,25 +1813,26 @@ def get_parameter(self, parameter): 'The parameter must be a string, list, set or dict') if not parameter: raise TypeError('No parameter has been specified') - params = {} if isinstance(values, dict) else [] - for key in parameter: - param = key.strip().lower() if isinstance( - key, (bytes, str)) else None + query = self._valid_db.query + params: Any = {} if isinstance(values, dict) else [] + for param_key in parameter: + param = param_key.strip().lower() if isinstance( + param_key, (bytes, str)) else None if not param: raise TypeError('Invalid parameter') if param == 'all': - q = 'SHOW ALL' - values = self.db.query(q).getresult() + cmd = 'SHOW ALL' + values = query(cmd).getresult() values = {value[0]: value[1] for value in values} break - if isinstance(values, dict): - params[param] = key + if isinstance(params, dict): + params[param] = param_key else: params.append(param) else: for param in params: - q = f'SHOW {param}' - value = self.db.query(q).singlescalar() + cmd = f'SHOW {param}' + value = query(cmd).singlescalar() if values is None: values = value elif isinstance(values, list): @@ -1798,7 +1841,12 @@ def get_parameter(self, parameter): values[params[param]] = value return values - def set_parameter(self, parameter, value=None, local=False): + def set_parameter(self, + parameter: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str] | dict[str, Any], + value: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str]| None = None, + local: bool = False) -> None: """Set the value of a run-time parameter. If the parameter and the value are strings, the run-time parameter @@ -1833,7 +1881,7 @@ def set_parameter(self, parameter, value=None, local=False): if isinstance(value, (list, tuple, set, frozenset)): value = set(value) if len(value) == 1: - value = value.pop() + value = next(iter(value)) if not (value is None or isinstance(value, str)): raise ValueError( 'A single value must be specified' @@ -1849,30 +1897,28 @@ def set_parameter(self, parameter, value=None, local=False): 'The parameter must be a string, list, set or dict') if not parameter: raise TypeError('No parameter has been specified') - params = {} - for key, value in parameter.items(): - param = key.strip().lower() if isinstance( - key, str) else None + params: dict[str, str | None] = {} + for param, param_value in parameter.items(): + param = param.strip().lower() if isinstance(param, str) else None if not param: raise TypeError('Invalid parameter') if param == 'all': - if value is not None: + if param_value is not None: raise ValueError( - 'A value must ot be specified' + 'A value must not be specified' " when parameter is 'all'") params = {'all': None} break - params[param] = value - local = ' LOCAL' if local else '' - for param, value in params.items(): - if value is None: - q = f'RESET{local} {param}' - else: - q = f'SET{local} {param} TO {value}' - self._do_debug(q) - self.db.query(q) - - def query(self, command, *args): + params[param] = param_value + local_clause = ' LOCAL' if local else '' + for param, param_value in params.items(): + cmd = (f'RESET{local_clause} {param}' + if param_value is None else + f'SET{local_clause} {param} TO {param_value}') + self._do_debug(cmd) + self._valid_db.query(cmd) + + def query(self, command: str, *args: Any) -> Query: """Execute a SQL command string. This method simply sends a SQL query to the database. If the query is @@ -1892,16 +1938,17 @@ def query(self, command, *args): values can also be given as a single list or tuple argument. """ # Wraps shared library function for debugging. - if not self.db: - raise _int_error('Connection is not valid') + db = self._valid_db if args: self._do_debug(command, args) - return self.db.query(command, args) + return db.query(command, args) self._do_debug(command) - return self.db.query(command) + return db.query(command) - def query_formatted(self, command, - parameters=None, types=None, inline=False): + def query_formatted(self, command: str, + parameters: tuple | list | dict | None = None, + types: tuple | list | dict | None = None, + inline: bool =False) -> Query: """Execute a formatted SQL command string. Similar to query, but using Python format placeholders of the form @@ -1916,24 +1963,23 @@ def query_formatted(self, command, return self.query(*self.adapter.format_query( command, parameters, types, inline)) - def query_prepared(self, name, *args): + def query_prepared(self, name: str, *args: Any) -> Query: """Execute a prepared SQL statement. This works like the query() method, except that instead of passing the SQL command, you pass the name of a prepared statement. If you pass an empty name, the unnamed statement will be executed. """ - if not self.db: - raise _int_error('Connection is not valid') if name is None: name = '' + db = self._valid_db if args: self._do_debug('EXECUTE', name, args) - return self.db.query_prepared(name, args) + return db.query_prepared(name, args) self._do_debug('EXECUTE', name) - return self.db.query_prepared(name) + return db.query_prepared(name) - def prepare(self, name, command): + def prepare(self, name: str, command: str) -> Query: """Create a prepared SQL statement. This creates a prepared statement for the given command with the @@ -1946,14 +1992,12 @@ def prepare(self, name, command): named queries, since unnamed queries have a limited lifetime and can be automatically replaced or destroyed by various operations. """ - if not self.db: - raise _int_error('Connection is not valid') if name is None: name = '' self._do_debug('prepare', name, command) - return self.db.prepare(name, command) + return self._valid_db.prepare(name, command) - def describe_prepared(self, name=None): + def describe_prepared(self, name: str | None = None) -> Query: """Describe a prepared SQL statement. This method returns a Query object describing the result columns of @@ -1962,9 +2006,9 @@ def describe_prepared(self, name=None): """ if name is None: name = '' - return self.db.describe_prepared(name) + return self._valid_db.describe_prepared(name) - def delete_prepared(self, name=None): + def delete_prepared(self, name: str | None = None) -> Query: """Delete a prepared SQL statement. This deallocates a previously prepared SQL statement with the given @@ -1974,12 +2018,13 @@ def delete_prepared(self, name=None): """ if not name: name = 'ALL' - q = f"DEALLOCATE {name}" - self._do_debug(q) - return self.db.query(q) + cmd = f"DEALLOCATE {name}" + self._do_debug(cmd) + return self._valid_db.query(cmd) - def pkey(self, table, composite=False, flush=False): - """Get or set the primary key of a table. + def pkey(self, table: str, composite: bool = False, flush: bool = False + ) -> str | tuple[str, ...]: + """Get the primary key of a table. Single primary keys are returned as strings unless you set the composite flag. Composite primary keys are always @@ -1997,26 +2042,26 @@ def pkey(self, table, composite=False, flush=False): try: # cache lookup pkey = pkeys[table] except KeyError as e: # cache miss, check the database - q = ("SELECT" # noqa: S608 - " a.attname, a.attnum, i.indkey" - " FROM pg_catalog.pg_index i" - " JOIN pg_catalog.pg_attribute a" - " ON a.attrelid OPERATOR(pg_catalog.=) i.indrelid" - " AND a.attnum OPERATOR(pg_catalog.=) ANY(i.indkey)" - " AND NOT a.attisdropped" - " WHERE i.indrelid OPERATOR(pg_catalog.=)" - " {}::pg_catalog.regclass" - " AND i.indisprimary ORDER BY a.attnum").format( - _quote_if_unqualified('$1', table)) - pkey = self.db.query(q, (table,)).getresult() + cmd = ("SELECT" # noqa: S608 + " a.attname, a.attnum, i.indkey" + " FROM pg_catalog.pg_index i" + " JOIN pg_catalog.pg_attribute a" + " ON a.attrelid OPERATOR(pg_catalog.=) i.indrelid" + " AND a.attnum OPERATOR(pg_catalog.=) ANY(i.indkey)" + " AND NOT a.attisdropped" + " WHERE i.indrelid OPERATOR(pg_catalog.=)" + " {}::pg_catalog.regclass" + " AND i.indisprimary ORDER BY a.attnum").format( + _quote_if_unqualified('$1', table)) + pkey = self._valid_db.query(cmd, (table,)).getresult() if not pkey: raise KeyError(f'Table {table} has no primary key') from e # we want to use the order defined in the primary key index here, # not the order as defined by the columns in the table if len(pkey) > 1: indkey = pkey[0][2] - pkey = sorted(pkey, key=lambda row: indkey.index(row[1])) - pkey = tuple(row[0] for row in pkey) + pkey = tuple(row[0] for row in sorted( + pkey, key=lambda row: indkey.index(row[1]))) else: pkey = pkey[0][0] pkeys[table] = pkey # cache it @@ -2024,12 +2069,20 @@ def pkey(self, table, composite=False, flush=False): pkey = (pkey,) return pkey - def get_databases(self): + def pkeys(self, table: str) -> tuple[str, ...]: + """Get the primary key of a table as a tuple. + + Same as pkey() with 'composite' set to True. + """ + return self.pkey(table, True) # type: ignore + + def get_databases(self) -> list[str]: """Get list of databases in the system.""" - return [s[0] for s in self.db.query( + return [r[0] for r in self._valid_db.query( 'SELECT datname FROM pg_catalog.pg_database').getresult()] - def get_relations(self, kinds=None, system=False): + def get_relations(self, kinds: str | Sequence[str] | None = None, + system: bool = False) -> list[str]: """Get list of relations in connected database of specified kinds. If kinds is None or empty, all kinds of relations are returned. @@ -2038,31 +2091,32 @@ def get_relations(self, kinds=None, system=False): Set the system flag if you want to get the system relations as well. """ - where = [] + where_parts = [] if kinds: - where.append( + where_parts.append( "r.relkind IN ({})".format(','.join(f"'{k}'" for k in kinds))) if not system: - where.append("s.nspname NOT SIMILAR" - " TO 'pg/_%|information/_schema' ESCAPE '/'") - where = " WHERE " + ' AND '.join(where) if where else '' - q = ("SELECT" # noqa: S608 - " pg_catalog.quote_ident(s.nspname) OPERATOR(pg_catalog.||)" - " '.' OPERATOR(pg_catalog.||) pg_catalog.quote_ident(r.relname)" - " FROM pg_catalog.pg_class r" - " JOIN pg_catalog.pg_namespace s" - f" ON s.oid OPERATOR(pg_catalog.=) r.relnamespace{where}" - " ORDER BY s.nspname, r.relname") - return [r[0] for r in self.db.query(q).getresult()] - - def get_tables(self, system=False): + where_parts.append("s.nspname NOT SIMILAR" + " TO 'pg/_%|information/_schema' ESCAPE '/'") + where = " WHERE " + ' AND '.join(where_parts) if where_parts else '' + cmd = ("SELECT" # noqa: S608 + " pg_catalog.quote_ident(s.nspname) OPERATOR(pg_catalog.||)" + " '.' OPERATOR(pg_catalog.||) pg_catalog.quote_ident(r.relname)" + " FROM pg_catalog.pg_class r" + " JOIN pg_catalog.pg_namespace s" + f" ON s.oid OPERATOR(pg_catalog.=) r.relnamespace{where}" + " ORDER BY s.nspname, r.relname") + return [r[0] for r in self._valid_db.query(cmd).getresult()] + + def get_tables(self, system: bool = False) -> list[str]: """Return list of tables in connected database. Set the system flag if you want to get the system tables as well. """ return self.get_relations('r', system) - def get_attnames(self, table, with_oid=True, flush=False): + def get_attnames(self, table: str, with_oid: bool=True, flush: bool=False + ) -> AttrDict: """Given the name of a table, dig out the set of attribute names. Returns a read-only dictionary of attribute names (the names are @@ -2083,19 +2137,18 @@ def get_attnames(self, table, with_oid=True, flush=False): try: # cache lookup names = attnames[table] except KeyError: # cache miss, check the database - q = "a.attnum OPERATOR(pg_catalog.>) 0" + cmd = "a.attnum OPERATOR(pg_catalog.>) 0" if with_oid: - q = f"({q} OR a.attname OPERATOR(pg_catalog.=) 'oid')" - q = self._query_attnames.format( - _quote_if_unqualified('$1', table), q) - names = self.db.query(q, (table,)).getresult() + cmd = f"({cmd} OR a.attname OPERATOR(pg_catalog.=) 'oid')" + cmd = self._query_attnames.format( + _quote_if_unqualified('$1', table), cmd) + names = self._valid_db.query(cmd, (table,)).getresult() types = self.dbtypes - names = ((name[0], types.add(*name[1:])) for name in names) - names = AttrDict(names) + names = AttrDict((name[0], types.add(*name[1:])) for name in names) attnames[table] = names # cache it return names - def get_generated(self, table, flush=False): + def get_generated(self, table: str, flush: bool = False) -> frozenset[str]: """Given the name of a table, dig out the set of generated columns. Returns a set of column names that are generated and unalterable. @@ -2111,28 +2164,28 @@ def get_generated(self, table, flush=False): try: # cache lookup names = generated[table] except KeyError: # cache miss, check the database - q = "a.attnum OPERATOR(pg_catalog.>) 0" - q = f"{q} AND {self._query_generated}" - q = self._query_attnames.format( - _quote_if_unqualified('$1', table), q) - names = self.db.query(q, (table,)).getresult() + cmd = "a.attnum OPERATOR(pg_catalog.>) 0" + cmd = f"{cmd} AND {self._query_generated}" + cmd = self._query_attnames.format( + _quote_if_unqualified('$1', table), cmd) + names = self._valid_db.query(cmd, (table,)).getresult() names = frozenset(name[0] for name in names) generated[table] = names # cache it return names - def use_regtypes(self, regtypes=None): + def use_regtypes(self, regtypes: bool | None = None) -> bool: """Use registered type names instead of simplified type names.""" if regtypes is None: return self.dbtypes._regtypes - else: - regtypes = bool(regtypes) - if regtypes != self.dbtypes._regtypes: - self.dbtypes._regtypes = regtypes - self._attnames.clear() - self.dbtypes.clear() - return regtypes - - def has_table_privilege(self, table, privilege='select', flush=False): + regtypes = bool(regtypes) + if regtypes != self.dbtypes._regtypes: + self.dbtypes._regtypes = regtypes + self._attnames.clear() + self.dbtypes.clear() + return regtypes + + def has_table_privilege(self, table: str, privilege: str = 'select', + flush: bool = False) -> bool: """Check whether current user has specified table privilege. If flush is set, then the internal cache for table privileges will @@ -2146,14 +2199,15 @@ def has_table_privilege(self, table, privilege='select', flush=False): try: # ask cache ret = privileges[table, privilege] except KeyError: # cache miss, ask the database - q = "SELECT pg_catalog.has_table_privilege({}, $2)".format( + cmd = "SELECT pg_catalog.has_table_privilege({}, $2)".format( _quote_if_unqualified('$1', table)) - q = self.db.query(q, (table, privilege)) - ret = q.singlescalar() == self._make_bool(True) + query = self._valid_db.query(cmd, (table, privilege)) + ret = query.singlescalar() == self._make_bool(True) privileges[table, privilege] = ret # cache it return ret - def get(self, table, row, keyname=None): + def get(self, table: str, row: Any, + keyname: str | tuple[str, ...] | None = None) -> dict[str, Any]: """Get a row from a database table or view. This method is the basic mechanism to get a single row. It assumes @@ -2181,7 +2235,7 @@ def get(self, table, row, keyname=None): row['oid'] = row[qoid] if not keyname: try: # if keyname is not specified, try using the primary key - keyname = self.pkey(table, True) + keyname = self.pkeys(table) except KeyError as e: # the table has no primary key # try using the oid instead if qoid and isinstance(row, dict) and 'oid' in row: @@ -2216,10 +2270,10 @@ def get(self, table, row, keyname=None): row[qoid] = row['oid'] del row['oid'] t = self._escape_qualified_name(table) - q = f'SELECT {what} FROM {t} WHERE {where} LIMIT 1' # noqa: S608s - self._do_debug(q, params) - q = self.db.query(q, params) - res = q.dictresult() + cmd = f'SELECT {what} FROM {t} WHERE {where} LIMIT 1' # noqa: S608s + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() if not res: # make where clause in error message better readable where = where.replace('OPERATOR(pg_catalog.=)', '=') @@ -2232,7 +2286,8 @@ def get(self, table, row, keyname=None): row[n] = value return row - def insert(self, table, row=None, **kw): + def insert(self, table: str, row: dict[str, Any] | None = None, **kw: Any + ) -> dict[str, Any]: """Insert a row into a database table. This method inserts a row into a table. The name of the table must @@ -2258,21 +2313,21 @@ def insert(self, table, row=None, **kw): params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier - names, values = [], [] + name_list, value_list = [], [] for n in attnames: if n in row and n not in generated: - names.append(col(n)) - values.append(adapt(row[n], attnames[n])) - if not names: + name_list.append(col(n)) + value_list.append(adapt(row[n], attnames[n])) + if not name_list: raise _prg_error('No column found that can be inserted') - names, values = ', '.join(names), ', '.join(values) + names, values = ', '.join(name_list), ', '.join(value_list) ret = 'oid, *' if qoid else '*' t = self._escape_qualified_name(table) - q = (f'INSERT INTO {t} ({names})' # noqa: S608 - f' VALUES ({values}) RETURNING {ret}') - self._do_debug(q, params) - q = self.db.query(q, params) - res = q.dictresult() + cmd = (f'INSERT INTO {t} ({names})' # noqa: S608 + f' VALUES ({values}) RETURNING {ret}') + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() if res: # this should always be true for n, value in res[0].items(): if qoid and n == 'oid': @@ -2280,7 +2335,8 @@ def insert(self, table, row=None, **kw): row[n] = value return row - def update(self, table, row=None, **kw): + def update(self, table: str, row: dict[str, Any] | None = None, **kw : Any + ) -> dict[str, Any]: """Update an existing row in a database table. Similar to insert, but updates an existing row. The update is based @@ -2304,39 +2360,40 @@ def update(self, table, row=None, **kw): if qoid and qoid in row and 'oid' not in row: row['oid'] = row[qoid] if qoid and 'oid' in row: # try using the oid - keyname = ('oid',) + keynames: tuple[str, ...] = ('oid',) + keyset = set(keynames) else: # try using the primary key try: - keyname = self.pkey(table, True) + keynames = self.pkeys(table) except KeyError as e: # the table has no primary key raise _prg_error(f'Table {table} has no primary key') from e + keyset = set(keynames) # check whether all key columns have values - if not set(keyname).issubset(row): + if not keyset.issubset(row): raise KeyError('Missing value for primary key in row') params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( - col(k), adapt(row[k], attnames[k])) for k in keyname) + col(k), adapt(row[k], attnames[k])) for k in keynames) if 'oid' in row: if qoid: row[qoid] = row['oid'] del row['oid'] - values = [] - keyname = set(keyname) + values_list = [] for n in attnames: - if n in row and n not in keyname and n not in generated: - values.append(f'{col(n)} = {adapt(row[n], attnames[n])}') - if not values: + if n in row and n not in keyset and n not in generated: + values_list.append(f'{col(n)} = {adapt(row[n], attnames[n])}') + if not values_list: return row - values = ', '.join(values) + values = ', '.join(values_list) ret = 'oid, *' if qoid else '*' t = self._escape_qualified_name(table) - q = (f'UPDATE {t} SET {values}' # noqa: S608 - f' WHERE {where} RETURNING {ret}') - self._do_debug(q, params) - q = self.db.query(q, params) - res = q.dictresult() + cmd = (f'UPDATE {t} SET {values}' # noqa: S608 + f' WHERE {where} RETURNING {ret}') + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() if res: # may be empty when row does not exist for n, value in res[0].items(): if qoid and n == 'oid': @@ -2344,7 +2401,8 @@ def update(self, table, row=None, **kw): row[n] = value return row - def upsert(self, table, row=None, **kw): + def upsert(self, table: str, row: dict[str, Any] | None = None, **kw: Any + ) -> dict[str, Any]: """Insert a row into a database table with conflict resolution. This method inserts a row into a table, but instead of raising a @@ -2402,22 +2460,22 @@ def upsert(self, table, row=None, **kw): params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier - names, values = [], [] + name_list, value_list = [], [] for n in attnames: if n in row and n not in generated: - names.append(col(n)) - values.append(adapt(row[n], attnames[n])) - names, values = ', '.join(names), ', '.join(values) + name_list.append(col(n)) + value_list.append(adapt(row[n], attnames[n])) + names, values = ', '.join(name_list), ', '.join(value_list) try: - keyname = self.pkey(table, True) + keynames = self.pkeys(table) except KeyError as e: raise _prg_error(f'Table {table} has no primary key') from e - target = ', '.join(col(k) for k in keyname) + target = ', '.join(col(k) for k in keynames) update = [] - keyname = set(keyname) - keyname.add('oid') + keyset = set(keynames) + keyset.add('oid') for n in attnames: - if n not in keyname and n not in generated: + if n not in keyset and n not in generated: value = kw.get(n, n in row) if value: if not isinstance(value, str): @@ -2428,12 +2486,12 @@ def upsert(self, table, row=None, **kw): do = 'update set ' + ', '.join(update) if update else 'nothing' ret = 'oid, *' if qoid else '*' t = self._escape_qualified_name(table) - q = (f'INSERT INTO {t} AS included ({names})' # noqa: S608 - f' VALUES ({values})' - f' ON CONFLICT ({target}) DO {do} RETURNING {ret}') - self._do_debug(q, params) - q = self.db.query(q, params) - res = q.dictresult() + cmd = (f'INSERT INTO {t} AS included ({names})' # noqa: S608 + f' VALUES ({values})' + f' ON CONFLICT ({target}) DO {do} RETURNING {ret}') + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() if res: # may be empty with "do nothing" for n, value in res[0].items(): if qoid and n == 'oid': @@ -2443,7 +2501,8 @@ def upsert(self, table, row=None, **kw): self.get(table, row) return row - def clear(self, table, row=None): + def clear(self, table: str, row: dict[str, Any] | None = None + ) -> dict[str, Any]: """Clear all the attributes to values determined by the types. Numeric types are set to 0, Booleans are set to false, and everything @@ -2467,7 +2526,8 @@ def clear(self, table, row=None): row[n] = '' return row - def delete(self, table, row=None, **kw): + def delete(self, table: str, row: dict[str, Any] | None = None, **kw: Any + ) -> int: """Delete an existing row in a database table. This method deletes the row from a table. It deletes based on the @@ -2492,31 +2552,33 @@ def delete(self, table, row=None, **kw): if qoid and qoid in row and 'oid' not in row: row['oid'] = row[qoid] if qoid and 'oid' in row: # try using the oid - keyname = ('oid',) + keynames: tuple[str, ...] = ('oid',) else: # try using the primary key try: - keyname = self.pkey(table, True) + keynames = self.pkeys(table) except KeyError as e: # the table has no primary key raise _prg_error(f'Table {table} has no primary key') from e # check whether all key columns have values - if not set(keyname).issubset(row): + if not set(keynames).issubset(row): raise KeyError('Missing value for primary key in row') params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( - col(k), adapt(row[k], attnames[k])) for k in keyname) + col(k), adapt(row[k], attnames[k])) for k in keynames) if 'oid' in row: if qoid: row[qoid] = row['oid'] del row['oid'] t = self._escape_qualified_name(table) - q = f'DELETE FROM {t} WHERE {where}' # noqa: S608 - self._do_debug(q, params) - res = self.db.query(q, params) + cmd = f'DELETE FROM {t} WHERE {where}' # noqa: S608 + self._do_debug(cmd, params) + res = self._valid_db.query(cmd, params) return int(res) - def truncate(self, table, restart=False, cascade=False, only=False): + def truncate(self, table: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str], restart: bool = False, + cascade: bool = False, only: bool = False) -> Query: """Empty a table or set of tables. This method quickly removes all rows from the given table or set @@ -2528,21 +2590,21 @@ def truncate(self, table, restart=False, cascade=False, only=False): If restart is set to True, sequences owned by columns of the truncated table(s) are automatically restarted. If cascade is set to True, it also truncates all tables that have foreign-key references to any of - the named tables. If the parameter only is not set to True, all the + the named tables. If the parameter 'only' is not set to True, all the descendant tables (if any) will also be truncated. Optionally, a '*' can be specified after the table name to explicitly indicate that descendant tables are included. """ if isinstance(table, str): - only = {table: only} + table_only = {table: only} table = [table] elif isinstance(table, (list, tuple)): if isinstance(only, (list, tuple)): - only = dict(zip(table, only)) + table_only = dict(zip(table, only)) else: - only = dict.fromkeys(table, only) + table_only = dict.fromkeys(table, only) elif isinstance(table, (set, frozenset)): - only = dict.fromkeys(table, only) + table_only = dict.fromkeys(table, only) else: raise TypeError('The table must be a string, list or set') if not (restart is None or isinstance(restart, (bool, int))): @@ -2551,7 +2613,7 @@ def truncate(self, table, restart=False, cascade=False, only=False): raise TypeError('Invalid type for the cascade option') tables = [] for t in table: - u = only.get(t) + u = table_only.get(t) if not (u is None or isinstance(u, (bool, int))): raise TypeError('Invalid type for the only option') if t.endswith('*'): @@ -2563,17 +2625,21 @@ def truncate(self, table, restart=False, cascade=False, only=False): if u: t = f'ONLY {t}' tables.append(t) - q = ['TRUNCATE', ', '.join(tables)] + cmd_parts = ['TRUNCATE', ', '.join(tables)] if restart: - q.append('RESTART IDENTITY') + cmd_parts.append('RESTART IDENTITY') if cascade: - q.append('CASCADE') - q = ' '.join(q) - self._do_debug(q) - return self.db.query(q) - - def get_as_list(self, table, what=None, where=None, - order=None, limit=None, offset=None, scalar=False): + cmd_parts.append('CASCADE') + cmd = ' '.join(cmd_parts) + self._do_debug(cmd) + return self._valid_db.query(cmd) + + def get_as_list(self, table: str, + what: str | list[str] | tuple[str, ...] | None = None, + where: str | list[str] | tuple[str, ...] | None = None, + order: str | list[str] | tuple[str, ...] | None = None, + limit: int | None = None, offset: int | None = None, + scalar: bool = False) -> list: """Get a table as a list. This gets a convenient representation of the table as a list @@ -2585,16 +2651,18 @@ def get_as_list(self, table, what=None, where=None, The parameter 'what' can restrict the query to only return a subset of the table columns. It can be a string, list or a tuple. + The parameter 'where' can restrict the query to only return a subset of the table rows. It can be a string, list or a tuple - of SQL expressions that all need to be fulfilled. The parameter - 'order' specifies the ordering of the rows. It can also be a - other string, list or a tuple. If no ordering is specified, - the result will be ordered by the primary key(s) or all columns - if no primary key exists. You can set 'order' to False if you - don't care about the ordering. The parameters 'limit' and 'offset' - can be integers specifying the maximum number of rows returned - and a number of rows skipped over. + of SQL expressions that all need to be fulfilled. + + The parameter 'order' specifies the ordering of the rows. It can + also be a string, list or a tuple. If no ordering is specified, + the result will be ordered by the primary key(s) or all columns if + no primary key exists. You can set 'order' to False if you don't + care about the ordering. The parameters 'limit' and 'offset' can be + integers specifying the maximum number of rows returned and a number + of rows skipped over. If you set the 'scalar' option to True, then instead of the named tuples you will get the first items of these tuples. @@ -2609,35 +2677,40 @@ def get_as_list(self, table, what=None, where=None, order = what else: what = '*' - q = ['SELECT', what, 'FROM', table] + cmd_parts = ['SELECT', what, 'FROM', table] if where: if isinstance(where, (list, tuple)): where = ' AND '.join(map(str, where)) - q.extend(['WHERE', where]) + cmd_parts.extend(['WHERE', where]) if order is None: try: - order = self.pkey(table, True) + order = self.pkeys(table) except (KeyError, ProgrammingError): with suppress(KeyError, ProgrammingError): order = list(self.get_attnames(table)) if order: if isinstance(order, (list, tuple)): order = ', '.join(map(str, order)) - q.extend(['ORDER BY', order]) + cmd_parts.extend(['ORDER BY', order]) if limit: - q.append(f'LIMIT {limit}') + cmd_parts.append(f'LIMIT {limit}') if offset: - q.append(f'OFFSET {offset}') - q = ' '.join(q) - self._do_debug(q) - q = self.db.query(q) - res = q.namedresult() + cmd_parts.append(f'OFFSET {offset}') + cmd = ' '.join(cmd_parts) + self._do_debug(cmd) + query = self._valid_db.query(cmd) + res = query.namedresult() if res and scalar: res = [row[0] for row in res] return res - def get_as_dict(self, table, keyname=None, what=None, where=None, - order=None, limit=None, offset=None, scalar=False): + def get_as_dict(self, table: str, + keyname: str | list[str] | tuple[str, ...] | None = None, + what: str | list[str] | tuple[str, ...] | None = None, + where: str | list[str] | tuple[str, ...] | None = None, + order: str | list[str] | tuple[str, ...] | None = None, + limit: int | None = None, offset: int | None = None, + scalar: bool = False) -> dict: """Get a table as a dictionary. This method is similar to get_as_list(), but returns the table @@ -2652,7 +2725,7 @@ def get_as_dict(self, table, keyname=None, what=None, where=None, be set as a string, list or a tuple. If the Python version supports it, the dictionary will be an - OrderedDict using the order specified with the 'order' parameter + dict using the order specified with the 'order' parameter or the key column(s) if not specified. You can set 'order' to False if you don't care about the ordering. In this case the returned dictionary will be an ordinary one. @@ -2661,12 +2734,14 @@ def get_as_dict(self, table, keyname=None, what=None, where=None, raise TypeError('The table name is missing') if not keyname: try: - keyname = self.pkey(table, True) + keyname = self.pkeys(table) except (KeyError, ProgrammingError) as e: raise _prg_error(f'Table {table} has no primary key') from e if isinstance(keyname, str): - keyname = [keyname] - elif not isinstance(keyname, (list, tuple)): + keynames: list[str] | tuple[str, ...] = (keyname,) + elif isinstance(keyname, (list, tuple)): + keynames = keyname + else: raise KeyError('The keyname must be a string, list or tuple') if what: if isinstance(what, (list, tuple)): @@ -2675,64 +2750,68 @@ def get_as_dict(self, table, keyname=None, what=None, where=None, order = what else: what = '*' - q = ['SELECT', what, 'FROM', table] + cmd_parts = ['SELECT', what, 'FROM', table] if where: if isinstance(where, (list, tuple)): where = ' AND '.join(map(str, where)) - q.extend(['WHERE', where]) + cmd_parts.extend(['WHERE', where]) if order is None: order = keyname if order: if isinstance(order, (list, tuple)): order = ', '.join(map(str, order)) - q.extend(['ORDER BY', order]) + cmd_parts.extend(['ORDER BY', order]) if limit: - q.append(f'LIMIT {limit}') + cmd_parts.append(f'LIMIT {limit}') if offset: - q.append(f'OFFSET {offset}') - q = ' '.join(q) - self._do_debug(q) - q = self.db.query(q) - res = q.getresult() - cls = OrderedDict if order else dict + cmd_parts.append(f'OFFSET {offset}') + cmd = ' '.join(cmd_parts) + self._do_debug(cmd) + query = self._valid_db.query(cmd) + res = query.getresult() if not res: - return cls() - keyset = set(keyname) - fields = q.listfields() + return {} + keyset = set(keynames) + fields = query.listfields() if not keyset.issubset(fields): raise KeyError('Missing keyname in row') - keyind, rowind = [], [] + key_index: list[int] = [] + row_index: list[int] = [] for i, f in enumerate(fields): - (keyind if f in keyset else rowind).append(i) - keytuple = len(keyind) > 1 - getkey = itemgetter(*keyind) - keys = map(getkey, res) + (key_index if f in keyset else row_index).append(i) + key_tuple = len(key_index) > 1 + get_key = itemgetter(*key_index) + keys = map(get_key, res) if scalar: - rowind = rowind[:1] - rowtuple = False + row_index = row_index[:1] + row_is_tuple = False else: - rowtuple = len(rowind) > 1 - if scalar or rowtuple: - getrow = itemgetter(*rowind) + row_is_tuple = len(row_index) > 1 + if scalar or row_is_tuple: + get_row: Callable[[tuple], tuple] = itemgetter( # pyright: ignore + *row_index) else: - rowind = rowind[0] + frst_index = row_index[0] - def getrow(row): - return row[rowind], # tuple with one item + def get_row(row : tuple) -> tuple: + return row[frst_index], # tuple with one item - rowtuple = True - rows = map(getrow, res) - if keytuple or rowtuple: - if keytuple: - keys = _namediter(_MemoryQuery(keys, keyname)) - if rowtuple: + row_is_tuple = True + rows = map(get_row, res) + if key_tuple or row_is_tuple: + if key_tuple: + keys = _namediter(_MemoryQuery(keys, keynames)) # type: ignore + if row_is_tuple: fields = [f for f in fields if f not in keyset] - rows = _namediter(_MemoryQuery(rows, fields)) + rows = _namediter(_MemoryQuery(rows, fields)) # type: ignore # noinspection PyArgumentList - return cls(zip(keys, rows)) + return dict(zip(keys, rows)) - def notification_handler(self, event, callback, - arg_dict=None, timeout=None, stop_event=None): + def notification_handler(self, event: str, callback: Callable, + arg_dict: dict | None = None, + timeout: int | float | None = None, + stop_event: str | None = None + ) -> NotificationHandler: """Get notification handler that will run the given callback.""" return NotificationHandler(self, event, callback, arg_dict, timeout, stop_event) diff --git a/pgdb.py b/pgdb.py index 2e48e39d..df23bbfd 100644 --- a/pgdb.py +++ b/pgdb.py @@ -64,6 +64,8 @@ connection.close() # close the connection """ +from __future__ import annotations + from collections import namedtuple from collections.abc import Iterable from contextlib import suppress @@ -76,7 +78,7 @@ from math import isinf, isnan from re import compile as regex from time import localtime -from typing import ClassVar, Dict, Type +from typing import Callable, ClassVar from uuid import UUID as Uuid # noqa: N811 try: @@ -91,15 +93,16 @@ if os.path.exists(os.path.join(path, libpq))] if sys.version_info >= (3, 8): # see https://docs.python.org/3/whatsnew/3.8.html#ctypes + add_dll_dir = os.add_dll_directory # type: ignore for path in paths: - with os.add_dll_directory(os.path.abspath(path)): + with add_dll_dir(os.path.abspath(path)): try: - from _pg import version + from _pg import version # type: ignore except ImportError: pass else: del version - e = None + e = None # type: ignore break if paths: libpq = 'compatible ' + libpq @@ -140,7 +143,7 @@ 'Date', 'Time', 'Timestamp', 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks', 'Binary', 'Interval', 'Uuid', - 'Hstore', 'Json', 'Literal', 'Type', + 'Hstore', 'Json', 'Literal', 'DbType', 'STRING', 'BINARY', 'NUMBER', 'DATETIME', 'ROWID', 'BOOL', 'SMALLINT', 'INTEGER', 'LONG', 'FLOAT', 'NUMERIC', 'MONEY', 'DATE', 'TIME', 'TIMESTAMP', 'INTERVAL', @@ -150,9 +153,10 @@ 'IntegrityError', 'InternalError', 'ProgrammingError', 'NotSupportedError', 'apilevel', 'connect', 'paramstyle', 'threadsafety', 'get_typecast', 'set_typecast', 'reset_typecast', - 'version', '__version__'] + 'version', '__version__', +] -Decimal = StdDecimal +Decimal: type = StdDecimal # *** Module Constants *** @@ -173,17 +177,19 @@ # *** Internal Type Handling *** -def get_args(func): +def get_args(func: Callable) -> list: return list(signature(func).parameters) # time zones used in Postgres timestamptz output -_timezones = dict(CET='+0100', EET='+0200', EST='-0500', - GMT='+0000', HST='-1000', MET='+0100', MST='-0700', - UCT='+0000', UTC='+0000', WET='+0000') +_timezones: dict[str, str] = { + 'CET': '+0100', 'EET': '+0200', 'EST': '-0500', + 'GMT': '+0000', 'HST': '-1000', 'MET': '+0100', 'MST': '-0700', + 'UCT': '+0000', 'UTC': '+0000', 'WET': '+0000' +} -def _timezone_as_offset(tz): +def _timezone_as_offset(tz: str) -> str: if tz.startswith(('+', '-')): if len(tz) < 5: return tz + '00' @@ -191,7 +197,7 @@ def _timezone_as_offset(tz): return _timezones.get(tz, '+0000') -def decimal_type(decimal_type=None): +def decimal_type(decimal_type: type | None = None): """Get or set global type to be used for decimal values. Note that connections cache cast functions. To be sure a global change @@ -204,25 +210,25 @@ def decimal_type(decimal_type=None): return Decimal -def cast_bool(value): +def cast_bool(value: str) -> bool | None: """Cast boolean value in database format to bool.""" if value: return value[0] in ('t', 'T') -def cast_money(value): +def cast_money(value: str) -> Decimal | None: # pyright: ignore """Cast money value in database format to Decimal.""" if value: value = value.replace('(', '-') return Decimal(''.join(c for c in value if c.isdigit() or c in '.-')) -def cast_int2vector(value): +def cast_int2vector(value: str) -> list[int]: """Cast an int2vector value.""" return [int(v) for v in value.split()] -def cast_date(value, connection): +def cast_date(value: str, connection) -> date: """Cast a date value.""" # The output format depends on the server setting DateStyle. The default # setting ISO and the setting for German are actually unambiguous. The @@ -232,17 +238,17 @@ def cast_date(value, connection): return date.min if value == 'infinity': return date.max - value = value.split() - if value[-1] == 'BC': + values = value.split() + if values[-1] == 'BC': return date.min - value = value[0] + value = values[0] if len(value) > 10: return date.max - fmt = connection.date_format() - return datetime.strptime(value, fmt).date() + format = connection.date_format() + return datetime.strptime(value, format).date() -def cast_time(value): +def cast_time(value: str) -> time: """Cast a time value.""" fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' return datetime.strptime(value, fmt).time() @@ -251,74 +257,74 @@ def cast_time(value): _re_timezone = regex('(.*)([+-].*)') -def cast_timetz(value): +def cast_timetz(value: str) -> time: """Cast a timetz value.""" - tz = _re_timezone.match(value) - if tz: - value, tz = tz.groups() + m = _re_timezone.match(value) + if m: + value, tz = m.groups() else: tz = '+0000' - fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' + format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' value += _timezone_as_offset(tz) - fmt += '%z' - return datetime.strptime(value, fmt).timetz() + format += '%z' + return datetime.strptime(value, format).timetz() -def cast_timestamp(value, connection): +def cast_timestamp(value: str, connection) -> datetime: """Cast a timestamp value.""" if value == '-infinity': return datetime.min if value == 'infinity': return datetime.max - value = value.split() - if value[-1] == 'BC': + values = value.split() + if values[-1] == 'BC': return datetime.min - fmt = connection.date_format() - if fmt.endswith('-%Y') and len(value) > 2: - value = value[1:5] - if len(value[3]) > 4: + format = connection.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:5] + if len(values[3]) > 4: return datetime.max - fmt = ['%d %b' if fmt.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] else: - if len(value[0]) > 10: + if len(values[0]) > 10: return datetime.max - fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S'] - return datetime.strptime(' '.join(value), ' '.join(fmt)) + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + return datetime.strptime(' '.join(values), ' '.join(formats)) -def cast_timestamptz(value, connection): +def cast_timestamptz(value: str, connection) -> datetime: """Cast a timestamptz value.""" if value == '-infinity': return datetime.min if value == 'infinity': return datetime.max - value = value.split() - if value[-1] == 'BC': + values = value.split() + if values[-1] == 'BC': return datetime.min - fmt = connection.date_format() - if fmt.endswith('-%Y') and len(value) > 2: - value = value[1:] - if len(value[3]) > 4: + format = connection.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:] + if len(values[3]) > 4: return datetime.max - fmt = ['%d %b' if fmt.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] - value, tz = value[:-1], value[-1] + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] + values, tz = values[:-1], values[-1] else: - if fmt.startswith('%Y-'): - tz = _re_timezone.match(value[1]) - if tz: - value[1], tz = tz.groups() + if format.startswith('%Y-'): + m = _re_timezone.match(values[1]) + if m: + values[1], tz = m.groups() else: tz = '+0000' else: - value, tz = value[:-1], value[-1] - if len(value[0]) > 10: + values, tz = values[:-1], values[-1] + if len(values[0]) > 10: return datetime.max - fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S'] - value.append(_timezone_as_offset(tz)) - fmt.append('%z') - return datetime.strptime(' '.join(value), ' '.join(fmt)) + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + values.append(_timezone_as_offset(tz)) + formats.append('%z') + return datetime.strptime(' '.join(values), ' '.join(formats)) _re_interval_sql_standard = regex( @@ -349,37 +355,37 @@ def cast_timestamptz(value, connection): '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?') -def cast_interval(value): +def cast_interval(value: str) -> timedelta: """Cast an interval value.""" # The output format depends on the server setting IntervalStyle, but it's # not necessary to consult this setting to parse it. It's faster to just # check all possible formats, and there is no ambiguity here. m = _re_interval_iso_8601.match(value) if m: - m = [d or '0' for d in m.groups()] - secs_ago = m.pop(5) == '-' - m = [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m + s = [v or '0' for v in m.groups()] + secs_ago = s.pop(5) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d if secs_ago: secs = -secs usecs = -usecs else: m = _re_interval_postgres_verbose.match(value) if m: - m, ago = [d or '0' for d in m.groups()[:8]], m.group(9) - secs_ago = m.pop(5) == '-' - m = [-int(d) for d in m] if ago else [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m + s, ago = [v or '0' for v in m.groups()[:8]], m.group(9) + secs_ago = s.pop(5) == '-' + d = [-int(v) for v in s] if ago else [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d if secs_ago: secs = - secs usecs = -usecs else: m = _re_interval_postgres.match(value) if m and any(m.groups()): - m = [d or '0' for d in m.groups()] - hours_ago = m.pop(3) == '-' - m = [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m + s = [v or '0' for v in m.groups()] + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d if hours_ago: hours = -hours mins = -mins @@ -388,11 +394,11 @@ def cast_interval(value): else: m = _re_interval_sql_standard.match(value) if m and any(m.groups()): - m = [d or '0' for d in m.groups()] - years_ago = m.pop(0) == '-' - hours_ago = m.pop(3) == '-' - m = [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m + s = [v or '0' for v in m.groups()] + years_ago = s.pop(0) == '-' + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d if years_ago: years = -years mons = -mons @@ -419,7 +425,7 @@ class Typecasts(dict): # the default cast functions # (str functions are ignored but have been added for faster access) - defaults: ClassVar[Dict[str, Type]] = { + defaults: ClassVar[dict[str, Callable]] = { 'char': str, 'bpchar': str, 'name': str, 'text': str, 'varchar': str, 'sql_identifier': str, 'bool': cast_bool, 'bytea': unescape_bytea, @@ -759,10 +765,6 @@ def _op_error(msg): # *** Row Tuples *** - -_re_fieldname = regex('^[A-Za-z][_a-zA-Z0-9]*$') - - # The result rows for database operations are returned as named tuples # by default. Since creating namedtuple classes is a somewhat expensive # operation, we cache up to 1024 of these classes by default. diff --git a/pgmodule.c b/pgmodule.c index 628de9ec..64e769f6 100644 --- a/pgmodule.c +++ b/pgmodule.c @@ -21,7 +21,7 @@ static PyObject *Error, *Warning, *InterfaceError, *DatabaseError, *InternalError, *OperationalError, *ProgrammingError, *IntegrityError, *DataError, *NotSupportedError, *InvalidResultError, *NoResultError, - *MultipleResultsError; + *MultipleResultsError, *Connection, *Query; #define _TOSTRING(x) #x #define TOSTRING(x) _TOSTRING(x) @@ -1305,6 +1305,12 @@ PyInit__pg(void) InvalidResultError, NULL); PyDict_SetItemString(dict, "MultipleResultsError", MultipleResultsError); + /* Types */ + Connection = (PyObject *)&connType; + PyDict_SetItemString(dict, "Connection", Connection); + Query = (PyObject *)&queryType; + PyDict_SetItemString(dict, "Query", Query); + /* Make the version available */ s = PyUnicode_FromString(PyPgVersion); PyDict_SetItemString(dict, "version", s); diff --git a/pyproject.toml b/pyproject.toml index 131308b8..1016b433 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,21 @@ exclude = [ [tool.ruff.per-file-ignores] "tests/*.py" = ["D100", "D101", "D102", "D103", "D105", "D107", "S"] +[tool.mypy] +python_version = "3.11" +check_untyped_defs = true +no_implicit_optional = true +strict_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +disallow_untyped_defs = true + +[[tool.mypy.overrides]] +module = [ + "tests.*" +] +disallow_untyped_defs = false + [tool.setuptools] py-modules = ["pg", "pgdb"] license-files = ["LICENSE.txt"] diff --git a/tests/config.py b/tests/config.py index f6280548..0b15f62e 100644 --- a/tests/config.py +++ b/tests/config.py @@ -26,9 +26,9 @@ dbport = int(dbport) try: - from .LOCAL_PyGreSQL import * # noqa: F403 + from .LOCAL_PyGreSQL import * # type: ignore # noqa except (ImportError, ValueError): - try: # noqa: SIM105 - from LOCAL_PyGreSQL import * # noqa: F403 + try: # noqa + from LOCAL_PyGreSQL import * # type: ignore # noqa except ImportError: pass diff --git a/tests/dbapi20.py b/tests/dbapi20.py index d5f2938f..f72d99f7 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -7,13 +7,14 @@ Some modernization of the code has been done by the PyGreSQL team. """ -__version__ = '1.15.0' +from __future__ import annotations import time import unittest from contextlib import suppress -from typing import Any, Mapping, Tuple +from typing import Any, Mapping +__version__ = '1.15.0' class DatabaseAPI20Test(unittest.TestCase): """Test a database self.driver for DB API 2.0 compatibility. @@ -41,7 +42,7 @@ class mytest(dbapi20.DatabaseAPI20Test): # The self.driver module. This should be the module where the 'connect' # method is to be found driver: Any = None - connect_args: Tuple = () # List of arguments to pass to connect + connect_args: tuple = () # List of arguments to pass to connect connect_kw_args: Mapping[str, Any] = {} # Keyword arguments for connect table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 7d4409df..242fdbb5 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -9,6 +9,8 @@ These tests need a database to test against. """ +from __future__ import annotations + import os import threading import time @@ -17,7 +19,7 @@ from collections.abc import Iterable from contextlib import suppress from decimal import Decimal -from typing import Sequence, Tuple +from typing import Sequence import pg # the module under test @@ -532,9 +534,9 @@ def test_namedresult_with_good_fieldnames(self): self.assertEqual(v._fields, ('snake_case_alias', 'CamelCaseAlias')) def test_namedresult_with_bad_fieldnames(self): - r = namedtuple('Bad', ['?'] * 6, rename=True) + t = namedtuple('Bad', ['?'] * 6, rename=True) # type: ignore # noinspection PyUnresolvedReferences - fields = r._fields + fields = t._fields q = ('select 3 as "0alias", 4 as _alias, 5 as "alias$", 6 as "alias?",' ' 7 as "kebap-case-alias", 8 as break, 9 as and_a_good_one') result = [tuple(range(3, 10))] @@ -820,45 +822,44 @@ def tearDown(self): def test_getresul_ascii(self): result = 'Hello, world!' - q = f"select '{result}'" - v = self.c.query(q).getresult()[0][0] + cmd = f"select '{result}'" + v = self.c.query(cmd).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) def test_dictresul_ascii(self): result = 'Hello, world!' - q = f"select '{result}' as greeting" - v = self.c.query(q).dictresult()[0]['greeting'] + cmd = f"select '{result}' as greeting" + v = self.c.query(cmd).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) def test_getresult_utf8(self): result = 'Hello, wörld & мир!' - q = f"select '{result}'" + cmd = f"select '{result}'" # pass the query as unicode try: - v = self.c.query(q).getresult()[0][0] + v = self.c.query(cmd).getresult()[0][0] except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode() - # pass the query as bytes - v = self.c.query(q).getresult()[0][0] + cmd_bytes = cmd.encode() + v = self.c.query(cmd_bytes).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) def test_dictresult_utf8(self): result = 'Hello, wörld & мир!' - q = f"select '{result}' as greeting" + cmd = f"select '{result}' as greeting" try: - v = self.c.query(q).dictresult()[0]['greeting'] + v = self.c.query(cmd).dictresult()[0]['greeting'] except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode() - v = self.c.query(q).dictresult()[0]['greeting'] + cmd_bytes = cmd.encode() + v = self.c.query(cmd_bytes).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -868,12 +869,12 @@ def test_getresult_latin1(self): except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") result = 'Hello, wörld!' - q = f"select '{result}'" - v = self.c.query(q).getresult()[0][0] + cmd = f"select '{result}'" + v = self.c.query(cmd).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('latin1') - v = self.c.query(q).getresult()[0][0] + cmd_bytes = cmd.encode('latin1') + v = self.c.query(cmd_bytes).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -883,12 +884,12 @@ def test_dictresult_latin1(self): except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") result = 'Hello, wörld!' - q = f"select '{result}' as greeting" - v = self.c.query(q).dictresult()[0]['greeting'] + cmd = f"select '{result}' as greeting" + v = self.c.query(cmd).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('latin1') - v = self.c.query(q).dictresult()[0]['greeting'] + cmd_bytes = cmd.encode('latin1') + v = self.c.query(cmd_bytes).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -898,12 +899,12 @@ def test_getresult_cyrillic(self): except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support cyrillic") result = 'Hello, мир!' - q = f"select '{result}'" - v = self.c.query(q).getresult()[0][0] + cmd = f"select '{result}'" + v = self.c.query(cmd).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('cyrillic') - v = self.c.query(q).getresult()[0][0] + cmd_bytes = cmd.encode('cyrillic') + v = self.c.query(cmd_bytes).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -913,12 +914,12 @@ def test_dictresult_cyrillic(self): except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support cyrillic") result = 'Hello, мир!' - q = f"select '{result}' as greeting" - v = self.c.query(q).dictresult()[0]['greeting'] + cmd = f"select '{result}' as greeting" + v = self.c.query(cmd).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('cyrillic') - v = self.c.query(q).dictresult()[0]['greeting'] + cmd_bytes = cmd.encode('cyrillic') + v = self.c.query(cmd_bytes).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -928,12 +929,12 @@ def test_getresult_latin9(self): except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin9") result = 'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' - q = f"select '{result}'" - v = self.c.query(q).getresult()[0][0] + cmd = f"select '{result}'" + v = self.c.query(cmd).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('latin9') - v = self.c.query(q).getresult()[0][0] + cmd_bytes = cmd.encode('latin9') + v = self.c.query(cmd_bytes).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -943,12 +944,12 @@ def test_dictresult_latin9(self): except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin9") result = 'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' - q = f"select '{result}' as menu" - v = self.c.query(q).dictresult()[0]['menu'] + cmd = f"select '{result}' as menu" + v = self.c.query(cmd).dictresult()[0]['menu'] self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('latin9') - v = self.c.query(q).dictresult()[0]['menu'] + cmd_bytes = cmd.encode('latin9') + v = self.c.query(cmd_bytes).dictresult()[0]['menu'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -1698,6 +1699,7 @@ class TestInserttable(unittest.TestCase): """Test inserttable method.""" cls_set_up = False + has_encoding = False @classmethod def setUpClass(cls): @@ -1738,7 +1740,7 @@ def tearDown(self): self.c.query("truncate table test") self.c.close() - data: Sequence[Tuple] = [ + data: Sequence[tuple] = [ (-1, -1, -1, True, '1492-10-12', '08:30:00', -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'), (0, 0, 0, False, '1607-04-14', '09:00:00', @@ -1868,7 +1870,7 @@ def test_inserttable_from_list_of_sets(self): def test_inserttable_multiple_rows(self): num_rows = 100 - data = self.data[2:3] * num_rows + data = list(self.data[2:3]) * num_rows self.c.inserttable('test', data) r = self.c.query("select count(*) from test").getresult()[0][0] self.assertEqual(r, num_rows) @@ -1892,13 +1894,13 @@ def test_inserttable_no_column(self): self.assertEqual(self.get_back(), []) def test_inserttable_only_one_column(self): - data = [(42,)] * 50 + data: list[tuple] = [(42,)] * 50 self.c.inserttable('test', data, ['i4']) data = [tuple([42 if i == 1 else None for i in range(14)])] * 50 self.assertEqual(self.get_back(), data) def test_inserttable_only_two_columns(self): - data = [(bool(i % 2), i * .5) for i in range(20)] + data: list[tuple] = [(bool(i % 2), i * .5) for i in range(20)] self.c.inserttable('test', data, ('b', 'f4')) # noinspection PyTypeChecker data = [(None,) * 3 + (bool(i % 2),) + (None,) * 3 + (i * .5,) @@ -2021,7 +2023,7 @@ def test_inserttable_unicode_latin1(self): self.skipTest("database does not support latin1") # non-ascii chars do not fit in char(1) when there is no encoding c = '€' if self.has_encoding else '$' - row_unicode = ( + row_unicode: tuple = ( 0, 0, 0, False, '1970-01-01', '00:00:00', 0.0, 0.0, 0.0, '0.0', c, 'bäd', 'bäd', "for käse and pont-l'évêque pay in €") diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 3884436f..31aec400 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -9,19 +9,21 @@ These tests need a database to test against. """ +from __future__ import annotations + import gc import json import os import sys import tempfile import unittest -from collections import OrderedDict from contextlib import suppress from datetime import date, datetime, time, timedelta from decimal import Decimal from io import StringIO from operator import itemgetter from time import strftime +from typing import Any, ClassVar from uuid import UUID import pg # the module under test @@ -51,20 +53,19 @@ class TestAttrDict(unittest.TestCase): """Test the simple ordered dictionary for attribute names.""" cls = pg.AttrDict - base = OrderedDict def test_init(self): a = self.cls() - self.assertIsInstance(a, self.base) - self.assertEqual(a, self.base()) + self.assertIsInstance(a, dict) + self.assertEqual(a, {}) items = [('id', 'int'), ('name', 'text')] a = self.cls(items) - self.assertIsInstance(a, self.base) - self.assertEqual(a, self.base(items)) + self.assertIsInstance(a, dict) + self.assertEqual(a, dict(items)) iteritems = iter(items) a = self.cls(iteritems) - self.assertIsInstance(a, self.base) - self.assertEqual(a, self.base(items)) + self.assertIsInstance(a, dict) + self.assertEqual(a, dict(items)) def test_iter(self): a = self.cls() @@ -127,7 +128,7 @@ def test_write_methods(self): self.assertEqual(a['id'], 1) for method in 'clear', 'update', 'pop', 'setdefault', 'popitem': method = getattr(a, method) - self.assertRaises(TypeError, method, a) + self.assertRaises(TypeError, method, a) # type: ignore class TestDBClassInit(unittest.TestCase): @@ -193,7 +194,7 @@ def test_all_db_attributes(self): 'locreate', 'loimport', 'notification_handler', 'options', - 'parameter', 'pkey', 'poll', 'port', + 'parameter', 'pkey', 'pkeys', 'poll', 'port', 'prepare', 'protocol_version', 'putline', 'query', 'query_formatted', 'query_prepared', 'release', 'reopen', 'reset', 'rollback', @@ -416,11 +417,12 @@ class TestDBClass(unittest.TestCase): cls_set_up = False regtypes = None + supports_oids = False @classmethod def setUpClass(cls): db = DB() - cls.oids = db.server_version < 120000 + cls.supports_oids = db.server_version < 120000 db.query("drop table if exists test cascade") db.query("create table test (" "i2 smallint, i4 integer, i8 bigint," @@ -469,21 +471,21 @@ def create_table(self, table, definition, if not as_query and not definition.startswith('('): definition = f'({definition})' with_oids = 'with oids' if oids else ( - 'without oids' if self.oids else '') - q = ['create', temporary, table] + 'without oids' if self.supports_oids else '') + cmd_parts = ['create', temporary, table] if as_query: - q.extend([with_oids, definition]) + cmd_parts.extend([with_oids, definition]) else: - q.extend([definition, with_oids]) - q = ' '.join(q) - query(q) + cmd_parts.extend([definition, with_oids]) + cmd = ' '.join(cmd_parts) + query(cmd) if values: for params in values: if not isinstance(params, (list, tuple)): params = [params] values = ', '.join(f'${n + 1}' for n in range(len(params))) - q = f"insert into {table} values ({values})" - query(q, params) + cmd = f"insert into {table} values ({values})" + query(cmd, params) def test_class_name(self): self.assertEqual(self.db.__class__.__name__, 'DB') @@ -494,7 +496,7 @@ def test_module_name(self): def test_escape_literal(self): f = self.db.escape_literal - r = f(b"plain") + r: Any = f(b"plain") self.assertIsInstance(r, bytes) self.assertEqual(r, b"'plain'") r = f("plain") @@ -846,7 +848,7 @@ def test_create_table(self): self.assertEqual(r, "Hello, World!") def test_create_table_with_oids(self): - if not self.oids: + if not self.supports_oids: self.skipTest("database does not support tables with oids") table = 'test hello world' values = [(2, "World!"), (1, "Hello")] @@ -893,7 +895,7 @@ def test_query(self): self.assertEqual(r, '5') def test_query_with_oids(self): - if not self.oids: + if not self.supports_oids: self.skipTest("database does not support tables with oids") query = self.db.query table = 'test_table' @@ -1175,6 +1177,16 @@ def test_pkey(self): # we get the changed primary key when the cache is flushed self.assertEqual(pkey(f'{t}1', flush=True), 'x') + def test_pkeys(self): + pkeys = self.db.pkeys + t = 'pkeys_test_' + self.create_table(f'{t}0', 'a int') + self.create_table(f'{t}1', 'a int primary key, b int') + self.create_table(f'{t}2', 'a int, b int, c int, primary key (a, c)') + self.assertRaises(KeyError, pkeys, f'{t}0') + self.assertEqual(pkeys(f'{t}1'), ('a',)) + self.assertEqual(pkeys(f'{t}2'), ('a', 'c')) + def test_get_databases(self): databases = self.db.get_databases() self.assertIn('template0', databases) @@ -1194,11 +1206,11 @@ def test_get_tables(self): before_tables = get_tables() self.assertIsInstance(before_tables, list) for t in before_tables: - t = t.split('.', 1) - self.assertGreaterEqual(len(t), 2) - if len(t) > 2: - self.assertTrue(t[1].startswith('"')) - t = t[0] + s = t.split('.', 1) + self.assertGreaterEqual(len(s), 2) + if len(s) > 2: + self.assertTrue(s[1].startswith('"')) + t = s[0] self.assertNotEqual(t, 'information_schema') self.assertFalse(t.startswith('pg_')) for t in tables: @@ -1392,41 +1404,37 @@ def test_get_attnames_is_cached(self): def test_get_attnames_is_ordered(self): get_attnames = self.db.get_attnames r = get_attnames('test', flush=True) - self.assertIsInstance(r, OrderedDict) + self.assertIsInstance(r, dict) if self.regtypes: - self.assertEqual(r, OrderedDict([ - ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'), - ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'), - ('m', 'money'), ('v4', 'character varying'), - ('c4', 'character'), ('t', 'text')])) + self.assertEqual(r, { + 'i2': 'smallint', 'i4': 'integer', 'i8': 'bigint', + 'd': 'numeric', 'f4': 'real', 'f8': 'double precision', + 'm': 'money', 'v4': 'character varying', + 'c4': 'character', 't': 'text'}) else: - self.assertEqual(r, OrderedDict([ - ('i2', 'int'), ('i4', 'int'), ('i8', 'int'), - ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'), - ('v4', 'text'), ('c4', 'text'), ('t', 'text')])) - if OrderedDict is not dict: - r = ' '.join(list(r.keys())) - self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t') + self.assertEqual(r, { + 'i2': 'int', 'i4': 'int', 'i8': 'int', + 'd': 'num', 'f4': 'float', 'f8': 'float', 'm': 'money', + 'v4': 'text', 'c4': 'text', 't': 'text'}) + r = ' '.join(list(r.keys())) + self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t') table = 'test table for get_attnames' self.create_table( table, 'n int, alpha smallint, v varchar(3),' ' gamma char(5), tau text, beta bool') r = get_attnames(table) - self.assertIsInstance(r, OrderedDict) + self.assertIsInstance(r, dict) if self.regtypes: - self.assertEqual(r, OrderedDict([ - ('n', 'integer'), ('alpha', 'smallint'), - ('v', 'character varying'), ('gamma', 'character'), - ('tau', 'text'), ('beta', 'boolean')])) - else: - self.assertEqual(r, OrderedDict([ - ('n', 'int'), ('alpha', 'int'), ('v', 'text'), - ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')])) - if OrderedDict is not dict: - r = ' '.join(list(r.keys())) - self.assertEqual(r, 'n alpha v gamma tau beta') + self.assertEqual(r, { + 'n': 'integer', 'alpha': 'smallint', + 'v': 'character varying', 'gamma': 'character', + 'tau': 'text', 'beta': 'boolean'}) else: - self.skipTest('OrderedDict is not supported') + self.assertEqual(r, { + 'n': 'int', 'alpha': 'int', 'v': 'text', + 'gamma': 'text', 'tau': 'text', 'beta': 'bool'}) + r = ' '.join(list(r.keys())) + self.assertEqual(r, 'n alpha v gamma tau beta') def test_get_attnames_is_attr_dict(self): AttrDict = pg.AttrDict # noqa: N806 @@ -1541,7 +1549,7 @@ def test_get(self): self.create_table(table, 'n integer, t text', values=enumerate('xyz', start=1)) self.assertRaises(pg.ProgrammingError, get, table, 2) - r = get(table, 2, 'n') + r: Any = get(table, 2, 'n') self.assertIsInstance(r, dict) self.assertEqual(r, dict(n=2, t='y')) r = get(table, 1, 'n') @@ -1554,7 +1562,7 @@ def test_get(self): self.assertRaises(pg.DatabaseError, get, table, 4, 'n') self.assertRaises(pg.DatabaseError, get, table, 'y') self.assertRaises(pg.DatabaseError, get, table, 2, 't') - s = dict(n=3) + s: dict = dict(n=3) self.assertRaises(pg.ProgrammingError, get, table, s) r = get(table, s, 'n') self.assertIs(r, s) @@ -1588,7 +1596,7 @@ def test_get(self): self.assertRaises(KeyError, get, table, s) def test_get_with_oids(self): - if not self.oids: + if not self.supports_oids: self.skipTest("database does not support tables with oids") get = self.db.get query = self.db.query @@ -1753,7 +1761,7 @@ def test_insert(self): ' d numeric, f4 real, f8 double precision, m money,' ' v4 varchar(4), c4 char(4), t text,' ' b boolean, ts timestamp') - tests = [ + tests: list[dict | tuple[dict, dict]] = [ dict(i2=None, i4=None, i8=None), (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)), (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)), @@ -1798,8 +1806,8 @@ def test_insert(self): dict(ts='current_timestamp')] for test in tests: if isinstance(test, dict): - data = test - change = {} + data: dict = test + change: dict = {} else: data, change = test expect = data.copy() @@ -1835,7 +1843,7 @@ def test_insert(self): query(f'truncate table "{table}"') def test_insert_with_oids(self): - if not self.oids: + if not self.supports_oids: self.skipTest("database does not support tables with oids") insert = self.db.insert query = self.db.query @@ -1910,7 +1918,7 @@ def test_insert_with_quoted_names(self): table = 'test table for insert()' self.create_table(table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text') - r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'} + r: Any = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'} r = insert(table, r) self.assertIsInstance(r, dict) self.assertEqual(r['Prime!'], 11) @@ -1928,7 +1936,7 @@ def test_insert_into_view(self): query = self.db.query query("truncate table test") q = 'select * from test_view order by i4 limit 3' - r = query(q).getresult() + r: Any = query(q).getresult() self.assertEqual(r, []) r = dict(i4=1234, v4='abcd') insert('test', r) @@ -1993,7 +2001,7 @@ def test_update(self): self.assertEqual(r, 'u') def test_update_with_oids(self): - if not self.oids: + if not self.supports_oids: self.skipTest("database does not support tables with oids") update = self.db.update get = self.db.get @@ -2133,7 +2141,7 @@ def test_update_with_quoted_names(self): self.create_table(table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text', values=[(13, 3003, 'Why!')]) - r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'} + r: Any = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'} r = update(table, r) self.assertIsInstance(r, dict) self.assertEqual(r['Prime!'], 13) @@ -2166,7 +2174,7 @@ def test_update_with_generated_columns(self): self.create_table(table, table_def) i, d = 35, 1001 j = i + 7 - r = query(f'insert into {table} (i, d) values ({i}, {d})') + r: Any = query(f'insert into {table} (i, d) values ({i}, {d})') self.assertEqual(r, '1') r = get(table, d) self.assertIsInstance(r, dict) @@ -2185,8 +2193,8 @@ def test_upsert(self): 'test', i2=2, i4=4, i8=8) table = 'upsert_test_table' self.create_table(table, 'n integer primary key, t text') - s = dict(n=1, t='x') - r = upsert(table, s) + s: dict = dict(n=1, t='x') + r: Any = upsert(table, s) self.assertIs(r, s) self.assertEqual(r['n'], 1) self.assertEqual(r['t'], 'x') @@ -2252,7 +2260,7 @@ def test_upsert(self): self.assertEqual(r, [(1, 'x2'), (2, 'y3')]) def test_upsert_with_oids(self): - if not self.oids: + if not self.supports_oids: self.skipTest("database does not support tables with oids") upsert = self.db.upsert get = self.db.get @@ -2260,7 +2268,7 @@ def test_upsert_with_oids(self): self.create_table('test_table', 'n int', oids=True, values=[1]) self.assertRaises(pg.ProgrammingError, upsert, 'test_table', dict(n=2)) - r = get('test_table', 1, 'n') + r: Any = get('test_table', 1, 'n') self.assertIsInstance(r, dict) self.assertEqual(r['n'], 1) qoid = 'oid(test_table)' @@ -2338,8 +2346,8 @@ def test_upsert_with_composite_key(self): table = 'upsert_test_table_2' self.create_table( table, 'n integer, m integer, t text, primary key (n, m)') - s = dict(n=1, m=2, t='x') - r = upsert(table, s) + s: dict = dict(n=1, m=2, t='x') + r: Any = upsert(table, s) self.assertIs(r, s) self.assertEqual(r['n'], 1) self.assertEqual(r['m'], 2) @@ -2400,8 +2408,8 @@ def test_upsert_with_quoted_names(self): table = 'test table for upsert()' self.create_table(table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text') - s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'} - r = upsert(table, s) + s: dict = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'} + r: Any = upsert(table, s) self.assertIs(r, s) self.assertEqual(r['Prime!'], 31) self.assertEqual(r['much space'], 9009) @@ -2437,7 +2445,7 @@ def test_upsert_with_generated_columns(self): self.create_table(table, table_def) i, d = 35, 1001 j = i + 7 - r = upsert(table, {'i': i, 'd': d, 'a': 1, 'j': j}) + r: Any = upsert(table, {'i': i, 'd': d, 'a': 1, 'j': j}) self.assertIsInstance(r, dict) self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) r['i'] += 1 @@ -2452,7 +2460,7 @@ def test_upsert_with_generated_columns(self): def test_clear(self): clear = self.db.clear f = False if pg.get_bool() else 'f' - r = clear('test') + r: Any = clear('test') result = dict( i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='') self.assertEqual(r, result) @@ -2491,8 +2499,8 @@ def test_delete(self): self.create_table(table, 'n integer primary key, t text', oids=False, values=enumerate('xyz', start=1)) self.assertRaises(pg.DatabaseError, self.db.get, table, 4) - r = self.db.get(table, 1) - s = delete(table, r) + r: Any = self.db.get(table, 1) + s: Any = delete(table, r) self.assertEqual(s, 1) r = self.db.get(table, 3) s = delete(table, r) @@ -2516,15 +2524,15 @@ def test_delete(self): self.assertEqual(s, 0) def test_delete_with_oids(self): - if not self.oids: + if not self.supports_oids: self.skipTest("database does not support tables with oids") delete = self.db.delete get = self.db.get query = self.db.query self.create_table('test_table', 'n int', oids=True, values=range(1, 7)) - r = dict(n=3) + r: Any = dict(n=3) self.assertRaises(pg.ProgrammingError, delete, 'test_table', r) - s = get('test_table', 1, 'n') + s: Any = get('test_table', 1, 'n') qoid = 'oid(test_table)' self.assertIn(qoid, s) r = delete('test_table', s) @@ -2618,7 +2626,7 @@ def test_delete_with_composite_key(self): values=enumerate('abc', start=1)) self.assertRaises(KeyError, self.db.delete, table, dict(t='b')) self.assertEqual(self.db.delete(table, dict(n=2)), 1) - r = query(f'select t from "{table}" where n=2').getresult() + r: Any = query(f'select t from "{table}" where n=2').getresult() self.assertEqual(r, []) self.assertEqual(self.db.delete(table, dict(n=2)), 0) r = query(f'select t from "{table}" where n=3').getresult()[0][0] @@ -2650,7 +2658,7 @@ def test_delete_with_quoted_names(self): table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text', values=[(19, 5005, 'Yes!')]) - r = {'Prime!': 17} + r: Any = {'Prime!': 17} r = delete(table, r) self.assertEqual(r, 0) r = query(f'select count(*) from "{table}"').getresult() @@ -2676,7 +2684,7 @@ def test_delete_referenced(self): delete, 'test_parent', None, n=2) self.assertRaises(pg.IntegrityError, delete, 'test_parent *', None, n=2) - r = delete('test_child', None, n=2) + r: Any = delete('test_child', None, n=2) self.assertEqual(r, 1) self.assertEqual(query(q).getresult()[0], (3, 2)) r = delete('test_parent', None, n=2) @@ -2706,7 +2714,7 @@ def test_temp_crud(self): self.db.insert(table, dict(n=1, t='one')) self.db.insert(table, dict(n=2, t='too')) self.db.insert(table, dict(n=3, t='three')) - r = self.db.get(table, 2) + r: Any = self.db.get(table, 2) self.assertEqual(r['t'], 'too') self.db.update(table, dict(n=2, t='two')) r = self.db.get(table, 2) @@ -2724,7 +2732,7 @@ def test_truncate(self): self.create_table('test_table', 'n smallint', temporary=False, values=[1] * 3) q = "select count(*) from test_table" - r = query(q).getresult()[0][0] + r: Any = query(q).getresult()[0][0] self.assertEqual(r, 3) truncate('test_table') r = query(q).getresult()[0][0] @@ -2757,7 +2765,7 @@ def test_truncate_restart(self): for _n in range(3): query("insert into test_table (t) values ('test')") q = "select count(n), min(n), max(n) from test_table" - r = query(q).getresult()[0] + r: Any = query(q).getresult()[0] self.assertEqual(r, (3, 1, 3)) truncate('test_table') r = query(q).getresult()[0] @@ -2785,7 +2793,7 @@ def test_truncate_cascade(self): values=range(3)) q = ("select (select count(*) from test_parent)," " (select count(*) from test_child)") - r = query(q).getresult()[0] + r: Any = query(q).getresult()[0] self.assertEqual(r, (3, 3)) self.assertRaises(pg.NotSupportedError, truncate, 'test_parent') truncate(['test_parent', 'test_child']) @@ -2899,7 +2907,7 @@ def test_get_as_list(self): self.assertRaises(TypeError, get_as_list, None) query = self.db.query table = 'test_aslist' - r = query('select 1 as colname').namedresult()[0] + r: Any = query('select 1 as colname').namedresult()[0] self.assertIsInstance(r, tuple) named = hasattr(r, 'colname') names = [(1, 'Homer'), (2, 'Marge'), @@ -2918,7 +2926,7 @@ def test_get_as_list(self): self.assertEqual(t._asdict(), dict(id=n[0], name=n[1])) r = get_as_list(table, what='name') self.assertIsInstance(r, list) - expected = sorted((row[1],) for row in names) + expected: Any = sorted((row[1],) for row in names) self.assertEqual(r, expected) r = get_as_list(table, what='name, id') self.assertIsInstance(r, list) @@ -3029,8 +3037,8 @@ def test_get_as_dict(self): self.assertRaises(KeyError, get_as_dict, table, keyname='rgb', what='name') r = get_as_dict(table) - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[0], row[1:]) for row in colors) + self.assertIsInstance(r, dict) + expected: Any = {row[0]: row[1:] for row in colors} self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, int) @@ -3045,9 +3053,9 @@ def test_get_as_dict(self): self.assertEqual(row._asdict(), dict(rgb=t[0], name=t[1])) self.assertEqual(r.keys(), expected.keys()) r = get_as_dict(table, keyname='rgb') - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[1], (row[0], row[2])) - for row in sorted(colors, key=itemgetter(1))) + self.assertIsInstance(r, dict) + expected = {row[1]: (row[0], row[2]) + for row in sorted(colors, key=itemgetter(1))} self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, str) @@ -3063,8 +3071,8 @@ def test_get_as_dict(self): self.assertEqual(row._asdict(), dict(id=t[0], name=t[1])) self.assertEqual(r.keys(), expected.keys()) r = get_as_dict(table, keyname=['id', 'rgb']) - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[:2], row[2:]) for row in colors) + self.assertIsInstance(r, dict) + expected = {row[:2]: row[2:] for row in colors} self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, tuple) @@ -3084,8 +3092,8 @@ def test_get_as_dict(self): self.assertEqual(row._asdict(), dict(name=t[0])) self.assertEqual(r.keys(), expected.keys()) r = get_as_dict(table, keyname=['id', 'rgb'], scalar=True) - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[:2], row[2]) for row in colors) + self.assertIsInstance(r, dict) + expected = {row[:2]: row[2] for row in colors} self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, tuple) @@ -3097,9 +3105,9 @@ def test_get_as_dict(self): self.assertEqual(r.keys(), expected.keys()) r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], scalar=True) - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict( - (row[1], row[2]) for row in sorted(colors, key=itemgetter(1))) + self.assertIsInstance(r, dict) + expected = {row[1]: row[2] + for row in sorted(colors, key=itemgetter(1))} self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, str) @@ -3111,8 +3119,8 @@ def test_get_as_dict(self): self.assertEqual(r.keys(), expected.keys()) r = get_as_dict( table, what='id, name', where="rgb like '#b%'", scalar=True) - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[0], row[2]) for row in colors[1:3]) + self.assertIsInstance(r, dict) + expected = {row[0]: row[2] for row in colors[1:3]} self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, int) @@ -3140,31 +3148,31 @@ def test_get_as_dict(self): self.assertEqual(len(r), 1) self.assertEqual(r[4][1], 'Desert') r = get_as_dict(table, order='id desc') - expected = OrderedDict((row[0], row[1:]) for row in reversed(colors)) + expected = {row[0]: row[1:] for row in reversed(colors)} self.assertEqual(r, expected) r = get_as_dict(table, where='id > 5') - self.assertIsInstance(r, OrderedDict) + self.assertIsInstance(r, dict) self.assertEqual(len(r), 0) # test with unordered query expected = {row[0]: row[1:] for row in colors} r = get_as_dict(table, order=False) self.assertIsInstance(r, dict) self.assertEqual(r, expected) - self.assertNotIsInstance(self, OrderedDict) + self.assertNotIsInstance(self, dict) # test with arbitrary from clause from_table = f'(select id, lower(name) as n2 from "{table}") as t2' # primary key must be passed explicitly in this case self.assertRaises(pg.ProgrammingError, get_as_dict, from_table) r = get_as_dict(from_table, 'id') - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[0], (row[2].lower(),)) for row in colors) + self.assertIsInstance(r, dict) + expected = {row[0]: (row[2].lower(),) for row in colors} self.assertEqual(r, expected) # test without a primary key query(f'alter table "{table}" drop constraint "{table}_pkey"') self.assertRaises(KeyError, self.db.pkey, table, flush=True) self.assertRaises(pg.ProgrammingError, get_as_dict, table) r = get_as_dict(table, keyname='id') - expected = OrderedDict((row[0], row[1:]) for row in colors) + expected = {row[0]: row[1:] for row in colors} self.assertIsInstance(r, dict) self.assertEqual(r, expected) r = (1, '#007fff', 'Azure') @@ -3783,14 +3791,17 @@ def test_insert_update_get_record(self): name='text', age='int', married='bool', weight='float', salary='money')) decimal = pg.get_decimal() + bool_class: type + t: bool | str + f: bool | str if pg.get_bool(): bool_class = bool t, f = True, False else: bool_class = str t, f = 't', 'f' - person = ('John Doe', 61, t, 99.5, decimal('93456.75')) - r = self.db.insert('test_person', None, person=person) + person: tuple = ('John Doe', 61, t, 99.5, decimal('93456.75')) + r: Any = self.db.insert('test_person', None, person=person) self.assertEqual(r['id'], 1) p = r['person'] self.assertIsInstance(p, tuple) @@ -4301,9 +4312,11 @@ def test_inserttable_from_query(self): class TestDBClassNonStdOpts(TestDBClass): """Test the methods of the DB class with non-standard global options.""" + saved_options: ClassVar[dict[str, Any]] = {} + @classmethod def setUpClass(cls): - cls.saved_options = {} + cls.saved_options.clear() cls.set_option('decimal', float) not_bool = not pg.get_bool() cls.set_option('bool', not_bool) @@ -4375,8 +4388,8 @@ def test_adapt_query_typed_list(self): self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), ('int2',)) self.assertRaises( TypeError, format_query, '%s,%s', (1,), ('int2', 'int2')) - values = (3, 7.5, 'hello', True) - types = ('int4', 'float4', 'text', 'bool') + values: list | tuple = (3, 7.5, 'hello', True) + types: list | tuple = ('int4', 'float4', 'text', 'bool') sql, params = format_query("select %s,%s,%s,%s", values, types) self.assertEqual(sql, 'select $1,$2,$3,$4') self.assertEqual(params, [3, 7.5, 'hello', 't']) @@ -4434,7 +4447,7 @@ def test_adapt_query_typed_list_with_types_as_classes(self): def test_adapt_query_typed_list_with_json(self): format_query = self.adapter.format_query - value = {'test': [1, "it's fine", 3]} + value: Any = {'test': [1, "it's fine", 3]} sql, params = format_query("select %s", (value,), 'json') self.assertEqual(sql, 'select $1') self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}']) @@ -4449,7 +4462,7 @@ def test_adapt_query_typed_list_with_json(self): def test_adapt_query_typed_with_hstore(self): format_query = self.adapter.format_query - value = {'one': "it's fine", 'two': 2} + value: Any = {'one': "it's fine", 'two': 2} sql, params = format_query("select %s", (value,), 'hstore') self.assertEqual(sql, "select $1") self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) @@ -4464,7 +4477,7 @@ def test_adapt_query_typed_with_hstore(self): def test_adapt_query_typed_with_uuid(self): format_query = self.adapter.format_query - value = '12345678-1234-5678-1234-567812345678' + value: Any = '12345678-1234-5678-1234-567812345678' sql, params = format_query("select %s", (value,), 'uuid') self.assertEqual(sql, "select $1") self.assertEqual(params, ['12345678-1234-5678-1234-567812345678']) @@ -4482,8 +4495,8 @@ def test_adapt_query_typed_dict(self): self.assertRaises( TypeError, format_query, '%s,%s', dict(i1=1, i2=2), dict(i1='int2')) - values = dict(i=3, f=7.5, t='hello', b=True) - types = dict(i='int4', f='float4', t='text', b='bool') + values: dict = dict(i=3, f=7.5, t='hello', b=True) + types: dict = dict(i='int4', f='float4', t='text', b='bool') sql, params = format_query( "select %(i)s,%(f)s,%(t)s,%(b)s", values, types) self.assertEqual(sql, 'select $3,$2,$4,$1') @@ -4523,7 +4536,7 @@ def test_adapt_query_typed_dict(self): def test_adapt_query_untyped_list(self): format_query = self.adapter.format_query - values = (3, 7.5, 'hello', True) + values: list | tuple = (3, 7.5, 'hello', True) sql, params = format_query("select %s,%s,%s,%s", values) self.assertEqual(sql, 'select $1,$2,$3,$4') self.assertEqual(params, [3, 7.5, 'hello', 't']) @@ -4562,7 +4575,7 @@ def test_adapt_query_untyped_with_hstore(self): def test_adapt_query_untyped_dict(self): format_query = self.adapter.format_query - values = dict(i=3, f=7.5, t='hello', b=True) + values: dict = dict(i=3, f=7.5, t='hello', b=True) sql, params = format_query( "select %(i)s,%(f)s,%(t)s,%(b)s", values) self.assertEqual(sql, 'select $3,$2,$4,$1') @@ -4589,7 +4602,7 @@ def test_adapt_query_untyped_dict(self): def test_adapt_query_inline_list(self): format_query = self.adapter.format_query - values = (3, 7.5, 'hello', True) + values: list | tuple = (3, 7.5, 'hello', True) sql, params = format_query("select %s,%s,%s,%s", values, inline=True) self.assertEqual(sql, "select 3,7.5,'hello',true") self.assertEqual(params, []) @@ -4633,7 +4646,7 @@ def test_adapt_query_inline_list_with_hstore(self): def test_adapt_query_inline_dict(self): format_query = self.adapter.format_query - values = dict(i=3, f=7.5, t='hello', b=True) + values: dict = dict(i=3, f=7.5, t='hello', b=True) sql, params = format_query( "select %(i)s,%(f)s,%(t)s,%(b)s", values, inline=True) self.assertEqual(sql, "select 3,7.5,'hello',true") @@ -4683,6 +4696,7 @@ class TestSchemas(unittest.TestCase): """Test correct handling of schemas (namespaces).""" cls_set_up = False + with_oids = "" @classmethod def setUpClass(cls): @@ -4823,11 +4837,11 @@ def test_query_information_schema(self): class TestDebug(unittest.TestCase): """Test the debug attribute of the DB class.""" - + def setUp(self): self.db = DB() self.query = self.db.query - self.debug = self.db.debug + self.debug = self.db.debug # type: ignore self.output = StringIO() self.stdout, sys.stdout = sys.stdout, self.output @@ -4877,7 +4891,7 @@ def test_debug_is_file_like(self): self.assertEqual(self.get_output(), "") def test_debug_is_callable(self): - output = [] + output: list[str] = [] self.db.debug = output.append self.db.query("select 1") self.db.query("select 2") @@ -4885,7 +4899,7 @@ def test_debug_is_callable(self): self.assertEqual(self.get_output(), "") def test_debug_multiple_args(self): - output = [] + output: list[str] = [] self.db.debug = output.append args = ['Error', 42, {1: 'a', 2: 'b'}, [3, 5, 7]] self.db._do_debug(*args) @@ -4897,8 +4911,8 @@ class TestMemoryLeaks(unittest.TestCase): """Test that the DB class does not leak memory.""" def get_leaks(self, fut): - ids = set() - objs = [] + ids: set = set() + objs: list = [] add_ids = ids.update gc.collect() objs[:] = gc.get_objects() diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 37606b13..33c2f6f9 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -9,11 +9,13 @@ These tests do not need a database to test against. """ +from __future__ import annotations + import json import re import unittest from datetime import timedelta -from typing import Any, Sequence, Tuple, Type +from typing import Any, Sequence import pg # the module under test @@ -21,56 +23,64 @@ class TestHasConnect(unittest.TestCase): """Test existence of basic pg module functions.""" - def testhas_pg_error(self): + def test_has_pg_error(self): self.assertTrue(issubclass(pg.Error, Exception)) - def testhas_pg_warning(self): + def test_has_pg_warning(self): self.assertTrue(issubclass(pg.Warning, Exception)) - def testhas_pg_interface_error(self): + def test_has_pg_interface_error(self): self.assertTrue(issubclass(pg.InterfaceError, pg.Error)) - def testhas_pg_database_error(self): + def test_has_pg_database_error(self): self.assertTrue(issubclass(pg.DatabaseError, pg.Error)) - def testhas_pg_internal_error(self): + def test_has_pg_internal_error(self): self.assertTrue(issubclass(pg.InternalError, pg.DatabaseError)) - def testhas_pg_operational_error(self): + def test_has_pg_operational_error(self): self.assertTrue(issubclass(pg.OperationalError, pg.DatabaseError)) - def testhas_pg_programming_error(self): + def test_has_pg_programming_error(self): self.assertTrue(issubclass(pg.ProgrammingError, pg.DatabaseError)) - def testhas_pg_integrity_error(self): + def test_has_pg_integrity_error(self): self.assertTrue(issubclass(pg.IntegrityError, pg.DatabaseError)) - def testhas_pg_data_error(self): + def test_has_pg_data_error(self): self.assertTrue(issubclass(pg.DataError, pg.DatabaseError)) - def testhas_pg_not_supported_error(self): + def test_has_pg_not_supported_error(self): self.assertTrue(issubclass(pg.NotSupportedError, pg.DatabaseError)) - def testhas_pg_invalid_result_error(self): + def test_has_pg_invalid_result_error(self): self.assertTrue(issubclass(pg.InvalidResultError, pg.DataError)) - def testhas_pg_no_result_error(self): + def test_has_pg_no_result_error(self): self.assertTrue(issubclass(pg.NoResultError, pg.InvalidResultError)) - def testhas_pg_multiple_results_error(self): + def test_has_pg_multiple_results_error(self): self.assertTrue( issubclass(pg.MultipleResultsError, pg.InvalidResultError)) - def testhas_connect(self): + def test_has_connection_type(self): + self.assertIsInstance(pg.Connection, type) + self.assertEqual(pg.Connection.__name__, 'Connection') + + def test_has_query_type(self): + self.assertIsInstance(pg.Query, type) + self.assertEqual(pg.Query.__name__, 'Query') + + def test_has_connect(self): self.assertTrue(callable(pg.connect)) - def testhas_escape_string(self): + def test_has_escape_string(self): self.assertTrue(callable(pg.escape_string)) - def testhas_escape_bytea(self): + def test_has_escape_bytea(self): self.assertTrue(callable(pg.escape_bytea)) - def testhas_unescape_bytea(self): + def test_has_unescape_bytea(self): self.assertTrue(callable(pg.unescape_bytea)) def test_def_host(self): @@ -120,7 +130,7 @@ def test_pqlib_version(self): class TestParseArray(unittest.TestCase): """Test the array parser.""" - test_strings: Sequence[Tuple[str, Type, Any]] = [ + test_strings: Sequence[tuple[str, type | None, Any]] = [ ('', str, ValueError), ('{}', None, []), ('{}', str, []), @@ -354,7 +364,7 @@ def replace_comma(value): class TestParseRecord(unittest.TestCase): """Test the record parser.""" - test_strings: Sequence[Tuple[str, Type, Any]] = [ + test_strings: Sequence[tuple[str, type | tuple[type, ...] | None, Any]] = [ ('', None, ValueError), ('', str, ValueError), ('(', None, ValueError), @@ -635,7 +645,7 @@ def replace_comma(value): class TestParseHStore(unittest.TestCase): """Test the hstore parser.""" - test_strings: Sequence[Tuple[str, Any]] = [ + test_strings: Sequence[tuple[str, Any]] = [ ('', {}), ('=>', ValueError), ('""=>', ValueError), @@ -684,7 +694,7 @@ def test_parser(self): class TestCastInterval(unittest.TestCase): """Test the interval typecast function.""" - intervals: Sequence[Tuple[Tuple[int, ...], Tuple[str, ...]]] = [ + intervals: Sequence[tuple[tuple[int, ...], tuple[str, ...]]] = [ ((0, 0, 0, 1, 0, 0, 0), ('1:00:00', '01:00:00', '@ 1 hour', 'PT1H')), ((0, 0, 0, -1, 0, 0, 0), diff --git a/tests/test_classic_largeobj.py b/tests/test_classic_largeobj.py index 7e5ad4a2..4fb8773c 100755 --- a/tests/test_classic_largeobj.py +++ b/tests/test_classic_largeobj.py @@ -13,6 +13,7 @@ import tempfile import unittest from contextlib import suppress +from typing import Any import pg # the module under test @@ -105,6 +106,7 @@ def test_get_lo(self): self.assertEqual(r, data) def test_lo_import(self): + f : Any if windows: # NamedTemporaryFiles don't work well here fname = 'temp_test_pg_largeobj_import.txt' @@ -412,6 +414,7 @@ def test_export(self): self.assertRaises(TypeError, export) self.assertRaises(TypeError, export, 0) self.assertRaises(TypeError, export, 'invalid', 0) + f: Any if windows: # NamedTemporaryFiles don't work well here fname = 'temp_test_pg_largeobj_export.txt' diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index bcacd476..09211718 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -9,10 +9,12 @@ These tests need a database to test against. """ +from __future__ import annotations # + import unittest from collections.abc import Iterable from contextlib import suppress -from typing import Sequence, Tuple +from typing import Sequence import pgdb # the module under test @@ -150,7 +152,7 @@ def tearDown(self): with suppress(Exception): self.con.close() - data: Sequence[Tuple[int, str]] = [ + data: Sequence[tuple[int, str]] = [ (1935, 'Luciano Pavarotti'), (1941, 'Plácido Domingo'), (1946, 'José Carreras')] diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index 3f76f39b..c28fbefc 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -1,6 +1,7 @@ #!/usr/bin/python import unittest +from typing import Any from pg import DB from pgdb import connect @@ -29,7 +30,7 @@ def tearDown(self): def test_all_steps(self): db = self.db - r = db.get_tables() + r: Any = db.get_tables() self.assertIsInstance(r, list) self.assertIn('public.fruits', r) r = db.get_attnames('fruits') diff --git a/tox.ini b/tox.ini index 37b3a39d..7e52747d 100644 --- a/tox.ini +++ b/tox.ini @@ -9,6 +9,12 @@ deps = ruff>=0.0.287 commands = ruff setup.py pg.py pgdb.py tests +[testenv:mypy] +basepython = python3.11 +deps = mypy>=1.5.1 +commands = + mypy setup.py pg.py pgdb.py tests + [testenv:cformat] basepython = python3.11 allowlist_externals = From 08c43d8fef8a2bc23284d09a0b5d637e961d042a Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 4 Sep 2023 00:55:43 +0200 Subject: [PATCH 049/118] Add type hints for the pgdb module --- docs/contents/changelog.rst | 3 +- docs/contents/pg/db_wrapper.rst | 2 + pg.py | 46 +-- pgdb.py | 613 +++++++++++++++++--------------- pgsource.c | 3 +- tests/dbapi20.py | 6 +- tests/test_classic_dbwrapper.py | 4 +- tests/test_dbapi20.py | 116 +++--- tests/test_dbapi20_copy.py | 35 +- tests/test_tutorial.py | 2 +- 10 files changed, 458 insertions(+), 372 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index d240daa2..6afc68dd 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -5,8 +5,9 @@ Version 6.0 (to be released) ---------------------------- - Removed support for Python versions older than 3.7 (released June 2017) and PostgreSQL older than version 10 (released October 2017). +- Added method `pkeys()` to the `pg.DB` object. - Removed deprecated function `pg.pgnotify()`. -- Removed the deprecated method `ntuples()` of the `pg.Query` object. +- Removed deprecated method `ntuples()` of the `pg.Query` object. - Renamed `pgdb.Type` to `pgdb.DbType` to avoid confusion with `typing.Type`. - Modernized code and tools for development, testing, linting and building. diff --git a/docs/contents/pg/db_wrapper.rst b/docs/contents/pg/db_wrapper.rst index 64710456..ea4f71c1 100644 --- a/docs/contents/pg/db_wrapper.rst +++ b/docs/contents/pg/db_wrapper.rst @@ -83,6 +83,8 @@ This method returns the primary keys of a table as a tuple, i.e. single primary keys are also returned as a tuple with one item. Note that this raises a KeyError if the table does not have a primary key. +.. versionadded:: 6.0 + get_databases -- get list of databases in the system ---------------------------------------------------- diff --git a/pg.py b/pg.py index 11aaf90a..45f8ae46 100644 --- a/pg.py +++ b/pg.py @@ -242,7 +242,8 @@ def __str__(self) -> str: class Json: """Wrapper class for marking Json values.""" - def __init__(self, obj: Any, encode: Callable | None = None) -> None: + def __init__(self, obj: Any, + encode: Callable[[Any], str] | None = None) -> None: """Initialize the JSON object.""" self.obj = obj self.encode = encode or jsonencode @@ -1014,7 +1015,7 @@ class Typecasts(dict): connection: DB | None = None # set in a connection specific instance - def __missing__(self, typ: Any) -> Callable | None: + def __missing__(self, typ: str) -> Callable | None: """Create a cast function if it is not cached. Note that this class never raises a KeyError, @@ -1047,8 +1048,7 @@ def _needs_connection(func: Callable) -> bool: args = get_args(func) except (TypeError, ValueError): return False - else: - return 'connection' in args[1:] + return 'connection' in args[1:] def _add_connection(self, cast: Callable) -> Callable: """Add a connection argument to the typecast function if necessary.""" @@ -1056,11 +1056,12 @@ def _add_connection(self, cast: Callable) -> Callable: return cast return partial(cast, connection=self.connection) - def get(self, typ: Any, default: Any = None) -> Any: + def get(self, typ: str, default: Callable | None = None # type: ignore + ) -> Callable | None: """Get the typecast function for the given database type.""" return self[typ] or default - def set(self, typ: Any, cast: Callable) -> None: + def set(self, typ: str | Sequence[str], cast: Callable | None) -> None: """Set a typecast function for the specified database type(s).""" if isinstance(typ, str): typ = [typ] @@ -1075,7 +1076,7 @@ def set(self, typ: Any, cast: Callable) -> None: self[t] = self._add_connection(cast) self.pop(f'_{t}', None) - def reset(self, typ: Any = None) -> None: + def reset(self, typ: str | Sequence[str] | None = None) -> None: """Reset the typecasts for the specified type(s) to their defaults. When no type is specified, all typecasts will be reset. @@ -1089,12 +1090,13 @@ def reset(self, typ: Any = None) -> None: self.pop(t, None) @classmethod - def get_default(cls, typ: Any) -> Any: + def get_default(cls, typ: str) -> Any: """Get the default typecast function for the given database type.""" return cls.defaults.get(typ) @classmethod - def set_default(cls, typ: Any, cast: Callable | None) -> None: + def set_default(cls, typ: str | Sequence[str], + cast: Callable | None) -> None: """Set a default typecast function for the given database type(s).""" if isinstance(typ, str): typ = [typ] @@ -1130,7 +1132,7 @@ def create_array_cast(self, basecast: Callable) -> Callable: """Create an array typecast for the given base cast.""" cast_array = self['anyarray'] - def cast(v: Any) -> Callable: + def cast(v: Any) -> list: return cast_array(v, basecast) return cast @@ -1146,12 +1148,12 @@ def cast(v: Any) -> record: return cast -def get_typecast(typ: Any) -> Callable | None: - """Get the global typecast function for the given database type(s).""" +def get_typecast(typ: str) -> Callable | None: + """Get the global typecast function for the given database type.""" return Typecasts.get_default(typ) -def set_typecast(typ: Any, cast: Callable | None) -> None: +def set_typecast(typ: str | Sequence[str], cast: Callable | None) -> None: """Set a global typecast function for the given database type(s). Note that connections cache cast functions. To be sure a global change @@ -1254,7 +1256,8 @@ def __missing__(self, key: int | str) -> DbType: self[typ.oid] = self[typ.pgtype] = typ return typ - def get(self, key: int | str, default: Any = None) -> Any: + def get(self, key: int | str, # type: ignore + default: DbType | None = None) -> DbType | None: """Get the type even if it is not cached.""" try: return self[key] @@ -1271,27 +1274,27 @@ def get_attnames(self, typ: Any) -> AttrDict | None: return None return self._db.get_attnames(typ.relid, with_oid=False) - def get_typecast(self, typ: Any) -> Callable: + def get_typecast(self, typ: Any) -> Callable | None: """Get the typecast function for the given database type.""" return self._typecasts.get(typ) - def set_typecast(self, typ: Any, cast: Callable) -> None: + def set_typecast(self, typ: str | Sequence[str], cast: Callable) -> None: """Set a typecast function for the specified database type(s).""" self._typecasts.set(typ, cast) - def reset_typecast(self, typ: Any = None) -> None: + def reset_typecast(self, typ: str | Sequence[str] | None = None) -> None: """Reset the typecast function for the specified database type(s).""" self._typecasts.reset(typ) - def typecast(self, value: Any, typ: Any) -> Callable | None: + def typecast(self, value: Any, typ: str) -> Any: """Cast the given value according to the given database type.""" if value is None: # for NULL values, no typecast is necessary return None if not isinstance(typ, DbType): - typ = self.get(typ) - if typ: - typ = typ.pgtype + db_type = self.get(typ) + if db_type: + typ = db_type.pgtype cast = self.get_typecast(typ) if typ else None if not cast or cast is str: # no typecast is necessary @@ -1373,6 +1376,7 @@ def getresult(self) -> Any: def __iter__(self) -> Iterator[Any]: return iter(self.result) +# Error messages E = TypeVar('E', bound=DatabaseError) diff --git a/pgdb.py b/pgdb.py index df23bbfd..332ca3d0 100644 --- a/pgdb.py +++ b/pgdb.py @@ -69,7 +69,7 @@ from collections import namedtuple from collections.abc import Iterable from contextlib import suppress -from datetime import date, datetime, time, timedelta +from datetime import date, datetime, time, timedelta, tzinfo from decimal import Decimal as StdDecimal from functools import lru_cache, partial from inspect import signature @@ -78,7 +78,16 @@ from math import isinf, isnan from re import compile as regex from time import localtime -from typing import Callable, ClassVar +from typing import ( + Any, + Callable, + ClassVar, + Generator, + Mapping, + NamedTuple, + Sequence, + TypeVar, +) from uuid import UUID as Uuid # noqa: N811 try: @@ -131,10 +140,15 @@ cast_array, cast_hstore, cast_record, - connect, unescape_bytea, version, ) +from _pg import ( + Connection as Cnx, # base connection +) +from _pg import ( + connect as get_cnx, # get base connection +) __version__ = version @@ -197,7 +211,7 @@ def _timezone_as_offset(tz: str) -> str: return _timezones.get(tz, '+0000') -def decimal_type(decimal_type: type | None = None): +def decimal_type(decimal_type: type | None = None) -> type: """Get or set global type to be used for decimal values. Note that connections cache cast functions. To be sure a global change @@ -212,15 +226,15 @@ def decimal_type(decimal_type: type | None = None): def cast_bool(value: str) -> bool | None: """Cast boolean value in database format to bool.""" - if value: - return value[0] in ('t', 'T') + return value[0] in ('t', 'T') if value else None -def cast_money(value: str) -> Decimal | None: # pyright: ignore +def cast_money(value: str) -> StdDecimal | None: """Cast money value in database format to Decimal.""" - if value: - value = value.replace('(', '-') - return Decimal(''.join(c for c in value if c.isdigit() or c in '.-')) + if not value: + return None + value = value.replace('(', '-') + return Decimal(''.join(c for c in value if c.isdigit() or c in '.-')) def cast_int2vector(value: str) -> list[int]: @@ -228,7 +242,7 @@ def cast_int2vector(value: str) -> list[int]: return [int(v) for v in value.split()] -def cast_date(value: str, connection) -> date: +def cast_date(value: str, cnx: Cnx) -> date: """Cast a date value.""" # The output format depends on the server setting DateStyle. The default # setting ISO and the setting for German are actually unambiguous. The @@ -244,7 +258,7 @@ def cast_date(value: str, connection) -> date: value = values[0] if len(value) > 10: return date.max - format = connection.date_format() + format = cnx.date_format() return datetime.strptime(value, format).date() @@ -270,7 +284,7 @@ def cast_timetz(value: str) -> time: return datetime.strptime(value, format).timetz() -def cast_timestamp(value: str, connection) -> datetime: +def cast_timestamp(value: str, cnx: Cnx) -> datetime: """Cast a timestamp value.""" if value == '-infinity': return datetime.min @@ -279,7 +293,7 @@ def cast_timestamp(value: str, connection) -> datetime: values = value.split() if values[-1] == 'BC': return datetime.min - format = connection.date_format() + format = cnx.date_format() if format.endswith('-%Y') and len(values) > 2: values = values[1:5] if len(values[3]) > 4: @@ -293,7 +307,7 @@ def cast_timestamp(value: str, connection) -> datetime: return datetime.strptime(' '.join(values), ' '.join(formats)) -def cast_timestamptz(value: str, connection) -> datetime: +def cast_timestamptz(value: str, cnx: Cnx) -> datetime: """Cast a timestamptz value.""" if value == '-infinity': return datetime.min @@ -302,7 +316,7 @@ def cast_timestamptz(value: str, connection) -> datetime: values = value.split() if values[-1] == 'BC': return datetime.min - format = connection.date_format() + format = cnx.date_format() if format.endswith('-%Y') and len(values) > 2: values = values[1:] if len(values[3]) > 4: @@ -439,9 +453,9 @@ class Typecasts(dict): 'int2vector': cast_int2vector, 'uuid': Uuid, 'anyarray': cast_array, 'record': cast_record} - connection = None # will be set in local connection specific instances + cnx: Cnx | None = None # for local connection specific instances - def __missing__(self, typ): + def __missing__(self, typ: str) -> Callable | None: """Create a cast function if it is not cached. Note that this class never raises a KeyError, @@ -464,26 +478,26 @@ def __missing__(self, typ): return cast @staticmethod - def _needs_connection(func): + def _needs_connection(func: Callable) -> bool: """Check if a typecast function needs a connection argument.""" try: args = get_args(func) except (TypeError, ValueError): return False - else: - return 'connection' in args[1:] + return 'cnx' in args[1:] - def _add_connection(self, cast): + def _add_connection(self, cast: Callable) -> Callable: """Add a connection argument to the typecast function if necessary.""" - if not self.connection or not self._needs_connection(cast): + if not self.cnx or not self._needs_connection(cast): return cast - return partial(cast, connection=self.connection) + return partial(cast, cnx=self.cnx) - def get(self, typ, default=None): + def get(self, typ: str, default: Callable | None = None # type: ignore + ) -> Callable | None: """Get the typecast function for the given database type.""" return self[typ] or default - def set(self, typ, cast): + def set(self, typ: str | Sequence[str], cast: Callable | None) -> None: """Set a typecast function for the specified database type(s).""" if isinstance(typ, str): typ = [typ] @@ -498,7 +512,7 @@ def set(self, typ, cast): self[t] = self._add_connection(cast) self.pop(f'_{t}', None) - def reset(self, typ=None): + def reset(self, typ: str | Sequence[str] | None = None) -> None: """Reset the typecasts for the specified type(s) to their defaults. When no type is specified, all typecasts will be reset. @@ -524,20 +538,21 @@ def reset(self, typ=None): self.pop(t, None) self.pop(f'_{t}', None) - def create_array_cast(self, basecast): + def create_array_cast(self, basecast: Callable) -> Callable: """Create an array typecast for the given base cast.""" cast_array = self['anyarray'] - def cast(v): + def cast(v: Any) -> list: return cast_array(v, basecast) return cast - def create_record_cast(self, name, fields, casts): + def create_record_cast(self, name: str, fields: Sequence[str], + casts: Sequence[str]) -> Callable: """Create a named record typecast for the given fields and casts.""" cast_record = self['record'] - record = namedtuple(name, fields) + record = namedtuple(name, fields) # type: ignore - def cast(v): + def cast(v: Any) -> record: # noinspection PyArgumentList return record(*cast_record(v, casts)) return cast @@ -546,12 +561,12 @@ def cast(v): _typecasts = Typecasts() # this is the global typecast dictionary -def get_typecast(typ): - """Get the global typecast function for the given database type(s).""" +def get_typecast(typ: str) -> Callable | None: + """Get the global typecast function for the given database type.""" return _typecasts.get(typ) -def set_typecast(typ, cast): +def set_typecast(typ: str | Sequence[str], cast: Callable | None) -> None: """Set a global typecast function for the given database type(s). Note that connections cache cast functions. To be sure a global change @@ -560,7 +575,7 @@ def set_typecast(typ, cast): _typecasts.set(typ, cast) -def reset_typecast(typ=None): +def reset_typecast(typ: str | Sequence[str] | None = None) -> None: """Reset the global typecasts for the given type(s) to their default. When no type is specified, all typecasts will be reset. @@ -576,10 +591,11 @@ class LocalTypecasts(Typecasts): defaults = _typecasts - connection = None # will be set in a connection specific instance + cnx: Cnx | None = None # set in connection specific instances - def __missing__(self, typ): + def __missing__(self, typ: str) -> Callable | None: """Create a cast function if it is not cached.""" + cast: Callable | None if typ.startswith('_'): base_cast = self[typ[1:]] cast = self.create_array_cast(base_cast) @@ -594,13 +610,13 @@ def __missing__(self, typ): fields = self.get_fields(typ) if fields: casts = [self[field.type] for field in fields] - fields = [field.name for field in fields] - cast = self.create_record_cast(typ, fields, casts) + field_names = [field.name for field in fields] + cast = self.create_record_cast(typ, field_names, casts) self[typ] = cast return cast # noinspection PyMethodMayBeStatic,PyUnusedLocal - def get_fields(self, typ): + def get_fields(self, typ: str) -> list[FieldInfo]: """Return the fields for the given record type. This method will be replaced with a method that looks up the fields @@ -616,9 +632,17 @@ class TypeCode(str): but carry some additional information. """ + oid: int + len: int + type: str + category: str + delim: str + relid: int + # noinspection PyShadowingBuiltins @classmethod - def create(cls, oid, name, len, type, category, delim, relid): + def create(cls, oid: int, name: str, len: int, type: str, category: str, + delim: str, relid: int) -> TypeCode: """Create a type code for a PostgreSQL data type.""" self = cls(name) self.oid = oid @@ -640,21 +664,22 @@ class TypeCache(dict): important information on the associated database type. """ - def __init__(self, cnx): + def __init__(self, cnx: Cnx) -> None: """Initialize type cache for connection.""" super().__init__() self._escape_string = cnx.escape_string self._src = cnx.source() self._typecasts = LocalTypecasts() - self._typecasts.get_fields = self.get_fields - self._typecasts.connection = cnx + self._typecasts.get_fields = self.get_fields # type: ignore + self._typecasts.cnx = cnx self._query_pg_type = ( "SELECT oid, typname," " typlen, typtype, typcategory, typdelim, typrelid" " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) {}") - def __missing__(self, key): + def __missing__(self, key: int | str) -> TypeCode: """Get the type info from the database if it is not cached.""" + oid: int | str if isinstance(key, int): oid = key else: @@ -677,43 +702,48 @@ def __missing__(self, key): self[type_code.oid] = self[str(type_code)] = type_code return type_code - def get(self, key, default=None): + def get(self, key: int | str, # type: ignore + default: TypeCode | None = None) -> TypeCode | None: """Get the type even if it is not cached.""" try: return self[key] except KeyError: return default - def get_fields(self, typ): + def get_fields(self, typ: int | str | TypeCode) -> list[FieldInfo] | None: """Get the names and types of the fields of composite types.""" - if not isinstance(typ, TypeCode): - typ = self.get(typ) - if not typ: + if isinstance(typ, TypeCode): + relid = typ.relid + else: + type_code = self.get(typ) + if not type_code: return None - if not typ.relid: + relid = type_code.relid + if not relid: return None # this type is not composite self._src.execute( "SELECT attname, atttypid" # noqa: S608 " FROM pg_catalog.pg_attribute" - f" WHERE attrelid OPERATOR(pg_catalog.=) {typ.relid}" + f" WHERE attrelid OPERATOR(pg_catalog.=) {relid}" " AND attnum OPERATOR(pg_catalog.>) 0" " AND NOT attisdropped ORDER BY attnum") return [FieldInfo(name, self.get(int(oid))) for name, oid in self._src.fetch(-1)] - def get_typecast(self, typ): + def get_typecast(self, typ: str) -> Callable | None: """Get the typecast function for the given database type.""" return self._typecasts[typ] - def set_typecast(self, typ, cast): + def set_typecast(self, typ: str | Sequence[str], + cast: Callable | None) -> None: """Set a typecast function for the specified database type(s).""" self._typecasts.set(typ, cast) - def reset_typecast(self, typ=None): + def reset_typecast(self, typ: str | Sequence[str] | None = None) -> None: """Reset the typecast function for the specified database type(s).""" self._typecasts.reset(typ) - def typecast(self, value, typ): + def typecast(self, value: Any, typ: str) -> Any: """Cast the given value according to the given database type.""" if value is None: # for NULL values, no typecast is necessary @@ -724,13 +754,13 @@ def typecast(self, value, typ): return value return cast(value) - def get_row_caster(self, types): + def get_row_caster(self, types: Sequence[str]) -> Callable: """Get a typecast function for a complete row of values.""" typecasts = self._typecasts casts = [typecasts[typ] for typ in types] casts = [cast if cast is not str else None for cast in casts] - def row_caster(row): + def row_caster(row: Sequence) -> Sequence: return [value if cast is None or value is None else cast(value) for cast, value in zip(casts, row)] @@ -743,22 +773,26 @@ class _QuoteDict(dict): The quote attribute must be set to the desired quote function. """ - def __getitem__(self, key): + quote: Callable[[str], str] + + def __getitem__(self, key: str) -> str: # noinspection PyUnresolvedReferences return self.quote(super().__getitem__(key)) # *** Error Messages *** +E = TypeVar('E', bound=DatabaseError) + -def _db_error(msg, cls=DatabaseError): +def _db_error(msg: str, cls:type[E] = DatabaseError) -> type[E]: """Return DatabaseError with empty sqlstate attribute.""" error = cls(msg) error.sqlstate = None return error -def _op_error(msg): +def _op_error(msg: str) -> OperationalError: """Return OperationalError.""" return _db_error(msg, OperationalError) @@ -771,16 +805,16 @@ def _op_error(msg): # noinspection PyUnresolvedReferences @lru_cache(maxsize=1024) -def _row_factory(names): +def _row_factory(names: Sequence[str]) -> Callable[[Sequence], NamedTuple]: """Get a namedtuple factory for row results with the given names.""" try: - return namedtuple('Row', names, rename=True)._make + return namedtuple('Row', names, rename=True)._make # type: ignore except ValueError: # there is still a problem with the field names names = [f'column_{n}' for n in range(len(names))] - return namedtuple('Row', names)._make + return namedtuple('Row', names)._make # type: ignore -def set_row_factory_size(maxsize): +def set_row_factory_size(maxsize: int | None) -> None: """Change the size of the namedtuple factory cache. If maxsize is set to None, the cache can grow without bound. @@ -795,46 +829,51 @@ def set_row_factory_size(maxsize): class Cursor: """Cursor object.""" - def __init__(self, dbcnx): + def __init__(self, connection: Connection) -> None: """Create a cursor object for the database connection.""" - self.connection = self._dbcnx = dbcnx - self._cnx = dbcnx._cnx - self.type_cache = dbcnx.type_cache + self.connection = self._connection = connection + cnx = connection._cnx + if not cnx: + raise _op_error("Connection has been closed") + self._cnx = cnx + self.type_cache = connection.type_cache self._src = self._cnx.source() # the official attribute for describing the result columns - self._description = None + self._description: list[CursorDescription] | bool | None = None if self.row_factory is Cursor.row_factory: # the row factory needs to be determined dynamically - self.row_factory = None + self.row_factory = None # type: ignore else: - self.build_row_factory = None + self.build_row_factory = None # type: ignore self.rowcount = -1 self.arraysize = 1 self.lastrowid = None - def __iter__(self): + def __iter__(self) -> Cursor: """Make cursor compatible to the iteration protocol.""" return self - def __enter__(self): + def __enter__(self) -> Cursor: """Enter the runtime context for the cursor object.""" return self - def __exit__(self, et, ev, tb): + def __exit__(self, et: type[BaseException] | None, + ev: BaseException | None, tb: Any) -> None: """Exit the runtime context for the cursor object.""" self.close() - def _quote(self, value): + def _quote(self, value: Any) -> Any: """Quote value depending on its type.""" if value is None: return 'NULL' if isinstance(value, (Hstore, Json)): value = str(value) if isinstance(value, (bytes, str)): + cnx = self._cnx if isinstance(value, Binary): - value = self._cnx.escape_bytea(value).decode('ascii') + value = cnx.escape_bytea(value).decode('ascii') else: - value = self._cnx.escape_string(value) + value = cnx.escape_string(value) return f"'{value}'" if isinstance(value, float): if isinf(value): @@ -887,7 +926,8 @@ def _quote(self, value): value = self._quote(value) return value - def _quoteparams(self, string, parameters): + def _quoteparams(self, string: str, + parameters: Mapping | Sequence | None) -> str: """Quote parameters. This function works for both mappings and sequences. @@ -907,12 +947,15 @@ def _quoteparams(self, string, parameters): parameters = tuple(map(self._quote, parameters)) return string % parameters - def _make_description(self, info): + def _make_description(self, info: tuple[int, str, int, int, int] + ) -> CursorDescription: """Make the description tuple for the given field info.""" name, typ, size, mod = info[1:] type_code = self.type_cache[typ] if mod > 0: mod -= 4 + precision: int | None + scale: int | None if type_code == 'numeric': precision, scale = mod >> 16, mod & 0xffff size = precision @@ -922,34 +965,39 @@ def _make_description(self, info): if size == -1: size = mod precision = scale = None - return CursorDescription(name, type_code, - None, size, precision, scale, None) + return CursorDescription( + name, type_code, None, size, precision, scale, None) @property - def description(self): + def description(self) -> list[CursorDescription] | None: """Read-only attribute describing the result columns.""" - descr = self._description - if self._description is True: + description = self._description + if description is None: + return None + if not isinstance(description, list): make = self._make_description - descr = [make(info) for info in self._src.listinfo()] - self._description = descr - return descr + description = [make(info) for info in self._src.listinfo()] + self._description = description + return description @property - def colnames(self): + def colnames(self) -> Sequence[str] | None: """Unofficial convenience method for getting the column names.""" - return [d[0] for d in self.description] + description = self.description + return None if description is None else [d[0] for d in description] @property - def coltypes(self): + def coltypes(self) -> Sequence[TypeCode] | None: """Unofficial convenience method for getting the column types.""" - return [d[1] for d in self.description] + description = self.description + return None if description is None else [d[1] for d in description] - def close(self): + def close(self) -> None: """Close the cursor object.""" self._src.close() - def execute(self, operation, parameters=None): + def execute(self, operation: str, parameters: Sequence | None = None + ) -> Cursor: """Prepare and execute a database operation (query or command).""" # The parameters may also be specified as list of tuples to e.g. # insert multiple rows in a single operation, but this kind of @@ -960,22 +1008,22 @@ def execute(self, operation, parameters=None): and all(isinstance(p, tuple) for p in parameters) and all(len(p) == len(parameters[0]) for p in parameters[1:])): return self.executemany(operation, parameters) - else: - # not a list of tuples - return self.executemany(operation, [parameters]) + # not a list of tuples + return self.executemany(operation, [parameters]) - def executemany(self, operation, seq_of_parameters): + def executemany(self, operation: str, + seq_of_parameters: Sequence[Sequence | None]) -> Cursor: """Prepare operation and execute it against a parameter sequence.""" if not seq_of_parameters: # don't do anything without parameters - return + return self self._description = None self.rowcount = -1 # first try to execute all queries rowcount = 0 sql = "BEGIN" try: - if not self._dbcnx._tnx and not self._dbcnx.autocommit: + if not self._connection._tnx and not self._connection.autocommit: try: self._src.execute(sql) except DatabaseError: @@ -983,7 +1031,7 @@ def executemany(self, operation, seq_of_parameters): except Exception as e: raise _op_error("Can't start transaction") from e else: - self._dbcnx._tnx = True + self._connection._tnx = True for parameters in seq_of_parameters: sql = operation sql = self._quoteparams(sql, parameters) @@ -1005,8 +1053,9 @@ def executemany(self, operation, seq_of_parameters): self._description = True # fetch on demand self.rowcount = self._src.ntuples self.lastrowid = None - if self.build_row_factory: - self.row_factory = self.build_row_factory() + build_row_factory = self.build_row_factory + if build_row_factory: # type: ignore + self.row_factory = build_row_factory() # type: ignore else: self.rowcount = rowcount self.lastrowid = self._src.oidstatus() @@ -1014,7 +1063,7 @@ def executemany(self, operation, seq_of_parameters): # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)" return self - def fetchone(self): + def fetchone(self) -> Sequence | None: """Fetch the next row of a query result set.""" res = self.fetchmany(1, False) try: @@ -1022,11 +1071,12 @@ def fetchone(self): except IndexError: return None - def fetchall(self): + def fetchall(self) -> Sequence[Sequence]: """Fetch all (remaining) rows of a query result.""" return self.fetchmany(-1, False) - def fetchmany(self, size=None, keep=False): + def fetchmany(self, size: int | None = None, keep: bool = False + ) -> Sequence[Sequence]: """Fetch the next set of rows of a query result. The number of rows to fetch per call is specified by the @@ -1046,6 +1096,9 @@ def fetchmany(self, size=None, keep=False): raise _db_error(str(err)) from err row_factory = self.row_factory coltypes = self.coltypes + if coltypes is None: + # cannot determine column types, return raw result + return [row_factory(row) for row in result] if len(result) > 5: # optimize the case where we really fetch many values # by looking up all type casting functions upfront @@ -1055,7 +1108,8 @@ def fetchmany(self, size=None, keep=False): return [row_factory([cast_value(value, typ) for typ, value in zip(coltypes, row)]) for row in result] - def callproc(self, procname, parameters=None): + def callproc(self, procname: str, parameters: Sequence | None = None + ) -> Sequence | None: """Call a stored database procedure with the given name. The sequence of parameters must contain one entry for each input @@ -1073,15 +1127,17 @@ def callproc(self, procname, parameters=None): return parameters # noinspection PyShadowingBuiltins - def copy_from(self, stream, table, - format=None, sep=None, null=None, size=None, columns=None): + def copy_from(self, stream: Any, table: str, + format: str | None = None, sep: str | None = None, + null: str | None = None, size: int | None = None, + columns: Sequence[str] | None = None) -> Cursor: """Copy data from an input stream to the specified table. The input stream can be a file-like object with a read() method or it can also be an iterable returning a row or multiple rows of input on each iteration. - The format must be text, csv or binary. The sep option sets the + The format must be 'text', 'csv' or 'binary'. The sep option sets the column separator (delimiter) used in the non binary formats. The null option sets the textual representation of NULL in the input. @@ -1098,6 +1154,8 @@ def copy_from(self, stream, table, if size: raise ValueError( "Size must only be set for file-like objects") from e + input_type: type | tuple[type, ...] + type_name: str if binary_format: input_type = bytes type_name = 'byte strings' @@ -1116,12 +1174,12 @@ def copy_from(self, stream, table, if not stream.endswith(b'\n'): stream += b'\n' - def chunks(): + def chunks() -> Generator: yield stream elif isinstance(stream, Iterable): - def chunks(): + def chunks() -> Generator: for chunk in stream: if not isinstance(chunk, input_type): raise ValueError( @@ -1143,7 +1201,7 @@ def chunks(): raise TypeError("The size option must be an integer") if size > 0: - def chunks(): + def chunks() -> Generator: while True: buffer = read(size) yield buffer @@ -1152,19 +1210,18 @@ def chunks(): else: - def chunks(): + def chunks() -> Generator: yield read() if not table or not isinstance(table, str): raise TypeError("Need a table to copy to") if table.lower().startswith('select '): raise ValueError("Must specify a table, not a query") - else: - table = '.'.join(map( - self.connection._cnx.escape_identifier, table.split('.', 1))) - operation = [f'copy {table}'] + cnx = self._cnx + table = '.'.join(map(cnx.escape_identifier, table.split('.', 1))) + operation_parts = [f'copy {table}'] options = [] - params = [] + parameters = [] if format is not None: if not isinstance(format, str): raise TypeError("The format option must be be a string") @@ -1181,25 +1238,23 @@ def chunks(): raise ValueError( "The sep option must be a single one-byte character") options.append('delimiter %s') - params.append(sep) + parameters.append(sep) if null is not None: if not isinstance(null, str): raise TypeError("The null option must be a string") options.append('null %s') - params.append(null) + parameters.append(null) if columns: if not isinstance(columns, str): - columns = ','.join(map( - self.connection._cnx.escape_identifier, columns)) - operation.append(f'({columns})') - operation.append("from stdin") + columns = ','.join(map(cnx.escape_identifier, columns)) + operation_parts.append(f'({columns})') + operation_parts.append("from stdin") if options: - options = ','.join(options) - operation.append(f'({options})') - operation = ' '.join(operation) + operation_parts.append(f"({','.join(options)})") + operation = ' '.join(operation_parts) putdata = self._src.putdata - self.execute(operation, params) + self.execute(operation, parameters) try: for chunk in chunks(): @@ -1215,8 +1270,10 @@ def chunks(): return self # noinspection PyShadowingBuiltins - def copy_to(self, stream, table, - format=None, sep=None, null=None, decode=None, columns=None): + def copy_to(self, stream: Any, table: str, + format: str | None = None, sep: str | None = None, + null: str | None = None, decode: bool | None = None, + columns: Sequence[str] | None = None) -> Cursor | Generator: """Copy data from the specified table to an output stream. The output stream can be a file-like object with a write() method or @@ -1227,7 +1284,7 @@ def copy_to(self, stream, table, Note that you can also use a select query instead of the table name. - The format must be text, csv or binary. The sep option sets the + The format must be 'text', 'csv' or 'binary'. The sep option sets the column separator (delimiter) used in the non binary formats. The null option sets the textual representation of NULL in the output. @@ -1235,23 +1292,25 @@ def copy_to(self, stream, table, columns are specified, all of them will be copied. """ binary_format = format == 'binary' - if stream is not None: + if stream is None: + write = None + else: try: write = stream.write except AttributeError as e: raise TypeError("Need an output stream to copy to") from e if not table or not isinstance(table, str): raise TypeError("Need a table to copy to") + cnx = self._cnx if table.lower().startswith('select '): if columns: raise ValueError("Columns must be specified in the query") table = f'({table})' else: - table = '.'.join(map( - self.connection._cnx.escape_identifier, table.split('.', 1))) - operation = [f'copy {table}'] + table = '.'.join(map(cnx.escape_identifier, table.split('.', 1))) + operation_parts = [f'copy {table}'] options = [] - params = [] + parameters = [] if format is not None: if not isinstance(format, str): raise TypeError("The format option must be a string") @@ -1268,12 +1327,12 @@ def copy_to(self, stream, table, raise ValueError( "The sep option must be a single one-byte character") options.append('delimiter %s') - params.append(sep) + parameters.append(sep) if null is not None: if not isinstance(null, str): raise TypeError("The null option must be a string") options.append('null %s') - params.append(null) + parameters.append(null) if decode is None: decode = format != 'binary' else: @@ -1284,20 +1343,18 @@ def copy_to(self, stream, table, "The decode option is not allowed with binary format") if columns: if not isinstance(columns, str): - columns = ','.join(map( - self.connection._cnx.escape_identifier, columns)) - operation.append(f'({columns})') + columns = ','.join(map(cnx.escape_identifier, columns)) + operation_parts.append(f'({columns})') - operation.append("to stdout") + operation_parts.append("to stdout") if options: - options = ','.join(options) - operation.append(f'({options})') - operation = ' '.join(operation) + operation_parts.append(f"({','.join(options)})") + operation = ' '.join(operation_parts) getdata = self._src.getdata - self.execute(operation, params) + self.execute(operation, parameters) - def copy(): + def copy() -> Generator: self.rowcount = 0 while True: row = getdata(decode) @@ -1308,7 +1365,7 @@ def copy(): self.rowcount += 1 yield row - if stream is None: + if write is None: # no input stream, return the generator return copy() @@ -1320,7 +1377,7 @@ def copy(): # return the cursor object, so you can chain operations return self - def __next__(self): + def __next__(self) -> Sequence: """Return the next row (support for the iteration protocol).""" res = self.fetchone() if res is None: @@ -1332,22 +1389,22 @@ def __next__(self): next = __next__ @staticmethod - def nextset(): + def nextset() -> bool | None: """Not supported.""" raise NotSupportedError("The nextset() method is not supported") @staticmethod - def setinputsizes(sizes): + def setinputsizes(sizes: Sequence[int]) -> None: """Not supported.""" pass # unsupported, but silently passed @staticmethod - def setoutputsize(size, column=0): + def setoutputsize(size: int, column: int = 0) -> None: """Not supported.""" pass # unsupported, but silently passed @staticmethod - def row_factory(row): + def row_factory(row: Sequence) -> Sequence: """Process rows before they are returned. You can overwrite this statically with a custom row factory, or @@ -1367,7 +1424,7 @@ def row_factory(self, row): """ raise NotImplementedError - def build_row_factory(self): + def build_row_factory(self) -> Callable[[Sequence], Sequence] | None: """Build a row factory based on the current description. This implementation builds a row factory for creating named tuples. @@ -1375,8 +1432,7 @@ def build_row_factory(self): different row factories whenever the column description changes. """ names = self.colnames - if names: - return _row_factory(tuple(names)) + return _row_factory(tuple(names)) if names else None CursorDescription = namedtuple('CursorDescription', ( @@ -1401,7 +1457,7 @@ class Connection: DataError = DataError NotSupportedError = NotSupportedError - def __init__(self, cnx): + def __init__(self, cnx: Cnx) -> None: """Create a database connection object.""" self._cnx = cnx # connection self._tnx = False # transaction state @@ -1413,7 +1469,7 @@ def __init__(self, cnx): except Exception as e: raise _op_error("Invalid connection") from e - def __enter__(self): + def __enter__(self) -> Connection: """Enter the runtime context for the connection object. The runtime context can be used for running transactions. @@ -1421,8 +1477,11 @@ def __enter__(self): This also starts a transaction in autocommit mode. """ if self.autocommit: + cnx = self._cnx + if not cnx: + raise _op_error("Connection has been closed") try: - self._cnx.source().execute("BEGIN") + cnx.source().execute("BEGIN") except DatabaseError: raise # database provides error message except Exception as e: @@ -1431,7 +1490,8 @@ def __enter__(self): self._tnx = True return self - def __exit__(self, et, ev, tb): + def __exit__(self, et: type[BaseException] | None, + ev: BaseException | None, tb: Any) -> None: """Exit the runtime context for the connection object. This does not close the connection, but it ends a transaction. @@ -1441,103 +1501,101 @@ def __exit__(self, et, ev, tb): else: self.rollback() - def close(self): + def close(self) -> None: """Close the connection object.""" - if self._cnx: - if self._tnx: - with suppress(DatabaseError): - self.rollback() - self._cnx.close() - self._cnx = None - else: + if not self._cnx: raise _op_error("Connection has been closed") + if self._tnx: + with suppress(DatabaseError): + self.rollback() + self._cnx.close() + self._cnx = None @property - def closed(self): + def closed(self) -> bool: """Check whether the connection has been closed or is broken.""" try: return not self._cnx or self._cnx.status != 1 except TypeError: return True - def commit(self): + def commit(self) -> None: """Commit any pending transaction to the database.""" - if self._cnx: - if self._tnx: - self._tnx = False - try: - self._cnx.source().execute("COMMIT") - except DatabaseError: - raise # database provides error message - except Exception as e: - raise _op_error("Can't commit transaction") from e - else: + if not self._cnx: raise _op_error("Connection has been closed") + if self._tnx: + self._tnx = False + try: + self._cnx.source().execute("COMMIT") + except DatabaseError: + raise # database provides error message + except Exception as e: + raise _op_error("Can't commit transaction") from e - def rollback(self): + def rollback(self) -> None: """Roll back to the start of any pending transaction.""" - if self._cnx: - if self._tnx: - self._tnx = False - try: - self._cnx.source().execute("ROLLBACK") - except DatabaseError: - raise # database provides error message - except Exception as e: - raise _op_error("Can't rollback transaction") from e - else: + if not self._cnx: raise _op_error("Connection has been closed") - - def cursor(self): - """Return a new cursor object using the connection.""" - if self._cnx: + if self._tnx: + self._tnx = False try: - return self.cursor_type(self) + self._cnx.source().execute("ROLLBACK") + except DatabaseError: + raise # database provides error message except Exception as e: - raise _op_error("Invalid connection") from e - else: + raise _op_error("Can't rollback transaction") from e + + def cursor(self) -> Cursor: + """Return a new cursor object using the connection.""" + if not self._cnx: raise _op_error("Connection has been closed") + try: + return self.cursor_type(self) + except Exception as e: + raise _op_error("Invalid connection") from e if shortcutmethods: # otherwise do not implement and document this - def execute(self, operation, params=None): + def execute(self, operation: str, + parameters: Sequence | None = None) -> Cursor: """Shortcut method to run an operation on an implicit cursor.""" cursor = self.cursor() - cursor.execute(operation, params) + cursor.execute(operation, parameters) return cursor - def executemany(self, operation, param_seq): + def executemany(self, operation: str, + seq_of_parameters: Sequence[Sequence | None] + ) -> Cursor: """Shortcut method to run an operation against a sequence.""" cursor = self.cursor() - cursor.executemany(operation, param_seq) + cursor.executemany(operation, seq_of_parameters) return cursor # *** Module Interface *** -_connect = connect - - -def connect(dsn=None, - user=None, password=None, - host=None, database=None, **kwargs): +def connect(dsn: str | None = None, + user: str | None = None, password: str | None = None, + host: str | None = None, database: str | None = None, + **kwargs: Any) -> Connection: """Connect to a database.""" # first get params from DSN dbport = -1 - dbhost = "" - dbname = "" - dbuser = "" - dbpasswd = "" - dbopt = "" - try: - params = dsn.split(":") - dbhost = params[0] - dbname = params[1] - dbuser = params[2] - dbpasswd = params[3] - dbopt = params[4] - except (AttributeError, IndexError, TypeError): - pass + dbhost: str | None = "" + dbname: str | None = "" + dbuser: str | None = "" + dbpasswd: str | None = "" + dbopt: str | None = "" + if dsn: + try: + params = dsn.split(":", 4) + dbhost = params[0] + dbname = params[1] + dbuser = params[2] + dbpasswd = params[3] + dbopt = params[4] + except (AttributeError, IndexError, TypeError): + pass # override if necessary if user is not None: @@ -1546,9 +1604,9 @@ def connect(dsn=None, dbpasswd = password if database is not None: dbname = database - if host is not None: + if host: try: - params = host.split(":") + params = host.split(":", 1) dbhost = params[0] dbport = int(params[1]) except (AttributeError, IndexError, TypeError, ValueError): @@ -1562,22 +1620,21 @@ def connect(dsn=None, # pass keyword arguments as connection info string if kwargs: - kwargs = list(kwargs.items()) - if '=' in dbname: - dbname = [dbname] + kwarg_list = list(kwargs.items()) + kw_parts = [] + if dbname and '=' in dbname: + kw_parts.append(dbname) else: - kwargs.insert(0, ('dbname', dbname)) - dbname = [] - for kw, value in kwargs: + kwarg_list.insert(0, ('dbname', dbname)) + for kw, value in kwarg_list: value = str(value) if not value or ' ' in value: value = value.replace('\\', '\\\\').replace("'", "\\'") value = f"'{value}'" - dbname.append(f'{kw}={value}') - dbname = ' '.join(dbname) + kw_parts.append(f'{kw}={value}') + dbname = ' '.join(kw_parts) # open the connection - # noinspection PyArgumentList - cnx = _connect(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd) + cnx = get_cnx(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd) return Connection(cnx) @@ -1590,67 +1647,61 @@ class DbType(frozenset): We must thus use type names as internal type codes. """ - def __new__(cls, values): + def __new__(cls, values: str | Iterable[str]) -> DbType: """Create new type object.""" if isinstance(values, str): values = values.split() - return super().__new__(cls, values) + return super().__new__(cls, values) # type: ignore - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: """Check whether types are considered equal.""" if isinstance(other, str): if other.startswith('_'): other = other[1:] return other in self - else: - return super().__eq__(other) + return super().__eq__(other) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: """Check whether types are not considered equal.""" if isinstance(other, str): if other.startswith('_'): other = other[1:] return other not in self - else: - return super().__ne__(other) + return super().__ne__(other) class ArrayType: """Type class for PostgreSQL array types.""" - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, str): return other.startswith('_') - else: - return isinstance(other, ArrayType) + return isinstance(other, ArrayType) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: if isinstance(other, str): return not other.startswith('_') - else: - return not isinstance(other, ArrayType) + return not isinstance(other, ArrayType) class RecordType: """Type class for PostgreSQL record types.""" - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, TypeCode): # noinspection PyUnresolvedReferences return other.type == 'c' - elif isinstance(other, str): + if isinstance(other, str): return other == 'record' - else: - return isinstance(other, RecordType) + return isinstance(other, RecordType) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: if isinstance(other, TypeCode): # noinspection PyUnresolvedReferences return other.type != 'c' - elif isinstance(other, str): + if isinstance(other, str): return other != 'record' - else: - return not isinstance(other, RecordType) + return not isinstance(other, RecordType) # Mandatory type objects defined by DB-API 2 specs: @@ -1691,35 +1742,38 @@ def __ne__(self, other): # Mandatory type helpers defined by DB-API 2 specs: -def Date(year, month, day): # noqa: N802 +def Date(year: int, month: int, day: int) -> date: # noqa: N802 """Construct an object holding a date value.""" return date(year, month, day) -def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None): # noqa: N802 +def Time(hour: int, minute: int = 0, # noqa: N802 + second: int = 0, microsecond: int = 0, + tzinfo: tzinfo | None = None) -> time: """Construct an object holding a time value.""" return time(hour, minute, second, microsecond, tzinfo) -def Timestamp(year, month, day, # noqa: N802 - hour=0, minute=0, second=0, microsecond=0, - tzinfo=None): +def Timestamp(year: int, month: int, day: int, # noqa: N802 + hour: int = 0, minute: int = 0, + second: int = 0, microsecond: int = 0, + tzinfo: tzinfo | None = None) -> datetime: """Construct an object holding a time stamp value.""" - return datetime(year, month, day, hour, minute, second, microsecond, - tzinfo) + return datetime(year, month, day, hour, minute, + second, microsecond, tzinfo) -def DateFromTicks(ticks): # noqa: N802 +def DateFromTicks(ticks: float | None) -> date: # noqa: N802 """Construct an object holding a date value from the given ticks value.""" return Date(*localtime(ticks)[:3]) -def TimeFromTicks(ticks): # noqa: N802 +def TimeFromTicks(ticks: float | None) -> time: # noqa: N802 """Construct an object holding a time value from the given ticks value.""" return Time(*localtime(ticks)[3:6]) -def TimestampFromTicks(ticks): # noqa: N802 +def TimestampFromTicks(ticks: float | None) -> datetime: # noqa: N802 """Construct an object holding a time stamp from the given ticks value.""" return Timestamp(*localtime(ticks)[:6]) @@ -1730,11 +1784,13 @@ class Binary(bytes): # Additional type helpers for PyGreSQL: -def Interval(days, # noqa: N802 - hours=0, minutes=0, seconds=0, microseconds=0): +def Interval(days: int | float, # noqa: N802 + hours: int | float = 0, minutes: int | float = 0, + seconds: int | float = 0, microseconds: int | float = 0 + ) -> timedelta: """Construct an object holding a time interval value.""" - return timedelta(days, hours=hours, minutes=minutes, seconds=seconds, - microseconds=microseconds) + return timedelta(days, hours=hours, minutes=minutes, + seconds=seconds, microseconds=microseconds) Uuid = Uuid # Construct an object holding a UUID value @@ -1747,7 +1803,7 @@ class Hstore(dict): _re_escape = regex(r'(["\\])') @classmethod - def _quote(cls, s): + def _quote(cls, s: Any) -> Any: if s is None: return 'NULL' if not isinstance(s, str): @@ -1760,7 +1816,7 @@ def _quote(cls, s): s = f'"{s}"' return s - def __str__(self): + def __str__(self) -> str: """Create a printable representation of the hstore value.""" q = self._quote return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) @@ -1769,12 +1825,13 @@ def __str__(self): class Json: """Construct a wrapper for holding an object serializable to JSON.""" - def __init__(self, obj, encode=None): + def __init__(self, obj: Any, + encode: Callable[[Any], str] | None = None) -> None: """Initialize the JSON object.""" self.obj = obj self.encode = encode or jsonencode - def __str__(self): + def __str__(self) -> str: """Create a printable representation of the JSON object.""" obj = self.obj if isinstance(obj, str): @@ -1785,11 +1842,11 @@ def __str__(self): class Literal: """Construct a wrapper for holding a literal SQL string.""" - def __init__(self, sql): + def __init__(self, sql: str) -> None: """Initialize literal SQL string.""" self.sql = sql - def __str__(self): + def __str__(self) -> str: """Return a printable representation of the SQL string.""" return self.sql diff --git a/pgsource.c b/pgsource.c index 73c9a52b..9bc6bb4a 100644 --- a/pgsource.c +++ b/pgsource.c @@ -680,7 +680,8 @@ _source_buildinfo(sourceObject *self, int num) /* Lists fields info. */ static char source_listinfo__doc__[] = - "listinfo() -- get information for all fields (position, name, type oid)"; + "listinfo() -- get information for all fields" + " (position, name, type oid, size, type modifier)"; static PyObject * source_listInfo(sourceObject *self, PyObject *noargs) diff --git a/tests/dbapi20.py b/tests/dbapi20.py index f72d99f7..0c038f72 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -12,7 +12,7 @@ import time import unittest from contextlib import suppress -from typing import Any, Mapping +from typing import Any, ClassVar __version__ = '1.15.0' @@ -22,7 +22,7 @@ class DatabaseAPI20Test(unittest.TestCase): This implementation tests Gadfly, but the TestCase is structured so that other self.drivers can subclass this test case to ensure compliance with the DB-API. It is - expected that this TestCase may be expanded in the future + expected that this TestCase may be expanded i qn the future if ambiguities or edge conditions are discovered. The 'Optional Extensions' are not yet being tested. @@ -43,7 +43,7 @@ class mytest(dbapi20.DatabaseAPI20Test): # method is to be found driver: Any = None connect_args: tuple = () # List of arguments to pass to connect - connect_kw_args: Mapping[str, Any] = {} # Keyword arguments for connect + connect_kw_args: ClassVar[dict[str, Any]] = {} # Keyword arguments table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables ddl1 = f'create table {table_prefix}booze (name varchar(20))' diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 31aec400..71438f71 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -23,7 +23,7 @@ from io import StringIO from operator import itemgetter from time import strftime -from typing import Any, ClassVar +from typing import Any, Callable, ClassVar from uuid import UUID import pg # the module under test @@ -4910,7 +4910,7 @@ def test_debug_multiple_args(self): class TestMemoryLeaks(unittest.TestCase): """Test that the DB class does not leak memory.""" - def get_leaks(self, fut): + def get_leaks(self, fut: Callable): ids: set = set() objs: list = [] add_ids = ids.update diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 657e820c..6838d03a 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -1,9 +1,11 @@ #!/usr/bin/python +from __future__ import annotations + import gc import unittest from datetime import date, datetime, time, timedelta, timezone -from typing import Any, Mapping +from typing import Any, ClassVar from uuid import UUID as Uuid # noqa: N811 import pgdb @@ -26,7 +28,7 @@ class TestPgDb(dbapi20.DatabaseAPI20Test): driver = pgdb connect_args = () - connect_kw_args: Mapping[str, Any] = { + connect_kw_args: ClassVar[dict[str, Any]] = { 'database': dbname, 'host': f"{dbhost or ''}:{dbport or -1}", 'user': dbuser, 'password': dbpasswd} @@ -159,8 +161,10 @@ def test_row_factory(self): class TestCursor(pgdb.Cursor): def row_factory(self, row): + description = self.description + assert isinstance(description, list) return {f'column {desc[0]}': value - for desc, value in zip(self.description, row)} + for desc, value in zip(description, row)} con = self._connect() con.cursor_type = TestCursor @@ -186,7 +190,9 @@ def test_build_row_factory(self): class TestCursor(pgdb.Cursor): def build_row_factory(self): - keys = [desc[0] for desc in self.description] + description = self.description + assert isinstance(description, list) + keys = [desc[0] for desc in description] return lambda row: { key: value for key, value in zip(keys, row)} @@ -566,19 +572,37 @@ def test_float(self): inval = -inf elif inval in ('nan', 'NaN'): inval = nan - if isinf(inval): + if isinf(inval): # type: ignore self.assertTrue(isinf(outval)) - if inval < 0: + if inval < 0: # type: ignore self.assertTrue(outval < 0) else: self.assertTrue(outval > 0) - elif isnan(inval): + elif isnan(inval): # type: ignore self.assertTrue(isnan(outval)) else: self.assertEqual(inval, outval) def test_datetime(self): dt = datetime(2011, 7, 17, 15, 47, 42, 317509) + values = [dt.date(), dt.time(), dt, dt.time(), dt] + assert isinstance(values[3], time) + values[3] = values[3].replace(tzinfo=timezone.utc) + assert isinstance(values[4], datetime) + values[4] = values[4].replace(tzinfo=timezone.utc) + d = (dt.year, dt.month, dt.day) + t = (dt.hour, dt.minute, dt.second, dt.microsecond) + z = (timezone.utc,) + inputs = [ + # input as objects + values, + # input as text + [v.isoformat() for v in values], # type: ignore + # # input using type helpers + [pgdb.Date(*d), pgdb.Time(*t), + pgdb.Timestamp(*(d + t)), pgdb.Time(*(t + z)), + pgdb.Timestamp(*(d + t + z))] + ] table = self.table_prefix + 'booze' con = self._connect() try: @@ -587,26 +611,11 @@ def test_datetime(self): cur.execute(f"create table {table} (" "d date, t time, ts timestamp," "tz timetz, tsz timestamptz)") - for n in range(3): - values = [dt.date(), dt.time(), dt, dt.time(), dt] - values[3] = values[3].replace(tzinfo=timezone.utc) - values[4] = values[4].replace(tzinfo=timezone.utc) - if n == 0: # input as objects - params = values - if n == 1: # input as text - params = [v.isoformat() for v in values] # as text - elif n == 2: # input using type helpers - d = (dt.year, dt.month, dt.day) - t = (dt.hour, dt.minute, dt.second, dt.microsecond) - z = (timezone.utc,) - params = [pgdb.Date(*d), pgdb.Time(*t), - pgdb.Timestamp(*(d + t)), pgdb.Time(*(t + z)), - pgdb.Timestamp(*(d + t + z))] + for params in inputs: for datestyle in ('iso', 'postgres, mdy', 'postgres, dmy', 'sql, mdy', 'sql, dmy', 'german'): cur.execute(f"set datestyle to {datestyle}") - if n != 1: - # noinspection PyUnboundLocalVariable + if not isinstance(params[0], str): cur.execute("select %s,%s,%s,%s,%s", params) row = cur.fetchone() self.assertEqual(row, tuple(values)) @@ -615,11 +624,13 @@ def test_datetime(self): " values (%s,%s,%s,%s,%s)", params) cur.execute(f"select * from {table}") d = cur.description + assert isinstance(d, list) for i in range(5): - self.assertEqual(d[i].type_code, pgdb.DATETIME) - self.assertNotEqual(d[i].type_code, pgdb.STRING) - self.assertNotEqual(d[i].type_code, pgdb.ARRAY) - self.assertNotEqual(d[i].type_code, pgdb.RECORD) + tc = d[i].type_code + self.assertEqual(tc, pgdb.DATETIME) + self.assertNotEqual(tc, pgdb.STRING) + self.assertNotEqual(tc, pgdb.ARRAY) + self.assertNotEqual(tc, pgdb.RECORD) self.assertEqual(d[0].type_code, pgdb.DATE) self.assertEqual(d[1].type_code, pgdb.TIME) self.assertEqual(d[2].type_code, pgdb.TIMESTAMP) @@ -633,20 +644,20 @@ def test_datetime(self): def test_interval(self): td = datetime(2011, 7, 17, 15, 47, 42, 317509) - datetime(1970, 1, 1) + inputs = [ + # input as objects + td, + # input as text + f'{td.days} days {td.seconds} seconds' + f' {td.microseconds} microseconds', + # input using type helpers + pgdb.Interval(td.days, 0, 0, td.seconds, td.microseconds)] table = self.table_prefix + 'booze' con = self._connect() try: cur = con.cursor() cur.execute(f"create table {table} (i interval)") - for n in range(3): - if n == 0: # input as objects - param = td - if n == 1: # input as text - param = (f'{td.days} days {td.seconds} seconds' - f' {td.microseconds} microseconds') - elif n == 2: # input using type helpers - param = pgdb.Interval( - td.days, 0, 0, td.seconds, td.microseconds) + for param in inputs: for intervalstyle in ('sql_standard ', 'postgres', 'postgres_verbose', 'iso_8601'): cur.execute(f"set intervalstyle to {intervalstyle}") @@ -705,7 +716,7 @@ def test_uuid(self): self.assertEqual(result, d) def test_insert_array(self): - values = [ + values: list[tuple[Any, Any]] = [ (None, None), ([], []), ([None], [[None], ['null']]), ([1, 2, 3], [['a', 'b'], ['c', 'd']]), ([20000, 25000, 25000, 30000], @@ -819,15 +830,15 @@ def test_select_record(self): def test_custom_type(self): values = [3, 5, 65] - values = list(map(PgBitString, values)) + values = list(map(PgBitString, values)) # type: ignore table = self.table_prefix + 'booze' con = self._connect() try: cur = con.cursor() - params = enumerate(values) # params have __pg_repr__ method + seq_params = enumerate(values) # params have __pg_repr__ method cur.execute( f'create table "{table}" (n smallint, b bit varying(7))') - cur.executemany(f"insert into {table} values (%s,%s)", params) + cur.executemany(f"insert into {table} values (%s,%s)", seq_params) cur.execute(f"select * from {table}") rows = cur.fetchall() finally: @@ -850,20 +861,29 @@ def test_set_decimal_type(self): try: cur = con.cursor() # change decimal type globally to int - int_type = lambda v: int(float(v)) # noqa: E731 - self.assertTrue(pgdb.decimal_type(int_type) is int_type) + + class CustomDecimal(str): + + def __init__(self, value: Any) -> None: + self.value = value + + def __str__(self) -> str: + return str(self.value).replace('.', ',') + + self.assertTrue(pgdb.decimal_type(CustomDecimal) is CustomDecimal) cur.execute('select 4.25') self.assertEqual(cur.description[0].type_code, pgdb.NUMBER) value = cur.fetchone()[0] - self.assertTrue(isinstance(value, int)) - self.assertEqual(value, 4) + self.assertTrue(isinstance(value, CustomDecimal)) + self.assertEqual(str(value), '4,25') # change decimal type again to float self.assertTrue(pgdb.decimal_type(float) is float) cur.execute('select 4.25') self.assertEqual(cur.description[0].type_code, pgdb.NUMBER) value = cur.fetchone()[0] # the connection still uses the old setting - self.assertTrue(isinstance(value, int)) + self.assertTrue(isinstance(value, str)) + self.assertEqual(str(value), '4,25') # bust the cache for type functions for the connection con.type_cache.reset_typecast() cur.execute('select 4.25') @@ -1352,8 +1372,8 @@ def test_set_row_factory_size(self): info.hits, 0 if maxsize is not None and maxsize < 2 else 4) def test_memory_leaks(self): - ids = set() - objs = [] + ids: set = set() + objs: list = [] add_ids = ids.update gc.collect() objs[:] = gc.get_objects() diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index 09211718..02810ba6 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -14,7 +14,7 @@ import unittest from collections.abc import Iterable from contextlib import suppress -from typing import Sequence +from typing import ClassVar import pgdb # the module under test @@ -101,6 +101,11 @@ class TestCopy(unittest.TestCase): cls_set_up = False + data: ClassVar[list[tuple[int, str]]] = [ + (1935, 'Luciano Pavarotti'), + (1941, 'Plácido Domingo'), + (1946, 'José Carreras')] + @staticmethod def connect(): host = f"{dbhost or ''}:{dbport or -1}" @@ -122,8 +127,9 @@ def setUpClass(cls): cur.execute("set client_encoding=utf8") cur.execute("select 'Plácido and José'").fetchone() except (pgdb.DataError, pgdb.NotSupportedError): - cls.data[1] = (1941, 'Plaacido Domingo') - cls.data[2] = (1946, 'Josee Carreras') + cls.data[1:3] = [ + (1941, 'Plaacido Domingo'), + (1946, 'Josee Carreras')] cls.can_encode = False cur.close() con.close() @@ -152,11 +158,6 @@ def tearDown(self): with suppress(Exception): self.con.close() - data: Sequence[tuple[int, str]] = [ - (1935, 'Luciano Pavarotti'), - (1941, 'Plácido Domingo'), - (1946, 'José Carreras')] - can_encode = True @property @@ -405,9 +406,9 @@ def test_generator(self): self.assertIsInstance(ret, Iterable) rows = list(ret) self.assertEqual(len(rows), 3) - rows = ''.join(rows) - self.assertIsInstance(rows, str) - self.assertEqual(rows, self.data_text) + text = ''.join(rows) + self.assertIsInstance(text, str) + self.assertEqual(text, self.data_text) self.check_rowcount() def test_generator_with_schema_name(self): @@ -419,9 +420,9 @@ def test_generator_bytes(self): self.assertIsInstance(ret, Iterable) rows = list(ret) self.assertEqual(len(rows), 3) - rows = b''.join(rows) - self.assertIsInstance(rows, bytes) - self.assertEqual(rows, self.data_text.encode()) + byte_text = b''.join(rows) + self.assertIsInstance(byte_text, bytes) + self.assertEqual(byte_text, self.data_text.encode()) def test_rowcount_increment(self): ret = self.copy_to() @@ -477,9 +478,9 @@ def test_csv(self): self.assertIsInstance(ret, Iterable) rows = list(ret) self.assertEqual(len(rows), 3) - rows = ''.join(rows) - self.assertIsInstance(rows, str) - self.assertEqual(rows, self.data_csv) + csv = ''.join(rows) + self.assertIsInstance(csv, str) + self.assertEqual(csv, self.data_csv) self.check_rowcount(3) def test_csv_with_sep(self): diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index c28fbefc..c09d13b8 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -139,7 +139,7 @@ def test_all_steps(self): cursor.executemany("insert into fruits (name) values (%s)", parameters) con.commit() cursor.execute('select * from fruits where id=1') - r = cursor.fetchone() + r: Any = cursor.fetchone() self.assertIsInstance(r, tuple) self.assertEqual(len(r), 2) r = str(r) From 18bb347aeb5dd618b7c8877d090adc086ffa8e5f Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 4 Sep 2023 08:47:02 +0200 Subject: [PATCH 050/118] Do not use OrderedDict anymore --- docs/conf.py | 2 +- docs/contents/pg/db_wrapper.rst | 9 +++---- docs/contents/pgdb/cursor.rst | 4 +-- docs/contents/tutorial.rst | 9 ++----- pg.py | 44 ++++++++++++++++----------------- tox.ini | 4 +-- 6 files changed, 33 insertions(+), 39 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 0f95ab1b..9dd604f2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -41,7 +41,7 @@ 'list', 'object', 'set', 'str', 'tuple', 'False', 'True', 'None', 'namedtuple', 'namedtuples', - 'OrderedDict', 'decimal.Decimal', + 'decimal.Decimal', 'bytes/str', 'list of namedtuples', 'tuple of callables', 'first field', 'type of first field', 'Notice', 'DATETIME'), diff --git a/docs/contents/pg/db_wrapper.rst b/docs/contents/pg/db_wrapper.rst index ea4f71c1..1dbd18ef 100644 --- a/docs/contents/pg/db_wrapper.rst +++ b/docs/contents/pg/db_wrapper.rst @@ -823,7 +823,7 @@ has only one column anyway. :param int offset: number of rows to be skipped (the OFFSET clause) :param bool scalar: whether only the first column shall be returned :returns: the content of the table as a list - :rtype: dict or OrderedDict + :rtype: dict :raises TypeError: the table name has not been specified :raises KeyError: keyname(s) are invalid or not part of the result :raises pg.ProgrammingError: no keyname(s) and table has no primary key @@ -837,10 +837,9 @@ The rows will be also named tuples unless the *scalar* option has been set to *True*. With the optional parameter *keyname* you can specify a different set of columns to be used as the keys of the dictionary. -If the Python version supports it, the dictionary will be an *OrderedDict* -using the order specified with the *order* parameter or the key column(s) -if not specified. You can set *order* to *False* if you don't care about the -ordering. In this case the returned dictionary will be an ordinary one. +The dictionary will be ordered using the order specified with the *order* +parameter or the key column(s) if not specified. You can set *order* to +*False* if you don't care about the ordering. .. versionadded:: 5.0 diff --git a/docs/contents/pgdb/cursor.rst b/docs/contents/pgdb/cursor.rst index e1ed8b0f..72473057 100644 --- a/docs/contents/pgdb/cursor.rst +++ b/docs/contents/pgdb/cursor.rst @@ -340,8 +340,8 @@ be used for all result sets. If you overwrite this method, the method will be ignored. Note that named tuples are very efficient and can be easily converted to -dicts (even OrderedDicts) by calling ``row._asdict()``. If you still want -to return rows as dicts, you can create a custom cursor class like this:: +dicts by calling ``row._asdict()``. If you still want to return rows as dicts, +you can create a custom cursor class like this:: class DictCursor(pgdb.Cursor): diff --git a/docs/contents/tutorial.rst b/docs/contents/tutorial.rst index 15577ad3..79273c7c 100644 --- a/docs/contents/tutorial.rst +++ b/docs/contents/tutorial.rst @@ -117,13 +117,8 @@ Using the method :meth:`DB.get_as_dict`, you can easily import the whole table into a Python dictionary mapping the primary key *id* to the *name*:: >>> db.get_as_dict('fruits', scalar=True) - OrderedDict([(1, 'apple'), - (2, 'banana'), - (3, 'cherimaya'), - (4, 'durian'), - (5, 'eggfruit'), - (6, 'fig'), - (7, 'grapefruit')]) + {1: 'apple', 2: 'banana', 3: 'cherimaya', 4: 'durian', 5: 'eggfruit', + 6: 'fig', 7: 'grapefruit', 8: 'apple', 9: 'banana'} To change a single row in the database, you can use the :meth:`DB.update` method. For instance, if you want to capitalize the name 'banana':: diff --git a/pg.py b/pg.py index 45f8ae46..75c0b32c 100644 --- a/pg.py +++ b/pg.py @@ -2638,12 +2638,13 @@ def truncate(self, table: str | list[str] | tuple[str, ...] | self._do_debug(cmd) return self._valid_db.query(cmd) - def get_as_list(self, table: str, - what: str | list[str] | tuple[str, ...] | None = None, - where: str | list[str] | tuple[str, ...] | None = None, - order: str | list[str] | tuple[str, ...] | None = None, - limit: int | None = None, offset: int | None = None, - scalar: bool = False) -> list: + def get_as_list( + self, table: str, + what: str | list[str] | tuple[str, ...] | None = None, + where: str | list[str] | tuple[str, ...] | None = None, + order: str | list[str] | tuple[str, ...] | bool | None = None, + limit: int | None = None, offset: int | None = None, + scalar: bool = False) -> list: """Get a table as a list. This gets a convenient representation of the table as a list @@ -2686,13 +2687,13 @@ def get_as_list(self, table: str, if isinstance(where, (list, tuple)): where = ' AND '.join(map(str, where)) cmd_parts.extend(['WHERE', where]) - if order is None: + if order is None or order is True: try: order = self.pkeys(table) except (KeyError, ProgrammingError): with suppress(KeyError, ProgrammingError): order = list(self.get_attnames(table)) - if order: + if order and not isinstance(order, bool): if isinstance(order, (list, tuple)): order = ', '.join(map(str, order)) cmd_parts.extend(['ORDER BY', order]) @@ -2708,13 +2709,14 @@ def get_as_list(self, table: str, res = [row[0] for row in res] return res - def get_as_dict(self, table: str, - keyname: str | list[str] | tuple[str, ...] | None = None, - what: str | list[str] | tuple[str, ...] | None = None, - where: str | list[str] | tuple[str, ...] | None = None, - order: str | list[str] | tuple[str, ...] | None = None, - limit: int | None = None, offset: int | None = None, - scalar: bool = False) -> dict: + def get_as_dict( + self, table: str, + keyname: str | list[str] | tuple[str, ...] | None = None, + what: str | list[str] | tuple[str, ...] | None = None, + where: str | list[str] | tuple[str, ...] | None = None, + order: str | list[str] | tuple[str, ...] | bool | None = None, + limit: int | None = None, offset: int | None = None, + scalar: bool = False) -> dict: """Get a table as a dictionary. This method is similar to get_as_list(), but returns the table @@ -2728,11 +2730,9 @@ def get_as_dict(self, table: str, set of columns to be used as the keys of the dictionary. It must be set as a string, list or a tuple. - If the Python version supports it, the dictionary will be an - dict using the order specified with the 'order' parameter - or the key column(s) if not specified. You can set 'order' to False - if you don't care about the ordering. In this case the returned - dictionary will be an ordinary one. + The dictionary will be ordered using the order specified with the + 'order' parameter or the key column(s) if not specified. You can + set 'order' to False if you don't care about the ordering. """ if not table: raise TypeError('The table name is missing') @@ -2759,9 +2759,9 @@ def get_as_dict(self, table: str, if isinstance(where, (list, tuple)): where = ' AND '.join(map(str, where)) cmd_parts.extend(['WHERE', where]) - if order is None: + if order is None or order is True: order = keyname - if order: + if order and not isinstance(order, bool): if isinstance(order, (list, tuple)): order = ', '.join(map(str, order)) cmd_parts.extend(['ORDER BY', order]) diff --git a/tox.ini b/tox.ini index 7e52747d..322a3f32 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ # config file for tox [tox] -envlist = py3{7,8,9,10,11},ruff,cformat,docs +envlist = py3{7,8,9,10,11},ruff,mypy,cformat,docs [testenv:ruff] basepython = python3.11 @@ -13,7 +13,7 @@ commands = basepython = python3.11 deps = mypy>=1.5.1 commands = - mypy setup.py pg.py pgdb.py tests + mypy pg.py pgdb.py tests [testenv:cformat] basepython = python3.11 From 2ccc937ef018dac9c7f4ebd2c2ae2a81b808c254 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 4 Sep 2023 12:51:06 +0200 Subject: [PATCH 051/118] Add back setup keywords in setup.py This brings some duplication with pyproject.toml but it avoids missing metadata and Python modules when using legacy install. Unfortunately, due to the C extension, we cannot completely get rid of setup.py yet. --- pg.py | 2 +- setup.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/pg.py b/pg.py index 75c0b32c..c7a34e1b 100644 --- a/pg.py +++ b/pg.py @@ -164,7 +164,7 @@ 'InvalidResultError', 'MultipleResultsError', 'NoResultError', 'NotSupportedError', 'OperationalError', 'ProgrammingError', - 'Query', + 'Connection', 'Query', 'INV_READ', 'INV_WRITE', 'POLLING_OK', 'POLLING_FAILED', 'POLLING_READING', 'POLLING_WRITING', 'SEEK_CUR', 'SEEK_END', 'SEEK_SET', diff --git a/setup.py b/setup.py index c20c9607..3ad3a906 100755 --- a/setup.py +++ b/setup.py @@ -25,11 +25,15 @@ raise Exception( f"Sorry, PyGreSQL {version} does not support this Python version") +with open('README.rst') as f: + long_description = f.read() + # For historical reasons, PyGreSQL does not install itself as a single # "pygresql" package, but as two top-level modules "pg", providing the # classic interface, and "pgdb" for the modern DB-API 2.0 interface. # These two top-level Python modules share the same C extension "_pg". +py_modules = ['pg', 'pgdb'] c_sources = ['pgmodule.c'] def pg_config(s): @@ -125,6 +129,40 @@ def finalize_options(self): setup( name="PyGreSQL", version=version, + description="Python PostgreSQL Interfaces", + long_description=long_description, + long_description_content_type='text/x-rst', + keywords="pygresql postgresql database api dbapi", + author="D'Arcy J. M. Cain", + author_email="darcy@PyGreSQL.org", + url="https://pygresql.github.io/", + download_url="https://pygresql.github.io/contents/download/", + project_urls={ + "Documentation": "https://pygresql.github.io/contents/", + "Issue Tracker": "https://github.com/PyGreSQL/PyGreSQL/issues/", + "Mailing List": "https://mail.vex.net/mailman/listinfo/pygresql", + "Source Code": "https://github.com/PyGreSQL/PyGreSQL"}, + classifiers=[ + "Development Status :: 6 - Mature", + "Intended Audience :: Developers", + "License :: OSI Approved :: PostgreSQL License", + "Operating System :: OS Independent", + "Programming Language :: C", + 'Programming Language :: Python', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + "Programming Language :: SQL", + "Topic :: Database", + "Topic :: Database :: Front-Ends", + "Topic :: Software Development :: Libraries :: Python Modules"], + license="PostgreSQL", + py_modules=py_modules, + test_suite='tests.discover', + zip_safe=False, ext_modules=[Extension( '_pg', c_sources, include_dirs=include_dirs, library_dirs=library_dirs, From 4d9103e61d9d96e7e99902307014846c390dace9 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 4 Sep 2023 23:42:40 +0200 Subject: [PATCH 052/118] Make inline type hints available and add stub file For this to work it was necessary to convert the two single modules to packages that can hold the pg.typed hint and stub file. The advantage of the stub file is that we now have proper type hints also for the objects imported from the C extension module. --- .github/workflows/lint.yml | 2 +- MANIFEST.in | 16 +- docs/contents/install.rst | 9 +- docs/contents/pg/connection.rst | 4 +- docs/contents/pg/large_objects.rst | 4 +- docs/contents/pg/module.rst | 10 +- docs/download/files.rst | 29 +- pgconn.c => ext/pgconn.c | 0 pginternal.c => ext/pginternal.c | 0 pglarge.c => ext/pglarge.c | 0 pgmodule.c => ext/pgmodule.c | 2 +- pgnotice.c => ext/pgnotice.c | 0 pgquery.c => ext/pgquery.c | 0 pgsource.c => ext/pgsource.c | 4 +- pgtypes.h => ext/pgtypes.h | 0 pg.py => pg/__init__.py | 110 ++--- pg/_pg.pyi | 635 +++++++++++++++++++++++++++++ pg/py.typed | 4 + pgdb.py => pgdb/__init__.py | 88 ++-- pgdb/py.typed | 1 + pyproject.toml | 6 +- setup.py | 46 +-- tests/config.py | 5 +- tests/test_classic_connection.py | 78 ++-- tests/test_classic_dbwrapper.py | 36 +- tests/test_classic_functions.py | 48 +-- tests/test_dbapi20.py | 2 +- tox.ini | 6 +- 28 files changed, 888 insertions(+), 257 deletions(-) rename pgconn.c => ext/pgconn.c (100%) rename pginternal.c => ext/pginternal.c (100%) rename pglarge.c => ext/pglarge.c (100%) rename pgmodule.c => ext/pgmodule.c (99%) rename pgnotice.c => ext/pgnotice.c (100%) rename pgquery.c => ext/pgquery.c (100%) rename pgsource.c => ext/pgsource.c (99%) rename pgtypes.h => ext/pgtypes.h (100%) rename pg.py => pg/__init__.py (97%) create mode 100644 pg/_pg.pyi create mode 100644 pg/py.typed rename pgdb.py => pgdb/__init__.py (97%) create mode 100644 pgdb/py.typed diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 40f5299e..dad89096 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -22,5 +22,5 @@ jobs: with: python-version: 3.11 - name: Run quality checks - run: tox -e ruff,docs + run: tox -e ruff,mypy,cformat,docs timeout-minutes: 5 diff --git a/MANIFEST.in b/MANIFEST.in index e6e9e5a9..4ff1c2b6 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,14 +1,20 @@ -include *.c -include *.h -include *.py +include setup.py + +recursive-include pg pgdb tests *.py + +include pg/*.pyi +include pg/py.typed +include pgdb/py.typed + +include ext/*.c +include ext/*.h include README.rst include LICENSE.txt include tox.ini - -recursive-include tests *.py +include pyproject.toml include docs/Makefile include docs/make.bat diff --git a/docs/contents/install.rst b/docs/contents/install.rst index fd4f99b5..f447abc3 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -16,11 +16,10 @@ On Windows, you also need to make sure that the directory that contains The current version of PyGreSQL has been tested with Python versions 3.7 to 3.11, and PostgreSQL versions 10 to 15. -PyGreSQL will be installed as three modules, a shared library called -``_pg.so`` (on Linux) or a DLL called ``_pg.pyd`` (on Windows), and two pure -Python wrapper modules called ``pg.py`` and ``pgdb.py``. -All three files will be installed directly into the Python site-packages -directory. To uninstall PyGreSQL, simply remove these three files. +PyGreSQL will be installed as two packages named ``pg`` (for the classic +interface) and ``pgdb`` (for the DB API 2 compliant interface). The former +also contains a shared library called ``_pg.so`` (on Linux) or a DLL called +``_pg.pyd`` (on Windows) and a stub file ``_pg.pyi`` for this library. Installing with Pip diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index b175a2a0..e4a08591 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -616,7 +616,7 @@ getline -- get a line from server socket Get a line from server socket - :returns: the line read + :returns: the line read :rtype: str :raises TypeError: invalid connection :raises TypeError: too many parameters @@ -666,7 +666,7 @@ getlo -- build a large object from given oid :param int oid: OID of the existing large object :returns: object handling the PostgreSQL large object :rtype: :class:`LargeObject` - :raises TypeError: invalid connection, bad parameter type, or too many parameters + :raises TypeError: invalid connection, bad parameter type, or too many parameters :raises ValueError: bad OID value (0 is invalid_oid) This method allows reusing a previously created large object through the diff --git a/docs/contents/pg/large_objects.rst b/docs/contents/pg/large_objects.rst index a1d9818d..037b2128 100644 --- a/docs/contents/pg/large_objects.rst +++ b/docs/contents/pg/large_objects.rst @@ -75,9 +75,9 @@ current position. .. method:: LargeObject.write(string) - Read data to large object + Write data to large object - :param bytes string: string buffer to be written + :param bytes data: buffer of bytes to be written :rtype: None :raises TypeError: invalid connection, bad parameter type, or too many parameters :raises IOError: object is not opened, or write error diff --git a/docs/contents/pg/module.rst b/docs/contents/pg/module.rst index 2dc26d5f..acf75f93 100644 --- a/docs/contents/pg/module.rst +++ b/docs/contents/pg/module.rst @@ -349,8 +349,7 @@ get/set_decimal -- decimal type to be used for numeric values :rtype: class This function returns the Python class that is used by PyGreSQL to hold -PostgreSQL numeric values. The default class is :class:`decimal.Decimal` -if available, otherwise the :class:`float` type is used. +PostgreSQL numeric values. The default class is :class:`decimal.Decimal`. .. function:: set_decimal(cls) @@ -360,8 +359,7 @@ if available, otherwise the :class:`float` type is used. This function can be used to specify the Python class that shall be used by PyGreSQL to hold PostgreSQL numeric values. -The default class is :class:`decimal.Decimal` if available, -otherwise the :class:`float` type is used. +The default class is :class:`decimal.Decimal`. get/set_decimal_point -- decimal mark used for monetary values -------------------------------------------------------------- @@ -639,7 +637,7 @@ are not supported by default in PostgreSQL. :param str string: the string with the text representation of the array :param cast: a typecast function for the elements of the array :type cast: callable or None - :param delim: delimiter character between adjacent elements + :param bytes delim: delimiter character between adjacent elements :type str: byte string with a single character :returns: a list representing the PostgreSQL array in Python :rtype: list @@ -667,7 +665,7 @@ then a comma will be used by default. :param str string: the string with the text representation of the record :param cast: typecast function(s) for the elements of the record :type cast: callable, list or tuple of callables, or None - :param delim: delimiter character between adjacent elements + :param bytes delim: delimiter character between adjacent elements :type str: byte string with a single character :returns: a tuple representing the PostgreSQL record in Python :rtype: tuple diff --git a/docs/download/files.rst b/docs/download/files.rst index ec581bf0..f5e7a523 100644 --- a/docs/download/files.rst +++ b/docs/download/files.rst @@ -3,26 +3,13 @@ Distribution files ============== = -pgmodule.c the main source file for the C extension module (_pg) -pgconn.c the connection object -pginternal.c internal functions -pglarge.c large object support -pgnotice.c the notice object -pgquery.c the query object -pgsource.c the source object +pg/ the "classic" PyGreSQL module -pgtypes.h PostgreSQL type definitions +pgdb/ a DB-SIG DB-API 2.0 compliant API wrapper for PyGreSQL -pg.py the "classic" PyGreSQL module -pgdb.py a DB-SIG DB-API 2.0 compliant API wrapper for PyGreSQL +ext/ the source files for the C extension -setup.py the Python setup script - - To install PyGreSQL, you can run "python setup.py install". - -setup.cfg the Python setup configuration - -docs/ documentation directory +docs/ the documentation directory The documentation has been created with Sphinx. All text files are in ReST format; a HTML version of @@ -30,4 +17,12 @@ docs/ documentation directory tests/ a suite of unit tests for PyGreSQL +pyproject.toml contains project metadata and the build system requirements + +setup.py the Python setup script used for building the C extension + +LICENSE.text contains the license information for PyGreSQL + +README.rst a summary of the PyGreSQL project + ============== = diff --git a/pgconn.c b/ext/pgconn.c similarity index 100% rename from pgconn.c rename to ext/pgconn.c diff --git a/pginternal.c b/ext/pginternal.c similarity index 100% rename from pginternal.c rename to ext/pginternal.c diff --git a/pglarge.c b/ext/pglarge.c similarity index 100% rename from pglarge.c rename to ext/pglarge.c diff --git a/pgmodule.c b/ext/pgmodule.c similarity index 99% rename from pgmodule.c rename to ext/pgmodule.c index 64e769f6..546c5cc5 100644 --- a/pgmodule.c +++ b/ext/pgmodule.c @@ -180,7 +180,7 @@ typedef struct { /* Connect to a database. */ static char pg_connect__doc__[] = - "connect(dbname, host, port, opt, user, passwd, wait) -- connect to a " + "connect(dbname, host, port, opt, user, passwd, nowait) -- connect to a " "PostgreSQL database\n\n" "The connection uses the specified parameters (optional, keywords " "aware).\n"; diff --git a/pgnotice.c b/ext/pgnotice.c similarity index 100% rename from pgnotice.c rename to ext/pgnotice.c diff --git a/pgquery.c b/ext/pgquery.c similarity index 100% rename from pgquery.c rename to ext/pgquery.c diff --git a/pgsource.c b/ext/pgsource.c similarity index 99% rename from pgsource.c rename to ext/pgsource.c index 9bc6bb4a..42510b30 100644 --- a/pgsource.c +++ b/ext/pgsource.c @@ -119,8 +119,8 @@ source_setattr(sourceObject *self, char *name, PyObject *v) /* Close object. */ static char source_close__doc__[] = - "close() -- close query object without deleting it\n\n" - "All instances of the query object can no longer be used after this " + "close() -- close source object without deleting it\n\n" + "All instances of the source object can no longer be used after this " "call.\n"; static PyObject * diff --git a/pgtypes.h b/ext/pgtypes.h similarity index 100% rename from pgtypes.h rename to ext/pgtypes.h diff --git a/pg.py b/pg/__init__.py similarity index 97% rename from pg.py rename to pg/__init__.py index c7a34e1b..0740db20 100644 --- a/pg.py +++ b/pg/__init__.py @@ -51,7 +51,7 @@ from uuid import UUID try: - from _pg import version + from ._pg import version except ImportError as e: # noqa: F841 import os libpq = 'libpq.' @@ -66,7 +66,7 @@ for path in paths: with add_dll_dir(os.path.abspath(path)): try: - from _pg import version # type: ignore + from ._pg import version except ImportError: pass else: @@ -85,13 +85,17 @@ del version # import objects from extension module -from _pg import ( +from ._pg import ( INV_READ, INV_WRITE, POLLING_FAILED, POLLING_OK, POLLING_READING, POLLING_WRITING, + RESULT_DDL, + RESULT_DML, + RESULT_DQL, + RESULT_EMPTY, SEEK_CUR, SEEK_END, SEEK_SET, @@ -167,6 +171,7 @@ 'Connection', 'Query', 'INV_READ', 'INV_WRITE', 'POLLING_OK', 'POLLING_FAILED', 'POLLING_READING', 'POLLING_WRITING', + 'RESULT_DDL', 'RESULT_DML', 'RESULT_DQL', 'RESULT_EMPTY', 'SEEK_CUR', 'SEEK_END', 'SEEK_SET', 'TRANS_ACTIVE', 'TRANS_IDLE', 'TRANS_INERROR', 'TRANS_INTRANS', 'TRANS_UNKNOWN', @@ -186,6 +191,8 @@ # Auxiliary classes and functions that are independent of a DB connection: +SomeNamedTuple = Any # alias for accessing arbitrary named tuples + def get_args(func: Callable) -> list: return list(signature(func).parameters) @@ -1188,7 +1195,7 @@ class DbType(str): category: str delim: str relid: int - + _get_attnames: Callable[[DbType], AttrDict] @property @@ -1336,14 +1343,14 @@ def _dictiter(q: Query) -> Generator[dict[str, Any], None, None]: yield dict(zip(fields, r)) -def _namediter(q: Query) -> Generator[NamedTuple, None, None]: +def _namediter(q: Query) -> Generator[SomeNamedTuple, None, None]: """Get query result as an iterator of named tuples.""" row = _row_factory(q.listfields()) for r in q: yield row(r) -def _namednext(q: Query) -> NamedTuple: +def _namednext(q: Query) -> SomeNamedTuple: """Get next row from query result as a named tuple.""" return _row_factory(q.listfields())(next(q)) @@ -1378,23 +1385,29 @@ def __iter__(self) -> Iterator[Any]: # Error messages -E = TypeVar('E', bound=DatabaseError) +E = TypeVar('E', bound=Error) -def _db_error(msg: str, cls:type[E] = DatabaseError) -> type[E]: - """Return DatabaseError with empty sqlstate attribute.""" +def _error(msg: str, cls: type[E]) -> E: + """Return specified error object with empty sqlstate attribute.""" error = cls(msg) - error.sqlstate = None + if isinstance(error, DatabaseError): + error.sqlstate = None return error +def _db_error(msg: str) -> DatabaseError: + """Return DatabaseError.""" + return _error(msg, DatabaseError) + + def _int_error(msg: str) -> InternalError: """Return InternalError.""" - return _db_error(msg, InternalError) + return _error(msg, InternalError) def _prg_error(msg: str) -> ProgrammingError: """Return ProgrammingError.""" - return _db_error(msg, ProgrammingError) + return _error(msg, ProgrammingError) # Initialize the C module @@ -1468,7 +1481,7 @@ def unlisten(self) -> None: self.listening = False def notify(self, db: DB | None = None, stop: bool = False, - payload: str | None = None) -> None: + payload: str | None = None) -> Query | None: """Generate a notification. Optionally, you can pass a payload with the notification. @@ -1480,16 +1493,17 @@ def notify(self, db: DB | None = None, stop: bool = False, must pass a different database connection since PyGreSQL database connections are not thread-safe. """ - if self.listening: + if not self.listening: + return None + if not db: + db = self.db if not db: - db = self.db - if not db: - return - event = self.stop_event if stop else self.event - cmd = f'notify "{event}"' - if payload: - cmd += f", '{payload}'" - return db.query(cmd) + return None + event = self.stop_event if stop else self.event + cmd = f'notify "{event}"' + if payload: + cmd += f", '{payload}'" + return db.query(cmd) def __call__(self) -> None: """Invoke the notification handler. @@ -1545,6 +1559,7 @@ class DB: """Wrapper class for the _pg connection type.""" db: Connection | None = None # invalid fallback for underlying connection + _db_args: Any # either the connectoin args or the underlying connection def __init__(self, *args: Any, **kw: Any) -> None: """Create a new connection. @@ -1730,7 +1745,7 @@ def reset(self) -> None: All derived queries and large objects derived from this connection will not be usable after this call. - """ + """ self._valid_db.reset() def reopen(self) -> None: @@ -1741,7 +1756,8 @@ def reopen(self) -> None: """ # There is no such shared library function. if self._closeable: - db = connect(*self._db_args[0], **self._db_args[1]) + args, kw = self._db_args + db = connect(*args, **kw) if self.db: self.db.set_cast_hook(None) self.db.close() @@ -1750,7 +1766,7 @@ def reopen(self) -> None: else: self.db = self._db_args - def begin(self, mode: str | None = None) -> None: + def begin(self, mode: str | None = None) -> Query: """Begin a transaction.""" qstr = 'BEGIN' if mode: @@ -1759,13 +1775,13 @@ def begin(self, mode: str | None = None) -> None: start = begin - def commit(self) -> None: + def commit(self) -> Query: """Commit the current transaction.""" return self.query('COMMIT') end = commit - def rollback(self, name: str | None = None) -> None: + def rollback(self, name: str | None = None) -> Query: """Roll back the current transaction.""" qstr = 'ROLLBACK' if name: @@ -1774,11 +1790,11 @@ def rollback(self, name: str | None = None) -> None: abort = rollback - def savepoint(self, name: str) -> None: + def savepoint(self, name: str) -> Query: """Define a new savepoint within the current transaction.""" return self.query('SAVEPOINT ' + name) - def release(self, name: str) -> None: + def release(self, name: str) -> Query: """Destroy a previously defined savepoint.""" return self.query('RELEASE ' + name) @@ -1983,7 +1999,7 @@ def query_prepared(self, name: str, *args: Any) -> Query: self._do_debug('EXECUTE', name) return db.query_prepared(name) - def prepare(self, name: str, command: str) -> Query: + def prepare(self, name: str, command: str) -> None: """Create a prepared SQL statement. This creates a prepared statement for the given command with the @@ -1999,7 +2015,7 @@ def prepare(self, name: str, command: str) -> Query: if name is None: name = '' self._do_debug('prepare', name, command) - return self._valid_db.prepare(name, command) + self._valid_db.prepare(name, command) def describe_prepared(self, name: str | None = None) -> Query: """Describe a prepared SQL statement. @@ -2057,17 +2073,17 @@ def pkey(self, table: str, composite: bool = False, flush: bool = False " {}::pg_catalog.regclass" " AND i.indisprimary ORDER BY a.attnum").format( _quote_if_unqualified('$1', table)) - pkey = self._valid_db.query(cmd, (table,)).getresult() - if not pkey: + res = self._valid_db.query(cmd, (table,)).getresult() + if not res: raise KeyError(f'Table {table} has no primary key') from e # we want to use the order defined in the primary key index here, # not the order as defined by the columns in the table - if len(pkey) > 1: - indkey = pkey[0][2] + if len(res) > 1: + indkey = res[0][2] pkey = tuple(row[0] for row in sorted( - pkey, key=lambda row: indkey.index(row[1]))) + res, key=lambda row: indkey.index(row[1]))) else: - pkey = pkey[0][0] + pkey = res[0][0] pkeys[table] = pkey # cache it if composite and not isinstance(pkey, tuple): pkey = (pkey,) @@ -2075,7 +2091,7 @@ def pkey(self, table: str, composite: bool = False, flush: bool = False def pkeys(self, table: str) -> tuple[str, ...]: """Get the primary key of a table as a tuple. - + Same as pkey() with 'composite' set to True. """ return self.pkey(table, True) # type: ignore @@ -2146,9 +2162,9 @@ def get_attnames(self, table: str, with_oid: bool=True, flush: bool=False cmd = f"({cmd} OR a.attname OPERATOR(pg_catalog.=) 'oid')" cmd = self._query_attnames.format( _quote_if_unqualified('$1', table), cmd) - names = self._valid_db.query(cmd, (table,)).getresult() + res = self._valid_db.query(cmd, (table,)).getresult() types = self.dbtypes - names = AttrDict((name[0], types.add(*name[1:])) for name in names) + names = AttrDict((name[0], types.add(*name[1:])) for name in res) attnames[table] = names # cache it return names @@ -2172,8 +2188,8 @@ def get_generated(self, table: str, flush: bool = False) -> frozenset[str]: cmd = f"{cmd} AND {self._query_generated}" cmd = self._query_attnames.format( _quote_if_unqualified('$1', table), cmd) - names = self._valid_db.query(cmd, (table,)).getresult() - names = frozenset(name[0] for name in names) + res = self._valid_db.query(cmd, (table,)).getresult() + names = frozenset(name[0] for name in res) generated[table] = names # cache it return names @@ -2578,7 +2594,7 @@ def delete(self, table: str, row: dict[str, Any] | None = None, **kw: Any cmd = f'DELETE FROM {t} WHERE {where}' # noqa: S608 self._do_debug(cmd, params) res = self._valid_db.query(cmd, params) - return int(res) + return int(res) # type: ignore def truncate(self, table: str | list[str] | tuple[str, ...] | set[str] | frozenset[str], restart: bool = False, @@ -2660,7 +2676,7 @@ def get_as_list( The parameter 'where' can restrict the query to only return a subset of the table rows. It can be a string, list or a tuple of SQL expressions that all need to be fulfilled. - + The parameter 'order' specifies the ordering of the rows. It can also be a string, list or a tuple. If no ordering is specified, the result will be ordered by the primary key(s) or all columns if @@ -2806,7 +2822,7 @@ def get_row(row : tuple) -> tuple: if key_tuple: keys = _namediter(_MemoryQuery(keys, keynames)) # type: ignore if row_is_tuple: - fields = [f for f in fields if f not in keyset] + fields = tuple(f for f in fields if f not in keyset) rows = _namediter(_MemoryQuery(rows, fields)) # type: ignore # noinspection PyArgumentList return dict(zip(keys, rows)) @@ -2824,6 +2840,6 @@ def notification_handler(self, event: str, callback: Callable, # if run as script, print some information if __name__ == '__main__': - print('PyGreSQL version' + version) - print('') + print('PyGreSQL version', version) + print() print(__doc__) diff --git a/pg/_pg.pyi b/pg/_pg.pyi new file mode 100644 index 00000000..70f6e37e --- /dev/null +++ b/pg/_pg.pyi @@ -0,0 +1,635 @@ +"""Type hints for the PyGreSQL C extension.""" + +from __future__ import annotations + +from typing import Any, Callable, Iterable, Sequence, TypeVar + +AnyStr = TypeVar('AnyStr', str, bytes, str | bytes) +SomeNamedTuple = Any # alias for accessing arbitrary named tuples + +version: str +__version__: str + +RESULT_EMPTY: int +RESULT_DML: int +RESULT_DDL: int +RESULT_DQL: int + +TRANS_IDLE: int +TRANS_ACTIVE: int +TRANS_INTRANS: int +TRANS_INERROR: int +TRANS_UNKNOWN: int + +POLLING_OK: int +POLLING_FAILED: int +POLLING_READING: int +POLLING_WRITING: int + +INV_READ: int +INV_WRITE: int + +SEEK_SET: int +SEEK_CUR: int +SEEK_END: int + + +class Error(Exception): + """Exception that is the base class of all other error exceptions.""" + + +class Warning(Exception): # noqa: N818 + """Exception raised for important warnings.""" + + +class InterfaceError(Error): + """Exception raised for errors related to the database interface.""" + + +class DatabaseError(Error): + """Exception raised for errors that are related to the database.""" + + sqlstate: str | None + + +class InternalError(DatabaseError): + """Exception raised when the database encounters an internal error.""" + + +class OperationalError(DatabaseError): + """Exception raised for errors related to the operation of the database.""" + + +class ProgrammingError(DatabaseError): + """Exception raised for programming errors.""" + + +class IntegrityError(DatabaseError): + """Exception raised when the relational integrity is affected.""" + + +class DataError(DatabaseError): + """Exception raised for errors due to problems with the processed data.""" + + +class NotSupportedError(DatabaseError): + """Exception raised when a method or database API is not supported.""" + + +class InvalidResultError(DataError): + """Exception when a database operation produced an invalid result.""" + + +class NoResultError(InvalidResultError): + """Exception when a database operation did not produce any result.""" + + +class MultipleResultsError(InvalidResultError): + """Exception when a database operation produced multiple results.""" + + +class Source: + """Source object.""" + + arraysize: int + resulttype: int + ntuples: int + nfields: int + + def execute(self, sql: str) -> int | None: + """Execute a SQL statement.""" + ... + + def fetch(self, num: int) -> list[tuple]: + """Return the next num rows from the last result in a list.""" + ... + + def listinfo(self) -> tuple[tuple[int, str, int, int, int], ...]: + """Get information for all fields.""" + ... + + def oidstatus(self) -> int | None: + """Return oid of last inserted row (if available).""" + ... + + def putdata(self, buffer: str | bytes | BaseException | None + ) -> int | None: + """Send data to server during copy from stdin.""" + ... + + def getdata(self, decode: bool | None = None) -> str | bytes | int: + """Receive data to server during copy to stdout.""" + ... + + def close(self) -> None: + """Close query object without deleting it.""" + ... + + +class LargeObject: + """Large object.""" + + oid: int + pgcnx: Connection + error: str + + def open(self, mode: int) -> None: + """Open a large object. + + The valid values for 'mode' parameter are defined as the module level + constants INV_READ and INV_WRITE. + """ + ... + + def close(self) -> None: + """Close a large object.""" + ... + + def read(self, size: int) -> bytes: + """Read data from large object.""" + ... + + def write(self, data: bytes) -> None: + """Write data to large object.""" + ... + + def seek(self, offset: int, whence: int) -> int: + """Change current position in large object. + + The valid values for the 'whence' parameter are defined as the + module level constants SEEK_SET, SEEK_CUR and SEEK_END. + """ + ... + + def unlink(self) -> None: + """Delete large object.""" + ... + + def size(self) -> int: + """Return the large object size.""" + ... + + def export(self, filename: str) -> None: + """Export a large object to a file.""" + ... + + +class Connection: + """Connection object. + + This object handles a connection to a PostgreSQL database. + It embeds and hides all the parameters that define this connection, + thus just leaving really significant parameters in function calls. + """ + + host: str + port: int + db: str + options: str + error: str + status: int + user : str + protocol_version: int + server_version: int + socket: int + backend_pid: int + ssl_in_use: bool + ssl_attributes: dict[str, str | None] + + def source(self) -> Source: + """Create a new source object for this connection.""" + ... + + def query(self, cmd: str, args: Sequence | None = None) -> Query: + """Create a new query object for this connection. + + Note that if the command is something other than DQL, this method + can return an int, str or None instead of a Query. + """ + ... + + def send_query(self, cmd: str, args: Sequence | None = None) -> Query: + """Create a new asynchronous query object for this connection.""" + ... + + def query_prepared(self, name: str, args: Sequence | None = None) -> Query: + """Execute a prepared statement.""" + ... + + def prepare(self, name: str, cmd: str) -> None: + """Create a prepared statement.""" + ... + + def describe_prepared(self, name: str) -> Query: + """Describe a prepared statement.""" + ... + + def poll(self) -> int: + """Complete an asynchronous connection and get its state.""" + ... + + def reset(self) -> None: + """Reset the connection.""" + ... + + def cancel(self) -> None: + """Abandon processing of current SQL command.""" + ... + + def close(self) -> None: + """Close the database connection.""" + ... + + def fileno(self) -> int: + """Get the socket used to connect to the database.""" + ... + + def get_cast_hook(self) -> Callable | None: + """Get the function that handles all external typecasting.""" + ... + + def set_cast_hook(self, hook: Callable | None) -> None: + """Set a function that will handle all external typecasting.""" + ... + + def get_notice_receiver(self) -> Callable | None: + """Get the current notice receiver.""" + ... + + def set_notice_receiver(self, receiver: Callable | None) -> None: + """Set a custom notice receiver.""" + ... + + def getnotify(self) -> tuple[str, int, str] | None: + """Get the last notify from the server.""" + ... + + def inserttable(self, table: str, values: Sequence[list|tuple], + columns: list[str] | tuple[str, ...] | None = None) -> int: + """Insert a Python iterable into a database table.""" + ... + + def transaction(self) -> int: + """Get the current in-transaction status of the server. + + The status returned by this method can be TRANS_IDLE (currently idle), + TRANS_ACTIVE (a command is in progress), TRANS_INTRANS (idle, in a + valid transaction block), or TRANS_INERROR (idle, in a failed + transaction block). TRANS_UNKNOWN is reported if the connection is + bad. The status TRANS_ACTIVE is reported only when a query has been + sent to the server and not yet completed. + """ + ... + + def parameter(self, name: str) -> str | None: + """Look up a current parameter setting of the server.""" + ... + + def date_format(self) -> str: + """Look up the date format currently being used by the database.""" + ... + + def escape_literal(self, s: AnyStr) -> AnyStr: + """Escape a literal constant for use within SQL.""" + ... + + def escape_identifier(self, s: AnyStr) -> AnyStr: + """Escape an identifier for use within SQL.""" + ... + + def escape_string(self, s: AnyStr) -> AnyStr: + """Escape a string for use within SQL.""" + ... + + def escape_bytea(self, s: AnyStr) -> AnyStr: + """Escape binary data for use within SQL as type 'bytea'.""" + ... + + def putline(self, line: str) -> None: + """Write a line to the server socket.""" + ... + + def getline(self) -> str: + """Get a line from server socket.""" + ... + + def endcopy(self) -> None: + """Synchronize client and server.""" + ... + + def set_non_blocking(self, nb: bool) -> None: + """Set the non-blocking mode of the connection.""" + ... + + def is_non_blocking(self) -> bool: + """Get the non-blocking mode of the connection.""" + ... + + def locreate(self, mode: int) -> LargeObject: + """Create a large object in the database. + + The valid values for 'mode' parameter are defined as the module level + constants INV_READ and INV_WRITE. + """ + ... + + def getlo(self, oid: int) -> LargeObject: + """Build a large object from given oid.""" + ... + + def loimport(self, filename: str) -> LargeObject: + """Import a file to a large object.""" + ... + + +class Query: + """Query object. + + The Query object returned by Connection.query and DB.query can be used + as an iterable returning rows as tuples. You can also directly access + row tuples using their index, and get the number of rows with the + len() function. The Query class also provides the several methods + for accessing the results of the query. + """ + + def __len__(self) -> int: + ... + + def __getitem__(self, key: int) -> object: + ... + + def __iter__(self) -> Query: + ... + + def __next__(self) -> tuple: + ... + + def getresult(self) -> list[tuple]: + """Get query values as list of tuples.""" + ... + + def dictresult(self) -> list[dict[str, object]]: + """Get query values as list of dictionaries.""" + ... + + def dictiter(self) -> Iterable[dict[str, object]]: + """Get query values as iterable of dictionaries.""" + ... + + def namedresult(self) -> list[SomeNamedTuple]: + """Get query values as list of named tuples.""" + ... + + def namediter(self) -> Iterable[SomeNamedTuple]: + """Get query values as iterable of named tuples.""" + ... + + def one(self) -> tuple | None: + """Get one row from the result of a query as a tuple.""" + ... + + def single(self) -> tuple: + """Get single row from the result of a query as a tuple.""" + ... + + def onedict(self) -> dict[str, object] | None: + """Get one row from the result of a query as a dictionary.""" + ... + + def singledict(self) -> dict[str, object]: + """Get single row from the result of a query as a dictionary.""" + ... + + def onenamed(self) -> SomeNamedTuple | None: + """Get one row from the result of a query as named tuple.""" + ... + + def singlenamed(self) -> SomeNamedTuple: + """Get single row from the result of a query as named tuple.""" + ... + + def scalarresult(self) -> list: + """Get first fields from query result as list of scalar values.""" + + def scalariter(self) -> Iterable: + """Get first fields from query result as iterable of scalar values.""" + ... + + def onescalar(self) -> object | None: + """Get one row from the result of a query as scalar value.""" + ... + + def singlescalar(self) -> object: + """Get single row from the result of a query as scalar value.""" + ... + + def fieldname(self, num: int) -> str: + """Get field name from its number.""" + ... + + def fieldnum(self, name: str) -> int: + """Get field number from its name.""" + ... + + def listfields(self) -> tuple[str, ...]: + """List field names of query result.""" + ... + + def fieldinfo(self, column: int | str | None) -> tuple[str, int, int, int]: + """Get information on one or all fields of the query. + + The four-tuples contain the following information: + The field name, the internal OID number of the field type, + the size in bytes of the column or a negative value if it is + of variable size, and a type-specific modifier value. + """ + ... + + def memsize(self) -> int: + """Return number of bytes allocated by query result.""" + ... + + +def connect(dbname: str | None = None, + host: str | None = None, + port: int | None = None, + opt: str | None = None, + user: str | None = None, + passwd: str | None = None, + nowait: int | None = None) -> Connection: + """Connect to a PostgreSQL database.""" + ... + + +def cast_array(s: str, cast: Callable | None = None, + delim: bytes | None = None) -> list: + """Cast a string representing a PostgreSQL array to a Python list.""" + ... + + +def cast_record(s: str, + cast: Callable | list[Callable | None] | + tuple[Callable | None, ...] | None = None, + delim: bytes | None = None) -> tuple: + """Cast a string representing a PostgreSQL record to a Python tuple.""" + ... + + +def cast_hstore(s: str) -> dict[str, str | None]: + """Cast a string as a hstore.""" + ... + + +def escape_bytea(s: AnyStr) -> AnyStr: + """Escape binary data for use within SQL as type 'bytea'.""" + ... + + +def unescape_bytea(s: AnyStr) -> bytes: + """Unescape 'bytea' data that has been retrieved as text.""" + ... + + +def escape_string(s: AnyStr) -> AnyStr: + """Escape a string for use within SQL.""" + ... + + +def get_pqlib_version() -> int: + """Get the version of libpq that is being used by PyGreSQL.""" + ... + + +def get_array() -> bool: + """Check whether arrays are returned as list objects.""" + ... + + +def set_array(on: bool) -> None: + """Set whether arrays are returned as list objects.""" + ... + + +def get_bool() -> bool: + """Check whether boolean values are returned as bool objects.""" + ... + + +def set_bool(on: bool | int) -> None: + """Set whether boolean values are returned as bool objects.""" + ... + + +def get_bytea_escaped() -> bool: + """Check whether 'bytea' values are returned as escaped strings.""" + ... + + +def set_bytea_escaped(on: bool | int) -> None: + """Set whether 'bytea' values are returned as escaped strings.""" + ... + + +def get_datestyle() -> str | None: + """Get the assumed date style for typecasting.""" + ... + + +def set_datestyle(datestyle: str | None) -> None: + """Set a fixed date style that shall be assumed when typecasting.""" + ... + + +def get_decimal() -> type: + """Get the decimal type to be used for numeric values.""" + ... + + +def set_decimal(cls: type) -> None: + """Set a fixed date style that shall be assumed when typecasting.""" + ... + + +def get_decimal_point() -> str | None: + """Get the decimal mark used for monetary values.""" + ... + + +def set_decimal_point(mark: str | None) -> None: + """Specify which decimal mark is used for interpreting monetary values.""" + ... + + +def get_jsondecode() -> Callable[[str], object] | None: + """Get the function that deserializes JSON formatted strings.""" + ... + + +def set_jsondecode(decode: Callable[[str], object] | None) -> None: + """Set a function that will deserialize JSON formatted strings.""" + ... + + +def get_defbase() -> str | None: + """Get the default database name.""" + ... + + +def set_defbase(base: str | None) -> None: + """Set the default database name.""" + ... + + +def get_defhost() -> str | None: + """Get the default host.""" + ... + + +def set_defhost(host: str | None) -> None: + """Set the default host.""" + ... + + +def get_defport() -> int | None: + """Get the default host.""" + ... + + +def set_defport(port: int | None) -> None: + """Set the default port.""" + ... + + +def get_defopt() -> str | None: + """Get the default connection options.""" + ... + + +def set_defopt(opt: str | None) -> None: + """Set the default connection options.""" + ... + + +def get_defuser() -> str | None: + """Get the default database user.""" + ... + + +def set_defuser(user: str | None) -> None: + """Set the default database user.""" + ... + + +def get_defpasswd() -> str | None: + """Get the default database password.""" + ... + + +def set_defpasswd(passwd: str | None) -> None: + """Set the default database password.""" + ... + + +def set_query_helpers(*helpers: Callable) -> None: + """Set internal query helper functions.""" + ... diff --git a/pg/py.typed b/pg/py.typed new file mode 100644 index 00000000..ea6e1ace --- /dev/null +++ b/pg/py.typed @@ -0,0 +1,4 @@ +# Marker file for PEP 561. + +# The pg package use inline types, +# except for the _pg extension module which uses a stub file. diff --git a/pgdb.py b/pgdb/__init__.py similarity index 97% rename from pgdb.py rename to pgdb/__init__.py index 332ca3d0..74ad38e5 100644 --- a/pgdb.py +++ b/pgdb/__init__.py @@ -90,42 +90,8 @@ ) from uuid import UUID as Uuid # noqa: N811 -try: - from _pg import version -except ImportError as e: # noqa: F841 - import os - libpq = 'libpq.' - if os.name == 'nt': - libpq += 'dll' - import sys - paths = [path for path in os.environ["PATH"].split(os.pathsep) - if os.path.exists(os.path.join(path, libpq))] - if sys.version_info >= (3, 8): - # see https://docs.python.org/3/whatsnew/3.8.html#ctypes - add_dll_dir = os.add_dll_directory # type: ignore - for path in paths: - with add_dll_dir(os.path.abspath(path)): - try: - from _pg import version # type: ignore - except ImportError: - pass - else: - del version - e = None # type: ignore - break - if paths: - libpq = 'compatible ' + libpq - else: - libpq += 'so' - if e: - raise ImportError( - "Cannot import shared library for PyGreSQL,\n" - f"probably because no {libpq} is installed.\n{e}") from e -else: - del version - # import objects from extension module -from _pg import ( +from pg import ( RESULT_DQL, DatabaseError, DataError, @@ -143,10 +109,10 @@ unescape_bytea, version, ) -from _pg import ( +from pg import ( Connection as Cnx, # base connection ) -from _pg import ( +from pg import ( connect as get_cnx, # get base connection ) @@ -694,10 +660,9 @@ def __missing__(self, key: int | str) -> TypeCode: res = self._src.fetch(1) if not res: raise KeyError(f'Type {key} could not be found') - res = res[0] + r = res[0] type_code = TypeCode.create( - int(res[0]), res[1], int(res[2]), - res[3], res[4], res[5], int(res[6])) + int(r[0]), r[1], int(r[2]), r[3], r[4], r[5], int(r[6])) # noinspection PyUnresolvedReferences self[type_code.oid] = self[str(type_code)] = type_code return type_code @@ -782,19 +747,30 @@ def __getitem__(self, key: str) -> str: # *** Error Messages *** -E = TypeVar('E', bound=DatabaseError) +E = TypeVar('E', bound=Error) -def _db_error(msg: str, cls:type[E] = DatabaseError) -> type[E]: - """Return DatabaseError with empty sqlstate attribute.""" +def _error(msg: str, cls: type[E]) -> E: + """Return specified error object with empty sqlstate attribute.""" error = cls(msg) - error.sqlstate = None + if isinstance(error, DatabaseError): + error.sqlstate = None return error +def _db_error(msg: str) -> DatabaseError: + """Return DatabaseError.""" + return _error(msg, DatabaseError) + + +def _if_error(msg: str) -> InterfaceError: + """Return InterfaceError.""" + return _error(msg, InterfaceError) + + def _op_error(msg: str) -> OperationalError: """Return OperationalError.""" - return _db_error(msg, OperationalError) + return _error(msg, OperationalError) # *** Row Tuples *** @@ -835,8 +811,8 @@ def __init__(self, connection: Connection) -> None: cnx = connection._cnx if not cnx: raise _op_error("Connection has been closed") - self._cnx = cnx - self.type_cache = connection.type_cache + self._cnx: Cnx = cnx + self.type_cache: TypeCache = connection.type_cache self._src = self._cnx.source() # the official attribute for describing the result columns self._description: list[CursorDescription] | bool | None = None @@ -845,9 +821,9 @@ def __init__(self, connection: Connection) -> None: self.row_factory = None # type: ignore else: self.build_row_factory = None # type: ignore - self.rowcount = -1 - self.arraysize = 1 - self.lastrowid = None + self.rowcount: int | None = -1 + self.arraysize: int = 1 + self.lastrowid: int | None = None def __iter__(self) -> Cursor: """Make cursor compatible to the iteration protocol.""" @@ -1044,8 +1020,7 @@ def executemany(self, operation: str, raise # database provides error message except Error as err: # noinspection PyTypeChecker - raise _db_error( - f"Error in '{sql}': '{err}'", InterfaceError) from err + raise _if_error(f"Error in '{sql}': '{err}'") from err except Exception as err: raise _op_error(f"Internal error in '{sql}': {err}") from err # then initialize result raw count and description @@ -1264,7 +1239,8 @@ def chunks() -> Generator: # the following call will re-raise the error putdata(error) else: - self.rowcount = putdata(None) + rowcount = putdata(None) + self.rowcount = -1 if rowcount is None else rowcount # return the cursor object, so you can chain operations return self @@ -1459,7 +1435,7 @@ class Connection: def __init__(self, cnx: Cnx) -> None: """Create a database connection object.""" - self._cnx = cnx # connection + self._cnx: Cnx | None = cnx # connection self._tnx = False # transaction state self.type_cache = TypeCache(cnx) self.cursor_type = Cursor @@ -1509,7 +1485,7 @@ def close(self) -> None: with suppress(DatabaseError): self.rollback() self._cnx.close() - self._cnx = None + self._cnx = None @property def closed(self) -> bool: @@ -1857,5 +1833,5 @@ def __str__(self) -> str: if __name__ == '__main__': print('PyGreSQL version', version) - print('') + print() print(__doc__) diff --git a/pgdb/py.typed b/pgdb/py.typed new file mode 100644 index 00000000..ead52d46 --- /dev/null +++ b/pgdb/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561. The pgdb package uses inline types. diff --git a/pyproject.toml b/pyproject.toml index 1016b433..e289b38f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,9 +89,13 @@ module = [ disallow_untyped_defs = false [tool.setuptools] -py-modules = ["pg", "pgdb"] +packages = ["pg", "pgdb"] license-files = ["LICENSE.txt"] +[tool.setuptools.package-data] +pg = ["pg.typed"] +pgdb = ["pg.typed"] + [build-system] requires = ["setuptools>=68", "wheel>=0.41"] build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 3ad3a906..4fd39c56 100755 --- a/setup.py +++ b/setup.py @@ -33,9 +33,6 @@ # classic interface, and "pgdb" for the modern DB-API 2.0 interface. # These two top-level Python modules share the same C extension "_pg". -py_modules = ['pg', 'pgdb'] -c_sources = ['pgmodule.c'] - def pg_config(s): """Retrieve information about installed version of PostgreSQL.""" f = os.popen(f'pg_config --{s}') # noqa: S605 @@ -127,27 +124,27 @@ def finalize_options(self): setup( - name="PyGreSQL", + name='PyGreSQL', version=version, - description="Python PostgreSQL Interfaces", + description='Python PostgreSQL Interfaces', long_description=long_description, long_description_content_type='text/x-rst', - keywords="pygresql postgresql database api dbapi", + keywords='pygresql postgresql database api dbapi', author="D'Arcy J. M. Cain", author_email="darcy@PyGreSQL.org", - url="https://pygresql.github.io/", - download_url="https://pygresql.github.io/contents/download/", + url='https://pygresql.github.io/', + download_url='https://pygresql.github.io/contents/download/', project_urls={ - "Documentation": "https://pygresql.github.io/contents/", - "Issue Tracker": "https://github.com/PyGreSQL/PyGreSQL/issues/", - "Mailing List": "https://mail.vex.net/mailman/listinfo/pygresql", - "Source Code": "https://github.com/PyGreSQL/PyGreSQL"}, + 'Documentation': 'https://pygresql.github.io/contents/', + 'Issue Tracker': 'https://github.com/PyGreSQL/PyGreSQL/issues/', + 'Mailing List': 'https://mail.vex.net/mailman/listinfo/pygresql', + 'Source Code': 'https://github.com/PyGreSQL/PyGreSQL'}, classifiers=[ - "Development Status :: 6 - Mature", - "Intended Audience :: Developers", - "License :: OSI Approved :: PostgreSQL License", - "Operating System :: OS Independent", - "Programming Language :: C", + 'Development Status :: 6 - Mature', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: PostgreSQL License', + 'Operating System :: OS Independent', + 'Programming Language :: C', 'Programming Language :: Python', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.7', @@ -155,16 +152,17 @@ def finalize_options(self): 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', - "Programming Language :: SQL", - "Topic :: Database", - "Topic :: Database :: Front-Ends", - "Topic :: Software Development :: Libraries :: Python Modules"], - license="PostgreSQL", - py_modules=py_modules, + 'Programming Language :: SQL', + 'Topic :: Database', + 'Topic :: Database :: Front-Ends', + 'Topic :: Software Development :: Libraries :: Python Modules'], + license='PostgreSQL', test_suite='tests.discover', zip_safe=False, + packages=["pg", "pgdb"], + package_data={"pg": ["py.typed"], "pgdb": ["py.typed"]}, ext_modules=[Extension( - '_pg', c_sources, + 'pg._pg', ["ext/pgmodule.c"], include_dirs=include_dirs, library_dirs=library_dirs, define_macros=define_macros, undef_macros=undef_macros, libraries=libraries, extra_compile_args=extra_compile_args)], diff --git a/tests/config.py b/tests/config.py index 0b15f62e..4e27c3ae 100644 --- a/tests/config.py +++ b/tests/config.py @@ -18,13 +18,10 @@ dbname = get('PYGRESQL_DB', get('PGDATABASE', 'test')) dbhost = get('PYGRESQL_HOST', get('PGHOST', 'localhost')) -dbport = get('PYGRESQL_PORT', get('PGPORT', 5432)) +dbport = int(get('PYGRESQL_PORT', get('PGPORT', 5432))) dbuser = get('PYGRESQL_USER', get('PGUSER')) dbpasswd = get('PYGRESQL_PASSWD', get('PGPASSWORD')) -if dbport: - dbport = int(dbport) - try: from .LOCAL_PyGreSQL import * # type: ignore # noqa except (ImportError, ValueError): diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 242fdbb5..d6a742bf 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -19,7 +19,7 @@ from collections.abc import Iterable from contextlib import suppress from decimal import Decimal -from typing import Sequence +from typing import Any, Sequence import pg # the module under test @@ -999,7 +999,7 @@ def test_query_with_bool_params(self, bool_enabled=None): self.assertEqual(query(q, (False,)).getresult(), r_false) self.assertEqual(query(q, (True,)).getresult(), r_true) finally: - if bool_enabled is not None: + if bool_enabled_default is not None: pg.set_bool(bool_enabled_default) def test_query_with_bool_params_not_default(self): @@ -1557,7 +1557,7 @@ def test_single_with_empty_query(self): try: q.single() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.NoResultError) @@ -1577,7 +1577,7 @@ def test_single_with_two_rows(self): try: q.single() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.MultipleResultsError) @@ -1588,7 +1588,7 @@ def test_single_dict_with_empty_query(self): try: q.singledict() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.NoResultError) @@ -1608,7 +1608,7 @@ def test_single_dict_with_two_rows(self): try: q.singledict() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.MultipleResultsError) @@ -1619,7 +1619,7 @@ def test_single_named_with_empty_query(self): try: q.singlenamed() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.NoResultError) @@ -1627,7 +1627,7 @@ def test_single_named_with_empty_query(self): def test_single_named_with_single_row(self): q = self.c.query("select 1 as one, 2 as two") - r = q.singlenamed() + r: Any = q.singlenamed() self.assertEqual(r._fields, ('one', 'two')) self.assertEqual(r.one, 1) self.assertEqual(r.two, 2) @@ -1643,7 +1643,7 @@ def test_single_named_with_two_rows(self): try: q.singlenamed() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.MultipleResultsError) @@ -1654,7 +1654,7 @@ def test_single_scalar_with_empty_query(self): try: q.singlescalar() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.NoResultError) @@ -1674,7 +1674,7 @@ def test_single_scalar_with_two_rows(self): try: q.singlescalar() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.MultipleResultsError) @@ -2685,38 +2685,38 @@ def setUpClass(cls): def test_escape_string(self): self.assertTrue(self.cls_set_up) f = pg.escape_string - r = f(b'plain') - self.assertIsInstance(r, bytes) - self.assertEqual(r, b'plain') - r = f('plain') - self.assertIsInstance(r, str) - self.assertEqual(r, 'plain') - r = f("das is' käse".encode()) - self.assertIsInstance(r, bytes) - self.assertEqual(r, "das is'' käse".encode()) - r = f("that's cheesy") - self.assertIsInstance(r, str) - self.assertEqual(r, "that''s cheesy") - r = f(r"It's bad to have a \ inside.") - self.assertEqual(r, r"It''s bad to have a \\ inside.") + b = f(b'plain') + self.assertIsInstance(b, bytes) + self.assertEqual(b, b'plain') + s = f('plain') + self.assertIsInstance(s, str) + self.assertEqual(s, 'plain') + b = f("das is' käse".encode()) + self.assertIsInstance(b, bytes) + self.assertEqual(b, "das is'' käse".encode()) + s = f("that's cheesy") + self.assertIsInstance(s, str) + self.assertEqual(s, "that''s cheesy") + s = f(r"It's bad to have a \ inside.") + self.assertEqual(s, r"It''s bad to have a \\ inside.") def test_escape_bytea(self): self.assertTrue(self.cls_set_up) f = pg.escape_bytea - r = f(b'plain') - self.assertIsInstance(r, bytes) - self.assertEqual(r, b'plain') - r = f('plain') - self.assertIsInstance(r, str) - self.assertEqual(r, 'plain') - r = f("das is' käse".encode()) - self.assertIsInstance(r, bytes) - self.assertEqual(r, b"das is'' k\\\\303\\\\244se") - r = f("that's cheesy") - self.assertIsInstance(r, str) - self.assertEqual(r, "that''s cheesy") - r = f(b'O\x00ps\xff!') - self.assertEqual(r, b'O\\\\000ps\\\\377!') + b = f(b'plain') + self.assertIsInstance(b, bytes) + self.assertEqual(b, b'plain') + s = f('plain') + self.assertIsInstance(s, str) + self.assertEqual(s, 'plain') + b = f("das is' käse".encode()) + self.assertIsInstance(b, bytes) + self.assertEqual(b, b"das is'' k\\\\303\\\\244se") + s = f("that's cheesy") + self.assertIsInstance(s, str) + self.assertEqual(s, "that''s cheesy") + b = f(b'O\x00ps\xff!') + self.assertEqual(b, b'O\\\\000ps\\\\377!') if __name__ == '__main__': diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 71438f71..74d6df8e 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -3348,24 +3348,26 @@ def test_insert_update_get_bytea(self): def test_upsert_bytea(self): self.create_table('bytea_test', 'n smallint primary key, data bytea') s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n" - r = dict(n=7, data=s) - r = self.db.upsert('bytea_test', r) - self.assertIsInstance(r, dict) - self.assertIn('n', r) - self.assertEqual(r['n'], 7) - self.assertIn('data', r) + d = dict(n=7, data=s) + d = self.db.upsert('bytea_test', d) + self.assertIsInstance(d, dict) + self.assertIn('n', d) + self.assertEqual(d['n'], 7) + self.assertIn('data', d) + data = d['data'] if pg.get_bytea_escaped(): - self.assertNotEqual(r['data'], s) - r['data'] = pg.unescape_bytea(r['data']) - self.assertIsInstance(r['data'], bytes) - self.assertEqual(r['data'], s) - r['data'] = None - r = self.db.upsert('bytea_test', r) - self.assertIsInstance(r, dict) - self.assertIn('n', r) - self.assertEqual(r['n'], 7) - self.assertIn('data', r) - self.assertIsNone(r['data']) + self.assertNotEqual(data, s) + self.assertIsInstance(data, str) + data = pg.unescape_bytea(data) # type: ignore + self.assertIsInstance(data, bytes) + self.assertEqual(data, s) + d['data'] = None + d = self.db.upsert('bytea_test', d) + self.assertIsInstance(d, dict) + self.assertIn('n', d) + self.assertEqual(d['n'], 7) + self.assertIn('data', d) + self.assertIsNone(d['data']) def test_insert_get_json(self): self.create_table('json_test', 'n smallint primary key, data json') diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 33c2f6f9..19214c5d 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -287,11 +287,11 @@ def test_parser_nested(self): def test_parser_too_deeply_nested(self): f = pg.cast_array for n in 3, 5, 9, 12, 16, 32, 64, 256: - r = '{' * n + 'a,b,c' + '}' * n + s = '{' * n + 'a,b,c' + '}' * n if n > 16: # hard coded maximum depth - self.assertRaises(ValueError, f, r) + self.assertRaises(ValueError, f, s) else: - r = f(r) + r = f(s) for _i in range(n - 1): self.assertIsInstance(r, list) self.assertEqual(len(r), 1) @@ -537,9 +537,9 @@ def test_parser_nested(self): def test_parser_many_elements(self): f = pg.cast_record for n in 3, 5, 9, 12, 16, 32, 64, 256: - r = ','.join(map(str, range(n))) - r = f'({r})' - r = f(r, int) + s = ','.join(map(str, range(n))) + s = f'({s})' + r = f(s, int) self.assertEqual(r, tuple(range(n))) def test_parser_cast_uniform(self): @@ -877,27 +877,27 @@ class TestEscapeFunctions(unittest.TestCase): def test_escape_string(self): f = pg.escape_string - r = f(b'plain') - self.assertIsInstance(r, bytes) - self.assertEqual(r, b'plain') - r = f('plain') - self.assertIsInstance(r, str) - self.assertEqual(r, 'plain') - r = f("that's cheese") - self.assertIsInstance(r, str) - self.assertEqual(r, "that''s cheese") + b = f(b'plain') + self.assertIsInstance(b, bytes) + self.assertEqual(b, b'plain') + s = f('plain') + self.assertIsInstance(s, str) + self.assertEqual(s, 'plain') + s = f("that's cheese") + self.assertIsInstance(s, str) + self.assertEqual(s, "that''s cheese") def test_escape_bytea(self): f = pg.escape_bytea - r = f(b'plain') - self.assertIsInstance(r, bytes) - self.assertEqual(r, b'plain') - r = f('plain') - self.assertIsInstance(r, str) - self.assertEqual(r, 'plain') - r = f("that's cheese") - self.assertIsInstance(r, str) - self.assertEqual(r, "that''s cheese") + b = f(b'plain') + self.assertIsInstance(b, bytes) + self.assertEqual(b, b'plain') + s = f('plain') + self.assertIsInstance(s, str) + self.assertEqual(s, 'plain') + s = f("that's cheese") + self.assertIsInstance(s, str) + self.assertEqual(s, "that''s cheese") def test_unescape_bytea(self): f = pg.unescape_bytea diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 6838d03a..2e731c6e 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -869,7 +869,7 @@ def __init__(self, value: Any) -> None: def __str__(self) -> str: return str(self.value).replace('.', ',') - + self.assertTrue(pgdb.decimal_type(CustomDecimal) is CustomDecimal) cur.execute('select 4.25') self.assertEqual(cur.description[0].type_code, pgdb.NUMBER) diff --git a/tox.ini b/tox.ini index 322a3f32..eae93234 100644 --- a/tox.ini +++ b/tox.ini @@ -7,20 +7,20 @@ envlist = py3{7,8,9,10,11},ruff,mypy,cformat,docs basepython = python3.11 deps = ruff>=0.0.287 commands = - ruff setup.py pg.py pgdb.py tests + ruff setup.py pg pgdb tests [testenv:mypy] basepython = python3.11 deps = mypy>=1.5.1 commands = - mypy pg.py pgdb.py tests + mypy pg pgdb tests [testenv:cformat] basepython = python3.11 allowlist_externals = sh commands = - sh -c "! (clang-format --style=file -n *.c 2>&1 | tee /dev/tty | grep format-violations)" + sh -c "! (clang-format --style=file -n ext/*.c 2>&1 | tee /dev/tty | grep format-violations)" [testenv:docs] basepython = python3.11 From 8758cdaa0e0230f452918e80fe44305d522c8e91 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 5 Sep 2023 00:01:05 +0200 Subject: [PATCH 053/118] Fix manifest file --- MANIFEST.in | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 4ff1c2b6..8d4bbd33 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,11 +1,9 @@ include setup.py -recursive-include pg pgdb tests *.py - -include pg/*.pyi -include pg/py.typed -include pgdb/py.typed +recursive-include pg *.py *.pyi py.typed +recursive-include pgdb *.py py.typed +recursive-include tests *.py include ext/*.c include ext/*.h From 5f861cbe39c19cd92e9cc7180e3128cee5a2bc6f Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 5 Sep 2023 00:13:11 +0200 Subject: [PATCH 054/118] Use organization for docs, mention myself in README --- README.rst | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/README.rst b/README.rst index 150effb5..a010b944 100644 --- a/README.rst +++ b/README.rst @@ -7,8 +7,10 @@ powerful PostgreSQL features from Python. PyGreSQL should run on most platforms where PostgreSQL and Python is running. It is based on the PyGres95 code written by Pascal Andre. -D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 +D'Arcy J. M. Cain renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. +Christoph Zwerschke volunteered as another maintainer and has been the main +contributor since version 3.7 of PyGreSQL. The following Python versions are supported: @@ -16,7 +18,6 @@ The following Python versions are supported: * PyGreSQL 5.x: Python 2 and Python 3 * PyGreSQL 6.x and newer: Python 3 only - Installation ------------ @@ -31,6 +32,6 @@ Documentation ------------- The documentation is available at -`pygresql.github.io/PyGreSQL/ `_ -and at `pygresql.readthedocs.io `_, +`pygresql.github.io/ `_ and at +`pygresql.readthedocs.io `_, where you can also find the documentation for older versions. From 5252d13164c8b50f805ce4e8cb9a23c7190a6088 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 5 Sep 2023 13:14:54 +0200 Subject: [PATCH 055/118] Split pg package into submodules Also fix a small issue in the code for adaptation of records. --- pg/__init__.py | 2728 +----------------------------- pg/adapt.py | 680 ++++++++ pg/attrs.py | 35 + pg/cast.py | 436 +++++ pg/core.py | 135 ++ pg/db.py | 1332 +++++++++++++++ pg/error.py | 35 + pg/helpers.py | 98 ++ pg/notify.py | 149 ++ pg/tz.py | 21 + tests/test_classic_attrdict.py | 100 ++ tests/test_classic_connection.py | 10 +- tests/test_classic_dbwrapper.py | 167 +- tests/test_classic_functions.py | 9 +- 14 files changed, 3096 insertions(+), 2839 deletions(-) create mode 100644 pg/adapt.py create mode 100644 pg/attrs.py create mode 100644 pg/cast.py create mode 100644 pg/core.py create mode 100644 pg/db.py create mode 100644 pg/error.py create mode 100644 pg/helpers.py create mode 100644 pg/notify.py create mode 100644 pg/tz.py create mode 100644 tests/test_classic_attrdict.py diff --git a/pg/__init__.py b/pg/__init__.py index 0740db20..e0e1b214 100644 --- a/pg/__init__.py +++ b/pg/__init__.py @@ -22,70 +22,9 @@ from __future__ import annotations -import select -import weakref -from collections import namedtuple -from contextlib import suppress -from datetime import date, datetime, time, timedelta -from decimal import Decimal -from functools import lru_cache, partial -from inspect import signature -from json import dumps as jsonencode -from json import loads as jsondecode -from math import isinf, isnan -from operator import itemgetter -from re import compile as regex -from types import MappingProxyType -from typing import ( - Any, - Callable, - ClassVar, - Generator, - Iterator, - List, - Mapping, - NamedTuple, - Sequence, - TypeVar, -) -from uuid import UUID - -try: - from ._pg import version -except ImportError as e: # noqa: F841 - import os - libpq = 'libpq.' - if os.name == 'nt': - libpq += 'dll' - import sys - paths = [path for path in os.environ["PATH"].split(os.pathsep) - if os.path.exists(os.path.join(path, libpq))] - if sys.version_info >= (3, 8): - # see https://docs.python.org/3/whatsnew/3.8.html#ctypes - add_dll_dir = os.add_dll_directory # type: ignore - for path in paths: - with add_dll_dir(os.path.abspath(path)): - try: - from ._pg import version - except ImportError: - pass - else: - del version - e = None # type: ignore - break - if paths: - libpq = 'compatible ' + libpq - else: - libpq += 'so' - if e: - raise ImportError( - "Cannot import shared library for PyGreSQL,\n" - f"probably because no {libpq} is installed.\n{e}") from e -else: - del version - -# import objects from extension module -from ._pg import ( +from .adapt import Adapter, Bytea, Hstore, Json, Literal +from .cast import Typecasts, get_typecast, set_typecast +from .core import ( INV_READ, INV_WRITE, POLLING_FAILED, @@ -155,6 +94,9 @@ unescape_bytea, version, ) +from .db import DB +from .helpers import init_core, set_row_factory_size +from .notify import NotificationHandler __version__ = version @@ -185,2661 +127,9 @@ 'set_datestyle', 'set_decimal', 'set_decimal_point', 'set_defbase', 'set_defhost', 'set_defopt', 'set_defpasswd', 'set_defport', 'set_defuser', - 'set_jsondecode', 'set_query_helpers', 'set_typecast', + 'set_jsondecode', 'set_query_helpers', + 'set_row_factory_size', 'set_typecast', 'version', '__version__', ] -# Auxiliary classes and functions that are independent of a DB connection: - -SomeNamedTuple = Any # alias for accessing arbitrary named tuples - -def get_args(func: Callable) -> list: - return list(signature(func).parameters) - - -# time zones used in Postgres timestamptz output -_timezones: dict[str, str] = { - 'CET': '+0100', 'EET': '+0200', 'EST': '-0500', - 'GMT': '+0000', 'HST': '-1000', 'MET': '+0100', 'MST': '-0700', - 'UCT': '+0000', 'UTC': '+0000', 'WET': '+0000' -} - - -def _timezone_as_offset(tz: str) -> str: - if tz.startswith(('+', '-')): - if len(tz) < 5: - return tz + '00' - return tz.replace(':', '') - return _timezones.get(tz, '+0000') - - -def _oid_key(table: str) -> str: - """Build oid key from a table name.""" - return f'oid({table})' - - -class Bytea(bytes): - """Wrapper class for marking Bytea values.""" - - -class Hstore(dict): - """Wrapper class for marking hstore values.""" - - _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]') - - @classmethod - def _quote(cls, s: Any) -> str: - if s is None: - return 'NULL' - if not isinstance(s, str): - s = str(s) - if not s: - return '""' - s = s.replace('"', '\\"') - if cls._re_quote.search(s): - s = f'"{s}"' - return s - - def __str__(self) -> str: - """Create a printable representation of the hstore value.""" - q = self._quote - return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) - - -class Json: - """Wrapper class for marking Json values.""" - - def __init__(self, obj: Any, - encode: Callable[[Any], str] | None = None) -> None: - """Initialize the JSON object.""" - self.obj = obj - self.encode = encode or jsonencode - - def __str__(self) -> str: - """Create a printable representation of the JSON object.""" - obj = self.obj - if isinstance(obj, str): - return obj - return self.encode(obj) - - -class _SimpleTypes(dict): - """Dictionary mapping pg_type names to simple type names. - - The corresponding Python types and simple names are also mapped. - """ - - _type_aliases: Mapping[str, list[str | type]] = MappingProxyType({ - 'bool': [bool], - 'bytea': [Bytea], - 'date': ['interval', 'time', 'timetz', 'timestamp', 'timestamptz', - 'abstime', 'reltime', # these are very old - 'datetime', 'timedelta', # these do not really exist - date, time, datetime, timedelta], - 'float': ['float4', 'float8', float], - 'int': ['cid', 'int2', 'int4', 'int8', 'oid', 'xid', int], - 'hstore': [Hstore], 'json': ['jsonb', Json], 'uuid': [UUID], - 'num': ['numeric', Decimal], 'money': [], - 'text': ['bpchar', 'char', 'name', 'varchar', bytes, str] - }) - - # noinspection PyMissingConstructor - def __init__(self) -> None: - """Initialize type mapping.""" - for typ, keys in self._type_aliases.items(): - keys = [typ, *keys] - for key in keys: - self[key] = typ - if isinstance(key, str): - self[f'_{key}'] = f'{typ}[]' - elif not isinstance(key, tuple): - self[List[key]] = f'{typ}[]' # type: ignore - - @staticmethod - def __missing__(key: str) -> str: - """Unmapped types are interpreted as text.""" - return 'text' - - def get_type_dict(self) -> dict[type, str]: - """Get a plain dictionary of only the types.""" - return {key: typ for key, typ in self.items() - if not isinstance(key, (str, tuple))} - - -_simpletypes = _SimpleTypes() -_simple_type_dict = _simpletypes.get_type_dict() - - -def _quote_if_unqualified(param: str, name: int | str) -> str: - """Quote parameter representing a qualified name. - - Puts a quote_ident() call around the given parameter unless - the name contains a dot, in which case the name is ambiguous - (could be a qualified name or just a name with a dot in it) - and must be quoted manually by the caller. - """ - if isinstance(name, str) and '.' not in name: - return f'quote_ident({param})' - return param - - -class _ParameterList(list): - """Helper class for building typed parameter lists.""" - - adapt: Callable - - def add(self, value: Any, typ:Any = None) -> str: - """Typecast value with known database type and build parameter list. - - If this is a literal value, it will be returned as is. Otherwise, a - placeholder will be returned and the parameter list will be augmented. - """ - # noinspection PyUnresolvedReferences - value = self.adapt(value, typ) - if isinstance(value, Literal): - return value - self.append(value) - return f'${len(self)}' - - -class Literal(str): - """Wrapper class for marking literal SQL values.""" - - -class AttrDict(dict): - """Simple read-only ordered dictionary for storing attribute names.""" - - def __init__(self, *args: Any, **kw: Any) -> None: - self._read_only = False - super().__init__(*args, **kw) - self._read_only = True - error = self._read_only_error - self.clear = self.update = error # type: ignore - self.pop = self.setdefault = self.popitem = error # type: ignore - - def __setitem__(self, key: str, value: Any) -> None: - if self._read_only: - self._read_only_error() - super().__setitem__(key, value) - - def __delitem__(self, key: str) -> None: - if self._read_only: - self._read_only_error() - super().__delitem__(key) - - @staticmethod - def _read_only_error(*_args: Any, **_kw: Any) -> Any: - raise TypeError('This object is read-only') - - -class Adapter: - """Class providing methods for adapting parameters to the database.""" - - _bool_true_values = frozenset('t true 1 y yes on'.split()) - - _date_literals = frozenset( - 'current_date current_time' - ' current_timestamp localtime localtimestamp'.split()) - - _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$') - _re_record_quote = regex(r'[(,"\\]') - _re_array_escape = _re_record_escape = regex(r'(["\\])') - - def __init__(self, db: DB): - """Initialize the adapter object with the given connection.""" - self.db = weakref.proxy(db) - - @classmethod - def _adapt_bool(cls, v: Any) -> str | None: - """Adapt a boolean parameter.""" - if isinstance(v, str): - if not v: - return None - v = v.lower() in cls._bool_true_values - return 't' if v else 'f' - - @classmethod - def _adapt_date(cls, v: Any) -> Any: - """Adapt a date parameter.""" - if not v: - return None - if isinstance(v, str) and v.lower() in cls._date_literals: - return Literal(v) - return v - - @staticmethod - def _adapt_num(v: Any) -> Any: - """Adapt a numeric parameter.""" - if not v and v != 0: - return None - return v - - _adapt_int = _adapt_float = _adapt_money = _adapt_num - - def _adapt_bytea(self, v: Any) -> str: - """Adapt a bytea parameter.""" - return self.db.escape_bytea(v) - - def _adapt_json(self, v: Any) -> str | None: - """Adapt a json parameter.""" - if not v: - return None - if isinstance(v, str): - return v - if isinstance(v, Json): - return str(v) - return self.db.encode_json(v) - - def _adapt_hstore(self, v: Any) -> str | None: - """Adapt a hstore parameter.""" - if not v: - return None - if isinstance(v, str): - return v - if isinstance(v, Hstore): - return str(v) - if isinstance(v, dict): - return str(Hstore(v)) - raise TypeError(f'Hstore parameter {v} has wrong type') - - def _adapt_uuid(self, v: Any) -> str | None: - """Adapt a UUID parameter.""" - if not v: - return None - if isinstance(v, str): - return v - return str(v) - - @classmethod - def _adapt_text_array(cls, v: Any) -> str: - """Adapt a text type array parameter.""" - if isinstance(v, list): - adapt = cls._adapt_text_array - return '{' + ','.join(adapt(v) for v in v) + '}' - if v is None: - return 'null' - if not v: - return '""' - v = str(v) - if cls._re_array_quote.search(v): - v = cls._re_array_escape.sub(r'\\\1', v) - v = f'"{v}"' - return v - - _adapt_date_array = _adapt_text_array - - @classmethod - def _adapt_bool_array(cls, v: Any) -> str: - """Adapt a boolean array parameter.""" - if isinstance(v, list): - adapt = cls._adapt_bool_array - return '{' + ','.join(adapt(v) for v in v) + '}' - if v is None: - return 'null' - if isinstance(v, str): - if not v: - return 'null' - v = v.lower() in cls._bool_true_values - return 't' if v else 'f' - - @classmethod - def _adapt_num_array(cls, v: Any) -> str: - """Adapt a numeric array parameter.""" - if isinstance(v, list): - adapt = cls._adapt_num_array - v = '{' + ','.join(adapt(v) for v in v) + '}' - if not v and v != 0: - return 'null' - return str(v) - - _adapt_int_array = _adapt_float_array = _adapt_money_array = \ - _adapt_num_array - - def _adapt_bytea_array(self, v: Any) -> bytes: - """Adapt a bytea array parameter.""" - if isinstance(v, list): - return b'{' + b','.join( - self._adapt_bytea_array(v) for v in v) + b'}' - if v is None: - return b'null' - return self.db.escape_bytea(v).replace(b'\\', b'\\\\') - - def _adapt_json_array(self, v: Any) -> str: - """Adapt a json array parameter.""" - if isinstance(v, list): - adapt = self._adapt_json_array - return '{' + ','.join(adapt(v) for v in v) + '}' - if not v: - return 'null' - if not isinstance(v, str): - v = self.db.encode_json(v) - if self._re_array_quote.search(v): - v = self._re_array_escape.sub(r'\\\1', v) - v = f'"{v}"' - return v - - def _adapt_record(self, v: Any, typ: Any) -> str: - """Adapt a record parameter with given type.""" - typ = self.get_attnames(typ).values() - if len(typ) != len(v): - raise TypeError(f'Record parameter {v} has wrong size') - adapt = self.adapt - value = [] - for v, t in zip(v, typ): # noqa: B020 - v = adapt(v, t) - if v is None: - v = '' - elif not v: - v = '""' - else: - if isinstance(v, bytes): - if str is not bytes: - v = v.decode('ascii') - else: - v = str(v) - if self._re_record_quote.search(v): - v = self._re_record_escape.sub(r'\\\1', v) - v = f'"{v}"' - value.append(v) - v = ','.join(value) - return f'({v})' - - def adapt(self, value: Any, typ: Any = None) -> str: - """Adapt a value with known database type.""" - if value is not None and not isinstance(value, Literal): - if typ: - simple = self.get_simple_name(typ) - else: - typ = simple = self.guess_simple_type(value) or 'text' - pg_str = getattr(value, '__pg_str__', None) - if pg_str: - value = pg_str(typ) - if simple == 'text': - pass - elif simple == 'record': - if isinstance(value, tuple): - value = self._adapt_record(value, typ) - elif simple.endswith('[]'): - if isinstance(value, list): - adapt = getattr(self, f'_adapt_{simple[:-2]}_array') - value = adapt(value) - else: - adapt = getattr(self, f'_adapt_{simple}') - value = adapt(value) - return value - - @staticmethod - def simple_type(name: str) -> DbType: - """Create a simple database type with given attribute names.""" - typ = DbType(name) - typ.simple = name - return typ - - @staticmethod - def get_simple_name(typ: Any) -> str: - """Get the simple name of a database type.""" - if isinstance(typ, DbType): - # noinspection PyUnresolvedReferences - return typ.simple - return _simpletypes[typ] - - @staticmethod - def get_attnames(typ: Any) -> dict[str, dict[str, str]]: - """Get the attribute names of a composite database type.""" - if isinstance(typ, DbType): - return typ.attnames - return {} - - @classmethod - def guess_simple_type(cls, value: Any) -> str | None: - """Try to guess which database type the given value has.""" - # optimize for most frequent types - try: - return _simple_type_dict[type(value)] - except KeyError: - pass - if isinstance(value, (bytes, str)): - return 'text' - if isinstance(value, bool): - return 'bool' - if isinstance(value, int): - return 'int' - if isinstance(value, float): - return 'float' - if isinstance(value, Decimal): - return 'num' - if isinstance(value, (date, time, datetime, timedelta)): - return 'date' - if isinstance(value, Bytea): - return 'bytea' - if isinstance(value, Json): - return 'json' - if isinstance(value, Hstore): - return 'hstore' - if isinstance(value, UUID): - return 'uuid' - if isinstance(value, list): - return (cls.guess_simple_base_type(value) or 'text') + '[]' - if isinstance(value, tuple): - simple_type = cls.simple_type - guess = cls.guess_simple_type - - # noinspection PyUnusedLocal - def get_attnames(self: DbType) -> AttrDict: - return AttrDict((str(n + 1), simple_type(guess(v) or 'text')) - for n, v in enumerate(value)) - - typ = simple_type('record') - typ._get_attnames = get_attnames - return typ - return None - - @classmethod - def guess_simple_base_type(cls, value: Any) -> str | None: - """Try to guess the base type of a given array.""" - for v in value: - if isinstance(v, list): - typ = cls.guess_simple_base_type(v) - else: - typ = cls.guess_simple_type(v) - if typ: - return typ - return None - - def adapt_inline(self, value: Any, nested: bool=False) -> Any: - """Adapt a value that is put into the SQL and needs to be quoted.""" - if value is None: - return 'NULL' - if isinstance(value, Literal): - return value - if isinstance(value, Bytea): - value = self.db.escape_bytea(value).decode('ascii') - elif isinstance(value, (datetime, date, time, timedelta)): - value = str(value) - if isinstance(value, (bytes, str)): - value = self.db.escape_string(value) - return f"'{value}'" - if isinstance(value, bool): - return 'true' if value else 'false' - if isinstance(value, float): - if isinf(value): - return "'-Infinity'" if value < 0 else "'Infinity'" - if isnan(value): - return "'NaN'" - return value - if isinstance(value, (int, Decimal)): - return value - if isinstance(value, list): - q = self.adapt_inline - s = '[{}]' if nested else 'ARRAY[{}]' - return s.format(','.join(str(q(v, nested=True)) for v in value)) - if isinstance(value, tuple): - q = self.adapt_inline - return '({})'.format(','.join(str(q(v)) for v in value)) - if isinstance(value, Json): - value = self.db.escape_string(str(value)) - return f"'{value}'::json" - if isinstance(value, Hstore): - value = self.db.escape_string(str(value)) - return f"'{value}'::hstore" - pg_repr = getattr(value, '__pg_repr__', None) - if not pg_repr: - raise InterfaceError( - f'Do not know how to adapt type {type(value)}') - value = pg_repr() - if isinstance(value, (tuple, list)): - value = self.adapt_inline(value) - return value - - def parameter_list(self) -> _ParameterList: - """Return a parameter list for parameters with known database types. - - The list has an add(value, typ) method that will build up the - list and return either the literal value or a placeholder. - """ - params = _ParameterList() - params.adapt = self.adapt - return params - - def format_query(self, command: str, - values: list | tuple | dict | None = None, - types: list | tuple | dict | None = None, - inline: bool=False - ) -> tuple[str, _ParameterList]: - """Format a database query using the given values and types. - - The optional types describe the values and must be passed as a list, - tuple or string (that will be split on whitespace) when values are - passed as a list or tuple, or as a dict if values are passed as a dict. - - If inline is set to True, then parameters will be passed inline - together with the query string. - """ - params = self.parameter_list() - if not values: - return command, params - if inline and types: - raise ValueError('Typed parameters must be sent separately') - if isinstance(values, (list, tuple)): - if inline: - adapt = self.adapt_inline - seq_literals = [adapt(value) for value in values] - else: - add = params.add - if types: - if isinstance(types, str): - types = types.split() - if (not isinstance(types, (list, tuple)) - or len(types) != len(values)): - raise TypeError('The values and types do not match') - seq_literals = [add(value, typ) - for value, typ in zip(values, types)] - else: - seq_literals = [add(value) for value in values] - command %= tuple(seq_literals) - elif isinstance(values, dict): - # we want to allow extra keys in the dictionary, - # so we first must find the values actually used in the command - used_values = {} - map_literals = dict.fromkeys(values, '') - for key in values: - del map_literals[key] - try: - command % map_literals - except KeyError: - used_values[key] = values[key] # pyright: ignore - map_literals[key] = '' - if inline: - adapt = self.adapt_inline - map_literals = {key: adapt(value) - for key, value in used_values.items()} - else: - add = params.add - if types: - if not isinstance(types, dict): - raise TypeError('The values and types do not match') - map_literals = {key: add(used_values[key], types.get(key)) - for key in sorted(used_values)} - else: - map_literals = {key: add(used_values[key]) - for key in sorted(used_values)} - command %= map_literals - else: - raise TypeError('The values must be passed as tuple, list or dict') - return command, params - - -def cast_bool(value: str) -> Any: - """Cast a boolean value.""" - if not get_bool(): - return value - return value[0] == 't' - - -def cast_json(value: str) -> Any: - """Cast a JSON value.""" - cast = get_jsondecode() - if not cast: - return value - return cast(value) - - -def cast_num(value: str) -> Any: - """Cast a numeric value.""" - return (get_decimal() or float)(value) - - -def cast_money(value: str) -> Any: - """Cast a money value.""" - point = get_decimal_point() - if not point: - return value - if point != '.': - value = value.replace(point, '.') - value = value.replace('(', '-') - value = ''.join(c for c in value if c.isdigit() or c in '.-') - return (get_decimal() or float)(value) - - -def cast_int2vector(value: str) -> list[int]: - """Cast an int2vector value.""" - return [int(v) for v in value.split()] - - -def cast_date(value: str, connection: DB) -> Any: - """Cast a date value.""" - # The output format depends on the server setting DateStyle. The default - # setting ISO and the setting for German are actually unambiguous. The - # order of days and months in the other two settings is however ambiguous, - # so at least here we need to consult the setting to properly parse values. - if value == '-infinity': - return date.min - if value == 'infinity': - return date.max - values = value.split() - if values[-1] == 'BC': - return date.min - value = values[0] - if len(value) > 10: - return date.max - format = connection.date_format() - return datetime.strptime(value, format).date() - - -def cast_time(value: str) -> Any: - """Cast a time value.""" - format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' - return datetime.strptime(value, format).time() - - -_re_timezone = regex('(.*)([+-].*)') - - -def cast_timetz(value: str) -> Any: - """Cast a timetz value.""" - m = _re_timezone.match(value) - if m: - value, tz = m.groups() - else: - tz = '+0000' - format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' - value += _timezone_as_offset(tz) - format += '%z' - return datetime.strptime(value, format).timetz() - - -def cast_timestamp(value: str, connection: DB) -> Any: - """Cast a timestamp value.""" - if value == '-infinity': - return datetime.min - if value == 'infinity': - return datetime.max - values = value.split() - if values[-1] == 'BC': - return datetime.min - format = connection.date_format() - if format.endswith('-%Y') and len(values) > 2: - values = values[1:5] - if len(values[3]) > 4: - return datetime.max - formats = ['%d %b' if format.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] - else: - if len(values[0]) > 10: - return datetime.max - formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] - return datetime.strptime(' '.join(values), ' '.join(formats)) - - -def cast_timestamptz(value: str, connection: DB) -> Any: - """Cast a timestamptz value.""" - if value == '-infinity': - return datetime.min - if value == 'infinity': - return datetime.max - values = value.split() - if values[-1] == 'BC': - return datetime.min - format = connection.date_format() - if format.endswith('-%Y') and len(values) > 2: - values = values[1:] - if len(values[3]) > 4: - return datetime.max - formats = ['%d %b' if format.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] - values, tz = values[:-1], values[-1] - else: - if format.startswith('%Y-'): - m = _re_timezone.match(values[1]) - if m: - values[1], tz = m.groups() - else: - tz = '+0000' - else: - values, tz = values[:-1], values[-1] - if len(values[0]) > 10: - return datetime.max - formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] - values.append(_timezone_as_offset(tz)) - formats.append('%z') - return datetime.strptime(' '.join(values), ' '.join(formats)) - - -_re_interval_sql_standard = regex( - '(?:([+-])?([0-9]+)-([0-9]+) ?)?' - '(?:([+-]?[0-9]+)(?!:) ?)?' - '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') - -_re_interval_postgres = regex( - '(?:([+-]?[0-9]+) ?years? ?)?' - '(?:([+-]?[0-9]+) ?mons? ?)?' - '(?:([+-]?[0-9]+) ?days? ?)?' - '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') - -_re_interval_postgres_verbose = regex( - '@ ?(?:([+-]?[0-9]+) ?years? ?)?' - '(?:([+-]?[0-9]+) ?mons? ?)?' - '(?:([+-]?[0-9]+) ?days? ?)?' - '(?:([+-]?[0-9]+) ?hours? ?)?' - '(?:([+-]?[0-9]+) ?mins? ?)?' - '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?') - -_re_interval_iso_8601 = regex( - 'P(?:([+-]?[0-9]+)Y)?' - '(?:([+-]?[0-9]+)M)?' - '(?:([+-]?[0-9]+)D)?' - '(?:T(?:([+-]?[0-9]+)H)?' - '(?:([+-]?[0-9]+)M)?' - '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?') - - -def cast_interval(value: str) -> timedelta: - """Cast an interval value.""" - # The output format depends on the server setting IntervalStyle, but it's - # not necessary to consult this setting to parse it. It's faster to just - # check all possible formats, and there is no ambiguity here. - m = _re_interval_iso_8601.match(value) - if m: - s = [v or '0' for v in m.groups()] - secs_ago = s.pop(5) == '-' - d = [int(v) for v in s] - years, mons, days, hours, mins, secs, usecs = d - if secs_ago: - secs = -secs - usecs = -usecs - else: - m = _re_interval_postgres_verbose.match(value) - if m: - s, ago = [v or '0' for v in m.groups()[:8]], m.group(9) - secs_ago = s.pop(5) == '-' - d = [-int(v) for v in s] if ago else [int(v) for v in s] - years, mons, days, hours, mins, secs, usecs = d - if secs_ago: - secs = - secs - usecs = -usecs - else: - m = _re_interval_postgres.match(value) - if m and any(m.groups()): - s = [v or '0' for v in m.groups()] - hours_ago = s.pop(3) == '-' - d = [int(v) for v in s] - years, mons, days, hours, mins, secs, usecs = d - if hours_ago: - hours = -hours - mins = -mins - secs = -secs - usecs = -usecs - else: - m = _re_interval_sql_standard.match(value) - if m and any(m.groups()): - s = [v or '0' for v in m.groups()] - years_ago = s.pop(0) == '-' - hours_ago = s.pop(3) == '-' - d = [int(v) for v in s] - years, mons, days, hours, mins, secs, usecs = d - if years_ago: - years = -years - mons = -mons - if hours_ago: - hours = -hours - mins = -mins - secs = -secs - usecs = -usecs - else: - raise ValueError(f'Cannot parse interval: {value}') - days += 365 * years + 30 * mons - return timedelta(days=days, hours=hours, minutes=mins, - seconds=secs, microseconds=usecs) - - -class Typecasts(dict): - """Dictionary mapping database types to typecast functions. - - The cast functions get passed the string representation of a value in - the database which they need to convert to a Python object. The - passed string will never be None since NULL values are already - handled before the cast function is called. - - Note that the basic types are already handled by the C extension. - They only need to be handled here as record or array components. - """ - - # the default cast functions - # (str functions are ignored but have been added for faster access) - defaults: ClassVar[dict[str, Callable]] = { - 'char': str, 'bpchar': str, 'name': str, - 'text': str, 'varchar': str, 'sql_identifier': str, - 'bool': cast_bool, 'bytea': unescape_bytea, - 'int2': int, 'int4': int, 'serial': int, 'int8': int, 'oid': int, - 'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json, - 'float4': float, 'float8': float, - 'numeric': cast_num, 'money': cast_money, - 'date': cast_date, 'interval': cast_interval, - 'time': cast_time, 'timetz': cast_timetz, - 'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz, - 'int2vector': cast_int2vector, 'uuid': UUID, - 'anyarray': cast_array, 'record': cast_record} # pyright: ignore - - connection: DB | None = None # set in a connection specific instance - - def __missing__(self, typ: str) -> Callable | None: - """Create a cast function if it is not cached. - - Note that this class never raises a KeyError, - but returns None when no special cast function exists. - """ - if not isinstance(typ, str): - raise TypeError(f'Invalid type: {typ}') - cast: Callable | None = self.defaults.get(typ) - if cast: - # store default for faster access - cast = self._add_connection(cast) - self[typ] = cast - elif typ.startswith('_'): - base_cast = self[typ[1:]] - cast = self.create_array_cast(base_cast) - if base_cast: - self[typ] = cast - else: - attnames = self.get_attnames(typ) - if attnames: - casts = [self[v.pgtype] for v in attnames.values()] - cast = self.create_record_cast(typ, attnames, casts) - self[typ] = cast - return cast - - @staticmethod - def _needs_connection(func: Callable) -> bool: - """Check if a typecast function needs a connection argument.""" - try: - args = get_args(func) - except (TypeError, ValueError): - return False - return 'connection' in args[1:] - - def _add_connection(self, cast: Callable) -> Callable: - """Add a connection argument to the typecast function if necessary.""" - if not self.connection or not self._needs_connection(cast): - return cast - return partial(cast, connection=self.connection) - - def get(self, typ: str, default: Callable | None = None # type: ignore - ) -> Callable | None: - """Get the typecast function for the given database type.""" - return self[typ] or default - - def set(self, typ: str | Sequence[str], cast: Callable | None) -> None: - """Set a typecast function for the specified database type(s).""" - if isinstance(typ, str): - typ = [typ] - if cast is None: - for t in typ: - self.pop(t, None) - self.pop(f'_{t}', None) - else: - if not callable(cast): - raise TypeError("Cast parameter must be callable") - for t in typ: - self[t] = self._add_connection(cast) - self.pop(f'_{t}', None) - - def reset(self, typ: str | Sequence[str] | None = None) -> None: - """Reset the typecasts for the specified type(s) to their defaults. - - When no type is specified, all typecasts will be reset. - """ - if typ is None: - self.clear() - else: - if isinstance(typ, str): - typ = [typ] - for t in typ: - self.pop(t, None) - - @classmethod - def get_default(cls, typ: str) -> Any: - """Get the default typecast function for the given database type.""" - return cls.defaults.get(typ) - - @classmethod - def set_default(cls, typ: str | Sequence[str], - cast: Callable | None) -> None: - """Set a default typecast function for the given database type(s).""" - if isinstance(typ, str): - typ = [typ] - defaults = cls.defaults - if cast is None: - for t in typ: - defaults.pop(t, None) - defaults.pop(f'_{t}', None) - else: - if not callable(cast): - raise TypeError("Cast parameter must be callable") - for t in typ: - defaults[t] = cast - defaults.pop(f'_{t}', None) - - # noinspection PyMethodMayBeStatic,PyUnusedLocal - def get_attnames(self, typ: Any) -> AttrDict: - """Return the fields for the given record type. - - This method will be replaced with the get_attnames() method of DbTypes. - """ - return AttrDict() - - # noinspection PyMethodMayBeStatic - def dateformat(self) -> str: - """Return the current date format. - - This method will be replaced with the dateformat() method of DbTypes. - """ - return '%Y-%m-%d' - - def create_array_cast(self, basecast: Callable) -> Callable: - """Create an array typecast for the given base cast.""" - cast_array = self['anyarray'] - - def cast(v: Any) -> list: - return cast_array(v, basecast) - return cast - - def create_record_cast(self, name: str, fields: AttrDict, - casts: list[Callable]) -> Callable: - """Create a named record typecast for the given fields and casts.""" - cast_record = self['record'] - record = namedtuple(name, fields) # type: ignore - - def cast(v: Any) -> record: - # noinspection PyArgumentList - return record(*cast_record(v, casts)) - return cast - - -def get_typecast(typ: str) -> Callable | None: - """Get the global typecast function for the given database type.""" - return Typecasts.get_default(typ) - - -def set_typecast(typ: str | Sequence[str], cast: Callable | None) -> None: - """Set a global typecast function for the given database type(s). - - Note that connections cache cast functions. To be sure a global change - is picked up by a running connection, call db.db_types.reset_typecast(). - """ - Typecasts.set_default(typ, cast) - - -class DbType(str): - """Class augmenting the simple type name with additional info. - - The following additional information is provided: - - oid: the PostgreSQL type OID - pgtype: the internal PostgreSQL data type name - regtype: the registered PostgreSQL data type name - simple: the more coarse-grained PyGreSQL type name - typlen: the internal size, negative if variable - typtype: b = base type, c = composite type etc. - category: A = Array, b = Boolean, C = Composite etc. - delim: delimiter for array types - relid: corresponding table for composite types - attnames: attributes for composite types - """ - - oid: int - pgtype: str - regtype: str - simple: str - typlen: int - typtype: str - category: str - delim: str - relid: int - - _get_attnames: Callable[[DbType], AttrDict] - - @property - def attnames(self) -> AttrDict: - """Get names and types of the fields of a composite type.""" - # noinspection PyUnresolvedReferences - return self._get_attnames(self) - - -class DbTypes(dict): - """Cache for PostgreSQL data types. - - This cache maps type OIDs and names to DbType objects containing - information on the associated database type. - """ - - _num_types = frozenset('int float num money int2 int4 int8' - ' float4 float8 numeric money'.split()) - - def __init__(self, db: DB) -> None: - """Initialize type cache for connection.""" - super().__init__() - self._db = weakref.proxy(db) - self._regtypes = False - self._typecasts = Typecasts() - self._typecasts.get_attnames = self.get_attnames # type: ignore - self._typecasts.connection = self._db - self._query_pg_type = ( - "SELECT oid, typname, oid::pg_catalog.regtype," - " typlen, typtype, typcategory, typdelim, typrelid" - " FROM pg_catalog.pg_type" - " WHERE oid OPERATOR(pg_catalog.=) {}::pg_catalog.regtype") - - def add(self, oid: int, pgtype: str, regtype: str, - typlen: int, typtype: str, category: str, delim: str, relid: int - ) -> DbType: - """Create a PostgreSQL type name with additional info.""" - if oid in self: - return self[oid] - simple = 'record' if relid else _simpletypes[pgtype] - typ = DbType(regtype if self._regtypes else simple) - typ.oid = oid - typ.simple = simple - typ.pgtype = pgtype - typ.regtype = regtype - typ.typlen = typlen - typ.typtype = typtype - typ.category = category - typ.delim = delim - typ.relid = relid - typ._get_attnames = self.get_attnames # type: ignore - return typ - - def __missing__(self, key: int | str) -> DbType: - """Get the type info from the database if it is not cached.""" - try: - cmd = self._query_pg_type.format(_quote_if_unqualified('$1', key)) - res = self._db.query(cmd, (key,)).getresult() - except ProgrammingError: - res = None - if not res: - raise KeyError(f'Type {key} could not be found') - res = res[0] - typ = self.add(*res) - self[typ.oid] = self[typ.pgtype] = typ - return typ - - def get(self, key: int | str, # type: ignore - default: DbType | None = None) -> DbType | None: - """Get the type even if it is not cached.""" - try: - return self[key] - except KeyError: - return default - - def get_attnames(self, typ: Any) -> AttrDict | None: - """Get names and types of the fields of a composite type.""" - if not isinstance(typ, DbType): - typ = self.get(typ) - if not typ: - return None - if not typ.relid: - return None - return self._db.get_attnames(typ.relid, with_oid=False) - - def get_typecast(self, typ: Any) -> Callable | None: - """Get the typecast function for the given database type.""" - return self._typecasts.get(typ) - - def set_typecast(self, typ: str | Sequence[str], cast: Callable) -> None: - """Set a typecast function for the specified database type(s).""" - self._typecasts.set(typ, cast) - - def reset_typecast(self, typ: str | Sequence[str] | None = None) -> None: - """Reset the typecast function for the specified database type(s).""" - self._typecasts.reset(typ) - - def typecast(self, value: Any, typ: str) -> Any: - """Cast the given value according to the given database type.""" - if value is None: - # for NULL values, no typecast is necessary - return None - if not isinstance(typ, DbType): - db_type = self.get(typ) - if db_type: - typ = db_type.pgtype - cast = self.get_typecast(typ) if typ else None - if not cast or cast is str: - # no typecast is necessary - return value - return cast(value) - - -# The result rows for database operations are returned as named tuples -# by default. Since creating namedtuple classes is a somewhat expensive -# operation, we cache up to 1024 of these classes by default. - -# noinspection PyUnresolvedReferences -@lru_cache(maxsize=1024) -def _row_factory(names: Sequence[str]) -> Callable[[Sequence], NamedTuple]: - """Get a namedtuple factory for row results with the given names.""" - try: - return namedtuple('Row', names, rename=True)._make # type: ignore - except ValueError: # there is still a problem with the field names - names = [f'column_{n}' for n in range(len(names))] - return namedtuple('Row', names)._make # type: ignore - - -def set_row_factory_size(maxsize: int | None) -> None: - """Change the size of the namedtuple factory cache. - - If maxsize is set to None, the cache can grow without bound. - """ - # noinspection PyGlobalUndefined - global _row_factory - _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__) - - -# Helper functions used by the query object - -def _dictiter(q: Query) -> Generator[dict[str, Any], None, None]: - """Get query result as an iterator of dictionaries.""" - fields: tuple[str, ...] = q.listfields() - for r in q: - yield dict(zip(fields, r)) - - -def _namediter(q: Query) -> Generator[SomeNamedTuple, None, None]: - """Get query result as an iterator of named tuples.""" - row = _row_factory(q.listfields()) - for r in q: - yield row(r) - - -def _namednext(q: Query) -> SomeNamedTuple: - """Get next row from query result as a named tuple.""" - return _row_factory(q.listfields())(next(q)) - - -def _scalariter(q: Query) -> Generator[Any, None, None]: - """Get query result as an iterator of scalar values.""" - for r in q: - yield r[0] - - -class _MemoryQuery: - """Class that embodies a given query result.""" - - result: Any - fields: tuple[str, ...] - - def __init__(self, result: Any, fields: Sequence[str]) -> None: - """Create query from given result rows and field names.""" - self.result = result - self.fields = tuple(fields) - - def listfields(self) -> tuple[str, ...]: - """Return the stored field names of this query.""" - return self.fields - - def getresult(self) -> Any: - """Return the stored result of this query.""" - return self.result - - def __iter__(self) -> Iterator[Any]: - return iter(self.result) - -# Error messages - -E = TypeVar('E', bound=Error) - -def _error(msg: str, cls: type[E]) -> E: - """Return specified error object with empty sqlstate attribute.""" - error = cls(msg) - if isinstance(error, DatabaseError): - error.sqlstate = None - return error - - -def _db_error(msg: str) -> DatabaseError: - """Return DatabaseError.""" - return _error(msg, DatabaseError) - - -def _int_error(msg: str) -> InternalError: - """Return InternalError.""" - return _error(msg, InternalError) - - -def _prg_error(msg: str) -> ProgrammingError: - """Return ProgrammingError.""" - return _error(msg, ProgrammingError) - - -# Initialize the C module - -set_decimal(Decimal) -set_jsondecode(jsondecode) -set_query_helpers(_dictiter, _namediter, _namednext, _scalariter) - - -# The notification handler - -class NotificationHandler: - """A PostgreSQL client-side asynchronous notification handler.""" - - def __init__(self, db: DB, event: str, callback: Callable, - arg_dict: dict | None = None, - timeout: int | float | None = None, - stop_event: str | None = None): - """Initialize the notification handler. - - You must pass a PyGreSQL database connection, the name of an - event (notification channel) to listen for and a callback function. - - You can also specify a dictionary arg_dict that will be passed as - the single argument to the callback function, and a timeout value - in seconds (a floating point number denotes fractions of seconds). - If it is absent or None, the callers will never time out. If the - timeout is reached, the callback function will be called with a - single argument that is None. If you set the timeout to zero, - the handler will poll notifications synchronously and return. - - You can specify the name of the event that will be used to signal - the handler to stop listening as stop_event. By default, it will - be the event name prefixed with 'stop_'. - """ - self.db: DB | None = db - self.event = event - self.stop_event = stop_event or f'stop_{event}' - self.listening = False - self.callback = callback - if arg_dict is None: - arg_dict = {} - self.arg_dict = arg_dict - self.timeout = timeout - - def __del__(self) -> None: - """Delete the notification handler.""" - self.unlisten() - - def close(self) -> None: - """Stop listening and close the connection.""" - if self.db: - self.unlisten() - self.db.close() - self.db = None - - def listen(self) -> None: - """Start listening for the event and the stop event.""" - db = self.db - if db and not self.listening: - db.query(f'listen "{self.event}"') - db.query(f'listen "{self.stop_event}"') - self.listening = True - - def unlisten(self) -> None: - """Stop listening for the event and the stop event.""" - db = self.db - if db and self.listening: - db.query(f'unlisten "{self.event}"') - db.query(f'unlisten "{self.stop_event}"') - self.listening = False - - def notify(self, db: DB | None = None, stop: bool = False, - payload: str | None = None) -> Query | None: - """Generate a notification. - - Optionally, you can pass a payload with the notification. - - If you set the stop flag, a stop notification will be sent that - will cause the handler to stop listening. - - Note: If the notification handler is running in another thread, you - must pass a different database connection since PyGreSQL database - connections are not thread-safe. - """ - if not self.listening: - return None - if not db: - db = self.db - if not db: - return None - event = self.stop_event if stop else self.event - cmd = f'notify "{event}"' - if payload: - cmd += f", '{payload}'" - return db.query(cmd) - - def __call__(self) -> None: - """Invoke the notification handler. - - The handler is a loop that listens for notifications on the event - and stop event channels. When either of these notifications are - received, its associated 'pid', 'event' and 'extra' (the payload - passed with the notification) are inserted into its arg_dict - dictionary and the callback is invoked with this dictionary as - a single argument. When the handler receives a stop event, it - stops listening to both events and return. - - In the special case that the timeout of the handler has been set - to zero, the handler will poll all events synchronously and return. - If will keep listening until it receives a stop event. - - Note: If you run this loop in another thread, don't use the same - database connection for database operations in the main thread. - """ - if not self.db: - return - self.listen() - poll = self.timeout == 0 - rlist = [] if poll else [self.db.fileno()] - while self.db and self.listening: - # noinspection PyUnboundLocalVariable - if poll or select.select(rlist, [], [], self.timeout)[0]: - while self.db and self.listening: - notice = self.db.getnotify() - if not notice: # no more messages - break - event, pid, extra = notice - if event not in (self.event, self.stop_event): - self.unlisten() - raise _db_error( - f'Listening for "{self.event}"' - f' and "{self.stop_event}",' - f' but notified of "{event}"') - if event == self.stop_event: - self.unlisten() - self.arg_dict.update(pid=pid, event=event, extra=extra) - self.callback(self.arg_dict) - if poll: - break - else: # we timed out - self.unlisten() - self.callback(None) - - -# The actual PostgreSQL database connection interface: - -class DB: - """Wrapper class for the _pg connection type.""" - - db: Connection | None = None # invalid fallback for underlying connection - _db_args: Any # either the connectoin args or the underlying connection - - def __init__(self, *args: Any, **kw: Any) -> None: - """Create a new connection. - - You can pass either the connection parameters or an existing - _pg or pgdb connection. This allows you to use the methods - of the classic pg interface with a DB-API 2 pgdb connection. - """ - if not args and len(kw) == 1: - db = kw.get('db') - elif not kw and len(args) == 1: - db = args[0] - else: - db = None - if db: - if isinstance(db, DB): - db = db.db - else: - with suppress(AttributeError): - # noinspection PyUnresolvedReferences - db = db._cnx - if not db or not hasattr(db, 'db') or not hasattr(db, 'query'): - db = connect(*args, **kw) - self._db_args = args, kw - self._closeable = True - else: - self._db_args = db - self._closeable = False - self.db = db - self.dbname = db.db - self._regtypes = False - self._attnames: dict[str, AttrDict] = {} - self._generated: dict[str, frozenset[str]] = {} - self._pkeys: dict[str, str | tuple[str, ...]] = {} - self._privileges: dict[tuple[str, str], bool] = {} - self.adapter = Adapter(self) - self.dbtypes = DbTypes(self) - self._query_attnames = ( - "SELECT a.attname," - " t.oid, t.typname, t.oid::pg_catalog.regtype," - " t.typlen, t.typtype, t.typcategory, t.typdelim, t.typrelid" - " FROM pg_catalog.pg_attribute a" - " JOIN pg_catalog.pg_type t" - " ON t.oid OPERATOR(pg_catalog.=) a.atttypid" - " WHERE a.attrelid OPERATOR(pg_catalog.=)" - " {}::pg_catalog.regclass" - " AND {} AND NOT a.attisdropped ORDER BY a.attnum") - if db.server_version < 120000: - self._query_generated = ( - "a.attidentity OPERATOR(pg_catalog.=) 'a'" - ) - else: - self._query_generated = ( - "(a.attidentity OPERATOR(pg_catalog.=) 'a' OR" - " a.attgenerated OPERATOR(pg_catalog.!=) '')" - ) - db.set_cast_hook(self.dbtypes.typecast) - # For debugging scripts, self.debug can be set - # * to a string format specification (e.g. in CGI set to "%s
"), - # * to a file object to write debug statements or - # * to a callable object which takes a string argument - # * to any other true value to just print debug statements - self.debug: Any = None - - def __getattr__(self, name: str) -> Any: - """Get the specified attritbute of the connection.""" - # All undefined members are same as in underlying connection: - if self.db: - return getattr(self.db, name) - else: - raise _int_error('Connection is not valid') - - def __dir__(self) -> list[str]: - """List all attributes of the connection.""" - # Custom dir function including the attributes of the connection: - attrs = set(self.__class__.__dict__) - attrs.update(self.__dict__) - attrs.update(dir(self.db)) - return sorted(attrs) - - # Context manager methods - - def __enter__(self) -> DB: - """Enter the runtime context. This will start a transaction.""" - self.begin() - return self - - def __exit__(self, et: type[BaseException] | None, - ev: BaseException | None, tb: Any) -> None: - """Exit the runtime context. This will end the transaction.""" - if et is None and ev is None and tb is None: - self.commit() - else: - self.rollback() - - def __del__(self) -> None: - """Delete the connection.""" - try: - db = self.db - except AttributeError: - db = None - if db: - with suppress(TypeError): # when already closed - db.set_cast_hook(None) - if self._closeable: - with suppress(InternalError): # when already closed - db.close() - - # Auxiliary methods - - def _do_debug(self, *args: Any) -> None: - """Print a debug message.""" - if self.debug: - s = '\n'.join(str(arg) for arg in args) - if isinstance(self.debug, str): - print(self.debug % s) - elif hasattr(self.debug, 'write'): - # noinspection PyCallingNonCallable - self.debug.write(s + '\n') - elif callable(self.debug): - self.debug(s) - else: - print(s) - - def _escape_qualified_name(self, s: str) -> str: - """Escape a qualified name. - - Escapes the name for use as an SQL identifier, unless the - name contains a dot, in which case the name is ambiguous - (could be a qualified name or just a name with a dot in it) - and must be quoted manually by the caller. - """ - if '.' not in s: - s = self.escape_identifier(s) - return s - - @staticmethod - def _make_bool(d: Any) -> bool | str: - """Get boolean value corresponding to d.""" - return bool(d) if get_bool() else ('t' if d else 'f') - - @staticmethod - def _list_params(params: Sequence) -> str: - """Create a human readable parameter list.""" - return ', '.join(f'${n}={v!r}' for n, v in enumerate(params, 1)) - - @property - def _valid_db(self) -> Connection: - """Get underlying connection and make sure it is not closed.""" - db = self.db - if not db: - raise _int_error('Connection already closed') - return db - - # Public methods - - # escape_string and escape_bytea exist as methods, - # so we define unescape_bytea as a method as well - unescape_bytea = staticmethod(unescape_bytea) - - @staticmethod - def decode_json(s: str) -> Any: - """Decode a JSON string coming from the database.""" - return (get_jsondecode() or jsondecode)(s) - - @staticmethod - def encode_json(d: Any) -> str: - """Encode a JSON string for use within SQL.""" - return jsonencode(d) - - def close(self) -> None: - """Close the database connection.""" - # Wraps shared library function so we can track state. - db = self._valid_db - with suppress(TypeError): # when already closed - db.set_cast_hook(None) - if self._closeable: - db.close() - self.db = None - - def reset(self) -> None: - """Reset connection with current parameters. - - All derived queries and large objects derived from this connection - will not be usable after this call. - """ - self._valid_db.reset() - - def reopen(self) -> None: - """Reopen connection to the database. - - Used in case we need another connection to the same database. - Note that we can still reopen a database that we have closed. - """ - # There is no such shared library function. - if self._closeable: - args, kw = self._db_args - db = connect(*args, **kw) - if self.db: - self.db.set_cast_hook(None) - self.db.close() - db.set_cast_hook(self.dbtypes.typecast) - self.db = db - else: - self.db = self._db_args - - def begin(self, mode: str | None = None) -> Query: - """Begin a transaction.""" - qstr = 'BEGIN' - if mode: - qstr += ' ' + mode - return self.query(qstr) - - start = begin - - def commit(self) -> Query: - """Commit the current transaction.""" - return self.query('COMMIT') - - end = commit - - def rollback(self, name: str | None = None) -> Query: - """Roll back the current transaction.""" - qstr = 'ROLLBACK' - if name: - qstr += ' TO ' + name - return self.query(qstr) - - abort = rollback - - def savepoint(self, name: str) -> Query: - """Define a new savepoint within the current transaction.""" - return self.query('SAVEPOINT ' + name) - - def release(self, name: str) -> Query: - """Destroy a previously defined savepoint.""" - return self.query('RELEASE ' + name) - - def get_parameter(self, - parameter: str | list[str] | tuple[str, ...] | - set[str] | frozenset[str] | dict[str, Any] - ) -> str | list[str] | dict[str, str]: - """Get the value of a run-time parameter. - - If the parameter is a string, the return value will also be a string - that is the current setting of the run-time parameter with that name. - - You can get several parameters at once by passing a list, set or dict. - When passing a list of parameter names, the return value will be a - corresponding list of parameter settings. When passing a set of - parameter names, a new dict will be returned, mapping these parameter - names to their settings. Finally, if you pass a dict as parameter, - its values will be set to the current parameter settings corresponding - to its keys. - - By passing the special name 'all' as the parameter, you can get a dict - of all existing configuration parameters. - """ - values: Any - if isinstance(parameter, str): - parameter = [parameter] - values = None - elif isinstance(parameter, (list, tuple)): - values = [] - elif isinstance(parameter, (set, frozenset)): - values = {} - elif isinstance(parameter, dict): - values = parameter - else: - raise TypeError( - 'The parameter must be a string, list, set or dict') - if not parameter: - raise TypeError('No parameter has been specified') - query = self._valid_db.query - params: Any = {} if isinstance(values, dict) else [] - for param_key in parameter: - param = param_key.strip().lower() if isinstance( - param_key, (bytes, str)) else None - if not param: - raise TypeError('Invalid parameter') - if param == 'all': - cmd = 'SHOW ALL' - values = query(cmd).getresult() - values = {value[0]: value[1] for value in values} - break - if isinstance(params, dict): - params[param] = param_key - else: - params.append(param) - else: - for param in params: - cmd = f'SHOW {param}' - value = query(cmd).singlescalar() - if values is None: - values = value - elif isinstance(values, list): - values.append(value) - else: - values[params[param]] = value - return values - - def set_parameter(self, - parameter: str | list[str] | tuple[str, ...] | - set[str] | frozenset[str] | dict[str, Any], - value: str | list[str] | tuple[str, ...] | - set[str] | frozenset[str]| None = None, - local: bool = False) -> None: - """Set the value of a run-time parameter. - - If the parameter and the value are strings, the run-time parameter - will be set to that value. If no value or None is passed as a value, - then the run-time parameter will be restored to its default value. - - You can set several parameters at once by passing a list of parameter - names, together with a single value that all parameters should be - set to or with a corresponding list of values. You can also pass - the parameters as a set if you only provide a single value. - Finally, you can pass a dict with parameter names as keys. In this - case, you should not pass a value, since the values for the parameters - will be taken from the dict. - - By passing the special name 'all' as the parameter, you can reset - all existing settable run-time parameters to their default values. - - If you set local to True, then the command takes effect for only the - current transaction. After commit() or rollback(), the session-level - setting takes effect again. Setting local to True will appear to - have no effect if it is executed outside a transaction, since the - transaction will end immediately. - """ - if isinstance(parameter, str): - parameter = {parameter: value} - elif isinstance(parameter, (list, tuple)): - if isinstance(value, (list, tuple)): - parameter = dict(zip(parameter, value)) - else: - parameter = dict.fromkeys(parameter, value) - elif isinstance(parameter, (set, frozenset)): - if isinstance(value, (list, tuple, set, frozenset)): - value = set(value) - if len(value) == 1: - value = next(iter(value)) - if not (value is None or isinstance(value, str)): - raise ValueError( - 'A single value must be specified' - ' when parameter is a set') - parameter = dict.fromkeys(parameter, value) - elif isinstance(parameter, dict): - if value is not None: - raise ValueError( - 'A value must not be specified' - ' when parameter is a dictionary') - else: - raise TypeError( - 'The parameter must be a string, list, set or dict') - if not parameter: - raise TypeError('No parameter has been specified') - params: dict[str, str | None] = {} - for param, param_value in parameter.items(): - param = param.strip().lower() if isinstance(param, str) else None - if not param: - raise TypeError('Invalid parameter') - if param == 'all': - if param_value is not None: - raise ValueError( - 'A value must not be specified' - " when parameter is 'all'") - params = {'all': None} - break - params[param] = param_value - local_clause = ' LOCAL' if local else '' - for param, param_value in params.items(): - cmd = (f'RESET{local_clause} {param}' - if param_value is None else - f'SET{local_clause} {param} TO {param_value}') - self._do_debug(cmd) - self._valid_db.query(cmd) - - def query(self, command: str, *args: Any) -> Query: - """Execute a SQL command string. - - This method simply sends a SQL query to the database. If the query is - an insert statement that inserted exactly one row into a table that - has OIDs, the return value is the OID of the newly inserted row. - If the query is an update or delete statement, or an insert statement - that did not insert exactly one row in a table with OIDs, then the - number of rows affected is returned as a string. If it is a statement - that returns rows as a result (usually a select statement, but maybe - also an "insert/update ... returning" statement), this method returns - a Query object that can be accessed via getresult() or dictresult() - or simply printed. Otherwise, it returns `None`. - - The query can contain numbered parameters of the form $1 in place - of any data constant. Arguments given after the query string will - be substituted for the corresponding numbered parameter. Parameter - values can also be given as a single list or tuple argument. - """ - # Wraps shared library function for debugging. - db = self._valid_db - if args: - self._do_debug(command, args) - return db.query(command, args) - self._do_debug(command) - return db.query(command) - - def query_formatted(self, command: str, - parameters: tuple | list | dict | None = None, - types: tuple | list | dict | None = None, - inline: bool =False) -> Query: - """Execute a formatted SQL command string. - - Similar to query, but using Python format placeholders of the form - %s or %(names)s instead of PostgreSQL placeholders of the form $1. - The parameters must be passed as a tuple, list or dict. You can - also pass a corresponding tuple, list or dict of database types in - order to format the parameters properly in case there is ambiguity. - - If you set inline to True, the parameters will be sent to the database - embedded in the SQL command, otherwise they will be sent separately. - """ - return self.query(*self.adapter.format_query( - command, parameters, types, inline)) - - def query_prepared(self, name: str, *args: Any) -> Query: - """Execute a prepared SQL statement. - - This works like the query() method, except that instead of passing - the SQL command, you pass the name of a prepared statement. If you - pass an empty name, the unnamed statement will be executed. - """ - if name is None: - name = '' - db = self._valid_db - if args: - self._do_debug('EXECUTE', name, args) - return db.query_prepared(name, args) - self._do_debug('EXECUTE', name) - return db.query_prepared(name) - - def prepare(self, name: str, command: str) -> None: - """Create a prepared SQL statement. - - This creates a prepared statement for the given command with the - given name for later execution with the query_prepared() method. - - The name can be empty to create an unnamed statement, in which case - any pre-existing unnamed statement is automatically replaced; - otherwise it is an error if the statement name is already - defined in the current database session. We recommend always using - named queries, since unnamed queries have a limited lifetime and - can be automatically replaced or destroyed by various operations. - """ - if name is None: - name = '' - self._do_debug('prepare', name, command) - self._valid_db.prepare(name, command) - - def describe_prepared(self, name: str | None = None) -> Query: - """Describe a prepared SQL statement. - - This method returns a Query object describing the result columns of - the prepared statement with the given name. If you omit the name, - the unnamed statement will be described if you created one before. - """ - if name is None: - name = '' - return self._valid_db.describe_prepared(name) - - def delete_prepared(self, name: str | None = None) -> Query: - """Delete a prepared SQL statement. - - This deallocates a previously prepared SQL statement with the given - name, or deallocates all prepared statements if you do not specify a - name. Note that prepared statements are also deallocated automatically - when the current session ends. - """ - if not name: - name = 'ALL' - cmd = f"DEALLOCATE {name}" - self._do_debug(cmd) - return self._valid_db.query(cmd) - - def pkey(self, table: str, composite: bool = False, flush: bool = False - ) -> str | tuple[str, ...]: - """Get the primary key of a table. - - Single primary keys are returned as strings unless you - set the composite flag. Composite primary keys are always - represented as tuples. Note that this raises a KeyError - if the table does not have a primary key. - - If flush is set then the internal cache for primary keys will - be flushed. This may be necessary after the database schema or - the search path has been changed. - """ - pkeys = self._pkeys - if flush: - pkeys.clear() - self._do_debug('The pkey cache has been flushed') - try: # cache lookup - pkey = pkeys[table] - except KeyError as e: # cache miss, check the database - cmd = ("SELECT" # noqa: S608 - " a.attname, a.attnum, i.indkey" - " FROM pg_catalog.pg_index i" - " JOIN pg_catalog.pg_attribute a" - " ON a.attrelid OPERATOR(pg_catalog.=) i.indrelid" - " AND a.attnum OPERATOR(pg_catalog.=) ANY(i.indkey)" - " AND NOT a.attisdropped" - " WHERE i.indrelid OPERATOR(pg_catalog.=)" - " {}::pg_catalog.regclass" - " AND i.indisprimary ORDER BY a.attnum").format( - _quote_if_unqualified('$1', table)) - res = self._valid_db.query(cmd, (table,)).getresult() - if not res: - raise KeyError(f'Table {table} has no primary key') from e - # we want to use the order defined in the primary key index here, - # not the order as defined by the columns in the table - if len(res) > 1: - indkey = res[0][2] - pkey = tuple(row[0] for row in sorted( - res, key=lambda row: indkey.index(row[1]))) - else: - pkey = res[0][0] - pkeys[table] = pkey # cache it - if composite and not isinstance(pkey, tuple): - pkey = (pkey,) - return pkey - - def pkeys(self, table: str) -> tuple[str, ...]: - """Get the primary key of a table as a tuple. - - Same as pkey() with 'composite' set to True. - """ - return self.pkey(table, True) # type: ignore - - def get_databases(self) -> list[str]: - """Get list of databases in the system.""" - return [r[0] for r in self._valid_db.query( - 'SELECT datname FROM pg_catalog.pg_database').getresult()] - - def get_relations(self, kinds: str | Sequence[str] | None = None, - system: bool = False) -> list[str]: - """Get list of relations in connected database of specified kinds. - - If kinds is None or empty, all kinds of relations are returned. - Otherwise, kinds can be a string or sequence of type letters - specifying which kind of relations you want to list. - - Set the system flag if you want to get the system relations as well. - """ - where_parts = [] - if kinds: - where_parts.append( - "r.relkind IN ({})".format(','.join(f"'{k}'" for k in kinds))) - if not system: - where_parts.append("s.nspname NOT SIMILAR" - " TO 'pg/_%|information/_schema' ESCAPE '/'") - where = " WHERE " + ' AND '.join(where_parts) if where_parts else '' - cmd = ("SELECT" # noqa: S608 - " pg_catalog.quote_ident(s.nspname) OPERATOR(pg_catalog.||)" - " '.' OPERATOR(pg_catalog.||) pg_catalog.quote_ident(r.relname)" - " FROM pg_catalog.pg_class r" - " JOIN pg_catalog.pg_namespace s" - f" ON s.oid OPERATOR(pg_catalog.=) r.relnamespace{where}" - " ORDER BY s.nspname, r.relname") - return [r[0] for r in self._valid_db.query(cmd).getresult()] - - def get_tables(self, system: bool = False) -> list[str]: - """Return list of tables in connected database. - - Set the system flag if you want to get the system tables as well. - """ - return self.get_relations('r', system) - - def get_attnames(self, table: str, with_oid: bool=True, flush: bool=False - ) -> AttrDict: - """Given the name of a table, dig out the set of attribute names. - - Returns a read-only dictionary of attribute names (the names are - the keys, the values are the names of the attributes' types) - with the column names in the proper order if you iterate over it. - - If flush is set, then the internal cache for attribute names will - be flushed. This may be necessary after the database schema or - the search path has been changed. - - By default, only a limited number of simple types will be returned. - You can get the registered types after calling use_regtypes(True). - """ - attnames = self._attnames - if flush: - attnames.clear() - self._do_debug('The attnames cache has been flushed') - try: # cache lookup - names = attnames[table] - except KeyError: # cache miss, check the database - cmd = "a.attnum OPERATOR(pg_catalog.>) 0" - if with_oid: - cmd = f"({cmd} OR a.attname OPERATOR(pg_catalog.=) 'oid')" - cmd = self._query_attnames.format( - _quote_if_unqualified('$1', table), cmd) - res = self._valid_db.query(cmd, (table,)).getresult() - types = self.dbtypes - names = AttrDict((name[0], types.add(*name[1:])) for name in res) - attnames[table] = names # cache it - return names - - def get_generated(self, table: str, flush: bool = False) -> frozenset[str]: - """Given the name of a table, dig out the set of generated columns. - - Returns a set of column names that are generated and unalterable. - - If flush is set, then the internal cache for generated columns will - be flushed. This may be necessary after the database schema or - the search path has been changed. - """ - generated = self._generated - if flush: - generated.clear() - self._do_debug('The generated cache has been flushed') - try: # cache lookup - names = generated[table] - except KeyError: # cache miss, check the database - cmd = "a.attnum OPERATOR(pg_catalog.>) 0" - cmd = f"{cmd} AND {self._query_generated}" - cmd = self._query_attnames.format( - _quote_if_unqualified('$1', table), cmd) - res = self._valid_db.query(cmd, (table,)).getresult() - names = frozenset(name[0] for name in res) - generated[table] = names # cache it - return names - - def use_regtypes(self, regtypes: bool | None = None) -> bool: - """Use registered type names instead of simplified type names.""" - if regtypes is None: - return self.dbtypes._regtypes - regtypes = bool(regtypes) - if regtypes != self.dbtypes._regtypes: - self.dbtypes._regtypes = regtypes - self._attnames.clear() - self.dbtypes.clear() - return regtypes - - def has_table_privilege(self, table: str, privilege: str = 'select', - flush: bool = False) -> bool: - """Check whether current user has specified table privilege. - - If flush is set, then the internal cache for table privileges will - be flushed. This may be necessary after privileges have been changed. - """ - privileges = self._privileges - if flush: - privileges.clear() - self._do_debug('The privileges cache has been flushed') - privilege = privilege.lower() - try: # ask cache - ret = privileges[table, privilege] - except KeyError: # cache miss, ask the database - cmd = "SELECT pg_catalog.has_table_privilege({}, $2)".format( - _quote_if_unqualified('$1', table)) - query = self._valid_db.query(cmd, (table, privilege)) - ret = query.singlescalar() == self._make_bool(True) - privileges[table, privilege] = ret # cache it - return ret - - def get(self, table: str, row: Any, - keyname: str | tuple[str, ...] | None = None) -> dict[str, Any]: - """Get a row from a database table or view. - - This method is the basic mechanism to get a single row. It assumes - that the keyname specifies a unique row. It must be the name of a - single column or a tuple of column names. If the keyname is not - specified, then the primary key for the table is used. - - If row is a dictionary, then the value for the key is taken from it. - Otherwise, the row must be a single value or a tuple of values - corresponding to the passed keyname or primary key. The fetched row - from the table will be returned as a new dictionary or used to replace - the existing values when row was passed as a dictionary. - - The OID is also put into the dictionary if the table has one, but - in order to allow the caller to work with multiple tables, it is - munged as "oid(table)" using the actual name of the table. - """ - if table.endswith('*'): # hint for descendant tables can be ignored - table = table[:-1].rstrip() - attnames = self.get_attnames(table) - qoid = _oid_key(table) if 'oid' in attnames else None - if keyname and isinstance(keyname, str): - keyname = (keyname,) - if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row: - row['oid'] = row[qoid] - if not keyname: - try: # if keyname is not specified, try using the primary key - keyname = self.pkeys(table) - except KeyError as e: # the table has no primary key - # try using the oid instead - if qoid and isinstance(row, dict) and 'oid' in row: - keyname = ('oid',) - else: - raise _prg_error( - f'Table {table} has no primary key') from e - else: # the table has a primary key - # check whether all key columns have values - if isinstance(row, dict) and not set(keyname).issubset(row): - # try using the oid instead - if qoid and 'oid' in row: - keyname = ('oid',) - else: - raise KeyError( - 'Missing value in row for specified keyname') - if not isinstance(row, dict): - if not isinstance(row, (tuple, list)): - row = [row] - if len(keyname) != len(row): - raise KeyError( - 'Differing number of items in keyname and row') - row = dict(zip(keyname, row)) - params = self.adapter.parameter_list() - adapt = params.add - col = self.escape_identifier - what = 'oid, *' if qoid else '*' - where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( - col(k), adapt(row[k], attnames[k])) for k in keyname) - if 'oid' in row: - if qoid: - row[qoid] = row['oid'] - del row['oid'] - t = self._escape_qualified_name(table) - cmd = f'SELECT {what} FROM {t} WHERE {where} LIMIT 1' # noqa: S608s - self._do_debug(cmd, params) - query = self._valid_db.query(cmd, params) - res = query.dictresult() - if not res: - # make where clause in error message better readable - where = where.replace('OPERATOR(pg_catalog.=)', '=') - raise _db_error( - f'No such record in {table}\nwhere {where}\nwith ' - + self._list_params(params)) - for n, value in res[0].items(): - if qoid and n == 'oid': - n = qoid - row[n] = value - return row - - def insert(self, table: str, row: dict[str, Any] | None = None, **kw: Any - ) -> dict[str, Any]: - """Insert a row into a database table. - - This method inserts a row into a table. The name of the table must - be passed as the first parameter. The other parameters are used for - providing the data of the row that shall be inserted into the table. - If a dictionary is supplied as the second parameter, it starts with - that. Otherwise, it uses a blank dictionary. - Either way the dictionary is updated from the keywords. - - The dictionary is then reloaded with the values actually inserted in - order to pick up values modified by rules, triggers, etc. - """ - if table.endswith('*'): # hint for descendant tables can be ignored - table = table[:-1].rstrip() - if row is None: - row = {} - row.update(kw) - if 'oid' in row: - del row['oid'] # do not insert oid - attnames = self.get_attnames(table) - generated = self.get_generated(table) - qoid = _oid_key(table) if 'oid' in attnames else None - params = self.adapter.parameter_list() - adapt = params.add - col = self.escape_identifier - name_list, value_list = [], [] - for n in attnames: - if n in row and n not in generated: - name_list.append(col(n)) - value_list.append(adapt(row[n], attnames[n])) - if not name_list: - raise _prg_error('No column found that can be inserted') - names, values = ', '.join(name_list), ', '.join(value_list) - ret = 'oid, *' if qoid else '*' - t = self._escape_qualified_name(table) - cmd = (f'INSERT INTO {t} ({names})' # noqa: S608 - f' VALUES ({values}) RETURNING {ret}') - self._do_debug(cmd, params) - query = self._valid_db.query(cmd, params) - res = query.dictresult() - if res: # this should always be true - for n, value in res[0].items(): - if qoid and n == 'oid': - n = qoid - row[n] = value - return row - - def update(self, table: str, row: dict[str, Any] | None = None, **kw : Any - ) -> dict[str, Any]: - """Update an existing row in a database table. - - Similar to insert, but updates an existing row. The update is based - on the primary key of the table or the OID value as munged by get() - or passed as keyword. The OID will take precedence if provided, so - that it is possible to update the primary key itself. - - The dictionary is then modified to reflect any changes caused by the - update due to triggers, rules, default values, etc. - """ - if table.endswith('*'): - table = table[:-1].rstrip() # need parent table name - attnames = self.get_attnames(table) - generated = self.get_generated(table) - qoid = _oid_key(table) if 'oid' in attnames else None - if row is None: - row = {} - elif 'oid' in row: - del row['oid'] # only accept oid key from named args for safety - row.update(kw) - if qoid and qoid in row and 'oid' not in row: - row['oid'] = row[qoid] - if qoid and 'oid' in row: # try using the oid - keynames: tuple[str, ...] = ('oid',) - keyset = set(keynames) - else: # try using the primary key - try: - keynames = self.pkeys(table) - except KeyError as e: # the table has no primary key - raise _prg_error(f'Table {table} has no primary key') from e - keyset = set(keynames) - # check whether all key columns have values - if not keyset.issubset(row): - raise KeyError('Missing value for primary key in row') - params = self.adapter.parameter_list() - adapt = params.add - col = self.escape_identifier - where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( - col(k), adapt(row[k], attnames[k])) for k in keynames) - if 'oid' in row: - if qoid: - row[qoid] = row['oid'] - del row['oid'] - values_list = [] - for n in attnames: - if n in row and n not in keyset and n not in generated: - values_list.append(f'{col(n)} = {adapt(row[n], attnames[n])}') - if not values_list: - return row - values = ', '.join(values_list) - ret = 'oid, *' if qoid else '*' - t = self._escape_qualified_name(table) - cmd = (f'UPDATE {t} SET {values}' # noqa: S608 - f' WHERE {where} RETURNING {ret}') - self._do_debug(cmd, params) - query = self._valid_db.query(cmd, params) - res = query.dictresult() - if res: # may be empty when row does not exist - for n, value in res[0].items(): - if qoid and n == 'oid': - n = qoid - row[n] = value - return row - - def upsert(self, table: str, row: dict[str, Any] | None = None, **kw: Any - ) -> dict[str, Any]: - """Insert a row into a database table with conflict resolution. - - This method inserts a row into a table, but instead of raising a - ProgrammingError exception in case a row with the same primary key - already exists, an update will be executed instead. This will be - performed as a single atomic operation on the database, so race - conditions can be avoided. - - Like the insert method, the first parameter is the name of the - table and the second parameter can be used to pass the values to - be inserted as a dictionary. - - Unlike the insert und update statement, keyword parameters are not - used to modify the dictionary, but to specify which columns shall - be updated in case of a conflict, and in which way: - - A value of False or None means the column shall not be updated, - a value of True means the column shall be updated with the value - that has been proposed for insertion, i.e. has been passed as value - in the dictionary. Columns that are not specified by keywords but - appear as keys in the dictionary are also updated like in the case - keywords had been passed with the value True. - - So if in the case of a conflict you want to update every column - that has been passed in the dictionary row, you would call - upsert(table, row). If you don't want to do anything in case - of a conflict, i.e. leave the existing row as it is, call - upsert(table, row, **dict.fromkeys(row)). - - If you need more fine-grained control of what gets updated, you can - also pass strings in the keyword parameters. These strings will - be used as SQL expressions for the update columns. In these - expressions you can refer to the value that already exists in - the table by prefixing the column name with "included.", and to - the value that has been proposed for insertion by prefixing the - column name with the "excluded." - - The dictionary is modified in any case to reflect the values in - the database after the operation has completed. - - Note: The method uses the PostgreSQL "upsert" feature which is - only available since PostgreSQL 9.5. - """ - if table.endswith('*'): # hint for descendant tables can be ignored - table = table[:-1].rstrip() - if row is None: - row = {} - if 'oid' in row: - del row['oid'] # do not insert oid - if 'oid' in kw: - del kw['oid'] # do not update oid - attnames = self.get_attnames(table) - generated = self.get_generated(table) - qoid = _oid_key(table) if 'oid' in attnames else None - params = self.adapter.parameter_list() - adapt = params.add - col = self.escape_identifier - name_list, value_list = [], [] - for n in attnames: - if n in row and n not in generated: - name_list.append(col(n)) - value_list.append(adapt(row[n], attnames[n])) - names, values = ', '.join(name_list), ', '.join(value_list) - try: - keynames = self.pkeys(table) - except KeyError as e: - raise _prg_error(f'Table {table} has no primary key') from e - target = ', '.join(col(k) for k in keynames) - update = [] - keyset = set(keynames) - keyset.add('oid') - for n in attnames: - if n not in keyset and n not in generated: - value = kw.get(n, n in row) - if value: - if not isinstance(value, str): - value = f'excluded.{col(n)}' - update.append(f'{col(n)} = {value}') - if not values: - return row - do = 'update set ' + ', '.join(update) if update else 'nothing' - ret = 'oid, *' if qoid else '*' - t = self._escape_qualified_name(table) - cmd = (f'INSERT INTO {t} AS included ({names})' # noqa: S608 - f' VALUES ({values})' - f' ON CONFLICT ({target}) DO {do} RETURNING {ret}') - self._do_debug(cmd, params) - query = self._valid_db.query(cmd, params) - res = query.dictresult() - if res: # may be empty with "do nothing" - for n, value in res[0].items(): - if qoid and n == 'oid': - n = qoid - row[n] = value - else: - self.get(table, row) - return row - - def clear(self, table: str, row: dict[str, Any] | None = None - ) -> dict[str, Any]: - """Clear all the attributes to values determined by the types. - - Numeric types are set to 0, Booleans are set to false, and everything - else is set to the empty string. If the row argument is present, - it is used as the row dictionary and any entries matching attribute - names are cleared with everything else left unchanged. - """ - # At some point we will need a way to get defaults from a table. - if row is None: - row = {} # empty if argument is not present - attnames = self.get_attnames(table) - for n, t in attnames.items(): - if n == 'oid': - continue - t = t.simple - if t in DbTypes._num_types: - row[n] = 0 - elif t == 'bool': - row[n] = self._make_bool(False) - else: - row[n] = '' - return row - - def delete(self, table: str, row: dict[str, Any] | None = None, **kw: Any - ) -> int: - """Delete an existing row in a database table. - - This method deletes the row from a table. It deletes based on the - primary key of the table or the OID value as munged by get() or - passed as keyword. The OID will take precedence if provided. - - The return value is the number of deleted rows (i.e. 0 if the row - did not exist and 1 if the row was deleted). - - Note that if the row cannot be deleted because e.g. it is still - referenced by another table, this method raises a ProgrammingError. - """ - if table.endswith('*'): # hint for descendant tables can be ignored - table = table[:-1].rstrip() - attnames = self.get_attnames(table) - qoid = _oid_key(table) if 'oid' in attnames else None - if row is None: - row = {} - elif 'oid' in row: - del row['oid'] # only accept oid key from named args for safety - row.update(kw) - if qoid and qoid in row and 'oid' not in row: - row['oid'] = row[qoid] - if qoid and 'oid' in row: # try using the oid - keynames: tuple[str, ...] = ('oid',) - else: # try using the primary key - try: - keynames = self.pkeys(table) - except KeyError as e: # the table has no primary key - raise _prg_error(f'Table {table} has no primary key') from e - # check whether all key columns have values - if not set(keynames).issubset(row): - raise KeyError('Missing value for primary key in row') - params = self.adapter.parameter_list() - adapt = params.add - col = self.escape_identifier - where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( - col(k), adapt(row[k], attnames[k])) for k in keynames) - if 'oid' in row: - if qoid: - row[qoid] = row['oid'] - del row['oid'] - t = self._escape_qualified_name(table) - cmd = f'DELETE FROM {t} WHERE {where}' # noqa: S608 - self._do_debug(cmd, params) - res = self._valid_db.query(cmd, params) - return int(res) # type: ignore - - def truncate(self, table: str | list[str] | tuple[str, ...] | - set[str] | frozenset[str], restart: bool = False, - cascade: bool = False, only: bool = False) -> Query: - """Empty a table or set of tables. - - This method quickly removes all rows from the given table or set - of tables. It has the same effect as an unqualified DELETE on each - table, but since it does not actually scan the tables it is faster. - Furthermore, it reclaims disk space immediately, rather than requiring - a subsequent VACUUM operation. This is most useful on large tables. - - If restart is set to True, sequences owned by columns of the truncated - table(s) are automatically restarted. If cascade is set to True, it - also truncates all tables that have foreign-key references to any of - the named tables. If the parameter 'only' is not set to True, all the - descendant tables (if any) will also be truncated. Optionally, a '*' - can be specified after the table name to explicitly indicate that - descendant tables are included. - """ - if isinstance(table, str): - table_only = {table: only} - table = [table] - elif isinstance(table, (list, tuple)): - if isinstance(only, (list, tuple)): - table_only = dict(zip(table, only)) - else: - table_only = dict.fromkeys(table, only) - elif isinstance(table, (set, frozenset)): - table_only = dict.fromkeys(table, only) - else: - raise TypeError('The table must be a string, list or set') - if not (restart is None or isinstance(restart, (bool, int))): - raise TypeError('Invalid type for the restart option') - if not (cascade is None or isinstance(cascade, (bool, int))): - raise TypeError('Invalid type for the cascade option') - tables = [] - for t in table: - u = table_only.get(t) - if not (u is None or isinstance(u, (bool, int))): - raise TypeError('Invalid type for the only option') - if t.endswith('*'): - if u: - raise ValueError( - 'Contradictory table name and only options') - t = t[:-1].rstrip() - t = self._escape_qualified_name(t) - if u: - t = f'ONLY {t}' - tables.append(t) - cmd_parts = ['TRUNCATE', ', '.join(tables)] - if restart: - cmd_parts.append('RESTART IDENTITY') - if cascade: - cmd_parts.append('CASCADE') - cmd = ' '.join(cmd_parts) - self._do_debug(cmd) - return self._valid_db.query(cmd) - - def get_as_list( - self, table: str, - what: str | list[str] | tuple[str, ...] | None = None, - where: str | list[str] | tuple[str, ...] | None = None, - order: str | list[str] | tuple[str, ...] | bool | None = None, - limit: int | None = None, offset: int | None = None, - scalar: bool = False) -> list: - """Get a table as a list. - - This gets a convenient representation of the table as a list - of named tuples in Python. You only need to pass the name of - the table (or any other SQL expression returning rows). Note that - by default this will return the full content of the table which - can be huge and overflow your memory. However, you can control - the amount of data returned using the other optional parameters. - - The parameter 'what' can restrict the query to only return a - subset of the table columns. It can be a string, list or a tuple. - - The parameter 'where' can restrict the query to only return a - subset of the table rows. It can be a string, list or a tuple - of SQL expressions that all need to be fulfilled. - - The parameter 'order' specifies the ordering of the rows. It can - also be a string, list or a tuple. If no ordering is specified, - the result will be ordered by the primary key(s) or all columns if - no primary key exists. You can set 'order' to False if you don't - care about the ordering. The parameters 'limit' and 'offset' can be - integers specifying the maximum number of rows returned and a number - of rows skipped over. - - If you set the 'scalar' option to True, then instead of the - named tuples you will get the first items of these tuples. - This is useful if the result has only one column anyway. - """ - if not table: - raise TypeError('The table name is missing') - if what: - if isinstance(what, (list, tuple)): - what = ', '.join(map(str, what)) - if order is None: - order = what - else: - what = '*' - cmd_parts = ['SELECT', what, 'FROM', table] - if where: - if isinstance(where, (list, tuple)): - where = ' AND '.join(map(str, where)) - cmd_parts.extend(['WHERE', where]) - if order is None or order is True: - try: - order = self.pkeys(table) - except (KeyError, ProgrammingError): - with suppress(KeyError, ProgrammingError): - order = list(self.get_attnames(table)) - if order and not isinstance(order, bool): - if isinstance(order, (list, tuple)): - order = ', '.join(map(str, order)) - cmd_parts.extend(['ORDER BY', order]) - if limit: - cmd_parts.append(f'LIMIT {limit}') - if offset: - cmd_parts.append(f'OFFSET {offset}') - cmd = ' '.join(cmd_parts) - self._do_debug(cmd) - query = self._valid_db.query(cmd) - res = query.namedresult() - if res and scalar: - res = [row[0] for row in res] - return res - - def get_as_dict( - self, table: str, - keyname: str | list[str] | tuple[str, ...] | None = None, - what: str | list[str] | tuple[str, ...] | None = None, - where: str | list[str] | tuple[str, ...] | None = None, - order: str | list[str] | tuple[str, ...] | bool | None = None, - limit: int | None = None, offset: int | None = None, - scalar: bool = False) -> dict: - """Get a table as a dictionary. - - This method is similar to get_as_list(), but returns the table - as a Python dict instead of a Python list, which can be even - more convenient. The primary key column(s) of the table will - be used as the keys of the dictionary, while the other column(s) - will be the corresponding values. The keys will be named tuples - if the table has a composite primary key. The rows will be also - named tuples unless the 'scalar' option has been set to True. - With the optional parameter 'keyname' you can specify an alternative - set of columns to be used as the keys of the dictionary. It must - be set as a string, list or a tuple. - - The dictionary will be ordered using the order specified with the - 'order' parameter or the key column(s) if not specified. You can - set 'order' to False if you don't care about the ordering. - """ - if not table: - raise TypeError('The table name is missing') - if not keyname: - try: - keyname = self.pkeys(table) - except (KeyError, ProgrammingError) as e: - raise _prg_error(f'Table {table} has no primary key') from e - if isinstance(keyname, str): - keynames: list[str] | tuple[str, ...] = (keyname,) - elif isinstance(keyname, (list, tuple)): - keynames = keyname - else: - raise KeyError('The keyname must be a string, list or tuple') - if what: - if isinstance(what, (list, tuple)): - what = ', '.join(map(str, what)) - if order is None: - order = what - else: - what = '*' - cmd_parts = ['SELECT', what, 'FROM', table] - if where: - if isinstance(where, (list, tuple)): - where = ' AND '.join(map(str, where)) - cmd_parts.extend(['WHERE', where]) - if order is None or order is True: - order = keyname - if order and not isinstance(order, bool): - if isinstance(order, (list, tuple)): - order = ', '.join(map(str, order)) - cmd_parts.extend(['ORDER BY', order]) - if limit: - cmd_parts.append(f'LIMIT {limit}') - if offset: - cmd_parts.append(f'OFFSET {offset}') - cmd = ' '.join(cmd_parts) - self._do_debug(cmd) - query = self._valid_db.query(cmd) - res = query.getresult() - if not res: - return {} - keyset = set(keynames) - fields = query.listfields() - if not keyset.issubset(fields): - raise KeyError('Missing keyname in row') - key_index: list[int] = [] - row_index: list[int] = [] - for i, f in enumerate(fields): - (key_index if f in keyset else row_index).append(i) - key_tuple = len(key_index) > 1 - get_key = itemgetter(*key_index) - keys = map(get_key, res) - if scalar: - row_index = row_index[:1] - row_is_tuple = False - else: - row_is_tuple = len(row_index) > 1 - if scalar or row_is_tuple: - get_row: Callable[[tuple], tuple] = itemgetter( # pyright: ignore - *row_index) - else: - frst_index = row_index[0] - - def get_row(row : tuple) -> tuple: - return row[frst_index], # tuple with one item - - row_is_tuple = True - rows = map(get_row, res) - if key_tuple or row_is_tuple: - if key_tuple: - keys = _namediter(_MemoryQuery(keys, keynames)) # type: ignore - if row_is_tuple: - fields = tuple(f for f in fields if f not in keyset) - rows = _namediter(_MemoryQuery(rows, fields)) # type: ignore - # noinspection PyArgumentList - return dict(zip(keys, rows)) - - def notification_handler(self, event: str, callback: Callable, - arg_dict: dict | None = None, - timeout: int | float | None = None, - stop_event: str | None = None - ) -> NotificationHandler: - """Get notification handler that will run the given callback.""" - return NotificationHandler(self, event, callback, - arg_dict, timeout, stop_event) - - -# if run as script, print some information - -if __name__ == '__main__': - print('PyGreSQL version', version) - print() - print(__doc__) +init_core() diff --git a/pg/adapt.py b/pg/adapt.py new file mode 100644 index 00000000..fd4705ae --- /dev/null +++ b/pg/adapt.py @@ -0,0 +1,680 @@ +"""Adaption of parameters.""" + +from __future__ import annotations + +import weakref +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from json import dumps as jsonencode +from math import isinf, isnan +from re import compile as regex +from types import MappingProxyType +from typing import TYPE_CHECKING, Any, Callable, List, Mapping, Sequence +from uuid import UUID + +from .attrs import AttrDict +from .cast import Typecasts +from .core import InterfaceError, ProgrammingError +from .helpers import quote_if_unqualified + +if TYPE_CHECKING: + from .db import DB + +__all__ = [ + 'Adapter', 'Bytea', 'DbType', 'DbTypes', + 'Hstore', 'Literal', 'Json', 'UUID' +] + + +class Bytea(bytes): + """Wrapper class for marking Bytea values.""" + + +class Hstore(dict): + """Wrapper class for marking hstore values.""" + + _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]') + + @classmethod + def _quote(cls, s: Any) -> str: + if s is None: + return 'NULL' + if not isinstance(s, str): + s = str(s) + if not s: + return '""' + s = s.replace('"', '\\"') + if cls._re_quote.search(s): + s = f'"{s}"' + return s + + def __str__(self) -> str: + """Create a printable representation of the hstore value.""" + q = self._quote + return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) + + +class Json: + """Wrapper class for marking Json values.""" + + def __init__(self, obj: Any, + encode: Callable[[Any], str] | None = None) -> None: + """Initialize the JSON object.""" + self.obj = obj + self.encode = encode or jsonencode + + def __str__(self) -> str: + """Create a printable representation of the JSON object.""" + obj = self.obj + if isinstance(obj, str): + return obj + return self.encode(obj) + + +class Literal(str): + """Wrapper class for marking literal SQL values.""" + + + +class _SimpleTypes(dict): + """Dictionary mapping pg_type names to simple type names. + + The corresponding Python types and simple names are also mapped. + """ + + _type_aliases: Mapping[str, list[str | type]] = MappingProxyType({ + 'bool': [bool], + 'bytea': [Bytea], + 'date': ['interval', 'time', 'timetz', 'timestamp', 'timestamptz', + 'abstime', 'reltime', # these are very old + 'datetime', 'timedelta', # these do not really exist + date, time, datetime, timedelta], + 'float': ['float4', 'float8', float], + 'int': ['cid', 'int2', 'int4', 'int8', 'oid', 'xid', int], + 'hstore': [Hstore], 'json': ['jsonb', Json], 'uuid': [UUID], + 'num': ['numeric', Decimal], 'money': [], + 'text': ['bpchar', 'char', 'name', 'varchar', bytes, str] + }) + + # noinspection PyMissingConstructor + def __init__(self) -> None: + """Initialize type mapping.""" + for typ, keys in self._type_aliases.items(): + keys = [typ, *keys] + for key in keys: + self[key] = typ + if isinstance(key, str): + self[f'_{key}'] = f'{typ}[]' + elif not isinstance(key, tuple): + self[List[key]] = f'{typ}[]' # type: ignore + + @staticmethod + def __missing__(key: str) -> str: + """Unmapped types are interpreted as text.""" + return 'text' + + def get_type_dict(self) -> dict[type, str]: + """Get a plain dictionary of only the types.""" + return {key: typ for key, typ in self.items() + if not isinstance(key, (str, tuple))} + + +_simpletypes = _SimpleTypes() +_simple_type_dict = _simpletypes.get_type_dict() + + +class _ParameterList(list): + """Helper class for building typed parameter lists.""" + + adapt: Callable + + def add(self, value: Any, typ:Any = None) -> str: + """Typecast value with known database type and build parameter list. + + If this is a literal value, it will be returned as is. Otherwise, a + placeholder will be returned and the parameter list will be augmented. + """ + # noinspection PyUnresolvedReferences + value = self.adapt(value, typ) + if isinstance(value, Literal): + return value + self.append(value) + return f'${len(self)}' + + + +class DbType(str): + """Class augmenting the simple type name with additional info. + + The following additional information is provided: + + oid: the PostgreSQL type OID + pgtype: the internal PostgreSQL data type name + regtype: the registered PostgreSQL data type name + simple: the more coarse-grained PyGreSQL type name + typlen: the internal size, negative if variable + typtype: b = base type, c = composite type etc. + category: A = Array, b = Boolean, C = Composite etc. + delim: delimiter for array types + relid: corresponding table for composite types + attnames: attributes for composite types + """ + + oid: int + pgtype: str + regtype: str + simple: str + typlen: int + typtype: str + category: str + delim: str + relid: int + + _get_attnames: Callable[[DbType], AttrDict] + + @property + def attnames(self) -> AttrDict: + """Get names and types of the fields of a composite type.""" + # noinspection PyUnresolvedReferences + return self._get_attnames(self) + + +class DbTypes(dict): + """Cache for PostgreSQL data types. + + This cache maps type OIDs and names to DbType objects containing + information on the associated database type. + """ + + _num_types = frozenset('int float num money int2 int4 int8' + ' float4 float8 numeric money'.split()) + + def __init__(self, db: DB) -> None: + """Initialize type cache for connection.""" + super().__init__() + self._db = weakref.proxy(db) + self._regtypes = False + self._typecasts = Typecasts() + self._typecasts.get_attnames = self.get_attnames # type: ignore + self._typecasts.connection = self._db.db + self._query_pg_type = ( + "SELECT oid, typname, oid::pg_catalog.regtype," + " typlen, typtype, typcategory, typdelim, typrelid" + " FROM pg_catalog.pg_type" + " WHERE oid OPERATOR(pg_catalog.=) {}::pg_catalog.regtype") + + def add(self, oid: int, pgtype: str, regtype: str, + typlen: int, typtype: str, category: str, delim: str, relid: int + ) -> DbType: + """Create a PostgreSQL type name with additional info.""" + if oid in self: + return self[oid] + simple = 'record' if relid else _simpletypes[pgtype] + typ = DbType(regtype if self._regtypes else simple) + typ.oid = oid + typ.simple = simple + typ.pgtype = pgtype + typ.regtype = regtype + typ.typlen = typlen + typ.typtype = typtype + typ.category = category + typ.delim = delim + typ.relid = relid + typ._get_attnames = self.get_attnames # type: ignore + return typ + + def __missing__(self, key: int | str) -> DbType: + """Get the type info from the database if it is not cached.""" + try: + cmd = self._query_pg_type.format(quote_if_unqualified('$1', key)) + res = self._db.query(cmd, (key,)).getresult() + except ProgrammingError: + res = None + if not res: + raise KeyError(f'Type {key} could not be found') + res = res[0] + typ = self.add(*res) + self[typ.oid] = self[typ.pgtype] = typ + return typ + + def get(self, key: int | str, # type: ignore + default: DbType | None = None) -> DbType | None: + """Get the type even if it is not cached.""" + try: + return self[key] + except KeyError: + return default + + def get_attnames(self, typ: Any) -> AttrDict | None: + """Get names and types of the fields of a composite type.""" + if not isinstance(typ, DbType): + typ = self.get(typ) + if not typ: + return None + if not typ.relid: + return None + return self._db.get_attnames(typ.relid, with_oid=False) + + def get_typecast(self, typ: Any) -> Callable | None: + """Get the typecast function for the given database type.""" + return self._typecasts.get(typ) + + def set_typecast(self, typ: str | Sequence[str], cast: Callable) -> None: + """Set a typecast function for the specified database type(s).""" + self._typecasts.set(typ, cast) + + def reset_typecast(self, typ: str | Sequence[str] | None = None) -> None: + """Reset the typecast function for the specified database type(s).""" + self._typecasts.reset(typ) + + def typecast(self, value: Any, typ: str) -> Any: + """Cast the given value according to the given database type.""" + if value is None: + # for NULL values, no typecast is necessary + return None + if not isinstance(typ, DbType): + db_type = self.get(typ) + if db_type: + typ = db_type.pgtype + cast = self.get_typecast(typ) if typ else None + if not cast or cast is str: + # no typecast is necessary + return value + return cast(value) + + +class Adapter: + """Class providing methods for adapting parameters to the database.""" + + _bool_true_values = frozenset('t true 1 y yes on'.split()) + + _date_literals = frozenset( + 'current_date current_time' + ' current_timestamp localtime localtimestamp'.split()) + + _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$') + _re_record_quote = regex(r'[(,"\\]') + _re_array_escape = _re_record_escape = regex(r'(["\\])') + + def __init__(self, db: DB): + """Initialize the adapter object with the given connection.""" + self.db = weakref.proxy(db) + + @classmethod + def _adapt_bool(cls, v: Any) -> str | None: + """Adapt a boolean parameter.""" + if isinstance(v, str): + if not v: + return None + v = v.lower() in cls._bool_true_values + return 't' if v else 'f' + + @classmethod + def _adapt_date(cls, v: Any) -> Any: + """Adapt a date parameter.""" + if not v: + return None + if isinstance(v, str) and v.lower() in cls._date_literals: + return Literal(v) + return v + + @staticmethod + def _adapt_num(v: Any) -> Any: + """Adapt a numeric parameter.""" + if not v and v != 0: + return None + return v + + _adapt_int = _adapt_float = _adapt_money = _adapt_num + + def _adapt_bytea(self, v: Any) -> str: + """Adapt a bytea parameter.""" + return self.db.escape_bytea(v) + + def _adapt_json(self, v: Any) -> str | None: + """Adapt a json parameter.""" + if not v: + return None + if isinstance(v, str): + return v + if isinstance(v, Json): + return str(v) + return self.db.encode_json(v) + + def _adapt_hstore(self, v: Any) -> str | None: + """Adapt a hstore parameter.""" + if not v: + return None + if isinstance(v, str): + return v + if isinstance(v, Hstore): + return str(v) + if isinstance(v, dict): + return str(Hstore(v)) + raise TypeError(f'Hstore parameter {v} has wrong type') + + def _adapt_uuid(self, v: Any) -> str | None: + """Adapt a UUID parameter.""" + if not v: + return None + if isinstance(v, str): + return v + return str(v) + + @classmethod + def _adapt_text_array(cls, v: Any) -> str: + """Adapt a text type array parameter.""" + if isinstance(v, list): + adapt = cls._adapt_text_array + return '{' + ','.join(adapt(v) for v in v) + '}' + if v is None: + return 'null' + if not v: + return '""' + v = str(v) + if cls._re_array_quote.search(v): + v = cls._re_array_escape.sub(r'\\\1', v) + v = f'"{v}"' + return v + + _adapt_date_array = _adapt_text_array + + @classmethod + def _adapt_bool_array(cls, v: Any) -> str: + """Adapt a boolean array parameter.""" + if isinstance(v, list): + adapt = cls._adapt_bool_array + return '{' + ','.join(adapt(v) for v in v) + '}' + if v is None: + return 'null' + if isinstance(v, str): + if not v: + return 'null' + v = v.lower() in cls._bool_true_values + return 't' if v else 'f' + + @classmethod + def _adapt_num_array(cls, v: Any) -> str: + """Adapt a numeric array parameter.""" + if isinstance(v, list): + adapt = cls._adapt_num_array + v = '{' + ','.join(adapt(v) for v in v) + '}' + if not v and v != 0: + return 'null' + return str(v) + + _adapt_int_array = _adapt_float_array = _adapt_money_array = \ + _adapt_num_array + + def _adapt_bytea_array(self, v: Any) -> bytes: + """Adapt a bytea array parameter.""" + if isinstance(v, list): + return b'{' + b','.join( + self._adapt_bytea_array(v) for v in v) + b'}' + if v is None: + return b'null' + return self.db.escape_bytea(v).replace(b'\\', b'\\\\') + + def _adapt_json_array(self, v: Any) -> str: + """Adapt a json array parameter.""" + if isinstance(v, list): + adapt = self._adapt_json_array + return '{' + ','.join(adapt(v) for v in v) + '}' + if not v: + return 'null' + if not isinstance(v, str): + v = self.db.encode_json(v) + if self._re_array_quote.search(v): + v = self._re_array_escape.sub(r'\\\1', v) + v = f'"{v}"' + return v + + def _adapt_record(self, v: Any, typ: Any) -> str: + """Adapt a record parameter with given type.""" + typ = self.get_attnames(typ).values() + if len(typ) != len(v): + raise TypeError(f'Record parameter {v} has wrong size') + adapt = self.adapt + value = [] + for v, t in zip(v, typ): # noqa: B020 + v = adapt(v, t) + if v is None: + v = '' + else: + if isinstance(v, bytes): + v = v.decode('ascii') + elif not isinstance(v, str): + v = str(v) + if v: + if self._re_record_quote.search(v): + v = self._re_record_escape.sub(r'\\\1', v) + v = f'"{v}"' + else: + v = '""' + value.append(v) + v = ','.join(value) + return f'({v})' + + def adapt(self, value: Any, typ: Any = None) -> str: + """Adapt a value with known database type.""" + if value is not None and not isinstance(value, Literal): + if typ: + simple = self.get_simple_name(typ) + else: + typ = simple = self.guess_simple_type(value) or 'text' + pg_str = getattr(value, '__pg_str__', None) + if pg_str: + value = pg_str(typ) + if simple == 'text': + pass + elif simple == 'record': + if isinstance(value, tuple): + value = self._adapt_record(value, typ) + elif simple.endswith('[]'): + if isinstance(value, list): + adapt = getattr(self, f'_adapt_{simple[:-2]}_array') + value = adapt(value) + else: + adapt = getattr(self, f'_adapt_{simple}') + value = adapt(value) + return value + + @staticmethod + def simple_type(name: str) -> DbType: + """Create a simple database type with given attribute names.""" + typ = DbType(name) + typ.simple = name + return typ + + @staticmethod + def get_simple_name(typ: Any) -> str: + """Get the simple name of a database type.""" + if isinstance(typ, DbType): + # noinspection PyUnresolvedReferences + return typ.simple + return _simpletypes[typ] + + @staticmethod + def get_attnames(typ: Any) -> dict[str, dict[str, str]]: + """Get the attribute names of a composite database type.""" + if isinstance(typ, DbType): + return typ.attnames + return {} + + @classmethod + def guess_simple_type(cls, value: Any) -> str | None: + """Try to guess which database type the given value has.""" + # optimize for most frequent types + try: + return _simple_type_dict[type(value)] + except KeyError: + pass + if isinstance(value, (bytes, str)): + return 'text' + if isinstance(value, bool): + return 'bool' + if isinstance(value, int): + return 'int' + if isinstance(value, float): + return 'float' + if isinstance(value, Decimal): + return 'num' + if isinstance(value, (date, time, datetime, timedelta)): + return 'date' + if isinstance(value, Bytea): + return 'bytea' + if isinstance(value, Json): + return 'json' + if isinstance(value, Hstore): + return 'hstore' + if isinstance(value, UUID): + return 'uuid' + if isinstance(value, list): + return (cls.guess_simple_base_type(value) or 'text') + '[]' + if isinstance(value, tuple): + simple_type = cls.simple_type + guess = cls.guess_simple_type + + # noinspection PyUnusedLocal + def get_attnames(self: DbType) -> AttrDict: + return AttrDict((str(n + 1), simple_type(guess(v) or 'text')) + for n, v in enumerate(value)) + + typ = simple_type('record') + typ._get_attnames = get_attnames + return typ + return None + + @classmethod + def guess_simple_base_type(cls, value: Any) -> str | None: + """Try to guess the base type of a given array.""" + for v in value: + if isinstance(v, list): + typ = cls.guess_simple_base_type(v) + else: + typ = cls.guess_simple_type(v) + if typ: + return typ + return None + + def adapt_inline(self, value: Any, nested: bool=False) -> Any: + """Adapt a value that is put into the SQL and needs to be quoted.""" + if value is None: + return 'NULL' + if isinstance(value, Literal): + return value + if isinstance(value, Bytea): + value = self.db.escape_bytea(value).decode('ascii') + elif isinstance(value, (datetime, date, time, timedelta)): + value = str(value) + if isinstance(value, (bytes, str)): + value = self.db.escape_string(value) + return f"'{value}'" + if isinstance(value, bool): + return 'true' if value else 'false' + if isinstance(value, float): + if isinf(value): + return "'-Infinity'" if value < 0 else "'Infinity'" + if isnan(value): + return "'NaN'" + return value + if isinstance(value, (int, Decimal)): + return value + if isinstance(value, list): + q = self.adapt_inline + s = '[{}]' if nested else 'ARRAY[{}]' + return s.format(','.join(str(q(v, nested=True)) for v in value)) + if isinstance(value, tuple): + q = self.adapt_inline + return '({})'.format(','.join(str(q(v)) for v in value)) + if isinstance(value, Json): + value = self.db.escape_string(str(value)) + return f"'{value}'::json" + if isinstance(value, Hstore): + value = self.db.escape_string(str(value)) + return f"'{value}'::hstore" + pg_repr = getattr(value, '__pg_repr__', None) + if not pg_repr: + raise InterfaceError( + f'Do not know how to adapt type {type(value)}') + value = pg_repr() + if isinstance(value, (tuple, list)): + value = self.adapt_inline(value) + return value + + def parameter_list(self) -> _ParameterList: + """Return a parameter list for parameters with known database types. + + The list has an add(value, typ) method that will build up the + list and return either the literal value or a placeholder. + """ + params = _ParameterList() + params.adapt = self.adapt + return params + + def format_query(self, command: str, + values: list | tuple | dict | None = None, + types: list | tuple | dict | None = None, + inline: bool=False + ) -> tuple[str, _ParameterList]: + """Format a database query using the given values and types. + + The optional types describe the values and must be passed as a list, + tuple or string (that will be split on whitespace) when values are + passed as a list or tuple, or as a dict if values are passed as a dict. + + If inline is set to True, then parameters will be passed inline + together with the query string. + """ + params = self.parameter_list() + if not values: + return command, params + if inline and types: + raise ValueError('Typed parameters must be sent separately') + if isinstance(values, (list, tuple)): + if inline: + adapt = self.adapt_inline + seq_literals = [adapt(value) for value in values] + else: + add = params.add + if types: + if isinstance(types, str): + types = types.split() + if (not isinstance(types, (list, tuple)) + or len(types) != len(values)): + raise TypeError('The values and types do not match') + seq_literals = [add(value, typ) + for value, typ in zip(values, types)] + else: + seq_literals = [add(value) for value in values] + command %= tuple(seq_literals) + elif isinstance(values, dict): + # we want to allow extra keys in the dictionary, + # so we first must find the values actually used in the command + used_values = {} + map_literals = dict.fromkeys(values, '') + for key in values: + del map_literals[key] + try: + command % map_literals + except KeyError: + used_values[key] = values[key] # pyright: ignore + map_literals[key] = '' + if inline: + adapt = self.adapt_inline + map_literals = {key: adapt(value) + for key, value in used_values.items()} + else: + add = params.add + if types: + if not isinstance(types, dict): + raise TypeError('The values and types do not match') + map_literals = {key: add(used_values[key], types.get(key)) + for key in sorted(used_values)} + else: + map_literals = {key: add(used_values[key]) + for key in sorted(used_values)} + command %= map_literals + else: + raise TypeError('The values must be passed as tuple, list or dict') + return command, params diff --git a/pg/attrs.py b/pg/attrs.py new file mode 100644 index 00000000..7a5e6c41 --- /dev/null +++ b/pg/attrs.py @@ -0,0 +1,35 @@ +"""Helpers for memorizing attributes.""" + +from typing import Any + +__all__ = ['AttrDict'] + + +class AttrDict(dict): + """Simple read-only ordered dictionary for storing attribute names.""" + + def __init__(self, *args: Any, **kw: Any) -> None: + """Initialize the dictionary.""" + self._read_only = False + super().__init__(*args, **kw) + self._read_only = True + error = self._read_only_error + self.clear = self.update = error # type: ignore + self.pop = self.setdefault = self.popitem = error # type: ignore + + def __setitem__(self, key: str, value: Any) -> None: + """Set a value.""" + if self._read_only: + self._read_only_error() + super().__setitem__(key, value) + + def __delitem__(self, key: str) -> None: + """Delete a value.""" + if self._read_only: + self._read_only_error() + super().__delitem__(key) + + @staticmethod + def _read_only_error(*_args: Any, **_kw: Any) -> Any: + """Raise error for write operations.""" + raise TypeError('This object is read-only') diff --git a/pg/cast.py b/pg/cast.py new file mode 100644 index 00000000..ad1758be --- /dev/null +++ b/pg/cast.py @@ -0,0 +1,436 @@ +"""Typecasting mechanisms.""" + +from __future__ import annotations + +from collections import namedtuple +from datetime import date, datetime, timedelta +from functools import partial +from inspect import signature +from re import compile as regex +from typing import Any, Callable, ClassVar, Sequence +from uuid import UUID + +from .attrs import AttrDict +from .core import ( + Connection, + cast_array, + cast_hstore, + cast_record, + get_bool, + get_decimal, + get_decimal_point, + get_jsondecode, + unescape_bytea, +) +from .tz import timezone_as_offset + +__all__ = [ + 'cast_bool', 'cast_json', 'cast_num', 'cast_money', 'cast_int2vector', + 'cast_date', 'cast_time', 'cast_timetz', 'cast_interval', + 'cast_timestamp','cast_timestamptz', + 'Typecasts', 'get_typecast', 'set_typecast' +] + +def get_args(func: Callable) -> list: + """Get the arguments of a function.""" + return list(signature(func).parameters) + + +def cast_bool(value: str) -> Any: + """Cast a boolean value.""" + if not get_bool(): + return value + return value[0] == 't' + + +def cast_json(value: str) -> Any: + """Cast a JSON value.""" + cast = get_jsondecode() + if not cast: + return value + return cast(value) + + +def cast_num(value: str) -> Any: + """Cast a numeric value.""" + return (get_decimal() or float)(value) + + +def cast_money(value: str) -> Any: + """Cast a money value.""" + point = get_decimal_point() + if not point: + return value + if point != '.': + value = value.replace(point, '.') + value = value.replace('(', '-') + value = ''.join(c for c in value if c.isdigit() or c in '.-') + return (get_decimal() or float)(value) + + +def cast_int2vector(value: str) -> list[int]: + """Cast an int2vector value.""" + return [int(v) for v in value.split()] + + +def cast_date(value: str, connection: Connection) -> Any: + """Cast a date value.""" + # The output format depends on the server setting DateStyle. The default + # setting ISO and the setting for German are actually unambiguous. The + # order of days and months in the other two settings is however ambiguous, + # so at least here we need to consult the setting to properly parse values. + if value == '-infinity': + return date.min + if value == 'infinity': + return date.max + values = value.split() + if values[-1] == 'BC': + return date.min + value = values[0] + if len(value) > 10: + return date.max + format = connection.date_format() + return datetime.strptime(value, format).date() + + +def cast_time(value: str) -> Any: + """Cast a time value.""" + format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' + return datetime.strptime(value, format).time() + + +_re_timezone = regex('(.*)([+-].*)') + + +def cast_timetz(value: str) -> Any: + """Cast a timetz value.""" + m = _re_timezone.match(value) + if m: + value, tz = m.groups() + else: + tz = '+0000' + format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' + value += timezone_as_offset(tz) + format += '%z' + return datetime.strptime(value, format).timetz() + + +def cast_timestamp(value: str, connection: Connection) -> Any: + """Cast a timestamp value.""" + if value == '-infinity': + return datetime.min + if value == 'infinity': + return datetime.max + values = value.split() + if values[-1] == 'BC': + return datetime.min + format = connection.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:5] + if len(values[3]) > 4: + return datetime.max + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] + else: + if len(values[0]) > 10: + return datetime.max + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + return datetime.strptime(' '.join(values), ' '.join(formats)) + + +def cast_timestamptz(value: str, connection: Connection) -> Any: + """Cast a timestamptz value.""" + if value == '-infinity': + return datetime.min + if value == 'infinity': + return datetime.max + values = value.split() + if values[-1] == 'BC': + return datetime.min + format = connection.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:] + if len(values[3]) > 4: + return datetime.max + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] + values, tz = values[:-1], values[-1] + else: + if format.startswith('%Y-'): + m = _re_timezone.match(values[1]) + if m: + values[1], tz = m.groups() + else: + tz = '+0000' + else: + values, tz = values[:-1], values[-1] + if len(values[0]) > 10: + return datetime.max + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + values.append(timezone_as_offset(tz)) + formats.append('%z') + return datetime.strptime(' '.join(values), ' '.join(formats)) + + +_re_interval_sql_standard = regex( + '(?:([+-])?([0-9]+)-([0-9]+) ?)?' + '(?:([+-]?[0-9]+)(?!:) ?)?' + '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') + +_re_interval_postgres = regex( + '(?:([+-]?[0-9]+) ?years? ?)?' + '(?:([+-]?[0-9]+) ?mons? ?)?' + '(?:([+-]?[0-9]+) ?days? ?)?' + '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') + +_re_interval_postgres_verbose = regex( + '@ ?(?:([+-]?[0-9]+) ?years? ?)?' + '(?:([+-]?[0-9]+) ?mons? ?)?' + '(?:([+-]?[0-9]+) ?days? ?)?' + '(?:([+-]?[0-9]+) ?hours? ?)?' + '(?:([+-]?[0-9]+) ?mins? ?)?' + '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?') + +_re_interval_iso_8601 = regex( + 'P(?:([+-]?[0-9]+)Y)?' + '(?:([+-]?[0-9]+)M)?' + '(?:([+-]?[0-9]+)D)?' + '(?:T(?:([+-]?[0-9]+)H)?' + '(?:([+-]?[0-9]+)M)?' + '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?') + + +def cast_interval(value: str) -> timedelta: + """Cast an interval value.""" + # The output format depends on the server setting IntervalStyle, but it's + # not necessary to consult this setting to parse it. It's faster to just + # check all possible formats, and there is no ambiguity here. + m = _re_interval_iso_8601.match(value) + if m: + s = [v or '0' for v in m.groups()] + secs_ago = s.pop(5) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if secs_ago: + secs = -secs + usecs = -usecs + else: + m = _re_interval_postgres_verbose.match(value) + if m: + s, ago = [v or '0' for v in m.groups()[:8]], m.group(9) + secs_ago = s.pop(5) == '-' + d = [-int(v) for v in s] if ago else [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if secs_ago: + secs = - secs + usecs = -usecs + else: + m = _re_interval_postgres.match(value) + if m and any(m.groups()): + s = [v or '0' for v in m.groups()] + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if hours_ago: + hours = -hours + mins = -mins + secs = -secs + usecs = -usecs + else: + m = _re_interval_sql_standard.match(value) + if m and any(m.groups()): + s = [v or '0' for v in m.groups()] + years_ago = s.pop(0) == '-' + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if years_ago: + years = -years + mons = -mons + if hours_ago: + hours = -hours + mins = -mins + secs = -secs + usecs = -usecs + else: + raise ValueError(f'Cannot parse interval: {value}') + days += 365 * years + 30 * mons + return timedelta(days=days, hours=hours, minutes=mins, + seconds=secs, microseconds=usecs) + + +class Typecasts(dict): + """Dictionary mapping database types to typecast functions. + + The cast functions get passed the string representation of a value in + the database which they need to convert to a Python object. The + passed string will never be None since NULL values are already + handled before the cast function is called. + + Note that the basic types are already handled by the C extension. + They only need to be handled here as record or array components. + """ + + # the default cast functions + # (str functions are ignored but have been added for faster access) + defaults: ClassVar[dict[str, Callable]] = { + 'char': str, 'bpchar': str, 'name': str, + 'text': str, 'varchar': str, 'sql_identifier': str, + 'bool': cast_bool, 'bytea': unescape_bytea, + 'int2': int, 'int4': int, 'serial': int, 'int8': int, 'oid': int, + 'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json, + 'float4': float, 'float8': float, + 'numeric': cast_num, 'money': cast_money, + 'date': cast_date, 'interval': cast_interval, + 'time': cast_time, 'timetz': cast_timetz, + 'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz, + 'int2vector': cast_int2vector, 'uuid': UUID, + 'anyarray': cast_array, 'record': cast_record} # pyright: ignore + + connection: Connection | None = None # set in connection specific instance + + def __missing__(self, typ: str) -> Callable | None: + """Create a cast function if it is not cached. + + Note that this class never raises a KeyError, + but returns None when no special cast function exists. + """ + if not isinstance(typ, str): + raise TypeError(f'Invalid type: {typ}') + cast: Callable | None = self.defaults.get(typ) + if cast: + # store default for faster access + cast = self._add_connection(cast) + self[typ] = cast + elif typ.startswith('_'): + base_cast = self[typ[1:]] + cast = self.create_array_cast(base_cast) + if base_cast: + self[typ] = cast + else: + attnames = self.get_attnames(typ) + if attnames: + casts = [self[v.pgtype] for v in attnames.values()] + cast = self.create_record_cast(typ, attnames, casts) + self[typ] = cast + return cast + + @staticmethod + def _needs_connection(func: Callable) -> bool: + """Check if a typecast function needs a connection argument.""" + try: + args = get_args(func) + except (TypeError, ValueError): + return False + return 'connection' in args[1:] + + def _add_connection(self, cast: Callable) -> Callable: + """Add a connection argument to the typecast function if necessary.""" + if not self.connection or not self._needs_connection(cast): + return cast + return partial(cast, connection=self.connection) + + def get(self, typ: str, default: Callable | None = None # type: ignore + ) -> Callable | None: + """Get the typecast function for the given database type.""" + return self[typ] or default + + def set(self, typ: str | Sequence[str], cast: Callable | None) -> None: + """Set a typecast function for the specified database type(s).""" + if isinstance(typ, str): + typ = [typ] + if cast is None: + for t in typ: + self.pop(t, None) + self.pop(f'_{t}', None) + else: + if not callable(cast): + raise TypeError("Cast parameter must be callable") + for t in typ: + self[t] = self._add_connection(cast) + self.pop(f'_{t}', None) + + def reset(self, typ: str | Sequence[str] | None = None) -> None: + """Reset the typecasts for the specified type(s) to their defaults. + + When no type is specified, all typecasts will be reset. + """ + if typ is None: + self.clear() + else: + if isinstance(typ, str): + typ = [typ] + for t in typ: + self.pop(t, None) + + @classmethod + def get_default(cls, typ: str) -> Any: + """Get the default typecast function for the given database type.""" + return cls.defaults.get(typ) + + @classmethod + def set_default(cls, typ: str | Sequence[str], + cast: Callable | None) -> None: + """Set a default typecast function for the given database type(s).""" + if isinstance(typ, str): + typ = [typ] + defaults = cls.defaults + if cast is None: + for t in typ: + defaults.pop(t, None) + defaults.pop(f'_{t}', None) + else: + if not callable(cast): + raise TypeError("Cast parameter must be callable") + for t in typ: + defaults[t] = cast + defaults.pop(f'_{t}', None) + + # noinspection PyMethodMayBeStatic,PyUnusedLocal + def get_attnames(self, typ: Any) -> AttrDict: + """Return the fields for the given record type. + + This method will be replaced with the get_attnames() method of DbTypes. + """ + return AttrDict() + + # noinspection PyMethodMayBeStatic + def dateformat(self) -> str: + """Return the current date format. + + This method will be replaced with the dateformat() method of DbTypes. + """ + return '%Y-%m-%d' + + def create_array_cast(self, basecast: Callable) -> Callable: + """Create an array typecast for the given base cast.""" + cast_array = self['anyarray'] + + def cast(v: Any) -> list: + return cast_array(v, basecast) + return cast + + def create_record_cast(self, name: str, fields: AttrDict, + casts: list[Callable]) -> Callable: + """Create a named record typecast for the given fields and casts.""" + cast_record = self['record'] + record = namedtuple(name, fields) # type: ignore + + def cast(v: Any) -> record: + # noinspection PyArgumentList + return record(*cast_record(v, casts)) + return cast + + +def get_typecast(typ: str) -> Callable | None: + """Get the global typecast function for the given database type.""" + return Typecasts.get_default(typ) + + +def set_typecast(typ: str | Sequence[str], cast: Callable | None) -> None: + """Set a global typecast function for the given database type(s). + + Note that connections cache cast functions. To be sure a global change + is picked up by a running connection, call db.db_types.reset_typecast(). + """ + Typecasts.set_default(typ, cast) diff --git a/pg/core.py b/pg/core.py new file mode 100644 index 00000000..3eb8f745 --- /dev/null +++ b/pg/core.py @@ -0,0 +1,135 @@ +"""Core functionality from extension module.""" + +try: + from ._pg import version +except ImportError as e: # noqa: F841 + import os + libpq = 'libpq.' + if os.name == 'nt': + libpq += 'dll' + import sys + paths = [path for path in os.environ["PATH"].split(os.pathsep) + if os.path.exists(os.path.join(path, libpq))] + if sys.version_info >= (3, 8): + # see https://docs.python.org/3/whatsnew/3.8.html#ctypes + add_dll_dir = os.add_dll_directory # type: ignore + for path in paths: + with add_dll_dir(os.path.abspath(path)): + try: + from ._pg import version + except ImportError: + pass + else: + del version + e = None # type: ignore + break + if paths: + libpq = 'compatible ' + libpq + else: + libpq += 'so' + if e: + raise ImportError( + "Cannot import shared library for PyGreSQL,\n" + f"probably because no {libpq} is installed.\n{e}") from e +else: + del version + +# import objects from extension module +from ._pg import ( + INV_READ, + INV_WRITE, + POLLING_FAILED, + POLLING_OK, + POLLING_READING, + POLLING_WRITING, + RESULT_DDL, + RESULT_DML, + RESULT_DQL, + RESULT_EMPTY, + SEEK_CUR, + SEEK_END, + SEEK_SET, + TRANS_ACTIVE, + TRANS_IDLE, + TRANS_INERROR, + TRANS_INTRANS, + TRANS_UNKNOWN, + Connection, + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + InvalidResultError, + MultipleResultsError, + NoResultError, + NotSupportedError, + OperationalError, + ProgrammingError, + Query, + Warning, + cast_array, + cast_hstore, + cast_record, + connect, + escape_bytea, + escape_string, + get_array, + get_bool, + get_bytea_escaped, + get_datestyle, + get_decimal, + get_decimal_point, + get_defbase, + get_defhost, + get_defopt, + get_defport, + get_defuser, + get_jsondecode, + get_pqlib_version, + set_array, + set_bool, + set_bytea_escaped, + set_datestyle, + set_decimal, + set_decimal_point, + set_defbase, + set_defhost, + set_defopt, + set_defpasswd, + set_defport, + set_defuser, + set_jsondecode, + set_query_helpers, + unescape_bytea, + version, +) + +__all__ = [ + 'Error', 'Warning', + 'DataError', 'DatabaseError', + 'IntegrityError', 'InterfaceError', 'InternalError', + 'InvalidResultError', 'MultipleResultsError', + 'NoResultError', 'NotSupportedError', + 'OperationalError', 'ProgrammingError', + 'Connection', 'Query', + 'INV_READ', 'INV_WRITE', + 'POLLING_OK', 'POLLING_FAILED', 'POLLING_READING', 'POLLING_WRITING', + 'RESULT_DDL', 'RESULT_DML', 'RESULT_DQL', 'RESULT_EMPTY', + 'SEEK_CUR', 'SEEK_END', 'SEEK_SET', + 'TRANS_ACTIVE', 'TRANS_IDLE', 'TRANS_INERROR', + 'TRANS_INTRANS', 'TRANS_UNKNOWN', + 'cast_array', 'cast_hstore', 'cast_record', + 'connect', 'escape_bytea', 'escape_string', 'unescape_bytea', + 'get_array', 'get_bool', 'get_bytea_escaped', + 'get_datestyle', 'get_decimal', 'get_decimal_point', + 'get_defbase', 'get_defhost', 'get_defopt', 'get_defport', 'get_defuser', + 'get_jsondecode', 'get_pqlib_version', + 'set_array', 'set_bool', 'set_bytea_escaped', + 'set_datestyle', 'set_decimal', 'set_decimal_point', + 'set_defbase', 'set_defhost', 'set_defopt', + 'set_defpasswd', 'set_defport', 'set_defuser', + 'set_jsondecode', 'set_query_helpers', + 'version', +] diff --git a/pg/db.py b/pg/db.py new file mode 100644 index 00000000..ce7915f8 --- /dev/null +++ b/pg/db.py @@ -0,0 +1,1332 @@ +"""Connection wrapper.""" + +from __future__ import annotations + +from contextlib import suppress +from json import dumps as jsonencode +from json import loads as jsondecode +from operator import itemgetter +from typing import Any, Callable, Iterator, Sequence + +from . import Connection, connect +from .adapt import Adapter, DbTypes +from .attrs import AttrDict +from .core import ( + InternalError, + ProgrammingError, + Query, + get_bool, + get_jsondecode, + unescape_bytea, +) +from .error import db_error, int_error, prg_error +from .helpers import namediter, oid_key, quote_if_unqualified +from .notify import NotificationHandler + +__all__ = ['DB'] + +# The actual PostgreSQL database connection interface: + +class DB: + """Wrapper class for the _pg connection type.""" + + db: Connection | None = None # invalid fallback for underlying connection + _db_args: Any # either the connect args or the underlying connection + + def __init__(self, *args: Any, **kw: Any) -> None: + """Create a new connection. + + You can pass either the connection parameters or an existing + _pg or pgdb connection. This allows you to use the methods + of the classic pg interface with a DB-API 2 pgdb connection. + """ + if not args and len(kw) == 1: + db = kw.get('db') + elif not kw and len(args) == 1: + db = args[0] + else: + db = None + if db: + if isinstance(db, DB): + db = db.db + else: + with suppress(AttributeError): + # noinspection PyUnresolvedReferences + db = db._cnx + if not db or not hasattr(db, 'db') or not hasattr(db, 'query'): + db = connect(*args, **kw) + self._db_args = args, kw + self._closeable = True + else: + self._db_args = db + self._closeable = False + self.db = db + self.dbname = db.db + self._regtypes = False + self._attnames: dict[str, AttrDict] = {} + self._generated: dict[str, frozenset[str]] = {} + self._pkeys: dict[str, str | tuple[str, ...]] = {} + self._privileges: dict[tuple[str, str], bool] = {} + self.adapter = Adapter(self) + self.dbtypes = DbTypes(self) + self._query_attnames = ( + "SELECT a.attname," + " t.oid, t.typname, t.oid::pg_catalog.regtype," + " t.typlen, t.typtype, t.typcategory, t.typdelim, t.typrelid" + " FROM pg_catalog.pg_attribute a" + " JOIN pg_catalog.pg_type t" + " ON t.oid OPERATOR(pg_catalog.=) a.atttypid" + " WHERE a.attrelid OPERATOR(pg_catalog.=)" + " {}::pg_catalog.regclass" + " AND {} AND NOT a.attisdropped ORDER BY a.attnum") + if db.server_version < 120000: + self._query_generated = ( + "a.attidentity OPERATOR(pg_catalog.=) 'a'" + ) + else: + self._query_generated = ( + "(a.attidentity OPERATOR(pg_catalog.=) 'a' OR" + " a.attgenerated OPERATOR(pg_catalog.!=) '')" + ) + db.set_cast_hook(self.dbtypes.typecast) + # For debugging scripts, self.debug can be set + # * to a string format specification (e.g. in CGI set to "%s
"), + # * to a file object to write debug statements or + # * to a callable object which takes a string argument + # * to any other true value to just print debug statements + self.debug: Any = None + + def __getattr__(self, name: str) -> Any: + """Get the specified attritbute of the connection.""" + # All undefined members are same as in underlying connection: + if self.db: + return getattr(self.db, name) + else: + raise int_error('Connection is not valid') + + def __dir__(self) -> list[str]: + """List all attributes of the connection.""" + # Custom dir function including the attributes of the connection: + attrs = set(self.__class__.__dict__) + attrs.update(self.__dict__) + attrs.update(dir(self.db)) + return sorted(attrs) + + # Context manager methods + + def __enter__(self) -> DB: + """Enter the runtime context. This will start a transaction.""" + self.begin() + return self + + def __exit__(self, et: type[BaseException] | None, + ev: BaseException | None, tb: Any) -> None: + """Exit the runtime context. This will end the transaction.""" + if et is None and ev is None and tb is None: + self.commit() + else: + self.rollback() + + def __del__(self) -> None: + """Delete the connection.""" + try: + db = self.db + except AttributeError: + db = None + if db: + with suppress(TypeError): # when already closed + db.set_cast_hook(None) + if self._closeable: + with suppress(InternalError): # when already closed + db.close() + + # Auxiliary methods + + def _do_debug(self, *args: Any) -> None: + """Print a debug message.""" + if self.debug: + s = '\n'.join(str(arg) for arg in args) + if isinstance(self.debug, str): + print(self.debug % s) + elif hasattr(self.debug, 'write'): + # noinspection PyCallingNonCallable + self.debug.write(s + '\n') + elif callable(self.debug): + self.debug(s) + else: + print(s) + + def _escape_qualified_name(self, s: str) -> str: + """Escape a qualified name. + + Escapes the name for use as an SQL identifier, unless the + name contains a dot, in which case the name is ambiguous + (could be a qualified name or just a name with a dot in it) + and must be quoted manually by the caller. + """ + if '.' not in s: + s = self.escape_identifier(s) + return s + + @staticmethod + def _make_bool(d: Any) -> bool | str: + """Get boolean value corresponding to d.""" + return bool(d) if get_bool() else ('t' if d else 'f') + + @staticmethod + def _list_params(params: Sequence) -> str: + """Create a human readable parameter list.""" + return ', '.join(f'${n}={v!r}' for n, v in enumerate(params, 1)) + + @property + def _valid_db(self) -> Connection: + """Get underlying connection and make sure it is not closed.""" + db = self.db + if not db: + raise int_error('Connection already closed') + return db + + # Public methods + + # escape_string and escape_bytea exist as methods, + # so we define unescape_bytea as a method as well + unescape_bytea = staticmethod(unescape_bytea) + + @staticmethod + def decode_json(s: str) -> Any: + """Decode a JSON string coming from the database.""" + return (get_jsondecode() or jsondecode)(s) + + @staticmethod + def encode_json(d: Any) -> str: + """Encode a JSON string for use within SQL.""" + return jsonencode(d) + + def close(self) -> None: + """Close the database connection.""" + # Wraps shared library function so we can track state. + db = self._valid_db + with suppress(TypeError): # when already closed + db.set_cast_hook(None) + if self._closeable: + db.close() + self.db = None + + def reset(self) -> None: + """Reset connection with current parameters. + + All derived queries and large objects derived from this connection + will not be usable after this call. + """ + self._valid_db.reset() + + def reopen(self) -> None: + """Reopen connection to the database. + + Used in case we need another connection to the same database. + Note that we can still reopen a database that we have closed. + """ + # There is no such shared library function. + if self._closeable: + args, kw = self._db_args + db = connect(*args, **kw) + if self.db: + self.db.set_cast_hook(None) + self.db.close() + db.set_cast_hook(self.dbtypes.typecast) + self.db = db + else: + self.db = self._db_args + + def begin(self, mode: str | None = None) -> Query: + """Begin a transaction.""" + qstr = 'BEGIN' + if mode: + qstr += ' ' + mode + return self.query(qstr) + + start = begin + + def commit(self) -> Query: + """Commit the current transaction.""" + return self.query('COMMIT') + + end = commit + + def rollback(self, name: str | None = None) -> Query: + """Roll back the current transaction.""" + qstr = 'ROLLBACK' + if name: + qstr += ' TO ' + name + return self.query(qstr) + + abort = rollback + + def savepoint(self, name: str) -> Query: + """Define a new savepoint within the current transaction.""" + return self.query('SAVEPOINT ' + name) + + def release(self, name: str) -> Query: + """Destroy a previously defined savepoint.""" + return self.query('RELEASE ' + name) + + def get_parameter(self, + parameter: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str] | dict[str, Any] + ) -> str | list[str] | dict[str, str]: + """Get the value of a run-time parameter. + + If the parameter is a string, the return value will also be a string + that is the current setting of the run-time parameter with that name. + + You can get several parameters at once by passing a list, set or dict. + When passing a list of parameter names, the return value will be a + corresponding list of parameter settings. When passing a set of + parameter names, a new dict will be returned, mapping these parameter + names to their settings. Finally, if you pass a dict as parameter, + its values will be set to the current parameter settings corresponding + to its keys. + + By passing the special name 'all' as the parameter, you can get a dict + of all existing configuration parameters. + """ + values: Any + if isinstance(parameter, str): + parameter = [parameter] + values = None + elif isinstance(parameter, (list, tuple)): + values = [] + elif isinstance(parameter, (set, frozenset)): + values = {} + elif isinstance(parameter, dict): + values = parameter + else: + raise TypeError( + 'The parameter must be a string, list, set or dict') + if not parameter: + raise TypeError('No parameter has been specified') + query = self._valid_db.query + params: Any = {} if isinstance(values, dict) else [] + for param_key in parameter: + param = param_key.strip().lower() if isinstance( + param_key, (bytes, str)) else None + if not param: + raise TypeError('Invalid parameter') + if param == 'all': + cmd = 'SHOW ALL' + values = query(cmd).getresult() + values = {value[0]: value[1] for value in values} + break + if isinstance(params, dict): + params[param] = param_key + else: + params.append(param) + else: + for param in params: + cmd = f'SHOW {param}' + value = query(cmd).singlescalar() + if values is None: + values = value + elif isinstance(values, list): + values.append(value) + else: + values[params[param]] = value + return values + + def set_parameter(self, + parameter: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str] | dict[str, Any], + value: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str]| None = None, + local: bool = False) -> None: + """Set the value of a run-time parameter. + + If the parameter and the value are strings, the run-time parameter + will be set to that value. If no value or None is passed as a value, + then the run-time parameter will be restored to its default value. + + You can set several parameters at once by passing a list of parameter + names, together with a single value that all parameters should be + set to or with a corresponding list of values. You can also pass + the parameters as a set if you only provide a single value. + Finally, you can pass a dict with parameter names as keys. In this + case, you should not pass a value, since the values for the parameters + will be taken from the dict. + + By passing the special name 'all' as the parameter, you can reset + all existing settable run-time parameters to their default values. + + If you set local to True, then the command takes effect for only the + current transaction. After commit() or rollback(), the session-level + setting takes effect again. Setting local to True will appear to + have no effect if it is executed outside a transaction, since the + transaction will end immediately. + """ + if isinstance(parameter, str): + parameter = {parameter: value} + elif isinstance(parameter, (list, tuple)): + if isinstance(value, (list, tuple)): + parameter = dict(zip(parameter, value)) + else: + parameter = dict.fromkeys(parameter, value) + elif isinstance(parameter, (set, frozenset)): + if isinstance(value, (list, tuple, set, frozenset)): + value = set(value) + if len(value) == 1: + value = next(iter(value)) + if not (value is None or isinstance(value, str)): + raise ValueError( + 'A single value must be specified' + ' when parameter is a set') + parameter = dict.fromkeys(parameter, value) + elif isinstance(parameter, dict): + if value is not None: + raise ValueError( + 'A value must not be specified' + ' when parameter is a dictionary') + else: + raise TypeError( + 'The parameter must be a string, list, set or dict') + if not parameter: + raise TypeError('No parameter has been specified') + params: dict[str, str | None] = {} + for param, param_value in parameter.items(): + param = param.strip().lower() if isinstance(param, str) else None + if not param: + raise TypeError('Invalid parameter') + if param == 'all': + if param_value is not None: + raise ValueError( + 'A value must not be specified' + " when parameter is 'all'") + params = {'all': None} + break + params[param] = param_value + local_clause = ' LOCAL' if local else '' + for param, param_value in params.items(): + cmd = (f'RESET{local_clause} {param}' + if param_value is None else + f'SET{local_clause} {param} TO {param_value}') + self._do_debug(cmd) + self._valid_db.query(cmd) + + def query(self, command: str, *args: Any) -> Query: + """Execute a SQL command string. + + This method simply sends a SQL query to the database. If the query is + an insert statement that inserted exactly one row into a table that + has OIDs, the return value is the OID of the newly inserted row. + If the query is an update or delete statement, or an insert statement + that did not insert exactly one row in a table with OIDs, then the + number of rows affected is returned as a string. If it is a statement + that returns rows as a result (usually a select statement, but maybe + also an "insert/update ... returning" statement), this method returns + a Query object that can be accessed via getresult() or dictresult() + or simply printed. Otherwise, it returns `None`. + + The query can contain numbered parameters of the form $1 in place + of any data constant. Arguments given after the query string will + be substituted for the corresponding numbered parameter. Parameter + values can also be given as a single list or tuple argument. + """ + # Wraps shared library function for debugging. + db = self._valid_db + if args: + self._do_debug(command, args) + return db.query(command, args) + self._do_debug(command) + return db.query(command) + + def query_formatted(self, command: str, + parameters: tuple | list | dict | None = None, + types: tuple | list | dict | None = None, + inline: bool =False) -> Query: + """Execute a formatted SQL command string. + + Similar to query, but using Python format placeholders of the form + %s or %(names)s instead of PostgreSQL placeholders of the form $1. + The parameters must be passed as a tuple, list or dict. You can + also pass a corresponding tuple, list or dict of database types in + order to format the parameters properly in case there is ambiguity. + + If you set inline to True, the parameters will be sent to the database + embedded in the SQL command, otherwise they will be sent separately. + """ + return self.query(*self.adapter.format_query( + command, parameters, types, inline)) + + def query_prepared(self, name: str, *args: Any) -> Query: + """Execute a prepared SQL statement. + + This works like the query() method, except that instead of passing + the SQL command, you pass the name of a prepared statement. If you + pass an empty name, the unnamed statement will be executed. + """ + if name is None: + name = '' + db = self._valid_db + if args: + self._do_debug('EXECUTE', name, args) + return db.query_prepared(name, args) + self._do_debug('EXECUTE', name) + return db.query_prepared(name) + + def prepare(self, name: str, command: str) -> None: + """Create a prepared SQL statement. + + This creates a prepared statement for the given command with the + given name for later execution with the query_prepared() method. + + The name can be empty to create an unnamed statement, in which case + any pre-existing unnamed statement is automatically replaced; + otherwise it is an error if the statement name is already + defined in the current database session. We recommend always using + named queries, since unnamed queries have a limited lifetime and + can be automatically replaced or destroyed by various operations. + """ + if name is None: + name = '' + self._do_debug('prepare', name, command) + self._valid_db.prepare(name, command) + + def describe_prepared(self, name: str | None = None) -> Query: + """Describe a prepared SQL statement. + + This method returns a Query object describing the result columns of + the prepared statement with the given name. If you omit the name, + the unnamed statement will be described if you created one before. + """ + if name is None: + name = '' + return self._valid_db.describe_prepared(name) + + def delete_prepared(self, name: str | None = None) -> Query: + """Delete a prepared SQL statement. + + This deallocates a previously prepared SQL statement with the given + name, or deallocates all prepared statements if you do not specify a + name. Note that prepared statements are also deallocated automatically + when the current session ends. + """ + if not name: + name = 'ALL' + cmd = f"DEALLOCATE {name}" + self._do_debug(cmd) + return self._valid_db.query(cmd) + + def pkey(self, table: str, composite: bool = False, flush: bool = False + ) -> str | tuple[str, ...]: + """Get the primary key of a table. + + Single primary keys are returned as strings unless you + set the composite flag. Composite primary keys are always + represented as tuples. Note that this raises a KeyError + if the table does not have a primary key. + + If flush is set then the internal cache for primary keys will + be flushed. This may be necessary after the database schema or + the search path has been changed. + """ + pkeys = self._pkeys + if flush: + pkeys.clear() + self._do_debug('The pkey cache has been flushed') + try: # cache lookup + pkey = pkeys[table] + except KeyError as e: # cache miss, check the database + cmd = ("SELECT" # noqa: S608 + " a.attname, a.attnum, i.indkey" + " FROM pg_catalog.pg_index i" + " JOIN pg_catalog.pg_attribute a" + " ON a.attrelid OPERATOR(pg_catalog.=) i.indrelid" + " AND a.attnum OPERATOR(pg_catalog.=) ANY(i.indkey)" + " AND NOT a.attisdropped" + " WHERE i.indrelid OPERATOR(pg_catalog.=)" + " {}::pg_catalog.regclass" + " AND i.indisprimary ORDER BY a.attnum").format( + quote_if_unqualified('$1', table)) + res = self._valid_db.query(cmd, (table,)).getresult() + if not res: + raise KeyError(f'Table {table} has no primary key') from e + # we want to use the order defined in the primary key index here, + # not the order as defined by the columns in the table + if len(res) > 1: + indkey = res[0][2] + pkey = tuple(row[0] for row in sorted( + res, key=lambda row: indkey.index(row[1]))) + else: + pkey = res[0][0] + pkeys[table] = pkey # cache it + if composite and not isinstance(pkey, tuple): + pkey = (pkey,) + return pkey + + def pkeys(self, table: str) -> tuple[str, ...]: + """Get the primary key of a table as a tuple. + + Same as pkey() with 'composite' set to True. + """ + return self.pkey(table, True) # type: ignore + + def get_databases(self) -> list[str]: + """Get list of databases in the system.""" + return [r[0] for r in self._valid_db.query( + 'SELECT datname FROM pg_catalog.pg_database').getresult()] + + def get_relations(self, kinds: str | Sequence[str] | None = None, + system: bool = False) -> list[str]: + """Get list of relations in connected database of specified kinds. + + If kinds is None or empty, all kinds of relations are returned. + Otherwise, kinds can be a string or sequence of type letters + specifying which kind of relations you want to list. + + Set the system flag if you want to get the system relations as well. + """ + where_parts = [] + if kinds: + where_parts.append( + "r.relkind IN ({})".format(','.join(f"'{k}'" for k in kinds))) + if not system: + where_parts.append("s.nspname NOT SIMILAR" + " TO 'pg/_%|information/_schema' ESCAPE '/'") + where = " WHERE " + ' AND '.join(where_parts) if where_parts else '' + cmd = ("SELECT" # noqa: S608 + " pg_catalog.quote_ident(s.nspname) OPERATOR(pg_catalog.||)" + " '.' OPERATOR(pg_catalog.||) pg_catalog.quote_ident(r.relname)" + " FROM pg_catalog.pg_class r" + " JOIN pg_catalog.pg_namespace s" + f" ON s.oid OPERATOR(pg_catalog.=) r.relnamespace{where}" + " ORDER BY s.nspname, r.relname") + return [r[0] for r in self._valid_db.query(cmd).getresult()] + + def get_tables(self, system: bool = False) -> list[str]: + """Return list of tables in connected database. + + Set the system flag if you want to get the system tables as well. + """ + return self.get_relations('r', system) + + def get_attnames(self, table: str, with_oid: bool=True, flush: bool=False + ) -> AttrDict: + """Given the name of a table, dig out the set of attribute names. + + Returns a read-only dictionary of attribute names (the names are + the keys, the values are the names of the attributes' types) + with the column names in the proper order if you iterate over it. + + If flush is set, then the internal cache for attribute names will + be flushed. This may be necessary after the database schema or + the search path has been changed. + + By default, only a limited number of simple types will be returned. + You can get the registered types after calling use_regtypes(True). + """ + attnames = self._attnames + if flush: + attnames.clear() + self._do_debug('The attnames cache has been flushed') + try: # cache lookup + names = attnames[table] + except KeyError: # cache miss, check the database + cmd = "a.attnum OPERATOR(pg_catalog.>) 0" + if with_oid: + cmd = f"({cmd} OR a.attname OPERATOR(pg_catalog.=) 'oid')" + cmd = self._query_attnames.format( + quote_if_unqualified('$1', table), cmd) + res = self._valid_db.query(cmd, (table,)).getresult() + types = self.dbtypes + names = AttrDict((name[0], types.add(*name[1:])) for name in res) + attnames[table] = names # cache it + return names + + def get_generated(self, table: str, flush: bool = False) -> frozenset[str]: + """Given the name of a table, dig out the set of generated columns. + + Returns a set of column names that are generated and unalterable. + + If flush is set, then the internal cache for generated columns will + be flushed. This may be necessary after the database schema or + the search path has been changed. + """ + generated = self._generated + if flush: + generated.clear() + self._do_debug('The generated cache has been flushed') + try: # cache lookup + names = generated[table] + except KeyError: # cache miss, check the database + cmd = "a.attnum OPERATOR(pg_catalog.>) 0" + cmd = f"{cmd} AND {self._query_generated}" + cmd = self._query_attnames.format( + quote_if_unqualified('$1', table), cmd) + res = self._valid_db.query(cmd, (table,)).getresult() + names = frozenset(name[0] for name in res) + generated[table] = names # cache it + return names + + def use_regtypes(self, regtypes: bool | None = None) -> bool: + """Use registered type names instead of simplified type names.""" + if regtypes is None: + return self.dbtypes._regtypes + regtypes = bool(regtypes) + if regtypes != self.dbtypes._regtypes: + self.dbtypes._regtypes = regtypes + self._attnames.clear() + self.dbtypes.clear() + return regtypes + + def has_table_privilege(self, table: str, privilege: str = 'select', + flush: bool = False) -> bool: + """Check whether current user has specified table privilege. + + If flush is set, then the internal cache for table privileges will + be flushed. This may be necessary after privileges have been changed. + """ + privileges = self._privileges + if flush: + privileges.clear() + self._do_debug('The privileges cache has been flushed') + privilege = privilege.lower() + try: # ask cache + ret = privileges[table, privilege] + except KeyError: # cache miss, ask the database + cmd = "SELECT pg_catalog.has_table_privilege({}, $2)".format( + quote_if_unqualified('$1', table)) + query = self._valid_db.query(cmd, (table, privilege)) + ret = query.singlescalar() == self._make_bool(True) + privileges[table, privilege] = ret # cache it + return ret + + def get(self, table: str, row: Any, + keyname: str | tuple[str, ...] | None = None) -> dict[str, Any]: + """Get a row from a database table or view. + + This method is the basic mechanism to get a single row. It assumes + that the keyname specifies a unique row. It must be the name of a + single column or a tuple of column names. If the keyname is not + specified, then the primary key for the table is used. + + If row is a dictionary, then the value for the key is taken from it. + Otherwise, the row must be a single value or a tuple of values + corresponding to the passed keyname or primary key. The fetched row + from the table will be returned as a new dictionary or used to replace + the existing values when row was passed as a dictionary. + + The OID is also put into the dictionary if the table has one, but + in order to allow the caller to work with multiple tables, it is + munged as "oid(table)" using the actual name of the table. + """ + if table.endswith('*'): # hint for descendant tables can be ignored + table = table[:-1].rstrip() + attnames = self.get_attnames(table) + qoid = oid_key(table) if 'oid' in attnames else None + if keyname and isinstance(keyname, str): + keyname = (keyname,) + if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row: + row['oid'] = row[qoid] + if not keyname: + try: # if keyname is not specified, try using the primary key + keyname = self.pkeys(table) + except KeyError as e: # the table has no primary key + # try using the oid instead + if qoid and isinstance(row, dict) and 'oid' in row: + keyname = ('oid',) + else: + raise prg_error( + f'Table {table} has no primary key') from e + else: # the table has a primary key + # check whether all key columns have values + if isinstance(row, dict) and not set(keyname).issubset(row): + # try using the oid instead + if qoid and 'oid' in row: + keyname = ('oid',) + else: + raise KeyError( + 'Missing value in row for specified keyname') + if not isinstance(row, dict): + if not isinstance(row, (tuple, list)): + row = [row] + if len(keyname) != len(row): + raise KeyError( + 'Differing number of items in keyname and row') + row = dict(zip(keyname, row)) + params = self.adapter.parameter_list() + adapt = params.add + col = self.escape_identifier + what = 'oid, *' if qoid else '*' + where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( + col(k), adapt(row[k], attnames[k])) for k in keyname) + if 'oid' in row: + if qoid: + row[qoid] = row['oid'] + del row['oid'] + t = self._escape_qualified_name(table) + cmd = f'SELECT {what} FROM {t} WHERE {where} LIMIT 1' # noqa: S608s + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() + if not res: + # make where clause in error message better readable + where = where.replace('OPERATOR(pg_catalog.=)', '=') + raise db_error( + f'No such record in {table}\nwhere {where}\nwith ' + + self._list_params(params)) + for n, value in res[0].items(): + if qoid and n == 'oid': + n = qoid + row[n] = value + return row + + def insert(self, table: str, row: dict[str, Any] | None = None, **kw: Any + ) -> dict[str, Any]: + """Insert a row into a database table. + + This method inserts a row into a table. The name of the table must + be passed as the first parameter. The other parameters are used for + providing the data of the row that shall be inserted into the table. + If a dictionary is supplied as the second parameter, it starts with + that. Otherwise, it uses a blank dictionary. + Either way the dictionary is updated from the keywords. + + The dictionary is then reloaded with the values actually inserted in + order to pick up values modified by rules, triggers, etc. + """ + if table.endswith('*'): # hint for descendant tables can be ignored + table = table[:-1].rstrip() + if row is None: + row = {} + row.update(kw) + if 'oid' in row: + del row['oid'] # do not insert oid + attnames = self.get_attnames(table) + generated = self.get_generated(table) + qoid = oid_key(table) if 'oid' in attnames else None + params = self.adapter.parameter_list() + adapt = params.add + col = self.escape_identifier + name_list, value_list = [], [] + for n in attnames: + if n in row and n not in generated: + name_list.append(col(n)) + value_list.append(adapt(row[n], attnames[n])) + if not name_list: + raise prg_error('No column found that can be inserted') + names, values = ', '.join(name_list), ', '.join(value_list) + ret = 'oid, *' if qoid else '*' + t = self._escape_qualified_name(table) + cmd = (f'INSERT INTO {t} ({names})' # noqa: S608 + f' VALUES ({values}) RETURNING {ret}') + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() + if res: # this should always be true + for n, value in res[0].items(): + if qoid and n == 'oid': + n = qoid + row[n] = value + return row + + def update(self, table: str, row: dict[str, Any] | None = None, **kw : Any + ) -> dict[str, Any]: + """Update an existing row in a database table. + + Similar to insert, but updates an existing row. The update is based + on the primary key of the table or the OID value as munged by get() + or passed as keyword. The OID will take precedence if provided, so + that it is possible to update the primary key itself. + + The dictionary is then modified to reflect any changes caused by the + update due to triggers, rules, default values, etc. + """ + if table.endswith('*'): + table = table[:-1].rstrip() # need parent table name + attnames = self.get_attnames(table) + generated = self.get_generated(table) + qoid = oid_key(table) if 'oid' in attnames else None + if row is None: + row = {} + elif 'oid' in row: + del row['oid'] # only accept oid key from named args for safety + row.update(kw) + if qoid and qoid in row and 'oid' not in row: + row['oid'] = row[qoid] + if qoid and 'oid' in row: # try using the oid + keynames: tuple[str, ...] = ('oid',) + keyset = set(keynames) + else: # try using the primary key + try: + keynames = self.pkeys(table) + except KeyError as e: # the table has no primary key + raise prg_error(f'Table {table} has no primary key') from e + keyset = set(keynames) + # check whether all key columns have values + if not keyset.issubset(row): + raise KeyError('Missing value for primary key in row') + params = self.adapter.parameter_list() + adapt = params.add + col = self.escape_identifier + where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( + col(k), adapt(row[k], attnames[k])) for k in keynames) + if 'oid' in row: + if qoid: + row[qoid] = row['oid'] + del row['oid'] + values_list = [] + for n in attnames: + if n in row and n not in keyset and n not in generated: + values_list.append(f'{col(n)} = {adapt(row[n], attnames[n])}') + if not values_list: + return row + values = ', '.join(values_list) + ret = 'oid, *' if qoid else '*' + t = self._escape_qualified_name(table) + cmd = (f'UPDATE {t} SET {values}' # noqa: S608 + f' WHERE {where} RETURNING {ret}') + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() + if res: # may be empty when row does not exist + for n, value in res[0].items(): + if qoid and n == 'oid': + n = qoid + row[n] = value + return row + + def upsert(self, table: str, row: dict[str, Any] | None = None, **kw: Any + ) -> dict[str, Any]: + """Insert a row into a database table with conflict resolution. + + This method inserts a row into a table, but instead of raising a + ProgrammingError exception in case a row with the same primary key + already exists, an update will be executed instead. This will be + performed as a single atomic operation on the database, so race + conditions can be avoided. + + Like the insert method, the first parameter is the name of the + table and the second parameter can be used to pass the values to + be inserted as a dictionary. + + Unlike the insert und update statement, keyword parameters are not + used to modify the dictionary, but to specify which columns shall + be updated in case of a conflict, and in which way: + + A value of False or None means the column shall not be updated, + a value of True means the column shall be updated with the value + that has been proposed for insertion, i.e. has been passed as value + in the dictionary. Columns that are not specified by keywords but + appear as keys in the dictionary are also updated like in the case + keywords had been passed with the value True. + + So if in the case of a conflict you want to update every column + that has been passed in the dictionary row, you would call + upsert(table, row). If you don't want to do anything in case + of a conflict, i.e. leave the existing row as it is, call + upsert(table, row, **dict.fromkeys(row)). + + If you need more fine-grained control of what gets updated, you can + also pass strings in the keyword parameters. These strings will + be used as SQL expressions for the update columns. In these + expressions you can refer to the value that already exists in + the table by prefixing the column name with "included.", and to + the value that has been proposed for insertion by prefixing the + column name with the "excluded." + + The dictionary is modified in any case to reflect the values in + the database after the operation has completed. + + Note: The method uses the PostgreSQL "upsert" feature which is + only available since PostgreSQL 9.5. + """ + if table.endswith('*'): # hint for descendant tables can be ignored + table = table[:-1].rstrip() + if row is None: + row = {} + if 'oid' in row: + del row['oid'] # do not insert oid + if 'oid' in kw: + del kw['oid'] # do not update oid + attnames = self.get_attnames(table) + generated = self.get_generated(table) + qoid = oid_key(table) if 'oid' in attnames else None + params = self.adapter.parameter_list() + adapt = params.add + col = self.escape_identifier + name_list, value_list = [], [] + for n in attnames: + if n in row and n not in generated: + name_list.append(col(n)) + value_list.append(adapt(row[n], attnames[n])) + names, values = ', '.join(name_list), ', '.join(value_list) + try: + keynames = self.pkeys(table) + except KeyError as e: + raise prg_error(f'Table {table} has no primary key') from e + target = ', '.join(col(k) for k in keynames) + update = [] + keyset = set(keynames) + keyset.add('oid') + for n in attnames: + if n not in keyset and n not in generated: + value = kw.get(n, n in row) + if value: + if not isinstance(value, str): + value = f'excluded.{col(n)}' + update.append(f'{col(n)} = {value}') + if not values: + return row + do = 'update set ' + ', '.join(update) if update else 'nothing' + ret = 'oid, *' if qoid else '*' + t = self._escape_qualified_name(table) + cmd = (f'INSERT INTO {t} AS included ({names})' # noqa: S608 + f' VALUES ({values})' + f' ON CONFLICT ({target}) DO {do} RETURNING {ret}') + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() + if res: # may be empty with "do nothing" + for n, value in res[0].items(): + if qoid and n == 'oid': + n = qoid + row[n] = value + else: + self.get(table, row) + return row + + def clear(self, table: str, row: dict[str, Any] | None = None + ) -> dict[str, Any]: + """Clear all the attributes to values determined by the types. + + Numeric types are set to 0, Booleans are set to false, and everything + else is set to the empty string. If the row argument is present, + it is used as the row dictionary and any entries matching attribute + names are cleared with everything else left unchanged. + """ + # At some point we will need a way to get defaults from a table. + if row is None: + row = {} # empty if argument is not present + attnames = self.get_attnames(table) + for n, t in attnames.items(): + if n == 'oid': + continue + t = t.simple + if t in DbTypes._num_types: + row[n] = 0 + elif t == 'bool': + row[n] = self._make_bool(False) + else: + row[n] = '' + return row + + def delete(self, table: str, row: dict[str, Any] | None = None, **kw: Any + ) -> int: + """Delete an existing row in a database table. + + This method deletes the row from a table. It deletes based on the + primary key of the table or the OID value as munged by get() or + passed as keyword. The OID will take precedence if provided. + + The return value is the number of deleted rows (i.e. 0 if the row + did not exist and 1 if the row was deleted). + + Note that if the row cannot be deleted because e.g. it is still + referenced by another table, this method raises a ProgrammingError. + """ + if table.endswith('*'): # hint for descendant tables can be ignored + table = table[:-1].rstrip() + attnames = self.get_attnames(table) + qoid = oid_key(table) if 'oid' in attnames else None + if row is None: + row = {} + elif 'oid' in row: + del row['oid'] # only accept oid key from named args for safety + row.update(kw) + if qoid and qoid in row and 'oid' not in row: + row['oid'] = row[qoid] + if qoid and 'oid' in row: # try using the oid + keynames: tuple[str, ...] = ('oid',) + else: # try using the primary key + try: + keynames = self.pkeys(table) + except KeyError as e: # the table has no primary key + raise prg_error(f'Table {table} has no primary key') from e + # check whether all key columns have values + if not set(keynames).issubset(row): + raise KeyError('Missing value for primary key in row') + params = self.adapter.parameter_list() + adapt = params.add + col = self.escape_identifier + where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( + col(k), adapt(row[k], attnames[k])) for k in keynames) + if 'oid' in row: + if qoid: + row[qoid] = row['oid'] + del row['oid'] + t = self._escape_qualified_name(table) + cmd = f'DELETE FROM {t} WHERE {where}' # noqa: S608 + self._do_debug(cmd, params) + res = self._valid_db.query(cmd, params) + return int(res) # type: ignore + + def truncate(self, table: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str], restart: bool = False, + cascade: bool = False, only: bool = False) -> Query: + """Empty a table or set of tables. + + This method quickly removes all rows from the given table or set + of tables. It has the same effect as an unqualified DELETE on each + table, but since it does not actually scan the tables it is faster. + Furthermore, it reclaims disk space immediately, rather than requiring + a subsequent VACUUM operation. This is most useful on large tables. + + If restart is set to True, sequences owned by columns of the truncated + table(s) are automatically restarted. If cascade is set to True, it + also truncates all tables that have foreign-key references to any of + the named tables. If the parameter 'only' is not set to True, all the + descendant tables (if any) will also be truncated. Optionally, a '*' + can be specified after the table name to explicitly indicate that + descendant tables are included. + """ + if isinstance(table, str): + table_only = {table: only} + table = [table] + elif isinstance(table, (list, tuple)): + if isinstance(only, (list, tuple)): + table_only = dict(zip(table, only)) + else: + table_only = dict.fromkeys(table, only) + elif isinstance(table, (set, frozenset)): + table_only = dict.fromkeys(table, only) + else: + raise TypeError('The table must be a string, list or set') + if not (restart is None or isinstance(restart, (bool, int))): + raise TypeError('Invalid type for the restart option') + if not (cascade is None or isinstance(cascade, (bool, int))): + raise TypeError('Invalid type for the cascade option') + tables = [] + for t in table: + u = table_only.get(t) + if not (u is None or isinstance(u, (bool, int))): + raise TypeError('Invalid type for the only option') + if t.endswith('*'): + if u: + raise ValueError( + 'Contradictory table name and only options') + t = t[:-1].rstrip() + t = self._escape_qualified_name(t) + if u: + t = f'ONLY {t}' + tables.append(t) + cmd_parts = ['TRUNCATE', ', '.join(tables)] + if restart: + cmd_parts.append('RESTART IDENTITY') + if cascade: + cmd_parts.append('CASCADE') + cmd = ' '.join(cmd_parts) + self._do_debug(cmd) + return self._valid_db.query(cmd) + + def get_as_list( + self, table: str, + what: str | list[str] | tuple[str, ...] | None = None, + where: str | list[str] | tuple[str, ...] | None = None, + order: str | list[str] | tuple[str, ...] | bool | None = None, + limit: int | None = None, offset: int | None = None, + scalar: bool = False) -> list: + """Get a table as a list. + + This gets a convenient representation of the table as a list + of named tuples in Python. You only need to pass the name of + the table (or any other SQL expression returning rows). Note that + by default this will return the full content of the table which + can be huge and overflow your memory. However, you can control + the amount of data returned using the other optional parameters. + + The parameter 'what' can restrict the query to only return a + subset of the table columns. It can be a string, list or a tuple. + + The parameter 'where' can restrict the query to only return a + subset of the table rows. It can be a string, list or a tuple + of SQL expressions that all need to be fulfilled. + + The parameter 'order' specifies the ordering of the rows. It can + also be a string, list or a tuple. If no ordering is specified, + the result will be ordered by the primary key(s) or all columns if + no primary key exists. You can set 'order' to False if you don't + care about the ordering. The parameters 'limit' and 'offset' can be + integers specifying the maximum number of rows returned and a number + of rows skipped over. + + If you set the 'scalar' option to True, then instead of the + named tuples you will get the first items of these tuples. + This is useful if the result has only one column anyway. + """ + if not table: + raise TypeError('The table name is missing') + if what: + if isinstance(what, (list, tuple)): + what = ', '.join(map(str, what)) + if order is None: + order = what + else: + what = '*' + cmd_parts = ['SELECT', what, 'FROM', table] + if where: + if isinstance(where, (list, tuple)): + where = ' AND '.join(map(str, where)) + cmd_parts.extend(['WHERE', where]) + if order is None or order is True: + try: + order = self.pkeys(table) + except (KeyError, ProgrammingError): + with suppress(KeyError, ProgrammingError): + order = list(self.get_attnames(table)) + if order and not isinstance(order, bool): + if isinstance(order, (list, tuple)): + order = ', '.join(map(str, order)) + cmd_parts.extend(['ORDER BY', order]) + if limit: + cmd_parts.append(f'LIMIT {limit}') + if offset: + cmd_parts.append(f'OFFSET {offset}') + cmd = ' '.join(cmd_parts) + self._do_debug(cmd) + query = self._valid_db.query(cmd) + res = query.namedresult() + if res and scalar: + res = [row[0] for row in res] + return res + + def get_as_dict( + self, table: str, + keyname: str | list[str] | tuple[str, ...] | None = None, + what: str | list[str] | tuple[str, ...] | None = None, + where: str | list[str] | tuple[str, ...] | None = None, + order: str | list[str] | tuple[str, ...] | bool | None = None, + limit: int | None = None, offset: int | None = None, + scalar: bool = False) -> dict: + """Get a table as a dictionary. + + This method is similar to get_as_list(), but returns the table + as a Python dict instead of a Python list, which can be even + more convenient. The primary key column(s) of the table will + be used as the keys of the dictionary, while the other column(s) + will be the corresponding values. The keys will be named tuples + if the table has a composite primary key. The rows will be also + named tuples unless the 'scalar' option has been set to True. + With the optional parameter 'keyname' you can specify an alternative + set of columns to be used as the keys of the dictionary. It must + be set as a string, list or a tuple. + + The dictionary will be ordered using the order specified with the + 'order' parameter or the key column(s) if not specified. You can + set 'order' to False if you don't care about the ordering. + """ + if not table: + raise TypeError('The table name is missing') + if not keyname: + try: + keyname = self.pkeys(table) + except (KeyError, ProgrammingError) as e: + raise prg_error(f'Table {table} has no primary key') from e + if isinstance(keyname, str): + keynames: list[str] | tuple[str, ...] = (keyname,) + elif isinstance(keyname, (list, tuple)): + keynames = keyname + else: + raise KeyError('The keyname must be a string, list or tuple') + if what: + if isinstance(what, (list, tuple)): + what = ', '.join(map(str, what)) + if order is None: + order = what + else: + what = '*' + cmd_parts = ['SELECT', what, 'FROM', table] + if where: + if isinstance(where, (list, tuple)): + where = ' AND '.join(map(str, where)) + cmd_parts.extend(['WHERE', where]) + if order is None or order is True: + order = keyname + if order and not isinstance(order, bool): + if isinstance(order, (list, tuple)): + order = ', '.join(map(str, order)) + cmd_parts.extend(['ORDER BY', order]) + if limit: + cmd_parts.append(f'LIMIT {limit}') + if offset: + cmd_parts.append(f'OFFSET {offset}') + cmd = ' '.join(cmd_parts) + self._do_debug(cmd) + query = self._valid_db.query(cmd) + res = query.getresult() + if not res: + return {} + keyset = set(keynames) + fields = query.listfields() + if not keyset.issubset(fields): + raise KeyError('Missing keyname in row') + key_index: list[int] = [] + row_index: list[int] = [] + for i, f in enumerate(fields): + (key_index if f in keyset else row_index).append(i) + key_tuple = len(key_index) > 1 + get_key = itemgetter(*key_index) + keys = map(get_key, res) + if scalar: + row_index = row_index[:1] + row_is_tuple = False + else: + row_is_tuple = len(row_index) > 1 + if scalar or row_is_tuple: + get_row: Callable[[tuple], tuple] = itemgetter( # pyright: ignore + *row_index) + else: + frst_index = row_index[0] + + def get_row(row : tuple) -> tuple: + return row[frst_index], # tuple with one item + + row_is_tuple = True + rows = map(get_row, res) + if key_tuple or row_is_tuple: + if key_tuple: + keys = namediter(_MemoryQuery(keys, keynames)) # type: ignore + if row_is_tuple: + fields = tuple(f for f in fields if f not in keyset) + rows = namediter(_MemoryQuery(rows, fields)) # type: ignore + # noinspection PyArgumentList + return dict(zip(keys, rows)) + + def notification_handler(self, event: str, callback: Callable, + arg_dict: dict | None = None, + timeout: int | float | None = None, + stop_event: str | None = None + ) -> NotificationHandler: + """Get notification handler that will run the given callback.""" + return NotificationHandler(self, event, callback, + arg_dict, timeout, stop_event) + + +class _MemoryQuery: + """Class that embodies a given query result.""" + + result: Any + fields: tuple[str, ...] + + def __init__(self, result: Any, fields: Sequence[str]) -> None: + """Create query from given result rows and field names.""" + self.result = result + self.fields = tuple(fields) + + def listfields(self) -> tuple[str, ...]: + """Return the stored field names of this query.""" + return self.fields + + def getresult(self) -> Any: + """Return the stored result of this query.""" + return self.result + + def __iter__(self) -> Iterator[Any]: + return iter(self.result) \ No newline at end of file diff --git a/pg/error.py b/pg/error.py new file mode 100644 index 00000000..b3164b42 --- /dev/null +++ b/pg/error.py @@ -0,0 +1,35 @@ +"""Error helpers.""" + +from __future__ import annotations + +from typing import TypeVar + +from .core import DatabaseError, Error, InternalError, ProgrammingError + +__all__ = ['error', 'db_error', 'int_error', 'prg_error'] + +# Error messages + +E = TypeVar('E', bound=Error) + +def error(msg: str, cls: type[E]) -> E: + """Return specified error object with empty sqlstate attribute.""" + error = cls(msg) + if isinstance(error, DatabaseError): + error.sqlstate = None + return error + + +def db_error(msg: str) -> DatabaseError: + """Return DatabaseError.""" + return error(msg, DatabaseError) + + +def int_error(msg: str) -> InternalError: + """Return InternalError.""" + return error(msg, InternalError) + + +def prg_error(msg: str) -> ProgrammingError: + """Return ProgrammingError.""" + return error(msg, ProgrammingError) \ No newline at end of file diff --git a/pg/helpers.py b/pg/helpers.py new file mode 100644 index 00000000..4426cfbc --- /dev/null +++ b/pg/helpers.py @@ -0,0 +1,98 @@ +"""Helper functions.""" + +from __future__ import annotations + +from collections import namedtuple +from decimal import Decimal +from functools import lru_cache +from json import loads as jsondecode +from typing import Any, Callable, Generator, NamedTuple, Sequence + +from .core import Query, set_decimal, set_jsondecode, set_query_helpers + +SomeNamedTuple = Any # alias for accessing arbitrary named tuples + +__all__ = [ + 'quote_if_unqualified', 'oid_key', 'set_row_factory_size', + 'dictiter', 'namediter', 'namednext', 'scalariter' +] + + +# Small helper functions + +def quote_if_unqualified(param: str, name: int | str) -> str: + """Quote parameter representing a qualified name. + + Puts a quote_ident() call around the given parameter unless + the name contains a dot, in which case the name is ambiguous + (could be a qualified name or just a name with a dot in it) + and must be quoted manually by the caller. + """ + if isinstance(name, str) and '.' not in name: + return f'quote_ident({param})' + return param + +def oid_key(table: str) -> str: + """Build oid key from a table name.""" + return f'oid({table})' + + +# Row factory + +# The result rows for database operations are returned as named tuples +# by default. Since creating namedtuple classes is a somewhat expensive +# operation, we cache up to 1024 of these classes by default. + +@lru_cache(maxsize=1024) +def _row_factory(names: Sequence[str]) -> Callable[[Sequence], NamedTuple]: + """Get a namedtuple factory for row results with the given names.""" + try: + return namedtuple('Row', names, rename=True)._make # type: ignore + except ValueError: # there is still a problem with the field names + names = [f'column_{n}' for n in range(len(names))] + return namedtuple('Row', names)._make # type: ignore + + +def set_row_factory_size(maxsize: int | None) -> None: + """Change the size of the namedtuple factory cache. + + If maxsize is set to None, the cache can grow without bound. + """ + global _row_factory + _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__) + + +# Helper functions used by the query object + +def dictiter(q: Query) -> Generator[dict[str, Any], None, None]: + """Get query result as an iterator of dictionaries.""" + fields: tuple[str, ...] = q.listfields() + for r in q: + yield dict(zip(fields, r)) + + +def namediter(q: Query) -> Generator[SomeNamedTuple, None, None]: + """Get query result as an iterator of named tuples.""" + row = _row_factory(q.listfields()) + for r in q: + yield row(r) + + +def namednext(q: Query) -> SomeNamedTuple: + """Get next row from query result as a named tuple.""" + return _row_factory(q.listfields())(next(q)) + + +def scalariter(q: Query) -> Generator[Any, None, None]: + """Get query result as an iterator of scalar values.""" + for r in q: + yield r[0] + + +# Initialization + +def init_core() -> None: + """Initialize the C extension module.""" + set_decimal(Decimal) + set_jsondecode(jsondecode) + set_query_helpers(dictiter, namediter, namednext, scalariter) diff --git a/pg/notify.py b/pg/notify.py new file mode 100644 index 00000000..e273c521 --- /dev/null +++ b/pg/notify.py @@ -0,0 +1,149 @@ +"""Handling of notifications.""" + +from __future__ import annotations + +import select +from typing import TYPE_CHECKING, Callable + +from .core import Query +from .error import db_error + +if TYPE_CHECKING: + from .db import DB + +__all__ = ['NotificationHandler'] + +# The notification handler + +class NotificationHandler: + """A PostgreSQL client-side asynchronous notification handler.""" + + def __init__(self, db: DB, event: str, callback: Callable, + arg_dict: dict | None = None, + timeout: int | float | None = None, + stop_event: str | None = None): + """Initialize the notification handler. + + You must pass a PyGreSQL database connection, the name of an + event (notification channel) to listen for and a callback function. + + You can also specify a dictionary arg_dict that will be passed as + the single argument to the callback function, and a timeout value + in seconds (a floating point number denotes fractions of seconds). + If it is absent or None, the callers will never time out. If the + timeout is reached, the callback function will be called with a + single argument that is None. If you set the timeout to zero, + the handler will poll notifications synchronously and return. + + You can specify the name of the event that will be used to signal + the handler to stop listening as stop_event. By default, it will + be the event name prefixed with 'stop_'. + """ + self.db: DB | None = db + self.event = event + self.stop_event = stop_event or f'stop_{event}' + self.listening = False + self.callback = callback + if arg_dict is None: + arg_dict = {} + self.arg_dict = arg_dict + self.timeout = timeout + + def __del__(self) -> None: + """Delete the notification handler.""" + self.unlisten() + + def close(self) -> None: + """Stop listening and close the connection.""" + if self.db: + self.unlisten() + self.db.close() + self.db = None + + def listen(self) -> None: + """Start listening for the event and the stop event.""" + db = self.db + if db and not self.listening: + db.query(f'listen "{self.event}"') + db.query(f'listen "{self.stop_event}"') + self.listening = True + + def unlisten(self) -> None: + """Stop listening for the event and the stop event.""" + db = self.db + if db and self.listening: + db.query(f'unlisten "{self.event}"') + db.query(f'unlisten "{self.stop_event}"') + self.listening = False + + def notify(self, db: DB | None = None, stop: bool = False, + payload: str | None = None) -> Query | None: + """Generate a notification. + + Optionally, you can pass a payload with the notification. + + If you set the stop flag, a stop notification will be sent that + will cause the handler to stop listening. + + Note: If the notification handler is running in another thread, you + must pass a different database connection since PyGreSQL database + connections are not thread-safe. + """ + if not self.listening: + return None + if not db: + db = self.db + if not db: + return None + event = self.stop_event if stop else self.event + cmd = f'notify "{event}"' + if payload: + cmd += f", '{payload}'" + return db.query(cmd) + + def __call__(self) -> None: + """Invoke the notification handler. + + The handler is a loop that listens for notifications on the event + and stop event channels. When either of these notifications are + received, its associated 'pid', 'event' and 'extra' (the payload + passed with the notification) are inserted into its arg_dict + dictionary and the callback is invoked with this dictionary as + a single argument. When the handler receives a stop event, it + stops listening to both events and return. + + In the special case that the timeout of the handler has been set + to zero, the handler will poll all events synchronously and return. + If will keep listening until it receives a stop event. + + Note: If you run this loop in another thread, don't use the same + database connection for database operations in the main thread. + """ + if not self.db: + return + self.listen() + poll = self.timeout == 0 + rlist = [] if poll else [self.db.fileno()] + while self.db and self.listening: + # noinspection PyUnboundLocalVariable + if poll or select.select(rlist, [], [], self.timeout)[0]: + while self.db and self.listening: + notice = self.db.getnotify() + if not notice: # no more messages + break + event, pid, extra = notice + if event not in (self.event, self.stop_event): + self.unlisten() + raise db_error( + f'Listening for "{self.event}"' + f' and "{self.stop_event}",' + f' but notified of "{event}"') + if event == self.stop_event: + self.unlisten() + self.arg_dict.update(pid=pid, event=event, extra=extra) + self.callback(self.arg_dict) + if poll: + break + else: # we timed out + self.unlisten() + self.callback(None) \ No newline at end of file diff --git a/pg/tz.py b/pg/tz.py new file mode 100644 index 00000000..7f22e049 --- /dev/null +++ b/pg/tz.py @@ -0,0 +1,21 @@ +"""Timezone helpers.""" + +from __future__ import annotations + +__all__ = ['timezone_as_offset'] + +# time zones used in Postgres timestamptz output +_timezone_offsets: dict[str, str] = { + 'CET': '+0100', 'EET': '+0200', 'EST': '-0500', + 'GMT': '+0000', 'HST': '-1000', 'MET': '+0100', 'MST': '-0700', + 'UCT': '+0000', 'UTC': '+0000', 'WET': '+0000' +} + + +def timezone_as_offset(tz: str) -> str: + """Convert timezone abbreviation to offset.""" + if tz.startswith(('+', '-')): + if len(tz) < 5: + return tz + '00' + return tz.replace(':', '') + return _timezone_offsets.get(tz, '+0000') \ No newline at end of file diff --git a/tests/test_classic_attrdict.py b/tests/test_classic_attrdict.py new file mode 100644 index 00000000..8eef00df --- /dev/null +++ b/tests/test_classic_attrdict.py @@ -0,0 +1,100 @@ +#!/usr/bin/python + +"""Test the classic PyGreSQL interface. + +Sub-tests for the DB wrapper object. + +Contributed by Christoph Zwerschke. + +These tests need a database to test against. +""" + +import unittest + +import pg.attrs # the module under test + + +class TestAttrDict(unittest.TestCase): + """Test the simple ordered dictionary for attribute names.""" + + cls = pg.attrs.AttrDict + + def test_init(self): + a = self.cls() + self.assertIsInstance(a, dict) + self.assertEqual(a, {}) + items = [('id', 'int'), ('name', 'text')] + a = self.cls(items) + self.assertIsInstance(a, dict) + self.assertEqual(a, dict(items)) + iteritems = iter(items) + a = self.cls(iteritems) + self.assertIsInstance(a, dict) + self.assertEqual(a, dict(items)) + + def test_iter(self): + a = self.cls() + self.assertEqual(list(a), []) + keys = ['id', 'name', 'age'] + items = [(key, None) for key in keys] + a = self.cls(items) + self.assertEqual(list(a), keys) + + def test_keys(self): + a = self.cls() + self.assertEqual(list(a.keys()), []) + keys = ['id', 'name', 'age'] + items = [(key, None) for key in keys] + a = self.cls(items) + self.assertEqual(list(a.keys()), keys) + + def test_values(self): + a = self.cls() + self.assertEqual(list(a.values()), []) + items = [('id', 'int'), ('name', 'text')] + values = [item[1] for item in items] + a = self.cls(items) + self.assertEqual(list(a.values()), values) + + def test_items(self): + a = self.cls() + self.assertEqual(list(a.items()), []) + items = [('id', 'int'), ('name', 'text')] + a = self.cls(items) + self.assertEqual(list(a.items()), items) + + def test_get(self): + a = self.cls([('id', 1)]) + try: + self.assertEqual(a['id'], 1) + except KeyError: + self.fail('AttrDict should be readable') + + def test_set(self): + a = self.cls() + try: + a['id'] = 1 + except TypeError: + pass + else: + self.fail('AttrDict should be read-only') + + def test_del(self): + a = self.cls([('id', 1)]) + try: + del a['id'] + except TypeError: + pass + else: + self.fail('AttrDict should be read-only') + + def test_write_methods(self): + a = self.cls([('id', 1)]) + self.assertEqual(a['id'], 1) + for method in 'clear', 'update', 'pop', 'setdefault', 'popitem': + method = getattr(a, method) + self.assertRaises(TypeError, method, a) # type: ignore + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index d6a742bf..eca64afd 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -2353,7 +2353,7 @@ def test_get_decimal_point(self): self.assertIsNone(r) def test_set_decimal_point(self): - d = pg.Decimal + d = Decimal point = pg.get_decimal_point() self.assertRaises(TypeError, pg.set_decimal_point) # error if decimal point is not a string @@ -2480,7 +2480,7 @@ def test_get_decimal(self): decimal_class = pg.get_decimal() # error if a parameter is passed self.assertRaises(TypeError, pg.get_decimal, decimal_class) - self.assertIs(decimal_class, pg.Decimal) # the default setting + self.assertIs(decimal_class, Decimal) # the default setting pg.set_decimal(int) try: r = pg.get_decimal() @@ -2499,7 +2499,6 @@ def test_set_decimal(self): r = query("select 3425::numeric") except pg.DatabaseError: self.skipTest('database does not support numeric') - r = None r = r.getresult()[0][0] self.assertIsInstance(r, decimal_class) self.assertEqual(r, decimal_class('3425')) @@ -2557,7 +2556,6 @@ def test_set_bool(self): r = query("select true::bool") except pg.ProgrammingError: self.skipTest('database does not support bool') - r = None r = r.getresult()[0][0] self.assertIsInstance(r, bool) self.assertEqual(r, True) @@ -2620,7 +2618,6 @@ def test_set_bytea_escaped(self): r = query("select 'data'::bytea") except pg.ProgrammingError: self.skipTest('database does not support bytea') - r = None r = r.getresult()[0][0] self.assertIsInstance(r, bytes) self.assertEqual(r, b'data') @@ -2653,7 +2650,8 @@ def test_set_row_factory_size(self): else: self.assertEqual(r, (1, 2, 3)) self.assertEqual(r._fields, ('a', 'b', 'c')) - info = pg._row_factory.cache_info() + from pg.helpers import _row_factory + info = _row_factory.cache_info() self.assertEqual(info.maxsize, maxsize) self.assertEqual(info.hits + info.misses, 6) self.assertEqual( diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 74d6df8e..8ebb8214 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -49,88 +49,6 @@ def DB(): # noqa: N802 return db -class TestAttrDict(unittest.TestCase): - """Test the simple ordered dictionary for attribute names.""" - - cls = pg.AttrDict - - def test_init(self): - a = self.cls() - self.assertIsInstance(a, dict) - self.assertEqual(a, {}) - items = [('id', 'int'), ('name', 'text')] - a = self.cls(items) - self.assertIsInstance(a, dict) - self.assertEqual(a, dict(items)) - iteritems = iter(items) - a = self.cls(iteritems) - self.assertIsInstance(a, dict) - self.assertEqual(a, dict(items)) - - def test_iter(self): - a = self.cls() - self.assertEqual(list(a), []) - keys = ['id', 'name', 'age'] - items = [(key, None) for key in keys] - a = self.cls(items) - self.assertEqual(list(a), keys) - - def test_keys(self): - a = self.cls() - self.assertEqual(list(a.keys()), []) - keys = ['id', 'name', 'age'] - items = [(key, None) for key in keys] - a = self.cls(items) - self.assertEqual(list(a.keys()), keys) - - def test_values(self): - a = self.cls() - self.assertEqual(list(a.values()), []) - items = [('id', 'int'), ('name', 'text')] - values = [item[1] for item in items] - a = self.cls(items) - self.assertEqual(list(a.values()), values) - - def test_items(self): - a = self.cls() - self.assertEqual(list(a.items()), []) - items = [('id', 'int'), ('name', 'text')] - a = self.cls(items) - self.assertEqual(list(a.items()), items) - - def test_get(self): - a = self.cls([('id', 1)]) - try: - self.assertEqual(a['id'], 1) - except KeyError: - self.fail('AttrDict should be readable') - - def test_set(self): - a = self.cls() - try: - a['id'] = 1 - except TypeError: - pass - else: - self.fail('AttrDict should be read-only') - - def test_del(self): - a = self.cls([('id', 1)]) - try: - del a['id'] - except TypeError: - pass - else: - self.fail('AttrDict should be read-only') - - def test_write_methods(self): - a = self.cls([('id', 1)]) - self.assertEqual(a['id'], 1) - for method in 'clear', 'update', 'pop', 'setdefault', 'popitem': - method = getattr(a, method) - self.assertRaises(TypeError, method, a) # type: ignore - - class TestDBClassInit(unittest.TestCase): """Test proper handling of errors when creating DB instances.""" @@ -491,8 +409,8 @@ def test_class_name(self): self.assertEqual(self.db.__class__.__name__, 'DB') def test_module_name(self): - self.assertEqual(self.db.__module__, 'pg') - self.assertEqual(self.db.__class__.__module__, 'pg') + self.assertEqual(self.db.__module__, 'pg.db') + self.assertEqual(self.db.__class__.__module__, 'pg.db') def test_escape_literal(self): f = self.db.escape_literal @@ -1437,21 +1355,21 @@ def test_get_attnames_is_ordered(self): self.assertEqual(r, 'n alpha v gamma tau beta') def test_get_attnames_is_attr_dict(self): - AttrDict = pg.AttrDict # noqa: N806 + from pg.attrs import AttrDict get_attnames = self.db.get_attnames r = get_attnames('test', flush=True) self.assertIsInstance(r, AttrDict) if self.regtypes: - self.assertEqual(r, AttrDict([ - ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'), - ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'), - ('m', 'money'), ('v4', 'character varying'), - ('c4', 'character'), ('t', 'text')])) + self.assertEqual(r, AttrDict({ + 'i2': 'smallint', 'i4': 'integer', 'i8': 'bigint', + 'd': 'numeric', 'f4': 'real', 'f8': 'double precision', + 'm': 'money', 'v4': 'character varying', + 'c4': 'character', 't': 'text'})) else: - self.assertEqual(r, AttrDict([ - ('i2', 'int'), ('i4', 'int'), ('i8', 'int'), - ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'), - ('v4', 'text'), ('c4', 'text'), ('t', 'text')])) + self.assertEqual(r, AttrDict({ + 'i2': 'int', 'i4': 'int', 'i8': 'int', + 'd': 'num', 'f4': 'float', 'f8': 'float', 'm': 'money', + 'v4': 'text', 'c4': 'text', 't': 'text'})) r = ' '.join(list(r.keys())) self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t') table = 'test table for get_attnames' @@ -1461,14 +1379,14 @@ def test_get_attnames_is_attr_dict(self): r = get_attnames(table) self.assertIsInstance(r, AttrDict) if self.regtypes: - self.assertEqual(r, AttrDict([ - ('n', 'integer'), ('alpha', 'smallint'), - ('v', 'character varying'), ('gamma', 'character'), - ('tau', 'text'), ('beta', 'boolean')])) + self.assertEqual(r, AttrDict({ + 'n': 'integer', 'alpha': 'smallint', + 'v': 'character varying', 'gamma': 'character', + 'tau': 'text', 'beta': 'boolean'})) else: - self.assertEqual(r, AttrDict([ - ('n', 'int'), ('alpha', 'int'), ('v', 'text'), - ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')])) + self.assertEqual(r, AttrDict({ + 'n': 'int', 'alpha': 'int', 'v': 'text', + 'gamma': 'text', 'tau': 'text', 'beta': 'bool'})) r = ' '.join(list(r.keys())) self.assertEqual(r, 'n alpha v gamma tau beta') @@ -4204,7 +4122,8 @@ def test_get_set_type_cast(self): self.assertNotIn('bool', dbtypes) self.assertIs(get_typecast('int4'), int) self.assertIs(get_typecast('float4'), float) - self.assertIs(get_typecast('bool'), pg.cast_bool) + from pg.cast import cast_bool + self.assertIs(get_typecast('bool'), cast_bool) cast_circle = get_typecast('circle') self.addCleanup(set_typecast, 'circle', cast_circle) squared_circle = lambda v: f'Squared Circle: {v}' # noqa: E731 @@ -4416,14 +4335,19 @@ def test_adapt_query_typed_list(self): values = [(3, 7.5, 'hello', True, [123], ['abc'])] t = self.adapter.simple_type typ = t('record') - typ._get_attnames = lambda _self: pg.AttrDict([ - ('i', t('int')), ('f', t('float')), - ('t', t('text')), ('b', t('bool')), - ('i3', t('int[]')), ('t3', t('text[]'))]) + from pg.attrs import AttrDict + typ._get_attnames = lambda _self: AttrDict({ + 'i': t('int'), 'f': t('float'), + 't': t('text'), 'b': t('bool'), + 'i3': t('int[]'), 't3': t('text[]')}) types = [typ] sql, params = format_query('select %s', values, types) self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) + values = [(0, -3.25, '', False, [0], [''])] + sql, params = format_query('select %s', values, types) + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['(0,-3.25,"",f,{0},"{\\"\\"}")']) def test_adapt_query_typed_list_with_types_as_string(self): format_query = self.adapter.format_query @@ -4527,14 +4451,19 @@ def test_adapt_query_typed_dict(self): values = dict(record=(3, 7.5, 'hello', True, [123], ['abc'])) t = self.adapter.simple_type typ = t('record') - typ._get_attnames = lambda _self: pg.AttrDict([ - ('i', t('int')), ('f', t('float')), - ('t', t('text')), ('b', t('bool')), - ('i3', t('int[]')), ('t3', t('text[]'))]) + from pg.attrs import AttrDict + typ._get_attnames = lambda _self: AttrDict({ + 'i': t('int'), 'f': t('float'), + 't': t('text'), 'b': t('bool'), + 'i3': t('int[]'), 't3': t('text[]')}) types = dict(record=typ) sql, params = format_query('select %(record)s', values, types) self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) + values = dict(record=(0, -3.25, '', False, [0], [''])) + sql, params = format_query('select %(record)s', values, types) + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['(0,-3.25,"",f,{0},"{\\"\\"}")']) def test_adapt_query_untyped_list(self): format_query = self.adapter.format_query @@ -4560,6 +4489,10 @@ def test_adapt_query_untyped_list(self): sql, params = format_query('select %s', values) self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) + values = [(0, -3.25, '', False, [0], [''])] + sql, params = format_query('select %s', values) + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['(0,-3.25,"",f,{0},"{\\"\\"}")']) def test_adapt_query_untyped_list_with_json(self): format_query = self.adapter.format_query @@ -4601,6 +4534,10 @@ def test_adapt_query_untyped_dict(self): sql, params = format_query('select %(record)s', values) self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) + values = dict(record=(0, -3.25, '', False, [0], [''])) + sql, params = format_query('select %(record)s', values) + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['(0,-3.25,"",f,{0},"{\\"\\"}")']) def test_adapt_query_inline_list(self): format_query = self.adapter.format_query @@ -4629,6 +4566,11 @@ def test_adapt_query_inline_list(self): self.assertEqual( sql, "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])") self.assertEqual(params, []) + values = [(0, -3.25, '', False, [0], [''])] + sql, params = format_query('select %s', values, inline=True) + self.assertEqual( + sql, "select (0,-3.25,'',false,ARRAY[0],ARRAY[''])") + self.assertEqual(params, []) def test_adapt_query_inline_list_with_json(self): format_query = self.adapter.format_query @@ -4676,6 +4618,11 @@ def test_adapt_query_inline_dict(self): self.assertEqual( sql, "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])") self.assertEqual(params, []) + values = dict(record=(0, -3.25, '', False, [0], [''])) + sql, params = format_query('select %(record)s', values, inline=True) + self.assertEqual( + sql, "select (0,-3.25,'',false,ARRAY[0],ARRAY[''])") + self.assertEqual(params, []) def test_adapt_query_with_pg_repr(self): format_query = self.adapter.format_query diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 19214c5d..01ed752e 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -15,6 +15,7 @@ import re import unittest from datetime import timedelta +from decimal import Decimal from typing import Any, Sequence import pg # the module under test @@ -854,15 +855,15 @@ class TestCastInterval(unittest.TestCase): 'P-10M-3DT3H55M5.999993S'))] def test_cast_interval(self): + from pg.cast import cast_interval for result, values in self.intervals: - f = pg.cast_interval years, mons, days, hours, mins, secs, usecs = result days += 365 * years + 30 * mons interval = timedelta( days=days, hours=hours, minutes=mins, seconds=secs, microseconds=usecs) for value in values: - self.assertEqual(f(value), interval) + self.assertEqual(cast_interval(value), interval) class TestEscapeFunctions(unittest.TestCase): @@ -970,10 +971,10 @@ def test_set_decimal_point(self): def test_get_decimal(self): r = pg.get_decimal() - self.assertIs(r, pg.Decimal) + self.assertIs(r, Decimal) def test_set_decimal(self): - decimal_class = pg.Decimal + decimal_class = Decimal try: pg.set_decimal(int) r = pg.get_decimal() From 493c6ee81670e0fb8bcd1f397ab87e66e09c6924 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 5 Sep 2023 15:02:30 +0200 Subject: [PATCH 056/118] Nicer initialization of AttrDict in tests --- tests/test_classic_dbwrapper.py | 48 ++++++++++++++++----------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 8ebb8214..2ddde601 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -1360,16 +1360,16 @@ def test_get_attnames_is_attr_dict(self): r = get_attnames('test', flush=True) self.assertIsInstance(r, AttrDict) if self.regtypes: - self.assertEqual(r, AttrDict({ - 'i2': 'smallint', 'i4': 'integer', 'i8': 'bigint', - 'd': 'numeric', 'f4': 'real', 'f8': 'double precision', - 'm': 'money', 'v4': 'character varying', - 'c4': 'character', 't': 'text'})) + self.assertEqual(r, AttrDict( + i2='smallint', i4='integer', i8='bigint', + d='numeric', f4='real', f8='double precision', + m='money', v4='character varying', + c4='character', t='text')) else: - self.assertEqual(r, AttrDict({ - 'i2': 'int', 'i4': 'int', 'i8': 'int', - 'd': 'num', 'f4': 'float', 'f8': 'float', 'm': 'money', - 'v4': 'text', 'c4': 'text', 't': 'text'})) + self.assertEqual(r, AttrDict( + i2='int', i4='int', i8='int', + d='num', f4='float', f8='float', m='money', + v4='text', c4='text', t='text')) r = ' '.join(list(r.keys())) self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t') table = 'test table for get_attnames' @@ -1379,14 +1379,14 @@ def test_get_attnames_is_attr_dict(self): r = get_attnames(table) self.assertIsInstance(r, AttrDict) if self.regtypes: - self.assertEqual(r, AttrDict({ - 'n': 'integer', 'alpha': 'smallint', - 'v': 'character varying', 'gamma': 'character', - 'tau': 'text', 'beta': 'boolean'})) + self.assertEqual(r, AttrDict( + n='integer', alpha='smallint', + v='character varying', gamma='character', + tau='text', beta='boolean')) else: - self.assertEqual(r, AttrDict({ - 'n': 'int', 'alpha': 'int', 'v': 'text', - 'gamma': 'text', 'tau': 'text', 'beta': 'bool'})) + self.assertEqual(r, AttrDict( + n='int', alpha='int', v='text', + gamma='text', tau='text', beta='bool')) r = ' '.join(list(r.keys())) self.assertEqual(r, 'n alpha v gamma tau beta') @@ -4336,10 +4336,10 @@ def test_adapt_query_typed_list(self): t = self.adapter.simple_type typ = t('record') from pg.attrs import AttrDict - typ._get_attnames = lambda _self: AttrDict({ - 'i': t('int'), 'f': t('float'), - 't': t('text'), 'b': t('bool'), - 'i3': t('int[]'), 't3': t('text[]')}) + typ._get_attnames = lambda _self: AttrDict( + i=t('int'), f=t('float'), + t=t('text'), b=t('bool'), + i3=t('int[]'), t3=t('text[]')) types = [typ] sql, params = format_query('select %s', values, types) self.assertEqual(sql, 'select $1') @@ -4452,10 +4452,10 @@ def test_adapt_query_typed_dict(self): t = self.adapter.simple_type typ = t('record') from pg.attrs import AttrDict - typ._get_attnames = lambda _self: AttrDict({ - 'i': t('int'), 'f': t('float'), - 't': t('text'), 'b': t('bool'), - 'i3': t('int[]'), 't3': t('text[]')}) + typ._get_attnames = lambda _self: AttrDict( + i=t('int'), f=t('float'), + t=t('text'), b=t('bool'), + i3=t('int[]'), t3=t('text[]')) types = dict(record=typ) sql, params = format_query('select %(record)s', values, types) self.assertEqual(sql, 'select $1') From 3ff98e3d29d06a1bf5b9685ceaab2be6c0b05f4d Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 5 Sep 2023 22:46:58 +0200 Subject: [PATCH 057/118] Split pgdb package into submodules --- docs/contents/changelog.rst | 7 + pg/__init__.py | 11 +- pg/adapt.py | 2 +- pg/error.py | 27 +- pg/helpers.py | 60 +- pgdb/__init__.py | 1784 +----------------------------- pgdb/adapt.py | 237 ++++ pgdb/cast.py | 581 ++++++++++ pgdb/connect.py | 74 ++ pgdb/connection.py | 156 +++ pgdb/constants.py | 14 + pgdb/cursor.py | 645 +++++++++++ pgdb/typecode.py | 34 + tests/dbapi20.py | 9 +- tests/test_classic.py | 6 +- tests/test_classic_connection.py | 14 +- tests/test_classic_dbwrapper.py | 5 +- tests/test_dbapi20.py | 72 +- 18 files changed, 1917 insertions(+), 1821 deletions(-) create mode 100644 pgdb/adapt.py create mode 100644 pgdb/cast.py create mode 100644 pgdb/connect.py create mode 100644 pgdb/connection.py create mode 100644 pgdb/constants.py create mode 100644 pgdb/cursor.py create mode 100644 pgdb/typecode.py diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 6afc68dd..077893a2 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -5,10 +5,17 @@ Version 6.0 (to be released) ---------------------------- - Removed support for Python versions older than 3.7 (released June 2017) and PostgreSQL older than version 10 (released October 2017). +- Converted the standalone modules `pg` and `pgdb` to packages with + several submodules each. The C extension module is now part of the + `pg` package and wrapped into the pure Python module `pg.core`. +- Added type hints and included a stub file for the C extension module. - Added method `pkeys()` to the `pg.DB` object. - Removed deprecated function `pg.pgnotify()`. - Removed deprecated method `ntuples()` of the `pg.Query` object. - Renamed `pgdb.Type` to `pgdb.DbType` to avoid confusion with `typing.Type`. +- `pg` and `pgdb` now use a shared row factory cache. +- The function `set_row_factory_size()` has been removed. The row cache is now + available as a `RowCache` class with methods `change_size()` and `clear()`. - Modernized code and tools for development, testing, linting and building. Version 5.2.5 (2023-08-28) diff --git a/pg/__init__.py b/pg/__init__.py index e0e1b214..37447c9e 100644 --- a/pg/__init__.py +++ b/pg/__init__.py @@ -95,11 +95,9 @@ version, ) from .db import DB -from .helpers import init_core, set_row_factory_size +from .helpers import RowCache, init_core from .notify import NotificationHandler -__version__ = version - __all__ = [ 'DB', 'Adapter', 'NotificationHandler', 'Typecasts', @@ -110,7 +108,7 @@ 'InvalidResultError', 'MultipleResultsError', 'NoResultError', 'NotSupportedError', 'OperationalError', 'ProgrammingError', - 'Connection', 'Query', + 'Connection', 'Query', 'RowCache', 'INV_READ', 'INV_WRITE', 'POLLING_OK', 'POLLING_FAILED', 'POLLING_READING', 'POLLING_WRITING', 'RESULT_DDL', 'RESULT_DML', 'RESULT_DQL', 'RESULT_EMPTY', @@ -127,9 +125,10 @@ 'set_datestyle', 'set_decimal', 'set_decimal_point', 'set_defbase', 'set_defhost', 'set_defopt', 'set_defpasswd', 'set_defport', 'set_defuser', - 'set_jsondecode', 'set_query_helpers', - 'set_row_factory_size', 'set_typecast', + 'set_jsondecode', 'set_query_helpers', 'set_typecast', 'version', '__version__', ] +__version__ = version + init_core() diff --git a/pg/adapt.py b/pg/adapt.py index fd4705ae..9cbecaaf 100644 --- a/pg/adapt.py +++ b/pg/adapt.py @@ -1,4 +1,4 @@ -"""Adaption of parameters.""" +"""Adaptation of parameters.""" from __future__ import annotations diff --git a/pg/error.py b/pg/error.py index b3164b42..484a1252 100644 --- a/pg/error.py +++ b/pg/error.py @@ -4,9 +4,18 @@ from typing import TypeVar -from .core import DatabaseError, Error, InternalError, ProgrammingError - -__all__ = ['error', 'db_error', 'int_error', 'prg_error'] +from .core import ( + DatabaseError, + Error, + InterfaceError, + InternalError, + OperationalError, + ProgrammingError, +) + +__all__ = [ + 'error', 'db_error', 'if_error', 'int_error', 'op_error', 'prg_error' +] # Error messages @@ -32,4 +41,14 @@ def int_error(msg: str) -> InternalError: def prg_error(msg: str) -> ProgrammingError: """Return ProgrammingError.""" - return error(msg, ProgrammingError) \ No newline at end of file + return error(msg, ProgrammingError) + + +def if_error(msg: str) -> InterfaceError: + """Return InterfaceError.""" + return error(msg, InterfaceError) + + +def op_error(msg: str) -> OperationalError: + """Return OperationalError.""" + return error(msg, OperationalError) diff --git a/pg/helpers.py b/pg/helpers.py index 4426cfbc..53689f6a 100644 --- a/pg/helpers.py +++ b/pg/helpers.py @@ -13,7 +13,7 @@ SomeNamedTuple = Any # alias for accessing arbitrary named tuples __all__ = [ - 'quote_if_unqualified', 'oid_key', 'set_row_factory_size', + 'quote_if_unqualified', 'oid_key', 'QuoteDict', 'RowCache', 'dictiter', 'namediter', 'namednext', 'scalariter' ] @@ -36,30 +36,50 @@ def oid_key(table: str) -> str: """Build oid key from a table name.""" return f'oid({table})' +class QuoteDict(dict): + """Dictionary with auto quoting of its items. -# Row factory + The quote attribute must be set to the desired quote function. + """ -# The result rows for database operations are returned as named tuples -# by default. Since creating namedtuple classes is a somewhat expensive -# operation, we cache up to 1024 of these classes by default. + quote: Callable[[str], str] -@lru_cache(maxsize=1024) -def _row_factory(names: Sequence[str]) -> Callable[[Sequence], NamedTuple]: - """Get a namedtuple factory for row results with the given names.""" - try: - return namedtuple('Row', names, rename=True)._make # type: ignore - except ValueError: # there is still a problem with the field names - names = [f'column_{n}' for n in range(len(names))] - return namedtuple('Row', names)._make # type: ignore + def __getitem__(self, key: str) -> str: + """Get a quoted value.""" + return self.quote(super().__getitem__(key)) -def set_row_factory_size(maxsize: int | None) -> None: - """Change the size of the namedtuple factory cache. +class RowCache: + """Global cache for the named tuples used for table rows. - If maxsize is set to None, the cache can grow without bound. + The result rows for database operations are returned as named tuples + by default. Since creating namedtuple classes is a somewhat expensive + operation, we cache up to 1024 of these classes by default. """ - global _row_factory - _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__) + + @staticmethod + @lru_cache(maxsize=1024) + def row_factory(names: Sequence[str]) -> Callable[[Sequence], NamedTuple]: + """Get a namedtuple factory for row results with the given names.""" + try: + return namedtuple('Row', names, rename=True)._make # type: ignore + except ValueError: # there is still a problem with the field names + names = [f'column_{n}' for n in range(len(names))] + return namedtuple('Row', names)._make # type: ignore + + @classmethod + def clear(cls) -> None: + """Clear the namedtuple factory cache.""" + cls.row_factory.cache_clear() + + @classmethod + def change_size(cls, maxsize: int | None) -> None: + """Change the size of the namedtuple factory cache. + + If maxsize is set to None, the cache can grow without bound. + """ + row_factory = cls.row_factory.__wrapped__ + cls.row_factory = lru_cache(maxsize)(row_factory) # type: ignore # Helper functions used by the query object @@ -73,14 +93,14 @@ def dictiter(q: Query) -> Generator[dict[str, Any], None, None]: def namediter(q: Query) -> Generator[SomeNamedTuple, None, None]: """Get query result as an iterator of named tuples.""" - row = _row_factory(q.listfields()) + row = RowCache.row_factory(q.listfields()) for r in q: yield row(r) def namednext(q: Query) -> SomeNamedTuple: """Get next row from query result as a named tuple.""" - return _row_factory(q.listfields())(next(q)) + return RowCache.row_factory(q.listfields())(next(q)) def scalariter(q: Query) -> Generator[Any, None, None]: diff --git a/pgdb/__init__.py b/pgdb/__init__.py index 74ad38e5..b9a4449a 100644 --- a/pgdb/__init__.py +++ b/pgdb/__init__.py @@ -64,35 +64,7 @@ connection.close() # close the connection """ -from __future__ import annotations - -from collections import namedtuple -from collections.abc import Iterable -from contextlib import suppress -from datetime import date, datetime, time, timedelta, tzinfo -from decimal import Decimal as StdDecimal -from functools import lru_cache, partial -from inspect import signature -from json import dumps as jsonencode -from json import loads as jsondecode -from math import isinf, isnan -from re import compile as regex -from time import localtime -from typing import ( - Any, - Callable, - ClassVar, - Generator, - Mapping, - NamedTuple, - Sequence, - TypeVar, -) -from uuid import UUID as Uuid # noqa: N811 - -# import objects from extension module -from pg import ( - RESULT_DQL, +from pg.core import ( DatabaseError, DataError, Error, @@ -103,20 +75,50 @@ OperationalError, ProgrammingError, Warning, - cast_array, - cast_hstore, - cast_record, - unescape_bytea, version, ) -from pg import ( - Connection as Cnx, # base connection -) -from pg import ( - connect as get_cnx, # get base connection -) -__version__ = version +from .adapt import ( + ARRAY, + BINARY, + BOOL, + DATE, + DATETIME, + FLOAT, + HSTORE, + INTEGER, + INTERVAL, + JSON, + LONG, + MONEY, + NUMBER, + NUMERIC, + RECORD, + ROWID, + SMALLINT, + STRING, + TIME, + TIMESTAMP, + UUID, + Binary, + Date, + DateFromTicks, + DbType, + Hstore, + Interval, + Json, + Literal, + Time, + TimeFromTicks, + Timestamp, + TimestampFromTicks, + Uuid, +) +from .cast import get_typecast, reset_typecast, set_typecast +from .connect import connect +from .connection import Connection +from .constants import apilevel, paramstyle, shortcutmethods, threadsafety +from .cursor import Cursor __all__ = [ 'Connection', 'Cursor', @@ -131,1707 +133,9 @@ 'Error', 'Warning', 'InterfaceError', 'DatabaseError', 'DataError', 'OperationalError', 'IntegrityError', 'InternalError', 'ProgrammingError', 'NotSupportedError', - 'apilevel', 'connect', 'paramstyle', 'threadsafety', 'get_typecast', 'set_typecast', 'reset_typecast', + 'apilevel', 'connect', 'paramstyle', 'shortcutmethods', 'threadsafety', 'version', '__version__', ] -Decimal: type = StdDecimal - - -# *** Module Constants *** - -# compliant with DB API 2.0 -apilevel = '2.0' - -# module may be shared, but not connections -threadsafety = 1 - -# this module use extended python format codes -paramstyle = 'pyformat' - -# shortcut methods have been excluded from DB API 2 and -# are not recommended by the DB SIG, but they can be handy -shortcutmethods = 1 - - -# *** Internal Type Handling *** - -def get_args(func: Callable) -> list: - return list(signature(func).parameters) - - -# time zones used in Postgres timestamptz output -_timezones: dict[str, str] = { - 'CET': '+0100', 'EET': '+0200', 'EST': '-0500', - 'GMT': '+0000', 'HST': '-1000', 'MET': '+0100', 'MST': '-0700', - 'UCT': '+0000', 'UTC': '+0000', 'WET': '+0000' -} - - -def _timezone_as_offset(tz: str) -> str: - if tz.startswith(('+', '-')): - if len(tz) < 5: - return tz + '00' - return tz.replace(':', '') - return _timezones.get(tz, '+0000') - - -def decimal_type(decimal_type: type | None = None) -> type: - """Get or set global type to be used for decimal values. - - Note that connections cache cast functions. To be sure a global change - is picked up by a running connection, call con.type_cache.reset_typecast(). - """ - global Decimal - if decimal_type is not None: - Decimal = decimal_type - set_typecast('numeric', decimal_type) - return Decimal - - -def cast_bool(value: str) -> bool | None: - """Cast boolean value in database format to bool.""" - return value[0] in ('t', 'T') if value else None - - -def cast_money(value: str) -> StdDecimal | None: - """Cast money value in database format to Decimal.""" - if not value: - return None - value = value.replace('(', '-') - return Decimal(''.join(c for c in value if c.isdigit() or c in '.-')) - - -def cast_int2vector(value: str) -> list[int]: - """Cast an int2vector value.""" - return [int(v) for v in value.split()] - - -def cast_date(value: str, cnx: Cnx) -> date: - """Cast a date value.""" - # The output format depends on the server setting DateStyle. The default - # setting ISO and the setting for German are actually unambiguous. The - # order of days and months in the other two settings is however ambiguous, - # so at least here we need to consult the setting to properly parse values. - if value == '-infinity': - return date.min - if value == 'infinity': - return date.max - values = value.split() - if values[-1] == 'BC': - return date.min - value = values[0] - if len(value) > 10: - return date.max - format = cnx.date_format() - return datetime.strptime(value, format).date() - - -def cast_time(value: str) -> time: - """Cast a time value.""" - fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' - return datetime.strptime(value, fmt).time() - - -_re_timezone = regex('(.*)([+-].*)') - - -def cast_timetz(value: str) -> time: - """Cast a timetz value.""" - m = _re_timezone.match(value) - if m: - value, tz = m.groups() - else: - tz = '+0000' - format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' - value += _timezone_as_offset(tz) - format += '%z' - return datetime.strptime(value, format).timetz() - - -def cast_timestamp(value: str, cnx: Cnx) -> datetime: - """Cast a timestamp value.""" - if value == '-infinity': - return datetime.min - if value == 'infinity': - return datetime.max - values = value.split() - if values[-1] == 'BC': - return datetime.min - format = cnx.date_format() - if format.endswith('-%Y') and len(values) > 2: - values = values[1:5] - if len(values[3]) > 4: - return datetime.max - formats = ['%d %b' if format.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] - else: - if len(values[0]) > 10: - return datetime.max - formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] - return datetime.strptime(' '.join(values), ' '.join(formats)) - - -def cast_timestamptz(value: str, cnx: Cnx) -> datetime: - """Cast a timestamptz value.""" - if value == '-infinity': - return datetime.min - if value == 'infinity': - return datetime.max - values = value.split() - if values[-1] == 'BC': - return datetime.min - format = cnx.date_format() - if format.endswith('-%Y') and len(values) > 2: - values = values[1:] - if len(values[3]) > 4: - return datetime.max - formats = ['%d %b' if format.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] - values, tz = values[:-1], values[-1] - else: - if format.startswith('%Y-'): - m = _re_timezone.match(values[1]) - if m: - values[1], tz = m.groups() - else: - tz = '+0000' - else: - values, tz = values[:-1], values[-1] - if len(values[0]) > 10: - return datetime.max - formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] - values.append(_timezone_as_offset(tz)) - formats.append('%z') - return datetime.strptime(' '.join(values), ' '.join(formats)) - - -_re_interval_sql_standard = regex( - '(?:([+-])?([0-9]+)-([0-9]+) ?)?' - '(?:([+-]?[0-9]+)(?!:) ?)?' - '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') - -_re_interval_postgres = regex( - '(?:([+-]?[0-9]+) ?years? ?)?' - '(?:([+-]?[0-9]+) ?mons? ?)?' - '(?:([+-]?[0-9]+) ?days? ?)?' - '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') - -_re_interval_postgres_verbose = regex( - '@ ?(?:([+-]?[0-9]+) ?years? ?)?' - '(?:([+-]?[0-9]+) ?mons? ?)?' - '(?:([+-]?[0-9]+) ?days? ?)?' - '(?:([+-]?[0-9]+) ?hours? ?)?' - '(?:([+-]?[0-9]+) ?mins? ?)?' - '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?') - -_re_interval_iso_8601 = regex( - 'P(?:([+-]?[0-9]+)Y)?' - '(?:([+-]?[0-9]+)M)?' - '(?:([+-]?[0-9]+)D)?' - '(?:T(?:([+-]?[0-9]+)H)?' - '(?:([+-]?[0-9]+)M)?' - '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?') - - -def cast_interval(value: str) -> timedelta: - """Cast an interval value.""" - # The output format depends on the server setting IntervalStyle, but it's - # not necessary to consult this setting to parse it. It's faster to just - # check all possible formats, and there is no ambiguity here. - m = _re_interval_iso_8601.match(value) - if m: - s = [v or '0' for v in m.groups()] - secs_ago = s.pop(5) == '-' - d = [int(v) for v in s] - years, mons, days, hours, mins, secs, usecs = d - if secs_ago: - secs = -secs - usecs = -usecs - else: - m = _re_interval_postgres_verbose.match(value) - if m: - s, ago = [v or '0' for v in m.groups()[:8]], m.group(9) - secs_ago = s.pop(5) == '-' - d = [-int(v) for v in s] if ago else [int(v) for v in s] - years, mons, days, hours, mins, secs, usecs = d - if secs_ago: - secs = - secs - usecs = -usecs - else: - m = _re_interval_postgres.match(value) - if m and any(m.groups()): - s = [v or '0' for v in m.groups()] - hours_ago = s.pop(3) == '-' - d = [int(v) for v in s] - years, mons, days, hours, mins, secs, usecs = d - if hours_ago: - hours = -hours - mins = -mins - secs = -secs - usecs = -usecs - else: - m = _re_interval_sql_standard.match(value) - if m and any(m.groups()): - s = [v or '0' for v in m.groups()] - years_ago = s.pop(0) == '-' - hours_ago = s.pop(3) == '-' - d = [int(v) for v in s] - years, mons, days, hours, mins, secs, usecs = d - if years_ago: - years = -years - mons = -mons - if hours_ago: - hours = -hours - mins = -mins - secs = -secs - usecs = -usecs - else: - raise ValueError(f'Cannot parse interval: {value}') - days += 365 * years + 30 * mons - return timedelta(days=days, hours=hours, minutes=mins, - seconds=secs, microseconds=usecs) - - -class Typecasts(dict): - """Dictionary mapping database types to typecast functions. - - The cast functions get passed the string representation of a value in - the database which they need to convert to a Python object. The - passed string will never be None since NULL values are already - handled before the cast function is called. - """ - - # the default cast functions - # (str functions are ignored but have been added for faster access) - defaults: ClassVar[dict[str, Callable]] = { - 'char': str, 'bpchar': str, 'name': str, - 'text': str, 'varchar': str, 'sql_identifier': str, - 'bool': cast_bool, 'bytea': unescape_bytea, - 'int2': int, 'int4': int, 'serial': int, 'int8': int, 'oid': int, - 'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode, - 'float4': float, 'float8': float, - 'numeric': Decimal, 'money': cast_money, - 'date': cast_date, 'interval': cast_interval, - 'time': cast_time, 'timetz': cast_timetz, - 'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz, - 'int2vector': cast_int2vector, 'uuid': Uuid, - 'anyarray': cast_array, 'record': cast_record} - - cnx: Cnx | None = None # for local connection specific instances - - def __missing__(self, typ: str) -> Callable | None: - """Create a cast function if it is not cached. - - Note that this class never raises a KeyError, - but returns None when no special cast function exists. - """ - if not isinstance(typ, str): - raise TypeError(f'Invalid type: {typ}') - cast = self.defaults.get(typ) - if cast: - # store default for faster access - cast = self._add_connection(cast) - self[typ] = cast - elif typ.startswith('_'): - # create array cast - base_cast = self[typ[1:]] - cast = self.create_array_cast(base_cast) - if base_cast: - # store only if base type exists - self[typ] = cast - return cast - - @staticmethod - def _needs_connection(func: Callable) -> bool: - """Check if a typecast function needs a connection argument.""" - try: - args = get_args(func) - except (TypeError, ValueError): - return False - return 'cnx' in args[1:] - - def _add_connection(self, cast: Callable) -> Callable: - """Add a connection argument to the typecast function if necessary.""" - if not self.cnx or not self._needs_connection(cast): - return cast - return partial(cast, cnx=self.cnx) - - def get(self, typ: str, default: Callable | None = None # type: ignore - ) -> Callable | None: - """Get the typecast function for the given database type.""" - return self[typ] or default - - def set(self, typ: str | Sequence[str], cast: Callable | None) -> None: - """Set a typecast function for the specified database type(s).""" - if isinstance(typ, str): - typ = [typ] - if cast is None: - for t in typ: - self.pop(t, None) - self.pop(f'_{t}', None) - else: - if not callable(cast): - raise TypeError("Cast parameter must be callable") - for t in typ: - self[t] = self._add_connection(cast) - self.pop(f'_{t}', None) - - def reset(self, typ: str | Sequence[str] | None = None) -> None: - """Reset the typecasts for the specified type(s) to their defaults. - - When no type is specified, all typecasts will be reset. - """ - defaults = self.defaults - if typ is None: - self.clear() - self.update(defaults) - else: - if isinstance(typ, str): - typ = [typ] - for t in typ: - cast = defaults.get(t) - if cast: - self[t] = self._add_connection(cast) - t = f'_{t}' - cast = defaults.get(t) - if cast: - self[t] = self._add_connection(cast) - else: - self.pop(t, None) - else: - self.pop(t, None) - self.pop(f'_{t}', None) - - def create_array_cast(self, basecast: Callable) -> Callable: - """Create an array typecast for the given base cast.""" - cast_array = self['anyarray'] - - def cast(v: Any) -> list: - return cast_array(v, basecast) - return cast - - def create_record_cast(self, name: str, fields: Sequence[str], - casts: Sequence[str]) -> Callable: - """Create a named record typecast for the given fields and casts.""" - cast_record = self['record'] - record = namedtuple(name, fields) # type: ignore - - def cast(v: Any) -> record: - # noinspection PyArgumentList - return record(*cast_record(v, casts)) - return cast - - -_typecasts = Typecasts() # this is the global typecast dictionary - - -def get_typecast(typ: str) -> Callable | None: - """Get the global typecast function for the given database type.""" - return _typecasts.get(typ) - - -def set_typecast(typ: str | Sequence[str], cast: Callable | None) -> None: - """Set a global typecast function for the given database type(s). - - Note that connections cache cast functions. To be sure a global change - is picked up by a running connection, call con.type_cache.reset_typecast(). - """ - _typecasts.set(typ, cast) - - -def reset_typecast(typ: str | Sequence[str] | None = None) -> None: - """Reset the global typecasts for the given type(s) to their default. - - When no type is specified, all typecasts will be reset. - - Note that connections cache cast functions. To be sure a global change - is picked up by a running connection, call con.type_cache.reset_typecast(). - """ - _typecasts.reset(typ) - - -class LocalTypecasts(Typecasts): - """Map typecasts, including local composite types, to cast functions.""" - - defaults = _typecasts - - cnx: Cnx | None = None # set in connection specific instances - - def __missing__(self, typ: str) -> Callable | None: - """Create a cast function if it is not cached.""" - cast: Callable | None - if typ.startswith('_'): - base_cast = self[typ[1:]] - cast = self.create_array_cast(base_cast) - if base_cast: - self[typ] = cast - else: - cast = self.defaults.get(typ) - if cast: - cast = self._add_connection(cast) - self[typ] = cast - else: - fields = self.get_fields(typ) - if fields: - casts = [self[field.type] for field in fields] - field_names = [field.name for field in fields] - cast = self.create_record_cast(typ, field_names, casts) - self[typ] = cast - return cast - - # noinspection PyMethodMayBeStatic,PyUnusedLocal - def get_fields(self, typ: str) -> list[FieldInfo]: - """Return the fields for the given record type. - - This method will be replaced with a method that looks up the fields - using the type cache of the connection. - """ - return [] - - -class TypeCode(str): - """Class representing the type_code used by the DB-API 2.0. - - TypeCode objects are strings equal to the PostgreSQL type name, - but carry some additional information. - """ - - oid: int - len: int - type: str - category: str - delim: str - relid: int - - # noinspection PyShadowingBuiltins - @classmethod - def create(cls, oid: int, name: str, len: int, type: str, category: str, - delim: str, relid: int) -> TypeCode: - """Create a type code for a PostgreSQL data type.""" - self = cls(name) - self.oid = oid - self.len = len - self.type = type - self.category = category - self.delim = delim - self.relid = relid - return self - - -FieldInfo = namedtuple('FieldInfo', ('name', 'type')) - - -class TypeCache(dict): - """Cache for database types. - - This cache maps type OIDs and names to TypeCode strings containing - important information on the associated database type. - """ - - def __init__(self, cnx: Cnx) -> None: - """Initialize type cache for connection.""" - super().__init__() - self._escape_string = cnx.escape_string - self._src = cnx.source() - self._typecasts = LocalTypecasts() - self._typecasts.get_fields = self.get_fields # type: ignore - self._typecasts.cnx = cnx - self._query_pg_type = ( - "SELECT oid, typname," - " typlen, typtype, typcategory, typdelim, typrelid" - " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) {}") - - def __missing__(self, key: int | str) -> TypeCode: - """Get the type info from the database if it is not cached.""" - oid: int | str - if isinstance(key, int): - oid = key - else: - if '.' not in key and '"' not in key: - key = f'"{key}"' - oid = f"'{self._escape_string(key)}'::pg_catalog.regtype" - try: - self._src.execute(self._query_pg_type.format(oid)) - except ProgrammingError: - res = None - else: - res = self._src.fetch(1) - if not res: - raise KeyError(f'Type {key} could not be found') - r = res[0] - type_code = TypeCode.create( - int(r[0]), r[1], int(r[2]), r[3], r[4], r[5], int(r[6])) - # noinspection PyUnresolvedReferences - self[type_code.oid] = self[str(type_code)] = type_code - return type_code - - def get(self, key: int | str, # type: ignore - default: TypeCode | None = None) -> TypeCode | None: - """Get the type even if it is not cached.""" - try: - return self[key] - except KeyError: - return default - - def get_fields(self, typ: int | str | TypeCode) -> list[FieldInfo] | None: - """Get the names and types of the fields of composite types.""" - if isinstance(typ, TypeCode): - relid = typ.relid - else: - type_code = self.get(typ) - if not type_code: - return None - relid = type_code.relid - if not relid: - return None # this type is not composite - self._src.execute( - "SELECT attname, atttypid" # noqa: S608 - " FROM pg_catalog.pg_attribute" - f" WHERE attrelid OPERATOR(pg_catalog.=) {relid}" - " AND attnum OPERATOR(pg_catalog.>) 0" - " AND NOT attisdropped ORDER BY attnum") - return [FieldInfo(name, self.get(int(oid))) - for name, oid in self._src.fetch(-1)] - - def get_typecast(self, typ: str) -> Callable | None: - """Get the typecast function for the given database type.""" - return self._typecasts[typ] - - def set_typecast(self, typ: str | Sequence[str], - cast: Callable | None) -> None: - """Set a typecast function for the specified database type(s).""" - self._typecasts.set(typ, cast) - - def reset_typecast(self, typ: str | Sequence[str] | None = None) -> None: - """Reset the typecast function for the specified database type(s).""" - self._typecasts.reset(typ) - - def typecast(self, value: Any, typ: str) -> Any: - """Cast the given value according to the given database type.""" - if value is None: - # for NULL values, no typecast is necessary - return None - cast = self._typecasts[typ] - if cast is None or cast is str: - # no typecast is necessary - return value - return cast(value) - - def get_row_caster(self, types: Sequence[str]) -> Callable: - """Get a typecast function for a complete row of values.""" - typecasts = self._typecasts - casts = [typecasts[typ] for typ in types] - casts = [cast if cast is not str else None for cast in casts] - - def row_caster(row: Sequence) -> Sequence: - return [value if cast is None or value is None else cast(value) - for cast, value in zip(casts, row)] - - return row_caster - - -class _QuoteDict(dict): - """Dictionary with auto quoting of its items. - - The quote attribute must be set to the desired quote function. - """ - - quote: Callable[[str], str] - - def __getitem__(self, key: str) -> str: - # noinspection PyUnresolvedReferences - return self.quote(super().__getitem__(key)) - - -# *** Error Messages *** - -E = TypeVar('E', bound=Error) - - -def _error(msg: str, cls: type[E]) -> E: - """Return specified error object with empty sqlstate attribute.""" - error = cls(msg) - if isinstance(error, DatabaseError): - error.sqlstate = None - return error - - -def _db_error(msg: str) -> DatabaseError: - """Return DatabaseError.""" - return _error(msg, DatabaseError) - - -def _if_error(msg: str) -> InterfaceError: - """Return InterfaceError.""" - return _error(msg, InterfaceError) - - -def _op_error(msg: str) -> OperationalError: - """Return OperationalError.""" - return _error(msg, OperationalError) - - -# *** Row Tuples *** - -# The result rows for database operations are returned as named tuples -# by default. Since creating namedtuple classes is a somewhat expensive -# operation, we cache up to 1024 of these classes by default. - -# noinspection PyUnresolvedReferences -@lru_cache(maxsize=1024) -def _row_factory(names: Sequence[str]) -> Callable[[Sequence], NamedTuple]: - """Get a namedtuple factory for row results with the given names.""" - try: - return namedtuple('Row', names, rename=True)._make # type: ignore - except ValueError: # there is still a problem with the field names - names = [f'column_{n}' for n in range(len(names))] - return namedtuple('Row', names)._make # type: ignore - - -def set_row_factory_size(maxsize: int | None) -> None: - """Change the size of the namedtuple factory cache. - - If maxsize is set to None, the cache can grow without bound. - """ - # noinspection PyGlobalUndefined - global _row_factory - _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__) - - -# *** Cursor Object *** - -class Cursor: - """Cursor object.""" - - def __init__(self, connection: Connection) -> None: - """Create a cursor object for the database connection.""" - self.connection = self._connection = connection - cnx = connection._cnx - if not cnx: - raise _op_error("Connection has been closed") - self._cnx: Cnx = cnx - self.type_cache: TypeCache = connection.type_cache - self._src = self._cnx.source() - # the official attribute for describing the result columns - self._description: list[CursorDescription] | bool | None = None - if self.row_factory is Cursor.row_factory: - # the row factory needs to be determined dynamically - self.row_factory = None # type: ignore - else: - self.build_row_factory = None # type: ignore - self.rowcount: int | None = -1 - self.arraysize: int = 1 - self.lastrowid: int | None = None - - def __iter__(self) -> Cursor: - """Make cursor compatible to the iteration protocol.""" - return self - - def __enter__(self) -> Cursor: - """Enter the runtime context for the cursor object.""" - return self - - def __exit__(self, et: type[BaseException] | None, - ev: BaseException | None, tb: Any) -> None: - """Exit the runtime context for the cursor object.""" - self.close() - - def _quote(self, value: Any) -> Any: - """Quote value depending on its type.""" - if value is None: - return 'NULL' - if isinstance(value, (Hstore, Json)): - value = str(value) - if isinstance(value, (bytes, str)): - cnx = self._cnx - if isinstance(value, Binary): - value = cnx.escape_bytea(value).decode('ascii') - else: - value = cnx.escape_string(value) - return f"'{value}'" - if isinstance(value, float): - if isinf(value): - return "'-Infinity'" if value < 0 else "'Infinity'" - if isnan(value): - return "'NaN'" - return value - if isinstance(value, (int, Decimal, Literal)): - return value - if isinstance(value, datetime): - if value.tzinfo: - return f"'{value}'::timestamptz" - return f"'{value}'::timestamp" - if isinstance(value, date): - return f"'{value}'::date" - if isinstance(value, time): - if value.tzinfo: - return f"'{value}'::timetz" - return f"'{value}'::time" - if isinstance(value, timedelta): - return f"'{value}'::interval" - if isinstance(value, Uuid): - return f"'{value}'::uuid" - if isinstance(value, list): - # Quote value as an ARRAY constructor. This is better than using - # an array literal because it carries the information that this is - # an array and not a string. One issue with this syntax is that - # you need to add an explicit typecast when passing empty arrays. - # The ARRAY keyword is actually only necessary at the top level. - if not value: # exception for empty array - return "'{}'" - q = self._quote - v = ','.join(str(q(v)) for v in value) - return f'ARRAY[{v}]' - if isinstance(value, tuple): - # Quote as a ROW constructor. This is better than using a record - # literal because it carries the information that this is a record - # and not a string. We don't use the keyword ROW in order to make - # this usable with the IN syntax as well. It is only necessary - # when the records has a single column which is not really useful. - q = self._quote - v = ','.join(str(q(v)) for v in value) - return f'({v})' - try: # noinspection PyUnresolvedReferences - value = value.__pg_repr__() - except AttributeError as e: - raise InterfaceError( - f'Do not know how to adapt type {type(value)}') from e - if isinstance(value, (tuple, list)): - value = self._quote(value) - return value - - def _quoteparams(self, string: str, - parameters: Mapping | Sequence | None) -> str: - """Quote parameters. - - This function works for both mappings and sequences. - - The function should be used even when there are no parameters, - so that we have a consistent behavior regarding percent signs. - """ - if not parameters: - try: - return string % () # unescape literal quotes if possible - except (TypeError, ValueError): - return string # silently accept unescaped quotes - if isinstance(parameters, dict): - parameters = _QuoteDict(parameters) - parameters.quote = self._quote - else: - parameters = tuple(map(self._quote, parameters)) - return string % parameters - - def _make_description(self, info: tuple[int, str, int, int, int] - ) -> CursorDescription: - """Make the description tuple for the given field info.""" - name, typ, size, mod = info[1:] - type_code = self.type_cache[typ] - if mod > 0: - mod -= 4 - precision: int | None - scale: int | None - if type_code == 'numeric': - precision, scale = mod >> 16, mod & 0xffff - size = precision - else: - if not size: - size = type_code.size - if size == -1: - size = mod - precision = scale = None - return CursorDescription( - name, type_code, None, size, precision, scale, None) - - @property - def description(self) -> list[CursorDescription] | None: - """Read-only attribute describing the result columns.""" - description = self._description - if description is None: - return None - if not isinstance(description, list): - make = self._make_description - description = [make(info) for info in self._src.listinfo()] - self._description = description - return description - - @property - def colnames(self) -> Sequence[str] | None: - """Unofficial convenience method for getting the column names.""" - description = self.description - return None if description is None else [d[0] for d in description] - - @property - def coltypes(self) -> Sequence[TypeCode] | None: - """Unofficial convenience method for getting the column types.""" - description = self.description - return None if description is None else [d[1] for d in description] - - def close(self) -> None: - """Close the cursor object.""" - self._src.close() - - def execute(self, operation: str, parameters: Sequence | None = None - ) -> Cursor: - """Prepare and execute a database operation (query or command).""" - # The parameters may also be specified as list of tuples to e.g. - # insert multiple rows in a single operation, but this kind of - # usage is deprecated. We make several plausibility checks because - # tuples can also be passed with the meaning of ROW constructors. - if (parameters and isinstance(parameters, list) - and len(parameters) > 1 - and all(isinstance(p, tuple) for p in parameters) - and all(len(p) == len(parameters[0]) for p in parameters[1:])): - return self.executemany(operation, parameters) - # not a list of tuples - return self.executemany(operation, [parameters]) - - def executemany(self, operation: str, - seq_of_parameters: Sequence[Sequence | None]) -> Cursor: - """Prepare operation and execute it against a parameter sequence.""" - if not seq_of_parameters: - # don't do anything without parameters - return self - self._description = None - self.rowcount = -1 - # first try to execute all queries - rowcount = 0 - sql = "BEGIN" - try: - if not self._connection._tnx and not self._connection.autocommit: - try: - self._src.execute(sql) - except DatabaseError: - raise # database provides error message - except Exception as e: - raise _op_error("Can't start transaction") from e - else: - self._connection._tnx = True - for parameters in seq_of_parameters: - sql = operation - sql = self._quoteparams(sql, parameters) - rows = self._src.execute(sql) - if rows: # true if not DML - rowcount += rows - else: - self.rowcount = -1 - except DatabaseError: - raise # database provides error message - except Error as err: - # noinspection PyTypeChecker - raise _if_error(f"Error in '{sql}': '{err}'") from err - except Exception as err: - raise _op_error(f"Internal error in '{sql}': {err}") from err - # then initialize result raw count and description - if self._src.resulttype == RESULT_DQL: - self._description = True # fetch on demand - self.rowcount = self._src.ntuples - self.lastrowid = None - build_row_factory = self.build_row_factory - if build_row_factory: # type: ignore - self.row_factory = build_row_factory() # type: ignore - else: - self.rowcount = rowcount - self.lastrowid = self._src.oidstatus() - # return the cursor object, so you can write statements such as - # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)" - return self - - def fetchone(self) -> Sequence | None: - """Fetch the next row of a query result set.""" - res = self.fetchmany(1, False) - try: - return res[0] - except IndexError: - return None - - def fetchall(self) -> Sequence[Sequence]: - """Fetch all (remaining) rows of a query result.""" - return self.fetchmany(-1, False) - - def fetchmany(self, size: int | None = None, keep: bool = False - ) -> Sequence[Sequence]: - """Fetch the next set of rows of a query result. - - The number of rows to fetch per call is specified by the - size parameter. If it is not given, the cursor's arraysize - determines the number of rows to be fetched. If you set - the keep parameter to true, this is kept as new arraysize. - """ - if size is None: - size = self.arraysize - if keep: - self.arraysize = size - try: - result = self._src.fetch(size) - except DatabaseError: - raise - except Error as err: - raise _db_error(str(err)) from err - row_factory = self.row_factory - coltypes = self.coltypes - if coltypes is None: - # cannot determine column types, return raw result - return [row_factory(row) for row in result] - if len(result) > 5: - # optimize the case where we really fetch many values - # by looking up all type casting functions upfront - cast_row = self.type_cache.get_row_caster(coltypes) - return [row_factory(cast_row(row)) for row in result] - cast_value = self.type_cache.typecast - return [row_factory([cast_value(value, typ) - for typ, value in zip(coltypes, row)]) for row in result] - - def callproc(self, procname: str, parameters: Sequence | None = None - ) -> Sequence | None: - """Call a stored database procedure with the given name. - - The sequence of parameters must contain one entry for each input - argument that the procedure expects. The result of the call is the - same as this input sequence; replacement of output and input/output - parameters in the return value is currently not supported. - - The procedure may also provide a result set as output. These can be - requested through the standard fetch methods of the cursor. - """ - n = len(parameters) if parameters else 0 - s = ','.join(n * ['%s']) - query = f'select * from "{procname}"({s})' # noqa: S608 - self.execute(query, parameters) - return parameters - - # noinspection PyShadowingBuiltins - def copy_from(self, stream: Any, table: str, - format: str | None = None, sep: str | None = None, - null: str | None = None, size: int | None = None, - columns: Sequence[str] | None = None) -> Cursor: - """Copy data from an input stream to the specified table. - - The input stream can be a file-like object with a read() method or - it can also be an iterable returning a row or multiple rows of input - on each iteration. - - The format must be 'text', 'csv' or 'binary'. The sep option sets the - column separator (delimiter) used in the non binary formats. - The null option sets the textual representation of NULL in the input. - - The size option sets the size of the buffer used when reading data - from file-like objects. - - The copy operation can be restricted to a subset of columns. If no - columns are specified, all of them will be copied. - """ - binary_format = format == 'binary' - try: - read = stream.read - except AttributeError as e: - if size: - raise ValueError( - "Size must only be set for file-like objects") from e - input_type: type | tuple[type, ...] - type_name: str - if binary_format: - input_type = bytes - type_name = 'byte strings' - else: - input_type = (bytes, str) - type_name = 'strings' - - if isinstance(stream, (bytes, str)): - if not isinstance(stream, input_type): - raise ValueError(f"The input must be {type_name}") from e - if not binary_format: - if isinstance(stream, str): - if not stream.endswith('\n'): - stream += '\n' - else: - if not stream.endswith(b'\n'): - stream += b'\n' - - def chunks() -> Generator: - yield stream - - elif isinstance(stream, Iterable): - - def chunks() -> Generator: - for chunk in stream: - if not isinstance(chunk, input_type): - raise ValueError( - f"Input stream must consist of {type_name}") - if isinstance(chunk, str): - if not chunk.endswith('\n'): - chunk += '\n' - else: - if not chunk.endswith(b'\n'): - chunk += b'\n' - yield chunk - - else: - raise TypeError("Need an input stream to copy from") from e - else: - if size is None: - size = 8192 - elif not isinstance(size, int): - raise TypeError("The size option must be an integer") - if size > 0: - - def chunks() -> Generator: - while True: - buffer = read(size) - yield buffer - if not buffer or len(buffer) < size: - break - - else: - - def chunks() -> Generator: - yield read() - - if not table or not isinstance(table, str): - raise TypeError("Need a table to copy to") - if table.lower().startswith('select '): - raise ValueError("Must specify a table, not a query") - cnx = self._cnx - table = '.'.join(map(cnx.escape_identifier, table.split('.', 1))) - operation_parts = [f'copy {table}'] - options = [] - parameters = [] - if format is not None: - if not isinstance(format, str): - raise TypeError("The format option must be be a string") - if format not in ('text', 'csv', 'binary'): - raise ValueError("Invalid format") - options.append(f'format {format}') - if sep is not None: - if not isinstance(sep, str): - raise TypeError("The sep option must be a string") - if format == 'binary': - raise ValueError( - "The sep option is not allowed with binary format") - if len(sep) != 1: - raise ValueError( - "The sep option must be a single one-byte character") - options.append('delimiter %s') - parameters.append(sep) - if null is not None: - if not isinstance(null, str): - raise TypeError("The null option must be a string") - options.append('null %s') - parameters.append(null) - if columns: - if not isinstance(columns, str): - columns = ','.join(map(cnx.escape_identifier, columns)) - operation_parts.append(f'({columns})') - operation_parts.append("from stdin") - if options: - operation_parts.append(f"({','.join(options)})") - operation = ' '.join(operation_parts) - - putdata = self._src.putdata - self.execute(operation, parameters) - - try: - for chunk in chunks(): - putdata(chunk) - except BaseException as error: - self.rowcount = -1 - # the following call will re-raise the error - putdata(error) - else: - rowcount = putdata(None) - self.rowcount = -1 if rowcount is None else rowcount - - # return the cursor object, so you can chain operations - return self - - # noinspection PyShadowingBuiltins - def copy_to(self, stream: Any, table: str, - format: str | None = None, sep: str | None = None, - null: str | None = None, decode: bool | None = None, - columns: Sequence[str] | None = None) -> Cursor | Generator: - """Copy data from the specified table to an output stream. - - The output stream can be a file-like object with a write() method or - it can also be None, in which case the method will return a generator - yielding a row on each iteration. - - Output will be returned as byte strings unless you set decode to true. - - Note that you can also use a select query instead of the table name. - - The format must be 'text', 'csv' or 'binary'. The sep option sets the - column separator (delimiter) used in the non binary formats. - The null option sets the textual representation of NULL in the output. - - The copy operation can be restricted to a subset of columns. If no - columns are specified, all of them will be copied. - """ - binary_format = format == 'binary' - if stream is None: - write = None - else: - try: - write = stream.write - except AttributeError as e: - raise TypeError("Need an output stream to copy to") from e - if not table or not isinstance(table, str): - raise TypeError("Need a table to copy to") - cnx = self._cnx - if table.lower().startswith('select '): - if columns: - raise ValueError("Columns must be specified in the query") - table = f'({table})' - else: - table = '.'.join(map(cnx.escape_identifier, table.split('.', 1))) - operation_parts = [f'copy {table}'] - options = [] - parameters = [] - if format is not None: - if not isinstance(format, str): - raise TypeError("The format option must be a string") - if format not in ('text', 'csv', 'binary'): - raise ValueError("Invalid format") - options.append(f'format {format}') - if sep is not None: - if not isinstance(sep, str): - raise TypeError("The sep option must be a string") - if binary_format: - raise ValueError( - "The sep option is not allowed with binary format") - if len(sep) != 1: - raise ValueError( - "The sep option must be a single one-byte character") - options.append('delimiter %s') - parameters.append(sep) - if null is not None: - if not isinstance(null, str): - raise TypeError("The null option must be a string") - options.append('null %s') - parameters.append(null) - if decode is None: - decode = format != 'binary' - else: - if not isinstance(decode, (int, bool)): - raise TypeError("The decode option must be a boolean") - if decode and binary_format: - raise ValueError( - "The decode option is not allowed with binary format") - if columns: - if not isinstance(columns, str): - columns = ','.join(map(cnx.escape_identifier, columns)) - operation_parts.append(f'({columns})') - - operation_parts.append("to stdout") - if options: - operation_parts.append(f"({','.join(options)})") - operation = ' '.join(operation_parts) - - getdata = self._src.getdata - self.execute(operation, parameters) - - def copy() -> Generator: - self.rowcount = 0 - while True: - row = getdata(decode) - if isinstance(row, int): - if self.rowcount != row: - self.rowcount = row - break - self.rowcount += 1 - yield row - - if write is None: - # no input stream, return the generator - return copy() - - # write the rows to the file-like input stream - for row in copy(): - # noinspection PyUnboundLocalVariable - write(row) - - # return the cursor object, so you can chain operations - return self - - def __next__(self) -> Sequence: - """Return the next row (support for the iteration protocol).""" - res = self.fetchone() - if res is None: - raise StopIteration - return res - - # Note that the iterator protocol now uses __next()__ instead of next(), - # but we keep it for backward compatibility of pgdb. - next = __next__ - - @staticmethod - def nextset() -> bool | None: - """Not supported.""" - raise NotSupportedError("The nextset() method is not supported") - - @staticmethod - def setinputsizes(sizes: Sequence[int]) -> None: - """Not supported.""" - pass # unsupported, but silently passed - - @staticmethod - def setoutputsize(size: int, column: int = 0) -> None: - """Not supported.""" - pass # unsupported, but silently passed - - @staticmethod - def row_factory(row: Sequence) -> Sequence: - """Process rows before they are returned. - - You can overwrite this statically with a custom row factory, or - you can build a row factory dynamically with build_row_factory(). - - For example, you can create a Cursor class that returns rows as - Python dictionaries like this: - - class DictCursor(pgdb.Cursor): - - def row_factory(self, row): - return {desc[0]: value - for desc, value in zip(self.description, row)} - - cur = DictCursor(con) # get one DictCursor instance or - con.cursor_type = DictCursor # always use DictCursor instances - """ - raise NotImplementedError - - def build_row_factory(self) -> Callable[[Sequence], Sequence] | None: - """Build a row factory based on the current description. - - This implementation builds a row factory for creating named tuples. - You can overwrite this method if you want to dynamically create - different row factories whenever the column description changes. - """ - names = self.colnames - return _row_factory(tuple(names)) if names else None - - -CursorDescription = namedtuple('CursorDescription', ( - 'name', 'type_code', 'display_size', 'internal_size', - 'precision', 'scale', 'null_ok')) - - -# *** Connection Objects *** - -class Connection: - """Connection object.""" - - # expose the exceptions as attributes on the connection object - Error = Error - Warning = Warning - InterfaceError = InterfaceError - DatabaseError = DatabaseError - InternalError = InternalError - OperationalError = OperationalError - ProgrammingError = ProgrammingError - IntegrityError = IntegrityError - DataError = DataError - NotSupportedError = NotSupportedError - - def __init__(self, cnx: Cnx) -> None: - """Create a database connection object.""" - self._cnx: Cnx | None = cnx # connection - self._tnx = False # transaction state - self.type_cache = TypeCache(cnx) - self.cursor_type = Cursor - self.autocommit = False - try: - self._cnx.source() - except Exception as e: - raise _op_error("Invalid connection") from e - - def __enter__(self) -> Connection: - """Enter the runtime context for the connection object. - - The runtime context can be used for running transactions. - - This also starts a transaction in autocommit mode. - """ - if self.autocommit: - cnx = self._cnx - if not cnx: - raise _op_error("Connection has been closed") - try: - cnx.source().execute("BEGIN") - except DatabaseError: - raise # database provides error message - except Exception as e: - raise _op_error("Can't start transaction") from e - else: - self._tnx = True - return self - - def __exit__(self, et: type[BaseException] | None, - ev: BaseException | None, tb: Any) -> None: - """Exit the runtime context for the connection object. - - This does not close the connection, but it ends a transaction. - """ - if et is None and ev is None and tb is None: - self.commit() - else: - self.rollback() - - def close(self) -> None: - """Close the connection object.""" - if not self._cnx: - raise _op_error("Connection has been closed") - if self._tnx: - with suppress(DatabaseError): - self.rollback() - self._cnx.close() - self._cnx = None - - @property - def closed(self) -> bool: - """Check whether the connection has been closed or is broken.""" - try: - return not self._cnx or self._cnx.status != 1 - except TypeError: - return True - - def commit(self) -> None: - """Commit any pending transaction to the database.""" - if not self._cnx: - raise _op_error("Connection has been closed") - if self._tnx: - self._tnx = False - try: - self._cnx.source().execute("COMMIT") - except DatabaseError: - raise # database provides error message - except Exception as e: - raise _op_error("Can't commit transaction") from e - - def rollback(self) -> None: - """Roll back to the start of any pending transaction.""" - if not self._cnx: - raise _op_error("Connection has been closed") - if self._tnx: - self._tnx = False - try: - self._cnx.source().execute("ROLLBACK") - except DatabaseError: - raise # database provides error message - except Exception as e: - raise _op_error("Can't rollback transaction") from e - - def cursor(self) -> Cursor: - """Return a new cursor object using the connection.""" - if not self._cnx: - raise _op_error("Connection has been closed") - try: - return self.cursor_type(self) - except Exception as e: - raise _op_error("Invalid connection") from e - - if shortcutmethods: # otherwise do not implement and document this - - def execute(self, operation: str, - parameters: Sequence | None = None) -> Cursor: - """Shortcut method to run an operation on an implicit cursor.""" - cursor = self.cursor() - cursor.execute(operation, parameters) - return cursor - - def executemany(self, operation: str, - seq_of_parameters: Sequence[Sequence | None] - ) -> Cursor: - """Shortcut method to run an operation against a sequence.""" - cursor = self.cursor() - cursor.executemany(operation, seq_of_parameters) - return cursor - - -# *** Module Interface *** - -def connect(dsn: str | None = None, - user: str | None = None, password: str | None = None, - host: str | None = None, database: str | None = None, - **kwargs: Any) -> Connection: - """Connect to a database.""" - # first get params from DSN - dbport = -1 - dbhost: str | None = "" - dbname: str | None = "" - dbuser: str | None = "" - dbpasswd: str | None = "" - dbopt: str | None = "" - if dsn: - try: - params = dsn.split(":", 4) - dbhost = params[0] - dbname = params[1] - dbuser = params[2] - dbpasswd = params[3] - dbopt = params[4] - except (AttributeError, IndexError, TypeError): - pass - - # override if necessary - if user is not None: - dbuser = user - if password is not None: - dbpasswd = password - if database is not None: - dbname = database - if host: - try: - params = host.split(":", 1) - dbhost = params[0] - dbport = int(params[1]) - except (AttributeError, IndexError, TypeError, ValueError): - pass - - # empty host is localhost - if dbhost == "": - dbhost = None - if dbuser == "": - dbuser = None - - # pass keyword arguments as connection info string - if kwargs: - kwarg_list = list(kwargs.items()) - kw_parts = [] - if dbname and '=' in dbname: - kw_parts.append(dbname) - else: - kwarg_list.insert(0, ('dbname', dbname)) - for kw, value in kwarg_list: - value = str(value) - if not value or ' ' in value: - value = value.replace('\\', '\\\\').replace("'", "\\'") - value = f"'{value}'" - kw_parts.append(f'{kw}={value}') - dbname = ' '.join(kw_parts) - # open the connection - cnx = get_cnx(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd) - return Connection(cnx) - - -# *** Types Handling *** - -class DbType(frozenset): - """Type class for a couple of PostgreSQL data types. - - PostgreSQL is object-oriented: types are dynamic. - We must thus use type names as internal type codes. - """ - - def __new__(cls, values: str | Iterable[str]) -> DbType: - """Create new type object.""" - if isinstance(values, str): - values = values.split() - return super().__new__(cls, values) # type: ignore - - def __eq__(self, other: Any) -> bool: - """Check whether types are considered equal.""" - if isinstance(other, str): - if other.startswith('_'): - other = other[1:] - return other in self - return super().__eq__(other) - - def __ne__(self, other: Any) -> bool: - """Check whether types are not considered equal.""" - if isinstance(other, str): - if other.startswith('_'): - other = other[1:] - return other not in self - return super().__ne__(other) - - -class ArrayType: - """Type class for PostgreSQL array types.""" - - def __eq__(self, other: Any) -> bool: - if isinstance(other, str): - return other.startswith('_') - return isinstance(other, ArrayType) - - def __ne__(self, other: Any) -> bool: - if isinstance(other, str): - return not other.startswith('_') - return not isinstance(other, ArrayType) - - -class RecordType: - """Type class for PostgreSQL record types.""" - - def __eq__(self, other: Any) -> bool: - if isinstance(other, TypeCode): - # noinspection PyUnresolvedReferences - return other.type == 'c' - if isinstance(other, str): - return other == 'record' - return isinstance(other, RecordType) - - def __ne__(self, other: Any) -> bool: - if isinstance(other, TypeCode): - # noinspection PyUnresolvedReferences - return other.type != 'c' - if isinstance(other, str): - return other != 'record' - return not isinstance(other, RecordType) - - -# Mandatory type objects defined by DB-API 2 specs: - -STRING = DbType('char bpchar name text varchar') -BINARY = DbType('bytea') -NUMBER = DbType('int2 int4 serial int8 float4 float8 numeric money') -DATETIME = DbType('date time timetz timestamp timestamptz interval' - ' abstime reltime') # these are very old -ROWID = DbType('oid') - - -# Additional type objects (more specific): - -BOOL = DbType('bool') -SMALLINT = DbType('int2') -INTEGER = DbType('int2 int4 int8 serial') -LONG = DbType('int8') -FLOAT = DbType('float4 float8') -NUMERIC = DbType('numeric') -MONEY = DbType('money') -DATE = DbType('date') -TIME = DbType('time timetz') -TIMESTAMP = DbType('timestamp timestamptz') -INTERVAL = DbType('interval') -UUID = DbType('uuid') -HSTORE = DbType('hstore') -JSON = DbType('json jsonb') - -# Type object for arrays (also equate to their base types): - -ARRAY = ArrayType() - -# Type object for records (encompassing all composite types): - -RECORD = RecordType() - - -# Mandatory type helpers defined by DB-API 2 specs: - -def Date(year: int, month: int, day: int) -> date: # noqa: N802 - """Construct an object holding a date value.""" - return date(year, month, day) - - -def Time(hour: int, minute: int = 0, # noqa: N802 - second: int = 0, microsecond: int = 0, - tzinfo: tzinfo | None = None) -> time: - """Construct an object holding a time value.""" - return time(hour, minute, second, microsecond, tzinfo) - - -def Timestamp(year: int, month: int, day: int, # noqa: N802 - hour: int = 0, minute: int = 0, - second: int = 0, microsecond: int = 0, - tzinfo: tzinfo | None = None) -> datetime: - """Construct an object holding a time stamp value.""" - return datetime(year, month, day, hour, minute, - second, microsecond, tzinfo) - - -def DateFromTicks(ticks: float | None) -> date: # noqa: N802 - """Construct an object holding a date value from the given ticks value.""" - return Date(*localtime(ticks)[:3]) - - -def TimeFromTicks(ticks: float | None) -> time: # noqa: N802 - """Construct an object holding a time value from the given ticks value.""" - return Time(*localtime(ticks)[3:6]) - - -def TimestampFromTicks(ticks: float | None) -> datetime: # noqa: N802 - """Construct an object holding a time stamp from the given ticks value.""" - return Timestamp(*localtime(ticks)[:6]) - - -class Binary(bytes): - """Construct an object capable of holding a binary (long) string value.""" - - -# Additional type helpers for PyGreSQL: - -def Interval(days: int | float, # noqa: N802 - hours: int | float = 0, minutes: int | float = 0, - seconds: int | float = 0, microseconds: int | float = 0 - ) -> timedelta: - """Construct an object holding a time interval value.""" - return timedelta(days, hours=hours, minutes=minutes, - seconds=seconds, microseconds=microseconds) - - -Uuid = Uuid # Construct an object holding a UUID value - - -class Hstore(dict): - """Wrapper class for marking hstore values.""" - - _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]') - _re_escape = regex(r'(["\\])') - - @classmethod - def _quote(cls, s: Any) -> Any: - if s is None: - return 'NULL' - if not isinstance(s, str): - s = str(s) - if not s: - return '""' - quote = cls._re_quote.search(s) - s = cls._re_escape.sub(r'\\\1', s) - if quote: - s = f'"{s}"' - return s - - def __str__(self) -> str: - """Create a printable representation of the hstore value.""" - q = self._quote - return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) - - -class Json: - """Construct a wrapper for holding an object serializable to JSON.""" - - def __init__(self, obj: Any, - encode: Callable[[Any], str] | None = None) -> None: - """Initialize the JSON object.""" - self.obj = obj - self.encode = encode or jsonencode - - def __str__(self) -> str: - """Create a printable representation of the JSON object.""" - obj = self.obj - if isinstance(obj, str): - return obj - return self.encode(obj) - - -class Literal: - """Construct a wrapper for holding a literal SQL string.""" - - def __init__(self, sql: str) -> None: - """Initialize literal SQL string.""" - self.sql = sql - - def __str__(self) -> str: - """Return a printable representation of the SQL string.""" - return self.sql - - __pg_repr__ = __str__ - - -# If run as script, print some information: - -if __name__ == '__main__': - print('PyGreSQL version', version) - print() - print(__doc__) +__version__ = version diff --git a/pgdb/adapt.py b/pgdb/adapt.py new file mode 100644 index 00000000..92b48a7e --- /dev/null +++ b/pgdb/adapt.py @@ -0,0 +1,237 @@ +"""Type helpers for adaptation of parameters.""" + +from __future__ import annotations + +from datetime import date, datetime, time, timedelta, tzinfo +from json import dumps as jsonencode +from re import compile as regex +from time import localtime +from typing import Any, Callable, Iterable +from uuid import UUID as Uuid # noqa: N811 + +from .typecode import TypeCode + +__all__ = [ + 'DbType', 'ArrayType', 'RecordType', + 'STRING', 'BINARY', 'NUMBER', 'DATETIME', 'ROWID', 'BOOL', 'SMALLINT', + 'INTEGER', 'LONG', 'FLOAT', 'NUMERIC', 'MONEY', 'DATE', 'TIME', + 'TIMESTAMP', 'INTERVAL', 'UUID', 'HSTORE', 'JSON', 'ARRAY', 'RECORD', + 'Date', 'Time', 'Timestamp', + 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks' + +] + + +class DbType(frozenset): + """Type class for a couple of PostgreSQL data types. + + PostgreSQL is object-oriented: types are dynamic. + We must thus use type names as internal type codes. + """ + + def __new__(cls, values: str | Iterable[str]) -> DbType: + """Create new type object.""" + if isinstance(values, str): + values = values.split() + return super().__new__(cls, values) # type: ignore + + def __eq__(self, other: Any) -> bool: + """Check whether types are considered equal.""" + if isinstance(other, str): + if other.startswith('_'): + other = other[1:] + return other in self + return super().__eq__(other) + + def __ne__(self, other: Any) -> bool: + """Check whether types are not considered equal.""" + if isinstance(other, str): + if other.startswith('_'): + other = other[1:] + return other not in self + return super().__ne__(other) + + +class ArrayType: + """Type class for PostgreSQL array types.""" + + def __eq__(self, other: Any) -> bool: + """Check whether arrays are equal.""" + if isinstance(other, str): + return other.startswith('_') + return isinstance(other, ArrayType) + + def __ne__(self, other: Any) -> bool: + """Check whether arrays are different.""" + if isinstance(other, str): + return not other.startswith('_') + return not isinstance(other, ArrayType) + + +class RecordType: + """Type class for PostgreSQL record types.""" + + def __eq__(self, other: Any) -> bool: + """Check whether records are equal.""" + if isinstance(other, TypeCode): + return other.type == 'c' + if isinstance(other, str): + return other == 'record' + return isinstance(other, RecordType) + + def __ne__(self, other: Any) -> bool: + """Check whether records are different.""" + if isinstance(other, TypeCode): + return other.type != 'c' + if isinstance(other, str): + return other != 'record' + return not isinstance(other, RecordType) + + +# Mandatory type objects defined by DB-API 2 specs: + +STRING = DbType('char bpchar name text varchar') +BINARY = DbType('bytea') +NUMBER = DbType('int2 int4 serial int8 float4 float8 numeric money') +DATETIME = DbType('date time timetz timestamp timestamptz interval' + ' abstime reltime') # these are very old +ROWID = DbType('oid') + + +# Additional type objects (more specific): + +BOOL = DbType('bool') +SMALLINT = DbType('int2') +INTEGER = DbType('int2 int4 int8 serial') +LONG = DbType('int8') +FLOAT = DbType('float4 float8') +NUMERIC = DbType('numeric') +MONEY = DbType('money') +DATE = DbType('date') +TIME = DbType('time timetz') +TIMESTAMP = DbType('timestamp timestamptz') +INTERVAL = DbType('interval') +UUID = DbType('uuid') +HSTORE = DbType('hstore') +JSON = DbType('json jsonb') + +# Type object for arrays (also equate to their base types): + +ARRAY = ArrayType() + +# Type object for records (encompassing all composite types): + +RECORD = RecordType() + + +# Mandatory type helpers defined by DB-API 2 specs: + +def Date(year: int, month: int, day: int) -> date: # noqa: N802 + """Construct an object holding a date value.""" + return date(year, month, day) + + +def Time(hour: int, minute: int = 0, # noqa: N802 + second: int = 0, microsecond: int = 0, + tzinfo: tzinfo | None = None) -> time: + """Construct an object holding a time value.""" + return time(hour, minute, second, microsecond, tzinfo) + + +def Timestamp(year: int, month: int, day: int, # noqa: N802 + hour: int = 0, minute: int = 0, + second: int = 0, microsecond: int = 0, + tzinfo: tzinfo | None = None) -> datetime: + """Construct an object holding a time stamp value.""" + return datetime(year, month, day, hour, minute, + second, microsecond, tzinfo) + + +def DateFromTicks(ticks: float | None) -> date: # noqa: N802 + """Construct an object holding a date value from the given ticks value.""" + return Date(*localtime(ticks)[:3]) + + +def TimeFromTicks(ticks: float | None) -> time: # noqa: N802 + """Construct an object holding a time value from the given ticks value.""" + return Time(*localtime(ticks)[3:6]) + + +def TimestampFromTicks(ticks: float | None) -> datetime: # noqa: N802 + """Construct an object holding a time stamp from the given ticks value.""" + return Timestamp(*localtime(ticks)[:6]) + + +class Binary(bytes): + """Construct an object capable of holding a binary (long) string value.""" + + +# Additional type helpers for PyGreSQL: + +def Interval(days: int | float, # noqa: N802 + hours: int | float = 0, minutes: int | float = 0, + seconds: int | float = 0, microseconds: int | float = 0 + ) -> timedelta: + """Construct an object holding a time interval value.""" + return timedelta(days, hours=hours, minutes=minutes, + seconds=seconds, microseconds=microseconds) + + +Uuid = Uuid # Construct an object holding a UUID value + + +class Hstore(dict): + """Wrapper class for marking hstore values.""" + + _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]') + _re_escape = regex(r'(["\\])') + + @classmethod + def _quote(cls, s: Any) -> Any: + if s is None: + return 'NULL' + if not isinstance(s, str): + s = str(s) + if not s: + return '""' + quote = cls._re_quote.search(s) + s = cls._re_escape.sub(r'\\\1', s) + if quote: + s = f'"{s}"' + return s + + def __str__(self) -> str: + """Create a printable representation of the hstore value.""" + q = self._quote + return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) + + +class Json: + """Construct a wrapper for holding an object serializable to JSON.""" + + def __init__(self, obj: Any, + encode: Callable[[Any], str] | None = None) -> None: + """Initialize the JSON object.""" + self.obj = obj + self.encode = encode or jsonencode + + def __str__(self) -> str: + """Create a printable representation of the JSON object.""" + obj = self.obj + if isinstance(obj, str): + return obj + return self.encode(obj) + + +class Literal: + """Construct a wrapper for holding a literal SQL string.""" + + def __init__(self, sql: str) -> None: + """Initialize literal SQL string.""" + self.sql = sql + + def __str__(self) -> str: + """Return a printable representation of the SQL string.""" + return self.sql + + __pg_repr__ = __str__ \ No newline at end of file diff --git a/pgdb/cast.py b/pgdb/cast.py new file mode 100644 index 00000000..03367506 --- /dev/null +++ b/pgdb/cast.py @@ -0,0 +1,581 @@ +"""Internal type handling.""" + +from __future__ import annotations + +from collections import namedtuple +from datetime import date, datetime, time, timedelta +from decimal import Decimal as _Decimal +from functools import partial +from inspect import signature +from json import loads as jsondecode +from re import compile as regex +from typing import Any, Callable, ClassVar, Sequence +from uuid import UUID as Uuid # noqa: N811 + +from pg.core import Connection as Cnx +from pg.core import ( + ProgrammingError, + cast_array, + cast_hstore, + cast_record, + unescape_bytea, +) + +from .typecode import TypeCode + +__all__ = [ + 'Decimal', 'decimal_type', 'cast_bool', 'cast_money', + 'cast_int2vector', 'cast_date', 'cast_time', 'cast_interval', + 'cast_timetz', 'cast_timestamp', 'cast_timestamptz', + 'get_typecast', 'set_typecast', 'reset_typecast', + 'Typecasts', 'LocalTypecasts', 'TypeCache', 'FieldInfo' +] + + +Decimal: type = _Decimal + + +def get_args(func: Callable) -> list: + return list(signature(func).parameters) + + +# time zones used in Postgres timestamptz output +_timezones: dict[str, str] = { + 'CET': '+0100', 'EET': '+0200', 'EST': '-0500', + 'GMT': '+0000', 'HST': '-1000', 'MET': '+0100', 'MST': '-0700', + 'UCT': '+0000', 'UTC': '+0000', 'WET': '+0000' +} + + +def _timezone_as_offset(tz: str) -> str: + if tz.startswith(('+', '-')): + if len(tz) < 5: + return tz + '00' + return tz.replace(':', '') + return _timezones.get(tz, '+0000') + + +def decimal_type(decimal_type: type | None = None) -> type: + """Get or set global type to be used for decimal values. + + Note that connections cache cast functions. To be sure a global change + is picked up by a running connection, call con.type_cache.reset_typecast(). + """ + global Decimal + if decimal_type is not None: + Decimal = decimal_type + set_typecast('numeric', decimal_type) + return Decimal + + +def cast_bool(value: str) -> bool | None: + """Cast boolean value in database format to bool.""" + return value[0] in ('t', 'T') if value else None + + +def cast_money(value: str) -> _Decimal | None: + """Cast money value in database format to Decimal.""" + if not value: + return None + value = value.replace('(', '-') + return Decimal(''.join(c for c in value if c.isdigit() or c in '.-')) + + +def cast_int2vector(value: str) -> list[int]: + """Cast an int2vector value.""" + return [int(v) for v in value.split()] + + +def cast_date(value: str, cnx: Cnx) -> date: + """Cast a date value.""" + # The output format depends on the server setting DateStyle. The default + # setting ISO and the setting for German are actually unambiguous. The + # order of days and months in the other two settings is however ambiguous, + # so at least here we need to consult the setting to properly parse values. + if value == '-infinity': + return date.min + if value == 'infinity': + return date.max + values = value.split() + if values[-1] == 'BC': + return date.min + value = values[0] + if len(value) > 10: + return date.max + format = cnx.date_format() + return datetime.strptime(value, format).date() + + +def cast_time(value: str) -> time: + """Cast a time value.""" + fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' + return datetime.strptime(value, fmt).time() + + +_re_timezone = regex('(.*)([+-].*)') + + +def cast_timetz(value: str) -> time: + """Cast a timetz value.""" + m = _re_timezone.match(value) + if m: + value, tz = m.groups() + else: + tz = '+0000' + format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' + value += _timezone_as_offset(tz) + format += '%z' + return datetime.strptime(value, format).timetz() + + +def cast_timestamp(value: str, cnx: Cnx) -> datetime: + """Cast a timestamp value.""" + if value == '-infinity': + return datetime.min + if value == 'infinity': + return datetime.max + values = value.split() + if values[-1] == 'BC': + return datetime.min + format = cnx.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:5] + if len(values[3]) > 4: + return datetime.max + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] + else: + if len(values[0]) > 10: + return datetime.max + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + return datetime.strptime(' '.join(values), ' '.join(formats)) + + +def cast_timestamptz(value: str, cnx: Cnx) -> datetime: + """Cast a timestamptz value.""" + if value == '-infinity': + return datetime.min + if value == 'infinity': + return datetime.max + values = value.split() + if values[-1] == 'BC': + return datetime.min + format = cnx.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:] + if len(values[3]) > 4: + return datetime.max + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] + values, tz = values[:-1], values[-1] + else: + if format.startswith('%Y-'): + m = _re_timezone.match(values[1]) + if m: + values[1], tz = m.groups() + else: + tz = '+0000' + else: + values, tz = values[:-1], values[-1] + if len(values[0]) > 10: + return datetime.max + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + values.append(_timezone_as_offset(tz)) + formats.append('%z') + return datetime.strptime(' '.join(values), ' '.join(formats)) + + +_re_interval_sql_standard = regex( + '(?:([+-])?([0-9]+)-([0-9]+) ?)?' + '(?:([+-]?[0-9]+)(?!:) ?)?' + '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') + +_re_interval_postgres = regex( + '(?:([+-]?[0-9]+) ?years? ?)?' + '(?:([+-]?[0-9]+) ?mons? ?)?' + '(?:([+-]?[0-9]+) ?days? ?)?' + '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') + +_re_interval_postgres_verbose = regex( + '@ ?(?:([+-]?[0-9]+) ?years? ?)?' + '(?:([+-]?[0-9]+) ?mons? ?)?' + '(?:([+-]?[0-9]+) ?days? ?)?' + '(?:([+-]?[0-9]+) ?hours? ?)?' + '(?:([+-]?[0-9]+) ?mins? ?)?' + '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?') + +_re_interval_iso_8601 = regex( + 'P(?:([+-]?[0-9]+)Y)?' + '(?:([+-]?[0-9]+)M)?' + '(?:([+-]?[0-9]+)D)?' + '(?:T(?:([+-]?[0-9]+)H)?' + '(?:([+-]?[0-9]+)M)?' + '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?') + + +def cast_interval(value: str) -> timedelta: + """Cast an interval value.""" + # The output format depends on the server setting IntervalStyle, but it's + # not necessary to consult this setting to parse it. It's faster to just + # check all possible formats, and there is no ambiguity here. + m = _re_interval_iso_8601.match(value) + if m: + s = [v or '0' for v in m.groups()] + secs_ago = s.pop(5) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if secs_ago: + secs = -secs + usecs = -usecs + else: + m = _re_interval_postgres_verbose.match(value) + if m: + s, ago = [v or '0' for v in m.groups()[:8]], m.group(9) + secs_ago = s.pop(5) == '-' + d = [-int(v) for v in s] if ago else [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if secs_ago: + secs = - secs + usecs = -usecs + else: + m = _re_interval_postgres.match(value) + if m and any(m.groups()): + s = [v or '0' for v in m.groups()] + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if hours_ago: + hours = -hours + mins = -mins + secs = -secs + usecs = -usecs + else: + m = _re_interval_sql_standard.match(value) + if m and any(m.groups()): + s = [v or '0' for v in m.groups()] + years_ago = s.pop(0) == '-' + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if years_ago: + years = -years + mons = -mons + if hours_ago: + hours = -hours + mins = -mins + secs = -secs + usecs = -usecs + else: + raise ValueError(f'Cannot parse interval: {value}') + days += 365 * years + 30 * mons + return timedelta(days=days, hours=hours, minutes=mins, + seconds=secs, microseconds=usecs) + + +class Typecasts(dict): + """Dictionary mapping database types to typecast functions. + + The cast functions get passed the string representation of a value in + the database which they need to convert to a Python object. The + passed string will never be None since NULL values are already + handled before the cast function is called. + """ + + # the default cast functions + # (str functions are ignored but have been added for faster access) + defaults: ClassVar[dict[str, Callable]] = { + 'char': str, 'bpchar': str, 'name': str, + 'text': str, 'varchar': str, 'sql_identifier': str, + 'bool': cast_bool, 'bytea': unescape_bytea, + 'int2': int, 'int4': int, 'serial': int, 'int8': int, 'oid': int, + 'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode, + 'float4': float, 'float8': float, + 'numeric': Decimal, 'money': cast_money, + 'date': cast_date, 'interval': cast_interval, + 'time': cast_time, 'timetz': cast_timetz, + 'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz, + 'int2vector': cast_int2vector, 'uuid': Uuid, + 'anyarray': cast_array, 'record': cast_record} + + cnx: Cnx | None = None # for local connection specific instances + + def __missing__(self, typ: str) -> Callable | None: + """Create a cast function if it is not cached. + + Note that this class never raises a KeyError, + but returns None when no special cast function exists. + """ + if not isinstance(typ, str): + raise TypeError(f'Invalid type: {typ}') + cast = self.defaults.get(typ) + if cast: + # store default for faster access + cast = self._add_connection(cast) + self[typ] = cast + elif typ.startswith('_'): + # create array cast + base_cast = self[typ[1:]] + cast = self.create_array_cast(base_cast) + if base_cast: + # store only if base type exists + self[typ] = cast + return cast + + @staticmethod + def _needs_connection(func: Callable) -> bool: + """Check if a typecast function needs a connection argument.""" + try: + args = get_args(func) + except (TypeError, ValueError): + return False + return 'cnx' in args[1:] + + def _add_connection(self, cast: Callable) -> Callable: + """Add a connection argument to the typecast function if necessary.""" + if not self.cnx or not self._needs_connection(cast): + return cast + return partial(cast, cnx=self.cnx) + + def get(self, typ: str, default: Callable | None = None # type: ignore + ) -> Callable | None: + """Get the typecast function for the given database type.""" + return self[typ] or default + + def set(self, typ: str | Sequence[str], cast: Callable | None) -> None: + """Set a typecast function for the specified database type(s).""" + if isinstance(typ, str): + typ = [typ] + if cast is None: + for t in typ: + self.pop(t, None) + self.pop(f'_{t}', None) + else: + if not callable(cast): + raise TypeError("Cast parameter must be callable") + for t in typ: + self[t] = self._add_connection(cast) + self.pop(f'_{t}', None) + + def reset(self, typ: str | Sequence[str] | None = None) -> None: + """Reset the typecasts for the specified type(s) to their defaults. + + When no type is specified, all typecasts will be reset. + """ + defaults = self.defaults + if typ is None: + self.clear() + self.update(defaults) + else: + if isinstance(typ, str): + typ = [typ] + for t in typ: + cast = defaults.get(t) + if cast: + self[t] = self._add_connection(cast) + t = f'_{t}' + cast = defaults.get(t) + if cast: + self[t] = self._add_connection(cast) + else: + self.pop(t, None) + else: + self.pop(t, None) + self.pop(f'_{t}', None) + + def create_array_cast(self, basecast: Callable) -> Callable: + """Create an array typecast for the given base cast.""" + cast_array = self['anyarray'] + + def cast(v: Any) -> list: + return cast_array(v, basecast) + return cast + + def create_record_cast(self, name: str, fields: Sequence[str], + casts: Sequence[str]) -> Callable: + """Create a named record typecast for the given fields and casts.""" + cast_record = self['record'] + record = namedtuple(name, fields) # type: ignore + + def cast(v: Any) -> record: + # noinspection PyArgumentList + return record(*cast_record(v, casts)) + return cast + + +_typecasts = Typecasts() # this is the global typecast dictionary + + +def get_typecast(typ: str) -> Callable | None: + """Get the global typecast function for the given database type.""" + return _typecasts.get(typ) + + +def set_typecast(typ: str | Sequence[str], cast: Callable | None) -> None: + """Set a global typecast function for the given database type(s). + + Note that connections cache cast functions. To be sure a global change + is picked up by a running connection, call con.type_cache.reset_typecast(). + """ + _typecasts.set(typ, cast) + + +def reset_typecast(typ: str | Sequence[str] | None = None) -> None: + """Reset the global typecasts for the given type(s) to their default. + + When no type is specified, all typecasts will be reset. + + Note that connections cache cast functions. To be sure a global change + is picked up by a running connection, call con.type_cache.reset_typecast(). + """ + _typecasts.reset(typ) + + +class LocalTypecasts(Typecasts): + """Map typecasts, including local composite types, to cast functions.""" + + defaults = _typecasts + + cnx: Cnx | None = None # set in connection specific instances + + def __missing__(self, typ: str) -> Callable | None: + """Create a cast function if it is not cached.""" + cast: Callable | None + if typ.startswith('_'): + base_cast = self[typ[1:]] + cast = self.create_array_cast(base_cast) + if base_cast: + self[typ] = cast + else: + cast = self.defaults.get(typ) + if cast: + cast = self._add_connection(cast) + self[typ] = cast + else: + fields = self.get_fields(typ) + if fields: + casts = [self[field.type] for field in fields] + field_names = [field.name for field in fields] + cast = self.create_record_cast(typ, field_names, casts) + self[typ] = cast + return cast + + # noinspection PyMethodMayBeStatic,PyUnusedLocal + def get_fields(self, typ: str) -> list[FieldInfo]: + """Return the fields for the given record type. + + This method will be replaced with a method that looks up the fields + using the type cache of the connection. + """ + return [] + + +FieldInfo = namedtuple('FieldInfo', ('name', 'type')) + + +class TypeCache(dict): + """Cache for database types. + + This cache maps type OIDs and names to TypeCode strings containing + important information on the associated database type. + """ + + def __init__(self, cnx: Cnx) -> None: + """Initialize type cache for connection.""" + super().__init__() + self._escape_string = cnx.escape_string + self._src = cnx.source() + self._typecasts = LocalTypecasts() + self._typecasts.get_fields = self.get_fields # type: ignore + self._typecasts.cnx = cnx + self._query_pg_type = ( + "SELECT oid, typname," + " typlen, typtype, typcategory, typdelim, typrelid" + " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) {}") + + def __missing__(self, key: int | str) -> TypeCode: + """Get the type info from the database if it is not cached.""" + oid: int | str + if isinstance(key, int): + oid = key + else: + if '.' not in key and '"' not in key: + key = f'"{key}"' + oid = f"'{self._escape_string(key)}'::pg_catalog.regtype" + try: + self._src.execute(self._query_pg_type.format(oid)) + except ProgrammingError: + res = None + else: + res = self._src.fetch(1) + if not res: + raise KeyError(f'Type {key} could not be found') + r = res[0] + type_code = TypeCode.create( + int(r[0]), r[1], int(r[2]), r[3], r[4], r[5], int(r[6])) + # noinspection PyUnresolvedReferences + self[type_code.oid] = self[str(type_code)] = type_code + return type_code + + def get(self, key: int | str, # type: ignore + default: TypeCode | None = None) -> TypeCode | None: + """Get the type even if it is not cached.""" + try: + return self[key] + except KeyError: + return default + + def get_fields(self, typ: int | str | TypeCode) -> list[FieldInfo] | None: + """Get the names and types of the fields of composite types.""" + if isinstance(typ, TypeCode): + relid = typ.relid + else: + type_code = self.get(typ) + if not type_code: + return None + relid = type_code.relid + if not relid: + return None # this type is not composite + self._src.execute( + "SELECT attname, atttypid" # noqa: S608 + " FROM pg_catalog.pg_attribute" + f" WHERE attrelid OPERATOR(pg_catalog.=) {relid}" + " AND attnum OPERATOR(pg_catalog.>) 0" + " AND NOT attisdropped ORDER BY attnum") + return [FieldInfo(name, self.get(int(oid))) + for name, oid in self._src.fetch(-1)] + + def get_typecast(self, typ: str) -> Callable | None: + """Get the typecast function for the given database type.""" + return self._typecasts[typ] + + def set_typecast(self, typ: str | Sequence[str], + cast: Callable | None) -> None: + """Set a typecast function for the specified database type(s).""" + self._typecasts.set(typ, cast) + + def reset_typecast(self, typ: str | Sequence[str] | None = None) -> None: + """Reset the typecast function for the specified database type(s).""" + self._typecasts.reset(typ) + + def typecast(self, value: Any, typ: str) -> Any: + """Cast the given value according to the given database type.""" + if value is None: + # for NULL values, no typecast is necessary + return None + cast = self._typecasts[typ] + if cast is None or cast is str: + # no typecast is necessary + return value + return cast(value) + + def get_row_caster(self, types: Sequence[str]) -> Callable: + """Get a typecast function for a complete row of values.""" + typecasts = self._typecasts + casts = [typecasts[typ] for typ in types] + casts = [cast if cast is not str else None for cast in casts] + + def row_caster(row: Sequence) -> Sequence: + return [value if cast is None or value is None else cast(value) + for cast, value in zip(casts, row)] + + return row_caster \ No newline at end of file diff --git a/pgdb/connect.py b/pgdb/connect.py new file mode 100644 index 00000000..73b96a36 --- /dev/null +++ b/pgdb/connect.py @@ -0,0 +1,74 @@ +"""The DB API 2 connect function.""" + +from __future__ import annotations + +from typing import Any + +from pg.core import connect as get_cnx + +from .connection import Connection + +__all__ = ['connect'] + +def connect(dsn: str | None = None, + user: str | None = None, password: str | None = None, + host: str | None = None, database: str | None = None, + **kwargs: Any) -> Connection: + """Connect to a database.""" + # first get params from DSN + dbport = -1 + dbhost: str | None = "" + dbname: str | None = "" + dbuser: str | None = "" + dbpasswd: str | None = "" + dbopt: str | None = "" + if dsn: + try: + params = dsn.split(":", 4) + dbhost = params[0] + dbname = params[1] + dbuser = params[2] + dbpasswd = params[3] + dbopt = params[4] + except (AttributeError, IndexError, TypeError): + pass + + # override if necessary + if user is not None: + dbuser = user + if password is not None: + dbpasswd = password + if database is not None: + dbname = database + if host: + try: + params = host.split(":", 1) + dbhost = params[0] + dbport = int(params[1]) + except (AttributeError, IndexError, TypeError, ValueError): + pass + + # empty host is localhost + if dbhost == "": + dbhost = None + if dbuser == "": + dbuser = None + + # pass keyword arguments as connection info string + if kwargs: + kwarg_list = list(kwargs.items()) + kw_parts = [] + if dbname and '=' in dbname: + kw_parts.append(dbname) + else: + kwarg_list.insert(0, ('dbname', dbname)) + for kw, value in kwarg_list: + value = str(value) + if not value or ' ' in value: + value = value.replace('\\', '\\\\').replace("'", "\\'") + value = f"'{value}'" + kw_parts.append(f'{kw}={value}') + dbname = ' '.join(kw_parts) + # open the connection + cnx = get_cnx(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd) + return Connection(cnx) diff --git a/pgdb/connection.py b/pgdb/connection.py new file mode 100644 index 00000000..17d32bcc --- /dev/null +++ b/pgdb/connection.py @@ -0,0 +1,156 @@ +"""The DB API 2 Connection objects.""" + +from __future__ import annotations + +from contextlib import suppress +from typing import Any, Sequence + +from pg.core import Connection as Cnx +from pg.core import ( + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + NotSupportedError, + OperationalError, + ProgrammingError, + Warning, +) +from pg.error import op_error + +from .cast import TypeCache +from .constants import shortcutmethods +from .cursor import Cursor + +__all__ = ['Connection'] + +class Connection: + """Connection object.""" + + # expose the exceptions as attributes on the connection object + Error = Error + Warning = Warning + InterfaceError = InterfaceError + DatabaseError = DatabaseError + InternalError = InternalError + OperationalError = OperationalError + ProgrammingError = ProgrammingError + IntegrityError = IntegrityError + DataError = DataError + NotSupportedError = NotSupportedError + + def __init__(self, cnx: Cnx) -> None: + """Create a database connection object.""" + self._cnx: Cnx | None = cnx # connection + self._tnx = False # transaction state + self.type_cache = TypeCache(cnx) + self.cursor_type = Cursor + self.autocommit = False + try: + self._cnx.source() + except Exception as e: + raise op_error("Invalid connection") from e + + def __enter__(self) -> Connection: + """Enter the runtime context for the connection object. + + The runtime context can be used for running transactions. + + This also starts a transaction in autocommit mode. + """ + if self.autocommit: + cnx = self._cnx + if not cnx: + raise op_error("Connection has been closed") + try: + cnx.source().execute("BEGIN") + except DatabaseError: + raise # database provides error message + except Exception as e: + raise op_error("Can't start transaction") from e + else: + self._tnx = True + return self + + def __exit__(self, et: type[BaseException] | None, + ev: BaseException | None, tb: Any) -> None: + """Exit the runtime context for the connection object. + + This does not close the connection, but it ends a transaction. + """ + if et is None and ev is None and tb is None: + self.commit() + else: + self.rollback() + + def close(self) -> None: + """Close the connection object.""" + if not self._cnx: + raise op_error("Connection has been closed") + if self._tnx: + with suppress(DatabaseError): + self.rollback() + self._cnx.close() + self._cnx = None + + @property + def closed(self) -> bool: + """Check whether the connection has been closed or is broken.""" + try: + return not self._cnx or self._cnx.status != 1 + except TypeError: + return True + + def commit(self) -> None: + """Commit any pending transaction to the database.""" + if not self._cnx: + raise op_error("Connection has been closed") + if self._tnx: + self._tnx = False + try: + self._cnx.source().execute("COMMIT") + except DatabaseError: + raise # database provides error message + except Exception as e: + raise op_error("Can't commit transaction") from e + + def rollback(self) -> None: + """Roll back to the start of any pending transaction.""" + if not self._cnx: + raise op_error("Connection has been closed") + if self._tnx: + self._tnx = False + try: + self._cnx.source().execute("ROLLBACK") + except DatabaseError: + raise # database provides error message + except Exception as e: + raise op_error("Can't rollback transaction") from e + + def cursor(self) -> Cursor: + """Return a new cursor object using the connection.""" + if not self._cnx: + raise op_error("Connection has been closed") + try: + return self.cursor_type(self) + except Exception as e: + raise op_error("Invalid connection") from e + + if shortcutmethods: # otherwise do not implement and document this + + def execute(self, operation: str, + parameters: Sequence | None = None) -> Cursor: + """Shortcut method to run an operation on an implicit cursor.""" + cursor = self.cursor() + cursor.execute(operation, parameters) + return cursor + + def executemany(self, operation: str, + seq_of_parameters: Sequence[Sequence | None] + ) -> Cursor: + """Shortcut method to run an operation against a sequence.""" + cursor = self.cursor() + cursor.executemany(operation, seq_of_parameters) + return cursor \ No newline at end of file diff --git a/pgdb/constants.py b/pgdb/constants.py new file mode 100644 index 00000000..e6547f9c --- /dev/null +++ b/pgdb/constants.py @@ -0,0 +1,14 @@ +"""The DB API 2 module constants.""" + +# compliant with DB API 2.0 +apilevel = '2.0' + +# module may be shared, but not connections +threadsafety = 1 + +# this module use extended python format codes +paramstyle = 'pyformat' + +# shortcut methods have been excluded from DB API 2 and +# are not recommended by the DB SIG, but they can be handy +shortcutmethods = 1 diff --git a/pgdb/cursor.py b/pgdb/cursor.py new file mode 100644 index 00000000..753f4691 --- /dev/null +++ b/pgdb/cursor.py @@ -0,0 +1,645 @@ +"""The DB API 2 Cursor object.""" + +from __future__ import annotations + +from collections import namedtuple +from collections.abc import Iterable +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from math import isinf, isnan +from typing import TYPE_CHECKING, Any, Callable, Generator, Mapping, Sequence +from uuid import UUID as Uuid # noqa: N811 + +from pg.core import ( + RESULT_DQL, + DatabaseError, + Error, + InterfaceError, + NotSupportedError, +) +from pg.core import Connection as Cnx +from pg.error import db_error, if_error, op_error +from pg.helpers import QuoteDict, RowCache + +from .adapt import Binary, Hstore, Json, Literal +from .cast import TypeCache +from .typecode import TypeCode + +if TYPE_CHECKING: + from .connection import Connection + +__all__ = ['Cursor', 'CursorDescription'] + + +class Cursor: + """Cursor object.""" + + def __init__(self, connection: Connection) -> None: + """Create a cursor object for the database connection.""" + self.connection = self._connection = connection + cnx = connection._cnx + if not cnx: + raise op_error("Connection has been closed") + self._cnx: Cnx = cnx + self.type_cache: TypeCache = connection.type_cache + self._src = self._cnx.source() + # the official attribute for describing the result columns + self._description: list[CursorDescription] | bool | None = None + if self.row_factory is Cursor.row_factory: + # the row factory needs to be determined dynamically + self.row_factory = None # type: ignore + else: + self.build_row_factory = None # type: ignore + self.rowcount: int | None = -1 + self.arraysize: int = 1 + self.lastrowid: int | None = None + + def __iter__(self) -> Cursor: + """Make cursor compatible to the iteration protocol.""" + return self + + def __enter__(self) -> Cursor: + """Enter the runtime context for the cursor object.""" + return self + + def __exit__(self, et: type[BaseException] | None, + ev: BaseException | None, tb: Any) -> None: + """Exit the runtime context for the cursor object.""" + self.close() + + def _quote(self, value: Any) -> Any: + """Quote value depending on its type.""" + if value is None: + return 'NULL' + if isinstance(value, (Hstore, Json)): + value = str(value) + if isinstance(value, (bytes, str)): + cnx = self._cnx + if isinstance(value, Binary): + value = cnx.escape_bytea(value).decode('ascii') + else: + value = cnx.escape_string(value) + return f"'{value}'" + if isinstance(value, float): + if isinf(value): + return "'-Infinity'" if value < 0 else "'Infinity'" + if isnan(value): + return "'NaN'" + return value + if isinstance(value, (int, Decimal, Literal)): + return value + if isinstance(value, datetime): + if value.tzinfo: + return f"'{value}'::timestamptz" + return f"'{value}'::timestamp" + if isinstance(value, date): + return f"'{value}'::date" + if isinstance(value, time): + if value.tzinfo: + return f"'{value}'::timetz" + return f"'{value}'::time" + if isinstance(value, timedelta): + return f"'{value}'::interval" + if isinstance(value, Uuid): + return f"'{value}'::uuid" + if isinstance(value, list): + # Quote value as an ARRAY constructor. This is better than using + # an array literal because it carries the information that this is + # an array and not a string. One issue with this syntax is that + # you need to add an explicit typecast when passing empty arrays. + # The ARRAY keyword is actually only necessary at the top level. + if not value: # exception for empty array + return "'{}'" + q = self._quote + v = ','.join(str(q(v)) for v in value) + return f'ARRAY[{v}]' + if isinstance(value, tuple): + # Quote as a ROW constructor. This is better than using a record + # literal because it carries the information that this is a record + # and not a string. We don't use the keyword ROW in order to make + # this usable with the IN syntax as well. It is only necessary + # when the records has a single column which is not really useful. + q = self._quote + v = ','.join(str(q(v)) for v in value) + return f'({v})' + try: # noinspection PyUnresolvedReferences + value = value.__pg_repr__() + except AttributeError as e: + raise InterfaceError( + f'Do not know how to adapt type {type(value)}') from e + if isinstance(value, (tuple, list)): + value = self._quote(value) + return value + + def _quoteparams(self, string: str, + parameters: Mapping | Sequence | None) -> str: + """Quote parameters. + + This function works for both mappings and sequences. + + The function should be used even when there are no parameters, + so that we have a consistent behavior regarding percent signs. + """ + if not parameters: + try: + return string % () # unescape literal quotes if possible + except (TypeError, ValueError): + return string # silently accept unescaped quotes + if isinstance(parameters, dict): + parameters = QuoteDict(parameters) + parameters.quote = self._quote + else: + parameters = tuple(map(self._quote, parameters)) + return string % parameters + + def _make_description(self, info: tuple[int, str, int, int, int] + ) -> CursorDescription: + """Make the description tuple for the given field info.""" + name, typ, size, mod = info[1:] + type_code = self.type_cache[typ] + if mod > 0: + mod -= 4 + precision: int | None + scale: int | None + if type_code == 'numeric': + precision, scale = mod >> 16, mod & 0xffff + size = precision + else: + if not size: + size = type_code.size + if size == -1: + size = mod + precision = scale = None + return CursorDescription( + name, type_code, None, size, precision, scale, None) + + @property + def description(self) -> list[CursorDescription] | None: + """Read-only attribute describing the result columns.""" + description = self._description + if description is None: + return None + if not isinstance(description, list): + make = self._make_description + description = [make(info) for info in self._src.listinfo()] + self._description = description + return description + + @property + def colnames(self) -> Sequence[str] | None: + """Unofficial convenience method for getting the column names.""" + description = self.description + return None if description is None else [d[0] for d in description] + + @property + def coltypes(self) -> Sequence[TypeCode] | None: + """Unofficial convenience method for getting the column types.""" + description = self.description + return None if description is None else [d[1] for d in description] + + def close(self) -> None: + """Close the cursor object.""" + self._src.close() + + def execute(self, operation: str, parameters: Sequence | None = None + ) -> Cursor: + """Prepare and execute a database operation (query or command).""" + # The parameters may also be specified as list of tuples to e.g. + # insert multiple rows in a single operation, but this kind of + # usage is deprecated. We make several plausibility checks because + # tuples can also be passed with the meaning of ROW constructors. + if (parameters and isinstance(parameters, list) + and len(parameters) > 1 + and all(isinstance(p, tuple) for p in parameters) + and all(len(p) == len(parameters[0]) for p in parameters[1:])): + return self.executemany(operation, parameters) + # not a list of tuples + return self.executemany(operation, [parameters]) + + def executemany(self, operation: str, + seq_of_parameters: Sequence[Sequence | None]) -> Cursor: + """Prepare operation and execute it against a parameter sequence.""" + if not seq_of_parameters: + # don't do anything without parameters + return self + self._description = None + self.rowcount = -1 + # first try to execute all queries + rowcount = 0 + sql = "BEGIN" + try: + if not self._connection._tnx and not self._connection.autocommit: + try: + self._src.execute(sql) + except DatabaseError: + raise # database provides error message + except Exception as e: + raise op_error("Can't start transaction") from e + else: + self._connection._tnx = True + for parameters in seq_of_parameters: + sql = operation + sql = self._quoteparams(sql, parameters) + rows = self._src.execute(sql) + if rows: # true if not DML + rowcount += rows + else: + self.rowcount = -1 + except DatabaseError: + raise # database provides error message + except Error as err: + # noinspection PyTypeChecker + raise if_error(f"Error in '{sql}': '{err}'") from err + except Exception as err: + raise op_error(f"Internal error in '{sql}': {err}") from err + # then initialize result raw count and description + if self._src.resulttype == RESULT_DQL: + self._description = True # fetch on demand + self.rowcount = self._src.ntuples + self.lastrowid = None + build_row_factory = self.build_row_factory + if build_row_factory: # type: ignore + self.row_factory = build_row_factory() # type: ignore + else: + self.rowcount = rowcount + self.lastrowid = self._src.oidstatus() + # return the cursor object, so you can write statements such as + # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)" + return self + + def fetchone(self) -> Sequence | None: + """Fetch the next row of a query result set.""" + res = self.fetchmany(1, False) + try: + return res[0] + except IndexError: + return None + + def fetchall(self) -> Sequence[Sequence]: + """Fetch all (remaining) rows of a query result.""" + return self.fetchmany(-1, False) + + def fetchmany(self, size: int | None = None, keep: bool = False + ) -> Sequence[Sequence]: + """Fetch the next set of rows of a query result. + + The number of rows to fetch per call is specified by the + size parameter. If it is not given, the cursor's arraysize + determines the number of rows to be fetched. If you set + the keep parameter to true, this is kept as new arraysize. + """ + if size is None: + size = self.arraysize + if keep: + self.arraysize = size + try: + result = self._src.fetch(size) + except DatabaseError: + raise + except Error as err: + raise db_error(str(err)) from err + row_factory = self.row_factory + coltypes = self.coltypes + if coltypes is None: + # cannot determine column types, return raw result + return [row_factory(row) for row in result] + if len(result) > 5: + # optimize the case where we really fetch many values + # by looking up all type casting functions upfront + cast_row = self.type_cache.get_row_caster(coltypes) + return [row_factory(cast_row(row)) for row in result] + cast_value = self.type_cache.typecast + return [row_factory([cast_value(value, typ) + for typ, value in zip(coltypes, row)]) for row in result] + + def callproc(self, procname: str, parameters: Sequence | None = None + ) -> Sequence | None: + """Call a stored database procedure with the given name. + + The sequence of parameters must contain one entry for each input + argument that the procedure expects. The result of the call is the + same as this input sequence; replacement of output and input/output + parameters in the return value is currently not supported. + + The procedure may also provide a result set as output. These can be + requested through the standard fetch methods of the cursor. + """ + n = len(parameters) if parameters else 0 + s = ','.join(n * ['%s']) + query = f'select * from "{procname}"({s})' # noqa: S608 + self.execute(query, parameters) + return parameters + + # noinspection PyShadowingBuiltins + def copy_from(self, stream: Any, table: str, + format: str | None = None, sep: str | None = None, + null: str | None = None, size: int | None = None, + columns: Sequence[str] | None = None) -> Cursor: + """Copy data from an input stream to the specified table. + + The input stream can be a file-like object with a read() method or + it can also be an iterable returning a row or multiple rows of input + on each iteration. + + The format must be 'text', 'csv' or 'binary'. The sep option sets the + column separator (delimiter) used in the non binary formats. + The null option sets the textual representation of NULL in the input. + + The size option sets the size of the buffer used when reading data + from file-like objects. + + The copy operation can be restricted to a subset of columns. If no + columns are specified, all of them will be copied. + """ + binary_format = format == 'binary' + try: + read = stream.read + except AttributeError as e: + if size: + raise ValueError( + "Size must only be set for file-like objects") from e + input_type: type | tuple[type, ...] + type_name: str + if binary_format: + input_type = bytes + type_name = 'byte strings' + else: + input_type = (bytes, str) + type_name = 'strings' + + if isinstance(stream, (bytes, str)): + if not isinstance(stream, input_type): + raise ValueError(f"The input must be {type_name}") from e + if not binary_format: + if isinstance(stream, str): + if not stream.endswith('\n'): + stream += '\n' + else: + if not stream.endswith(b'\n'): + stream += b'\n' + + def chunks() -> Generator: + yield stream + + elif isinstance(stream, Iterable): + + def chunks() -> Generator: + for chunk in stream: + if not isinstance(chunk, input_type): + raise ValueError( + f"Input stream must consist of {type_name}") + if isinstance(chunk, str): + if not chunk.endswith('\n'): + chunk += '\n' + else: + if not chunk.endswith(b'\n'): + chunk += b'\n' + yield chunk + + else: + raise TypeError("Need an input stream to copy from") from e + else: + if size is None: + size = 8192 + elif not isinstance(size, int): + raise TypeError("The size option must be an integer") + if size > 0: + + def chunks() -> Generator: + while True: + buffer = read(size) + yield buffer + if not buffer or len(buffer) < size: + break + + else: + + def chunks() -> Generator: + yield read() + + if not table or not isinstance(table, str): + raise TypeError("Need a table to copy to") + if table.lower().startswith('select '): + raise ValueError("Must specify a table, not a query") + cnx = self._cnx + table = '.'.join(map(cnx.escape_identifier, table.split('.', 1))) + operation_parts = [f'copy {table}'] + options = [] + parameters = [] + if format is not None: + if not isinstance(format, str): + raise TypeError("The format option must be be a string") + if format not in ('text', 'csv', 'binary'): + raise ValueError("Invalid format") + options.append(f'format {format}') + if sep is not None: + if not isinstance(sep, str): + raise TypeError("The sep option must be a string") + if format == 'binary': + raise ValueError( + "The sep option is not allowed with binary format") + if len(sep) != 1: + raise ValueError( + "The sep option must be a single one-byte character") + options.append('delimiter %s') + parameters.append(sep) + if null is not None: + if not isinstance(null, str): + raise TypeError("The null option must be a string") + options.append('null %s') + parameters.append(null) + if columns: + if not isinstance(columns, str): + columns = ','.join(map(cnx.escape_identifier, columns)) + operation_parts.append(f'({columns})') + operation_parts.append("from stdin") + if options: + operation_parts.append(f"({','.join(options)})") + operation = ' '.join(operation_parts) + + putdata = self._src.putdata + self.execute(operation, parameters) + + try: + for chunk in chunks(): + putdata(chunk) + except BaseException as error: + self.rowcount = -1 + # the following call will re-raise the error + putdata(error) + else: + rowcount = putdata(None) + self.rowcount = -1 if rowcount is None else rowcount + + # return the cursor object, so you can chain operations + return self + + # noinspection PyShadowingBuiltins + def copy_to(self, stream: Any, table: str, + format: str | None = None, sep: str | None = None, + null: str | None = None, decode: bool | None = None, + columns: Sequence[str] | None = None) -> Cursor | Generator: + """Copy data from the specified table to an output stream. + + The output stream can be a file-like object with a write() method or + it can also be None, in which case the method will return a generator + yielding a row on each iteration. + + Output will be returned as byte strings unless you set decode to true. + + Note that you can also use a select query instead of the table name. + + The format must be 'text', 'csv' or 'binary'. The sep option sets the + column separator (delimiter) used in the non binary formats. + The null option sets the textual representation of NULL in the output. + + The copy operation can be restricted to a subset of columns. If no + columns are specified, all of them will be copied. + """ + binary_format = format == 'binary' + if stream is None: + write = None + else: + try: + write = stream.write + except AttributeError as e: + raise TypeError("Need an output stream to copy to") from e + if not table or not isinstance(table, str): + raise TypeError("Need a table to copy to") + cnx = self._cnx + if table.lower().startswith('select '): + if columns: + raise ValueError("Columns must be specified in the query") + table = f'({table})' + else: + table = '.'.join(map(cnx.escape_identifier, table.split('.', 1))) + operation_parts = [f'copy {table}'] + options = [] + parameters = [] + if format is not None: + if not isinstance(format, str): + raise TypeError("The format option must be a string") + if format not in ('text', 'csv', 'binary'): + raise ValueError("Invalid format") + options.append(f'format {format}') + if sep is not None: + if not isinstance(sep, str): + raise TypeError("The sep option must be a string") + if binary_format: + raise ValueError( + "The sep option is not allowed with binary format") + if len(sep) != 1: + raise ValueError( + "The sep option must be a single one-byte character") + options.append('delimiter %s') + parameters.append(sep) + if null is not None: + if not isinstance(null, str): + raise TypeError("The null option must be a string") + options.append('null %s') + parameters.append(null) + if decode is None: + decode = format != 'binary' + else: + if not isinstance(decode, (int, bool)): + raise TypeError("The decode option must be a boolean") + if decode and binary_format: + raise ValueError( + "The decode option is not allowed with binary format") + if columns: + if not isinstance(columns, str): + columns = ','.join(map(cnx.escape_identifier, columns)) + operation_parts.append(f'({columns})') + + operation_parts.append("to stdout") + if options: + operation_parts.append(f"({','.join(options)})") + operation = ' '.join(operation_parts) + + getdata = self._src.getdata + self.execute(operation, parameters) + + def copy() -> Generator: + self.rowcount = 0 + while True: + row = getdata(decode) + if isinstance(row, int): + if self.rowcount != row: + self.rowcount = row + break + self.rowcount += 1 + yield row + + if write is None: + # no input stream, return the generator + return copy() + + # write the rows to the file-like input stream + for row in copy(): + # noinspection PyUnboundLocalVariable + write(row) + + # return the cursor object, so you can chain operations + return self + + def __next__(self) -> Sequence: + """Return the next row (support for the iteration protocol).""" + res = self.fetchone() + if res is None: + raise StopIteration + return res + + # Note that the iterator protocol now uses __next()__ instead of next(), + # but we keep it for backward compatibility of pgdb. + next = __next__ + + @staticmethod + def nextset() -> bool | None: + """Not supported.""" + raise NotSupportedError("The nextset() method is not supported") + + @staticmethod + def setinputsizes(sizes: Sequence[int]) -> None: + """Not supported.""" + pass # unsupported, but silently passed + + @staticmethod + def setoutputsize(size: int, column: int = 0) -> None: + """Not supported.""" + pass # unsupported, but silently passed + + @staticmethod + def row_factory(row: Sequence) -> Sequence: + """Process rows before they are returned. + + You can overwrite this statically with a custom row factory, or + you can build a row factory dynamically with build_row_factory(). + + For example, you can create a Cursor class that returns rows as + Python dictionaries like this: + + class DictCursor(pgdb.Cursor): + + def row_factory(self, row): + return {desc[0]: value + for desc, value in zip(self.description, row)} + + cur = DictCursor(con) # get one DictCursor instance or + con.cursor_type = DictCursor # always use DictCursor instances + """ + raise NotImplementedError + + def build_row_factory(self) -> Callable[[Sequence], Sequence] | None: + """Build a row factory based on the current description. + + This implementation builds a row factory for creating named tuples. + You can overwrite this method if you want to dynamically create + different row factories whenever the column description changes. + """ + names = self.colnames + return RowCache.row_factory(tuple(names)) if names else None + + +CursorDescription = namedtuple('CursorDescription', ( + 'name', 'type_code', 'display_size', 'internal_size', + 'precision', 'scale', 'null_ok')) diff --git a/pgdb/typecode.py b/pgdb/typecode.py new file mode 100644 index 00000000..fcfb4620 --- /dev/null +++ b/pgdb/typecode.py @@ -0,0 +1,34 @@ +"""Support for DB API 2 type codes.""" + +from __future__ import annotations + +__all__ = ['TypeCode'] + + +class TypeCode(str): + """Class representing the type_code used by the DB-API 2.0. + + TypeCode objects are strings equal to the PostgreSQL type name, + but carry some additional information. + """ + + oid: int + len: int + type: str + category: str + delim: str + relid: int + + # noinspection PyShadowingBuiltins + @classmethod + def create(cls, oid: int, name: str, len: int, type: str, category: str, + delim: str, relid: int) -> TypeCode: + """Create a type code for a PostgreSQL data type.""" + self = cls(name) + self.oid = oid + self.len = len + self.type = type + self.category = category + self.delim = delim + self.relid = relid + return self \ No newline at end of file diff --git a/tests/dbapi20.py b/tests/dbapi20.py index 0c038f72..bf3c5718 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -97,12 +97,13 @@ def tearDown(self): def _connect(self): try: - r = self.driver.connect( - *self.connect_args, **self.connect_kw_args - ) + con = self.driver.connect( + *self.connect_args, **self.connect_kw_args) except AttributeError: self.fail("No connect method found in self.driver module") - return r + if not isinstance(con, self.driver.Connection): + self.fail("The connect method does not return a Connection") + return con def test_connect(self): con = self._connect() diff --git a/tests/test_classic.py b/tests/test_classic.py index a6f78197..3bf0fe5c 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -148,7 +148,7 @@ def test_sqlstate(self): try: db.query("INSERT INTO _test_schema VALUES (1234)") except DatabaseError as error: - self.assertTrue(isinstance(error, IntegrityError)) + self.assertIsInstance(error, IntegrityError) # the SQLSTATE error code for unique violation is 23505 # noinspection PyUnresolvedReferences self.assertEqual(error.sqlstate, '23505') @@ -238,7 +238,7 @@ def test_notify(self, options=None): self.assertTrue(arg_dict['called']) self.assertEqual(arg_dict['event'], 'event_1') self.assertEqual(arg_dict['extra'], 'payload 1') - self.assertTrue(isinstance(arg_dict['pid'], int)) + self.assertIsInstance(arg_dict['pid'], int) self.assertFalse(self.notify_timeout) arg_dict['called'] = False self.assertTrue(thread.is_alive()) @@ -257,7 +257,7 @@ def test_notify(self, options=None): self.assertTrue(arg_dict['called']) self.assertEqual(arg_dict['event'], 'stop_event_1') self.assertEqual(arg_dict['extra'], 'payload 2') - self.assertTrue(isinstance(arg_dict['pid'], int)) + self.assertIsInstance(arg_dict['pid'], int) self.assertFalse(self.notify_timeout) thread.join(5) self.assertFalse(thread.is_alive()) diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index eca64afd..dcb7a5e2 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -2250,7 +2250,7 @@ def test_get_notify(self): self.assertIsNone(self.c.getnotify()) query("notify test_notify, 'test_payload'") r = getnotify() - self.assertTrue(isinstance(r, tuple)) + self.assertIsInstance(r, tuple) self.assertEqual(len(r), 3) self.assertIsInstance(r[0], str) self.assertIsInstance(r[1], int) @@ -2636,11 +2636,12 @@ def test_set_bytea_escaped(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b'data') - def test_set_row_factory_size(self): + def test_change_row_factory_cache_size(self): + cache = pg.RowCache queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc'] query = self.c.query for maxsize in (None, 0, 1, 2, 3, 10, 1024): - pg.set_row_factory_size(maxsize) + cache.change_size(maxsize) for _i in range(3): for q in queries: r = query(q).namedresult()[0] @@ -2650,12 +2651,11 @@ def test_set_row_factory_size(self): else: self.assertEqual(r, (1, 2, 3)) self.assertEqual(r._fields, ('a', 'b', 'c')) - from pg.helpers import _row_factory - info = _row_factory.cache_info() + info = cache.row_factory.cache_info() self.assertEqual(info.maxsize, maxsize) self.assertEqual(info.hits + info.misses, 6) - self.assertEqual( - info.hits, 0 if maxsize is not None and maxsize < 2 else 4) + self.assertEqual(info.hits, + 0 if maxsize is not None and maxsize < 2 else 4) class TestStandaloneEscapeFunctions(unittest.TestCase): diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 2ddde601..0755d95e 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -3164,7 +3164,7 @@ def test_context_manager(self): query("insert into test_table values (6)") query("insert into test_table values (-1)") except pg.IntegrityError as error: - self.assertTrue('check' in str(error)) + self.assertIn('check', str(error)) with self.db: query("insert into test_table values (7)") r = [r[0] for r in query( @@ -3276,7 +3276,8 @@ def test_upsert_bytea(self): if pg.get_bytea_escaped(): self.assertNotEqual(data, s) self.assertIsInstance(data, str) - data = pg.unescape_bytea(data) # type: ignore + assert isinstance(data, str) # type guard + data = pg.unescape_bytea(data) self.assertIsInstance(data, bytes) self.assertEqual(data, s) d['data'] = None diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 2e731c6e..ef4857d3 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -5,6 +5,7 @@ import gc import unittest from datetime import date, datetime, time, timedelta, timezone +from decimal import Decimal from typing import Any, ClassVar from uuid import UUID as Uuid # noqa: N811 @@ -443,7 +444,6 @@ def test_cursor_invalidation(self): self.assertRaises(pgdb.OperationalError, cur.fetchone) def test_fetch_2_rows(self): - Decimal = pgdb.decimal_type() # noqa: N806 values = ('test', pgdb.Binary(b'\xff\x52\xb2'), True, 5, 6, 5.7, Decimal('234.234234'), Decimal('75.45'), pgdb.Date(2011, 7, 17), pgdb.Time(15, 47, 42), @@ -536,7 +536,7 @@ def test_sqlstate(self): try: cur.execute("select 1/0") except pgdb.DatabaseError as error: - self.assertTrue(isinstance(error, pgdb.DataError)) + self.assertIsInstance(error, pgdb.DataError) # the SQLSTATE error code for division by zero is 22012 # noinspection PyUnresolvedReferences self.assertEqual(error.sqlstate, '22012') @@ -575,9 +575,9 @@ def test_float(self): if isinf(inval): # type: ignore self.assertTrue(isinf(outval)) if inval < 0: # type: ignore - self.assertTrue(outval < 0) + self.assertLess(outval, 0) else: - self.assertTrue(outval > 0) + self.assertGreater(outval, 0) elif isnan(inval): # type: ignore self.assertTrue(isnan(outval)) else: @@ -586,25 +586,27 @@ def test_float(self): def test_datetime(self): dt = datetime(2011, 7, 17, 15, 47, 42, 317509) values = [dt.date(), dt.time(), dt, dt.time(), dt] - assert isinstance(values[3], time) + self.assertIsInstance(values[3], time) + assert isinstance(values[3], time) # type guard values[3] = values[3].replace(tzinfo=timezone.utc) - assert isinstance(values[4], datetime) + self.assertIsInstance(values[4], datetime) + assert isinstance(values[4], datetime) # type guard values[4] = values[4].replace(tzinfo=timezone.utc) - d = (dt.year, dt.month, dt.day) - t = (dt.hour, dt.minute, dt.second, dt.microsecond) - z = (timezone.utc,) + da = (dt.year, dt.month, dt.day) + ti = (dt.hour, dt.minute, dt.second, dt.microsecond) + tz = (timezone.utc,) inputs = [ # input as objects values, # input as text [v.isoformat() for v in values], # type: ignore # # input using type helpers - [pgdb.Date(*d), pgdb.Time(*t), - pgdb.Timestamp(*(d + t)), pgdb.Time(*(t + z)), - pgdb.Timestamp(*(d + t + z))] + [pgdb.Date(*da), pgdb.Time(*ti), + pgdb.Timestamp(*(da + ti)), pgdb.Time(*(ti + tz)), + pgdb.Timestamp(*(da + ti + tz))] ] table = self.table_prefix + 'booze' - con = self._connect() + con: pgdb.Connection = self._connect() try: cur = con.cursor() cur.execute("set timezone = UTC") @@ -624,7 +626,8 @@ def test_datetime(self): " values (%s,%s,%s,%s,%s)", params) cur.execute(f"select * from {table}") d = cur.description - assert isinstance(d, list) + self.assertIsInstance(d, list) + assert d is not None # type guard for i in range(5): tc = d[i].type_code self.assertEqual(tc, pgdb.DATETIME) @@ -855,8 +858,8 @@ def test_custom_type(self): con.close() def test_set_decimal_type(self): - decimal_type = pgdb.decimal_type() - self.assertTrue(decimal_type is not None and callable(decimal_type)) + from pgdb.cast import decimal_type + self.assertIs(decimal_type(), Decimal) con = self._connect() try: cur = con.cursor() @@ -870,19 +873,19 @@ def __init__(self, value: Any) -> None: def __str__(self) -> str: return str(self.value).replace('.', ',') - self.assertTrue(pgdb.decimal_type(CustomDecimal) is CustomDecimal) + self.assertIs(decimal_type(CustomDecimal), CustomDecimal) cur.execute('select 4.25') self.assertEqual(cur.description[0].type_code, pgdb.NUMBER) value = cur.fetchone()[0] - self.assertTrue(isinstance(value, CustomDecimal)) + self.assertIsInstance(value, CustomDecimal) self.assertEqual(str(value), '4,25') # change decimal type again to float - self.assertTrue(pgdb.decimal_type(float) is float) + self.assertIs(decimal_type(float), float) cur.execute('select 4.25') self.assertEqual(cur.description[0].type_code, pgdb.NUMBER) value = cur.fetchone()[0] # the connection still uses the old setting - self.assertTrue(isinstance(value, str)) + self.assertIsInstance(value, str) self.assertEqual(str(value), '4,25') # bust the cache for type functions for the connection con.type_cache.reset_typecast() @@ -890,12 +893,12 @@ def __str__(self) -> str: self.assertEqual(cur.description[0].type_code, pgdb.NUMBER) value = cur.fetchone()[0] # now the connection uses the new setting - self.assertTrue(isinstance(value, float)) + self.assertIsInstance(value, float) self.assertEqual(value, 4.25) finally: con.close() - pgdb.decimal_type(decimal_type) - self.assertTrue(pgdb.decimal_type() is decimal_type) + decimal_type(Decimal) + self.assertIs(decimal_type(), Decimal) def test_global_typecast(self): try: @@ -1272,7 +1275,7 @@ def test_connection_as_contextmanager(self): cur.execute(f"insert into {table} values (3)") cur.execute(f"insert into {table} values (4)") except con.IntegrityError as error: - self.assertTrue('check' in str(error).lower()) + self.assertIn('check', str(error).lower()) with con: cur.execute(f"insert into {table} values (5)") cur.execute(f"insert into {table} values (6)") @@ -1325,11 +1328,11 @@ def test_pgdb_type(self): self.assertEqual('int8', pgdb.INTEGER) self.assertNotEqual('int4', pgdb.LONG) self.assertEqual('int8', pgdb.LONG) - self.assertTrue('char' in pgdb.STRING) - self.assertTrue(pgdb.NUMERIC <= pgdb.NUMBER) - self.assertTrue(pgdb.NUMBER >= pgdb.INTEGER) - self.assertTrue(pgdb.TIME <= pgdb.DATETIME) - self.assertTrue(pgdb.DATETIME >= pgdb.DATE) + self.assertIn('char', pgdb.STRING) + self.assertLess(pgdb.NUMERIC, pgdb.NUMBER) + self.assertGreaterEqual(pgdb.NUMBER, pgdb.INTEGER) + self.assertLessEqual(pgdb.TIME, pgdb.DATETIME) + self.assertGreaterEqual(pgdb.DATETIME, pgdb.DATE) self.assertEqual(pgdb.ARRAY, pgdb.ARRAY) self.assertNotEqual(pgdb.ARRAY, pgdb.STRING) self.assertEqual('_char', pgdb.ARRAY) @@ -1349,12 +1352,13 @@ def test_no_close(self): row = cur.fetchone() self.assertEqual(row, data) - def test_set_row_factory_size(self): + def test_change_row_factory_cache_size(self): + from pg import RowCache queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc'] con = self._connect() cur = con.cursor() for maxsize in (None, 0, 1, 2, 3, 10, 1024): - pgdb.set_row_factory_size(maxsize) + RowCache.change_size(maxsize) for _i in range(3): for q in queries: cur.execute(q) @@ -1365,11 +1369,11 @@ def test_set_row_factory_size(self): else: self.assertEqual(r, (1, 2, 3)) self.assertEqual(r._fields, ('a', 'b', 'c')) - info = pgdb._row_factory.cache_info() + info = RowCache.row_factory.cache_info() self.assertEqual(info.maxsize, maxsize) self.assertEqual(info.hits + info.misses, 6) - self.assertEqual( - info.hits, 0 if maxsize is not None and maxsize < 2 else 4) + self.assertEqual(info.hits, + 0 if maxsize is not None and maxsize < 2 else 4) def test_memory_leaks(self): ids: set = set() From b08b775cf77bd2cb8275429897f0dcb5dcd61fe9 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 11:59:52 +0200 Subject: [PATCH 058/118] Improve distribution files wording --- docs/download/files.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/download/files.rst b/docs/download/files.rst index f5e7a523..fc3ad26f 100644 --- a/docs/download/files.rst +++ b/docs/download/files.rst @@ -3,11 +3,11 @@ Distribution files ============== = -pg/ the "classic" PyGreSQL module +pg/ the "classic" PyGreSQL package pgdb/ a DB-SIG DB-API 2.0 compliant API wrapper for PyGreSQL -ext/ the source files for the C extension +ext/ the source files for the C extension module docs/ the documentation directory From c70b726984e2320883ac14451bcce2381324cc26 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 14:04:46 +0200 Subject: [PATCH 059/118] Improve typing of DB wrapper init method --- pg/db.py | 46 +++++++++++++++++++++++---------- tests/test_classic_dbwrapper.py | 7 ++--- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/pg/db.py b/pg/db.py index ce7915f8..03f7d7a4 100644 --- a/pg/db.py +++ b/pg/db.py @@ -6,7 +6,7 @@ from json import dumps as jsonencode from json import loads as jsondecode from operator import itemgetter -from typing import Any, Callable, Iterator, Sequence +from typing import TYPE_CHECKING, Any, Callable, Iterator, Sequence, overload from . import Connection, connect from .adapt import Adapter, DbTypes @@ -23,6 +23,9 @@ from .helpers import namediter, oid_key, quote_if_unqualified from .notify import NotificationHandler +if TYPE_CHECKING: + from pgdb.connection import Connection as DbApi2Connection + __all__ = ['DB'] # The actual PostgreSQL database connection interface: @@ -33,33 +36,48 @@ class DB: db: Connection | None = None # invalid fallback for underlying connection _db_args: Any # either the connect args or the underlying connection - def __init__(self, *args: Any, **kw: Any) -> None: + @overload + def __init__(self, dbname: str | None = None, + host: str | None = None, port: int = -1, + opt: str | None = None, + user: str | None = None, passwd: str | None = None, + nowait: bool = False) -> None: + ... + + @overload + def __init__(self, db: Connection | DB | DbApi2Connection) -> None: + ... + + def __init__(self, *args: Any, **kw: Any) -> None: """Create a new connection. You can pass either the connection parameters or an existing - _pg or pgdb connection. This allows you to use the methods - of the classic pg interface with a DB-API 2 pgdb connection. + pg or pgdb Connection. This allows you to use the methods + of the classic pg interface with a DB-API 2 pgdb Connection. """ - if not args and len(kw) == 1: + if kw: db = kw.get('db') - elif not kw and len(args) == 1: + if db is not None and (args or len(kw) > 1): + raise TypeError("Conflicting connection parameters") + elif len(args) == 1 and not isinstance(args[0], str): db = args[0] else: db = None if db: if isinstance(db, DB): - db = db.db + db = db.db # allow db to be a wrapped Connection else: with suppress(AttributeError): - # noinspection PyUnresolvedReferences - db = db._cnx - if not db or not hasattr(db, 'db') or not hasattr(db, 'query'): + db = db._cnx # allow db to be a pgdb Connection + if not isinstance(db, Connection): + raise TypeError( + "The 'db' argument must be a valid database connection.") + self._db_args = db + self._closeable = False + else: db = connect(*args, **kw) self._db_args = args, kw self._closeable = True - else: - self._db_args = db - self._closeable = False self.db = db self.dbname = db.db self._regtypes = False @@ -97,7 +115,7 @@ def __init__(self, *args: Any, **kw: Any) -> None: self.debug: Any = None def __getattr__(self, name: str) -> Any: - """Get the specified attritbute of the connection.""" + """Get the specified attribute of the connection.""" # All undefined members are same as in underlying connection: if self.db: return getattr(self.db, name) diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 0755d95e..e53617dd 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -301,12 +301,13 @@ def test_existing_connection(self): self.assertIsNone(db.db) db = pg.DB(self.db) self.assertEqual(self.db.db, db.db) + assert self.db.db is not None db = pg.DB(db=self.db.db) self.assertEqual(self.db.db, db.db) def test_existing_db_api2_connection(self): - class DBApi2Con: + class FakeDbApi2Connection: def __init__(self, cnx): self._cnx = cnx @@ -314,8 +315,8 @@ def __init__(self, cnx): def close(self): self._cnx.close() - db2 = DBApi2Con(self.db.db) - db = pg.DB(db2) + db2 = FakeDbApi2Connection(self.db.db) + db = pg.DB(db2) # type: ignore self.assertEqual(self.db.db, db.db) db.close() self.assertIsNone(db.db) From 17c42afbf46196510a7eacb822acad16de82948a Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 14:13:44 +0200 Subject: [PATCH 060/118] Use different docstrings for overloaded methods --- pg/db.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pg/db.py b/pg/db.py index 03f7d7a4..d541ac54 100644 --- a/pg/db.py +++ b/pg/db.py @@ -42,10 +42,12 @@ def __init__(self, dbname: str | None = None, opt: str | None = None, user: str | None = None, passwd: str | None = None, nowait: bool = False) -> None: + """Create a new connection using the specified parameters.""" ... @overload def __init__(self, db: Connection | DB | DbApi2Connection) -> None: + """Create a connection wrapper based on an existing connection.""" ... def __init__(self, *args: Any, **kw: Any) -> None: From 7cc9c879581299042af725de1f31ef869429c7b8 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 14:15:46 +0200 Subject: [PATCH 061/118] Actually overloaded methods shouldn't have docstrings --- pg/db.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pg/db.py b/pg/db.py index d541ac54..a13ea357 100644 --- a/pg/db.py +++ b/pg/db.py @@ -42,13 +42,11 @@ def __init__(self, dbname: str | None = None, opt: str | None = None, user: str | None = None, passwd: str | None = None, nowait: bool = False) -> None: - """Create a new connection using the specified parameters.""" - ... + ... # create a new connection using the specified parameters @overload def __init__(self, db: Connection | DB | DbApi2Connection) -> None: - """Create a connection wrapper based on an existing connection.""" - ... + ... # create a connection wrapper based on an existing connection def __init__(self, *args: Any, **kw: Any) -> None: """Create a new connection. From cc781024cce02bc7fb732a387a9db7dc41631eeb Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 16:16:30 +0200 Subject: [PATCH 062/118] Add immediately wrapped methods These methods now also check that the underlying connection is still valid, and they allow proper typing and auto completion for wrapped connections. --- ext/pgmodule.c | 4 +- pg/_pg.pyi | 5 +- pg/core.py | 3 +- pg/db.py | 153 ++++++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 160 insertions(+), 5 deletions(-) diff --git a/ext/pgmodule.c b/ext/pgmodule.c index 546c5cc5..761ae1b7 100644 --- a/ext/pgmodule.c +++ b/ext/pgmodule.c @@ -21,7 +21,7 @@ static PyObject *Error, *Warning, *InterfaceError, *DatabaseError, *InternalError, *OperationalError, *ProgrammingError, *IntegrityError, *DataError, *NotSupportedError, *InvalidResultError, *NoResultError, - *MultipleResultsError, *Connection, *Query; + *MultipleResultsError, *Connection, *Query, *LargeObject; #define _TOSTRING(x) #x #define TOSTRING(x) _TOSTRING(x) @@ -1310,6 +1310,8 @@ PyInit__pg(void) PyDict_SetItemString(dict, "Connection", Connection); Query = (PyObject *)&queryType; PyDict_SetItemString(dict, "Query", Query); + LargeObject = (PyObject *)&largeType; + PyDict_SetItemString(dict, "LargeObject", LargeObject); /* Make the version available */ s = PyUnicode_FromString(PyPgVersion); diff --git a/pg/_pg.pyi b/pg/_pg.pyi index 70f6e37e..b14bd5fc 100644 --- a/pg/_pg.pyi +++ b/pg/_pg.pyi @@ -4,7 +4,10 @@ from __future__ import annotations from typing import Any, Callable, Iterable, Sequence, TypeVar -AnyStr = TypeVar('AnyStr', str, bytes, str | bytes) +try: + AnyStr = TypeVar('AnyStr', str, bytes, str | bytes) +except TypeError: # Python < 3.10 + AnyStr = Any # type: ignore SomeNamedTuple = Any # alias for accessing arbitrary named tuples version: str diff --git a/pg/core.py b/pg/core.py index 3eb8f745..e20bdbd0 100644 --- a/pg/core.py +++ b/pg/core.py @@ -62,6 +62,7 @@ InterfaceError, InternalError, InvalidResultError, + LargeObject, MultipleResultsError, NoResultError, NotSupportedError, @@ -113,7 +114,7 @@ 'InvalidResultError', 'MultipleResultsError', 'NoResultError', 'NotSupportedError', 'OperationalError', 'ProgrammingError', - 'Connection', 'Query', + 'Connection', 'Query', 'LargeObject', 'INV_READ', 'INV_WRITE', 'POLLING_OK', 'POLLING_FAILED', 'POLLING_READING', 'POLLING_WRITING', 'RESULT_DDL', 'RESULT_DML', 'RESULT_DQL', 'RESULT_EMPTY', diff --git a/pg/db.py b/pg/db.py index a13ea357..f824cc9d 100644 --- a/pg/db.py +++ b/pg/db.py @@ -6,13 +6,22 @@ from json import dumps as jsonencode from json import loads as jsondecode from operator import itemgetter -from typing import TYPE_CHECKING, Any, Callable, Iterator, Sequence, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterator, + Sequence, + TypeVar, + overload, +) from . import Connection, connect from .adapt import Adapter, DbTypes from .attrs import AttrDict from .core import ( InternalError, + LargeObject, ProgrammingError, Query, get_bool, @@ -26,12 +35,32 @@ if TYPE_CHECKING: from pgdb.connection import Connection as DbApi2Connection +try: + AnyStr = TypeVar('AnyStr', str, bytes, str | bytes) +except TypeError: # Python < 3.10 + AnyStr = Any # type: ignore + __all__ = ['DB'] + # The actual PostgreSQL database connection interface: class DB: - """Wrapper class for the _pg connection type.""" + """Wrapper class for the core connection type.""" + + dbname: str + host: str + port: int + options: str + error: str + status: int + user : str + protocol_version: int + server_version: int + socket: int + backend_pid: int + ssl_in_use: bool + ssl_attributes: dict[str, str | None] db: Connection | None = None # invalid fallback for underlying connection _db_args: Any # either the connect args or the underlying connection @@ -1326,6 +1355,126 @@ def notification_handler(self, event: str, callback: Callable, return NotificationHandler(self, event, callback, arg_dict, timeout, stop_event) + # immediately wrapped methods + + def send_query(self, cmd: str, args: Sequence | None = None) -> Query: + """Create a new asynchronous query object for this connection.""" + if args is None: + return self._valid_db.send_query(cmd) + return self._valid_db.send_query(cmd, args) + + def poll(self) -> int: + """Complete an asynchronous connection and get its state.""" + return self._valid_db.poll() + + def cancel(self) -> None: + """Abandon processing of current SQL command.""" + self._valid_db.cancel() + + def fileno(self) -> int: + """Get the socket used to connect to the database.""" + return self._valid_db.fileno() + + def get_cast_hook(self) -> Callable | None: + """Get the function that handles all external typecasting.""" + return self._valid_db.get_cast_hook() + + def set_cast_hook(self, hook: Callable | None) -> None: + """Set a function that will handle all external typecasting.""" + self._valid_db.set_cast_hook(hook) + + def get_notice_receiver(self) -> Callable | None: + """Get the current notice receiver.""" + return self._valid_db.get_notice_receiver() + + def set_notice_receiver(self, receiver: Callable | None) -> None: + """Set a custom notice receiver.""" + self._valid_db.set_notice_receiver(receiver) + + def getnotify(self) -> tuple[str, int, str] | None: + """Get the last notify from the server.""" + return self._valid_db.getnotify() + + def inserttable(self, table: str, values: Sequence[list|tuple], + columns: list[str] | tuple[str, ...] | None = None) -> int: + """Insert a Python iterable into a database table.""" + if columns is None: + return self._valid_db.inserttable(table, values) + return self._valid_db.inserttable(table, values, columns) + + def transaction(self) -> int: + """Get the current in-transaction status of the server. + + The status returned by this method can be TRANS_IDLE (currently idle), + TRANS_ACTIVE (a command is in progress), TRANS_INTRANS (idle, in a + valid transaction block), or TRANS_INERROR (idle, in a failed + transaction block). TRANS_UNKNOWN is reported if the connection is + bad. The status TRANS_ACTIVE is reported only when a query has been + sent to the server and not yet completed. + """ + return self._valid_db.transaction() + + def parameter(self, name: str) -> str | None: + """Look up a current parameter setting of the server.""" + return self._valid_db.parameter(name) + + + def date_format(self) -> str: + """Look up the date format currently being used by the database.""" + return self._valid_db.date_format() + + def escape_literal(self, s: AnyStr) -> AnyStr: + """Escape a literal constant for use within SQL.""" + return self._valid_db.escape_literal(s) + + def escape_identifier(self, s: AnyStr) -> AnyStr: + """Escape an identifier for use within SQL.""" + return self._valid_db.escape_identifier(s) + + def escape_string(self, s: AnyStr) -> AnyStr: + """Escape a string for use within SQL.""" + return self._valid_db.escape_string(s) + + def escape_bytea(self, s: AnyStr) -> AnyStr: + """Escape binary data for use within SQL as type 'bytea'.""" + return self._valid_db.escape_bytea(s) + + def putline(self, line: str) -> None: + """Write a line to the server socket.""" + self._valid_db.putline(line) + + def getline(self) -> str: + """Get a line from server socket.""" + return self._valid_db.getline() + + def endcopy(self) -> None: + """Synchronize client and server.""" + self._valid_db.endcopy() + + def set_non_blocking(self, nb: bool) -> None: + """Set the non-blocking mode of the connection.""" + self._valid_db.set_non_blocking(nb) + + def is_non_blocking(self) -> bool: + """Get the non-blocking mode of the connection.""" + return self._valid_db.is_non_blocking() + + def locreate(self, mode: int) -> LargeObject: + """Create a large object in the database. + + The valid values for 'mode' parameter are defined as the module level + constants INV_READ and INV_WRITE. + """ + return self._valid_db.locreate(mode) + + def getlo(self, oid: int) -> LargeObject: + """Build a large object from given oid.""" + return self._valid_db.getlo(oid) + + def loimport(self, filename: str) -> LargeObject: + """Import a file to a large object.""" + return self._valid_db.loimport(filename) + class _MemoryQuery: """Class that embodies a given query result.""" From 20ce949bd8428eb416e85b45263dfd3f986865f7 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 16:55:20 +0200 Subject: [PATCH 063/118] Support Python 3.12 and PostgreSQL 16 --- .bumpversion.cfg | 2 +- .devcontainer/provision.sh | 7 ++++++- README.rst | 6 ++++++ docs/about.rst | 4 ++-- docs/conf.py | 2 +- docs/contents/changelog.rst | 5 +++-- docs/contents/install.rst | 2 +- pyproject.toml | 3 ++- setup.py | 3 ++- tests/test_classic_connection.py | 3 ++- tests/test_classic_dbwrapper.py | 3 ++- tests/test_classic_functions.py | 2 +- tox.ini | 2 +- 13 files changed, 30 insertions(+), 14 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 769d02cf..1e499975 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 6.0 +current_version = 6.0b1 commit = False tag = False diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index c780e7df..05a681e4 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -26,6 +26,7 @@ sudo apt-get install -y python3.8 python3.8-dev python3.8-distutils sudo apt-get install -y python3.9 python3.9-dev python3.9-distutils sudo apt-get install -y python3.10 python3.10-dev python3.10-distutils sudo apt-get install -y python3.11 python3.11-dev python3.11-distutils +sudo apt-get install -y python3.12 python3.12-dev python3.12-distutils # install build and testing tool @@ -43,7 +44,7 @@ sudo apt-get install -y tox clang-format sudo apt-get install -y postgresql libpq-dev -for pghost in pg10 pg12 pg14 pg15 +for pghost in pg10 pg12 pg14 pg15 pg16 do export PGHOST=$pghost export PGDATABASE=postgres @@ -76,3 +77,7 @@ do psql -c "create extension hstore" test_latin9 psql -c "create extension hstore" test_cyrillic done + +export PGDATABASE=test +export PGUSER=test +export PGPASSWORD=test diff --git a/README.rst b/README.rst index a010b944..e9f9465c 100644 --- a/README.rst +++ b/README.rst @@ -18,6 +18,9 @@ The following Python versions are supported: * PyGreSQL 5.x: Python 2 and Python 3 * PyGreSQL 6.x and newer: Python 3 only +The current version of PyGreSQL supports Python versions 3.7 to 3.12 +and PostgreSQL versions 10 to 16 on the server. + Installation ------------ @@ -28,6 +31,9 @@ The simplest way to install PyGreSQL is to type:: For other ways of installing PyGreSQL and requirements, see the documentation. +Note that PyGreSQL also requires the libpq shared library to be +installed and accessible on the client machine. + Documentation ------------- diff --git a/docs/about.rst b/docs/about.rst index 8235e5cc..18c6b7a6 100644 --- a/docs/about.rst +++ b/docs/about.rst @@ -39,6 +39,6 @@ on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -The current version PyGreSQL |version| needs PostgreSQL 10 to 15, and Python -3.7 to 3.11. If you need to support older PostgreSQL or Python versions, +The current version PyGreSQL |version| needs PostgreSQL 10 to 16, and Python +3.7 to 3.12. If you need to support older PostgreSQL or Python versions, you can resort to the PyGreSQL 5.x versions that still support them. diff --git a/docs/conf.py b/docs/conf.py index 9dd604f2..48cb7dc0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -10,7 +10,7 @@ author = 'The PyGreSQL team' copyright = '2023, ' + author -version = release = '6.0' +version = release = '6.0b1' language = 'en' diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 077893a2..9f35f716 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,8 +1,9 @@ ChangeLog ========= -Version 6.0 (to be released) ----------------------------- +Version 6.0b1 (2023-09-06) +-------------------------- +- Officially support Python 3.12 and PostgreSQL 16 (tested with rc versions). - Removed support for Python versions older than 3.7 (released June 2017) and PostgreSQL older than version 10 (released October 2017). - Converted the standalone modules `pg` and `pgdb` to packages with diff --git a/docs/contents/install.rst b/docs/contents/install.rst index f447abc3..7d28ea59 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -14,7 +14,7 @@ On Windows, you also need to make sure that the directory that contains ``libpq.dll`` is part of your ``PATH`` environment variable. The current version of PyGreSQL has been tested with Python versions -3.7 to 3.11, and PostgreSQL versions 10 to 15. +3.7 to 3.12, and PostgreSQL versions 10 to 16. PyGreSQL will be installed as two packages named ``pg`` (for the classic interface) and ``pgdb`` (for the DB API 2 compliant interface). The former diff --git a/pyproject.toml b/pyproject.toml index e289b38f..30d255e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "PyGreSQL" -version = "6.0" +version = "6.0b1" requires-python = ">=3.7" authors = [ {name = "D'Arcy J. M. Cain", email = "darcy@pygresql.org"}, @@ -22,6 +22,7 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: SQL", "Topic :: Database", "Topic :: Database :: Front-Ends", diff --git a/setup.py b/setup.py index 4fd39c56..d0f70ea0 100755 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ from setuptools import Extension, setup from setuptools.command.build_ext import build_ext -version = '6.0' +version = '6.0b1' if not (3, 7) <= sys.version_info[:2] < (4, 0): raise Exception( @@ -152,6 +152,7 @@ def finalize_options(self): 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Programming Language :: SQL', 'Topic :: Database', 'Topic :: Database :: Front-Ends', diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index dcb7a5e2..be1b5a42 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -174,7 +174,8 @@ def test_attribute_protocol_version(self): def test_attribute_server_version(self): server_version = self.connection.server_version self.assertIsInstance(server_version, int) - self.assertTrue(100000 <= server_version < 160000) + self.assertGreaterEqual(server_version, 100000) + self.assertLess(server_version, 170000) def test_attribute_socket(self): socket = self.connection.socket diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index e53617dd..d1224a53 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -168,7 +168,8 @@ def test_attribute_protocol_version(self): def test_attribute_server_version(self): server_version = self.db.server_version self.assertIsInstance(server_version, int) - self.assertTrue(100000 <= server_version < 160000) + self.assertGreaterEqual(server_version, 100000) + self.assertLess(server_version, 170000) self.assertEqual(server_version, self.db.db.server_version) def test_attribute_socket(self): diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 01ed752e..4351f794 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -125,7 +125,7 @@ def test_pqlib_version(self): v = pg.get_pqlib_version() self.assertIsInstance(v, int) self.assertGreater(v, 100000) - self.assertLess(v, 160000) + self.assertLess(v, 170000) class TestParseArray(unittest.TestCase): diff --git a/tox.ini b/tox.ini index eae93234..b86f9fea 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ # config file for tox [tox] -envlist = py3{7,8,9,10,11},ruff,mypy,cformat,docs +envlist = py3{7,8,9,10,11,12},ruff,mypy,cformat,docs [testenv:ruff] basepython = python3.11 From 741f734c73a2baaa667ed5980b53fb98090a4f1b Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 17:44:35 +0200 Subject: [PATCH 064/118] Install current build tools for development --- .devcontainer/provision.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index 05a681e4..7cb14be0 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -30,11 +30,11 @@ sudo apt-get install -y python3.12 python3.12-dev python3.12-distutils # install build and testing tool -python3.7 -m pip install build -python3.8 -m pip install build -python3.9 -m pip install build -python3.10 -m pip install build -python3.11 -m pip install build +python3.7 -m pip install -U pip setuptools wheel build +python3.8 -m pip install -U pip setuptools wheel build +python3.9 -m pip install -U pip setuptools wheel build +python3.10 -m pip install -U pip setuptools wheel build +python3.11 -m pip install -U pip setuptools wheel build pip install ruff From 7d93eb222a48ff0b14012f684b316d338cad5dd9 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 17:47:18 +0200 Subject: [PATCH 065/118] Avoid segfault when there is a poll error --- ext/pgconn.c | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ext/pgconn.c b/ext/pgconn.c index 10e5b780..9ffc0009 100644 --- a/ext/pgconn.c +++ b/ext/pgconn.c @@ -45,7 +45,7 @@ conn_getattr(connObject *self, PyObject *nameobj) /* postmaster host */ if (!strcmp(name, "host")) { char *r = PQhost(self->cnx); - if (!r || r[0] == '/') /* Pg >= 9.6 can return a Unix socket path */ + if (!r || r[0] == '/') /* this can return a Unix socket path */ r = "localhost"; return PyUnicode_FromString(r); } @@ -1577,7 +1577,6 @@ conn_poll(connObject *self, PyObject *noargs) if (rc == PGRES_POLLING_FAILED) { set_error(InternalError, "Polling failed", self->cnx, NULL); - Py_XDECREF(self); return NULL; } From b96d64f49be97731041948df12c0e868d8ea42d2 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 18:38:34 +0200 Subject: [PATCH 066/118] Add coverage to tox file --- tox.ini | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tox.ini b/tox.ini index b86f9fea..dd747abe 100644 --- a/tox.ini +++ b/tox.ini @@ -38,6 +38,14 @@ deps = commands = python -m build -n -C strict -C memory-size +[testenv:coverage] +basepython = python3.11 +deps = + coverage>=7,<8 +commands = + coverage run -m unittest discover + coverage html + [testenv] passenv = PG* From f8e79fb17640f5f5e2e5da735b550bd3339b2d38 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 18:39:04 +0200 Subject: [PATCH 067/118] Update bump file --- .bumpversion.cfg | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 1e499975..f9acd8ee 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -12,6 +12,10 @@ serialize = search = version = '{current_version}' replace = version = '{new_version}' +[bumpversion:file:pyproject.toml] +search = version = "{current_version}" +replace = version = "{new_version}" + [bumpversion:file:docs/conf.py] search = version = release = '{current_version}' replace = version = release = '{new_version}' From 439cbdd77b6be77281eb0eafcab1d3130226ec82 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 6 Sep 2023 19:52:34 +0200 Subject: [PATCH 068/118] Use consistent project urls --- .devcontainer/provision.sh | 1 + pyproject.toml | 14 +++++++------- setup.py | 2 +- tox.ini | 2 +- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index 7cb14be0..09acd893 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -39,6 +39,7 @@ python3.11 -m pip install -U pip setuptools wheel build pip install ruff sudo apt-get install -y tox clang-format +pip install -U tox # install PostgreSQL client tools diff --git a/pyproject.toml b/pyproject.toml index 30d255e4..fdbbcbea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,13 +33,13 @@ classifiers = [ file = "LICENSE.txt" [project.urls] -homepage = "https://pygresql.github.io/" -documentation = "https://pygresql.github.io/contents/" -source = "https://github.com/PyGreSQL/PyGreSQL" -issues = "https://github.com/PyGreSQL/PyGreSQL/issues/" -changelog = "https://pygresql.github.io/contents/changelog.html" -download = "https://pygresql.github.io/download/" -"mailing list" = "https://mail.vex.net/mailman/listinfo/pygresql" +Homepage = "https://pygresql.github.io/" +Documentation = "https://pygresql.github.io/contents/" +"Source Code" = "https://github.com/PyGreSQL/PyGreSQL" +"Issue Tracker" = "https://github.com/PyGreSQL/PyGreSQL/issues/" +Changelog = "https://pygresql.github.io/contents/changelog.html" +Download = "https://pygresql.github.io/download/" +"Mailing List" = "https://mail.vex.net/mailman/listinfo/pygresql" [tool.ruff] target-version = "py37" diff --git a/setup.py b/setup.py index d0f70ea0..c2d72a61 100755 --- a/setup.py +++ b/setup.py @@ -133,7 +133,7 @@ def finalize_options(self): author="D'Arcy J. M. Cain", author_email="darcy@PyGreSQL.org", url='https://pygresql.github.io/', - download_url='https://pygresql.github.io/contents/download/', + download_url='https://pygresql.github.io/download/', project_urls={ 'Documentation': 'https://pygresql.github.io/contents/', 'Issue Tracker': 'https://github.com/PyGreSQL/PyGreSQL/issues/', diff --git a/tox.ini b/tox.ini index dd747abe..c58f40b1 100644 --- a/tox.ini +++ b/tox.ini @@ -36,7 +36,7 @@ deps = wheel>=0.41 build>=0.10 commands = - python -m build -n -C strict -C memory-size + python -m build -s -n -C strict -C memory-size [testenv:coverage] basepython = python3.11 From e73f4ae9a2ed81b81cc833deae97f807fd25f420 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 3 Oct 2023 16:07:27 +0000 Subject: [PATCH 069/118] Test with Python 3.12 and Postgres 16 on GitHub --- .github/workflows/tests.yml | 12 +++++++----- tox.ini | 6 +++--- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ca8e4a36..43da55df 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -20,13 +20,15 @@ jobs: - { python: "3.9", postgres: "13" } - { python: "3.10", postgres: "14" } - { python: "3.11", postgres: "15" } + - { python: "3.12", postgres: "16" } # Opposite extremes of the supported Py/PG range, other architecture - - { python: "3.7", postgres: "15", architecture: "x86" } - - { python: "3.8", postgres: "14", architecture: "x86" } - - { python: "3.9", postgres: "13", architecture: "x86" } - - { python: "3.10", postgres: "12", architecture: "x86" } - - { python: "3.11", postgres: "11", architecture: "x86" } + - { python: "3.7", postgres: "16", architecture: "x86" } + - { python: "3.8", postgres: "15", architecture: "x86" } + - { python: "3.9", postgres: "14", architecture: "x86" } + - { python: "3.10", postgres: "13", architecture: "x86" } + - { python: "3.11", postgres: "12", architecture: "x86" } + - { python: "3.12", postgres: "11", architecture: "x86" } env: PYGRESQL_DB: test diff --git a/tox.ini b/tox.ini index c58f40b1..fd36f2a2 100644 --- a/tox.ini +++ b/tox.ini @@ -5,7 +5,7 @@ envlist = py3{7,8,9,10,11,12},ruff,mypy,cformat,docs [testenv:ruff] basepython = python3.11 -deps = ruff>=0.0.287 +deps = ruff>=0.0.292 commands = ruff setup.py pg pgdb tests @@ -33,8 +33,8 @@ commands = basepython = python3.11 deps = setuptools>=68 - wheel>=0.41 - build>=0.10 + wheel>=0.41,<1 + build>=1,<2 commands = python -m build -s -n -C strict -C memory-size From a581e0448244f439f7bbc3d66ee88879e1da47f6 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 3 Oct 2023 17:25:08 +0000 Subject: [PATCH 070/118] Update checkout action --- .github/workflows/docs.yml | 2 +- .github/workflows/lint.yml | 2 +- .github/workflows/tests.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index aae221a0..7d1ba05a 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -11,7 +11,7 @@ jobs: steps: - name: CHeck out repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python 3.11 uses: actions/setup-python@v4 with: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index dad89096..267c54c2 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -14,7 +14,7 @@ jobs: steps: - name: Check out repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install tox run: pip install tox - name: Setup Python diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 43da55df..31a48265 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -53,7 +53,7 @@ jobs: steps: - name: Check out repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install tox run: pip install tox - name: Setup Python From cdf9f427ee936a838992f28b478bb31171d7da44 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 3 Oct 2023 20:16:03 +0200 Subject: [PATCH 071/118] In some test environments, there is no SSL support The test started to break on GitHub in September 2023. Might have to do with changes in the Ubuntu docker image. --- tests/test_classic_connection.py | 7 ++++--- tests/test_classic_dbwrapper.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index be1b5a42..3f9427b2 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -195,9 +195,10 @@ def test_attribute_ssl_in_use(self): def test_attribute_ssl_attributes(self): ssl_attributes = self.connection.ssl_attributes self.assertIsInstance(ssl_attributes, dict) - self.assertEqual(ssl_attributes, { - 'cipher': None, 'compression': None, 'key_bits': None, - 'library': None, 'protocol': None}) + if ssl_attributes: + self.assertEqual(ssl_attributes, { + 'cipher': None, 'compression': None, 'key_bits': None, + 'library': None, 'protocol': None}) def test_attribute_status(self): status_ok = 1 diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index d1224a53..8aa691f5 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -190,9 +190,10 @@ def test_attribute_ssl_in_use(self): def test_attribute_ssl_attributes(self): ssl_attributes = self.db.ssl_attributes self.assertIsInstance(ssl_attributes, dict) - self.assertEqual(ssl_attributes, { - 'cipher': None, 'compression': None, 'key_bits': None, - 'library': None, 'protocol': None}) + if ssl_attributes: + self.assertEqual(ssl_attributes, { + 'cipher': None, 'compression': None, 'key_bits': None, + 'library': None, 'protocol': None}) def test_attribute_status(self): status_ok = 1 From 3e0de8e7e7d48f34141b473ba45e2b6c0bfef6f3 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 3 Oct 2023 18:34:34 +0000 Subject: [PATCH 072/118] Python 3.12 needs setuptools --- tox.ini | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tox.ini b/tox.ini index fd36f2a2..3917f73c 100644 --- a/tox.ini +++ b/tox.ini @@ -50,6 +50,8 @@ commands = passenv = PG* PYGRESQL_* +deps = + setuptools>=68 commands = python setup.py clean --all build_ext --force --inplace --strict --memory-size python -m unittest {posargs:discover} From 0587d593fe97b42a94735e87bf958e296ccbe704 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 3 Oct 2023 18:53:55 +0000 Subject: [PATCH 073/118] Keep version only in pyproject.toml --- .bumpversion.cfg | 21 --------------------- docs/conf.py | 10 +++++++++- pyproject.toml | 2 +- setup.py | 19 +++++++++++++++---- 4 files changed, 25 insertions(+), 27 deletions(-) delete mode 100644 .bumpversion.cfg diff --git a/.bumpversion.cfg b/.bumpversion.cfg deleted file mode 100644 index f9acd8ee..00000000 --- a/.bumpversion.cfg +++ /dev/null @@ -1,21 +0,0 @@ -[bumpversion] -current_version = 6.0b1 -commit = False -tag = False - -parse = (?P\d+)\.(?P\d+)(?:\.(?P\d+))? -serialize = - {major}.{minor}.{patch} - {major}.{minor} - -[bumpversion:file:setup.py] -search = version = '{current_version}' -replace = version = '{new_version}' - -[bumpversion:file:pyproject.toml] -search = version = "{current_version}" -replace = version = "{new_version}" - -[bumpversion:file:docs/conf.py] -search = version = release = '{current_version}' -replace = version = release = '{new_version}' diff --git a/docs/conf.py b/docs/conf.py index 48cb7dc0..1a63dac4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -10,7 +10,15 @@ author = 'The PyGreSQL team' copyright = '2023, ' + author -version = release = '6.0b1' +def project_version(): + with open('../pyproject.toml') as f: + for d in f: + if d.startswith("version ="): + version = d.split("=")[1].strip().strip('"') + return version + raise Exception("Cannot determine PyGreSQL version") + +version = release = project_version() language = 'en' diff --git a/pyproject.toml b/pyproject.toml index fdbbcbea..bfcac161 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "PyGreSQL" -version = "6.0b1" +version = "6.0" requires-python = ">=3.7" authors = [ {name = "D'Arcy J. M. Cain", email = "darcy@pygresql.org"}, diff --git a/setup.py b/setup.py index c2d72a61..813ecde8 100755 --- a/setup.py +++ b/setup.py @@ -19,14 +19,25 @@ from setuptools import Extension, setup from setuptools.command.build_ext import build_ext -version = '6.0b1' - if not (3, 7) <= sys.version_info[:2] < (4, 0): raise Exception( f"Sorry, PyGreSQL {version} does not support this Python version") -with open('README.rst') as f: - long_description = f.read() +def project_version(): + with open('pyproject.toml') as f: + for d in f: + if d.startswith("version ="): + version = d.split("=")[1].strip().strip('"') + return version + raise Exception("Cannot determine PyGreSQL version") + +def project_readme(): + with open('README.rst') as f: + return f.read() + +version = project_version() + +long_description = project_readme() # For historical reasons, PyGreSQL does not install itself as a single # "pygresql" package, but as two top-level modules "pg", providing the From 7a9a6fbd9120c77c2d69d4f5975e56fde95a75a9 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 3 Oct 2023 19:02:09 +0000 Subject: [PATCH 074/118] Handle ruff complaints --- setup.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 813ecde8..8b1ec5dc 100755 --- a/setup.py +++ b/setup.py @@ -19,11 +19,9 @@ from setuptools import Extension, setup from setuptools.command.build_ext import build_ext -if not (3, 7) <= sys.version_info[:2] < (4, 0): - raise Exception( - f"Sorry, PyGreSQL {version} does not support this Python version") def project_version(): + """Read the PyGreSQL version from the pyproject.toml file.""" with open('pyproject.toml') as f: for d in f: if d.startswith("version ="): @@ -31,14 +29,22 @@ def project_version(): return version raise Exception("Cannot determine PyGreSQL version") + def project_readme(): + """Get the content of the README file.""" with open('README.rst') as f: return f.read() + version = project_version() +if not (3, 7) <= sys.version_info[:2] < (4, 0): + raise Exception( + f"Sorry, PyGreSQL {version} does not support this Python version") + long_description = project_readme() + # For historical reasons, PyGreSQL does not install itself as a single # "pygresql" package, but as two top-level modules "pg", providing the # classic interface, and "pgdb" for the modern DB-API 2.0 interface. From 9bc5a1ec81ef15595645012e9493c26fd96333a9 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Tue, 3 Oct 2023 19:11:02 +0000 Subject: [PATCH 075/118] Update changelog --- docs/contents/changelog.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 9f35f716..ac501b56 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,12 @@ ChangeLog ========= +Version 6.0 (2023-10-03) +------------------------ +- Tested with the recent releases of Python 3.12 and PostgreSQL 16. +- Make pyproject.toml the only source of truth for the version number. +- Please also note the changes already made in version 6.0b1. + Version 6.0b1 (2023-09-06) -------------------------- - Officially support Python 3.12 and PostgreSQL 16 (tested with rc versions). From a5af0d2897f8565f7bbcf9f7fe9b67251ce185f3 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 29 Feb 2024 14:53:22 +0000 Subject: [PATCH 076/118] Ignore dll files for Python --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 83732331..8b08bb41 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,7 @@ *.patch *.pid *.pstats -*.py[co] +*.py[cdo] *.so *.swp From b2e1752c1e0ff18040a2280770edf2686873eddb Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 29 Feb 2024 16:11:22 +0000 Subject: [PATCH 077/118] Properly adapt falsy JSON values (#86) --- pg/adapt.py | 2 +- tests/test_classic_dbwrapper.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/pg/adapt.py b/pg/adapt.py index 9cbecaaf..2a5efaa2 100644 --- a/pg/adapt.py +++ b/pg/adapt.py @@ -333,7 +333,7 @@ def _adapt_bytea(self, v: Any) -> str: def _adapt_json(self, v: Any) -> str | None: """Adapt a json parameter.""" - if not v: + if v is None: return None if isinstance(v, str): return v diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 8aa691f5..f02955c7 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -4390,6 +4390,14 @@ def test_adapt_query_typed_list_with_json(self): self.assertEqual(sql, 'select $1') self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}']) + def test_adapt_query_typed_list_with_empty_json(self): + format_query = self.adapter.format_query + values: Any = [None, 0, False, '', [], {}] + types = ('json',) * 6 + sql, params = format_query("select %s,%s,%s,%s,%s,%s", values, types) + self.assertEqual(sql, 'select $1,$2,$3,$4,$5,$6') + self.assertEqual(params, [None, '0', 'false', '', '[]', '{}']) + def test_adapt_query_typed_with_hstore(self): format_query = self.adapter.format_query value: Any = {'one': "it's fine", 'two': 2} From a8507e0f1f1f63c19ae7a85ba22ed5c4e2883070 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 29 Feb 2024 16:21:02 +0000 Subject: [PATCH 078/118] Update ruff and mypy --- pyproject.toml | 29 ++++++++++++++++------------- tox.ini | 6 +++--- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bfcac161..e720490b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,18 +44,6 @@ Download = "https://pygresql.github.io/download/" [tool.ruff] target-version = "py37" line-length = 79 -select = [ - "E", # pycodestyle - "F", # pyflakes - "I", # isort - "N", # pep8-naming - "UP", # pyupgrade - "D", # pydocstyle - "B", # bugbear - "S", # bandit - "SIM", # simplify - "RUF", # ruff -] exclude = [ "__pycache__", "__pypackages__", @@ -71,7 +59,22 @@ exclude = [ "venv", ] -[tool.ruff.per-file-ignores] +[tool.ruff.lint] +select = [ + "E", # pycodestyle + "F", # pyflakes + "I", # isort + "N", # pep8-naming + "UP", # pyupgrade + "D", # pydocstyle + "B", # bugbear + "S", # bandit + "SIM", # simplify + "RUF", # ruff +] +ignore = ["D203", "D213"] + +[tool.ruff.lint.per-file-ignores] "tests/*.py" = ["D100", "D101", "D102", "D103", "D105", "D107", "S"] [tool.mypy] diff --git a/tox.ini b/tox.ini index 3917f73c..96ed1c5e 100644 --- a/tox.ini +++ b/tox.ini @@ -5,13 +5,13 @@ envlist = py3{7,8,9,10,11,12},ruff,mypy,cformat,docs [testenv:ruff] basepython = python3.11 -deps = ruff>=0.0.292 +deps = ruff>=0.3.0 commands = - ruff setup.py pg pgdb tests + ruff check setup.py pg pgdb tests [testenv:mypy] basepython = python3.11 -deps = mypy>=1.5.1 +deps = mypy>=1.8.0 commands = mypy pg pgdb tests From 8d5b39b35b196edaa8bfbe67de909bfeba8a794d Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 18 Apr 2024 21:49:43 +0200 Subject: [PATCH 079/118] Update lint tools and GitHub actions --- .github/workflows/docs.yml | 11 ++++++----- .github/workflows/lint.yml | 4 ++-- .github/workflows/tests.yml | 2 +- pyproject.toml | 4 ++-- tox.ini | 18 +++++++++--------- 5 files changed, 20 insertions(+), 19 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 7d1ba05a..50248b64 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -6,16 +6,17 @@ on: - main jobs: - build: + docs: + name: Build documentation runs-on: ubuntu-22.04 steps: - - name: CHeck out repository + - name: Check out repository uses: actions/checkout@v4 - - name: Set up Python 3.11 - uses: actions/setup-python@v4 + - name: Set up Python 3.12 + uses: actions/setup-python@v5 with: - python-version: 3.11 + python-version: 3.12 - name: Install dependencies run: | sudo apt install libpq-dev diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 267c54c2..c32a6e58 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -14,13 +14,13 @@ jobs: steps: - name: Check out repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Install tox run: pip install tox - name: Setup Python uses: actions/setup-python@v4 with: - python-version: 3.11 + python-version: 3.12 - name: Run quality checks run: tox -e ruff,mypy,cformat,docs timeout-minutes: 5 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 31a48265..822fabdb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -57,7 +57,7 @@ jobs: - name: Install tox run: pip install tox - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} - name: Run tests diff --git a/pyproject.toml b/pyproject.toml index e720490b..a3be7012 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ ignore = ["D203", "D213"] "tests/*.py" = ["D100", "D101", "D102", "D103", "D105", "D107", "S"] [tool.mypy] -python_version = "3.11" +python_version = "3.12" check_untyped_defs = true no_implicit_optional = true strict_optional = true @@ -101,5 +101,5 @@ pg = ["pg.typed"] pgdb = ["pg.typed"] [build-system] -requires = ["setuptools>=68", "wheel>=0.41"] +requires = ["setuptools>=68", "wheel>=0.42"] build-backend = "setuptools.build_meta" diff --git a/tox.ini b/tox.ini index 96ed1c5e..f703abb5 100644 --- a/tox.ini +++ b/tox.ini @@ -4,42 +4,42 @@ envlist = py3{7,8,9,10,11,12},ruff,mypy,cformat,docs [testenv:ruff] -basepython = python3.11 -deps = ruff>=0.3.0 +basepython = python3.12 +deps = ruff>=0.3.7 commands = ruff check setup.py pg pgdb tests [testenv:mypy] -basepython = python3.11 -deps = mypy>=1.8.0 +basepython = python3.12 +deps = mypy>=1.9.0 commands = mypy pg pgdb tests [testenv:cformat] -basepython = python3.11 +basepython = python3.12 allowlist_externals = sh commands = sh -c "! (clang-format --style=file -n ext/*.c 2>&1 | tee /dev/tty | grep format-violations)" [testenv:docs] -basepython = python3.11 +basepython = python3.12 deps = sphinx>=7,<8 commands = sphinx-build -b html -nEW docs docs/_build/html [testenv:build] -basepython = python3.11 +basepython = python3.12 deps = setuptools>=68 - wheel>=0.41,<1 + wheel>=0.42,<1 build>=1,<2 commands = python -m build -s -n -C strict -C memory-size [testenv:coverage] -basepython = python3.11 +basepython = python3.12 deps = coverage>=7,<8 commands = From 633324d45b15c23de295cd6e680ddbe365df91c7 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 18 Apr 2024 22:05:20 +0200 Subject: [PATCH 080/118] Add docker files to repository --- .devcontainer/Dockerfile | 14 +++++++ .devcontainer/docker-compose.yml | 69 ++++++++++++++++++++++++++++++++ .gitignore | 4 -- 3 files changed, 83 insertions(+), 4 deletions(-) create mode 100644 .devcontainer/Dockerfile create mode 100644 .devcontainer/docker-compose.yml diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 00000000..5aced2f4 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,14 @@ +FROM mcr.microsoft.com/devcontainers/base:jammy + +ENV PYTHONUNBUFFERED 1 + +# [Optional] If your requirements rarely change, uncomment this section to add them to the image. +# COPY requirements.txt /tmp/pip-tmp/ +# RUN pip3 --disable-pip-version-check --no-cache-dir install -r /tmp/pip-tmp/requirements.txt \ +# && rm -rf /tmp/pip-tmp + +# [Optional] Uncomment this section to install additional OS packages. +# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ +# && apt-get -y install --no-install-recommends + +CMD ["sleep", "infinity"] diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml new file mode 100644 index 00000000..61b13a7c --- /dev/null +++ b/.devcontainer/docker-compose.yml @@ -0,0 +1,69 @@ +services: + dev: + build: + context: . + dockerfile: ./Dockerfile + + env_file: dev.env + + volumes: + - ..:/workspace:cached + + command: sleep infinity + + pg10: + image: postgres:10 + restart: unless-stopped + volumes: + - postgres-data-10:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + + pg12: + image: postgres:12 + restart: unless-stopped + volumes: + - postgres-data-12:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + + pg14: + image: postgres:14 + restart: unless-stopped + volumes: + - postgres-data-14:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + + pg15: + image: postgres:15 + restart: unless-stopped + volumes: + - postgres-data-15:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + + pg16: + image: postgres:16 + restart: unless-stopped + volumes: + - postgres-data-16:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + +volumes: + postgres-data-10: + postgres-data-12: + postgres-data-14: + postgres-data-15: + postgres-data-16: diff --git a/.gitignore b/.gitignore index 8b08bb41..22c5ce3c 100644 --- a/.gitignore +++ b/.gitignore @@ -20,10 +20,6 @@ _build_doctrees/ /local/ /tests/LOCAL_*.py -docker-compose.yml -Dockerfile -Vagrantfile -Vagrantfile-* .coverage .tox/ From 8ec45a29e65f6fe4fe12e903b611db7a0236ac07 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 18 Apr 2024 22:54:41 +0200 Subject: [PATCH 081/118] Fix mintor linting issues --- pg/db.py | 15 +++++++++------ pgdb/adapt.py | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pg/db.py b/pg/db.py index f824cc9d..5c8beea7 100644 --- a/pg/db.py +++ b/pg/db.py @@ -802,8 +802,9 @@ def get(self, table: str, row: Any, adapt = params.add col = self.escape_identifier what = 'oid, *' if qoid else '*' - where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( - col(k), adapt(row[k], attnames[k])) for k in keyname) + where = ' AND '.join( + f'{col(k)} OPERATOR(pg_catalog.=) {adapt(row[k], attnames[k])}' + for k in keyname) if 'oid' in row: if qoid: row[qoid] = row['oid'] @@ -913,8 +914,9 @@ def update(self, table: str, row: dict[str, Any] | None = None, **kw : Any params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier - where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( - col(k), adapt(row[k], attnames[k])) for k in keynames) + where = ' AND '.join( + f'{col(k)} OPERATOR(pg_catalog.=) {adapt(row[k], attnames[k])}' + for k in keynames) if 'oid' in row: if qoid: row[qoid] = row['oid'] @@ -1103,8 +1105,9 @@ def delete(self, table: str, row: dict[str, Any] | None = None, **kw: Any params = self.adapter.parameter_list() adapt = params.add col = self.escape_identifier - where = ' AND '.join('{} OPERATOR(pg_catalog.=) {}'.format( - col(k), adapt(row[k], attnames[k])) for k in keynames) + where = ' AND '.join( + f'{col(k)} OPERATOR(pg_catalog.=) {adapt(row[k], attnames[k])}' + for k in keynames) if 'oid' in row: if qoid: row[qoid] = row['oid'] diff --git a/pgdb/adapt.py b/pgdb/adapt.py index 92b48a7e..b89986b6 100644 --- a/pgdb/adapt.py +++ b/pgdb/adapt.py @@ -33,7 +33,7 @@ def __new__(cls, values: str | Iterable[str]) -> DbType: """Create new type object.""" if isinstance(values, str): values = values.split() - return super().__new__(cls, values) # type: ignore + return super().__new__(cls, values) def __eq__(self, other: Any) -> bool: """Check whether types are considered equal.""" From 683a63632727d229a684d47bb264c1a9bee37b4c Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 18 Apr 2024 20:57:11 +0000 Subject: [PATCH 082/118] Update the year of the copyright --- LICENSE.txt | 2 +- docs/about.rst | 2 +- docs/conf.py | 2 +- docs/copyright.rst | 2 +- ext/pgconn.c | 2 +- ext/pginternal.c | 2 +- ext/pglarge.c | 2 +- ext/pgmodule.c | 2 +- ext/pgnotice.c | 2 +- ext/pgquery.c | 2 +- ext/pgsource.c | 2 +- pg/__init__.py | 2 +- pgdb/__init__.py | 2 +- 13 files changed, 13 insertions(+), 13 deletions(-) diff --git a/LICENSE.txt b/LICENSE.txt index eea706fe..b34bf23b 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -6,7 +6,7 @@ Copyright (c) 1995, Pascal Andre Further modifications copyright (c) 1997-2008 by D'Arcy J.M. Cain -Further modifications copyright (c) 2009-2023 by the PyGreSQL Development Team +Further modifications copyright (c) 2009-2024 by the PyGreSQL Development Team PyGreSQL is released under the PostgreSQL License, a liberal Open Source license, similar to the BSD or MIT licenses: diff --git a/docs/about.rst b/docs/about.rst index 18c6b7a6..180af459 100644 --- a/docs/about.rst +++ b/docs/about.rst @@ -8,7 +8,7 @@ powerful PostgreSQL features from Python. | This software is copyright © 1995, Pascal Andre. | Further modifications are copyright © 1997-2008 by D'Arcy J.M. Cain. - | Further modifications are copyright © 2009-2023 by the PyGreSQL team. + | Further modifications are copyright © 2009-2024 by the PyGreSQL team. | For licensing details, see the full :doc:`copyright`. **PostgreSQL** is a highly scalable, SQL compliant, open source diff --git a/docs/conf.py b/docs/conf.py index 1a63dac4..45a86cd4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -8,7 +8,7 @@ project = 'PyGreSQL' author = 'The PyGreSQL team' -copyright = '2023, ' + author +copyright = '2024, ' + author def project_version(): with open('../pyproject.toml') as f: diff --git a/docs/copyright.rst b/docs/copyright.rst index 9a8113ec..60739ef0 100644 --- a/docs/copyright.rst +++ b/docs/copyright.rst @@ -10,7 +10,7 @@ Copyright (c) 1995, Pascal Andre Further modifications copyright (c) 1997-2008 by D'Arcy J.M. Cain (darcy@PyGreSQL.org) -Further modifications copyright (c) 2009-2023 by the PyGreSQL team. +Further modifications copyright (c) 2009-2024 by the PyGreSQL team. Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement diff --git a/ext/pgconn.c b/ext/pgconn.c index 9ffc0009..ddc958ea 100644 --- a/ext/pgconn.c +++ b/ext/pgconn.c @@ -3,7 +3,7 @@ * * The connection object - this file is part a of the C extension module. * - * Copyright (c) 2023 by the PyGreSQL Development Team + * Copyright (c) 2024 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pginternal.c b/ext/pginternal.c index 124661c1..9b3952cc 100644 --- a/ext/pginternal.c +++ b/ext/pginternal.c @@ -3,7 +3,7 @@ * * Internal functions - this file is part a of the C extension module. * - * Copyright (c) 2023 by the PyGreSQL Development Team + * Copyright (c) 2024 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pglarge.c b/ext/pglarge.c index 77455361..f19568c4 100644 --- a/ext/pglarge.c +++ b/ext/pglarge.c @@ -3,7 +3,7 @@ * * Large object support - this file is part a of the C extension module. * - * Copyright (c) 2023 by the PyGreSQL Development Team + * Copyright (c) 2024 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgmodule.c b/ext/pgmodule.c index 761ae1b7..26b916d6 100644 --- a/ext/pgmodule.c +++ b/ext/pgmodule.c @@ -3,7 +3,7 @@ * * This is the main file for the C extension module. * - * Copyright (c) 2023 by the PyGreSQL Development Team + * Copyright (c) 2024 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgnotice.c b/ext/pgnotice.c index 0252a56f..ca051d88 100644 --- a/ext/pgnotice.c +++ b/ext/pgnotice.c @@ -3,7 +3,7 @@ * * The notice object - this file is part a of the C extension module. * - * Copyright (c) 2023 by the PyGreSQL Development Team + * Copyright (c) 2024 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgquery.c b/ext/pgquery.c index 6346497d..fe5dda47 100644 --- a/ext/pgquery.c +++ b/ext/pgquery.c @@ -3,7 +3,7 @@ * * The query object - this file is part a of the C extension module. * - * Copyright (c) 2023 by the PyGreSQL Development Team + * Copyright (c) 2024 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgsource.c b/ext/pgsource.c index 42510b30..4e197578 100644 --- a/ext/pgsource.c +++ b/ext/pgsource.c @@ -3,7 +3,7 @@ * * The source object - this file is part a of the C extension module. * - * Copyright (c) 2023 by the PyGreSQL Development Team + * Copyright (c) 2024 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pg/__init__.py b/pg/__init__.py index 37447c9e..cb4c7c34 100644 --- a/pg/__init__.py +++ b/pg/__init__.py @@ -4,7 +4,7 @@ # # This file contains the classic pg module. # -# Copyright (c) 2023 by the PyGreSQL Development Team +# Copyright (c) 2024 by the PyGreSQL Development Team # # The notification handler is based on pgnotify which is # Copyright (c) 2001 Ng Pheng Siong. All rights reserved. diff --git a/pgdb/__init__.py b/pgdb/__init__.py index b9a4449a..2604074a 100644 --- a/pgdb/__init__.py +++ b/pgdb/__init__.py @@ -4,7 +4,7 @@ # # This file contains the DB-API 2 compatible pgdb module. # -# Copyright (c) 2023 by the PyGreSQL Development Team +# Copyright (c) 2024 by the PyGreSQL Development Team # # Please see the LICENSE.TXT file for specific restrictions. From 0ab37131bce025668717735969a20667bd9bcb4b Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 18 Apr 2024 21:06:01 +0000 Subject: [PATCH 083/118] Fix GitHub action --- .github/workflows/lint.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index c32a6e58..9e5c0bde 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -14,11 +14,11 @@ jobs: steps: - name: Check out repository - uses: actions/checkout@v5 + uses: actions/checkout@v4 - name: Install tox run: pip install tox - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: 3.12 - name: Run quality checks From d55137969d33130e9c025dc4da77c87522008c4b Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 19 Apr 2024 21:34:09 +0200 Subject: [PATCH 084/118] Bump version number --- docs/contents/changelog.rst | 4 ++++ pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index ac501b56..6afcb1e8 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,10 @@ ChangeLog ========= +Version 6.0.1 (2024-04-19) +-------------------------- +- Properly adapt falsy JSON values (#86) + Version 6.0 (2023-10-03) ------------------------ - Tested with the recent releases of Python 3.12 and PostgreSQL 16. diff --git a/pyproject.toml b/pyproject.toml index a3be7012..ef1de2a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "PyGreSQL" -version = "6.0" +version = "6.0.1" requires-python = ">=3.7" authors = [ {name = "D'Arcy J. M. Cain", email = "darcy@pygresql.org"}, From 40ba811a3b424088cebcc7338e33a3f265c0fb87 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 19 Apr 2024 20:33:25 +0000 Subject: [PATCH 085/118] Fix issues with provision.sh --- .devcontainer/devcontainer.json | 2 +- .devcontainer/provision.sh | 37 ++++++++++++++++++--------------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index b9fbaaeb..0333b8e6 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -56,7 +56,7 @@ // Use 'forwardPorts' to make a list of ports inside the container available locally. // "forwardPorts": [], // Use 'postCreateCommand' to run commands after the container is created. - "postCreateCommand": "bash /workspace/.devcontainer/provision.sh" + "postCreateCommand": "sudo bash /workspace/.devcontainer/provision.sh" // Configure tool-specific properties. // "customizations": {}, // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index 09acd893..5515b687 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -4,46 +4,49 @@ export DEBIAN_FRONTEND=noninteractive -sudo apt-get update -sudo apt-get -y upgrade +apt-get update +apt-get -y upgrade # install base utilities and configure time zone -sudo ln -fs /usr/share/zoneinfo/UTC /etc/localtime -sudo apt-get install -y apt-utils software-properties-common -sudo apt-get install -y tzdata -sudo dpkg-reconfigure --frontend noninteractive tzdata +ln -fs /usr/share/zoneinfo/UTC /etc/localtime +apt-get install -y apt-utils software-properties-common +ap-get install -y tzdata +dpkg-reconfigure --frontend noninteractive tzdata -sudo apt-get install -y rpm wget zip +apt-get install -y rpm wget zip # install all supported Python versions -sudo add-apt-repository -y ppa:deadsnakes/ppa -sudo apt-get update +add-apt-repository -y ppa:deadsnakes/ppa +apt-get update -sudo apt-get install -y python3.7 python3.7-dev python3.7-distutils -sudo apt-get install -y python3.8 python3.8-dev python3.8-distutils -sudo apt-get install -y python3.9 python3.9-dev python3.9-distutils -sudo apt-get install -y python3.10 python3.10-dev python3.10-distutils -sudo apt-get install -y python3.11 python3.11-dev python3.11-distutils -sudo apt-get install -y python3.12 python3.12-dev python3.12-distutils +apt-get install -y python3.7 python3.7-dev python3.7-distutils +apt-get install -y python3.8 python3.8-dev python3.8-distutils +apt-get install -y python3.9 python3.9-dev python3.9-distutils +apt-get install -y python3.10 python3.10-dev python3.10-distutils +apt-get install -y python3.11 python3.11-dev python3.11-distutils +apt-get install -y python3.12 python3.12-dev python3.12-distutils # install build and testing tool +python -m ensurepip -U + python3.7 -m pip install -U pip setuptools wheel build python3.8 -m pip install -U pip setuptools wheel build python3.9 -m pip install -U pip setuptools wheel build python3.10 -m pip install -U pip setuptools wheel build python3.11 -m pip install -U pip setuptools wheel build +python3.12 -m pip install -U pip setuptools wheel build pip install ruff -sudo apt-get install -y tox clang-format +apt-get install -y tox clang-format pip install -U tox # install PostgreSQL client tools -sudo apt-get install -y postgresql libpq-dev +apt-get install -y postgresql libpq-dev for pghost in pg10 pg12 pg14 pg15 pg16 do From 0f18e8060c9037ac555c133f3c8ec80202d33492 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 27 Jul 2024 17:39:38 +0200 Subject: [PATCH 086/118] Fix doc for DB.delete (#87) --- docs/contents/pg/db_wrapper.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/contents/pg/db_wrapper.rst b/docs/contents/pg/db_wrapper.rst index 1dbd18ef..b9e72b69 100644 --- a/docs/contents/pg/db_wrapper.rst +++ b/docs/contents/pg/db_wrapper.rst @@ -715,7 +715,7 @@ delete -- delete a row from a database table Delete a row from a database table :param str table: name of table - :param dict d: optional dictionary of values + :param dict row: optional dictionary of values :param col: optional keyword arguments for updating the dictionary :rtype: None :raises pg.ProgrammingError: table has no primary key, From a07a71d3fba3e3a653623f7274109f658fcd19a0 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 27 Jul 2024 17:38:35 +0000 Subject: [PATCH 087/118] Use newer mypy and ruff --- tests/test_classic_connection.py | 8 ++++---- tests/test_dbapi20.py | 2 +- tox.ini | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 3f9427b2..180563ed 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -1997,10 +1997,10 @@ def test_inserttable_byte_values(self): row_bytes = tuple( s.encode() if isinstance(s, str) else s for s in row_unicode) - data = [row_bytes] * 2 - self.c.inserttable('test', data) - data = [row_unicode] * 2 - self.assertEqual(self.get_back(), data) + data_bytes = [row_bytes] * 2 + self.c.inserttable('test', data_bytes) + data_unicode = [row_unicode] * 2 + self.assertEqual(self.get_back(), data_unicode) def test_inserttable_unicode_utf8(self): try: diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index ef4857d3..0e70e073 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -161,7 +161,7 @@ def test_row_factory(self): class TestCursor(pgdb.Cursor): - def row_factory(self, row): + def row_factory(self, row): # type: ignore[override] description = self.description assert isinstance(description, list) return {f'column {desc[0]}': value diff --git a/tox.ini b/tox.ini index f703abb5..0679e456 100644 --- a/tox.ini +++ b/tox.ini @@ -5,13 +5,13 @@ envlist = py3{7,8,9,10,11,12},ruff,mypy,cformat,docs [testenv:ruff] basepython = python3.12 -deps = ruff>=0.3.7 +deps = ruff>=0.5,<0.6 commands = ruff check setup.py pg pgdb tests [testenv:mypy] basepython = python3.12 -deps = mypy>=1.9.0 +deps = mypy>=1.11,<1.12 commands = mypy pg pgdb tests @@ -43,7 +43,7 @@ basepython = python3.12 deps = coverage>=7,<8 commands = - coverage run -m unittest discover + coverage run -m unittest discover -v coverage html [testenv] @@ -54,4 +54,4 @@ deps = setuptools>=68 commands = python setup.py clean --all build_ext --force --inplace --strict --memory-size - python -m unittest {posargs:discover} + python -m unittest {posargs:discover -v} From 487452e988e212db426780c3851323629c27b55b Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 5 Dec 2024 12:18:55 +0000 Subject: [PATCH 088/118] Update dependencies and supported versions --- .devcontainer/docker-compose.yml | 11 ++++ .devcontainer/provision.sh | 4 +- .github/workflows/docs.yml | 6 +- .github/workflows/lint.yml | 2 +- .github/workflows/tests.yml | 14 ++-- README.rst | 4 +- docs/about.rst | 4 +- docs/contents/install.rst | 2 +- pg/__init__.py | 108 +++++++++++++++++++++++-------- pg/adapt.py | 10 ++- pg/cast.py | 18 ++++-- pg/core.py | 92 +++++++++++++++++++------- pg/error.py | 7 +- pg/helpers.py | 10 ++- pgdb/__init__.py | 71 +++++++++++++++----- pgdb/adapt.py | 36 +++++++++-- pgdb/cast.py | 23 +++++-- pyproject.toml | 3 +- setup.py | 1 + tests/test_classic_connection.py | 2 +- tests/test_classic_largeobj.py | 15 ++--- tox.ini | 20 +++--- 22 files changed, 340 insertions(+), 123 deletions(-) diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index 61b13a7c..541d63e9 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -61,9 +61,20 @@ services: POSTGRES_DB: postgres POSTGRES_PASSWORD: postgres + pg17: + image: postgres:17 + restart: unless-stopped + volumes: + - postgres-data-17:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + volumes: postgres-data-10: postgres-data-12: postgres-data-14: postgres-data-15: postgres-data-16: + postgres-data-17: diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index 5515b687..1ca7b020 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -27,6 +27,7 @@ apt-get install -y python3.9 python3.9-dev python3.9-distutils apt-get install -y python3.10 python3.10-dev python3.10-distutils apt-get install -y python3.11 python3.11-dev python3.11-distutils apt-get install -y python3.12 python3.12-dev python3.12-distutils +apt-get install -y python3.13 python3.13-dev python3.13-distutils # install build and testing tool @@ -38,6 +39,7 @@ python3.9 -m pip install -U pip setuptools wheel build python3.10 -m pip install -U pip setuptools wheel build python3.11 -m pip install -U pip setuptools wheel build python3.12 -m pip install -U pip setuptools wheel build +python3.13 -m pip install -U pip setuptools wheel build pip install ruff @@ -48,7 +50,7 @@ pip install -U tox apt-get install -y postgresql libpq-dev -for pghost in pg10 pg12 pg14 pg15 pg16 +for pghost in pg10 pg12 pg14 pg15 pg16 pg17 do export PGHOST=$pghost export PGDATABASE=postgres diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 50248b64..d88cd64a 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -13,16 +13,16 @@ jobs: steps: - name: Check out repository uses: actions/checkout@v4 - - name: Set up Python 3.12 + - name: Set up Python 3.13 uses: actions/setup-python@v5 with: - python-version: 3.12 + python-version: 3.13 - name: Install dependencies run: | sudo apt install libpq-dev python -m pip install --upgrade pip pip install . - pip install "sphinx>=7,<8" + pip install "sphinx>=8,<9" - name: Create docs with Sphinx run: | cd docs diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 9e5c0bde..66d79095 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -20,7 +20,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5 with: - python-version: 3.12 + python-version: 3.13 - name: Run quality checks run: tox -e ruff,mypy,cformat,docs timeout-minutes: 5 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 822fabdb..920e3f3e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -21,14 +21,16 @@ jobs: - { python: "3.10", postgres: "14" } - { python: "3.11", postgres: "15" } - { python: "3.12", postgres: "16" } + - { python: "3.13", postgres: "17" } # Opposite extremes of the supported Py/PG range, other architecture - - { python: "3.7", postgres: "16", architecture: "x86" } - - { python: "3.8", postgres: "15", architecture: "x86" } - - { python: "3.9", postgres: "14", architecture: "x86" } - - { python: "3.10", postgres: "13", architecture: "x86" } - - { python: "3.11", postgres: "12", architecture: "x86" } - - { python: "3.12", postgres: "11", architecture: "x86" } + - { python: "3.7", postgres: "17", architecture: "x86" } + - { python: "3.8", postgres: "16", architecture: "x86" } + - { python: "3.9", postgres: "15", architecture: "x86" } + - { python: "3.10", postgres: "14", architecture: "x86" } + - { python: "3.11", postgres: "13", architecture: "x86" } + - { python: "3.12", postgres: "12", architecture: "x86" } + - { python: "3.13", postgres: "11", architecture: "x86" } env: PYGRESQL_DB: test diff --git a/README.rst b/README.rst index e9f9465c..46a09c2b 100644 --- a/README.rst +++ b/README.rst @@ -18,8 +18,8 @@ The following Python versions are supported: * PyGreSQL 5.x: Python 2 and Python 3 * PyGreSQL 6.x and newer: Python 3 only -The current version of PyGreSQL supports Python versions 3.7 to 3.12 -and PostgreSQL versions 10 to 16 on the server. +The current version of PyGreSQL supports Python versions 3.7 to 3.13 +and PostgreSQL versions 10 to 17 on the server. Installation ------------ diff --git a/docs/about.rst b/docs/about.rst index 180af459..ec1dbd2f 100644 --- a/docs/about.rst +++ b/docs/about.rst @@ -39,6 +39,6 @@ on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -The current version PyGreSQL |version| needs PostgreSQL 10 to 16, and Python -3.7 to 3.12. If you need to support older PostgreSQL or Python versions, +The current version PyGreSQL |version| needs PostgreSQL 10 to 17, and Python +3.7 to 3.13. If you need to support older PostgreSQL or Python versions, you can resort to the PyGreSQL 5.x versions that still support them. diff --git a/docs/contents/install.rst b/docs/contents/install.rst index 7d28ea59..23694528 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -14,7 +14,7 @@ On Windows, you also need to make sure that the directory that contains ``libpq.dll`` is part of your ``PATH`` environment variable. The current version of PyGreSQL has been tested with Python versions -3.7 to 3.12, and PostgreSQL versions 10 to 16. +3.7 to 3.13, and PostgreSQL versions 10 to 17. PyGreSQL will be installed as two packages named ``pg`` (for the classic interface) and ``pgdb`` (for the DB API 2 compliant interface). The former diff --git a/pg/__init__.py b/pg/__init__.py index cb4c7c34..eeda3b73 100644 --- a/pg/__init__.py +++ b/pg/__init__.py @@ -99,34 +99,86 @@ from .notify import NotificationHandler __all__ = [ - 'DB', 'Adapter', - 'NotificationHandler', 'Typecasts', - 'Bytea', 'Hstore', 'Json', 'Literal', - 'Error', 'Warning', - 'DataError', 'DatabaseError', - 'IntegrityError', 'InterfaceError', 'InternalError', - 'InvalidResultError', 'MultipleResultsError', - 'NoResultError', 'NotSupportedError', - 'OperationalError', 'ProgrammingError', - 'Connection', 'Query', 'RowCache', - 'INV_READ', 'INV_WRITE', - 'POLLING_OK', 'POLLING_FAILED', 'POLLING_READING', 'POLLING_WRITING', - 'RESULT_DDL', 'RESULT_DML', 'RESULT_DQL', 'RESULT_EMPTY', - 'SEEK_CUR', 'SEEK_END', 'SEEK_SET', - 'TRANS_ACTIVE', 'TRANS_IDLE', 'TRANS_INERROR', - 'TRANS_INTRANS', 'TRANS_UNKNOWN', - 'cast_array', 'cast_hstore', 'cast_record', - 'connect', 'escape_bytea', 'escape_string', 'unescape_bytea', - 'get_array', 'get_bool', 'get_bytea_escaped', - 'get_datestyle', 'get_decimal', 'get_decimal_point', - 'get_defbase', 'get_defhost', 'get_defopt', 'get_defport', 'get_defuser', - 'get_jsondecode', 'get_pqlib_version', 'get_typecast', - 'set_array', 'set_bool', 'set_bytea_escaped', - 'set_datestyle', 'set_decimal', 'set_decimal_point', - 'set_defbase', 'set_defhost', 'set_defopt', - 'set_defpasswd', 'set_defport', 'set_defuser', - 'set_jsondecode', 'set_query_helpers', 'set_typecast', - 'version', '__version__', + 'DB', + 'INV_READ', + 'INV_WRITE', + 'POLLING_FAILED', + 'POLLING_OK', + 'POLLING_READING', + 'POLLING_WRITING', + 'RESULT_DDL', + 'RESULT_DML', + 'RESULT_DQL', + 'RESULT_EMPTY', + 'SEEK_CUR', + 'SEEK_END', + 'SEEK_SET', + 'TRANS_ACTIVE', + 'TRANS_IDLE', + 'TRANS_INERROR', + 'TRANS_INTRANS', + 'TRANS_UNKNOWN', + 'Adapter', + 'Bytea', + 'Connection', + 'DataError', + 'DatabaseError', + 'Error', + 'Hstore', + 'IntegrityError', + 'InterfaceError', + 'InternalError', + 'InvalidResultError', + 'Json', + 'Literal', + 'MultipleResultsError', + 'NoResultError', + 'NotSupportedError', + 'NotificationHandler', + 'OperationalError', + 'ProgrammingError', + 'Query', + 'RowCache', + 'Typecasts', + 'Warning', + '__version__', + 'cast_array', + 'cast_hstore', + 'cast_record', + 'connect', + 'escape_bytea', + 'escape_string', + 'get_array', + 'get_bool', + 'get_bytea_escaped', + 'get_datestyle', + 'get_decimal', + 'get_decimal_point', + 'get_defbase', + 'get_defhost', + 'get_defopt', + 'get_defport', + 'get_defuser', + 'get_jsondecode', + 'get_pqlib_version', + 'get_typecast', + 'set_array', + 'set_bool', + 'set_bytea_escaped', + 'set_datestyle', + 'set_decimal', + 'set_decimal_point', + 'set_defbase', + 'set_defhost', + 'set_defopt', + 'set_defpasswd', + 'set_defport', + 'set_defuser', + 'set_jsondecode', + 'set_query_helpers', + 'set_typecast', + 'unescape_bytea', + 'version', ] __version__ = version diff --git a/pg/adapt.py b/pg/adapt.py index 2a5efaa2..97e0391c 100644 --- a/pg/adapt.py +++ b/pg/adapt.py @@ -21,8 +21,14 @@ from .db import DB __all__ = [ - 'Adapter', 'Bytea', 'DbType', 'DbTypes', - 'Hstore', 'Literal', 'Json', 'UUID' + 'UUID', + 'Adapter', + 'Bytea', + 'DbType', + 'DbTypes', + 'Hstore', + 'Json', + 'Literal' ] diff --git a/pg/cast.py b/pg/cast.py index ad1758be..98baa8f6 100644 --- a/pg/cast.py +++ b/pg/cast.py @@ -25,10 +25,20 @@ from .tz import timezone_as_offset __all__ = [ - 'cast_bool', 'cast_json', 'cast_num', 'cast_money', 'cast_int2vector', - 'cast_date', 'cast_time', 'cast_timetz', 'cast_interval', - 'cast_timestamp','cast_timestamptz', - 'Typecasts', 'get_typecast', 'set_typecast' + 'Typecasts', + 'cast_bool', + 'cast_date', + 'cast_int2vector', + 'cast_interval', + 'cast_json', + 'cast_money', + 'cast_num', + 'cast_time', + 'cast_timestamp', + 'cast_timestamptz', + 'cast_timetz', + 'get_typecast', + 'set_typecast' ] def get_args(func: Callable) -> list: diff --git a/pg/core.py b/pg/core.py index e20bdbd0..4d0c03c0 100644 --- a/pg/core.py +++ b/pg/core.py @@ -108,29 +108,73 @@ ) __all__ = [ - 'Error', 'Warning', - 'DataError', 'DatabaseError', - 'IntegrityError', 'InterfaceError', 'InternalError', - 'InvalidResultError', 'MultipleResultsError', - 'NoResultError', 'NotSupportedError', - 'OperationalError', 'ProgrammingError', - 'Connection', 'Query', 'LargeObject', - 'INV_READ', 'INV_WRITE', - 'POLLING_OK', 'POLLING_FAILED', 'POLLING_READING', 'POLLING_WRITING', - 'RESULT_DDL', 'RESULT_DML', 'RESULT_DQL', 'RESULT_EMPTY', - 'SEEK_CUR', 'SEEK_END', 'SEEK_SET', - 'TRANS_ACTIVE', 'TRANS_IDLE', 'TRANS_INERROR', - 'TRANS_INTRANS', 'TRANS_UNKNOWN', - 'cast_array', 'cast_hstore', 'cast_record', - 'connect', 'escape_bytea', 'escape_string', 'unescape_bytea', - 'get_array', 'get_bool', 'get_bytea_escaped', - 'get_datestyle', 'get_decimal', 'get_decimal_point', - 'get_defbase', 'get_defhost', 'get_defopt', 'get_defport', 'get_defuser', - 'get_jsondecode', 'get_pqlib_version', - 'set_array', 'set_bool', 'set_bytea_escaped', - 'set_datestyle', 'set_decimal', 'set_decimal_point', - 'set_defbase', 'set_defhost', 'set_defopt', - 'set_defpasswd', 'set_defport', 'set_defuser', - 'set_jsondecode', 'set_query_helpers', + 'INV_READ', + 'INV_WRITE', + 'POLLING_FAILED', + 'POLLING_OK', + 'POLLING_READING', + 'POLLING_WRITING', + 'RESULT_DDL', + 'RESULT_DML', + 'RESULT_DQL', + 'RESULT_EMPTY', + 'SEEK_CUR', + 'SEEK_END', + 'SEEK_SET', + 'TRANS_ACTIVE', + 'TRANS_IDLE', + 'TRANS_INERROR', + 'TRANS_INTRANS', + 'TRANS_UNKNOWN', + 'Connection', + 'DataError', + 'DatabaseError', + 'Error', + 'IntegrityError', + 'InterfaceError', + 'InternalError', + 'InvalidResultError', + 'LargeObject', + 'MultipleResultsError', + 'NoResultError', + 'NotSupportedError', + 'OperationalError', + 'ProgrammingError', + 'Query', + 'Warning', + 'cast_array', + 'cast_hstore', + 'cast_record', + 'connect', + 'escape_bytea', + 'escape_string', + 'get_array', + 'get_bool', + 'get_bytea_escaped', + 'get_datestyle', + 'get_decimal', + 'get_decimal_point', + 'get_defbase', + 'get_defhost', + 'get_defopt', + 'get_defport', + 'get_defuser', + 'get_jsondecode', + 'get_pqlib_version', + 'set_array', + 'set_bool', + 'set_bytea_escaped', + 'set_datestyle', + 'set_decimal', + 'set_decimal_point', + 'set_defbase', + 'set_defhost', + 'set_defopt', + 'set_defpasswd', + 'set_defport', + 'set_defuser', + 'set_jsondecode', + 'set_query_helpers', + 'unescape_bytea', 'version', ] diff --git a/pg/error.py b/pg/error.py index 484a1252..f4b9fd0f 100644 --- a/pg/error.py +++ b/pg/error.py @@ -14,7 +14,12 @@ ) __all__ = [ - 'error', 'db_error', 'if_error', 'int_error', 'op_error', 'prg_error' + 'db_error', + 'error', + 'if_error', + 'int_error', + 'op_error', + 'prg_error' ] # Error messages diff --git a/pg/helpers.py b/pg/helpers.py index 53689f6a..9d176740 100644 --- a/pg/helpers.py +++ b/pg/helpers.py @@ -13,8 +13,14 @@ SomeNamedTuple = Any # alias for accessing arbitrary named tuples __all__ = [ - 'quote_if_unqualified', 'oid_key', 'QuoteDict', 'RowCache', - 'dictiter', 'namediter', 'namednext', 'scalariter' + 'QuoteDict', + 'RowCache', + 'dictiter', + 'namediter', + 'namednext', + 'oid_key', + 'quote_if_unqualified', + 'scalariter' ] diff --git a/pgdb/__init__.py b/pgdb/__init__.py index 2604074a..5db2fd46 100644 --- a/pgdb/__init__.py +++ b/pgdb/__init__.py @@ -121,21 +121,62 @@ from .cursor import Cursor __all__ = [ - 'Connection', 'Cursor', - 'Date', 'Time', 'Timestamp', - 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks', - 'Binary', 'Interval', 'Uuid', - 'Hstore', 'Json', 'Literal', 'DbType', - 'STRING', 'BINARY', 'NUMBER', 'DATETIME', 'ROWID', 'BOOL', - 'SMALLINT', 'INTEGER', 'LONG', 'FLOAT', 'NUMERIC', 'MONEY', - 'DATE', 'TIME', 'TIMESTAMP', 'INTERVAL', - 'UUID', 'HSTORE', 'JSON', 'ARRAY', 'RECORD', - 'Error', 'Warning', - 'InterfaceError', 'DatabaseError', 'DataError', 'OperationalError', - 'IntegrityError', 'InternalError', 'ProgrammingError', 'NotSupportedError', - 'get_typecast', 'set_typecast', 'reset_typecast', - 'apilevel', 'connect', 'paramstyle', 'shortcutmethods', 'threadsafety', - 'version', '__version__', + 'ARRAY', + 'BINARY', + 'BOOL', + 'DATE', + 'DATETIME', + 'FLOAT', + 'HSTORE', + 'INTEGER', + 'INTERVAL', + 'JSON', + 'LONG', + 'MONEY', + 'NUMBER', + 'NUMERIC', + 'RECORD', + 'ROWID', + 'SMALLINT', + 'STRING', + 'TIME', + 'TIMESTAMP', + 'UUID', + 'Binary', + 'Connection', + 'Cursor', + 'DataError', + 'DatabaseError', + 'Date', + 'DateFromTicks', + 'DbType', + 'Error', + 'Hstore', + 'IntegrityError', + 'InterfaceError', + 'InternalError', + 'Interval', + 'Json', + 'Literal', + 'NotSupportedError', + 'OperationalError', + 'ProgrammingError', + 'Time', + 'TimeFromTicks', + 'Timestamp', + 'TimestampFromTicks', + 'Uuid', + 'Warning', + '__version__', + 'apilevel', + 'connect', + 'get_typecast', + 'paramstyle', + 'reset_typecast', + 'set_typecast', + 'shortcutmethods', + 'threadsafety', + 'version', ] __version__ = version diff --git a/pgdb/adapt.py b/pgdb/adapt.py index b89986b6..f657b190 100644 --- a/pgdb/adapt.py +++ b/pgdb/adapt.py @@ -12,12 +12,36 @@ from .typecode import TypeCode __all__ = [ - 'DbType', 'ArrayType', 'RecordType', - 'STRING', 'BINARY', 'NUMBER', 'DATETIME', 'ROWID', 'BOOL', 'SMALLINT', - 'INTEGER', 'LONG', 'FLOAT', 'NUMERIC', 'MONEY', 'DATE', 'TIME', - 'TIMESTAMP', 'INTERVAL', 'UUID', 'HSTORE', 'JSON', 'ARRAY', 'RECORD', - 'Date', 'Time', 'Timestamp', - 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks' + 'ARRAY', + 'BINARY', + 'BOOL', + 'DATE', + 'DATETIME', + 'FLOAT', + 'HSTORE', + 'INTEGER', + 'INTERVAL', + 'JSON', + 'LONG', + 'MONEY', + 'NUMBER', + 'NUMERIC', + 'RECORD', + 'ROWID', + 'SMALLINT', + 'STRING', + 'TIME', + 'TIMESTAMP', + 'UUID', + 'ArrayType', + 'Date', + 'DateFromTicks', + 'DbType', + 'RecordType', + 'Time', + 'TimeFromTicks', + 'Timestamp', + 'TimestampFromTicks' ] diff --git a/pgdb/cast.py b/pgdb/cast.py index 03367506..49b4bd84 100644 --- a/pgdb/cast.py +++ b/pgdb/cast.py @@ -24,11 +24,24 @@ from .typecode import TypeCode __all__ = [ - 'Decimal', 'decimal_type', 'cast_bool', 'cast_money', - 'cast_int2vector', 'cast_date', 'cast_time', 'cast_interval', - 'cast_timetz', 'cast_timestamp', 'cast_timestamptz', - 'get_typecast', 'set_typecast', 'reset_typecast', - 'Typecasts', 'LocalTypecasts', 'TypeCache', 'FieldInfo' + 'Decimal', + 'FieldInfo', + 'LocalTypecasts', + 'TypeCache', + 'Typecasts', + 'cast_bool', + 'cast_date', + 'cast_int2vector', + 'cast_interval', + 'cast_money', + 'cast_time', + 'cast_timestamp', + 'cast_timestamptz', + 'cast_timetz', + 'decimal_type', + 'get_typecast', + 'reset_typecast', + 'set_typecast' ] diff --git a/pyproject.toml b/pyproject.toml index ef1de2a6..3ef4d645 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: SQL", "Topic :: Database", "Topic :: Database :: Front-Ends", @@ -78,7 +79,7 @@ ignore = ["D203", "D213"] "tests/*.py" = ["D100", "D101", "D102", "D103", "D105", "D107", "S"] [tool.mypy] -python_version = "3.12" +python_version = "3.13" check_untyped_defs = true no_implicit_optional = true strict_optional = true diff --git a/setup.py b/setup.py index 8b1ec5dc..950364c0 100755 --- a/setup.py +++ b/setup.py @@ -170,6 +170,7 @@ def finalize_options(self): 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', 'Programming Language :: SQL', 'Topic :: Database', 'Topic :: Database :: Front-Ends', diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 180563ed..7234ffb6 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -984,7 +984,7 @@ def test_query_with_bool_params(self, bool_enabled=None): pg.set_bool(bool_enabled) try: bool_on = bool_enabled or bool_enabled is None - v_false, v_true = (False, True) if bool_on else 'ft' + v_false, v_true = (False, True) if bool_on else ('f', 't') r_false, r_true = [(v_false,)], [(v_true,)] self.assertEqual(query("select false").getresult(), r_false) self.assertEqual(query("select true").getresult(), r_true) diff --git a/tests/test_classic_largeobj.py b/tests/test_classic_largeobj.py index 4fb8773c..7c53053d 100755 --- a/tests/test_classic_largeobj.py +++ b/tests/test_classic_largeobj.py @@ -112,7 +112,7 @@ def test_lo_import(self): fname = 'temp_test_pg_largeobj_import.txt' f = open(fname, 'wb') # noqa: SIM115 else: - f = tempfile.NamedTemporaryFile() + f = tempfile.NamedTemporaryFile() # noqa: SIM115 fname = f.name data = b'some data to be imported' f.write(data) @@ -420,7 +420,7 @@ def test_export(self): fname = 'temp_test_pg_largeobj_export.txt' f = open(fname, 'wb') # noqa: SIM115 else: - f = tempfile.NamedTemporaryFile() + f = tempfile.NamedTemporaryFile() # noqa: SIM115 fname = f.name data = b'some data to be exported' self.obj.open(pg.INV_WRITE) @@ -441,12 +441,11 @@ def test_export(self): def test_export_in_existent(self): export = self.obj.export - f = tempfile.NamedTemporaryFile() - self.obj.open(pg.INV_WRITE) - self.obj.close() - self.pgcnx.query(f'select lo_unlink({self.obj.oid})') - self.assertRaises(IOError, export, f.name) - f.close() + with tempfile.NamedTemporaryFile() as f: + self.obj.open(pg.INV_WRITE) + self.obj.close() + self.pgcnx.query(f'select lo_unlink({self.obj.oid})') + self.assertRaises(IOError, export, f.name) if __name__ == '__main__': diff --git a/tox.ini b/tox.ini index 0679e456..e89c7d73 100644 --- a/tox.ini +++ b/tox.ini @@ -1,36 +1,36 @@ # config file for tox [tox] -envlist = py3{7,8,9,10,11,12},ruff,mypy,cformat,docs +envlist = py3{7,8,9,10,11,12,13},ruff,mypy,cformat,docs [testenv:ruff] -basepython = python3.12 -deps = ruff>=0.5,<0.6 +basepython = python3.13 +deps = ruff>=0.8,<0.9 commands = ruff check setup.py pg pgdb tests [testenv:mypy] -basepython = python3.12 -deps = mypy>=1.11,<1.12 +basepython = python3.13 +deps = mypy>=1.13,<1.14 commands = mypy pg pgdb tests [testenv:cformat] -basepython = python3.12 +basepython = python3.13 allowlist_externals = sh commands = sh -c "! (clang-format --style=file -n ext/*.c 2>&1 | tee /dev/tty | grep format-violations)" [testenv:docs] -basepython = python3.12 +basepython = python3.13 deps = - sphinx>=7,<8 + sphinx>=8,<9 commands = sphinx-build -b html -nEW docs docs/_build/html [testenv:build] -basepython = python3.12 +basepython = python3.13 deps = setuptools>=68 wheel>=0.42,<1 @@ -39,7 +39,7 @@ commands = python -m build -s -n -C strict -C memory-size [testenv:coverage] -basepython = python3.12 +basepython = python3.13 deps = coverage>=7,<8 commands = From a29e5822c90a3a5fe1a86baf8571f89843c1afef Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 5 Dec 2024 14:10:44 +0000 Subject: [PATCH 089/118] Test should work with Pg 17 client and newer --- tests/test_classic_connection.py | 4 ++-- tests/test_classic_dbwrapper.py | 4 ++-- tests/test_classic_functions.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 7234ffb6..90d69a59 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -174,8 +174,8 @@ def test_attribute_protocol_version(self): def test_attribute_server_version(self): server_version = self.connection.server_version self.assertIsInstance(server_version, int) - self.assertGreaterEqual(server_version, 100000) - self.assertLess(server_version, 170000) + self.assertGreaterEqual(server_version, 100000) # >= 10.0 + self.assertLess(server_version, 190000) # < 20.0 def test_attribute_socket(self): socket = self.connection.socket diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index f02955c7..1d64c754 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -168,8 +168,8 @@ def test_attribute_protocol_version(self): def test_attribute_server_version(self): server_version = self.db.server_version self.assertIsInstance(server_version, int) - self.assertGreaterEqual(server_version, 100000) - self.assertLess(server_version, 170000) + self.assertGreaterEqual(server_version, 100000) # >= 10.0 + self.assertLess(server_version, 200000) # < 20.0 self.assertEqual(server_version, self.db.db.server_version) def test_attribute_socket(self): diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index 4351f794..d1bde01c 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -124,8 +124,8 @@ def test_pqlib_version(self): # noinspection PyUnresolvedReferences v = pg.get_pqlib_version() self.assertIsInstance(v, int) - self.assertGreater(v, 100000) - self.assertLess(v, 170000) + self.assertGreater(v, 100000) # >= 10.0 + self.assertLess(v, 200000) # < 20.0 class TestParseArray(unittest.TestCase): From 6ee4c4565bf20332656503ef9f3201cfec2eaa18 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 5 Dec 2024 14:13:32 +0000 Subject: [PATCH 090/118] Make tox work with Python 3.7 --- tox.ini | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tox.ini b/tox.ini index e89c7d73..2359c8df 100644 --- a/tox.ini +++ b/tox.ini @@ -2,6 +2,9 @@ [tox] envlist = py3{7,8,9,10,11,12,13},ruff,mypy,cformat,docs +requires = # this is needed for compatibility with Python 3.7 + pip<24.1 + virtualenv<20.27 [testenv:ruff] basepython = python3.13 From 417f5430f5375550e6955e185b60742c52c861a0 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 5 Dec 2024 16:51:23 +0100 Subject: [PATCH 091/118] Bump minor version --- docs/contents/changelog.rst | 4 ++++ pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 6afcb1e8..ad5f7f0e 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,10 @@ ChangeLog ========= +Version 6.1.0 (2024-12-05) +-------------------------- +- Support Python 3.13 and PostgreSQL 17. + Version 6.0.1 (2024-04-19) -------------------------- - Properly adapt falsy JSON values (#86) diff --git a/pyproject.toml b/pyproject.toml index 3ef4d645..01b5086f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "PyGreSQL" -version = "6.0.1" +version = "6.1.0" requires-python = ">=3.7" authors = [ {name = "D'Arcy J. M. Cain", email = "darcy@pygresql.org"}, From fae41b5cfd1c28a405839d9e3ed5938432041f64 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Thu, 5 Dec 2024 17:43:33 +0100 Subject: [PATCH 092/118] Make it compile with latest MSVSC --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 950364c0..bf652276 100755 --- a/setup.py +++ b/setup.py @@ -136,7 +136,7 @@ def finalize_options(self): define_macros.append(('MS_WIN64', None)) elif compiler == 'msvc': # Microsoft Visual C++ extra_compile_args[1:] = [ - '-J', '-W3', '-WX', + '-J', '-W3', '-WX', '-wd4391', '-Dinline=__inline'] # needed for MSVC 9 From ca4392e98febb5b0b50cb087afc3a09538a07d17 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 6 Jan 2025 16:02:14 +0000 Subject: [PATCH 093/118] Update year of copyright --- LICENSE.txt | 2 +- docs/about.rst | 2 +- docs/conf.py | 2 +- docs/copyright.rst | 2 +- ext/pgconn.c | 2 +- ext/pginternal.c | 2 +- ext/pglarge.c | 2 +- ext/pgmodule.c | 2 +- ext/pgnotice.c | 2 +- ext/pgquery.c | 2 +- ext/pgsource.c | 2 +- pg/__init__.py | 2 +- pgdb/__init__.py | 2 +- 13 files changed, 13 insertions(+), 13 deletions(-) diff --git a/LICENSE.txt b/LICENSE.txt index b34bf23b..e905706e 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -6,7 +6,7 @@ Copyright (c) 1995, Pascal Andre Further modifications copyright (c) 1997-2008 by D'Arcy J.M. Cain -Further modifications copyright (c) 2009-2024 by the PyGreSQL Development Team +Further modifications copyright (c) 2009-2025 by the PyGreSQL Development Team PyGreSQL is released under the PostgreSQL License, a liberal Open Source license, similar to the BSD or MIT licenses: diff --git a/docs/about.rst b/docs/about.rst index ec1dbd2f..10ceaf59 100644 --- a/docs/about.rst +++ b/docs/about.rst @@ -8,7 +8,7 @@ powerful PostgreSQL features from Python. | This software is copyright © 1995, Pascal Andre. | Further modifications are copyright © 1997-2008 by D'Arcy J.M. Cain. - | Further modifications are copyright © 2009-2024 by the PyGreSQL team. + | Further modifications are copyright © 2009-2025 by the PyGreSQL team. | For licensing details, see the full :doc:`copyright`. **PostgreSQL** is a highly scalable, SQL compliant, open source diff --git a/docs/conf.py b/docs/conf.py index 45a86cd4..f25d78e7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -8,7 +8,7 @@ project = 'PyGreSQL' author = 'The PyGreSQL team' -copyright = '2024, ' + author +copyright = '2025, ' + author def project_version(): with open('../pyproject.toml') as f: diff --git a/docs/copyright.rst b/docs/copyright.rst index 60739ef0..bf7d9b04 100644 --- a/docs/copyright.rst +++ b/docs/copyright.rst @@ -10,7 +10,7 @@ Copyright (c) 1995, Pascal Andre Further modifications copyright (c) 1997-2008 by D'Arcy J.M. Cain (darcy@PyGreSQL.org) -Further modifications copyright (c) 2009-2024 by the PyGreSQL team. +Further modifications copyright (c) 2009-2025 by the PyGreSQL team. Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement diff --git a/ext/pgconn.c b/ext/pgconn.c index ddc958ea..783eaffc 100644 --- a/ext/pgconn.c +++ b/ext/pgconn.c @@ -3,7 +3,7 @@ * * The connection object - this file is part a of the C extension module. * - * Copyright (c) 2024 by the PyGreSQL Development Team + * Copyright (c) 2025 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pginternal.c b/ext/pginternal.c index 9b3952cc..25290950 100644 --- a/ext/pginternal.c +++ b/ext/pginternal.c @@ -3,7 +3,7 @@ * * Internal functions - this file is part a of the C extension module. * - * Copyright (c) 2024 by the PyGreSQL Development Team + * Copyright (c) 2025 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pglarge.c b/ext/pglarge.c index f19568c4..1b817b25 100644 --- a/ext/pglarge.c +++ b/ext/pglarge.c @@ -3,7 +3,7 @@ * * Large object support - this file is part a of the C extension module. * - * Copyright (c) 2024 by the PyGreSQL Development Team + * Copyright (c) 2025 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgmodule.c b/ext/pgmodule.c index 26b916d6..916adda2 100644 --- a/ext/pgmodule.c +++ b/ext/pgmodule.c @@ -3,7 +3,7 @@ * * This is the main file for the C extension module. * - * Copyright (c) 2024 by the PyGreSQL Development Team + * Copyright (c) 2025 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgnotice.c b/ext/pgnotice.c index ca051d88..c56b249f 100644 --- a/ext/pgnotice.c +++ b/ext/pgnotice.c @@ -3,7 +3,7 @@ * * The notice object - this file is part a of the C extension module. * - * Copyright (c) 2024 by the PyGreSQL Development Team + * Copyright (c) 2025 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgquery.c b/ext/pgquery.c index fe5dda47..b87eba18 100644 --- a/ext/pgquery.c +++ b/ext/pgquery.c @@ -3,7 +3,7 @@ * * The query object - this file is part a of the C extension module. * - * Copyright (c) 2024 by the PyGreSQL Development Team + * Copyright (c) 2025 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgsource.c b/ext/pgsource.c index 4e197578..bbec2f86 100644 --- a/ext/pgsource.c +++ b/ext/pgsource.c @@ -3,7 +3,7 @@ * * The source object - this file is part a of the C extension module. * - * Copyright (c) 2024 by the PyGreSQL Development Team + * Copyright (c) 2025 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pg/__init__.py b/pg/__init__.py index eeda3b73..c3b7f4e9 100644 --- a/pg/__init__.py +++ b/pg/__init__.py @@ -4,7 +4,7 @@ # # This file contains the classic pg module. # -# Copyright (c) 2024 by the PyGreSQL Development Team +# Copyright (c) 2025 by the PyGreSQL Development Team # # The notification handler is based on pgnotify which is # Copyright (c) 2001 Ng Pheng Siong. All rights reserved. diff --git a/pgdb/__init__.py b/pgdb/__init__.py index 5db2fd46..132ce292 100644 --- a/pgdb/__init__.py +++ b/pgdb/__init__.py @@ -4,7 +4,7 @@ # # This file contains the DB-API 2 compatible pgdb module. # -# Copyright (c) 2024 by the PyGreSQL Development Team +# Copyright (c) 2025 by the PyGreSQL Development Team # # Please see the LICENSE.TXT file for specific restrictions. From dc37be796227e9cfd6c2600043acb7a4a628c116 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 20 Dec 2025 15:12:38 +0100 Subject: [PATCH 094/118] Update ReadTheDocs config --- .readthedocs.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 9712e405..21780348 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -7,9 +7,9 @@ version: 2 # Set the version of Python and other tools you might need build: - os: ubuntu-22.04 + os: ubuntu-24.04 tools: - python: "3.11" + python: "3.14" # Build documentation in the docs/ directory with Sphinx sphinx: From cbb804a3167b527882877b51b5edcff8bb8fdeb0 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 21 Dec 2025 10:50:51 +0000 Subject: [PATCH 095/118] Fix set_parameter with special chars --- pg/db.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/pg/db.py b/pg/db.py index 5c8beea7..df91ec58 100644 --- a/pg/db.py +++ b/pg/db.py @@ -451,9 +451,14 @@ def set_parameter(self, params[param] = param_value local_clause = ' LOCAL' if local else '' for param, param_value in params.items(): - cmd = (f'RESET{local_clause} {param}' - if param_value is None else - f'SET{local_clause} {param} TO {param_value}') + if param_value is None: + cmd = f'RESET{local_clause} {param}' + else: + if isinstance(param_value, str) and ( + not param_value.isalnum() or + param_value.upper() != param_value.lower()): + param_value = f"'{param_value}'" + cmd = f'SET{local_clause} {param} TO {param_value}' self._do_debug(cmd) self._valid_db.query(cmd) From b9bab2c4414c0260d276e5f7668c9eee691d825a Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 21 Dec 2025 10:51:27 +0000 Subject: [PATCH 096/118] Avoid ambiguous time zones in tests --- tests/test_classic_dbwrapper.py | 4 ++-- tests/test_dbapi20.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 1d64c754..b70ccd3c 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -3899,7 +3899,7 @@ def test_time(self): def test_timetz(self): query = self.db.query - timezones = dict(CET=1, EET=2, EST=-5, UTC=0) + timezones = {'GMT': 0, 'Etc/GMT-1': 1, 'Etc/GMT-5': -5, 'UTC': 0} for timezone in sorted(timezones): tz = f'{timezones[timezone]:+03d}00' tzinfo = datetime.strptime(tz, '%z').tzinfo @@ -3951,7 +3951,7 @@ def test_timestamp(self): def test_timestamptz(self): query = self.db.query - timezones = dict(CET=1, EET=2, EST=-5, UTC=0) + timezones = {'GMT': 0, 'Etc/GMT-1': 1, 'Etc/GMT-5': -5, 'UTC': 0} for timezone in sorted(timezones): tz = f'{timezones[timezone]:+03d}00' tzinfo = datetime.strptime(tz, '%z').tzinfo diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 0e70e073..959abfad 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -1403,7 +1403,7 @@ def test_cve_2018_1058(self): CREATE OR REPLACE FUNCTION public.bad_eq(oid, integer) RETURNS boolean AS $$ BEGIN - SET TIMEZONE TO 'CET'; + SET TIMEZONE TO 'Europe/Athens'; RETURN oideq($1, $2::oid); END $$ LANGUAGE plpgsql From d1229ae22d32e066cc6347bebe82100a1bf1be23 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 21 Dec 2025 10:58:33 +0000 Subject: [PATCH 097/118] Fix dev container provisioning --- .devcontainer/provision.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index 1ca7b020..5b2846d9 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -26,8 +26,8 @@ apt-get install -y python3.8 python3.8-dev python3.8-distutils apt-get install -y python3.9 python3.9-dev python3.9-distutils apt-get install -y python3.10 python3.10-dev python3.10-distutils apt-get install -y python3.11 python3.11-dev python3.11-distutils -apt-get install -y python3.12 python3.12-dev python3.12-distutils -apt-get install -y python3.13 python3.13-dev python3.13-distutils +apt-get install -y python3.12 python3.12-dev +apt-get install -y python3.13 python3.13-dev # install build and testing tool From 8298b77c28c43c47ab952615f6507ecf7ba3b590 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 21 Dec 2025 14:30:08 +0000 Subject: [PATCH 098/118] Update Python and Postgres versions --- .devcontainer/dev.env | 4 +- .devcontainer/docker-compose.yml | 25 +++-- .devcontainer/provision.sh | 16 ++-- .github/workflows/docs.yml | 12 +-- .github/workflows/lint.yml | 8 +- .github/workflows/tests.yml | 20 ++-- README.rst | 6 +- docs/about.rst | 4 +- docs/contents/changelog.rst | 6 ++ docs/contents/install.rst | 2 +- ext/pgtypes.h | 153 +++++++++++++++---------------- pg/core.py | 25 +++-- pyproject.toml | 15 ++- setup.py | 11 +-- tests/test_classic_connection.py | 4 +- tests/test_classic_dbwrapper.py | 2 +- tests/test_classic_functions.py | 2 +- tox.ini | 36 ++++---- 18 files changed, 181 insertions(+), 170 deletions(-) diff --git a/.devcontainer/dev.env b/.devcontainer/dev.env index 996ee8d2..fd84d4cc 100644 --- a/.devcontainer/dev.env +++ b/.devcontainer/dev.env @@ -1,11 +1,11 @@ -PGHOST=pg15 +PGHOST=pg17 PGPORT=5432 PGDATABASE=test PGUSER=test PGPASSWORD=test PYGRESQL_DB=test -PYGRESQL_HOST=pg15 +PYGRESQL_HOST=pg17 PYGRESQL_PORT=5432 PYGRESQL_USER=test PYGRESQL_PASSWD=test diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index 541d63e9..c3521cb2 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -11,21 +11,21 @@ services: command: sleep infinity - pg10: - image: postgres:10 + pg12: + image: postgres:12 restart: unless-stopped volumes: - - postgres-data-10:/var/lib/postgresql/data + - postgres-data-12:/var/lib/postgresql/data environment: POSTGRES_USER: postgres POSTGRES_DB: postgres POSTGRES_PASSWORD: postgres - pg12: - image: postgres:12 + pg13: + image: postgres:13 restart: unless-stopped volumes: - - postgres-data-12:/var/lib/postgresql/data + - postgres-data-13:/var/lib/postgresql/data environment: POSTGRES_USER: postgres POSTGRES_DB: postgres @@ -71,10 +71,21 @@ services: POSTGRES_DB: postgres POSTGRES_PASSWORD: postgres + pg18: + image: postgres:18 + restart: unless-stopped + volumes: + - postgres-data-18:/var/lib/postgresql + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + volumes: - postgres-data-10: postgres-data-12: + postgres-data-13: postgres-data-14: postgres-data-15: postgres-data-16: postgres-data-17: + postgres-data-18: diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh index 5b2846d9..d34a5cf7 100644 --- a/.devcontainer/provision.sh +++ b/.devcontainer/provision.sh @@ -21,25 +21,27 @@ apt-get install -y rpm wget zip add-apt-repository -y ppa:deadsnakes/ppa apt-get update -apt-get install -y python3.7 python3.7-dev python3.7-distutils apt-get install -y python3.8 python3.8-dev python3.8-distutils apt-get install -y python3.9 python3.9-dev python3.9-distutils apt-get install -y python3.10 python3.10-dev python3.10-distutils -apt-get install -y python3.11 python3.11-dev python3.11-distutils -apt-get install -y python3.12 python3.12-dev -apt-get install -y python3.13 python3.13-dev +apt-get install -y python3.11 python3.11-dev python3.11-distutils +apt-get install -y python3.12 python3.12-dev python3.12-venv +apt-get install -y python3.13 python3.13-dev python3.13-venv +apt-get install -y python3.14 python3.14-dev python3.14-venv # install build and testing tool -python -m ensurepip -U +python3.12 -m ensurepip --upgrade --default-pip +python3.13 -m ensurepip --upgrade --default-pip +python3.14 -m ensurepip --upgrade --default-pip -python3.7 -m pip install -U pip setuptools wheel build python3.8 -m pip install -U pip setuptools wheel build python3.9 -m pip install -U pip setuptools wheel build python3.10 -m pip install -U pip setuptools wheel build python3.11 -m pip install -U pip setuptools wheel build python3.12 -m pip install -U pip setuptools wheel build python3.13 -m pip install -U pip setuptools wheel build +python3.14 -m pip install -U pip setuptools wheel build pip install ruff @@ -50,7 +52,7 @@ pip install -U tox apt-get install -y postgresql libpq-dev -for pghost in pg10 pg12 pg14 pg15 pg16 pg17 +for pghost in pg12 pg13 pg14 pg15 pg16 pg17 pg18 do export PGHOST=$pghost export PGDATABASE=postgres diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index d88cd64a..cb56b57b 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -8,21 +8,21 @@ on: jobs: docs: name: Build documentation - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - name: Check out repository - uses: actions/checkout@v4 - - name: Set up Python 3.13 - uses: actions/setup-python@v5 + uses: actions/checkout@v6 + - name: Set up Python 3.14 + uses: actions/setup-python@v6 with: - python-version: 3.13 + python-version: 3.14 - name: Install dependencies run: | sudo apt install libpq-dev python -m pip install --upgrade pip pip install . - pip install "sphinx>=8,<9" + pip install "sphinx>=9,<10" - name: Create docs with Sphinx run: | cd docs diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 66d79095..29a4b042 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -7,20 +7,20 @@ on: jobs: checks: name: Quality checks run - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: fail-fast: false steps: - name: Check out repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install tox run: pip install tox - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: - python-version: 3.13 + python-version: 3.14 - name: Run quality checks run: tox -e ruff,mypy,cformat,docs timeout-minutes: 5 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 920e3f3e..caaa1955 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,22 +15,22 @@ jobs: fail-fast: false matrix: include: - - { python: "3.7", postgres: "11" } - { python: "3.8", postgres: "12" } - { python: "3.9", postgres: "13" } - { python: "3.10", postgres: "14" } - { python: "3.11", postgres: "15" } - { python: "3.12", postgres: "16" } - { python: "3.13", postgres: "17" } + - { python: "3.14", postgres: "18" } # Opposite extremes of the supported Py/PG range, other architecture - - { python: "3.7", postgres: "17", architecture: "x86" } - - { python: "3.8", postgres: "16", architecture: "x86" } - - { python: "3.9", postgres: "15", architecture: "x86" } - - { python: "3.10", postgres: "14", architecture: "x86" } - - { python: "3.11", postgres: "13", architecture: "x86" } - - { python: "3.12", postgres: "12", architecture: "x86" } - - { python: "3.13", postgres: "11", architecture: "x86" } + - { python: "3.8", postgres: "18", architecture: "x86" } + - { python: "3.9", postgres: "17", architecture: "x86" } + - { python: "3.10", postgres: "16", architecture: "x86" } + - { python: "3.11", postgres: "15", architecture: "x86" } + - { python: "3.12", postgres: "14", architecture: "x86" } + - { python: "3.13", postgres: "13", architecture: "x86" } + - { python: "3.14", postgres: "12", architecture: "x86" } env: PYGRESQL_DB: test @@ -55,11 +55,11 @@ jobs: steps: - name: Check out repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install tox run: pip install tox - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python }} - name: Run tests diff --git a/README.rst b/README.rst index 46a09c2b..112e912d 100644 --- a/README.rst +++ b/README.rst @@ -10,7 +10,7 @@ It is based on the PyGres95 code written by Pascal Andre. D'Arcy J. M. Cain renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. Christoph Zwerschke volunteered as another maintainer and has been the main -contributor since version 3.7 of PyGreSQL. +contributor since version 3.8 of PyGreSQL. The following Python versions are supported: @@ -18,8 +18,8 @@ The following Python versions are supported: * PyGreSQL 5.x: Python 2 and Python 3 * PyGreSQL 6.x and newer: Python 3 only -The current version of PyGreSQL supports Python versions 3.7 to 3.13 -and PostgreSQL versions 10 to 17 on the server. +The current version of PyGreSQL supports Python versions 3.8 to 3.14 +and PostgreSQL versions 12 to 18 on the server. Installation ------------ diff --git a/docs/about.rst b/docs/about.rst index 10ceaf59..96284f05 100644 --- a/docs/about.rst +++ b/docs/about.rst @@ -39,6 +39,6 @@ on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -The current version PyGreSQL |version| needs PostgreSQL 10 to 17, and Python -3.7 to 3.13. If you need to support older PostgreSQL or Python versions, +The current version PyGreSQL |version| needs PostgreSQL 12 to 18, and Python +3.8 to 3.14. If you need to support older PostgreSQL or Python versions, you can resort to the PyGreSQL 5.x versions that still support them. diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index ad5f7f0e..22294d1c 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,12 @@ ChangeLog ========= +Version 6.2.0 (2025-12-21) +-------------------------- +- Added support for Python 3.14 and PostgreSQL 18. +- Removed support for Python versions older than 3.8 (released October 2019) + and PostgreSQL older than version 12 (released October 2019). + Version 6.1.0 (2024-12-05) -------------------------- - Support Python 3.13 and PostgreSQL 17. diff --git a/docs/contents/install.rst b/docs/contents/install.rst index 23694528..2d806311 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -14,7 +14,7 @@ On Windows, you also need to make sure that the directory that contains ``libpq.dll`` is part of your ``PATH`` environment variable. The current version of PyGreSQL has been tested with Python versions -3.7 to 3.13, and PostgreSQL versions 10 to 17. +3.8 to 3.14, and PostgreSQL versions 12 to 18. PyGreSQL will be installed as two packages named ``pg`` (for the classic interface) and ``pgdb`` (for the DB API 2 compliant interface). The former diff --git a/ext/pgtypes.h b/ext/pgtypes.h index 72c42ca9..216f2e49 100644 --- a/ext/pgtypes.h +++ b/ext/pgtypes.h @@ -1,14 +1,14 @@ /* pgtypes - PostgreSQL type definitions - These are the standard PostgreSQL 11.1 built-in types, - extracted from src/backend/catalog/pg_type_d.h, - because that header file is sometimes not available - or needs other header files to get properly included. + These are the standard PostgreSQL 12.22 built-in types, + extracted from src/include/catalog/pg_type_d.h, + because that header file is generated and often unavailable. You can also query pg_type to get this information. */ #ifndef PG_TYPE_D_H +#define PG_TYPE_D_H #define BOOLOID 16 #define BYTEAOID 17 @@ -27,90 +27,39 @@ #define OIDVECTOROID 30 #define JSONOID 114 #define XMLOID 142 -#define XMLARRAYOID 143 -#define JSONARRAYOID 199 #define PGNODETREEOID 194 #define PGNDISTINCTOID 3361 #define PGDEPENDENCIESOID 3402 +#define PGMCVLISTOID 5017 #define PGDDLCOMMANDOID 32 -#define SMGROID 210 #define POINTOID 600 #define LSEGOID 601 #define PATHOID 602 #define BOXOID 603 #define POLYGONOID 604 #define LINEOID 628 -#define LINEARRAYOID 629 #define FLOAT4OID 700 #define FLOAT8OID 701 -#define ABSTIMEOID 702 -#define RELTIMEOID 703 -#define TINTERVALOID 704 #define UNKNOWNOID 705 #define CIRCLEOID 718 -#define CIRCLEARRAYOID 719 #define CASHOID 790 -#define MONEYARRAYOID 791 #define MACADDROID 829 #define INETOID 869 #define CIDROID 650 #define MACADDR8OID 774 -#define BOOLARRAYOID 1000 -#define BYTEAARRAYOID 1001 -#define CHARARRAYOID 1002 -#define NAMEARRAYOID 1003 -#define INT2ARRAYOID 1005 -#define INT2VECTORARRAYOID 1006 -#define INT4ARRAYOID 1007 -#define REGPROCARRAYOID 1008 -#define TEXTARRAYOID 1009 -#define OIDARRAYOID 1028 -#define TIDARRAYOID 1010 -#define XIDARRAYOID 1011 -#define CIDARRAYOID 1012 -#define OIDVECTORARRAYOID 1013 -#define BPCHARARRAYOID 1014 -#define VARCHARARRAYOID 1015 -#define INT8ARRAYOID 1016 -#define POINTARRAYOID 1017 -#define LSEGARRAYOID 1018 -#define PATHARRAYOID 1019 -#define BOXARRAYOID 1020 -#define FLOAT4ARRAYOID 1021 -#define FLOAT8ARRAYOID 1022 -#define ABSTIMEARRAYOID 1023 -#define RELTIMEARRAYOID 1024 -#define TINTERVALARRAYOID 1025 -#define POLYGONARRAYOID 1027 #define ACLITEMOID 1033 -#define ACLITEMARRAYOID 1034 -#define MACADDRARRAYOID 1040 -#define MACADDR8ARRAYOID 775 -#define INETARRAYOID 1041 -#define CIDRARRAYOID 651 -#define CSTRINGARRAYOID 1263 #define BPCHAROID 1042 #define VARCHAROID 1043 #define DATEOID 1082 #define TIMEOID 1083 #define TIMESTAMPOID 1114 -#define TIMESTAMPARRAYOID 1115 -#define DATEARRAYOID 1182 -#define TIMEARRAYOID 1183 #define TIMESTAMPTZOID 1184 -#define TIMESTAMPTZARRAYOID 1185 #define INTERVALOID 1186 -#define INTERVALARRAYOID 1187 -#define NUMERICARRAYOID 1231 #define TIMETZOID 1266 -#define TIMETZARRAYOID 1270 #define BITOID 1560 -#define BITARRAYOID 1561 #define VARBITOID 1562 -#define VARBITARRAYOID 1563 #define NUMERICOID 1700 #define REFCURSOROID 1790 -#define REFCURSORARRAYOID 2201 #define REGPROCEDUREOID 2202 #define REGOPEROID 2203 #define REGOPERATOROID 2204 @@ -118,43 +67,22 @@ #define REGTYPEOID 2206 #define REGROLEOID 4096 #define REGNAMESPACEOID 4089 -#define REGPROCEDUREARRAYOID 2207 -#define REGOPERARRAYOID 2208 -#define REGOPERATORARRAYOID 2209 -#define REGCLASSARRAYOID 2210 -#define REGTYPEARRAYOID 2211 -#define REGROLEARRAYOID 4097 -#define REGNAMESPACEARRAYOID 4090 #define UUIDOID 2950 -#define UUIDARRAYOID 2951 #define LSNOID 3220 -#define PG_LSNARRAYOID 3221 #define TSVECTOROID 3614 #define GTSVECTOROID 3642 #define TSQUERYOID 3615 #define REGCONFIGOID 3734 #define REGDICTIONARYOID 3769 -#define TSVECTORARRAYOID 3643 -#define GTSVECTORARRAYOID 3644 -#define TSQUERYARRAYOID 3645 -#define REGCONFIGARRAYOID 3735 -#define REGDICTIONARYARRAYOID 3770 #define JSONBOID 3802 -#define JSONBARRAYOID 3807 +#define JSONPATHOID 4072 #define TXID_SNAPSHOTOID 2970 -#define TXID_SNAPSHOTARRAYOID 2949 #define INT4RANGEOID 3904 -#define INT4RANGEARRAYOID 3905 #define NUMRANGEOID 3906 -#define NUMRANGEARRAYOID 3907 #define TSRANGEOID 3908 -#define TSRANGEARRAYOID 3909 #define TSTZRANGEOID 3910 -#define TSTZRANGEARRAYOID 3911 #define DATERANGEOID 3912 -#define DATERANGEARRAYOID 3913 #define INT8RANGEOID 3926 -#define INT8RANGEARRAYOID 3927 #define RECORDOID 2249 #define RECORDARRAYOID 2287 #define CSTRINGOID 2275 @@ -172,6 +100,75 @@ #define FDW_HANDLEROID 3115 #define INDEX_AM_HANDLEROID 325 #define TSM_HANDLEROID 3310 +#define TABLE_AM_HANDLEROID 269 #define ANYRANGEOID 3831 +#define BOOLARRAYOID 1000 +#define BYTEAARRAYOID 1001 +#define CHARARRAYOID 1002 +#define NAMEARRAYOID 1003 +#define INT8ARRAYOID 1016 +#define INT2ARRAYOID 1005 +#define INT2VECTORARRAYOID 1006 +#define INT4ARRAYOID 1007 +#define REGPROCARRAYOID 1008 +#define TEXTARRAYOID 1009 +#define OIDARRAYOID 1028 +#define TIDARRAYOID 1010 +#define XIDARRAYOID 1011 +#define CIDARRAYOID 1012 +#define OIDVECTORARRAYOID 1013 +#define JSONARRAYOID 199 +#define XMLARRAYOID 143 +#define POINTARRAYOID 1017 +#define LSEGARRAYOID 1018 +#define PATHARRAYOID 1019 +#define BOXARRAYOID 1020 +#define POLYGONARRAYOID 1027 +#define LINEARRAYOID 629 +#define FLOAT4ARRAYOID 1021 +#define FLOAT8ARRAYOID 1022 +#define CIRCLEARRAYOID 719 +#define MONEYARRAYOID 791 +#define MACADDRARRAYOID 1040 +#define INETARRAYOID 1041 +#define CIDRARRAYOID 651 +#define MACADDR8ARRAYOID 775 +#define ACLITEMARRAYOID 1034 +#define BPCHARARRAYOID 1014 +#define VARCHARARRAYOID 1015 +#define DATEARRAYOID 1182 +#define TIMEARRAYOID 1183 +#define TIMESTAMPARRAYOID 1115 +#define TIMESTAMPTZARRAYOID 1185 +#define INTERVALARRAYOID 1187 +#define TIMETZARRAYOID 1270 +#define BITARRAYOID 1561 +#define VARBITARRAYOID 1563 +#define NUMERICARRAYOID 1231 +#define REFCURSORARRAYOID 2201 +#define REGPROCEDUREARRAYOID 2207 +#define REGOPERARRAYOID 2208 +#define REGOPERATORARRAYOID 2209 +#define REGCLASSARRAYOID 2210 +#define REGTYPEARRAYOID 2211 +#define REGROLEARRAYOID 4097 +#define REGNAMESPACEARRAYOID 4090 +#define UUIDARRAYOID 2951 +#define PG_LSNARRAYOID 3221 +#define TSVECTORARRAYOID 3643 +#define GTSVECTORARRAYOID 3644 +#define TSQUERYARRAYOID 3645 +#define REGCONFIGARRAYOID 3735 +#define REGDICTIONARYARRAYOID 3770 +#define JSONBARRAYOID 3807 +#define JSONPATHARRAYOID 4073 +#define TXID_SNAPSHOTARRAYOID 2949 +#define INT4RANGEARRAYOID 3905 +#define NUMRANGEARRAYOID 3907 +#define TSRANGEARRAYOID 3909 +#define TSTZRANGEARRAYOID 3911 +#define DATERANGEARRAYOID 3913 +#define INT8RANGEARRAYOID 3927 +#define CSTRINGARRAYOID 1263 #endif /* PG_TYPE_D_H */ diff --git a/pg/core.py b/pg/core.py index 4d0c03c0..87191ae6 100644 --- a/pg/core.py +++ b/pg/core.py @@ -10,19 +10,18 @@ import sys paths = [path for path in os.environ["PATH"].split(os.pathsep) if os.path.exists(os.path.join(path, libpq))] - if sys.version_info >= (3, 8): - # see https://docs.python.org/3/whatsnew/3.8.html#ctypes - add_dll_dir = os.add_dll_directory # type: ignore - for path in paths: - with add_dll_dir(os.path.abspath(path)): - try: - from ._pg import version - except ImportError: - pass - else: - del version - e = None # type: ignore - break + # see https://docs.python.org/3/whatsnew/3.8.html#ctypes + add_dll_dir = os.add_dll_directory # type: ignore + for path in paths: + with add_dll_dir(os.path.abspath(path)): + try: + from ._pg import version + except ImportError: + pass + else: + del version + e = None # type: ignore + break if paths: libpq = 'compatible ' + libpq else: diff --git a/pyproject.toml b/pyproject.toml index 01b5086f..343bdec0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,38 +1,36 @@ [project] name = "PyGreSQL" -version = "6.1.0" -requires-python = ">=3.7" +version = "6.2.0" +requires-python = ">=3.8" authors = [ {name = "D'Arcy J. M. Cain", email = "darcy@pygresql.org"}, {name = "Christoph Zwerschke", email = "cito@online.de"}, ] description = "Python PostgreSQL interfaces" readme = "README.rst" +license = "PostgreSQL" +license-files = ["LICENSE.txt"] keywords = ["pygresql", "postgresql", "database", "api", "dbapi"] classifiers = [ "Development Status :: 6 - Mature", "Intended Audience :: Developers", - "License :: OSI Approved :: PostgreSQL License", "Operating System :: OS Independent", "Programming Language :: C", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Programming Language :: SQL", "Topic :: Database", "Topic :: Database :: Front-Ends", "Topic :: Software Development :: Libraries :: Python Modules", ] -[project.license] -file = "LICENSE.txt" - [project.urls] Homepage = "https://pygresql.github.io/" Documentation = "https://pygresql.github.io/contents/" @@ -95,12 +93,11 @@ disallow_untyped_defs = false [tool.setuptools] packages = ["pg", "pgdb"] -license-files = ["LICENSE.txt"] [tool.setuptools.package-data] pg = ["pg.typed"] pgdb = ["pg.typed"] [build-system] -requires = ["setuptools>=68", "wheel>=0.42"] +requires = ["setuptools>=80", "wheel>=0.45"] build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index bf652276..61d99ab1 100755 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ def project_readme(): version = project_version() -if not (3, 7) <= sys.version_info[:2] < (4, 0): +if not (3, 8) <= sys.version_info[:2] < (4, 0): raise Exception( f"Sorry, PyGreSQL {version} does not support this Python version") @@ -66,7 +66,7 @@ def pg_version(): match = re.search(r'(\d+)\.(\d+)', pg_config('version')) if match: return tuple(map(int, match.groups())) - return 10, 0 + return 12, 0 pg_version = pg_version() @@ -105,7 +105,7 @@ def initialize_options(self): build_ext.initialize_options(self) self.strict = False self.memory_size = None - supported = pg_version >= (10, 0) + supported = pg_version >= (12, 0) if not supported: warnings.warn( "PyGreSQL does not support the installed PostgreSQL version.", @@ -159,24 +159,23 @@ def finalize_options(self): classifiers=[ 'Development Status :: 6 - Mature', 'Intended Audience :: Developers', - 'License :: OSI Approved :: PostgreSQL License', 'Operating System :: OS Independent', 'Programming Language :: C', 'Programming Language :: Python', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', 'Programming Language :: Python :: 3.13', + 'Programming Language :: Python :: 3.14', 'Programming Language :: SQL', 'Topic :: Database', 'Topic :: Database :: Front-Ends', 'Topic :: Software Development :: Libraries :: Python Modules'], license='PostgreSQL', - test_suite='tests.discover', + license_files=['LICENSE.txt'], zip_safe=False, packages=["pg", "pgdb"], package_data={"pg": ["py.typed"], "pgdb": ["py.typed"]}, diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 90d69a59..4d9c2bb0 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -174,8 +174,8 @@ def test_attribute_protocol_version(self): def test_attribute_server_version(self): server_version = self.connection.server_version self.assertIsInstance(server_version, int) - self.assertGreaterEqual(server_version, 100000) # >= 10.0 - self.assertLess(server_version, 190000) # < 20.0 + self.assertGreaterEqual(server_version, 120000) # >= 12.0 + self.assertLess(server_version, 200000) # < 20.0 def test_attribute_socket(self): socket = self.connection.socket diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index b70ccd3c..1c3a6e7d 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -168,7 +168,7 @@ def test_attribute_protocol_version(self): def test_attribute_server_version(self): server_version = self.db.server_version self.assertIsInstance(server_version, int) - self.assertGreaterEqual(server_version, 100000) # >= 10.0 + self.assertGreaterEqual(server_version, 120000) # >= 12.0 self.assertLess(server_version, 200000) # < 20.0 self.assertEqual(server_version, self.db.db.server_version) diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index d1bde01c..eb9a6bf0 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -124,7 +124,7 @@ def test_pqlib_version(self): # noinspection PyUnresolvedReferences v = pg.get_pqlib_version() self.assertIsInstance(v, int) - self.assertGreater(v, 100000) # >= 10.0 + self.assertGreater(v, 120000) # >= 12.0 self.assertLess(v, 200000) # < 20.0 diff --git a/tox.ini b/tox.ini index 2359c8df..f3d60de6 100644 --- a/tox.ini +++ b/tox.ini @@ -1,48 +1,45 @@ # config file for tox [tox] -envlist = py3{7,8,9,10,11,12,13},ruff,mypy,cformat,docs -requires = # this is needed for compatibility with Python 3.7 - pip<24.1 - virtualenv<20.27 +envlist = py3{9,10,11,12,13,14},ruff,mypy,cformat,docs [testenv:ruff] -basepython = python3.13 -deps = ruff>=0.8,<0.9 +basepython = python3.14 +deps = ruff>=0.14,<0.15 commands = ruff check setup.py pg pgdb tests [testenv:mypy] -basepython = python3.13 -deps = mypy>=1.13,<1.14 +basepython = python3.14 +deps = mypy>=1.19,<1.20 commands = mypy pg pgdb tests [testenv:cformat] -basepython = python3.13 +basepython = python3.14 allowlist_externals = sh commands = sh -c "! (clang-format --style=file -n ext/*.c 2>&1 | tee /dev/tty | grep format-violations)" [testenv:docs] -basepython = python3.13 +basepython = python3.14 deps = - sphinx>=8,<9 + sphinx>=9,<10 commands = sphinx-build -b html -nEW docs docs/_build/html [testenv:build] -basepython = python3.13 +basepython = python3.14 deps = - setuptools>=68 - wheel>=0.42,<1 - build>=1,<2 + setuptools>=80 + wheel>=0.45,<1 + build>=1.3,<2 commands = python -m build -s -n -C strict -C memory-size [testenv:coverage] -basepython = python3.13 +basepython = python3.14 deps = coverage>=7,<8 commands = @@ -54,7 +51,10 @@ passenv = PG* PYGRESQL_* deps = - setuptools>=68 + setuptools>=75 +skip_install = + # The built distribution runs with Python 3.8, + # but the pyproject.toml is not compatible with it + py38: true commands = - python setup.py clean --all build_ext --force --inplace --strict --memory-size python -m unittest {posargs:discover -v} From 7d8545ab35844defcefcdeef1035e81548010ae8 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 21 Dec 2025 14:39:29 +0000 Subject: [PATCH 099/118] Fix formatting issues --- pg/adapt.py | 13 +++++----- pg/core.py | 1 - pg/db.py | 5 ++-- tests/test_classic_attrdict.py | 2 +- tests/test_classic_connection.py | 43 +++++++++++++++++--------------- tests/test_classic_largeobj.py | 2 +- tests/test_tutorial.py | 4 +-- 7 files changed, 36 insertions(+), 34 deletions(-) diff --git a/pg/adapt.py b/pg/adapt.py index 97e0391c..6f65be0d 100644 --- a/pg/adapt.py +++ b/pg/adapt.py @@ -192,8 +192,9 @@ class DbTypes(dict): information on the associated database type. """ - _num_types = frozenset('int float num money int2 int4 int8' - ' float4 float8 numeric money'.split()) + _num_types = frozenset([ + 'int', 'float', 'num', 'money', 'int2', 'int4', 'int8', + 'float4', 'float8', 'numeric', 'money']) def __init__(self, db: DB) -> None: """Initialize type cache for connection.""" @@ -292,11 +293,11 @@ def typecast(self, value: Any, typ: str) -> Any: class Adapter: """Class providing methods for adapting parameters to the database.""" - _bool_true_values = frozenset('t true 1 y yes on'.split()) + _bool_true_values = frozenset(['t', 'true', '1', 'y', 'yes', 'on']) - _date_literals = frozenset( - 'current_date current_time' - ' current_timestamp localtime localtimestamp'.split()) + _date_literals = frozenset([ + 'current_date', 'current_time', 'current_timestamp', + 'localtime', 'localtimestamp']) _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$') _re_record_quote = regex(r'[(,"\\]') diff --git a/pg/core.py b/pg/core.py index 87191ae6..c9dd7a1b 100644 --- a/pg/core.py +++ b/pg/core.py @@ -7,7 +7,6 @@ libpq = 'libpq.' if os.name == 'nt': libpq += 'dll' - import sys paths = [path for path in os.environ["PATH"].split(os.pathsep) if os.path.exists(os.path.join(path, libpq))] # see https://docs.python.org/3/whatsnew/3.8.html#ctypes diff --git a/pg/db.py b/pg/db.py index df91ec58..0ba37e92 100644 --- a/pg/db.py +++ b/pg/db.py @@ -815,7 +815,7 @@ def get(self, table: str, row: Any, row[qoid] = row['oid'] del row['oid'] t = self._escape_qualified_name(table) - cmd = f'SELECT {what} FROM {t} WHERE {where} LIMIT 1' # noqa: S608s + cmd = f'SELECT {what} FROM {t} WHERE {where} LIMIT 1' # noqa: S608 self._do_debug(cmd, params) query = self._valid_db.query(cmd, params) res = query.dictresult() @@ -998,8 +998,7 @@ def upsert(self, table: str, row: dict[str, Any] | None = None, **kw: Any row = {} if 'oid' in row: del row['oid'] # do not insert oid - if 'oid' in kw: - del kw['oid'] # do not update oid + kw.pop('oid', None) # do not update oid attnames = self.get_attnames(table) generated = self.get_generated(table) qoid = oid_key(table) if 'oid' in attnames else None diff --git a/tests/test_classic_attrdict.py b/tests/test_classic_attrdict.py index 8eef00df..319abedf 100644 --- a/tests/test_classic_attrdict.py +++ b/tests/test_classic_attrdict.py @@ -93,7 +93,7 @@ def test_write_methods(self): self.assertEqual(a['id'], 1) for method in 'clear', 'update', 'pop', 'setdefault', 'popitem': method = getattr(a, method) - self.assertRaises(TypeError, method, a) # type: ignore + self.assertRaises(TypeError, method, a) if __name__ == '__main__': diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 4d9c2bb0..0a3f7bb0 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -121,24 +121,26 @@ def test_repr(self): self.assertTrue(r.startswith(' Date: Sun, 21 Dec 2025 16:11:31 +0000 Subject: [PATCH 100/118] Make setup backward compatible --- pyproject.toml | 2 +- setup.py | 41 +++++++++++++++++++++++++++++++++++++---- tox.ini | 20 +++++++++++++------- 3 files changed, 51 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 343bdec0..63bd206b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,5 +99,5 @@ pg = ["pg.typed"] pgdb = ["pg.typed"] [build-system] -requires = ["setuptools>=80", "wheel>=0.45"] +requires = ["setuptools>=75", "wheel>=0.45"] build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 61d99ab1..9e9dbcbe 100755 --- a/setup.py +++ b/setup.py @@ -8,15 +8,16 @@ python -m build -C strict -C memory-size """ +import contextlib import os import platform import re import sys import warnings -from distutils.ccompiler import get_default_compiler -from distutils.sysconfig import get_python_inc, get_python_lib from setuptools import Extension, setup +from setuptools._distutils.ccompiler import get_default_compiler +from setuptools._distutils.sysconfig import get_python_inc, get_python_lib from setuptools.command.build_ext import build_ext @@ -42,7 +43,39 @@ def project_readme(): raise Exception( f"Sorry, PyGreSQL {version} does not support this Python version") -long_description = project_readme() + +def patch_pyproject_toml(): + """Patch pyproject.toml to make it work with old setuptools versions. + + This allows building PyGreSQL with Python < 3.9 which only supports + setuptools up to version 75, since our pyproject.toml requires version 77. + """ + from setuptools import __version__ as version + + try: + version = int(version.split('.', 1)[0]) + except Exception: + return + if version >= 77: + return + + from setuptools.config import pyprojecttoml + + load_file = pyprojecttoml.load_file + + def load_file_patched(filepath): + d = load_file(filepath) + with contextlib.suppress(KeyError): + p = d['project'] + t = p['license'] + f = p.pop('license-files') + p['license'] = {'text': t, 'files': f[0]} + return d + + pyprojecttoml.load_file = load_file_patched + + +patch_pyproject_toml() # needed for Python < 3.9 # For historical reasons, PyGreSQL does not install itself as a single @@ -144,7 +177,7 @@ def finalize_options(self): name='PyGreSQL', version=version, description='Python PostgreSQL Interfaces', - long_description=long_description, + long_description=project_readme(), long_description_content_type='text/x-rst', keywords='pygresql postgresql database api dbapi', author="D'Arcy J. M. Cain", diff --git a/tox.ini b/tox.ini index f3d60de6..64d0b9d5 100644 --- a/tox.ini +++ b/tox.ini @@ -1,19 +1,23 @@ # config file for tox [tox] -envlist = py3{9,10,11,12,13,14},ruff,mypy,cformat,docs +envlist = py3{8,9,10,11,12,13,14},ruff,mypy,cformat,docs [testenv:ruff] basepython = python3.14 -deps = ruff>=0.14,<0.15 +deps = + ruff>=0.14,<0.15 commands = ruff check setup.py pg pgdb tests +commands_pre = [testenv:mypy] basepython = python3.14 -deps = mypy>=1.19,<1.20 +deps = + mypy>=1.19,<1.20 commands = mypy pg pgdb tests +commands_pre = [testenv:cformat] basepython = python3.14 @@ -21,6 +25,7 @@ allowlist_externals = sh commands = sh -c "! (clang-format --style=file -n ext/*.c 2>&1 | tee /dev/tty | grep format-violations)" +commands_pre = [testenv:docs] basepython = python3.14 @@ -28,6 +33,7 @@ deps = sphinx>=9,<10 commands = sphinx-build -b html -nEW docs docs/_build/html +commands_pre = [testenv:build] basepython = python3.14 @@ -37,6 +43,7 @@ deps = build>=1.3,<2 commands = python -m build -s -n -C strict -C memory-size +commands_pre = [testenv:coverage] basepython = python3.14 @@ -52,9 +59,8 @@ passenv = PYGRESQL_* deps = setuptools>=75 -skip_install = - # The built distribution runs with Python 3.8, - # but the pyproject.toml is not compatible with it - py38: true + wheel>=0.45 commands = python -m unittest {posargs:discover -v} +commands_pre = + python setup.py clean --all build_ext --force --inplace --strict --memory-size From dee6fa5e3abcf6660ece771968df79ca332a8400 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 21 Dec 2025 20:35:11 +0000 Subject: [PATCH 101/118] Add note to changelog --- docs/contents/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 22294d1c..ceede228 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -6,6 +6,7 @@ Version 6.2.0 (2025-12-21) - Added support for Python 3.14 and PostgreSQL 18. - Removed support for Python versions older than 3.8 (released October 2019) and PostgreSQL older than version 12 (released October 2019). +- Fixed `set_parameter()` for values containing special characters. Version 6.1.0 (2024-12-05) -------------------------- From 72a219fa69f9841e9af1026a7698e4cd440ae426 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 27 Dec 2025 15:11:23 +0000 Subject: [PATCH 102/118] Make setup.py cleaner and more robust --- docs/contents/changelog.rst | 5 ++++ pyproject.toml | 2 +- setup.py | 57 ++++++++++++++++++++++--------------- 3 files changed, 40 insertions(+), 24 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index ceede228..cbb8e73f 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,11 @@ ChangeLog ========= +Version 6.2.1 (2025-12-27) +-------------------------- +- Setup is now cleaner and works with a wider range of setuptools versions. + + Version 6.2.0 (2025-12-21) -------------------------- - Added support for Python 3.14 and PostgreSQL 18. diff --git a/pyproject.toml b/pyproject.toml index 63bd206b..669bafd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "PyGreSQL" -version = "6.2.0" +version = "6.2.1" requires-python = ">=3.8" authors = [ {name = "D'Arcy J. M. Cain", email = "darcy@pygresql.org"}, diff --git a/setup.py b/setup.py index 9e9dbcbe..288cc303 100755 --- a/setup.py +++ b/setup.py @@ -13,13 +13,16 @@ import platform import re import sys +import sysconfig import warnings from setuptools import Extension, setup -from setuptools._distutils.ccompiler import get_default_compiler -from setuptools._distutils.sysconfig import get_python_inc, get_python_lib from setuptools.command.build_ext import build_ext +min_py_version = 3, 8 # supported: Python >= 3.8 +max_py_version = 4, 0 # and < 4.0 +min_pg_version = 12, 0 # supported: PostgreSQL >= 12.0 + def project_version(): """Read the PyGreSQL version from the pyproject.toml file.""" @@ -39,7 +42,7 @@ def project_readme(): version = project_version() -if not (3, 8) <= sys.version_info[:2] < (4, 0): +if not min_py_version <= sys.version_info[:2] < max_py_version: raise Exception( f"Sorry, PyGreSQL {version} does not support this Python version") @@ -54,12 +57,15 @@ def patch_pyproject_toml(): try: version = int(version.split('.', 1)[0]) - except Exception: + except (IndexError, TypeError, ValueError): return - if version >= 77: + if not 61 <= version < 77: # only needed for setuptools 61 to 76 return - from setuptools.config import pyprojecttoml + try: + from setuptools.config import pyprojecttoml + except ImportError: + return load_file = pyprojecttoml.load_file @@ -99,21 +105,21 @@ def pg_version(): match = re.search(r'(\d+)\.(\d+)', pg_config('version')) if match: return tuple(map(int, match.groups())) - return 12, 0 + return min_pg_version pg_version = pg_version() libraries = ['pq'] # Make sure that the Python header files are searched before # those of PostgreSQL, because PostgreSQL can have its own Python.h -include_dirs = [get_python_inc(), pg_config('includedir')] -library_dirs = [get_python_lib(), pg_config('libdir')] +include_dirs = [sysconfig.get_path("include"), pg_config('includedir')] +library_dirs = [sysconfig.get_path("purelib"), pg_config('libdir')] define_macros = [('PYGRESQL_VERSION', version)] undef_macros = [] extra_compile_args = ['-O2', '-funsigned-char', '-Wall', '-Wconversion'] -class build_pg_ext(build_ext): # noqa: N801 +class BuildPgExt(build_ext): """Customized build_ext command for PyGreSQL.""" description = "build the PyGreSQL C extension" @@ -129,28 +135,28 @@ class build_pg_ext(build_ext): # noqa: N801 negative_opt = { # noqa: RUF012 'no-memory-size': 'memory-size'} - def get_compiler(self): - """Return the C compiler used for building the extension.""" - return self.compiler or get_default_compiler() - def initialize_options(self): """Initialize the supported options with default values.""" - build_ext.initialize_options(self) + super().initialize_options() self.strict = False self.memory_size = None - supported = pg_version >= (12, 0) + supported = pg_version >= min_pg_version if not supported: warnings.warn( "PyGreSQL does not support the installed PostgreSQL version.", stacklevel=2) def finalize_options(self): - """Set final values for all build_pg options.""" - build_ext.finalize_options(self) + """Set values for all build_pg options. + + Some values are set in build_extensions() since they depend + on the compiler version which is not yet known at this point. + """ + super().finalize_options() if self.strict: extra_compile_args.append('-Werror') wanted = self.memory_size - supported = pg_version >= (12, 0) + supported = pg_version >= min_pg_version if (wanted is None and supported) or wanted: define_macros.append(('MEMORY_SIZE', None)) if not supported: @@ -158,19 +164,24 @@ def finalize_options(self): "The installed PostgreSQL version" " does not support the memory size function.", stacklevel=2) + + def build_extensions(self): + """Build the PyGreSQL C extension.""" + # Adjust settings for Windows platforms if sys.platform == 'win32': libraries[0] = 'lib' + libraries[0] if os.path.exists(os.path.join( library_dirs[1], libraries[0] + 'dll.lib')): libraries[0] += 'dll' - compiler = self.get_compiler() - if compiler == 'mingw32': # MinGW + compiler_type = self.compiler.compiler_type + if compiler_type == 'mingw32': # MinGW if platform.architecture()[0] == '64bit': # needs MinGW-w64 define_macros.append(('MS_WIN64', None)) - elif compiler == 'msvc': # Microsoft Visual C++ + elif compiler_type == 'msvc': # Microsoft Visual C++ extra_compile_args[1:] = [ '-J', '-W3', '-WX', '-wd4391', '-Dinline=__inline'] # needed for MSVC 9 + super().build_extensions() setup( @@ -217,5 +228,5 @@ def finalize_options(self): include_dirs=include_dirs, library_dirs=library_dirs, define_macros=define_macros, undef_macros=undef_macros, libraries=libraries, extra_compile_args=extra_compile_args)], - cmdclass=dict(build_ext=build_pg_ext), + cmdclass=dict(build_ext=BuildPgExt), ) From 70b3392a52310c38b6300b7a5ae0c164f3f23e9e Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 27 Dec 2025 16:29:31 +0000 Subject: [PATCH 103/118] Add freeze parameter to inserttable --- docs/contents/changelog.rst | 5 +++-- docs/contents/pg/connection.rst | 3 ++- ext/pgconn.c | 26 ++++++++++++++++++-------- pg/_pg.pyi | 3 ++- pg/db.py | 8 +++++--- tests/test_classic_connection.py | 23 ++++++++++++++++++++++- tests/test_classic_dbwrapper.py | 18 ++++++++++++++++++ 7 files changed, 70 insertions(+), 16 deletions(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index cbb8e73f..df7ee495 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -3,8 +3,9 @@ ChangeLog Version 6.2.1 (2025-12-27) -------------------------- -- Setup is now cleaner and works with a wider range of setuptools versions. - +- Made setup cleaner and work with a wider range of setuptools versions (#90). +- The `inserttable()` method in the `pg` module now supports an optional + `freeze` parameter to optimize initial bulk loading (#81). Version 6.2.0 (2025-12-21) -------------------------- diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index e4a08591..21ff2d79 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -478,13 +478,14 @@ first, otherwise :meth:`Connection.getnotify` will always return ``None``. inserttable -- insert an iterable into a table ---------------------------------------------- -.. method:: Connection.inserttable(table, values, [columns]) +.. method:: Connection.inserttable(table, values, [columns], *, freeze=False) Insert a Python iterable into a database table :param str table: the table name :param list values: iterable of row values, which must be lists or tuples :param list columns: list or tuple of column names + :param bool freeze: if True, immediately freeze the inserted rows :rtype: int :raises TypeError: invalid connection, bad argument type, or too many arguments :raises MemoryError: insert buffer could not be allocated diff --git a/ext/pgconn.c b/ext/pgconn.c index 783eaffc..d81dc905 100644 --- a/ext/pgconn.c +++ b/ext/pgconn.c @@ -708,28 +708,34 @@ conn_is_non_blocking(connObject *self, PyObject *noargs) static char conn_inserttable__doc__[] = "inserttable(table, data, [columns]) -- insert iterable into table\n\n" "The fields in the iterable must be in the same order as in the table\n" - "or in the list or tuple of columns if one is specified.\n"; + "or in the list or tuple of columns if one is specified.\n\n" + "If the optional argument 'freeze' is set to True, the inserted rows\n" + "will be immediately frozen (can be useful for initial bulk loads).\n"; static PyObject * -conn_inserttable(connObject *self, PyObject *args) +conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) { PGresult *result; char *table, *buffer, *bufpt, *bufmax, *s, *t; - int encoding, ret; + int freeze = 0, encoding, ret; size_t bufsiz; PyObject *rows, *iter_row, *item, *columns = NULL; Py_ssize_t i, j, m, n; + static char *kwlist[] = {"table", "data", "columns", "freeze", NULL}; + if (!self->cnx) { PyErr_SetString(PyExc_TypeError, "Connection is not valid"); return NULL; } /* gets arguments */ - if (!PyArg_ParseTuple(args, "sO|O", &table, &rows, &columns)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "sO|O$p", kwlist, &table, + &rows, &columns, &freeze)) { PyErr_SetString( PyExc_TypeError, - "Method inserttable() expects a string and a list as arguments"); + "Method inserttable() expects a string, an iterable, an optional " + "list/tuple and an optional boolean 'freeze' as arguments"); return NULL; } @@ -834,7 +840,11 @@ conn_inserttable(connObject *self, PyObject *args) } } if (bufpt < bufmax) - snprintf(bufpt, (size_t)(bufmax - bufpt), " from stdin"); + bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), " from stdin"); + if (freeze && bufpt < bufmax) { + bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), " freeze"); + } + if (bufpt >= bufmax) { PyMem_Free(buffer); Py_DECREF(iter_row); @@ -1753,8 +1763,8 @@ static struct PyMethodDef conn_methods[] = { conn_set_notice_receiver__doc__}, {"getnotify", (PyCFunction)conn_get_notify, METH_NOARGS, conn_get_notify__doc__}, - {"inserttable", (PyCFunction)conn_inserttable, METH_VARARGS, - conn_inserttable__doc__}, + {"inserttable", (PyCFunction)conn_inserttable, + METH_VARARGS | METH_KEYWORDS, conn_inserttable__doc__}, {"transaction", (PyCFunction)conn_transaction, METH_NOARGS, conn_transaction__doc__}, {"parameter", (PyCFunction)conn_parameter, METH_VARARGS, diff --git a/pg/_pg.pyi b/pg/_pg.pyi index b14bd5fc..bbb24219 100644 --- a/pg/_pg.pyi +++ b/pg/_pg.pyi @@ -268,7 +268,8 @@ class Connection: ... def inserttable(self, table: str, values: Sequence[list|tuple], - columns: list[str] | tuple[str, ...] | None = None) -> int: + columns: list[str] | tuple[str, ...] | None = None, + freeze: bool=False) -> int: """Insert a Python iterable into a database table.""" ... diff --git a/pg/db.py b/pg/db.py index 0ba37e92..cbfab8ee 100644 --- a/pg/db.py +++ b/pg/db.py @@ -1403,11 +1403,13 @@ def getnotify(self) -> tuple[str, int, str] | None: return self._valid_db.getnotify() def inserttable(self, table: str, values: Sequence[list|tuple], - columns: list[str] | tuple[str, ...] | None = None) -> int: + columns: list[str] | tuple[str, ...] | None = None, + freeze: bool=False) -> int: """Insert a Python iterable into a database table.""" if columns is None: - return self._valid_db.inserttable(table, values) - return self._valid_db.inserttable(table, values, columns) + return self._valid_db.inserttable(table, values, freeze=freeze) + return self._valid_db.inserttable( + table, values, columns, freeze=freeze) def transaction(self) -> int: """Get the current in-transaction status of the server. diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 0a3f7bb0..d9dad42e 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -2118,7 +2118,28 @@ def test_insert_table_small_int_overflow(self): self.assertIn( 'value "33000" is out of range for type smallint', str(e)) else: - self.assertFalse('expected an error') + self.assertFalse('expected an error since value is out of range') + + def test_insert_table_with_freeze_false(self): + data = self.data + self.c.inserttable('test', data, freeze=False) + self.assertEqual(self.get_back(), data) + + def test_insert_table_with_freeze_true_without_truncate(self): + try: + self.c.inserttable('test', self.data, freeze=True) + except ValueError as e: + self.assertIn('cannot perform COPY FREEZE', str(e)) + else: + self.assertFalse('expected an error since table was not truncated') + + def test_insert_table_with_freeze_true_with_truncate(self): + data = self.data + self.c.query("begin") + self.c.query('truncate table test') + self.c.inserttable('test', data, freeze=True) + self.c.query('commit') + self.assertEqual(self.get_back(), data) class TestDirectSocketAccess(unittest.TestCase): diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index 1c3a6e7d..61a06ed9 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -4233,6 +4233,24 @@ def test_inserttable_from_query(self): self.assertEqual([row[0] for row in data_from], [1, 2, 3]) self.assertEqual(data_from, data_to) + def test_inserttable_with_freeze(self): + # use inserttable() with freeze and table created in same transaction + query = self.db.query + values = [(i,) for i in range(1, 4)] + self.db.begin() + self.create_table('test_table_freeze', 'n integer') + self.db.inserttable('test_table_freeze', values, freeze=True) + self.db.commit() + r = query("select * from test_table_freeze").getresult() + self.assertEqual(r, values) + + def test_inserttable_with_freeze_no_transaction(self): + # use inserttable() with freeze and table created before transaction + values = [(i,) for i in range(1, 4)] + self.create_table('test_table_freeze', 'n integer') + self.assertRaises(ValueError, self.db.inserttable, + 'test_table_freeze', values, freeze=True) + class TestDBClassNonStdOpts(TestDBClass): """Test the methods of the DB class with non-standard global options.""" From e60f78f73e1547b99f8d80497ab4c222f4297d05 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 2 Jan 2026 23:21:26 +0100 Subject: [PATCH 104/118] inserttable: remove arbitrary buffer limit (#91) Implementation by Justin Pryzby. Minor changes by Christoph Zwerschke. --- ext/pgconn.c | 179 +++++++++++++------------------ ext/pginternal.c | 48 +++++++++ ext/pgmodule.c | 15 ++- tests/test_classic_connection.py | 12 +-- 4 files changed, 140 insertions(+), 114 deletions(-) diff --git a/ext/pgconn.c b/ext/pgconn.c index d81dc905..302b5c62 100644 --- a/ext/pgconn.c +++ b/ext/pgconn.c @@ -716,9 +716,9 @@ static PyObject * conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) { PGresult *result; - char *table, *buffer, *bufpt, *bufmax, *s, *t; + char *table, *s, *t; + struct CharBuffer buffer = {0}; int freeze = 0, encoding, ret; - size_t bufsiz; PyObject *rows, *iter_row, *item, *columns = NULL; Py_ssize_t i, j, m, n; @@ -775,18 +775,18 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) n = -1; /* number of columns not yet known */ } - /* allocate buffer */ - if (!(buffer = PyMem_Malloc(MAX_BUFFER_SIZE))) { - Py_DECREF(iter_row); - return PyErr_NoMemory(); - } - encoding = PQclientEncoding(self->cnx); /* starts query */ - bufpt = buffer; - bufmax = bufpt + MAX_BUFFER_SIZE; - bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), "copy "); + ext_char_buffer_s(&buffer, "copy "); + + /* return early if there is no buffer */ + if (buffer.error) { + if (buffer.data) + PyMem_Free(buffer.data); + Py_DECREF(iter_row); + return PyErr_NoMemory(); + } s = table; do { @@ -794,18 +794,16 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) if (!t) t = s + strlen(s); table = PQescapeIdentifier(self->cnx, s, (size_t)(t - s)); - if (bufpt < bufmax) - bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), "%s", table); + ext_char_buffer_s(&buffer, table); PQfreemem(table); s = t; - if (*s && bufpt < bufmax) - *bufpt++ = *s++; + if (*s) + ext_char_buffer_c(&buffer, *s++); } while (*s); if (columns) { /* adds a string like f" ({','.join(columns)})" */ - if (bufpt < bufmax) - bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), " ("); + ext_char_buffer_s(&buffer, " ("); for (j = 0; j < n; ++j) { PyObject *obj = PySequence_Fast_GET_ITEM(columns, j); Py_ssize_t slen; @@ -817,7 +815,7 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) else if (PyUnicode_Check(obj)) { obj = get_encoded_string(obj, encoding); if (!obj) { - PyMem_Free(buffer); + PyMem_Free(buffer.data); Py_DECREF(iter_row); return NULL; /* pass the UnicodeEncodeError */ } @@ -826,37 +824,34 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) PyErr_SetString( PyExc_TypeError, "The third argument must contain only strings"); - PyMem_Free(buffer); + PyMem_Free(buffer.data); Py_DECREF(iter_row); return NULL; } PyBytes_AsStringAndSize(obj, &col, &slen); col = PQescapeIdentifier(self->cnx, col, (size_t)slen); Py_DECREF(obj); - if (bufpt < bufmax) - bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), "%s%s", col, - j == n - 1 ? ")" : ","); + ext_char_buffer_s(&buffer, col); + ext_char_buffer_c(&buffer, j == n - 1 ? ')' : ','); PQfreemem(col); } } - if (bufpt < bufmax) - bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), " from stdin"); - if (freeze && bufpt < bufmax) { - bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), " freeze"); - } + ext_char_buffer_s(&buffer, " from stdin"); + if (freeze) + ext_char_buffer_s(&buffer, " freeze"); - if (bufpt >= bufmax) { - PyMem_Free(buffer); + if (buffer.error) { + PyMem_Free(buffer.data); Py_DECREF(iter_row); return PyErr_NoMemory(); } Py_BEGIN_ALLOW_THREADS - result = PQexec(self->cnx, buffer); + result = PQexec(self->cnx, buffer.data); Py_END_ALLOW_THREADS if (!result || PQresultStatus(result) != PGRES_COPY_IN) { - PyMem_Free(buffer); + PyMem_Free(buffer.data); Py_DECREF(iter_row); PyErr_SetString(PyExc_ValueError, PQerrorMessage(self->cnx)); return NULL; @@ -871,7 +866,7 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) if (!(PyTuple_Check(columns) || PyList_Check(columns))) { PQputCopyEnd(self->cnx, "Invalid arguments"); - PyMem_Free(buffer); + PyMem_Free(buffer.data); Py_DECREF(columns); Py_DECREF(columns); Py_DECREF(iter_row); @@ -887,7 +882,7 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) } else if (j != n) { PQputCopyEnd(self->cnx, "Invalid arguments"); - PyMem_Free(buffer); + PyMem_Free(buffer.data); Py_DECREF(columns); Py_DECREF(iter_row); PyErr_SetString( @@ -896,65 +891,54 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) return NULL; } + /* reset buffer to empty */ + buffer.len = 0; + /* builds insert line */ - bufpt = buffer; - bufsiz = MAX_BUFFER_SIZE - 1; for (j = 0; j < n; ++j) { - if (j) { - *bufpt++ = '\t'; - --bufsiz; - } + if (j) + ext_char_buffer_c(&buffer, '\t'); item = PySequence_Fast_GET_ITEM(columns, j); /* convert item to string and append to buffer */ if (item == Py_None) { - if (bufsiz > 2) { - *bufpt++ = '\\'; - *bufpt++ = 'N'; - bufsiz -= 2; - } - else - bufsiz = 0; + ext_char_buffer_c(&buffer, '\\'); + ext_char_buffer_c(&buffer, 'N'); } else if (PyBytes_Check(item)) { const char *t = PyBytes_AsString(item); - while (*t && bufsiz) { + while (*t) { switch (*t) { case '\\': - *bufpt++ = '\\'; - if (--bufsiz) - *bufpt++ = '\\'; + ext_char_buffer_c(&buffer, '\\'); + ext_char_buffer_c(&buffer, '\\'); break; case '\t': - *bufpt++ = '\\'; - if (--bufsiz) - *bufpt++ = 't'; + ext_char_buffer_c(&buffer, '\\'); + ext_char_buffer_c(&buffer, '\t'); break; case '\r': - *bufpt++ = '\\'; - if (--bufsiz) - *bufpt++ = 'r'; + ext_char_buffer_c(&buffer, '\\'); + ext_char_buffer_c(&buffer, '\r'); break; case '\n': - *bufpt++ = '\\'; - if (--bufsiz) - *bufpt++ = 'n'; + ext_char_buffer_c(&buffer, '\\'); + ext_char_buffer_c(&buffer, '\n'); break; default: - *bufpt++ = *t; + ext_char_buffer_c(&buffer, *t); } ++t; - --bufsiz; } } else if (PyUnicode_Check(item)) { PyObject *s = get_encoded_string(item, encoding); if (!s) { PQputCopyEnd(self->cnx, "Encoding error"); - PyMem_Free(buffer); + PyMem_Free(buffer.data); Py_DECREF(item); Py_DECREF(columns); Py_DECREF(iter_row); @@ -963,33 +947,28 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) else { const char *t = PyBytes_AsString(s); - while (*t && bufsiz) { + while (*t) { switch (*t) { case '\\': - *bufpt++ = '\\'; - if (--bufsiz) - *bufpt++ = '\\'; + ext_char_buffer_c(&buffer, '\\'); + ext_char_buffer_c(&buffer, '\\'); break; case '\t': - *bufpt++ = '\\'; - if (--bufsiz) - *bufpt++ = 't'; + ext_char_buffer_c(&buffer, '\\'); + ext_char_buffer_c(&buffer, '\t'); break; case '\r': - *bufpt++ = '\\'; - if (--bufsiz) - *bufpt++ = 'r'; + ext_char_buffer_c(&buffer, '\\'); + ext_char_buffer_c(&buffer, '\r'); break; case '\n': - *bufpt++ = '\\'; - if (--bufsiz) - *bufpt++ = 'n'; + ext_char_buffer_c(&buffer, '\\'); + ext_char_buffer_c(&buffer, '\n'); break; default: - *bufpt++ = *t; + ext_char_buffer_c(&buffer, *t); } ++t; - --bufsiz; } Py_DECREF(s); } @@ -998,50 +977,42 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) PyObject *s = PyObject_Str(item); const char *t = PyUnicode_AsUTF8(s); - while (*t && bufsiz) { - *bufpt++ = *t++; - --bufsiz; - } + ext_char_buffer_s(&buffer, t); Py_DECREF(s); } else { PyObject *s = PyObject_Repr(item); const char *t = PyUnicode_AsUTF8(s); - while (*t && bufsiz) { + while (*t) { switch (*t) { case '\\': - *bufpt++ = '\\'; - if (--bufsiz) - *bufpt++ = '\\'; + ext_char_buffer_c(&buffer, '\\'); + ext_char_buffer_c(&buffer, '\\'); break; case '\t': - *bufpt++ = '\\'; - if (--bufsiz) - *bufpt++ = 't'; + ext_char_buffer_c(&buffer, '\\'); + ext_char_buffer_c(&buffer, '\t'); break; case '\r': - *bufpt++ = '\\'; - if (--bufsiz) - *bufpt++ = 'r'; + ext_char_buffer_c(&buffer, '\\'); + ext_char_buffer_c(&buffer, '\r'); break; case '\n': - *bufpt++ = '\\'; - if (--bufsiz) - *bufpt++ = 'n'; + ext_char_buffer_c(&buffer, '\\'); + ext_char_buffer_c(&buffer, '\n'); break; default: - *bufpt++ = *t; + ext_char_buffer_c(&buffer, *t); } ++t; - --bufsiz; } Py_DECREF(s); } - if (bufsiz <= 0) { + if (buffer.error) { PQputCopyEnd(self->cnx, "Memory error"); - PyMem_Free(buffer); + PyMem_Free(buffer.data); Py_DECREF(columns); Py_DECREF(iter_row); return PyErr_NoMemory(); @@ -1050,16 +1021,16 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) Py_DECREF(columns); - *bufpt++ = '\n'; + ext_char_buffer_c(&buffer, '\n'); /* sends data */ - ret = PQputCopyData(self->cnx, buffer, (int)(bufpt - buffer)); + ret = PQputCopyData(self->cnx, buffer.data, (int)buffer.len); if (ret != 1) { char *errormsg = ret == -1 ? PQerrorMessage(self->cnx) : "Data cannot be queued"; PyErr_SetString(PyExc_IOError, errormsg); PQputCopyEnd(self->cnx, errormsg); - PyMem_Free(buffer); + PyMem_Free(buffer.data); Py_DECREF(iter_row); return NULL; } @@ -1067,7 +1038,7 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) Py_DECREF(iter_row); if (PyErr_Occurred()) { - PyMem_Free(buffer); + PyMem_Free(buffer.data); return NULL; /* pass the iteration error */ } @@ -1075,11 +1046,11 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) if (ret != 1) { PyErr_SetString(PyExc_IOError, ret == -1 ? PQerrorMessage(self->cnx) : "Data cannot be queued"); - PyMem_Free(buffer); + PyMem_Free(buffer.data); return NULL; } - PyMem_Free(buffer); + PyMem_Free(buffer.data); Py_BEGIN_ALLOW_THREADS result = PQgetResult(self->cnx); diff --git a/ext/pginternal.c b/ext/pginternal.c index 25290950..b02a4307 100644 --- a/ext/pginternal.c +++ b/ext/pginternal.c @@ -1493,3 +1493,51 @@ notice_receiver(void *arg, const PGresult *res) } PyGILState_Release(gstate); } + +/* Extend char buffer with given string */ +static void +ext_char_buffer_s(struct CharBuffer *buf, const char *s) +{ + size_t len = strlen(s); + size_t need = buf->len + len + 1; + + if (!len || buf->error) + return; + + if (need >= buf->max_len) { + void *tmp; + + // Allocate powers of two unless it's large + if (2 * buf->max_len >= need && buf->max_len < 1024 * 1024) + need = 2 * buf->max_len; + + tmp = PyMem_Realloc(buf->data, need); + if (!tmp) { + buf->error = 1; + return; + } + + buf->data = tmp; + buf->max_len = need; + } + + memcpy(buf->data + buf->len, s, len + 1); + buf->len += len; +} + +/* Extend char buffer with given character */ +static void +ext_char_buffer_c(struct CharBuffer *buf, char c) +{ + if (buf->len > buf->max_len - 2) { + // slow path dealing with reallocation + char tmp[2] = {c, '\0'}; + ext_char_buffer_s(buf, tmp); + } + else { + if (buf->error) + return; + buf->data[buf->len++] = c; + buf->data[buf->len] = '\0'; + } +} diff --git a/ext/pgmodule.c b/ext/pgmodule.c index 916adda2..b5949a43 100644 --- a/ext/pgmodule.c +++ b/ext/pgmodule.c @@ -53,8 +53,7 @@ static const char *PyPgVersion = TOSTRING(PYGRESQL_VERSION); #define QUERY_MOVENEXT 3 #define QUERY_MOVEPREV 4 -#define MAX_BUFFER_SIZE 65536 /* maximum transaction size */ -#define MAX_ARRAY_DEPTH 16 /* maximum allowed depth of an array */ +#define MAX_ARRAY_DEPTH 16 /* maximum allowed depth of an array */ /* MODULE GLOBAL VARIABLES */ @@ -158,6 +157,18 @@ typedef struct { } largeObject; #define is_largeObject(v) (PyType(v) == &largeType) +/* + A buffer for character data with routines to handle resizing. + This is inspired by libpq's PQExpBufferData. + The buffer can be extended with the extend_char_buffer_s/x() functions. +*/ +struct CharBuffer { + char *data; /* actual string data */ + size_t len; /* strlen() of data */ + size_t max_len; /* allocated size */ + int error; /* error flag (invalid string)*/ +}; + /* Internal functions */ #include "pginternal.c" diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index d9dad42e..3bd36495 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -1967,9 +1967,9 @@ def test_inserttable_with_huge_list_of_column_names(self): cols = ['very_long_column_name'] * 2000 # Should raise a value error because the column does not exist self.assertRaises(ValueError, self.c.inserttable, 'test', data, cols) - # double the size, should catch buffer overflow and raise memory error + # double the size, should not overflow buffer nor raise memory error cols *= 2 - self.assertRaises(MemoryError, self.c.inserttable, 'test', data, cols) + self.assertRaises(ValueError, self.c.inserttable, 'test', data, cols) def test_inserttable_with_out_of_range_data(self): # try inserting data out of range for the column type @@ -2095,16 +2095,12 @@ def __repr__(self): self.c.query('select t from test').getresult(), [(s,)] * 3) def test_insert_table_big_row_size(self): - # inserting rows with a size of up to 64k bytes should work - t = '*' * 50000 + # inserting rows with a size exceeding 64k bytes should work + t = '*' * 75000 data = [(t,)] self.c.inserttable('test', data, ['t']) self.assertEqual( self.c.query('select t from test').getresult(), data) - # double the size, should catch buffer overflow and raise memory error - t *= 2 - data = [(t,)] - self.assertRaises(MemoryError, self.c.inserttable, 'test', data, ['t']) def test_insert_table_small_int_overflow(self): rest_row = self.data[2][1:] From f090e587b448b0532f9e582849eff9a1cbf562e3 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 2 Jan 2026 22:53:43 +0000 Subject: [PATCH 105/118] Fix edge case error handling in two functions --- ext/pgconn.c | 6 ++++++ ext/pgquery.c | 5 +---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/ext/pgconn.c b/ext/pgconn.c index 302b5c62..fe34f34b 100644 --- a/ext/pgconn.c +++ b/ext/pgconn.c @@ -532,6 +532,12 @@ conn_describe_prepared(connObject *self, PyObject *args) query_obj->max_row = PQntuples(result); query_obj->num_fields = PQnfields(result); query_obj->col_types = get_col_types(result, query_obj->num_fields); + if (!query_obj->col_types) { + PQclear(result); + Py_DECREF(query_obj); + Py_DECREF(self); + return NULL; + } return (PyObject *)query_obj; } set_error(ProgrammingError, "Cannot describe prepared statement", diff --git a/ext/pgquery.c b/ext/pgquery.c index b87eba18..2eb148bd 100644 --- a/ext/pgquery.c +++ b/ext/pgquery.c @@ -179,11 +179,8 @@ _get_async_result(queryObject *self, int keep) self->max_row = PQntuples(self->result); self->num_fields = PQnfields(self->result); self->col_types = get_col_types(self->result, self->num_fields); - if (!self->col_types) { - Py_DECREF(self); - Py_DECREF(self); + if (!self->col_types) return NULL; - } } else if (self->async == 2 && !self->max_row && !self->num_fields && !self->col_types) { From 40df5a3c1e22f2649afc4c10228e83f35bd34cd4 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 3 Jan 2026 17:41:30 +0000 Subject: [PATCH 106/118] inserttable: optimize memory allocation --- ext/pgconn.c | 41 ++++++++++++++++----------------- ext/pginternal.c | 59 ++++++++++++++++++++++++++++++++---------------- ext/pgmodule.c | 8 +++---- 3 files changed, 63 insertions(+), 45 deletions(-) diff --git a/ext/pgconn.c b/ext/pgconn.c index fe34f34b..3a03a83e 100644 --- a/ext/pgconn.c +++ b/ext/pgconn.c @@ -783,17 +783,15 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) encoding = PQclientEncoding(self->cnx); - /* starts query */ - ext_char_buffer_s(&buffer, "copy "); - - /* return early if there is no buffer */ - if (buffer.error) { - if (buffer.data) - PyMem_Free(buffer.data); + /* pre-allocate some memory for the query buffer */ + if (!init_char_buffer(&buffer, 4096)) { Py_DECREF(iter_row); return PyErr_NoMemory(); } + /* starts query */ + ext_char_buffer_s(&buffer, "copy "); + s = table; do { t = strchr(s, '.'); @@ -838,13 +836,14 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) col = PQescapeIdentifier(self->cnx, col, (size_t)slen); Py_DECREF(obj); ext_char_buffer_s(&buffer, col); - ext_char_buffer_c(&buffer, j == n - 1 ? ')' : ','); PQfreemem(col); + ext_char_buffer_c(&buffer, j == n - 1 ? ')' : ','); } } ext_char_buffer_s(&buffer, " from stdin"); if (freeze) ext_char_buffer_s(&buffer, " freeze"); + ext_char_buffer_c(&buffer, '\0'); if (buffer.error) { PyMem_Free(buffer.data); @@ -874,7 +873,6 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) PQputCopyEnd(self->cnx, "Invalid arguments"); PyMem_Free(buffer.data); Py_DECREF(columns); - Py_DECREF(columns); Py_DECREF(iter_row); PyErr_SetString( PyExc_TypeError, @@ -897,10 +895,10 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) return NULL; } - /* reset buffer to empty */ - buffer.len = 0; + /* empty buffer while keeping allocated memory */ + buffer.size = 0; - /* builds insert line */ + /* build insert line */ for (j = 0; j < n; ++j) { if (j) @@ -1015,22 +1013,21 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) } Py_DECREF(s); } - - if (buffer.error) { - PQputCopyEnd(self->cnx, "Memory error"); - PyMem_Free(buffer.data); - Py_DECREF(columns); - Py_DECREF(iter_row); - return PyErr_NoMemory(); - } } Py_DECREF(columns); + /* terminate line */ ext_char_buffer_c(&buffer, '\n'); + if (buffer.error) { + PQputCopyEnd(self->cnx, "Memory error"); + PyMem_Free(buffer.data); + Py_DECREF(iter_row); + return PyErr_NoMemory(); + } - /* sends data */ - ret = PQputCopyData(self->cnx, buffer.data, (int)buffer.len); + /* send data */ + ret = PQputCopyData(self->cnx, buffer.data, (int)buffer.size); if (ret != 1) { char *errormsg = ret == -1 ? PQerrorMessage(self->cnx) : "Data cannot be queued"; diff --git a/ext/pginternal.c b/ext/pginternal.c index b02a4307..bb0c5561 100644 --- a/ext/pginternal.c +++ b/ext/pginternal.c @@ -1494,22 +1494,42 @@ notice_receiver(void *arg, const PGresult *res) PyGILState_Release(gstate); } -/* Extend char buffer with given string */ +/* Pre-allocate some memory for a char buffer and return success status. */ +static int +init_char_buffer(struct CharBuffer *buf, size_t initial_size) +{ + buf->size = 0; + buf->data = PyMem_Malloc(initial_size); + if (buf->data) { + buf->max_size = initial_size; + buf->error = 0; + } + else { + buf->max_size = 0; + buf->error = 1; + } + return !buf->error; +} + +/* Extend char buffer with given string. + Note: We do not assume or guarantee that the buffer is zero-terminated. */ static void ext_char_buffer_s(struct CharBuffer *buf, const char *s) { - size_t len = strlen(s); - size_t need = buf->len + len + 1; + size_t len = strlen(s), need; if (!len || buf->error) return; - if (need >= buf->max_len) { + if ((need = buf->size + len) >= buf->max_size) { void *tmp; - // Allocate powers of two unless it's large - if (2 * buf->max_len >= need && buf->max_len < 1024 * 1024) - need = 2 * buf->max_len; + if (buf->max_size < 1024 * 1024) { + /* allocate powers of two unless it's large */ + size_t double_size = 2 * buf->max_size; + if (double_size >= need) /* overflow check */ + need = double_size; + } tmp = PyMem_Realloc(buf->data, need); if (!tmp) { @@ -1518,26 +1538,27 @@ ext_char_buffer_s(struct CharBuffer *buf, const char *s) } buf->data = tmp; - buf->max_len = need; + buf->max_size = need; } - memcpy(buf->data + buf->len, s, len + 1); - buf->len += len; + memcpy(buf->data + buf->size, s, len); + buf->size += len; } /* Extend char buffer with given character */ static void ext_char_buffer_c(struct CharBuffer *buf, char c) { - if (buf->len > buf->max_len - 2) { - // slow path dealing with reallocation - char tmp[2] = {c, '\0'}; + if (buf->error) + return; + + if (buf->size >= buf->max_size) { /* buffer is full? */ + /* slow path dealing with reallocation */ + char tmp[2] = {c ? c : '\n', '\0'}; /* allow adding a zero-byte */ ext_char_buffer_s(buf, tmp); + if (!c) + buf->data[buf->size - 1] = '\0'; /* fix zero-byte */ } - else { - if (buf->error) - return; - buf->data[buf->len++] = c; - buf->data[buf->len] = '\0'; - } + else + buf->data[buf->size++] = c; } diff --git a/ext/pgmodule.c b/ext/pgmodule.c index b5949a43..cbd3c66c 100644 --- a/ext/pgmodule.c +++ b/ext/pgmodule.c @@ -163,10 +163,10 @@ typedef struct { The buffer can be extended with the extend_char_buffer_s/x() functions. */ struct CharBuffer { - char *data; /* actual string data */ - size_t len; /* strlen() of data */ - size_t max_len; /* allocated size */ - int error; /* error flag (invalid string)*/ + char *data; /* actual string data */ + size_t size; /* current size of data */ + size_t max_size; /* allocated size */ + int error; /* error flag (invalid data) */ }; /* Internal functions */ From 3b25216f7bc88989058f72141e9ae746a3799bcf Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sat, 3 Jan 2026 18:06:51 +0000 Subject: [PATCH 107/118] Bump version and year of copyright --- LICENSE.txt | 2 +- docs/about.rst | 2 +- docs/conf.py | 2 +- docs/contents/changelog.rst | 6 ++++++ docs/copyright.rst | 2 +- ext/pgconn.c | 2 +- ext/pginternal.c | 2 +- ext/pglarge.c | 2 +- ext/pgmodule.c | 2 +- ext/pgnotice.c | 2 +- ext/pgquery.c | 2 +- ext/pgsource.c | 2 +- pg/__init__.py | 2 +- pgdb/__init__.py | 2 +- pyproject.toml | 2 +- 15 files changed, 20 insertions(+), 14 deletions(-) diff --git a/LICENSE.txt b/LICENSE.txt index e905706e..fcdfd8d3 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -6,7 +6,7 @@ Copyright (c) 1995, Pascal Andre Further modifications copyright (c) 1997-2008 by D'Arcy J.M. Cain -Further modifications copyright (c) 2009-2025 by the PyGreSQL Development Team +Further modifications copyright (c) 2009-2026 by the PyGreSQL Development Team PyGreSQL is released under the PostgreSQL License, a liberal Open Source license, similar to the BSD or MIT licenses: diff --git a/docs/about.rst b/docs/about.rst index 96284f05..783d759c 100644 --- a/docs/about.rst +++ b/docs/about.rst @@ -8,7 +8,7 @@ powerful PostgreSQL features from Python. | This software is copyright © 1995, Pascal Andre. | Further modifications are copyright © 1997-2008 by D'Arcy J.M. Cain. - | Further modifications are copyright © 2009-2025 by the PyGreSQL team. + | Further modifications are copyright © 2009-2026 by the PyGreSQL team. | For licensing details, see the full :doc:`copyright`. **PostgreSQL** is a highly scalable, SQL compliant, open source diff --git a/docs/conf.py b/docs/conf.py index f25d78e7..36d65aab 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -8,7 +8,7 @@ project = 'PyGreSQL' author = 'The PyGreSQL team' -copyright = '2025, ' + author +copyright = '2026, ' + author def project_version(): with open('../pyproject.toml') as f: diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index df7ee495..76747403 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,12 @@ ChangeLog ========= +Version 6.2.2 (2026-01-03) +-------------------------- +- The `inserttable()` method in the `pg` module can now handle rows of + arbitrary size (this was previously limited to 64 KB) (#79, thanks + to Justin Pryzby for this contribution). + Version 6.2.1 (2025-12-27) -------------------------- - Made setup cleaner and work with a wider range of setuptools versions (#90). diff --git a/docs/copyright.rst b/docs/copyright.rst index bf7d9b04..300bd716 100644 --- a/docs/copyright.rst +++ b/docs/copyright.rst @@ -10,7 +10,7 @@ Copyright (c) 1995, Pascal Andre Further modifications copyright (c) 1997-2008 by D'Arcy J.M. Cain (darcy@PyGreSQL.org) -Further modifications copyright (c) 2009-2025 by the PyGreSQL team. +Further modifications copyright (c) 2009-2026 by the PyGreSQL team. Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement diff --git a/ext/pgconn.c b/ext/pgconn.c index 3a03a83e..c6569321 100644 --- a/ext/pgconn.c +++ b/ext/pgconn.c @@ -3,7 +3,7 @@ * * The connection object - this file is part a of the C extension module. * - * Copyright (c) 2025 by the PyGreSQL Development Team + * Copyright (c) 2026 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pginternal.c b/ext/pginternal.c index bb0c5561..20ab26f2 100644 --- a/ext/pginternal.c +++ b/ext/pginternal.c @@ -3,7 +3,7 @@ * * Internal functions - this file is part a of the C extension module. * - * Copyright (c) 2025 by the PyGreSQL Development Team + * Copyright (c) 2026 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pglarge.c b/ext/pglarge.c index 1b817b25..bac84d96 100644 --- a/ext/pglarge.c +++ b/ext/pglarge.c @@ -3,7 +3,7 @@ * * Large object support - this file is part a of the C extension module. * - * Copyright (c) 2025 by the PyGreSQL Development Team + * Copyright (c) 2026 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgmodule.c b/ext/pgmodule.c index cbd3c66c..c66637cc 100644 --- a/ext/pgmodule.c +++ b/ext/pgmodule.c @@ -3,7 +3,7 @@ * * This is the main file for the C extension module. * - * Copyright (c) 2025 by the PyGreSQL Development Team + * Copyright (c) 2026 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgnotice.c b/ext/pgnotice.c index c56b249f..0d64d6cc 100644 --- a/ext/pgnotice.c +++ b/ext/pgnotice.c @@ -3,7 +3,7 @@ * * The notice object - this file is part a of the C extension module. * - * Copyright (c) 2025 by the PyGreSQL Development Team + * Copyright (c) 2026 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgquery.c b/ext/pgquery.c index 2eb148bd..84a69fdd 100644 --- a/ext/pgquery.c +++ b/ext/pgquery.c @@ -3,7 +3,7 @@ * * The query object - this file is part a of the C extension module. * - * Copyright (c) 2025 by the PyGreSQL Development Team + * Copyright (c) 2026 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/ext/pgsource.c b/ext/pgsource.c index bbec2f86..972d7e76 100644 --- a/ext/pgsource.c +++ b/ext/pgsource.c @@ -3,7 +3,7 @@ * * The source object - this file is part a of the C extension module. * - * Copyright (c) 2025 by the PyGreSQL Development Team + * Copyright (c) 2026 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ diff --git a/pg/__init__.py b/pg/__init__.py index c3b7f4e9..da95e505 100644 --- a/pg/__init__.py +++ b/pg/__init__.py @@ -4,7 +4,7 @@ # # This file contains the classic pg module. # -# Copyright (c) 2025 by the PyGreSQL Development Team +# Copyright (c) 2026 by the PyGreSQL Development Team # # The notification handler is based on pgnotify which is # Copyright (c) 2001 Ng Pheng Siong. All rights reserved. diff --git a/pgdb/__init__.py b/pgdb/__init__.py index 132ce292..9fb6158a 100644 --- a/pgdb/__init__.py +++ b/pgdb/__init__.py @@ -4,7 +4,7 @@ # # This file contains the DB-API 2 compatible pgdb module. # -# Copyright (c) 2025 by the PyGreSQL Development Team +# Copyright (c) 2026 by the PyGreSQL Development Team # # Please see the LICENSE.TXT file for specific restrictions. diff --git a/pyproject.toml b/pyproject.toml index 669bafd7..2777003f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "PyGreSQL" -version = "6.2.1" +version = "6.2.2" requires-python = ">=3.8" authors = [ {name = "D'Arcy J. M. Cain", email = "darcy@pygresql.org"}, From dbd3ec81f5e499642a5cfe6d2fe5ff5a346ba466 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 9 Jan 2026 17:32:23 +0000 Subject: [PATCH 108/118] Modularize C extension into separate units (#25) --- ext/pgconn.c | 10 ++- ext/pginternal.c | 47 +++++++------- ext/pginternal.h | 91 +++++++++++++++++++++++++++ ext/pglarge.c | 6 +- ext/pgmodule.c | 157 ++++++++++++++--------------------------------- ext/pgmodule.h | 145 +++++++++++++++++++++++++++++++++++++++++++ ext/pgnotice.c | 6 +- ext/pgquery.c | 6 +- ext/pgsource.c | 6 +- setup.py | 10 ++- 10 files changed, 343 insertions(+), 141 deletions(-) create mode 100644 ext/pginternal.h create mode 100644 ext/pgmodule.h diff --git a/ext/pgconn.c b/ext/pgconn.c index c6569321..31767a3d 100644 --- a/ext/pgconn.c +++ b/ext/pgconn.c @@ -8,6 +8,10 @@ * Please see the LICENSE.TXT file for specific restrictions. */ +/* Shared headers */ +#include "pginternal.h" +#include "pgmodule.h" + /* Deallocate connection object. */ static void conn_dealloc(connObject *self) @@ -113,7 +117,7 @@ conn_getattr(connObject *self, PyObject *nameobj) } /* Check connection validity. */ -static int +int _check_cnx_obj(connObject *self) { if (!self || !self->valid || !self->cnx) { @@ -154,7 +158,7 @@ conn_source(connObject *self, PyObject *noargs) /* For a non-query result, set the appropriate error status, return the appropriate value, and free the result set. */ -static PyObject * +PyObject * _conn_non_query_result(int status, PGresult *result, PGconn *cnx) { switch (status) { @@ -1775,7 +1779,7 @@ static struct PyMethodDef conn_methods[] = { static char conn__doc__[] = "PostgreSQL connection object"; /* Connection type definition */ -static PyTypeObject connType = { +PyTypeObject connType = { PyVarObject_HEAD_INIT(NULL, 0) "pg.Connection", /* tp_name */ sizeof(connObject), /* tp_basicsize */ 0, /* tp_itemsize */ diff --git a/ext/pginternal.c b/ext/pginternal.c index 20ab26f2..ad8bfffc 100644 --- a/ext/pginternal.c +++ b/ext/pginternal.c @@ -8,6 +8,9 @@ * Please see the LICENSE.TXT file for specific restrictions. */ +/* Shared headers */ +#include "pginternal.h" + /* PyGreSQL internal types */ /* Simple types */ @@ -27,7 +30,7 @@ /* Shared functions for encoding and decoding strings */ -static PyObject * +PyObject * get_decoded_string(const char *str, Py_ssize_t size, int encoding) { if (encoding == pg_encoding_utf8) @@ -41,7 +44,7 @@ get_decoded_string(const char *str, Py_ssize_t size, int encoding) "strict"); } -static PyObject * +PyObject * get_encoded_string(PyObject *unicode_obj, int encoding) { if (encoding == pg_encoding_utf8) @@ -176,7 +179,7 @@ get_type(Oid pgtype) } /* Get PyGreSQL column types for all result columns. */ -static int * +int * get_col_types(PGresult *result, int nfields) { int *types, *t, j; @@ -194,7 +197,7 @@ get_col_types(PGresult *result, int nfields) /* Cast a bytea encoded text based type to a Python object. This assumes the text is null-terminated character string. */ -static PyObject * +PyObject * cast_bytea_text(char *s) { PyObject *obj; @@ -212,7 +215,7 @@ cast_bytea_text(char *s) /* Cast a text based type to a Python object. This needs the character string, size and encoding. */ -static PyObject * +PyObject * cast_sized_text(char *s, Py_ssize_t size, int encoding, int type) { PyObject *obj, *tmp_obj; @@ -264,7 +267,7 @@ cast_sized_text(char *s, Py_ssize_t size, int encoding, int type) /* Cast an arbitrary type to a Python object using a callback function. This needs the character string, size, encoding, the Postgres type and the external typecast function to be called. */ -static PyObject * +PyObject * cast_other(char *s, Py_ssize_t size, int encoding, Oid pgtype, PyObject *cast_hook) { @@ -282,7 +285,7 @@ cast_other(char *s, Py_ssize_t size, int encoding, Oid pgtype, /* Cast a simple type to a Python object. This needs a character string representation with a given size. */ -static PyObject * +PyObject * cast_sized_simple(char *s, Py_ssize_t size, int type) { PyObject *obj, *tmp_obj; @@ -375,7 +378,7 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) /* Cast a simple type to a Python object. This needs a null-terminated character string representation. */ -static PyObject * +PyObject * cast_unsized_simple(char *s, int type) { PyObject *obj, *tmp_obj; @@ -454,7 +457,7 @@ cast_unsized_simple(char *s, int type) Use internal type or cast function to cast elements. The parameter delim specifies the delimiter for the elements, since some types do not use the default delimiter of a comma. */ -static PyObject * +PyObject * cast_array(char *s, Py_ssize_t size, int encoding, int type, PyObject *cast, char delim) { @@ -721,7 +724,7 @@ cast_array(char *s, Py_ssize_t size, int encoding, int type, PyObject *cast, functions to cast elements. The parameter len is the record size. The parameter delim can specify a delimiter for the elements, although composite types always use a comma as delimiter. */ -static PyObject * +PyObject * cast_record(char *s, Py_ssize_t size, int encoding, int *type, PyObject *cast, Py_ssize_t len, char delim) { @@ -905,7 +908,7 @@ cast_record(char *s, Py_ssize_t size, int encoding, int *type, PyObject *cast, /* Cast string s with size and encoding to a Python dictionary. using the input and output syntax for hstore values. */ -static PyObject * +PyObject * cast_hstore(char *s, Py_ssize_t size, int encoding) { PyObject *result; @@ -1198,14 +1201,14 @@ set_error_msg_and_state(PyObject *type, const char *msg, int encoding, } /* Set given database error message. */ -static void +void set_error_msg(PyObject *type, const char *msg) { set_error_msg_and_state(type, msg, pg_encoding_ascii, NULL); } /* Set database error from connection and/or result. */ -static void +void set_error(PyObject *type, const char *msg, PGconn *cnx, PGresult *result) { char *sqlstate = NULL; @@ -1228,7 +1231,7 @@ set_error(PyObject *type, const char *msg, PGconn *cnx, PGresult *result) } /* Get SSL attributes and values as a dictionary. */ -static PyObject * +PyObject * get_ssl_attributes(PGconn *cnx) { PyObject *attr_dict = NULL; @@ -1260,7 +1263,7 @@ get_ssl_attributes(PGconn *cnx) PQprint() is not used because handing over a stream from Python to PostgreSQL can be problematic if they use different libs for streams and because using PQprint() and tp_print is not recommended any more. */ -static PyObject * +PyObject * format_result(const PGresult *res) { const int n = PQnfields(res); @@ -1402,7 +1405,7 @@ format_result(const PGresult *res) } /* Internal function converting a Postgres datestyles to date formats. */ -static const char * +const char * date_style_to_format(const char *s) { static const char *formats[] = { @@ -1435,7 +1438,7 @@ date_style_to_format(const char *s) } /* Internal function converting a date format to a Postgres datestyle. */ -static const char * +const char * date_format_to_style(const char *s) { static const char *datestyle[] = { @@ -1470,7 +1473,7 @@ date_format_to_style(const char *s) } /* Internal wrapper for the notice receiver callback. */ -static void +void notice_receiver(void *arg, const PGresult *res) { PyGILState_STATE gstate = PyGILState_Ensure(); @@ -1495,7 +1498,7 @@ notice_receiver(void *arg, const PGresult *res) } /* Pre-allocate some memory for a char buffer and return success status. */ -static int +int init_char_buffer(struct CharBuffer *buf, size_t initial_size) { buf->size = 0; @@ -1512,8 +1515,8 @@ init_char_buffer(struct CharBuffer *buf, size_t initial_size) } /* Extend char buffer with given string. - Note: We do not assume or guarantee that the buffer is zero-terminated. */ -static void + Note: We do not assume or guarantee that the buffer is zero-terminated. */ +void ext_char_buffer_s(struct CharBuffer *buf, const char *s) { size_t len = strlen(s), need; @@ -1546,7 +1549,7 @@ ext_char_buffer_s(struct CharBuffer *buf, const char *s) } /* Extend char buffer with given character */ -static void +void ext_char_buffer_c(struct CharBuffer *buf, char c) { if (buf->error) diff --git a/ext/pginternal.h b/ext/pginternal.h new file mode 100644 index 00000000..0567e31d --- /dev/null +++ b/ext/pginternal.h @@ -0,0 +1,91 @@ +/* + * Internal functions header for the PyGreSQL C extension. + * Provides prototypes for helpers implemented in pginternal.c + * and extern declarations for module globals used therein. + */ + +#ifndef PYGRE_SQL_PGINTERNAL_H +#define PYGRE_SQL_PGINTERNAL_H + +#define PY_SSIZE_T_CLEAN +#include +#include + +#include "pgmodule.h" + +/* Encoding helpers */ +/* PyGreSQL internal types */ +#define PYGRES_INT 1 +#define PYGRES_LONG 2 +#define PYGRES_FLOAT 3 +#define PYGRES_DECIMAL 4 +#define PYGRES_MONEY 5 +#define PYGRES_BOOL 6 +/* Text based types */ +#define PYGRES_TEXT 8 +#define PYGRES_BYTEA 9 +#define PYGRES_JSON 10 +#define PYGRES_OTHER 11 +/* Array types */ +#define PYGRES_ARRAY 16 +PyObject * +get_decoded_string(const char *str, Py_ssize_t size, int encoding); +PyObject * +get_encoded_string(PyObject *unicode_obj, int encoding); + +/* Result/Type helpers */ +int * +get_col_types(PGresult *result, int nfields); +PyObject * +format_result(const PGresult *res); + +/* Casting helpers */ +PyObject * +cast_bytea_text(char *s); +PyObject * +cast_sized_text(char *s, Py_ssize_t size, int encoding, int type); +PyObject * +cast_other(char *s, Py_ssize_t size, int encoding, Oid pgtype, + PyObject *cast_hook); +PyObject * +cast_sized_simple(char *s, Py_ssize_t size, int type); +PyObject * +cast_unsized_simple(char *s, int type); +PyObject * +cast_array(char *s, Py_ssize_t size, int encoding, int type, PyObject *cast, + char delim); +PyObject * +cast_record(char *s, Py_ssize_t size, int encoding, int *type, PyObject *cast, + Py_ssize_t len, char delim); +PyObject * +cast_hstore(char *s, Py_ssize_t size, int encoding); + +/* Error helpers */ +void +set_error_msg(PyObject *type, const char *msg); +void +set_error(PyObject *type, const char *msg, PGconn *cnx, PGresult *result); + +/* SSL attributes helper */ +PyObject * +get_ssl_attributes(PGconn *cnx); + +/* Date style helpers */ +const char * +date_style_to_format(const char *s); +const char * +date_format_to_style(const char *s); + +/* Notice receiver */ +void +notice_receiver(void *arg, const PGresult *res); + +/* Char buffer helpers */ +int +init_char_buffer(struct CharBuffer *buf, size_t initial_size); +void +ext_char_buffer_s(struct CharBuffer *buf, const char *s); +void +ext_char_buffer_c(struct CharBuffer *buf, char c); + +#endif /* PYGRE_SQL_PGINTERNAL_H */ diff --git a/ext/pglarge.c b/ext/pglarge.c index bac84d96..46740f3b 100644 --- a/ext/pglarge.c +++ b/ext/pglarge.c @@ -8,6 +8,10 @@ * Please see the LICENSE.TXT file for specific restrictions. */ +/* Shared headers */ +#include "pginternal.h" +#include "pgmodule.h" + /* Deallocate large object. */ static void large_dealloc(largeObject *self) @@ -423,7 +427,7 @@ static struct PyMethodDef large_methods[] = { static char large__doc__[] = "PostgreSQL large object"; /* Large object type definition */ -static PyTypeObject largeType = { +PyTypeObject largeType = { PyVarObject_HEAD_INIT(NULL, 0) "pg.LargeObject", /* tp_name */ sizeof(largeObject), /* tp_basicsize */ 0, /* tp_itemsize */ diff --git a/ext/pgmodule.c b/ext/pgmodule.c index c66637cc..6b04787e 100644 --- a/ext/pgmodule.c +++ b/ext/pgmodule.c @@ -15,12 +15,13 @@ #include #include -/* The type definitions from */ -#include "pgtypes.h" +/* Shared headers */ +#include "pginternal.h" +#include "pgmodule.h" -static PyObject *Error, *Warning, *InterfaceError, *DatabaseError, - *InternalError, *OperationalError, *ProgrammingError, *IntegrityError, - *DataError, *NotSupportedError, *InvalidResultError, *NoResultError, +PyObject *Error, *Warning, *InterfaceError, *DatabaseError, *InternalError, + *OperationalError, *ProgrammingError, *IntegrityError, *DataError, + *NotSupportedError, *InvalidResultError, *NoResultError, *MultipleResultsError, *Connection, *Query, *LargeObject; #define _TOSTRING(x) #x @@ -32,9 +33,15 @@ static const char *PyPgVersion = TOSTRING(PYGRESQL_VERSION); #endif /* Default values */ +#undef PG_ARRAYSIZE #define PG_ARRAYSIZE 1 /* Flags for object validity checks */ +#undef CHECK_OPEN +#undef CHECK_CLOSE +#undef CHECK_CNX +#undef CHECK_RESULT +#undef CHECK_DQL #define CHECK_OPEN 1 #define CHECK_CLOSE 2 #define CHECK_CNX 4 @@ -42,44 +49,52 @@ static const char *PyPgVersion = TOSTRING(PYGRESQL_VERSION); #define CHECK_DQL 16 /* Query result types */ +#undef RESULT_EMPTY +#undef RESULT_DML +#undef RESULT_DDL +#undef RESULT_DQL #define RESULT_EMPTY 1 #define RESULT_DML 2 #define RESULT_DDL 3 #define RESULT_DQL 4 /* Flags for move methods */ +#undef QUERY_MOVEFIRST +#undef QUERY_MOVELAST +#undef QUERY_MOVENEXT +#undef QUERY_MOVEPREV #define QUERY_MOVEFIRST 1 #define QUERY_MOVELAST 2 #define QUERY_MOVENEXT 3 #define QUERY_MOVEPREV 4 +#undef MAX_ARRAY_DEPTH #define MAX_ARRAY_DEPTH 16 /* maximum allowed depth of an array */ /* MODULE GLOBAL VARIABLES */ -static PyObject *pg_default_host; /* default database host */ -static PyObject *pg_default_base; /* default database name */ -static PyObject *pg_default_opt; /* default connection options */ -static PyObject *pg_default_port; /* default connection port */ -static PyObject *pg_default_user; /* default username */ -static PyObject *pg_default_passwd; /* default password */ - -static PyObject *decimal = NULL, /* decimal type */ - *dictiter = NULL, /* function for getting dict results */ - *namediter = NULL, /* function for getting named results */ - *namednext = NULL, /* function for getting one named result */ - *scalariter = NULL, /* function for getting scalar results */ - *jsondecode = - NULL; /* function for decoding json strings */ -static const char *date_format = NULL; /* date format that is always assumed */ -static char decimal_point = '.'; /* decimal point used in money values */ -static int bool_as_text = 0; /* whether bool shall be returned as text */ -static int array_as_text = 0; /* whether arrays shall be returned as text */ -static int bytea_escaped = 0; /* whether bytea shall be returned escaped */ - -static int pg_encoding_utf8 = 0; -static int pg_encoding_latin1 = 0; -static int pg_encoding_ascii = 0; +PyObject *pg_default_host; /* default database host */ +PyObject *pg_default_base; /* default database name */ +PyObject *pg_default_opt; /* default connection options */ +PyObject *pg_default_port; /* default connection port */ +PyObject *pg_default_user; /* default username */ +PyObject *pg_default_passwd; /* default password */ + +PyObject *decimal = NULL; /* decimal type */ +PyObject *dictiter = NULL; /* function for getting dict results */ +PyObject *namediter = NULL; /* function for getting named results */ +PyObject *namednext = NULL; /* function for getting one named result */ +PyObject *scalariter = NULL; /* function for getting scalar results */ +PyObject *jsondecode = NULL; /* function for decoding json strings */ +const char *date_format = NULL; /* date format that is always assumed */ +char decimal_point = '.'; /* decimal point used in money values */ +int bool_as_text = 0; /* whether bool shall be returned as text */ +int array_as_text = 0; /* whether arrays shall be returned as text */ +int bytea_escaped = 0; /* whether bytea shall be returned escaped */ + +int pg_encoding_utf8 = 0; +int pg_encoding_latin1 = 0; +int pg_encoding_ascii = 0; /* OBJECTS @@ -101,91 +116,11 @@ OBJECTS - source: Source object returned by pg.conn.source(). */ -/* Forward declarations for types */ -static PyTypeObject connType, sourceType, queryType, noticeType, largeType; - -/* Forward static declarations */ -static void -notice_receiver(void *, const PGresult *); - -/* Object declarations */ - -typedef struct { - PyObject_HEAD int valid; /* validity flag */ - PGconn *cnx; /* Postgres connection handle */ - const char *date_format; /* date format derived from datestyle */ - PyObject *cast_hook; /* external typecast method */ - PyObject *notice_receiver; /* current notice receiver */ -} connObject; -#define is_connObject(v) (PyType(v) == &connType) - -typedef struct { - PyObject_HEAD int valid; /* validity flag */ - connObject *pgcnx; /* parent connection object */ - PGresult *result; /* result content */ - int encoding; /* client encoding */ - int result_type; /* result type (DDL/DML/DQL) */ - long arraysize; /* array size for fetch method */ - int current_row; /* currently selected row */ - int max_row; /* number of rows in the result */ - int num_fields; /* number of fields in each row */ -} sourceObject; -#define is_sourceObject(v) (PyType(v) == &sourceType) - -typedef struct { - PyObject_HEAD connObject *pgcnx; /* parent connection object */ - PGresult const *res; /* an error or warning */ -} noticeObject; -#define is_noticeObject(v) (PyType(v) == ¬iceType) - -typedef struct { - PyObject_HEAD connObject *pgcnx; /* parent connection object */ - PGresult *result; /* result content */ - int async; /* flag for asynchronous queries */ - int encoding; /* client encoding */ - int current_row; /* currently selected row */ - int max_row; /* number of rows in the result */ - int num_fields; /* number of fields in each row */ - int *col_types; /* PyGreSQL column types */ -} queryObject; -#define is_queryObject(v) (PyType(v) == &queryType) - -typedef struct { - PyObject_HEAD connObject *pgcnx; /* parent connection object */ - Oid lo_oid; /* large object oid */ - int lo_fd; /* large object fd */ -} largeObject; -#define is_largeObject(v) (PyType(v) == &largeType) +/* Type objects are defined in their respective .c files */ -/* - A buffer for character data with routines to handle resizing. - This is inspired by libpq's PQExpBufferData. - The buffer can be extended with the extend_char_buffer_s/x() functions. -*/ -struct CharBuffer { - char *data; /* actual string data */ - size_t size; /* current size of data */ - size_t max_size; /* allocated size */ - int error; /* error flag (invalid data) */ -}; - -/* Internal functions */ -#include "pginternal.c" - -/* Connection object */ -#include "pgconn.c" - -/* Query object */ -#include "pgquery.c" - -/* Source object */ -#include "pgsource.c" - -/* Notice object */ -#include "pgnotice.c" +/* Object declarations are provided by shared headers */ -/* Large objects */ -#include "pglarge.c" +/* Object implementation files are compiled separately now */ /* MODULE FUNCTIONS */ diff --git a/ext/pgmodule.h b/ext/pgmodule.h new file mode 100644 index 00000000..c2ce06d9 --- /dev/null +++ b/ext/pgmodule.h @@ -0,0 +1,145 @@ +/* + * Shared header for the PyGreSQL C extension module. + * Declares common types, macros, and extern symbols used across files. + */ + +#ifndef PYGRE_SQL_PGMODULE_H +#define PYGRE_SQL_PGMODULE_H + +#define PY_SSIZE_T_CLEAN +#include +#include +#include + +#include "pgtypes.h" + +/* Default values */ +#define PG_ARRAYSIZE 1 + +/* Flags for object validity checks */ +#define CHECK_OPEN 1 +#define CHECK_CLOSE 2 +#define CHECK_CNX 4 +#define CHECK_RESULT 8 +#define CHECK_DQL 16 + +/* Query result types */ +#define RESULT_EMPTY 1 +#define RESULT_DML 2 +#define RESULT_DDL 3 +#define RESULT_DQL 4 + +/* Flags for move methods */ +#define QUERY_MOVEFIRST 1 +#define QUERY_MOVELAST 2 +#define QUERY_MOVENEXT 3 +#define QUERY_MOVEPREV 4 + +#define MAX_ARRAY_DEPTH 16 /* maximum allowed depth of an array */ + +/* Character buffer used by COPY and formatting helpers */ +struct CharBuffer { + char *data; /* actual string data */ + size_t size; /* current size of data */ + size_t max_size; /* allocated size */ + int error; /* error flag (invalid data) */ +}; + +/* Forward declarations for type objects (defined in their respective .c files) + */ +extern PyTypeObject connType; +extern PyTypeObject sourceType; +extern PyTypeObject noticeType; +extern PyTypeObject queryType; +extern PyTypeObject largeType; + +/* Exception types (created in pgmodule.c) */ +extern PyObject *Error, *Warning, *InterfaceError, *DatabaseError, + *InternalError, *OperationalError, *ProgrammingError, *IntegrityError, + *DataError, *NotSupportedError, *InvalidResultError, *NoResultError, + *MultipleResultsError; + +/* Module global configuration/state (defined in pgmodule.c) */ +extern PyObject *pg_default_host; /* default database host */ +extern PyObject *pg_default_base; /* default database name */ +extern PyObject *pg_default_opt; /* default connection options */ +extern PyObject *pg_default_port; /* default connection port */ +extern PyObject *pg_default_user; /* default username */ +extern PyObject *pg_default_passwd; /* default password */ + +extern PyObject *decimal; /* decimal type */ +extern PyObject *dictiter; /* function for getting dict results */ +extern PyObject *namediter; /* function for getting named results */ +extern PyObject *namednext; /* function for getting one named result */ +extern PyObject *scalariter; /* function for getting scalar results */ +extern PyObject *jsondecode; /* function for decoding json strings */ +extern const char *date_format; /* date format that is always assumed */ +extern char decimal_point; /* decimal point used in money values */ +extern int bool_as_text; /* whether bool shall be returned as text */ +extern int array_as_text; /* whether arrays shall be returned as text */ +extern int bytea_escaped; /* whether bytea shall be returned escaped */ + +extern int pg_encoding_utf8; +extern int pg_encoding_latin1; +extern int pg_encoding_ascii; + +/* Object declarations */ +typedef struct { + PyObject_HEAD int valid; /* validity flag */ + PGconn *cnx; /* Postgres connection handle */ + const char *date_format; /* date format derived from datestyle */ + PyObject *cast_hook; /* external typecast method */ + PyObject *notice_receiver; /* current notice receiver */ +} connObject; + +#define is_connObject(v) (PyType(v) == &connType) + +typedef struct { + PyObject_HEAD int valid; /* validity flag */ + connObject *pgcnx; /* parent connection object */ + PGresult *result; /* result content */ + int encoding; /* client encoding */ + int result_type; /* result type (DDL/DML/DQL) */ + long arraysize; /* array size for fetch method */ + int current_row; /* currently selected row */ + int max_row; /* number of rows in the result */ + int num_fields; /* number of fields in each row */ +} sourceObject; + +#define is_sourceObject(v) (PyType(v) == &sourceType) + +typedef struct { + PyObject_HEAD connObject *pgcnx; /* parent connection object */ + PGresult const *res; /* an error or warning */ +} noticeObject; + +#define is_noticeObject(v) (PyType(v) == ¬iceType) + +typedef struct { + PyObject_HEAD connObject *pgcnx; /* parent connection object */ + PGresult *result; /* result content */ + int async; /* flag for asynchronous queries */ + int encoding; /* client encoding */ + int current_row; /* currently selected row */ + int max_row; /* number of rows in the result */ + int num_fields; /* number of fields in each row */ + int *col_types; /* PyGreSQL column types */ +} queryObject; + +#define is_queryObject(v) (PyType(v) == &queryType) + +typedef struct { + PyObject_HEAD connObject *pgcnx; /* parent connection object */ + Oid lo_oid; /* large object oid */ + int lo_fd; /* large object fd */ +} largeObject; + +#define is_largeObject(v) (PyType(v) == &largeType) + +/* Cross-object helpers exported by pgconn.c */ +int +_check_cnx_obj(connObject *self); +PyObject * +_conn_non_query_result(int status, PGresult *result, PGconn *cnx); + +#endif /* PYGRE_SQL_PGMODULE_H */ diff --git a/ext/pgnotice.c b/ext/pgnotice.c index 0d64d6cc..9922e628 100644 --- a/ext/pgnotice.c +++ b/ext/pgnotice.c @@ -8,6 +8,10 @@ * Please see the LICENSE.TXT file for specific restrictions. */ +/* Shared headers */ +#include "pginternal.h" +#include "pgmodule.h" + /* Get notice object attributes. */ static PyObject * notice_getattr(noticeObject *self, PyObject *nameobj) @@ -89,7 +93,7 @@ static struct PyMethodDef notice_methods[] = { static char notice__doc__[] = "PostgreSQL notice object"; /* Notice type definition */ -static PyTypeObject noticeType = { +PyTypeObject noticeType = { PyVarObject_HEAD_INIT(NULL, 0) "pg.Notice", /* tp_name */ sizeof(noticeObject), /* tp_basicsize */ 0, /* tp_itemsize */ diff --git a/ext/pgquery.c b/ext/pgquery.c index 84a69fdd..bf921a44 100644 --- a/ext/pgquery.c +++ b/ext/pgquery.c @@ -8,6 +8,10 @@ * Please see the LICENSE.TXT file for specific restrictions. */ +/* Shared headers */ +#include "pginternal.h" +#include "pgmodule.h" + /* Deallocate the query object. */ static void query_dealloc(queryObject *self) @@ -969,7 +973,7 @@ static struct PyMethodDef query_methods[] = { static char query__doc__[] = "PyGreSQL query object"; /* Query type definition */ -static PyTypeObject queryType = { +PyTypeObject queryType = { PyVarObject_HEAD_INIT(NULL, 0) "pg.Query", /* tp_name */ sizeof(queryObject), /* tp_basicsize */ 0, /* tp_itemsize */ diff --git a/ext/pgsource.c b/ext/pgsource.c index 972d7e76..b8610eb3 100644 --- a/ext/pgsource.c +++ b/ext/pgsource.c @@ -8,6 +8,10 @@ * Please see the LICENSE.TXT file for specific restrictions. */ +/* Shared headers */ +#include "pginternal.h" +#include "pgmodule.h" + /* Deallocate source object. */ static void source_dealloc(sourceObject *self) @@ -796,7 +800,7 @@ static PyMethodDef source_methods[] = { static char source__doc__[] = "PyGreSQL source object"; /* Source type definition */ -static PyTypeObject sourceType = { +PyTypeObject sourceType = { PyVarObject_HEAD_INIT(NULL, 0) "pgdb.Source", /* tp_name */ sizeof(sourceObject), /* tp_basicsize */ 0, /* tp_itemsize */ diff --git a/setup.py b/setup.py index 288cc303..eaf04862 100755 --- a/setup.py +++ b/setup.py @@ -224,7 +224,15 @@ def build_extensions(self): packages=["pg", "pgdb"], package_data={"pg": ["py.typed"], "pgdb": ["py.typed"]}, ext_modules=[Extension( - 'pg._pg', ["ext/pgmodule.c"], + 'pg._pg', [ + "ext/pgmodule.c", + "ext/pginternal.c", + "ext/pgconn.c", + "ext/pgquery.c", + "ext/pgsource.c", + "ext/pgnotice.c", + "ext/pglarge.c", + ], include_dirs=include_dirs, library_dirs=library_dirs, define_macros=define_macros, undef_macros=undef_macros, libraries=libraries, extra_compile_args=extra_compile_args)], From e439b2fe533a887611202927076b70a48cc3d32a Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 9 Jan 2026 17:43:51 +0000 Subject: [PATCH 109/118] Always use snprintf instead of sprintf --- ext/pginternal.c | 19 ++++++++++++------- ext/pglarge.c | 8 ++++---- ext/pgmodule.c | 2 +- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/ext/pginternal.c b/ext/pginternal.c index ad8bfffc..f8feb71f 100644 --- a/ext/pginternal.c +++ b/ext/pginternal.c @@ -1343,8 +1343,10 @@ format_result(const PGresult *res) const size_t k = sizes[j]; const size_t h = (k - (size_t)strlen(s)) / 2; - sprintf(p, "%*s", (int)h, ""); - sprintf(p + h, "%-*s", (int)(k - h), s); + snprintf(p, size - (size_t)(p - buffer), "%*s", (int)h, + ""); + snprintf(p + h, size - (size_t)(p + h - buffer), "%-*s", + (int)(k - h), s); p += k; if (j + 1 < n) *p++ = '|'; @@ -1365,12 +1367,14 @@ format_result(const PGresult *res) const size_t k = sizes[j]; if (align) { - sprintf(p, align == 'r' ? "%*s" : "%-*s", (int)k, - PQgetvalue(res, i, j)); + snprintf(p, size - (size_t)(p - buffer), + align == 'r' ? "%*s" : "%-*s", (int)k, + PQgetvalue(res, i, j)); } else { - sprintf(p, "%-*s", (int)k, - PQgetisnull(res, i, j) ? "" : ""); + snprintf(p, size - (size_t)(p - buffer), "%-*s", + (int)k, + PQgetisnull(res, i, j) ? "" : ""); } p += k; if (j + 1 < n) @@ -1382,7 +1386,8 @@ format_result(const PGresult *res) PyMem_Free(aligns); PyMem_Free(sizes); /* create the footer */ - sprintf(p, "(%d row%s)", m, m == 1 ? "" : "s"); + snprintf(p, size - (size_t)(p - buffer), "(%d row%s)", m, + m == 1 ? "" : "s"); /* return the result */ result = PyUnicode_FromString(buffer); PyMem_Free(buffer); diff --git a/ext/pglarge.c b/ext/pglarge.c index 46740f3b..026ba12f 100644 --- a/ext/pglarge.c +++ b/ext/pglarge.c @@ -32,10 +32,10 @@ static PyObject * large_str(largeObject *self) { char str[80]; - sprintf(str, - self->lo_fd >= 0 ? "Opened large object, oid %ld" - : "Closed large object, oid %ld", - (long)self->lo_oid); + snprintf(str, sizeof(str), + self->lo_fd >= 0 ? "Opened large object, oid %ld" + : "Closed large object, oid %ld", + (long)self->lo_oid); return PyUnicode_FromString(str); } diff --git a/ext/pgmodule.c b/ext/pgmodule.c index 6b04787e..c3ebf717 100644 --- a/ext/pgmodule.c +++ b/ext/pgmodule.c @@ -210,7 +210,7 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) } if (pgport != -1) { memset(port_buffer, 0, sizeof(port_buffer)); - sprintf(port_buffer, "%d", pgport); + snprintf(port_buffer, sizeof(port_buffer), "%d", pgport); keywords[nkw] = "port"; values[nkw++] = port_buffer; From 7805b843420729c658fe00a2b023de83ed605d76 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 9 Jan 2026 17:55:15 +0000 Subject: [PATCH 110/118] Catch error in PyUnicode_AsUTF8 --- ext/pgconn.c | 2 ++ ext/pglarge.c | 2 ++ ext/pgnotice.c | 3 +++ ext/pgsource.c | 2 ++ 4 files changed, 9 insertions(+) diff --git a/ext/pgconn.c b/ext/pgconn.c index 31767a3d..36b302c5 100644 --- a/ext/pgconn.c +++ b/ext/pgconn.c @@ -31,6 +31,8 @@ static PyObject * conn_getattr(connObject *self, PyObject *nameobj) { const char *name = PyUnicode_AsUTF8(nameobj); + if (!name) + return NULL; /* * Although we could check individually, there are only a few diff --git a/ext/pglarge.c b/ext/pglarge.c index 026ba12f..0cb1ab31 100644 --- a/ext/pglarge.c +++ b/ext/pglarge.c @@ -73,6 +73,8 @@ static PyObject * large_getattr(largeObject *self, PyObject *nameobj) { const char *name = PyUnicode_AsUTF8(nameobj); + if (!name) + return NULL; /* list postgreSQL large object fields */ diff --git a/ext/pgnotice.c b/ext/pgnotice.c index 9922e628..a2079ef8 100644 --- a/ext/pgnotice.c +++ b/ext/pgnotice.c @@ -20,6 +20,9 @@ notice_getattr(noticeObject *self, PyObject *nameobj) const char *name = PyUnicode_AsUTF8(nameobj); int fieldcode; + if (!name) + return NULL; + if (!res) { PyErr_SetString(PyExc_TypeError, "Cannot get current notice"); return NULL; diff --git a/ext/pgsource.c b/ext/pgsource.c index b8610eb3..303b41b8 100644 --- a/ext/pgsource.c +++ b/ext/pgsource.c @@ -70,6 +70,8 @@ static PyObject * source_getattr(sourceObject *self, PyObject *nameobj) { const char *name = PyUnicode_AsUTF8(nameobj); + if (!name) + return NULL; /* pg connection object */ if (!strcmp(name, "pgcnx")) { From c3aee23a5be3432faaa8d18102477dcc06a9fc59 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 9 Jan 2026 18:51:35 +0000 Subject: [PATCH 111/118] Remove unnecessary macro section in pgmodule.c --- ext/pgmodule.c | 39 --------------------------------------- 1 file changed, 39 deletions(-) diff --git a/ext/pgmodule.c b/ext/pgmodule.c index c3ebf717..d93bb4a6 100644 --- a/ext/pgmodule.c +++ b/ext/pgmodule.c @@ -32,45 +32,6 @@ static const char *PyPgVersion = TOSTRING(PYGRESQL_VERSION); #define Py_InitModule4 Py_InitModule4_64 #endif -/* Default values */ -#undef PG_ARRAYSIZE -#define PG_ARRAYSIZE 1 - -/* Flags for object validity checks */ -#undef CHECK_OPEN -#undef CHECK_CLOSE -#undef CHECK_CNX -#undef CHECK_RESULT -#undef CHECK_DQL -#define CHECK_OPEN 1 -#define CHECK_CLOSE 2 -#define CHECK_CNX 4 -#define CHECK_RESULT 8 -#define CHECK_DQL 16 - -/* Query result types */ -#undef RESULT_EMPTY -#undef RESULT_DML -#undef RESULT_DDL -#undef RESULT_DQL -#define RESULT_EMPTY 1 -#define RESULT_DML 2 -#define RESULT_DDL 3 -#define RESULT_DQL 4 - -/* Flags for move methods */ -#undef QUERY_MOVEFIRST -#undef QUERY_MOVELAST -#undef QUERY_MOVENEXT -#undef QUERY_MOVEPREV -#define QUERY_MOVEFIRST 1 -#define QUERY_MOVELAST 2 -#define QUERY_MOVENEXT 3 -#define QUERY_MOVEPREV 4 - -#undef MAX_ARRAY_DEPTH -#define MAX_ARRAY_DEPTH 16 /* maximum allowed depth of an array */ - /* MODULE GLOBAL VARIABLES */ PyObject *pg_default_host; /* default database host */ From 15ab3ba93da5dd754f3014a7e809dfbcf7ef52d3 Mon Sep 17 00:00:00 2001 From: justinpryzby Date: Sun, 25 Jan 2026 13:26:34 -0600 Subject: [PATCH 112/118] fix memory leak and Inserttable batch (#92) --- ext/pgconn.c | 50 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/ext/pgconn.c b/ext/pgconn.c index 36b302c5..f5a325c0 100644 --- a/ext/pgconn.c +++ b/ext/pgconn.c @@ -862,6 +862,8 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) Py_END_ALLOW_THREADS if (!result || PQresultStatus(result) != PGRES_COPY_IN) { + if (result) + PQclear(result); PyMem_Free(buffer.data); Py_DECREF(iter_row); PyErr_SetString(PyExc_ValueError, PQerrorMessage(self->cnx)); @@ -870,6 +872,9 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) PQclear(result); + /* empty buffer while keeping allocated memory */ + buffer.size = 0; + /* feed table */ for (i = 0; m < 0 || i < m; ++i) { if (!(columns = PyIter_Next(iter_row))) @@ -901,9 +906,6 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) return NULL; } - /* empty buffer while keeping allocated memory */ - buffer.size = 0; - /* build insert line */ for (j = 0; j < n; ++j) { @@ -1032,22 +1034,39 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) return PyErr_NoMemory(); } - /* send data */ - ret = PQputCopyData(self->cnx, buffer.data, (int)buffer.size); - if (ret != 1) { - char *errormsg = ret == -1 ? PQerrorMessage(self->cnx) - : "Data cannot be queued"; - PyErr_SetString(PyExc_IOError, errormsg); - PQputCopyEnd(self->cnx, errormsg); - PyMem_Free(buffer.data); - Py_DECREF(iter_row); - return NULL; + if (buffer.size > 128 * 1024) { + /* send buffered data */ + ret = PQputCopyData(self->cnx, buffer.data, (int)buffer.size); + buffer.size = 0; + if (ret != 1) { + char *errormsg = ret == -1 ? PQerrorMessage(self->cnx) + : "Data cannot be queued"; + PyErr_SetString(PyExc_IOError, errormsg); + PQputCopyEnd(self->cnx, errormsg); + PyMem_Free(buffer.data); + Py_DECREF(iter_row); + return NULL; + } } } + /* flush any remaining data */ + // XXX: if buffer.size + ret = PQputCopyData(self->cnx, buffer.data, (int)buffer.size); + if (ret != 1) { + char *errormsg = + ret == -1 ? PQerrorMessage(self->cnx) : "Data cannot be queued"; + PyErr_SetString(PyExc_IOError, errormsg); + PQputCopyEnd(self->cnx, errormsg); + PyMem_Free(buffer.data); + Py_DECREF(iter_row); + return NULL; + } + + PyMem_Free(buffer.data); + Py_DECREF(iter_row); if (PyErr_Occurred()) { - PyMem_Free(buffer.data); return NULL; /* pass the iteration error */ } @@ -1055,12 +1074,9 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) if (ret != 1) { PyErr_SetString(PyExc_IOError, ret == -1 ? PQerrorMessage(self->cnx) : "Data cannot be queued"); - PyMem_Free(buffer.data); return NULL; } - PyMem_Free(buffer.data); - Py_BEGIN_ALLOW_THREADS result = PQgetResult(self->cnx); Py_END_ALLOW_THREADS From 060931703c65581b7a8573ae3eda1806dc0fa0f3 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 25 Jan 2026 19:47:22 +0000 Subject: [PATCH 113/118] inserttable: flush remaining only if needed --- ext/pgconn.c | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/ext/pgconn.c b/ext/pgconn.c index f5a325c0..a55eb295 100644 --- a/ext/pgconn.c +++ b/ext/pgconn.c @@ -1035,7 +1035,7 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) } if (buffer.size > 128 * 1024) { - /* send buffered data */ + /* send buffered data after reaching 128KB */ ret = PQputCopyData(self->cnx, buffer.data, (int)buffer.size); buffer.size = 0; if (ret != 1) { @@ -1051,16 +1051,18 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) } /* flush any remaining data */ - // XXX: if buffer.size - ret = PQputCopyData(self->cnx, buffer.data, (int)buffer.size); - if (ret != 1) { - char *errormsg = - ret == -1 ? PQerrorMessage(self->cnx) : "Data cannot be queued"; - PyErr_SetString(PyExc_IOError, errormsg); - PQputCopyEnd(self->cnx, errormsg); - PyMem_Free(buffer.data); - Py_DECREF(iter_row); - return NULL; + if (buffer.size) { + ret = PQputCopyData(self->cnx, buffer.data, (int)buffer.size); + buffer.size = 0; + if (ret != 1) { + char *errormsg = ret == -1 ? PQerrorMessage(self->cnx) + : "Data cannot be queued"; + PyErr_SetString(PyExc_IOError, errormsg); + PQputCopyEnd(self->cnx, errormsg); + PyMem_Free(buffer.data); + Py_DECREF(iter_row); + return NULL; + } } PyMem_Free(buffer.data); From fdb0931f0ceff62681adfa0e016bd75663c6dbc3 Mon Sep 17 00:00:00 2001 From: justinpryzby Date: Sun, 25 Jan 2026 13:52:00 -0600 Subject: [PATCH 114/118] inserttable: use str(datetime) Previously, datetime fields needed to be str()'ified by the caller --- ext/pgconn.c | 16 ++++++++++++++++ tests/test_classic_connection.py | 7 +++++++ 2 files changed, 23 insertions(+) diff --git a/ext/pgconn.c b/ext/pgconn.c index a55eb295..57c86962 100644 --- a/ext/pgconn.c +++ b/ext/pgconn.c @@ -12,6 +12,9 @@ #include "pginternal.h" #include "pgmodule.h" +/* Needs to be after Python.h */ +#include "datetime.h" + /* Deallocate connection object. */ static void conn_dealloc(connObject *self) @@ -789,6 +792,11 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) encoding = PQclientEncoding(self->cnx); + PyDateTime_IMPORT; + if (PyErr_Occurred()) { + return NULL; /* pass the error */ + } + /* pre-allocate some memory for the query buffer */ if (!init_char_buffer(&buffer, 4096)) { Py_DECREF(iter_row); @@ -992,6 +1000,14 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) ext_char_buffer_s(&buffer, t); Py_DECREF(s); } + else if (PyDate_Check(item) || PyDateTime_Check(item) || + PyTime_Check(item) || PyDelta_Check(item)) { + PyObject *s = PyObject_Str(item); + const char *t = PyUnicode_AsUTF8(s); + + ext_char_buffer_s(&buffer, t); + Py_DECREF(s); + } else { PyObject *s = PyObject_Repr(item); const char *t = PyUnicode_AsUTF8(s); diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 3bd36495..c6ced8cd 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -11,6 +11,7 @@ from __future__ import annotations +import datetime as dt import os import threading import time @@ -1897,6 +1898,12 @@ def test_inserttable_no_column(self): self.c.inserttable('test', data, []) self.assertEqual(self.get_back(), []) + def test_inserttable_datetime_adapt(self): + data = [(dt.date(1999,1,2), dt.time(11,12,13))] + self.c.inserttable('test', data, ['dt', 'ti']) + self.assertEqual([i[4:6] for i in self.get_back()], + [tuple(str(j) for j in i) for i in data]) + def test_inserttable_only_one_column(self): data: list[tuple] = [(42,)] * 50 self.c.inserttable('test', data, ['i4']) From fe08ff24537cc9a087df9ef024bb6802345c2318 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 25 Jan 2026 20:41:37 +0000 Subject: [PATCH 115/118] inserttable: test all date types --- tests/test_classic_connection.py | 106 ++++++++++++++++++------------- 1 file changed, 63 insertions(+), 43 deletions(-) diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index c6ced8cd..7dc24053 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -11,7 +11,7 @@ from __future__ import annotations -import datetime as dt +import datetime import os import threading import time @@ -1712,7 +1712,7 @@ def setUpClass(cls): c.query("drop table if exists test cascade") c.query("create table test (" "i2 smallint, i4 integer, i8 bigint," - "b boolean, dt date, ti time," + "b boolean, dt date, ti time, ts timestamp, td interval," "d numeric, f4 real, f8 double precision, m money," "c char(1), v4 varchar(4), c4 char(4), t text)") # Check whether the test database uses SQL_ASCII - this means @@ -1746,13 +1746,17 @@ def tearDown(self): self.c.close() data: Sequence[tuple] = [ - (-1, -1, -1, True, '1492-10-12', '08:30:00', + (-1, -1, -1, True, + '1492-10-12', '08:30:00', '1492-10-12 08:30:00', '-3 days', -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'), - (0, 0, 0, False, '1607-04-14', '09:00:00', + (0, 0, 0, False, + '1607-04-14', '09:00:00', '1607-04-14 09:00:00', '7 days', 0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'), - (1, 1, 1, True, '1801-03-04', '03:45:00', + (1, 1, 1, True, + '1801-03-04', '03:45:00', '1801-03-04 03:45:00', '3 mons', 1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'), - (2, 2, 2, False, '1903-12-17', '11:22:00', + (2, 2, 2, False, + '1903-12-17', '11:22:00', '1903-12-17 11:22:00', '1 year', 2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')] @classmethod @@ -1784,29 +1788,37 @@ def get_back(self, encoding='utf-8'): if row[5] is not None: # time self.assertIsInstance(row[5], str) self.assertTrue(row[5].replace(':', '').isdigit()) - if row[6] is not None: # numeric - self.assertIsInstance(row[6], Decimal) - row[6] = float(row[6]) - if row[7] is not None: # real - self.assertIsInstance(row[7], float) - if row[8] is not None: # double precision - self.assertIsInstance(row[8], float) + if row[6] is not None: # timestamp + self.assertIsInstance(row[6], str) + parts = row[6].split(' ') + self.assertEqual(len(parts), 2) + self.assertTrue(parts[0].replace('-', '').isdigit()) + self.assertTrue(parts[1].replace(':', '').isdigit()) + if row[7] is not None: # interval + self.assertIsInstance(row[7], str) + if row[8] is not None: # numeric + self.assertIsInstance(row[8], Decimal) row[8] = float(row[8]) - if row[9] is not None: # money - self.assertIsInstance(row[9], Decimal) - row[9] = str(float(row[9])) - if row[10] is not None: # char(1) - self.assertIsInstance(row[10], str) - self.assertEqual(self.db_len(row[10], encoding), 1) - if row[11] is not None: # varchar(4) - self.assertIsInstance(row[11], str) - self.assertLessEqual(self.db_len(row[11], encoding), 4) - if row[12] is not None: # char(4) + if row[9] is not None: # real + self.assertIsInstance(row[9], float) + if row[10] is not None: # double precision + self.assertIsInstance(row[10], float) + row[10] = float(row[10]) + if row[11] is not None: # money + self.assertIsInstance(row[11], Decimal) + row[11] = str(float(row[11])) + if row[12] is not None: # char(1) self.assertIsInstance(row[12], str) - self.assertEqual(self.db_len(row[12], encoding), 4) - row[12] = row[12].rstrip() - if row[13] is not None: # text + self.assertEqual(self.db_len(row[12], encoding), 1) + if row[13] is not None: # varchar(4) self.assertIsInstance(row[13], str) + self.assertLessEqual(self.db_len(row[13], encoding), 4) + if row[14] is not None: # char(4) + self.assertIsInstance(row[14], str) + self.assertEqual(self.db_len(row[14], encoding), 4) + row[14] = row[14].rstrip() + if row[15] is not None: # text + self.assertIsInstance(row[15], str) row = tuple(row) data.append(row) return data @@ -1889,7 +1901,7 @@ def test_inserttable_multiple_calls(self): self.assertEqual(r, num_rows) def test_inserttable_null_values(self): - data = [(None,) * 14] * 100 + data = [(None,) * 16] * 100 self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) @@ -1899,22 +1911,25 @@ def test_inserttable_no_column(self): self.assertEqual(self.get_back(), []) def test_inserttable_datetime_adapt(self): - data = [(dt.date(1999,1,2), dt.time(11,12,13))] - self.c.inserttable('test', data, ['dt', 'ti']) - self.assertEqual([i[4:6] for i in self.get_back()], - [tuple(str(j) for j in i) for i in data]) + data = [(datetime.date(1999, 1, 2), datetime.time(11, 12, 13), + datetime.datetime(1999, 1, 2, 11, 12, 13), + datetime.timedelta(days=123))] + self.c.inserttable('test', data, ['dt', 'ti', 'ts', 'td']) + back = [row[4:8] for row in self.get_back()] + self.assertEqual(back, [( + '1999-01-02', '11:12:13', '1999-01-02 11:12:13', '123 days')]) def test_inserttable_only_one_column(self): data: list[tuple] = [(42,)] * 50 self.c.inserttable('test', data, ['i4']) - data = [tuple([42 if i == 1 else None for i in range(14)])] * 50 + data = [tuple([42 if i == 1 else None for i in range(16)])] * 50 self.assertEqual(self.get_back(), data) def test_inserttable_only_two_columns(self): data: list[tuple] = [(bool(i % 2), i * .5) for i in range(20)] self.c.inserttable('test', data, ('b', 'f4')) # noinspection PyTypeChecker - data = [(None,) * 3 + (bool(i % 2),) + (None,) * 3 + (i * .5,) + data = [(None,) * 3 + (bool(i % 2),) + (None,) * 5 + (i * .5,) + (None,) * 6 for i in range(20)] self.assertEqual(self.get_back(), data) @@ -1985,9 +2000,9 @@ def test_inserttable_with_out_of_range_data(self): ValueError, self.c.inserttable, 'test', [[33000]], ['i2']) def test_inserttable_max_values(self): - data = [(2 ** 15 - 1, 2 ** 31 - 1, 2 ** 31 - 1, - True, '2999-12-31', '11:59:59', 1e99, - 1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None, + data = [(2 ** 15 - 1, 2 ** 31 - 1, 2 ** 31 - 1, True, + '2999-12-31', '11:59:59', '2999-12-31 23:59:59', '9999 years', + 1e99, 1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None, "1", "1234", "1234", "1234" * 100)] self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) @@ -2000,7 +2015,8 @@ def test_inserttable_byte_values(self): # non-ascii chars do not fit in char(1) when there is no encoding c = '€' if self.has_encoding else '$' row_unicode = ( - 0, 0, 0, False, '1970-01-01', '00:00:00', + 0, 0, 0, False, + '1970-01-01', '00:00:00', '1970-01-01 00:00:00', '00:00:00', 0.0, 0.0, 0.0, '0.0', c, 'bäd', 'bäd', "käse сыр pont-l'évêque") row_bytes = tuple( @@ -2019,7 +2035,8 @@ def test_inserttable_unicode_utf8(self): # non-ascii chars do not fit in char(1) when there is no encoding c = '€' if self.has_encoding else '$' row_unicode = ( - 0, 0, 0, False, '1970-01-01', '00:00:00', + 0, 0, 0, False, + '1970-01-01', '00:00:00', '1970-01-01 00:00:00', '00:00:00', 0.0, 0.0, 0.0, '0.0', c, 'bäd', 'bäd', "käse сыр pont-l'évêque") data = [row_unicode] * 2 @@ -2035,7 +2052,8 @@ def test_inserttable_unicode_latin1(self): # non-ascii chars do not fit in char(1) when there is no encoding c = '€' if self.has_encoding else '$' row_unicode: tuple = ( - 0, 0, 0, False, '1970-01-01', '00:00:00', + 0, 0, 0, False, + '1970-01-01', '00:00:00', '1970-01-01 00:00:00', '00:00:00', 0.0, 0.0, 0.0, '0.0', c, 'bäd', 'bäd', "for käse and pont-l'évêque pay in €") data = [row_unicode] @@ -2058,7 +2076,8 @@ def test_inserttable_unicode_latin9(self): # non-ascii chars do not fit in char(1) when there is no encoding c = '€' if self.has_encoding else '$' row_unicode = ( - 0, 0, 0, False, '1970-01-01', '00:00:00', + 0, 0, 0, False, + '1970-01-01', '00:00:00', '1970-01-01 00:00:00', '00:00:00', 0.0, 0.0, 0.0, '0.0', c, 'bäd', 'bäd', "for käse and pont-l'évêque pay in €") data = [row_unicode] * 2 @@ -2070,7 +2089,8 @@ def test_inserttable_no_encoding(self): # non-ascii chars do not fit in char(1) when there is no encoding c = '€' if self.has_encoding else '$' row_unicode = ( - 0, 0, 0, False, '1970-01-01', '00:00:00', + 0, 0, 0, False, + '1970-01-01', '00:00:00', '1970-01-01 00:00:00', '00:00:00', 0.0, 0.0, 0.0, '0.0', c, 'bäd', 'bäd', "for käse and pont-l'évêque pay in €") data = [row_unicode] @@ -2080,12 +2100,12 @@ def test_inserttable_no_encoding(self): def test_inserttable_from_query(self): data = self.c.query( "select 2::int2 as i2, 4::int4 as i4, 8::int8 as i8, true as b," - "null as dt, null as ti, null as d," + "null as dt, null as ti, null as ts, null as td, null as d," "4.5::float as float4, 8.5::float8 as f8," "null as m, 'c' as c, 'v4' as v4, null as c4, 'text' as text") self.c.inserttable('test', data) self.assertEqual(self.get_back(), [ - (2, 4, 8, True, None, None, None, 4.5, 8.5, + (2, 4, 8, True, None, None, None, None, None, 4.5, 8.5, None, 'c', 'v4', None, 'text')]) def test_inserttable_special_chars(self): From aa6fee8c00adff14a060f591df39a7f2ae771f3a Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 25 Jan 2026 21:31:00 +0000 Subject: [PATCH 116/118] inserttable: update documentation --- docs/contents/pg/connection.rst | 3 ++- ext/pgconn.c | 13 ++++--------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index 21ff2d79..cf901210 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -495,7 +495,8 @@ This method allows to *quickly* insert large blocks of data in a table. Internally, it uses the COPY command of the PostgreSQL database. The method takes an iterable of row values which must be tuples or lists of the same size, containing the values for each inserted row. -These may contain string, integer, long or double (real) values. +These may contain string, integer, long or double (real) values as well as +date, time, datetime or timedelta objects. ``columns`` is an optional tuple or list of column names to be passed on to the COPY command. The number of rows affected is returned. diff --git a/ext/pgconn.c b/ext/pgconn.c index 57c86962..dbab58bd 100644 --- a/ext/pgconn.c +++ b/ext/pgconn.c @@ -792,6 +792,7 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) encoding = PQclientEncoding(self->cnx); + /* import datetime C API (this is not compatible with subinterpreters) */ PyDateTime_IMPORT; if (PyErr_Occurred()) { return NULL; /* pass the error */ @@ -993,15 +994,9 @@ conn_inserttable(connObject *self, PyObject *args, PyObject *kwds) Py_DECREF(s); } } - else if (PyLong_Check(item)) { - PyObject *s = PyObject_Str(item); - const char *t = PyUnicode_AsUTF8(s); - - ext_char_buffer_s(&buffer, t); - Py_DECREF(s); - } - else if (PyDate_Check(item) || PyDateTime_Check(item) || - PyTime_Check(item) || PyDelta_Check(item)) { + else if (PyLong_Check(item) || PyDate_Check(item) || + PyDateTime_Check(item) || PyTime_Check(item) || + PyDelta_Check(item)) { PyObject *s = PyObject_Str(item); const char *t = PyUnicode_AsUTF8(s); From fff495afb40a2a469fb5463e0893dfa98ad99e72 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Sun, 25 Jan 2026 21:44:33 +0000 Subject: [PATCH 117/118] Bump version and update changelog --- docs/contents/changelog.rst | 9 +++++++++ pyproject.toml | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index 76747403..c86cb097 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,15 @@ ChangeLog ========= +Version 6.2.3 (2026-01-25) +-------------------------- +- Minor improvements and fixes (thanks to Justin Pryzby): + + - The `inserttable()` method in the `pg` module can now directly import + date, time, datetime and timedelta objects from Python (#93) and was + slightly improved to avoid memory exhaustion and memory leaks (#92). + - The C extension has no been properly modularized into separate units (#25). + Version 6.2.2 (2026-01-03) -------------------------- - The `inserttable()` method in the `pg` module can now handle rows of diff --git a/pyproject.toml b/pyproject.toml index 2777003f..c12c8196 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "PyGreSQL" -version = "6.2.2" +version = "6.2.3" requires-python = ">=3.8" authors = [ {name = "D'Arcy J. M. Cain", email = "darcy@pygresql.org"}, From 3359feb4524fba3718b4f9638212d51db229829b Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Mon, 26 Jan 2026 08:41:33 +0100 Subject: [PATCH 118/118] Fix typo --- docs/contents/changelog.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index c86cb097..ac428b9e 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -8,7 +8,7 @@ Version 6.2.3 (2026-01-25) - The `inserttable()` method in the `pg` module can now directly import date, time, datetime and timedelta objects from Python (#93) and was slightly improved to avoid memory exhaustion and memory leaks (#92). - - The C extension has no been properly modularized into separate units (#25). + - The C extension has been properly modularized into separate units (#25). Version 6.2.2 (2026-01-03) --------------------------