From 5632a2ba266ae7eb6930342ae1eead3348f971e7 Mon Sep 17 00:00:00 2001 From: Evert Date: Sat, 12 Jul 2025 18:15:02 +0200 Subject: [PATCH 001/472] 7 PRs between v1.3.1 and 1.3.2 - Add start/end offset percentage options to Python test runner (#18091) - Switch to Optional for type hints in polars lazy dataframe function (#18078) - Use `timestamp_t` instead of `time_t` for file last modified time (#18037) - fix star expr exclude error (#18063) - Remove match-case statements from polars_io.py (#18052) - Add support to produce Polars Lazy Dataframes (#17947) - Implement consumption and production of Arrow Binary View (#17975) --- duckdb/__init__.pyi | 4 +- duckdb/polars_io.py | 211 ++++++++ external/duckdb | 2 +- scripts/cache_data.json | 19 +- scripts/connection_methods.json | 7 + scripts/imports.py | 4 + sqllogic/conftest.py | 87 +++- src/duckdb_py/arrow/arrow_array_stream.cpp | 7 +- src/duckdb_py/duckdb_python.cpp | 6 +- .../import_cache/modules/duckdb_module.hpp | 22 +- .../pyconnection/pyconnection.hpp | 2 +- .../include/duckdb_python/pyfilesystem.hpp | 3 +- .../include/duckdb_python/pyrelation.hpp | 2 +- src/duckdb_py/pyconnection.cpp | 6 +- src/duckdb_py/pyexpression.cpp | 5 +- src/duckdb_py/pyfilesystem.cpp | 4 +- src/duckdb_py/pyrelation.cpp | 31 +- src/duckdb_py/pyrelation/initialize.cpp | 2 +- tests/fast/arrow/test_arrow_binary_view.py | 20 + tests/fast/arrow/test_polars.py | 479 ++++++++++++++++++ 20 files changed, 881 insertions(+), 42 deletions(-) create mode 100644 duckdb/polars_io.py create mode 100644 tests/fast/arrow/test_arrow_binary_view.py diff --git a/duckdb/__init__.pyi b/duckdb/__init__.pyi index 8723f2bf..7ed8b4e1 100644 --- a/duckdb/__init__.pyi +++ b/duckdb/__init__.pyi @@ -318,7 +318,7 @@ class DuckDBPyConnection: def fetch_df(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... def df(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... def fetch_df_chunk(self, vectors_per_chunk: int = 1, *, date_as_object: bool = False) -> pandas.DataFrame: ... - def pl(self, rows_per_batch: int = 1000000) -> polars.DataFrame: ... + def pl(self, rows_per_batch: int = 1000000, *, lazy: bool = False) -> polars.DataFrame: ... def fetch_arrow_table(self, rows_per_batch: int = 1000000) -> pyarrow.lib.Table: ... def arrow(self, rows_per_batch: int = 1000000) -> pyarrow.lib.Table: ... def fetch_record_batch(self, rows_per_batch: int = 1000000) -> pyarrow.lib.RecordBatchReader: ... @@ -666,7 +666,7 @@ def fetchdf(*, date_as_object: bool = False, connection: DuckDBPyConnection = .. def fetch_df(*, date_as_object: bool = False, connection: DuckDBPyConnection = ...) -> pandas.DataFrame: ... def df(*, date_as_object: bool = False, connection: DuckDBPyConnection = ...) -> pandas.DataFrame: ... def fetch_df_chunk(vectors_per_chunk: int = 1, *, date_as_object: bool = False, connection: DuckDBPyConnection = ...) -> pandas.DataFrame: ... -def pl(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> polars.DataFrame: ... +def pl(rows_per_batch: int = 1000000, *, lazy: bool = False, connection: DuckDBPyConnection = ...) -> polars.DataFrame: ... def fetch_arrow_table(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.Table: ... def arrow(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.Table: ... def fetch_record_batch(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.RecordBatchReader: ... diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py new file mode 100644 index 00000000..bea155e7 --- /dev/null +++ b/duckdb/polars_io.py @@ -0,0 +1,211 @@ +import duckdb +import polars as pl +from typing import Iterator, Optional + +from polars.io.plugins import register_io_source +from duckdb import SQLExpression +import json +from decimal import Decimal +import datetime + +def _predicate_to_expression(predicate: pl.Expr) -> Optional[SQLExpression]: + """ + Convert a Polars predicate expression to a DuckDB-compatible SQL expression. + + Parameters: + predicate (pl.Expr): A Polars expression (e.g., col("foo") > 5) + + Returns: + SQLExpression: A DuckDB SQL expression string equivalent. + None: If conversion fails. + + Example: + >>> _predicate_to_expression(pl.col("foo") > 5) + SQLExpression("(foo > 5)") + """ + # Serialize the Polars expression tree to JSON + tree = json.loads(predicate.meta.serialize(format="json")) + + try: + # Convert the tree to SQL + sql_filter = _pl_tree_to_sql(tree) + return SQLExpression(sql_filter) + except: + # If the conversion fails, we return None + return None + + +def _pl_operation_to_sql(op: str) -> str: + """ + Map Polars binary operation strings to SQL equivalents. + + Example: + >>> _pl_operation_to_sql("Eq") + '=' + """ + try: + return { + "Lt": "<", + "LtEq": "<=", + "Gt": ">", + "GtEq": ">=", + "Eq": "=", + "Modulus": "%", + "And": "AND", + "Or": "OR", + }[op] + except KeyError: + raise NotImplementedError(op) + + +def _pl_tree_to_sql(tree: dict) -> str: + """ + Recursively convert a Polars expression tree (as JSON) to a SQL string. + + Parameters: + tree (dict): JSON-deserialized expression tree from Polars + + Returns: + str: SQL expression string + + Example: + Input tree: + { + "BinaryExpr": { + "left": { "Column": "foo" }, + "op": "Gt", + "right": { "Literal": { "Int": 5 } } + } + } + Output: "(foo > 5)" + """ + [node_type] = tree.keys() + subtree = tree[node_type] + + if node_type == "BinaryExpr": + # Binary expressions: left OP right + return ( + "(" + + " ".join(( + _pl_tree_to_sql(subtree['left']), + _pl_operation_to_sql(subtree['op']), + _pl_tree_to_sql(subtree['right']) + )) + + ")" + ) + if node_type == "Column": + # A reference to a column name + return subtree + + if node_type in ("Literal", "Dyn"): + # Recursively process dynamic or literal values + return _pl_tree_to_sql(subtree) + + if node_type == "Int": + # Direct integer literals + return str(subtree) + + if node_type == "Function": + # Handle boolean functions like IsNull, IsNotNull + inputs = subtree["input"] + func_dict = subtree["function"] + + if "Boolean" in func_dict: + func = func_dict["Boolean"] + arg_sql = _pl_tree_to_sql(inputs[0]) + + if func == "IsNull": + return f"({arg_sql} IS NULL)" + if func == "IsNotNull": + return f"({arg_sql} IS NOT NULL)" + raise NotImplementedError(f"Boolean function not supported: {func}") + + raise NotImplementedError(f"Unsupported function type: {func_dict}") + + if node_type == "Scalar": + # Handle scalar values with typed representations + dtype = str(subtree["dtype"]) + value = subtree["value"] + + # Decimal support + if dtype.startswith("{'Decimal'"): + decimal_value = value['Decimal'] + decimal_value = Decimal(decimal_value[0]) / Decimal(10 ** decimal_value[1]) + return str(decimal_value) + + # Datetime with microseconds since epoch + if dtype.startswith("{'Datetime'"): + micros = value['Datetime'][0] + dt_timestamp = datetime.datetime.fromtimestamp(micros / 1_000_000, tz=datetime.UTC) + return f"'{str(dt_timestamp)}'::TIMESTAMP" + + # Match simple types + if dtype in ("Int8", "Int16", "Int32", "Int64", "UInt8", "UInt16", "UInt32", "UInt64", "Float32", "Float64", "Boolean"): + return str(value[dtype]) + + if dtype == "Time": + # Convert nanoseconds to TIME + nanoseconds = value["Time"] + seconds = nanoseconds // 1_000_000_000 + microseconds = (nanoseconds % 1_000_000_000) // 1_000 + dt_time = (datetime.datetime.min + datetime.timedelta(seconds=seconds, microseconds=microseconds)).time() + return f"'{str(dt_time)}'::TIME" + + if dtype == "Date": + # Convert days since Unix epoch to SQL DATE + days_since_epoch = value["Date"] + date = datetime.date(1970, 1, 1) + datetime.timedelta(days=days_since_epoch) + return f"'{str(date)}'::DATE" + if dtype == "Binary": + # Convert binary data to hex string for BLOB + binary_data = bytes(value["Binary"]) + escaped = ''.join(f'\\x{b:02x}' for b in binary_data) + return f"'{escaped}'::BLOB" + + if dtype == "String": + return f"'{value['StringOwned']}'" + + raise NotImplementedError(f"Unsupported scalar type {str(dtype)}, with value {value}") + + raise NotImplementedError(f"Node type: {node_type} is not implemented. {subtree}") + +def duckdb_source(relation: duckdb.DuckDBPyRelation, schema: pl.schema.Schema) -> pl.LazyFrame: + """ + A polars IO plugin for DuckDB. + """ + def source_generator( + with_columns: Optional[list[str]], + predicate: Optional[pl.Expr], + n_rows: Optional[int], + batch_size: Optional[int], + ) -> Iterator[pl.DataFrame]: + duck_predicate = None + relation_final = relation + if with_columns is not None: + cols = ",".join(with_columns) + relation_final = relation_final.project(cols) + if n_rows is not None: + relation_final = relation_final.limit(n_rows) + if predicate is not None: + # We have a predicate, if possible, we push it down to DuckDB + duck_predicate = _predicate_to_expression(predicate) + # Try to pushdown filter, if one exists + if duck_predicate is not None: + relation_final = relation_final.filter(duck_predicate) + if batch_size is None: + results = relation_final.fetch_arrow_reader() + else: + results = relation_final.fetch_arrow_reader(batch_size) + while True: + try: + record_batch = results.read_next_batch() + df = pl.from_arrow(record_batch) + if predicate is not None and duck_predicate is None: + # We have a predicate, but did not manage to push it down, we fallback here + yield pl.from_arrow(record_batch).filter(predicate) + else: + yield pl.from_arrow(record_batch) + except StopIteration: + break + + return register_io_source(source_generator, schema=schema) diff --git a/external/duckdb b/external/duckdb index 04daea14..2a04781a 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 04daea14ebf1fa2e0c293b3f59f1af37460a4304 +Subproject commit 2a04781aa9298e0354178cfd1cddd5d77dd5eb85 diff --git a/scripts/cache_data.json b/scripts/cache_data.json index a8968386..640052cd 100644 --- a/scripts/cache_data.json +++ b/scripts/cache_data.json @@ -538,7 +538,8 @@ "name": "duckdb", "children": [ "duckdb.filesystem", - "duckdb.Value" + "duckdb.Value", + "duckdb.polars_io" ] }, "duckdb.filesystem": { @@ -692,5 +693,21 @@ "full_path": "pyarrow.ipc.MessageReader", "name": "MessageReader", "children": [] + }, + "duckdb.polars_io": { + "type": "module", + "full_path": "duckdb.polars_io", + "name": "polars_io", + "children": [ + "duckdb.polars_io.duckdb_source" + ], + "required": false + }, + "duckdb.polars_io.duckdb_source": { + "type": "attribute", + "full_path": "duckdb.polars_io.duckdb_source", + "name": "duckdb_source", + "children": [], + "required": false } } \ No newline at end of file diff --git a/scripts/connection_methods.json b/scripts/connection_methods.json index c852c60e..521d7acb 100644 --- a/scripts/connection_methods.json +++ b/scripts/connection_methods.json @@ -385,6 +385,13 @@ "type": "int" } ], + "kwargs": [ + { + "name": "lazy", + "default": "False", + "type": "bool" + } + ], "return": "polars.DataFrame" }, { diff --git a/scripts/imports.py b/scripts/imports.py index b765f8a1..6b035768 100644 --- a/scripts/imports.py +++ b/scripts/imports.py @@ -122,3 +122,7 @@ collections.abc.Iterable collections.abc.Mapping + +import duckdb.polars_io + +duckdb.polars_io.duckdb_source diff --git a/sqllogic/conftest.py b/sqllogic/conftest.py index 7c5ce2e2..34f92ce2 100644 --- a/sqllogic/conftest.py +++ b/sqllogic/conftest.py @@ -1,3 +1,4 @@ +import glob import itertools import pathlib import pytest @@ -5,12 +6,11 @@ import re import typing import warnings -import glob from .skipped_tests import SKIPPED_TESTS SQLLOGIC_TEST_CASE_NAME = "test_sqllogic" SQLLOGIC_TEST_PARAMETER = "test_script_path" -DUCKDB_ROOT_DIR = (pathlib.Path(__file__).parent.parent / 'external' / 'duckdb').resolve() +DUCKDB_ROOT_DIR = pathlib.Path(__file__).parent.joinpath("../../..").resolve() def pytest_addoption(parser: pytest.Parser): @@ -38,6 +38,18 @@ def pytest_addoption(parser: pytest.Parser): ) parser.addoption("--start-offset", type=int, dest="start_offset", help="Index of the first test to run") parser.addoption("--end-offset", type=int, dest="end_offset", help="Index of the last test to run") + parser.addoption( + "--start-offset-percentage", + type=int, + dest="start_offset_percentage", + help="Runs the tests starting at N % of the total test suite", + ) + parser.addoption( + "--end-offset-percentage", + type=int, + dest="end_offset_percentage", + help="Runs the tests ending at N % of the total test suite, excluding the Nth % test", + ) parser.addoption( "--order", choices=["decl", "lex", "rand"], @@ -142,7 +154,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc): for test_dir in test_dirs: # Create absolute & normalized path test_dir = test_dir.resolve() - assert test_dir.is_dir() + assert test_dir.is_dir(), f"{test_dir} is not a directory" parameters.extend(scan_for_test_scripts(test_dir, metafunc.config)) if parameters == []: @@ -153,6 +165,58 @@ def pytest_generate_tests(metafunc: pytest.Metafunc): metafunc.parametrize(SQLLOGIC_TEST_PARAMETER, parameters) +def determine_test_offsets(config: pytest.Config, num_tests: int) -> typing.Tuple[int, int]: + """ + If start_offset and end_offset are specified, then these are used. + start_offset defaults to 0. end_offset defaults to and is capped to the last test index. + start_offset_percentage and end_offset_percentage are used to calculate the start and end offsets based on the total number of tests. + This is done in a way that a test run to 25% and another test run starting at 25% do not overlap by excluding the 25th percent test. + """ + + start_offset = config.getoption("start_offset") + end_offset = config.getoption("end_offset") + start_offset_percentage = config.getoption("start_offset_percentage") + end_offset_percentage = config.getoption("end_offset_percentage") + + index_specified = start_offset is not None or end_offset is not None + percentage_specified = start_offset_percentage is not None or end_offset_percentage is not None + + if index_specified and percentage_specified: + raise ValueError("You can only specify either start/end offsets or start/end offset percentages, not both") + + if start_offset is not None and start_offset < 0: + raise ValueError("--start-offset must be a non-negative integer") + + if start_offset_percentage is not None and (start_offset_percentage < 0 or start_offset_percentage > 100): + raise ValueError("--start-offset-percentage must be between 0 and 100") + + if end_offset_percentage is not None and (end_offset_percentage < 0 or end_offset_percentage > 100): + raise ValueError("--end-offset-percentage must be between 0 and 100") + + if start_offset is None: + if start_offset_percentage is not None: + start_offset = start_offset_percentage * num_tests // 100 + else: + start_offset = 0 + + if end_offset is not None and end_offset < start_offset: + raise ValueError( + f"--end-offset ({end_offset}) must be greater than or equal to the start offset ({start_offset})" + ) + + if end_offset is None: + if end_offset_percentage is not None: + end_offset = end_offset_percentage * num_tests // 100 - 1 + else: + end_offset = num_tests - 1 + + max_end_offset = num_tests - 1 + if end_offset > max_end_offset: + end_offset = max_end_offset + + return start_offset, end_offset + + # Execute last, after pytest has already deselected tests based on -k and -m parameters @pytest.hookimpl(trylast=True) def pytest_collection_modifyitems(session: pytest.Session, config: pytest.Config, items: list[pytest.Item]): @@ -184,22 +248,7 @@ def pytest_collection_modifyitems(session: pytest.Session, config: pytest.Config config.hook.pytest_deselected(items=deselected_items) items[:] = selected_items - start_offset = config.getoption("start_offset") - if start_offset is None: - start_offset = 0 - - end_offset = config.getoption("end_offset") - if end_offset is None: - end_offset = len(items) - 1 - - if start_offset < 0: - raise ValueError("--start-offset must be a non-negative integer") - elif end_offset < start_offset: - raise ValueError(f"--end-offset ({end_offset}) must be greater than or equal to --start-offset") - - max_end_offset = len(items) - 1 - if end_offset > max_end_offset: - end_offset = max_end_offset + start_offset, end_offset = determine_test_offsets(config, len(items)) # Order tests based on --order option. Take as is if order is "decl". if config.getoption("order") == "rand": diff --git a/src/duckdb_py/arrow/arrow_array_stream.cpp b/src/duckdb_py/arrow/arrow_array_stream.cpp index ccde8c90..6094dcb1 100644 --- a/src/duckdb_py/arrow/arrow_array_stream.cpp +++ b/src/duckdb_py/arrow/arrow_array_stream.cpp @@ -279,8 +279,13 @@ py::object GetScalar(Value &constant, const string &timezone_config, const Arrow return dataset_scalar(constant.GetValue()); case LogicalTypeId::VARCHAR: return dataset_scalar(constant.ToString()); - case LogicalTypeId::BLOB: + case LogicalTypeId::BLOB: { + if (type.GetTypeInfo().GetSizeType() == ArrowVariableSizeType::VIEW) { + py::object binary_view_type = py::module_::import("pyarrow").attr("binary_view"); + return dataset_scalar(scalar(py::bytes(constant.GetValueUnsafe()), binary_view_type())); + } return dataset_scalar(py::bytes(constant.GetValueUnsafe())); + } case LogicalTypeId::DECIMAL: { py::object decimal_type; auto &datetime_info = type.GetTypeInfo(); diff --git a/src/duckdb_py/duckdb_python.cpp b/src/duckdb_py/duckdb_python.cpp index 2c729590..27ebe8b9 100644 --- a/src/duckdb_py/duckdb_python.cpp +++ b/src/duckdb_py/duckdb_python.cpp @@ -408,14 +408,14 @@ static void InitializeConnectionMethods(py::module_ &m) { py::arg("date_as_object") = false, py::arg("connection") = py::none()); m.def( "pl", - [](idx_t rows_per_batch, shared_ptr conn = nullptr) { + [](idx_t rows_per_batch, bool lazy, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } - return conn->FetchPolars(rows_per_batch); + return conn->FetchPolars(rows_per_batch, lazy); }, "Fetch a result as Polars DataFrame following execute()", py::arg("rows_per_batch") = 1000000, py::kw_only(), - py::arg("connection") = py::none()); + py::arg("lazy") = false, py::arg("connection") = py::none()); m.def( "fetch_arrow_table", [](idx_t rows_per_batch, shared_ptr conn = nullptr) { diff --git a/src/duckdb_py/include/duckdb_python/import_cache/modules/duckdb_module.hpp b/src/duckdb_py/include/duckdb_python/import_cache/modules/duckdb_module.hpp index 04104777..2272a7d7 100644 --- a/src/duckdb_py/include/duckdb_python/import_cache/modules/duckdb_module.hpp +++ b/src/duckdb_py/include/duckdb_python/import_cache/modules/duckdb_module.hpp @@ -20,6 +20,25 @@ namespace duckdb { +struct DuckdbPolarsioCacheItem : public PythonImportCacheItem { + +public: + static constexpr const char *Name = "duckdb.polars_io"; + +public: + DuckdbPolarsioCacheItem() : PythonImportCacheItem("duckdb.polars_io"), duckdb_source("duckdb_source", this) { + } + ~DuckdbPolarsioCacheItem() override { + } + + PythonImportCacheItem duckdb_source; + +protected: + bool IsRequired() const override final { + return false; + } +}; + struct DuckdbFilesystemCacheItem : public PythonImportCacheItem { public: @@ -46,13 +65,14 @@ struct DuckdbCacheItem : public PythonImportCacheItem { static constexpr const char *Name = "duckdb"; public: - DuckdbCacheItem() : PythonImportCacheItem("duckdb"), filesystem(), Value("Value", this) { + DuckdbCacheItem() : PythonImportCacheItem("duckdb"), filesystem(), Value("Value", this), polars_io() { } ~DuckdbCacheItem() override { } DuckdbFilesystemCacheItem filesystem; PythonImportCacheItem Value; + DuckdbPolarsioCacheItem polars_io; }; } // namespace duckdb diff --git a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp index 0347a076..48ee055e 100644 --- a/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp +++ b/src/duckdb_py/include/duckdb_python/pyconnection/pyconnection.hpp @@ -319,7 +319,7 @@ struct DuckDBPyConnection : public enable_shared_from_this { PandasDataFrame FetchDFChunk(const idx_t vectors_per_chunk = 1, bool date_as_object = false); duckdb::pyarrow::Table FetchArrow(idx_t rows_per_batch); - PolarsDataFrame FetchPolars(idx_t rows_per_batch); + PolarsDataFrame FetchPolars(idx_t rows_per_batch, bool lazy); py::dict FetchPyTorch(); diff --git a/src/duckdb_py/include/duckdb_python/pyfilesystem.hpp b/src/duckdb_py/include/duckdb_python/pyfilesystem.hpp index 46d8c845..677513f7 100644 --- a/src/duckdb_py/include/duckdb_python/pyfilesystem.hpp +++ b/src/duckdb_py/include/duckdb_python/pyfilesystem.hpp @@ -5,6 +5,7 @@ #include "duckdb_python/pybind11/pybind_wrapper.hpp" #include "duckdb_python/pybind11/gil_wrapper.hpp" #include "duckdb/common/vector.hpp" +#include "duckdb/common/types/timestamp.hpp" namespace duckdb { @@ -90,7 +91,7 @@ class PythonFilesystem : public FileSystem { int64_t GetFileSize(FileHandle &handle) override; void RemoveFile(const string &filename, optional_ptr opener = nullptr) override; void MoveFile(const string &source, const string &dest, optional_ptr opener = nullptr) override; - time_t GetLastModifiedTime(FileHandle &handle) override; + timestamp_t GetLastModifiedTime(FileHandle &handle) override; void FileSync(FileHandle &handle) override; bool DirectoryExists(const string &directory, optional_ptr opener = nullptr) override; void CreateDirectory(const string &directory, optional_ptr opener = nullptr) override; diff --git a/src/duckdb_py/include/duckdb_python/pyrelation.hpp b/src/duckdb_py/include/duckdb_python/pyrelation.hpp index a27433c6..b1feb8ba 100644 --- a/src/duckdb_py/include/duckdb_python/pyrelation.hpp +++ b/src/duckdb_py/include/duckdb_python/pyrelation.hpp @@ -192,7 +192,7 @@ struct DuckDBPyRelation { duckdb::pyarrow::Table ToArrowTableInternal(idx_t batch_size, bool to_polars); - PolarsDataFrame ToPolars(idx_t batch_size); + PolarsDataFrame ToPolars(idx_t batch_size, bool lazy); py::object ToArrowCapsule(const py::object &requested_schema = py::none()); diff --git a/src/duckdb_py/pyconnection.cpp b/src/duckdb_py/pyconnection.cpp index 51051ca7..30b58701 100644 --- a/src/duckdb_py/pyconnection.cpp +++ b/src/duckdb_py/pyconnection.cpp @@ -201,7 +201,7 @@ static void InitializeConnectionMethods(py::class_(list_p); for (auto item : list) { if (py::isinstance(item)) { - exclude.insert(QualifiedColumnName(std::string(py::str(item)))); + string col_str = std::string(py::str(item)); + QualifiedColumnName qname = QualifiedColumnName::Parse(col_str); + exclude.insert(qname); continue; } shared_ptr expr; diff --git a/src/duckdb_py/pyfilesystem.cpp b/src/duckdb_py/pyfilesystem.cpp index f7cb4c70..d9821779 100644 --- a/src/duckdb_py/pyfilesystem.cpp +++ b/src/duckdb_py/pyfilesystem.cpp @@ -187,14 +187,14 @@ void PythonFilesystem::RemoveFile(const string &filename, optional_ptr(pybind11::module_::import("polars").attr("DataFrame")(arrow)); +PolarsDataFrame DuckDBPyRelation::ToPolars(idx_t batch_size, bool lazy) { + if (!lazy) { + auto arrow = ToArrowTableInternal(batch_size, true); + return py::cast(pybind11::module_::import("polars").attr("DataFrame")(arrow)); + } + auto &import_cache = *DuckDBPyConnection::ImportCache(); + auto lazy_frame_produce = import_cache.duckdb.polars_io.duckdb_source(); + // We also have to get a polars schema here, for this we can get at empty arrow table + // We start by extracting the arrow schema + ArrowSchema arrow_schema; + auto result_names = names; + QueryResult::DeduplicateColumns(result_names); + auto client_properties = rel->context->GetContext()->GetClientProperties(); + ArrowConverter::ToArrowSchema(&arrow_schema, types, result_names, client_properties); + py::list batches; + // Now we create an empty arrow table + auto empty_table = pyarrow::ToArrowTable(types, result_names, std::move(batches), client_properties); + + // And we extract the polars schema from the arrow table + auto polars_df = py::cast(pybind11::module_::import("polars").attr("DataFrame")(empty_table)); + auto polars_schema = polars_df.attr("schema"); + + return lazy_frame_produce(*this, polars_schema); } duckdb::pyarrow::RecordBatchReader DuckDBPyRelation::ToRecordBatch(idx_t batch_size) { @@ -984,7 +1005,9 @@ duckdb::pyarrow::RecordBatchReader DuckDBPyRelation::ToRecordBatch(idx_t batch_s ExecuteOrThrow(true); } AssertResultOpen(); - return result->FetchRecordBatchReader(batch_size); + auto res = result->FetchRecordBatchReader(batch_size); + result = nullptr; + return res; } void DuckDBPyRelation::Close() { diff --git a/src/duckdb_py/pyrelation/initialize.cpp b/src/duckdb_py/pyrelation/initialize.cpp index ad78fbaa..a93a54b5 100644 --- a/src/duckdb_py/pyrelation/initialize.cpp +++ b/src/duckdb_py/pyrelation/initialize.cpp @@ -68,7 +68,7 @@ static void InitializeConsumers(py::class_ &m) { .def("to_arrow_table", &DuckDBPyRelation::ToArrowTable, "Execute and fetch all rows as an Arrow Table", py::arg("batch_size") = 1000000) .def("pl", &DuckDBPyRelation::ToPolars, "Execute and fetch all rows as a Polars DataFrame", - py::arg("batch_size") = 1000000) + py::arg("batch_size") = 1000000, py::kw_only(), py::arg("lazy") = false) .def("torch", &DuckDBPyRelation::FetchPyTorch, "Fetch a result as dict of PyTorch Tensors") .def("tf", &DuckDBPyRelation::FetchTF, "Fetch a result as dict of TensorFlow Tensors"); const char *capsule_docs = R"( diff --git a/tests/fast/arrow/test_arrow_binary_view.py b/tests/fast/arrow/test_arrow_binary_view.py new file mode 100644 index 00000000..070de196 --- /dev/null +++ b/tests/fast/arrow/test_arrow_binary_view.py @@ -0,0 +1,20 @@ +import duckdb +import pytest + +pa = pytest.importorskip("pyarrow") + + +class TestArrowBinaryView(object): + def test_arrow_binary_view(self, duckdb_cursor): + con = duckdb.connect() + tab = pa.table({"x": pa.array([b"abc", b"thisisaverybigbinaryyaymorethanfifteen", None], pa.binary_view())}) + assert con.execute("FROM tab").fetchall() == [(b'abc',), (b'thisisaverybigbinaryyaymorethanfifteen',), (None,)] + # By default we won't export a view + assert not con.execute("FROM tab").arrow().equals(tab) + # We do the binary view from 1.4 onwards + con.execute("SET arrow_output_version = 1.4") + assert con.execute("FROM tab").arrow().equals(tab) + + assert con.execute("FROM tab where x = 'thisisaverybigbinaryyaymorethanfifteen'").fetchall() == [ + (b'thisisaverybigbinaryyaymorethanfifteen',) + ] diff --git a/tests/fast/arrow/test_polars.py b/tests/fast/arrow/test_polars.py index b2928f27..1a86c82d 100644 --- a/tests/fast/arrow/test_polars.py +++ b/tests/fast/arrow/test_polars.py @@ -1,11 +1,24 @@ import duckdb import pytest import sys +import datetime pl = pytest.importorskip("polars") arrow = pytest.importorskip("pyarrow") pl_testing = pytest.importorskip("polars.testing") +from duckdb.polars_io import _predicate_to_expression + + +def valid_filter(filter): + sql_expression = _predicate_to_expression(filter) + assert sql_expression is not None + + +def invalid_filter(filter): + sql_expression = _predicate_to_expression(filter) + assert sql_expression is None + class TestPolars(object): def test_polars(self, duckdb_cursor): @@ -87,3 +100,469 @@ def test_polars_from_json_error(self, duckdb_cursor): my_table = conn.query("select 'x' my_str").pl() my_res = duckdb.query("select my_str from my_table where my_str != 'y'") assert my_res.fetchall() == [('x',)] + + def test_polars_lazy(self, duckdb_cursor): + con = duckdb.connect() + con.execute("Create table names (a varchar, b integer)") + con.execute("insert into names values ('Pedro',32), ('Mark',31), ('Thijs', 29)") + rel = con.sql("FROM names") + lazy_df = rel.pl(lazy=True) + + assert isinstance(lazy_df, pl.LazyFrame) + assert lazy_df.collect().to_dicts() == [ + {'a': 'Pedro', 'b': 32}, + {'a': 'Mark', 'b': 31}, + {'a': 'Thijs', 'b': 29}, + ] + + assert lazy_df.select('a').collect().to_dicts() == [{'a': 'Pedro'}, {'a': 'Mark'}, {'a': 'Thijs'}] + assert lazy_df.limit(1).collect().to_dicts() == [{'a': 'Pedro', 'b': 32}] + assert lazy_df.filter(pl.col("b") < 32).collect().to_dicts() == [ + {'a': 'Mark', 'b': 31}, + {'a': 'Thijs', 'b': 29}, + ] + assert lazy_df.filter(pl.col("b") < 32).select('a').collect().to_dicts() == [{'a': 'Mark'}, {'a': 'Thijs'}] + + @pytest.mark.parametrize( + 'data_type', + [ + 'TINYINT', + 'SMALLINT', + 'INTEGER', + 'BIGINT', + 'UTINYINT', + 'USMALLINT', + 'UINTEGER', + 'UBIGINT', + 'FLOAT', + 'DOUBLE', + 'HUGEINT', + 'DECIMAL(4,1)', + 'DECIMAL(9,1)', + 'DECIMAL(18,4)', + 'DECIMAL(30,12)', + ], + ) + def test_polars_lazy_pushdown_numeric(self, data_type, duckdb_cursor): + con = duckdb.connect() + tbl_name = "test" + con.execute( + f""" + CREATE TABLE {tbl_name} ( + a {data_type}, + b {data_type}, + c {data_type} + ) + """ + ) + con.execute( + f""" + INSERT INTO {tbl_name} VALUES + (1,1,1), + (10,10,10), + (100,10,100), + (NULL,NULL,NULL) + """ + ) + rel = con.sql(f"FROM {tbl_name}") + lazy_df = rel.pl(lazy=True) + + # Equality + assert lazy_df.filter(pl.col("a") == 1).select("a").collect().to_dicts() == [{"a": 1}] + + # Greater than + assert lazy_df.filter(pl.col("a") > 1).select("a").collect().to_dicts() == [{"a": 10}, {"a": 100}] + # Greater than or equal + assert lazy_df.filter(pl.col("a") >= 10).select("a").collect().to_dicts() == [{"a": 10}, {"a": 100}] + # Less than + assert lazy_df.filter(pl.col("a") < 10).select("a").collect().to_dicts() == [{"a": 1}] + # Less than or equal + assert lazy_df.filter(pl.col("a") <= 10).select("a").collect().to_dicts() == [{"a": 1}, {"a": 10}] + + # IS NULL + assert lazy_df.filter(pl.col("a").is_null()).select("a").collect().to_dicts() == [{"a": None}] + # IS NOT NULL + assert lazy_df.filter(pl.col("a").is_not_null()).select("a").collect().to_dicts() == [ + {"a": 1}, + {"a": 10}, + {"a": 100}, + ] + + # AND + assert lazy_df.filter((pl.col("a") == 10) & (pl.col("b") == 1)).collect().to_dicts() == [] + assert lazy_df.filter( + (pl.col("a") == 100) & (pl.col("b") == 10) & (pl.col("c") == 100) + ).collect().to_dicts() == [{"a": 100, "b": 10, "c": 100}] + + # OR + assert lazy_df.filter((pl.col("a") == 100) | (pl.col("b") == 1)).select("a", "b").collect().to_dicts() == [ + {"a": 1, "b": 1}, + {"a": 100, "b": 10}, + ] + + # Validate Filters + valid_filter(pl.col("a") == 1) + valid_filter(pl.col("a") > 1) + valid_filter(pl.col("a") >= 10) + valid_filter(pl.col("a") < 10) + valid_filter(pl.col("a") <= 10) + valid_filter(pl.col("a").is_null()) + valid_filter(pl.col("a").is_not_null()) + valid_filter((pl.col("a") == 10) & (pl.col("b") == 1)) + valid_filter((pl.col("a") == 100) & (pl.col("b") == 10) & (pl.col("c") == 100)) + valid_filter((pl.col("a") == 100) | (pl.col("b") == 1)) + + def test_polars_lazy_pushdown_bool(self, duckdb_cursor): + duckdb_cursor.execute( + """ + CREATE TABLE test_bool ( + a BOOL, + b BOOL + ) + """ + ) + duckdb_cursor.execute( + """ + INSERT INTO test_bool VALUES + (TRUE,TRUE), + (TRUE,FALSE), + (FALSE,TRUE), + (NULL,NULL) + """ + ) + duck_tbl = duckdb_cursor.table("test_bool") + + lazy_df = duck_tbl.pl(lazy=True) + # == True + assert lazy_df.filter(pl.col("a") == True).select(pl.len()).collect().item() == 2 + + # IS NULL + assert lazy_df.filter(pl.col("a").is_null()).select(pl.len()).collect().item() == 1 + + # IS NOT NULL + assert lazy_df.filter(pl.col("a").is_not_null()).select(pl.len()).collect().item() == 3 + + # AND + assert lazy_df.filter((pl.col("a") == True) & (pl.col("b") == True)).select(pl.len()).collect().item() == 1 + + # OR + assert lazy_df.filter((pl.col("a") == True) | (pl.col("b") == True)).select(pl.len()).collect().item() == 3 + + # Validate Filters + valid_filter(pl.col("a") == True) + valid_filter(pl.col("a").is_null()) + valid_filter(pl.col("a").is_not_null()) + valid_filter((pl.col("a") == True) & (pl.col("b") == True)) + valid_filter((pl.col("a") == True) | (pl.col("b") == True)) + + def test_polars_lazy_pushdown_time(self, duckdb_cursor): + duckdb_cursor.execute( + """ + CREATE TABLE test_time ( + a TIME, + b TIME, + c TIME + ) + """ + ) + duckdb_cursor.execute( + """ + INSERT INTO test_time VALUES + ('00:01:00','00:01:00','00:01:00'), + ('00:10:00','00:10:00','00:10:00'), + ('01:00:00','00:10:00','01:00:00'), + (NULL,NULL,NULL) + """ + ) + duck_tbl = duckdb_cursor.table("test_time") + lazy_df = duck_tbl.pl(lazy=True) + + # Comparison time values + t_001 = datetime.time(0, 1) + t_010 = datetime.time(0, 10) + t_100 = datetime.time(1, 0) + + # == + assert lazy_df.filter(pl.col("a") == t_001).select(pl.len()).collect().item() == 1 + # > + assert lazy_df.filter(pl.col("a") > t_001).select(pl.len()).collect().item() == 2 + # >= + assert lazy_df.filter(pl.col("a") >= t_010).select(pl.len()).collect().item() == 2 + # < + assert lazy_df.filter(pl.col("a") < t_010).select(pl.len()).collect().item() == 1 + # <= + assert lazy_df.filter(pl.col("a") <= t_010).select(pl.len()).collect().item() == 2 + + # IS NULL + assert lazy_df.filter(pl.col("a").is_null()).select(pl.len()).collect().item() == 1 + # IS NOT NULL + assert lazy_df.filter(pl.col("a").is_not_null()).select(pl.len()).collect().item() == 3 + + # AND conditions + assert lazy_df.filter((pl.col("a") == t_010) & (pl.col("b") == t_001)).select(pl.len()).collect().item() == 0 + assert ( + lazy_df.filter((pl.col("a") == t_100) & (pl.col("b") == t_010) & (pl.col("c") == t_100)) + .select(pl.len()) + .collect() + .item() + == 1 + ) + + # OR condition + assert lazy_df.filter((pl.col("a") == t_100) | (pl.col("b") == t_001)).select(pl.len()).collect().item() == 2 + + # Validate Filter + valid_filter(pl.col("a") == t_001) + valid_filter(pl.col("a") > t_001) + valid_filter(pl.col("a") >= t_010) + valid_filter(pl.col("a") < t_010) + valid_filter(pl.col("a") <= t_010) + valid_filter(pl.col("a").is_null()) + valid_filter(pl.col("a").is_not_null()) + valid_filter((pl.col("a") == t_010) & (pl.col("b") == t_001)) + valid_filter((pl.col("a") == t_100) & (pl.col("b") == t_010) & (pl.col("c") == t_100)) + valid_filter((pl.col("a") == t_100) | (pl.col("b") == t_001)) + + def test_polars_lazy_pushdown_timestamp(self, duckdb_cursor): + duckdb_cursor.execute( + """ + CREATE TABLE test_timestamp ( + a TIMESTAMP, + b TIMESTAMP, + c TIMESTAMP + ) + """ + ) + duckdb_cursor.execute( + """ + INSERT INTO test_timestamp VALUES + ('2008-01-01 00:00:01','2008-01-01 00:00:01','2008-01-01 00:00:01'), + ('2010-01-01 10:00:01','2010-01-01 10:00:01','2010-01-01 10:00:01'), + ('2020-03-01 10:00:01','2010-01-01 10:00:01','2020-03-01 10:00:01'), + (NULL,NULL,NULL) + """ + ) + duck_tbl = duckdb_cursor.table("test_timestamp") + lazy_df = duck_tbl.pl(lazy=True) + + # Define timestamps + ts_2008 = datetime.datetime(2008, 1, 1, 0, 0, 1) + ts_2010 = datetime.datetime(2010, 1, 1, 10, 0, 1) + ts_2020 = datetime.datetime(2020, 3, 1, 10, 0, 1) + + # These will require a cast, which we currently do not support, hence the filter won't be pushed down, but the results + # Should still be correct, and we check we can't really pushdown the filter yet. + + # == + assert lazy_df.filter(pl.col("a") == ts_2008).select(pl.len()).collect().item() == 1 + # > + assert lazy_df.filter(pl.col("a") > ts_2008).select(pl.len()).collect().item() == 2 + # >= + assert lazy_df.filter(pl.col("a") >= ts_2010).select(pl.len()).collect().item() == 2 + # < + assert lazy_df.filter(pl.col("a") < ts_2010).select(pl.len()).collect().item() == 1 + # <= + assert lazy_df.filter(pl.col("a") <= ts_2010).select(pl.len()).collect().item() == 2 + + # IS NULL + assert lazy_df.filter(pl.col("a").is_null()).select(pl.len()).collect().item() == 1 + # IS NOT NULL + assert lazy_df.filter(pl.col("a").is_not_null()).select(pl.len()).collect().item() == 3 + + # AND + assert ( + lazy_df.filter((pl.col("a") == ts_2010) & (pl.col("b") == ts_2008)).select(pl.len()).collect().item() == 0 + ) + assert ( + lazy_df.filter((pl.col("a") == ts_2020) & (pl.col("b") == ts_2010) & (pl.col("c") == ts_2020)) + .select(pl.len()) + .collect() + .item() + == 1 + ) + + # OR + assert ( + lazy_df.filter((pl.col("a") == ts_2020) | (pl.col("b") == ts_2008)).select(pl.len()).collect().item() == 2 + ) + + # Validate Filter + invalid_filter(pl.col("a") == ts_2008) + invalid_filter(pl.col("a") > ts_2008) + invalid_filter(pl.col("a") >= ts_2010) + invalid_filter(pl.col("a") < ts_2010) + invalid_filter(pl.col("a") <= ts_2010) + # These two are actually valid because they don't produce a cast + valid_filter(pl.col("a").is_null()) + valid_filter(pl.col("a").is_not_null()) + invalid_filter((pl.col("a") == ts_2010) & (pl.col("b") == ts_2008)) + invalid_filter((pl.col("a") == ts_2020) & (pl.col("b") == ts_2010) & (pl.col("c") == ts_2020)) + invalid_filter((pl.col("a") == ts_2020) | (pl.col("b") == ts_2008)) + + def test_polars_lazy_pushdown_date(self, duckdb_cursor): + duckdb_cursor.execute( + """ + CREATE TABLE test_date ( + a DATE, + b DATE, + c DATE + ) + """ + ) + duckdb_cursor.execute( + """ + INSERT INTO test_date VALUES + ('2000-01-01','2000-01-01','2000-01-01'), + ('2000-10-01','2000-10-01','2000-10-01'), + ('2010-01-01','2000-10-01','2010-01-01'), + (NULL,NULL,NULL) + """ + ) + duck_tbl = duckdb_cursor.table("test_date") + lazy_df = duck_tbl.pl(lazy=True) + + # Reference dates + d_2000_01_01 = datetime.date(2000, 1, 1) + d_2000_10_01 = datetime.date(2000, 10, 1) + d_2010_01_01 = datetime.date(2010, 1, 1) + + # == + assert lazy_df.filter(pl.col("a") == d_2000_01_01).select(pl.len()).collect().item() == 1 + # > + assert lazy_df.filter(pl.col("a") > d_2000_01_01).select(pl.len()).collect().item() == 2 + # >= + assert lazy_df.filter(pl.col("a") >= d_2000_10_01).select(pl.len()).collect().item() == 2 + # < + assert lazy_df.filter(pl.col("a") < d_2000_10_01).select(pl.len()).collect().item() == 1 + # <= + assert lazy_df.filter(pl.col("a") <= d_2000_10_01).select(pl.len()).collect().item() == 2 + + # IS NULL + assert lazy_df.filter(pl.col("a").is_null()).select(pl.len()).collect().item() == 1 + # IS NOT NULL + assert lazy_df.filter(pl.col("a").is_not_null()).select(pl.len()).collect().item() == 3 + + # AND + assert ( + lazy_df.filter((pl.col("a") == d_2000_10_01) & (pl.col("b") == d_2000_01_01)) + .select(pl.len()) + .collect() + .item() + == 0 + ) + assert ( + lazy_df.filter( + (pl.col("a") == d_2010_01_01) & (pl.col("b") == d_2000_10_01) & (pl.col("c") == d_2010_01_01) + ) + .select(pl.len()) + .collect() + .item() + == 1 + ) + + # OR + assert ( + lazy_df.filter((pl.col("a") == d_2010_01_01) | (pl.col("b") == d_2000_01_01)) + .select(pl.len()) + .collect() + .item() + == 2 + ) + + # Validate Filter + valid_filter(pl.col("a") == d_2000_01_01) + valid_filter(pl.col("a") > d_2000_01_01) + valid_filter(pl.col("a") >= d_2000_10_01) + valid_filter(pl.col("a") < d_2000_10_01) + valid_filter(pl.col("a") <= d_2000_10_01) + valid_filter(pl.col("a").is_null()) + valid_filter(pl.col("a").is_not_null()) + valid_filter((pl.col("a") == d_2000_10_01) & (pl.col("b") == d_2000_01_01)) + valid_filter((pl.col("a") == d_2010_01_01) & (pl.col("b") == d_2000_10_01) & (pl.col("c") == d_2010_01_01)) + valid_filter((pl.col("a") == d_2010_01_01) | (pl.col("b") == d_2000_01_01)) + + def test_polars_lazy_pushdown_blob(self, duckdb_cursor): + import pandas + + df = pandas.DataFrame( + { + 'a': [bytes([1]), bytes([2]), bytes([3]), None], + 'b': [bytes([1]), bytes([2]), bytes([3]), None], + 'c': [bytes([1]), bytes([2]), bytes([3]), None], + } + ) + duck_tbl = duckdb.from_df(df) + lazy_df = duck_tbl.pl(lazy=True) + + # Reference bytes + b1 = b"\x01" + b2 = b"\x02" + + # == + assert lazy_df.filter(pl.col("a") == b1).select(pl.len()).collect().item() == 1 + # > + assert lazy_df.filter(pl.col("a") > b1).select(pl.len()).collect().item() == 2 + # >= + assert lazy_df.filter(pl.col("a") >= b2).select(pl.len()).collect().item() == 2 + # < + assert lazy_df.filter(pl.col("a") < b2).select(pl.len()).collect().item() == 1 + # <= + assert lazy_df.filter(pl.col("a") <= b2).select(pl.len()).collect().item() == 2 + + # IS NULL + assert lazy_df.filter(pl.col("a").is_null()).select(pl.len()).collect().item() == 1 + # IS NOT NULL + assert lazy_df.filter(pl.col("a").is_not_null()).select(pl.len()).collect().item() == 3 + + # AND + assert lazy_df.filter((pl.col("a") == b2) & (pl.col("b") == b1)).select(pl.len()).collect().item() == 0 + assert ( + lazy_df.filter((pl.col("a") == b2) & (pl.col("b") == b2) & (pl.col("c") == b2)) + .select(pl.len()) + .collect() + .item() + == 1 + ) + + # OR + assert lazy_df.filter((pl.col("a") == b1) | (pl.col("b") == b2)).select(pl.len()).collect().item() == 2 + + # Validate Filter + valid_filter(pl.col("a") == b1) + valid_filter(pl.col("a") > b1) + valid_filter(pl.col("a") >= b2) + valid_filter(pl.col("a") < b2) + valid_filter(pl.col("a") <= b2) + valid_filter(pl.col("a").is_null()) + valid_filter(pl.col("a").is_not_null()) + valid_filter((pl.col("a") == b2) & (pl.col("b") == b1)) + valid_filter((pl.col("a") == b2) & (pl.col("b") == b2) & (pl.col("c") == b2)) + valid_filter((pl.col("a") == b1) | (pl.col("b") == b2)) + + def test_polars_lazy_many_batches(self, duckdb_cursor): + duckdb_cursor = duckdb.connect() + duckdb_cursor.execute("CREATE table t as select range a from range(3000);") + duck_tbl = duckdb_cursor.table("t") + + lazy_df = duck_tbl.pl(1024, lazy=True) + + streamed_result = lazy_df.collect(engine="streaming") + + batches = streamed_result.iter_slices(1024) + + chunk1 = next(batches) + assert len(chunk1) == 1024 + + chunk2 = next(batches) + assert len(chunk2) == 1024 + + chunk3 = next(batches) + assert len(chunk3) == 952 + + with pytest.raises(StopIteration): + next(batches) + + res = duckdb_cursor.execute("FROM streamed_result").fetchall() + correct = duckdb_cursor.execute("FROM t").fetchall() + + assert res == correct From b7a2f7a61df92328ed8aa23e303e014d566176cb Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 14 Jul 2025 14:10:57 +0200 Subject: [PATCH 002/472] Fix on-pr workflow (#4) Fix PR workflow to use the correct git refs for checkout --- .github/workflows/coverage.yml | 7 ++++--- .github/workflows/on_pr.yml | 6 +++--- .github/workflows/pypi_packaging.yml | 7 +++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 2c05d032..fdd2a838 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -5,11 +5,11 @@ on: git_ref: type: string description: Git ref of the DuckDB python package - default: refs/heads/main - required: true + required: false duckdb_git_ref: type: string description: Git ref of DuckDB + required: false testsuite: type: choice description: Testsuite to run ('all' or 'fast') @@ -23,9 +23,10 @@ on: git_ref: type: string description: Git ref of the DuckDB python package - required: true + required: false duckdb_git_ref: type: string + required: false testsuite: type: string description: Testsuite to run ('all' or 'fast') diff --git a/.github/workflows/on_pr.yml b/.github/workflows/on_pr.yml index 5b727166..81fd3247 100644 --- a/.github/workflows/on_pr.yml +++ b/.github/workflows/on_pr.yml @@ -19,6 +19,7 @@ concurrency: cancel-in-progress: true jobs: + packaging_test: name: Build a minimal set of packages and run all tests on them # Skip packaging tests for draft PRs @@ -27,13 +28,12 @@ jobs: with: minimal: true testsuite: all - git_ref: ${{ github.ref }} - duckdb_git_ref: ${{ github.ref_name }} + duckdb_git_ref: ${{ github.base_ref }} coverage_test: name: Run coverage tests if: ${{ github.event_name != 'pull_request' || github.event.pull_request.draft == false }} uses: ./.github/workflows/coverage.yml with: - git_ref: ${{ github.ref }} + duckdb_git_ref: ${{ github.base_ref }} testsuite: all diff --git a/.github/workflows/pypi_packaging.yml b/.github/workflows/pypi_packaging.yml index a7f659f9..ef4f527a 100644 --- a/.github/workflows/pypi_packaging.yml +++ b/.github/workflows/pypi_packaging.yml @@ -19,8 +19,7 @@ on: git_ref: type: string description: Git ref of the DuckDB python package - required: true - default: refs/heads/main + required: false duckdb_git_ref: type: string description: Git ref of DuckDB @@ -44,11 +43,11 @@ on: git_ref: type: string description: Git ref of the DuckDB python package - required: true + required: false duckdb_git_ref: type: string description: Git ref of DuckDB - required: true + required: false force_version: description: Force version (vX.Y.Z-((rc|post)N)) required: false From eac6a1c6b2a0b87b3b1acfbfa508fdf0544b0e86 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 14 Jul 2025 14:19:18 +0200 Subject: [PATCH 003/472] Nightly python builds (#5) Add workflow for external dispatch --- .github/workflows/on_external_dispatch.yml | 28 ++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 .github/workflows/on_external_dispatch.yml diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml new file mode 100644 index 00000000..3604ce34 --- /dev/null +++ b/.github/workflows/on_external_dispatch.yml @@ -0,0 +1,28 @@ +name: Builds triggered externally by DuckDB +on: + workflow_dispatch: + inputs: + duckdb-sha: + type: string + description: The DuckDB SHA to build against + required: true + force_version: + type: string + description: Force version (vX.Y.Z-((rc|post)N)) + required: false + publish_to_pypi: + type: boolean + description: Publish packages to PyPI? + required: true + default: false + +jobs: + externally_triggered_build: + name: Build and test releases + uses: ./.github/workflows/pypi_packaging.yml + with: + minimal: false + testsuite: all + git_ref: ${{ github.ref }} + duckdb_git_ref: ${{ inputs.duckdb-sha }} + force_version: ${{ inputs.force_version }} From 48a47a5e927fe5a1427de22d72b050191b0dee62 Mon Sep 17 00:00:00 2001 From: Evert Date: Mon, 14 Jul 2025 16:39:08 +0200 Subject: [PATCH 004/472] Enable upload to staging --- .github/workflows/on_external_dispatch.yml | 33 ++++++++++++++++++++-- .github/workflows/pypi_packaging.yml | 4 ++- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index 3604ce34..46252f0d 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -10,9 +10,9 @@ on: type: string description: Force version (vX.Y.Z-((rc|post)N)) required: false - publish_to_pypi: + publish_packages: type: boolean - description: Publish packages to PyPI? + description: Publish packages on S3 and PyPI? required: true default: false @@ -26,3 +26,32 @@ jobs: git_ref: ${{ github.ref }} duckdb_git_ref: ${{ inputs.duckdb-sha }} force_version: ${{ inputs.force_version }} + + upload_to_staging: + name: Upload Artifacts to staging + runs-on: ubuntu-latest + needs: [ externally_triggered_build ] + if: ${{ github.repository_owner == 'duckdb' && inputs.publish_packages }} + steps: + - name: Fetch artifacts + uses: actions/download-artifact@v4 + with: + pattern: '{sdist,wheel}*' + path: artifacts/ + merge-multiple: true + + - name: Authenticate with AWS + uses: aws-actions/configure-aws-credentials@v4 + with: + aws-region: 'us-east-2' + aws-access-key-id: ${{ secrets.S3_DUCKDB_STAGING_ID }} + aws-secret-access-key: ${{ secrets.S3_DUCKDB_STAGING_KEY }} + + - name: Upload artifacts to S3 bucket + shell: bash + run: | + DUCKDB_SHA="${{ inputs.duckdb-sha }}" + aws s3 cp \ + artifacts \ + s3://duckdb-staging/${DUCKDB_SHA:0:7}/${{ github.repository }}/ \ + --recursive diff --git a/.github/workflows/pypi_packaging.yml b/.github/workflows/pypi_packaging.yml index ef4f527a..84016645 100644 --- a/.github/workflows/pypi_packaging.yml +++ b/.github/workflows/pypi_packaging.yml @@ -114,8 +114,9 @@ jobs: - uses: actions/upload-artifact@v4 with: - name: sdist-main + name: sdist path: dist/*.tar.gz + compression-level: 0 build_wheels: name: 'Wheel: ${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }}' @@ -186,3 +187,4 @@ jobs: with: name: wheel-${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} path: wheelhouse/*.whl + compression-level: 0 From c7d074a4d5fca542fd2e229b73f3f5ae01c359fe Mon Sep 17 00:00:00 2001 From: Evert Date: Mon, 14 Jul 2025 16:59:52 +0200 Subject: [PATCH 005/472] Enable automatic submodule switching --- .githooks/post-checkout | 6 ++++++ .gitmodules | 2 ++ README.md | 21 +++++++++++++++++++ external/README_GIT_SUBMODULE.md | 36 -------------------------------- 4 files changed, 29 insertions(+), 36 deletions(-) create mode 100755 .githooks/post-checkout delete mode 100644 external/README_GIT_SUBMODULE.md diff --git a/.githooks/post-checkout b/.githooks/post-checkout new file mode 100755 index 00000000..8e5f7089 --- /dev/null +++ b/.githooks/post-checkout @@ -0,0 +1,6 @@ +#!/bin/sh +if [ "$3" = "1" ]; then + echo "Updating submodules..." + git submodule update --init --recursive +fi + diff --git a/.gitmodules b/.gitmodules index 2f4bbe0e..d4c01492 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,3 +2,5 @@ path = external/duckdb url = https://github.com/duckdb/duckdb.git branch = main +[submodule] + recurse = true diff --git a/README.md b/README.md index c5cdd815..ca48ccd5 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,27 @@ pip install 'duckdb[all]' ## Development +### Cloning + +When you clone the repo or your fork, make sure you initialize the duckdb submodule: +```shell +git clone --recurse-submodules +``` + +... or, if you already have the repo locally: +```shell +git clone +cd +git submodule update --init --recursive +``` + +If you'll be switching between branches that are have the submodule set to different refs, then make your life +easier and add the git hooks in the .githooks directory to your local config: +```shell +git config --local core.hooksPath .githooks/ +``` + + ### Building wheels and sdists To build a wheel and sdist for your system and the default Python version: diff --git a/external/README_GIT_SUBMODULE.md b/external/README_GIT_SUBMODULE.md deleted file mode 100644 index 777bb495..00000000 --- a/external/README_GIT_SUBMODULE.md +++ /dev/null @@ -1,36 +0,0 @@ -# DuckDB vendored as submodule - -The submodule has a relative path. Git resolves it against the super-project’s remote, so a clone of `git@github.com:alice/duckdb-python.git` will automatically look for `git@github.com:alice/duckdb.git`, while a clone of the canonical repo falls back to -`https://github.com/duckdb/duckdb.git`. - -### Clone python repo - -From the main repo: -```shell -git clone --recurse-submodules git@github.com:duckdb/duckdb-python.git -``` - -Or from your fork (NOTE that you must also fork duckdb): -```shell -git clone --recurse-submodules git@github.com:/duckdb-python.git -``` - -### Switch submodule to fork -```shell -git submodule set-url external/duckdb git@github.com:/duckdb.git -``` - -### Work on a feature branch -```shell -git -C external/duckdb checkout my-branch -``` - -### Jump to latest DuckDB tag -```shell -git -C external/duckdb fetch --tags && git -C external/duckdb checkout v1.3.1 # Example -``` - -### Pull latest changes -```shell -git submodule update --remote --merge external/duckdb -``` \ No newline at end of file From bc556d0f252f0a13784941bc77cae5cf6a9cbf04 Mon Sep 17 00:00:00 2001 From: Evert Date: Wed, 16 Jul 2025 14:11:27 +0200 Subject: [PATCH 006/472] Better IDE integration --- CMakeLists.txt | 18 ++++++++++++------ README.md | 37 +++++++++++++++++++++++-------------- pyproject.toml | 1 + 3 files changed, 36 insertions(+), 20 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c927920e..7c6c5332 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,12 +17,6 @@ setup_compiler_launcher_if_available() # ──────────────────────────────────────────── # Create compile_commands.json for IntelliSense and clang-tidy set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -# If we're not building through scikit-build-core then we have to set a different dest dir -include(GNUInstallDirs) -set(_DUCKDB_PY_INSTALL_DIR "${SKBUILD_PLATLIB_DIR}") -if(NOT _DUCKDB_PY_INSTALL_DIR) - set(_DUCKDB_PY_INSTALL_DIR "${CMAKE_INSTALL_LIBDIR}") -endif() # ──────────────────────────────────────────── # Policy hygiene @@ -87,4 +81,16 @@ target_link_libraries(_duckdb PRIVATE _duckdb_dependencies) # ──────────────────────────────────────────── # Put the object file in the correct place # ──────────────────────────────────────────── + +# If we're not building through scikit-build-core then we have to set a different dest dir +include(GNUInstallDirs) +if(DEFINED SKBUILD_PLATLIB_DIR) + set(_DUCKDB_PY_INSTALL_DIR "${SKBUILD_PLATLIB_DIR}") +elseif(DEFINED Python_SITEARCH) + set(_DUCKDB_PY_INSTALL_DIR "${Python_SITEARCH}") +else() + message(WARNING "Could not determine Python install dir. Falling back to CMAKE_INSTALL_LIBDIR.") + set(_DUCKDB_PY_INSTALL_DIR "${CMAKE_INSTALL_LIBDIR}") +endif() + install(TARGETS _duckdb LIBRARY DESTINATION "${_DUCKDB_PY_INSTALL_DIR}") diff --git a/README.md b/README.md index ca48ccd5..13e9569d 100644 --- a/README.md +++ b/README.md @@ -65,19 +65,6 @@ git config --local core.hooksPath .githooks/ ``` -### Building wheels and sdists - - To build a wheel and sdist for your system and the default Python version: -```bash -uv build -```` - - To build a wheel for a different Python version: -```bash -# E.g. for Python 3.9 -uv build -p 3.9 -``` - ### Editable installs (general) It's good to be aware of the following when creating an editable install: @@ -93,7 +80,7 @@ uv build -p 3.9 # install all dev dependencies without building the project (needed once) uv sync -p 3.9 --no-install-project # build and install without build isolation -uv sync --no-build-isolation +uv sync --no-build-isolation ``` ### Editable installs (IDEs) @@ -103,6 +90,28 @@ uv sync --no-build-isolation compilation and editable rebuilds. This will skip scikit-build-core's build backend and all of uv's dependency management, so for "real" builds you better revert to the CLI. However, this should work fine for coding and debugging. + +### Cleaning + +```shell +uv cache clean +rm -rf build .venv uv.lock +``` + + +### Building wheels and sdists + +To build a wheel and sdist for your system and the default Python version: +```bash +uv build +```` + +To build a wheel for a different Python version: +```bash +# E.g. for Python 3.9 +uv build -p 3.9 +``` + ### Running tests Run all pytests: diff --git a/pyproject.toml b/pyproject.toml index 5532f41b..d0642b3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,6 +112,7 @@ if.state = "editable" if.env.COVERAGE = false build-dir = "build/debug/" editable.rebuild = true +editable.mode = "redirect" cmake.build-type = "Debug" # Separate override because we have to append to cmake.define with `inherit` in order not to overwrite other defines. From 119a8b962721a95702af986a1facf2e3aff58e1a Mon Sep 17 00:00:00 2001 From: Evert Date: Wed, 16 Jul 2025 14:16:32 +0200 Subject: [PATCH 007/472] Set MAIN_BRANCH_VERSIONING to False for the bugfix branch --- duckdb_packaging/setuptools_scm_version.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/duckdb_packaging/setuptools_scm_version.py b/duckdb_packaging/setuptools_scm_version.py index f555e384..956cf38a 100644 --- a/duckdb_packaging/setuptools_scm_version.py +++ b/duckdb_packaging/setuptools_scm_version.py @@ -12,8 +12,8 @@ # Import from our own versioning module to avoid duplication from ._versioning import parse_version, format_version -# MAIN_BRANCH_VERSIONING default should be 'True' for main branch and feature branches -MAIN_BRANCH_VERSIONING = True +# MAIN_BRANCH_VERSIONING should be 'True' on main branch only +MAIN_BRANCH_VERSIONING = False SCM_PRETEND_ENV_VAR = "SETUPTOOLS_SCM_PRETEND_VERSION_FOR_DUCKDB" SCM_GLOBAL_PRETEND_ENV_VAR = "SETUPTOOLS_SCM_PRETEND_VERSION" From 7caf97f14c8d9c078d3e6e67455e003735dfa56b Mon Sep 17 00:00:00 2001 From: Evert Date: Wed, 16 Jul 2025 14:17:53 +0200 Subject: [PATCH 008/472] Use same length SHA for staging upload as duckdb itself --- .github/workflows/on_external_dispatch.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index 46252f0d..33a807db 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -53,5 +53,5 @@ jobs: DUCKDB_SHA="${{ inputs.duckdb-sha }}" aws s3 cp \ artifacts \ - s3://duckdb-staging/${DUCKDB_SHA:0:7}/${{ github.repository }}/ \ + s3://duckdb-staging/${DUCKDB_SHA:0:10}/${{ github.repository }}/ \ --recursive From 09002df4e54ecbbbf7dbdb09ec2c558646115757 Mon Sep 17 00:00:00 2001 From: Evert Date: Wed, 16 Jul 2025 14:23:59 +0200 Subject: [PATCH 009/472] Set MAIN_BRANCH_VERSIONING to True for main branch --- duckdb_packaging/setuptools_scm_version.py | 2 +- external/duckdb | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/duckdb_packaging/setuptools_scm_version.py b/duckdb_packaging/setuptools_scm_version.py index 956cf38a..932fcd52 100644 --- a/duckdb_packaging/setuptools_scm_version.py +++ b/duckdb_packaging/setuptools_scm_version.py @@ -13,7 +13,7 @@ from ._versioning import parse_version, format_version # MAIN_BRANCH_VERSIONING should be 'True' on main branch only -MAIN_BRANCH_VERSIONING = False +MAIN_BRANCH_VERSIONING = True SCM_PRETEND_ENV_VAR = "SETUPTOOLS_SCM_PRETEND_VERSION_FOR_DUCKDB" SCM_GLOBAL_PRETEND_ENV_VAR = "SETUPTOOLS_SCM_PRETEND_VERSION" diff --git a/external/duckdb b/external/duckdb index 2a04781a..0b83e5d2 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 2a04781aa9298e0354178cfd1cddd5d77dd5eb85 +Subproject commit 0b83e5d2f68bc02dfefde74b846bd039f078affa From 4c771e94faca10a969e1f9e1c4e2ceaafe6c7c67 Mon Sep 17 00:00:00 2001 From: Evert Date: Thu, 17 Jul 2025 16:53:06 +0200 Subject: [PATCH 010/472] Fix versioning tests with explicit MAIN_BRANCH_VERSIONING values --- tests/fast/test_versioning.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/fast/test_versioning.py b/tests/fast/test_versioning.py index 08a0e730..7a3c7a68 100644 --- a/tests/fast/test_versioning.py +++ b/tests/fast/test_versioning.py @@ -109,9 +109,9 @@ def test_bump_version_exact_tag(self): assert _bump_version("1.2.3", 0, False) == "1.2.3" assert _bump_version("1.2.3.post1", 0, False) == "1.2.3.post1" + @patch.dict('os.environ', {'MAIN_BRANCH_VERSIONING': '1'}) def test_bump_version_with_distance(self): """Test bump_version with distance from tag.""" - # Main branch versioning (default) assert _bump_version("1.2.3", 5, False) == "1.3.0.dev5" # Post-release development @@ -119,19 +119,16 @@ def test_bump_version_with_distance(self): @patch.dict('os.environ', {'MAIN_BRANCH_VERSIONING': '0'}) def test_bump_version_release_branch(self): - """Test bump_version on release branch (MAIN_BRANCH_VERSIONING=False).""" - # Need to reload the module to pick up the environment variable - import importlib - from duckdb_packaging import setuptools_scm_version - importlib.reload(setuptools_scm_version) - - assert setuptools_scm_version._bump_version("1.2.3", 5, False) == "1.2.4.dev5" + """Test bump_version on bugfix branch.""" + assert _bump_version("1.2.3", 5, False) == "1.2.4.dev5" + @patch.dict('os.environ', {'MAIN_BRANCH_VERSIONING': '1'}) def test_bump_version_dirty(self): """Test bump_version with dirty working directory.""" assert _bump_version("1.2.3", 0, True) == "1.3.0.dev0" assert _bump_version("1.2.3.post1", 0, True) == "1.2.3.post2.dev0" + @patch.dict('os.environ', {'MAIN_BRANCH_VERSIONING': '1'}) def test_version_scheme_function(self): """Test the version_scheme function that setuptools_scm calls.""" # Mock setuptools_scm version object From 7b2c189d82a7a9522f28b110684606b8501445fd Mon Sep 17 00:00:00 2001 From: Evert Date: Fri, 18 Jul 2025 10:51:27 +0200 Subject: [PATCH 011/472] Bumped submodule --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 0b83e5d2..641e95d1 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 0b83e5d2f68bc02dfefde74b846bd039f078affa +Subproject commit 641e95d140ca7728085445a919c3e8d436aaf0c1 From 553e5ee0849acce18757c3aaa1daded779b557da Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 18 Jul 2025 16:28:08 +0200 Subject: [PATCH 012/472] Upload to PyPI --- .github/workflows/on_external_dispatch.yml | 18 ++-- .github/workflows/upload_to_pypi.yml | 103 +++++++++++++++++++++ 2 files changed, 112 insertions(+), 9 deletions(-) create mode 100644 .github/workflows/upload_to_pypi.yml diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index 33a807db..d863063e 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -1,4 +1,4 @@ -name: Builds triggered externally by DuckDB +name: External Dispatch on: workflow_dispatch: inputs: @@ -12,7 +12,7 @@ on: required: false publish_packages: type: boolean - description: Publish packages on S3 and PyPI? + description: Publish to S3 required: true default: false @@ -27,8 +27,8 @@ jobs: duckdb_git_ref: ${{ inputs.duckdb-sha }} force_version: ${{ inputs.force_version }} - upload_to_staging: - name: Upload Artifacts to staging + publish-s3: + name: Publish Artifacts to the S3 Staging Bucket runs-on: ubuntu-latest needs: [ externally_triggered_build ] if: ${{ github.repository_owner == 'duckdb' && inputs.publish_packages }} @@ -48,10 +48,10 @@ jobs: aws-secret-access-key: ${{ secrets.S3_DUCKDB_STAGING_KEY }} - name: Upload artifacts to S3 bucket + # semantics: if a version is forced then we upload into a folder by the version name, otherwise we upload + # into a folder that is named -. Only the latter will be discovered be + # upload_to_pypi.yml. shell: bash run: | - DUCKDB_SHA="${{ inputs.duckdb-sha }}" - aws s3 cp \ - artifacts \ - s3://duckdb-staging/${DUCKDB_SHA:0:10}/${{ github.repository }}/ \ - --recursive + FOLDER="${{ inputs.force_version != '' && inputs.force_version || format('{0}-{1}', github.run_id, github.run_attempt) }}" + aws s3 cp artifacts s3://duckdb-staging/${{ github.repository }}/${FOLDER}/ --recursive diff --git a/.github/workflows/upload_to_pypi.yml b/.github/workflows/upload_to_pypi.yml new file mode 100644 index 00000000..48d0ec02 --- /dev/null +++ b/.github/workflows/upload_to_pypi.yml @@ -0,0 +1,103 @@ +name: upload_to_pypi.yml +on: + # this workflow runs after the below workflows are completed + workflow_run: + workflows: [ External Dispatch ] + types: [ completed ] + branches: + - main + - v*.*-* + workflow_dispatch: + inputs: + environment: + description: Environment to run in () + type: choice + required: true + default: test.pypi + options: + - test.pypi + - production.pypi + artifact_folder: + description: The S3 folder that contains the artifacts (s3://duckdb-staging/duckdb/duckdb-python/) + type: string + required: true + +jobs: + prepare: + name: Prepare and guard upload + if: ${{ github.repository_owner == 'duckdb' && ( github.event.workflow_run.conclusion == 'success' || github.event_name != 'workflow_run' ) }} + runs-on: ubuntu-latest + outputs: + s3_prefix: ${{ steps.get_s3_prefix.outputs.s3_prefix }} + steps: + - name: Determine S3 Prefix + id: get_s3_prefix + run: | + if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then + artifact_folder="${{ inputs.artifact_folder }}" + elif [[ -n "${{ github.event.workflow_run.id }}" && -n "${{ github.event.workflow_run.run_attempt }}" ]]; then + artifact_folder="${{ github.event.workflow_run.id }}-${{ github.event.workflow_run.run_attempt }}" + fi + if [[ -n "${artifact_folder}" ]]; then + s3_prefix="${{ github.repository }}/${artifact_folder}" + echo "Created S3 prefix: ${s3_prefix}" + echo "s3_prefix=${s3_prefix}" >> $GITHUB_OUTPUT + else + echo "Can't determine S3 prefix for event: ${{ github.event_name }}. Quitting." + exit 1 + fi + + - name: Authenticate With AWS + uses: aws-actions/configure-aws-credentials@v4 + with: + aws-region: 'us-east-2' + aws-access-key-id: ${{ secrets.S3_DUCKDB_STAGING_ID }} + aws-secret-access-key: ${{ secrets.S3_DUCKDB_STAGING_KEY }} + + - name: Check S3 Prefix + shell: bash + run: | + if [[ $(aws s3api list-objects-v2 \ + --bucket duckdb-staging \ + --prefix "${{ steps.get_s3_prefix.outputs.s3_prefix }}/" \ + --max-items 1 \ + --query 'Contents[0].Key' \ + --output text) == "None" ]]; then + echo "Prefix does not exist: ${{ steps.get_s3_prefix.outputs.s3_prefix }}" + echo "${{ github.event_name == 'workflow_run' && 'Possibly built a stable release?' || 'Unexpected error' }}" + exit 1 + fi + + publish-pypi: + name: Publish Artifacts to PyPI + needs: [ prepare ] + runs-on: ubuntu-latest + environment: + name: ${{ github.event_name == 'workflow_dispatch' && inputs.environment || 'test.pypi' }} + if: ${{ vars.PYPI_URL != '' }} + permissions: + # this is needed for the OIDC flow that is used with trusted publishing on PyPI + id-token: write + steps: + - name: Authenticate With AWS + uses: aws-actions/configure-aws-credentials@v4 + with: + aws-region: 'us-east-2' + aws-access-key-id: ${{ secrets.S3_DUCKDB_STAGING_ID }} + aws-secret-access-key: ${{ secrets.S3_DUCKDB_STAGING_KEY }} + + - name: Download Artifacts From S3 + env: + S3_URL: 's3://duckdb-staging/${{ needs.prepare.outputs.s3_prefix }}/' + AWS_ACCESS_KEY_ID: ${{ secrets.S3_DUCKDB_STAGING_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_DUCKDB_STAGING_KEY }} + run: | + mkdir packages + aws s3 cp --recursive "${S3_URL}" packages + + - name: Upload artifacts to PyPI + if: ${{ vars.PYPI_URL != '' }} + uses: pypa/gh-action-pypi-publish@release/v1 + with: + repository-url: ${{ vars.PYPI_URL }} + packages-dir: packages From 557f5a2cdb0a9e6c52073cd3f5c647332587c609 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 18 Jul 2025 22:30:03 +0200 Subject: [PATCH 013/472] Debugging upload --- .github/workflows/upload_to_pypi.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/upload_to_pypi.yml b/.github/workflows/upload_to_pypi.yml index 48d0ec02..e239b158 100644 --- a/.github/workflows/upload_to_pypi.yml +++ b/.github/workflows/upload_to_pypi.yml @@ -74,7 +74,7 @@ jobs: runs-on: ubuntu-latest environment: name: ${{ github.event_name == 'workflow_dispatch' && inputs.environment || 'test.pypi' }} - if: ${{ vars.PYPI_URL != '' }} + #if: ${{ vars.PYPI_URL != '' }} permissions: # this is needed for the OIDC flow that is used with trusted publishing on PyPI id-token: write From b60c1977d3fbfc2c949502b9f99d322047d07e00 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 18 Jul 2025 22:46:19 +0200 Subject: [PATCH 014/472] Small pypi upload fixes --- .github/workflows/upload_to_pypi.yml | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/.github/workflows/upload_to_pypi.yml b/.github/workflows/upload_to_pypi.yml index e239b158..2c933436 100644 --- a/.github/workflows/upload_to_pypi.yml +++ b/.github/workflows/upload_to_pypi.yml @@ -22,6 +22,8 @@ on: type: string required: true +concurrency: ${{ inputs.artifact_folder || format('{0}-{1}', github.event.workflow_run.id, github.event.workflow_run.run_attempt) }} + jobs: prepare: name: Prepare and guard upload @@ -33,11 +35,7 @@ jobs: - name: Determine S3 Prefix id: get_s3_prefix run: | - if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then - artifact_folder="${{ inputs.artifact_folder }}" - elif [[ -n "${{ github.event.workflow_run.id }}" && -n "${{ github.event.workflow_run.run_attempt }}" ]]; then - artifact_folder="${{ github.event.workflow_run.id }}-${{ github.event.workflow_run.run_attempt }}" - fi + artifact_folder="${{ inputs.artifact_folder || format('{0}-{1}', github.event.workflow_run.id, github.event.workflow_run.run_attempt) }}" if [[ -n "${artifact_folder}" ]]; then s3_prefix="${{ github.repository }}/${artifact_folder}" echo "Created S3 prefix: ${s3_prefix}" @@ -74,11 +72,18 @@ jobs: runs-on: ubuntu-latest environment: name: ${{ github.event_name == 'workflow_dispatch' && inputs.environment || 'test.pypi' }} - #if: ${{ vars.PYPI_URL != '' }} permissions: # this is needed for the OIDC flow that is used with trusted publishing on PyPI id-token: write steps: + - name: Fail if PYPI_URL is not set + if: ${{ vars.PYPI_URL == '' }} + shell: bash + run: | + env_name="${{ github.event_name == 'workflow_dispatch' && inputs.environment || 'test.pypi' }}" + echo "Error: vars.PYPI_URL is not set in the resolved environment (${env_name})" + exit 1 + - name: Authenticate With AWS uses: aws-actions/configure-aws-credentials@v4 with: @@ -96,7 +101,6 @@ jobs: aws s3 cp --recursive "${S3_URL}" packages - name: Upload artifacts to PyPI - if: ${{ vars.PYPI_URL != '' }} uses: pypa/gh-action-pypi-publish@release/v1 with: repository-url: ${{ vars.PYPI_URL }} From 41171a9f2613d1df3ca45109c1e50457dff43093 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 18 Jul 2025 16:28:08 +0200 Subject: [PATCH 015/472] Upload to PyPI --- .github/workflows/on_external_dispatch.yml | 18 ++-- .github/workflows/upload_to_pypi.yml | 103 +++++++++++++++++++++ 2 files changed, 112 insertions(+), 9 deletions(-) create mode 100644 .github/workflows/upload_to_pypi.yml diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index 33a807db..d863063e 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -1,4 +1,4 @@ -name: Builds triggered externally by DuckDB +name: External Dispatch on: workflow_dispatch: inputs: @@ -12,7 +12,7 @@ on: required: false publish_packages: type: boolean - description: Publish packages on S3 and PyPI? + description: Publish to S3 required: true default: false @@ -27,8 +27,8 @@ jobs: duckdb_git_ref: ${{ inputs.duckdb-sha }} force_version: ${{ inputs.force_version }} - upload_to_staging: - name: Upload Artifacts to staging + publish-s3: + name: Publish Artifacts to the S3 Staging Bucket runs-on: ubuntu-latest needs: [ externally_triggered_build ] if: ${{ github.repository_owner == 'duckdb' && inputs.publish_packages }} @@ -48,10 +48,10 @@ jobs: aws-secret-access-key: ${{ secrets.S3_DUCKDB_STAGING_KEY }} - name: Upload artifacts to S3 bucket + # semantics: if a version is forced then we upload into a folder by the version name, otherwise we upload + # into a folder that is named -. Only the latter will be discovered be + # upload_to_pypi.yml. shell: bash run: | - DUCKDB_SHA="${{ inputs.duckdb-sha }}" - aws s3 cp \ - artifacts \ - s3://duckdb-staging/${DUCKDB_SHA:0:10}/${{ github.repository }}/ \ - --recursive + FOLDER="${{ inputs.force_version != '' && inputs.force_version || format('{0}-{1}', github.run_id, github.run_attempt) }}" + aws s3 cp artifacts s3://duckdb-staging/${{ github.repository }}/${FOLDER}/ --recursive diff --git a/.github/workflows/upload_to_pypi.yml b/.github/workflows/upload_to_pypi.yml new file mode 100644 index 00000000..48d0ec02 --- /dev/null +++ b/.github/workflows/upload_to_pypi.yml @@ -0,0 +1,103 @@ +name: upload_to_pypi.yml +on: + # this workflow runs after the below workflows are completed + workflow_run: + workflows: [ External Dispatch ] + types: [ completed ] + branches: + - main + - v*.*-* + workflow_dispatch: + inputs: + environment: + description: Environment to run in () + type: choice + required: true + default: test.pypi + options: + - test.pypi + - production.pypi + artifact_folder: + description: The S3 folder that contains the artifacts (s3://duckdb-staging/duckdb/duckdb-python/) + type: string + required: true + +jobs: + prepare: + name: Prepare and guard upload + if: ${{ github.repository_owner == 'duckdb' && ( github.event.workflow_run.conclusion == 'success' || github.event_name != 'workflow_run' ) }} + runs-on: ubuntu-latest + outputs: + s3_prefix: ${{ steps.get_s3_prefix.outputs.s3_prefix }} + steps: + - name: Determine S3 Prefix + id: get_s3_prefix + run: | + if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then + artifact_folder="${{ inputs.artifact_folder }}" + elif [[ -n "${{ github.event.workflow_run.id }}" && -n "${{ github.event.workflow_run.run_attempt }}" ]]; then + artifact_folder="${{ github.event.workflow_run.id }}-${{ github.event.workflow_run.run_attempt }}" + fi + if [[ -n "${artifact_folder}" ]]; then + s3_prefix="${{ github.repository }}/${artifact_folder}" + echo "Created S3 prefix: ${s3_prefix}" + echo "s3_prefix=${s3_prefix}" >> $GITHUB_OUTPUT + else + echo "Can't determine S3 prefix for event: ${{ github.event_name }}. Quitting." + exit 1 + fi + + - name: Authenticate With AWS + uses: aws-actions/configure-aws-credentials@v4 + with: + aws-region: 'us-east-2' + aws-access-key-id: ${{ secrets.S3_DUCKDB_STAGING_ID }} + aws-secret-access-key: ${{ secrets.S3_DUCKDB_STAGING_KEY }} + + - name: Check S3 Prefix + shell: bash + run: | + if [[ $(aws s3api list-objects-v2 \ + --bucket duckdb-staging \ + --prefix "${{ steps.get_s3_prefix.outputs.s3_prefix }}/" \ + --max-items 1 \ + --query 'Contents[0].Key' \ + --output text) == "None" ]]; then + echo "Prefix does not exist: ${{ steps.get_s3_prefix.outputs.s3_prefix }}" + echo "${{ github.event_name == 'workflow_run' && 'Possibly built a stable release?' || 'Unexpected error' }}" + exit 1 + fi + + publish-pypi: + name: Publish Artifacts to PyPI + needs: [ prepare ] + runs-on: ubuntu-latest + environment: + name: ${{ github.event_name == 'workflow_dispatch' && inputs.environment || 'test.pypi' }} + if: ${{ vars.PYPI_URL != '' }} + permissions: + # this is needed for the OIDC flow that is used with trusted publishing on PyPI + id-token: write + steps: + - name: Authenticate With AWS + uses: aws-actions/configure-aws-credentials@v4 + with: + aws-region: 'us-east-2' + aws-access-key-id: ${{ secrets.S3_DUCKDB_STAGING_ID }} + aws-secret-access-key: ${{ secrets.S3_DUCKDB_STAGING_KEY }} + + - name: Download Artifacts From S3 + env: + S3_URL: 's3://duckdb-staging/${{ needs.prepare.outputs.s3_prefix }}/' + AWS_ACCESS_KEY_ID: ${{ secrets.S3_DUCKDB_STAGING_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_DUCKDB_STAGING_KEY }} + run: | + mkdir packages + aws s3 cp --recursive "${S3_URL}" packages + + - name: Upload artifacts to PyPI + if: ${{ vars.PYPI_URL != '' }} + uses: pypa/gh-action-pypi-publish@release/v1 + with: + repository-url: ${{ vars.PYPI_URL }} + packages-dir: packages From 7d82f3d8ce2d9dd2476de6ebf0853c64a3ad6cf1 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 18 Jul 2025 22:30:03 +0200 Subject: [PATCH 016/472] Debugging upload --- .github/workflows/upload_to_pypi.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/upload_to_pypi.yml b/.github/workflows/upload_to_pypi.yml index 48d0ec02..e239b158 100644 --- a/.github/workflows/upload_to_pypi.yml +++ b/.github/workflows/upload_to_pypi.yml @@ -74,7 +74,7 @@ jobs: runs-on: ubuntu-latest environment: name: ${{ github.event_name == 'workflow_dispatch' && inputs.environment || 'test.pypi' }} - if: ${{ vars.PYPI_URL != '' }} + #if: ${{ vars.PYPI_URL != '' }} permissions: # this is needed for the OIDC flow that is used with trusted publishing on PyPI id-token: write From 04a3dbdf3c953e6b649abbf10fc2a1ddc5925a4b Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 18 Jul 2025 22:46:19 +0200 Subject: [PATCH 017/472] Small pypi upload fixes --- .github/workflows/upload_to_pypi.yml | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/.github/workflows/upload_to_pypi.yml b/.github/workflows/upload_to_pypi.yml index e239b158..2c933436 100644 --- a/.github/workflows/upload_to_pypi.yml +++ b/.github/workflows/upload_to_pypi.yml @@ -22,6 +22,8 @@ on: type: string required: true +concurrency: ${{ inputs.artifact_folder || format('{0}-{1}', github.event.workflow_run.id, github.event.workflow_run.run_attempt) }} + jobs: prepare: name: Prepare and guard upload @@ -33,11 +35,7 @@ jobs: - name: Determine S3 Prefix id: get_s3_prefix run: | - if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then - artifact_folder="${{ inputs.artifact_folder }}" - elif [[ -n "${{ github.event.workflow_run.id }}" && -n "${{ github.event.workflow_run.run_attempt }}" ]]; then - artifact_folder="${{ github.event.workflow_run.id }}-${{ github.event.workflow_run.run_attempt }}" - fi + artifact_folder="${{ inputs.artifact_folder || format('{0}-{1}', github.event.workflow_run.id, github.event.workflow_run.run_attempt) }}" if [[ -n "${artifact_folder}" ]]; then s3_prefix="${{ github.repository }}/${artifact_folder}" echo "Created S3 prefix: ${s3_prefix}" @@ -74,11 +72,18 @@ jobs: runs-on: ubuntu-latest environment: name: ${{ github.event_name == 'workflow_dispatch' && inputs.environment || 'test.pypi' }} - #if: ${{ vars.PYPI_URL != '' }} permissions: # this is needed for the OIDC flow that is used with trusted publishing on PyPI id-token: write steps: + - name: Fail if PYPI_URL is not set + if: ${{ vars.PYPI_URL == '' }} + shell: bash + run: | + env_name="${{ github.event_name == 'workflow_dispatch' && inputs.environment || 'test.pypi' }}" + echo "Error: vars.PYPI_URL is not set in the resolved environment (${env_name})" + exit 1 + - name: Authenticate With AWS uses: aws-actions/configure-aws-credentials@v4 with: @@ -96,7 +101,6 @@ jobs: aws s3 cp --recursive "${S3_URL}" packages - name: Upload artifacts to PyPI - if: ${{ vars.PYPI_URL != '' }} uses: pypa/gh-action-pypi-publish@release/v1 with: repository-url: ${{ vars.PYPI_URL }} From a65f44347732e9903f994ff22e2448260a422090 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 21 Jul 2025 10:50:33 +0200 Subject: [PATCH 018/472] Merge changes from "New Arrow C-API #18246" --- external/duckdb | 2 +- src/duckdb_py/arrow/arrow_array_stream.cpp | 8 +- .../arrow/arrow_array_stream.hpp | 2 +- tests/fast/adbc/test_adbc.py | 90 ++++++++++++++++++- 4 files changed, 94 insertions(+), 8 deletions(-) diff --git a/external/duckdb b/external/duckdb index 641e95d1..61f07c32 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 641e95d140ca7728085445a919c3e8d436aaf0c1 +Subproject commit 61f07c3221e01674d8fe40b4a25364ba2f3159a7 diff --git a/src/duckdb_py/arrow/arrow_array_stream.cpp b/src/duckdb_py/arrow/arrow_array_stream.cpp index 6094dcb1..533c31ed 100644 --- a/src/duckdb_py/arrow/arrow_array_stream.cpp +++ b/src/duckdb_py/arrow/arrow_array_stream.cpp @@ -41,10 +41,8 @@ py::object PythonTableArrowArrayStreamFactory::ProduceScanner(DBConfig &config, D_ASSERT(!py::isinstance(arrow_obj_handle)); ArrowSchemaWrapper schema; PythonTableArrowArrayStreamFactory::GetSchemaInternal(arrow_obj_handle, schema); - vector unused_names; - vector unused_types; - ArrowTableType arrow_table; - ArrowTableFunction::PopulateArrowTableType(config, arrow_table, schema, unused_names, unused_types); + ArrowTableSchema arrow_table; + ArrowTableFunction::PopulateArrowTableSchema(config, arrow_table, schema.arrow_schema); auto filters = parameters.filters; auto &column_list = parameters.projected_columns.columns; @@ -466,7 +464,7 @@ py::object PythonTableArrowArrayStreamFactory::TransformFilter(TableFilterSet &f std::unordered_map &columns, unordered_map filter_to_col, const ClientProperties &config, - const ArrowTableType &arrow_table) { + const ArrowTableSchema &arrow_table) { auto &filters_map = filter_collection.filters; py::object expression = py::none(); diff --git a/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp b/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp index 494be16a..7eb6d20b 100644 --- a/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp +++ b/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp @@ -89,7 +89,7 @@ class PythonTableArrowArrayStreamFactory { //! We transform a TableFilterSet to an Arrow Expression Object static py::object TransformFilter(TableFilterSet &filters, std::unordered_map &columns, unordered_map filter_to_col, - const ClientProperties &client_properties, const ArrowTableType &arrow_table); + const ClientProperties &client_properties, const ArrowTableSchema &arrow_table); static py::object ProduceScanner(DBConfig &config, py::object &arrow_scanner, py::handle &arrow_obj_handle, ArrowStreamParameters ¶meters, const ClientProperties &client_properties); diff --git a/tests/fast/adbc/test_adbc.py b/tests/fast/adbc/test_adbc.py index 3f9111bc..663563cf 100644 --- a/tests/fast/adbc/test_adbc.py +++ b/tests/fast/adbc/test_adbc.py @@ -3,6 +3,7 @@ import sys import datetime import os +import numpy as np if sys.version_info < (3, 9): pytest.skip( @@ -224,7 +225,7 @@ def test_insertion(duck_conn): with duck_conn.cursor() as cursor: with pytest.raises( adbc_driver_manager_lib.InternalError, - match=r'Failed to create table \'ingest_table\': Table with name "ingest_table" already exists!', + match=r'Table with name "ingest_table" already exists!', ): cursor.adbc_ingest("ingest_table", table, "create") cursor.adbc_ingest("ingest_table", table, "append") @@ -277,6 +278,93 @@ def test_read(duck_conn): } +def test_large_chunk(tmp_path): + num_chunks = 3 + chunk_size = 10_000 + + # Create data for each chunk + chunks_col1 = [pyarrow.array(np.random.randint(0, 100, chunk_size)) for _ in range(num_chunks)] + chunks_col2 = [pyarrow.array(np.random.rand(chunk_size)) for _ in range(num_chunks)] + chunks_col3 = [ + pyarrow.array([f"str_{i}" for i in range(j * chunk_size, (j + 1) * chunk_size)]) for j in range(num_chunks) + ] + + # Create chunked arrays + col1 = pyarrow.chunked_array(chunks_col1) + col2 = pyarrow.chunked_array(chunks_col2) + col3 = pyarrow.chunked_array(chunks_col3) + + # Create the table + table = pyarrow.table([col1, col2, col3], names=["ints", "floats", "strings"]) + + db = os.path.join(tmp_path, "tmp.db") + if os.path.exists(db): + os.remove(db) + db_kwargs = {"path": f"{db}"} + with adbc_driver_manager.connect( + driver=driver_path, + entrypoint="duckdb_adbc_init", + db_kwargs=db_kwargs, + autocommit=True, + ) as conn: + with conn.cursor() as cur: + cur.adbc_ingest("ingest", table, "create") + cur.execute("SELECT count(*) from ingest") + assert cur.fetch_arrow_table().to_pydict() == {'count_star()': [30_000]} + + +def test_dictionary_data(tmp_path): + data = ['apple', 'banana', 'apple', 'orange', 'banana', 'banana'] + + dict_type = pyarrow.dictionary(index_type=pyarrow.int32(), value_type=pyarrow.string()) + dict_array = pyarrow.array(data, type=dict_type) + + # Wrap in a table + table = pyarrow.table({'fruits': dict_array}) + db = os.path.join(tmp_path, "tmp.db") + if os.path.exists(db): + os.remove(db) + db_kwargs = {"path": f"{db}"} + with adbc_driver_manager.connect( + driver=driver_path, + entrypoint="duckdb_adbc_init", + db_kwargs=db_kwargs, + autocommit=True, + ) as conn: + with conn.cursor() as cur: + cur.adbc_ingest("ingest", table, "create") + cur.execute("from ingest") + assert cur.fetch_arrow_table().to_pydict() == { + 'fruits': ['apple', 'banana', 'apple', 'orange', 'banana', 'banana'] + } + + +def test_ree_data(tmp_path): + run_ends = pyarrow.array([3, 5, 6], type=pyarrow.int32()) # positions: [0-2], [3-4], [5] + values = pyarrow.array(["apple", "banana", "orange"], type=pyarrow.string()) + + ree_array = pyarrow.RunEndEncodedArray.from_arrays(run_ends, values) + + table = pyarrow.table({"fruits": ree_array}) + + db = os.path.join(tmp_path, "tmp.db") + if os.path.exists(db): + os.remove(db) + db_kwargs = {"path": f"{db}"} + with adbc_driver_manager.connect( + driver=driver_path, + entrypoint="duckdb_adbc_init", + db_kwargs=db_kwargs, + autocommit=True, + ) as conn: + with conn.cursor() as cur: + cur.adbc_ingest("ingest", table, "create") + cur.execute("from ingest") + assert cur.fetch_arrow_table().to_pydict() == { + 'fruits': ['apple', 'apple', 'apple', 'banana', 'banana', 'orange'] + } + + def sorted_get_objects(catalogs): res = [] for catalog in sorted(catalogs, key=lambda cat: cat['catalog_name']): From a57ca8d4ad4e8863b3e47758e43abad51b8dd597 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 21 Jul 2025 11:03:12 +0200 Subject: [PATCH 019/472] Bump submodule --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 0b83e5d2..c1c3d888 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 0b83e5d2f68bc02dfefde74b846bd039f078affa +Subproject commit c1c3d88864ff367936b7c667e0ac071bceed5215 From 5a06c58daac021a13bf9aedcddcacf8643969db1 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 21 Jul 2025 11:25:38 +0200 Subject: [PATCH 020/472] Push given duckdb sha as submodule ref --- .github/workflows/on_external_dispatch.yml | 41 +++++++++++++++++-- .github/workflows/on_pr.yml | 2 +- .github/workflows/on_push_postrelease.yml | 2 +- .../{pypi_packaging.yml => packaging.yml} | 0 4 files changed, 40 insertions(+), 5 deletions(-) rename .github/workflows/{pypi_packaging.yml => packaging.yml} (100%) diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index d863063e..de751cf4 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -6,6 +6,10 @@ on: type: string description: The DuckDB SHA to build against required: true + commit_duckdb_sha: + type: boolean + description: Set (commit) the DuckDB submodule to the given SHA + default: false force_version: type: string description: Force version (vX.Y.Z-((rc|post)N)) @@ -14,12 +18,44 @@ on: type: boolean description: Publish to S3 required: true - default: false + default: true + +defaults: + run: + shell: bash jobs: + commit_duckdb_submodule_sha: + name: Commit the submodule to the given DuckDB sha + if: ${{ inputs.commit_duckdb_sha }} + runs-on: ubuntu-24.04 + steps: + + - name: Checkout DuckDB Python + uses: actions/checkout@v4 + with: + ref: ${{ github.ref }} + fetch-depth: 0 + submodules: true + + - name: Checkout DuckDB + run: | + cd external/duckdb + git fetch origin + git checkout ${{ inputs.duckdb_git_ref }} + + - name: Commit and push new submodule ref + # see https://github.com/actions/checkout?tab=readme-ov-file#push-a-commit-to-a-pr-using-the-built-in-token + run: | + git config user.name "github-actions[bot]" + git config user.email "41898282+github-actions[bot]@users.noreply.github.com" + git add external/duckdb + git commit -m "Update submodule ref" + git push + externally_triggered_build: name: Build and test releases - uses: ./.github/workflows/pypi_packaging.yml + uses: ./.github/workflows/packaging.yml with: minimal: false testsuite: all @@ -51,7 +87,6 @@ jobs: # semantics: if a version is forced then we upload into a folder by the version name, otherwise we upload # into a folder that is named -. Only the latter will be discovered be # upload_to_pypi.yml. - shell: bash run: | FOLDER="${{ inputs.force_version != '' && inputs.force_version || format('{0}-{1}', github.run_id, github.run_attempt) }}" aws s3 cp artifacts s3://duckdb-staging/${{ github.repository }}/${FOLDER}/ --recursive diff --git a/.github/workflows/on_pr.yml b/.github/workflows/on_pr.yml index 81fd3247..fd7d0df6 100644 --- a/.github/workflows/on_pr.yml +++ b/.github/workflows/on_pr.yml @@ -24,7 +24,7 @@ jobs: name: Build a minimal set of packages and run all tests on them # Skip packaging tests for draft PRs if: ${{ github.event_name != 'pull_request' || github.event.pull_request.draft == false }} - uses: ./.github/workflows/pypi_packaging.yml + uses: ./.github/workflows/packaging.yml with: minimal: true testsuite: all diff --git a/.github/workflows/on_push_postrelease.yml b/.github/workflows/on_push_postrelease.yml index d12ccdf8..51dedaa6 100644 --- a/.github/workflows/on_push_postrelease.yml +++ b/.github/workflows/on_push_postrelease.yml @@ -33,7 +33,7 @@ jobs: packaging_test: name: Build and test post release packages and upload to S3 needs: extract_duckdb_tag - uses: ./.github/workflows/pypi_packaging.yml + uses: ./.github/workflows/packaging.yml with: minimal: false testsuite: all diff --git a/.github/workflows/pypi_packaging.yml b/.github/workflows/packaging.yml similarity index 100% rename from .github/workflows/pypi_packaging.yml rename to .github/workflows/packaging.yml From 85e7cfd64a88a96e6cba7df906c821138742e0b5 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 21 Jul 2025 12:51:28 +0200 Subject: [PATCH 021/472] Workflow naming and defaults fixes --- .github/workflows/on_external_dispatch.yml | 24 ++++++++--------- .github/workflows/on_pr.yml | 2 +- .github/workflows/on_push_postrelease.yml | 6 ++--- .github/workflows/packaging.yml | 30 +++++++++++----------- .github/workflows/upload_to_pypi.yml | 8 +++--- 5 files changed, 35 insertions(+), 35 deletions(-) diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index de751cf4..5ab7f8a7 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -6,19 +6,19 @@ on: type: string description: The DuckDB SHA to build against required: true - commit_duckdb_sha: + commit-duckdb-sha: type: boolean description: Set (commit) the DuckDB submodule to the given SHA - default: false - force_version: + default: true + force-version: type: string description: Force version (vX.Y.Z-((rc|post)N)) required: false - publish_packages: + publish-packages: type: boolean description: Publish to S3 required: true - default: true + default: false defaults: run: @@ -27,7 +27,7 @@ defaults: jobs: commit_duckdb_submodule_sha: name: Commit the submodule to the given DuckDB sha - if: ${{ inputs.commit_duckdb_sha }} + if: ${{ inputs.commit-duckdb-sha }} runs-on: ubuntu-24.04 steps: @@ -59,15 +59,15 @@ jobs: with: minimal: false testsuite: all - git_ref: ${{ github.ref }} - duckdb_git_ref: ${{ inputs.duckdb-sha }} - force_version: ${{ inputs.force_version }} + git-ref: ${{ github.ref }} + duckdb-git-ref: ${{ inputs.duckdb-sha }} + force-version: ${{ inputs.force-version }} - publish-s3: + publish_s3: name: Publish Artifacts to the S3 Staging Bucket runs-on: ubuntu-latest needs: [ externally_triggered_build ] - if: ${{ github.repository_owner == 'duckdb' && inputs.publish_packages }} + if: ${{ github.repository_owner == 'duckdb' && inputs.publish-packages }} steps: - name: Fetch artifacts uses: actions/download-artifact@v4 @@ -88,5 +88,5 @@ jobs: # into a folder that is named -. Only the latter will be discovered be # upload_to_pypi.yml. run: | - FOLDER="${{ inputs.force_version != '' && inputs.force_version || format('{0}-{1}', github.run_id, github.run_attempt) }}" + FOLDER="${{ inputs.force-version != '' && inputs.force-version || format('{0}-{1}', github.run_id, github.run_attempt) }}" aws s3 cp artifacts s3://duckdb-staging/${{ github.repository }}/${FOLDER}/ --recursive diff --git a/.github/workflows/on_pr.yml b/.github/workflows/on_pr.yml index fd7d0df6..d8bfe825 100644 --- a/.github/workflows/on_pr.yml +++ b/.github/workflows/on_pr.yml @@ -28,7 +28,7 @@ jobs: with: minimal: true testsuite: all - duckdb_git_ref: ${{ github.base_ref }} + duckdb-git-ref: ${{ github.base_ref }} coverage_test: name: Run coverage tests diff --git a/.github/workflows/on_push_postrelease.yml b/.github/workflows/on_push_postrelease.yml index 51dedaa6..14575754 100644 --- a/.github/workflows/on_push_postrelease.yml +++ b/.github/workflows/on_push_postrelease.yml @@ -37,6 +37,6 @@ jobs: with: minimal: false testsuite: all - git_ref: ${{ github.ref }} - duckdb_git_ref: ${{ needs.extract_duckdb_tag.outputs.duckdb_version }} - force_version: ${{ github.ref_name }} + git-ref: ${{ github.ref }} + duckdb-git-ref: ${{ needs.extract_duckdb_tag.outputs.duckdb_version }} + force-version: ${{ github.ref_name }} diff --git a/.github/workflows/packaging.yml b/.github/workflows/packaging.yml index 84016645..977bc914 100644 --- a/.github/workflows/packaging.yml +++ b/.github/workflows/packaging.yml @@ -1,5 +1,5 @@ name: Packaging -run-name: Build ${{ inputs.minimal && 'minimal set of' || 'all' }} packages (version=${{ inputs.force_version != '' && inputs.force_version || 'dev' }}, tests=${{ inputs.testsuite }}, ref=${{ inputs.git_ref }}, duckdb ref=${{ inputs.duckdb_git_ref }}) +run-name: Build ${{ inputs.minimal && 'minimal set of' || 'all' }} packages (version=${{ inputs.force-version != '' && inputs.force-version || 'dev' }}, tests=${{ inputs.testsuite }}, ref=${{ inputs.git-ref }}, duckdb ref=${{ inputs.duckdb-git-ref }}) on: workflow_dispatch: inputs: @@ -16,16 +16,16 @@ on: - none - fast - all - git_ref: + git-ref: type: string description: Git ref of the DuckDB python package required: false - duckdb_git_ref: + duckdb-git-ref: type: string description: Git ref of DuckDB required: true default: refs/heads/main - force_version: + force-version: type: string description: Force version (vX.Y.Z-((rc|post)N)) required: false @@ -40,15 +40,15 @@ on: description: Testsuite to run (none, fast, all) required: true default: all - git_ref: + git-ref: type: string description: Git ref of the DuckDB python package required: false - duckdb_git_ref: + duckdb-git-ref: type: string description: Git ref of DuckDB required: false - force_version: + force-version: description: Force version (vX.Y.Z-((rc|post)N)) required: false type: string @@ -70,7 +70,7 @@ jobs: - name: Checkout DuckDB Python uses: actions/checkout@v4 with: - ref: ${{ inputs.git_ref }} + ref: ${{ inputs.git-ref }} fetch-depth: 0 submodules: true @@ -79,11 +79,11 @@ jobs: run: | cd external/duckdb git fetch origin - git checkout ${{ inputs.duckdb_git_ref }} + git checkout ${{ inputs.duckdb-git-ref }} - name: Set OVERRIDE_GIT_DESCRIBE - if: ${{ inputs.force_version != '' }} - run: echo "OVERRIDE_GIT_DESCRIBE=${{ inputs.force_version }}" >> $GITHUB_ENV + if: ${{ inputs.force-version != '' }} + run: echo "OVERRIDE_GIT_DESCRIBE=${{ inputs.force-version }}" >> $GITHUB_ENV - name: Install Astral UV uses: astral-sh/setup-uv@v6 @@ -153,7 +153,7 @@ jobs: - name: Checkout DuckDB Python uses: actions/checkout@v4 with: - ref: ${{ inputs.git_ref }} + ref: ${{ inputs.git-ref }} fetch-depth: 0 submodules: true @@ -162,12 +162,12 @@ jobs: run: | cd external/duckdb git fetch origin - git checkout ${{ inputs.duckdb_git_ref }} + git checkout ${{ inputs.duckdb-git-ref }} # Make sure that OVERRIDE_GIT_DESCRIBE is propagated to cibuildwhel's env, also when it's running linux builds - name: Set OVERRIDE_GIT_DESCRIBE - if: ${{ inputs.force_version != '' }} - run: echo "CIBW_ENVIRONMENT=OVERRIDE_GIT_DESCRIBE=${{ inputs.force_version }}" >> $GITHUB_ENV + if: ${{ inputs.force-version != '' }} + run: echo "CIBW_ENVIRONMENT=OVERRIDE_GIT_DESCRIBE=${{ inputs.force-version }}" >> $GITHUB_ENV # Install Astral UV, which will be used as build-frontend for cibuildwheel - uses: astral-sh/setup-uv@v6 diff --git a/.github/workflows/upload_to_pypi.yml b/.github/workflows/upload_to_pypi.yml index 2c933436..bce4c315 100644 --- a/.github/workflows/upload_to_pypi.yml +++ b/.github/workflows/upload_to_pypi.yml @@ -17,12 +17,12 @@ on: options: - test.pypi - production.pypi - artifact_folder: - description: The S3 folder that contains the artifacts (s3://duckdb-staging/duckdb/duckdb-python/) + artifact-folder: + description: The S3 folder that contains the artifacts (s3://duckdb-staging/duckdb/duckdb-python/) type: string required: true -concurrency: ${{ inputs.artifact_folder || format('{0}-{1}', github.event.workflow_run.id, github.event.workflow_run.run_attempt) }} +concurrency: ${{ inputs.artifact-folder || format('{0}-{1}', github.event.workflow_run.id, github.event.workflow_run.run_attempt) }} jobs: prepare: @@ -35,7 +35,7 @@ jobs: - name: Determine S3 Prefix id: get_s3_prefix run: | - artifact_folder="${{ inputs.artifact_folder || format('{0}-{1}', github.event.workflow_run.id, github.event.workflow_run.run_attempt) }}" + artifact_folder="${{ inputs.artifact-folder || format('{0}-{1}', github.event.workflow_run.id, github.event.workflow_run.run_attempt) }}" if [[ -n "${artifact_folder}" ]]; then s3_prefix="${{ github.repository }}/${artifact_folder}" echo "Created S3 prefix: ${s3_prefix}" From 046769e77445d7cc5d8d4bcd4c9929c19e0a93b2 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 21 Jul 2025 14:13:48 +0200 Subject: [PATCH 022/472] Make sure win and osx flags are set correctly --- CMakeLists.txt | 2 ++ cmake/duckdb_loader.cmake | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7c6c5332..a9bc047d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,6 +4,8 @@ project(duckdb_py LANGUAGES CXX) # Always use C++11 set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) # Set the library name set(DUCKDB_PYTHON_LIB_NAME "_duckdb") diff --git a/cmake/duckdb_loader.cmake b/cmake/duckdb_loader.cmake index 78ad7784..5309924b 100644 --- a/cmake/duckdb_loader.cmake +++ b/cmake/duckdb_loader.cmake @@ -147,6 +147,23 @@ function(_duckdb_create_interface_target target_name) $<$:DUCKDB_DEBUG_MODE> ) + if(CMAKE_SYSTEM_NAME STREQUAL "Windows") + target_compile_options(${target_name} INTERFACE + /wd4244 # suppress Conversion from 'type1' to 'type2', possible loss of data + /wd4267 # suppress Conversion from ‘size_t’ to ‘type’, possible loss of data + /wd4200 # suppress Nonstandard extension used: zero-sized array in struct/union + /wd26451 /wd26495 # suppress Code Analysis + /D_CRT_SECURE_NO_WARNINGS # suppress warnings about unsafe functions + /D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR # see https://github.com/duckdblabs/duckdb-internal/issues/5151 + /utf-8 # treat source files as UTF-8 encoded + ) + elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin") + target_compile_options(${target_name} INTERFACE + -stdlib=libc++ # for libc++ in favor of older libstdc++ + -mmacosx-version-min=10.7 # minimum osx version compatibility + ) + endif() + # Link to the DuckDB static library target_link_libraries(${target_name} INTERFACE duckdb_static) From 3165e3e14bbaf87b3d7582d4df43fc09b7076f80 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 21 Jul 2025 14:13:48 +0200 Subject: [PATCH 023/472] Make sure win and osx flags are set correctly --- CMakeLists.txt | 2 ++ cmake/duckdb_loader.cmake | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7c6c5332..a9bc047d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,6 +4,8 @@ project(duckdb_py LANGUAGES CXX) # Always use C++11 set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) # Set the library name set(DUCKDB_PYTHON_LIB_NAME "_duckdb") diff --git a/cmake/duckdb_loader.cmake b/cmake/duckdb_loader.cmake index 78ad7784..5309924b 100644 --- a/cmake/duckdb_loader.cmake +++ b/cmake/duckdb_loader.cmake @@ -147,6 +147,23 @@ function(_duckdb_create_interface_target target_name) $<$:DUCKDB_DEBUG_MODE> ) + if(CMAKE_SYSTEM_NAME STREQUAL "Windows") + target_compile_options(${target_name} INTERFACE + /wd4244 # suppress Conversion from 'type1' to 'type2', possible loss of data + /wd4267 # suppress Conversion from ‘size_t’ to ‘type’, possible loss of data + /wd4200 # suppress Nonstandard extension used: zero-sized array in struct/union + /wd26451 /wd26495 # suppress Code Analysis + /D_CRT_SECURE_NO_WARNINGS # suppress warnings about unsafe functions + /D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR # see https://github.com/duckdblabs/duckdb-internal/issues/5151 + /utf-8 # treat source files as UTF-8 encoded + ) + elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin") + target_compile_options(${target_name} INTERFACE + -stdlib=libc++ # for libc++ in favor of older libstdc++ + -mmacosx-version-min=10.7 # minimum osx version compatibility + ) + endif() + # Link to the DuckDB static library target_link_libraries(${target_name} INTERFACE duckdb_static) From b16c9cee52e3f149411eb4276103f18719ebd7f4 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 21 Jul 2025 14:36:01 +0200 Subject: [PATCH 024/472] Dont fail if submodule already at given ref --- .github/workflows/on_external_dispatch.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index 5ab7f8a7..2ecccf2c 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -50,6 +50,10 @@ jobs: git config user.name "github-actions[bot]" git config user.email "41898282+github-actions[bot]@users.noreply.github.com" git add external/duckdb + if git diff --cached --quiet; then + echo "No changes to commit: submodule ref is unchanged." + exit 0 + fi git commit -m "Update submodule ref" git push From 09ce14a86b6f799b5a120ffcadd89ad265aec6ab Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 21 Jul 2025 14:36:01 +0200 Subject: [PATCH 025/472] Dont fail if submodule already at given ref --- .github/workflows/on_external_dispatch.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index 5ab7f8a7..2ecccf2c 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -50,6 +50,10 @@ jobs: git config user.name "github-actions[bot]" git config user.email "41898282+github-actions[bot]@users.noreply.github.com" git add external/duckdb + if git diff --cached --quiet; then + echo "No changes to commit: submodule ref is unchanged." + exit 0 + fi git commit -m "Update submodule ref" git push From 7383b64fd8e3c604f51aba3a26caa610eb6f2619 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 21 Jul 2025 17:01:54 +0200 Subject: [PATCH 026/472] Fix submodule commit --- .github/workflows/on_external_dispatch.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index 2ecccf2c..c1b52765 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -42,7 +42,7 @@ jobs: run: | cd external/duckdb git fetch origin - git checkout ${{ inputs.duckdb_git_ref }} + git checkout ${{ inputs.duckdb-sha }} - name: Commit and push new submodule ref # see https://github.com/actions/checkout?tab=readme-ov-file#push-a-commit-to-a-pr-using-the-built-in-token From ca66232b8f00b076eade26500207f4b4838d7161 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 21 Jul 2025 17:04:42 +0200 Subject: [PATCH 027/472] Workflow input names and defaults --- .github/workflows/on_external_dispatch.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index c1b52765..f8587b65 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -4,12 +4,12 @@ on: inputs: duckdb-sha: type: string - description: The DuckDB SHA to build against + description: The DuckDB submodule commit to build against required: true commit-duckdb-sha: type: boolean - description: Set (commit) the DuckDB submodule to the given SHA - default: true + description: Commit and push the DuckDB submodule ref + default: false force-version: type: string description: Force version (vX.Y.Z-((rc|post)N)) From cd76006ed5ae6a6290fd0666a7a91ac53208fa7a Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 21 Jul 2025 17:46:03 +0200 Subject: [PATCH 028/472] Disable pushing submodule ref --- .github/workflows/on_external_dispatch.yml | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index f8587b65..7a65e7cf 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -27,7 +27,9 @@ defaults: jobs: commit_duckdb_submodule_sha: name: Commit the submodule to the given DuckDB sha - if: ${{ inputs.commit-duckdb-sha }} + if: ${{ inputs.commit-duckdb-sha && false }} + outputs: + sha-after-commit: ${{ steps.commit_and_push.outputs.commit_sha }} runs-on: ubuntu-24.04 steps: @@ -52,18 +54,21 @@ jobs: git add external/duckdb if git diff --cached --quiet; then echo "No changes to commit: submodule ref is unchanged." - exit 0 + else + git commit -m "Update submodule ref" + git push fi - git commit -m "Update submodule ref" - git push + echo 'commit_sha=$( git log -1 --format="%h" )' >> $GITHUB_OUTPUT externally_triggered_build: name: Build and test releases + needs: commit_duckdb_submodule_sha + if: always() uses: ./.github/workflows/packaging.yml with: minimal: false testsuite: all - git-ref: ${{ github.ref }} + git-ref: ${{ needs.commit_duckdb_submodule_sha.outputs.sha-after-commit || github.ref }} duckdb-git-ref: ${{ inputs.duckdb-sha }} force-version: ${{ inputs.force-version }} From 836453d08a5d1345552e5e895a4ca67910fe0d07 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 22 Jul 2025 09:36:34 +0200 Subject: [PATCH 029/472] Make sure S3 job runs --- .github/workflows/on_external_dispatch.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index 7a65e7cf..3226c5b0 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -76,7 +76,7 @@ jobs: name: Publish Artifacts to the S3 Staging Bucket runs-on: ubuntu-latest needs: [ externally_triggered_build ] - if: ${{ github.repository_owner == 'duckdb' && inputs.publish-packages }} + if: ${{ github.repository_owner == 'duckdb' && inputs.publish-packages && needs.externally_triggered_build.result == 'success' && always() }} steps: - name: Fetch artifacts uses: actions/download-artifact@v4 From 911bdaf578f14ac79ec8277c942f40953a783813 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 22 Jul 2025 16:34:45 +0200 Subject: [PATCH 030/472] Fix problems with external dispatch workflow --- .github/workflows/on_external_dispatch.yml | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index 3226c5b0..947a7780 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -25,14 +25,14 @@ defaults: shell: bash jobs: - commit_duckdb_submodule_sha: + commit_submodule: name: Commit the submodule to the given DuckDB sha - if: ${{ inputs.commit-duckdb-sha && false }} outputs: - sha-after-commit: ${{ steps.commit_and_push.outputs.commit_sha }} + sha-after-commit: ${{ steps.git_commit_sha.outputs.commit_sha }} runs-on: ubuntu-24.04 + permissions: + contents: write steps: - - name: Checkout DuckDB Python uses: actions/checkout@v4 with: @@ -47,6 +47,7 @@ jobs: git checkout ${{ inputs.duckdb-sha }} - name: Commit and push new submodule ref + if: ${{ inputs.commit-duckdb-sha }} # see https://github.com/actions/checkout?tab=readme-ov-file#push-a-commit-to-a-pr-using-the-built-in-token run: | git config user.name "github-actions[bot]" @@ -58,17 +59,20 @@ jobs: git commit -m "Update submodule ref" git push fi - echo 'commit_sha=$( git log -1 --format="%h" )' >> $GITHUB_OUTPUT + + - name: Get the SHA of the latest commit + id: git_commit_sha + run: | + echo "commit_sha=$( git rev-parse HEAD )" >> $GITHUB_OUTPUT externally_triggered_build: name: Build and test releases - needs: commit_duckdb_submodule_sha - if: always() + needs: commit_submodule uses: ./.github/workflows/packaging.yml with: minimal: false testsuite: all - git-ref: ${{ needs.commit_duckdb_submodule_sha.outputs.sha-after-commit || github.ref }} + git-ref: ${{ needs.commit_submodule.outputs.sha-after-commit }} duckdb-git-ref: ${{ inputs.duckdb-sha }} force-version: ${{ inputs.force-version }} @@ -76,7 +80,7 @@ jobs: name: Publish Artifacts to the S3 Staging Bucket runs-on: ubuntu-latest needs: [ externally_triggered_build ] - if: ${{ github.repository_owner == 'duckdb' && inputs.publish-packages && needs.externally_triggered_build.result == 'success' && always() }} + if: ${{ github.repository_owner == 'duckdb' && inputs.publish-packages }} steps: - name: Fetch artifacts uses: actions/download-artifact@v4 From 645faadc78b06f3d0419bebdd3ec893e1c8e5fd9 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 22 Jul 2025 16:20:35 +0000 Subject: [PATCH 031/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index c1c3d888..35391cb7 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit c1c3d88864ff367936b7c667e0ac071bceed5215 +Subproject commit 35391cb7f32572e45fd202513af4a1a63ae9daa3 From c39fbffa29be469c80a3bb1ed1fdc6026f70e0b3 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 23 Jul 2025 05:44:23 +0000 Subject: [PATCH 032/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 35391cb7..524b7f9c 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 35391cb7f32572e45fd202513af4a1a63ae9daa3 +Subproject commit 524b7f9c25769a9779b04a2c4ef0526c31810c6d From 6373b7ef5b6d1158bbcf11a7ed73494c75a62c87 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 23 Jul 2025 05:56:18 +0000 Subject: [PATCH 033/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 61f07c32..5aa18eac 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 61f07c3221e01674d8fe40b4a25364ba2f3159a7 +Subproject commit 5aa18eace84c64b9395eaac7f517d8a08cfa5d3d From cba8650e1e9a20ec9448ab5032c960b01be5b646 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 23 Jul 2025 14:26:39 +0200 Subject: [PATCH 034/472] Use PYPI_HOST instead of full URL so we can re-use it for the cleanup flow --- .github/workflows/upload_to_pypi.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/upload_to_pypi.yml b/.github/workflows/upload_to_pypi.yml index bce4c315..d1d734b4 100644 --- a/.github/workflows/upload_to_pypi.yml +++ b/.github/workflows/upload_to_pypi.yml @@ -76,12 +76,12 @@ jobs: # this is needed for the OIDC flow that is used with trusted publishing on PyPI id-token: write steps: - - name: Fail if PYPI_URL is not set - if: ${{ vars.PYPI_URL == '' }} + - name: Fail if PYPI_HOST is not set + if: ${{ vars.PYPI_HOST == '' }} shell: bash run: | env_name="${{ github.event_name == 'workflow_dispatch' && inputs.environment || 'test.pypi' }}" - echo "Error: vars.PYPI_URL is not set in the resolved environment (${env_name})" + echo "Error: vars.PYPI_HOST is not set in the resolved environment (${env_name})" exit 1 - name: Authenticate With AWS @@ -103,5 +103,5 @@ jobs: - name: Upload artifacts to PyPI uses: pypa/gh-action-pypi-publish@release/v1 with: - repository-url: ${{ vars.PYPI_URL }} + repository-url: 'https://${{ vars.PYPI_HOST }}/legacy/' packages-dir: packages From ab3bb9fa7b95d33f22bdec7568e7c39b858ff827 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 23 Jul 2025 22:46:12 +0200 Subject: [PATCH 035/472] Also do pypi cleanup --- pyproject.toml | 1 + scripts/pypi_cleanup.py | 303 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 304 insertions(+) create mode 100644 scripts/pypi_cleanup.py diff --git a/pyproject.toml b/pyproject.toml index d0642b3b..669d188a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -260,6 +260,7 @@ scripts = [ # dependencies used for running scripts "pcpp", "polars", "pyarrow", + "pyotp>=2.9.0", "pytz" ] build = [ diff --git a/scripts/pypi_cleanup.py b/scripts/pypi_cleanup.py new file mode 100644 index 00000000..1eaf3d9a --- /dev/null +++ b/scripts/pypi_cleanup.py @@ -0,0 +1,303 @@ +import argparse +import pyotp +import datetime +import logging +import os +import re +import sys +import time +from collections import defaultdict +from html.parser import HTMLParser +from textwrap import dedent +from urllib.parse import urlparse + +import requests +from requests.exceptions import RequestException + +import argparse +import re +import os + +def valid_hostname(hostname): + """Validate hostname format""" + if len(hostname) > 253: + raise argparse.ArgumentTypeError("Hostname too long (max 253 characters)") + + # Check for valid hostname pattern + hostname_pattern = r'^[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?)*$' + if not re.match(hostname_pattern, hostname): + raise argparse.ArgumentTypeError(f"Invalid hostname format: {hostname}") + + return hostname + +def non_empty_string(value): + """Validate non-empty string""" + if not value or not value.strip(): + raise argparse.ArgumentTypeError("Value cannot be empty") + return value.strip() + +parser = argparse.ArgumentParser( + description="PyPI cleanup script", + epilog="Environment variables required (unless --dry): PYPI_CLEANUP_PASSWORD, PYPI_CLEANUP_OTP" +) +parser.add_argument("--dry", action="store_true", help="Show what would be deleted but don't actually do it") +parser.add_argument("--index-hostname", type=valid_hostname, required=True, help="Index hostname (required)") +parser.add_argument("--max-nightlies", type=int, default=2, help="Max number of nightlies of unreleased versions (default=2)") +parser.add_argument("--username", type=non_empty_string, help="Username (required unless --dry)") +args = parser.parse_args() + +# Handle secrets from environment variables +password = None +otp = None + +if not args.dry: + if not args.username: + parser.error("--username is required when not in dry-run mode") + + password = os.getenv('PYPI_CLEANUP_PASSWORD') + otp = os.getenv('PYPI_CLEANUP_OTP') + + if not password: + parser.error("PYPI_CLEANUP_PASSWORD environment variable is required when not in dry-run mode") + if not otp: + parser.error("PYPI_CLEANUP_OTP environment variable is required when not in dry-run mode") + +print(f"Dry run: {args.dry}") +print(f"Max nightlies: {args.max_nightlies}") +if not args.dry: + print(f"Hostname: {args.index_hostname}") + print(f"Username: {args.username}") + print("Password and OTP loaded from environment variables") + +# deletes old dev wheels from pypi. evil hack. +actually_delete = not args.dry +pypi_username = args.username or "user" +max_dev_releases = args.max_nightlies +host = 'https://{}/'.format(args.index_hostname) +pypi_password = password or "password" +pypi_otp = otp or "otp" + +patterns = [re.compile(r".*\.dev\d+$")] +###### NOTE: This code is taken from the pypi-cleanup package (https://github.com/arcivanov/pypi-cleanup/tree/master) +class CsfrParser(HTMLParser): + def __init__(self, target, contains_input=None): + super().__init__() + self._target = target + self._contains_input = contains_input + self.csrf = None # Result value from all forms on page + self._csrf = None # Temp value from current form + self._in_form = False # Currently parsing a form with an action we're interested in + self._input_contained = False # Input field requested is contained in the current form + + def handle_starttag(self, tag, attrs): + if tag == "form": + attrs = dict(attrs) + action = attrs.get("action") # Might be None. + if action and (action == self._target or action.startswith(self._target)): + self._in_form = True + return + + if self._in_form and tag == "input": + attrs = dict(attrs) + if attrs.get("name") == "csrf_token": + self._csrf = attrs["value"] + + if self._contains_input and attrs.get("name") == self._contains_input: + self._input_contained = True + + return + + def handle_endtag(self, tag): + if tag == "form": + self._in_form = False + # If we're in a right form that contains the requested input and csrf is not set + if (not self._contains_input or self._input_contained) and not self.csrf: + self.csrf = self._csrf + return + + +class PypiCleanup: + def __init__(self, url, username, package, password, otp, patterns, delete, max_dev_releases, verbose=False): + self.url = urlparse(url).geturl() + if self.url[-1] == "/": + self.url = self.url[:-1] + self.username = username + self.password = password + self.otp = otp + self.do_it = delete + self.package = package + self.patterns = patterns + self.max_dev_releases = max_dev_releases + self.verbose = verbose + + def run(self): + csrf = None + + if self.verbose: + logging.root.setLevel(logging.DEBUG) + + if self.do_it: + logging.warning("!!! WILL ACTUALLY DELETE THINGS !!!") + logging.warning("Will sleep for 3 seconds - Ctrl-C to abort!") + time.sleep(3.0) + else: + logging.info("Running in DRY RUN mode") + + logging.info(f"Will use the following patterns {self.patterns} on package {self.package}") + + with requests.Session() as s: + with s.get(f"{self.url}/pypi/{self.package}/json") as r: + try: + r.raise_for_status() + except RequestException as e: + logging.error(f"Unable to find package {repr(self.package)}", exc_info=e) + return 1 + + releases_by_date = {} + for release, files in r.json()["releases"].items(): + releases_by_date[release] = max( + [datetime.datetime.strptime(f["upload_time"], '%Y-%m-%dT%H:%M:%S') for f in files] + ) + + if not releases_by_date: + logging.info(f"No releases for package {self.package} have been found") + return + + version_dict = defaultdict(list) + releases = [] + for key in releases_by_date.keys(): + if '.dev' in key: + prefix, postfix = key.split('.dev') + version_dict[prefix].append(key) + + pkg_vers = [] + for version_key, versions in version_dict.items(): + # releases_by_date.keys() is a list of release versions, so when the version key appears in that list, + # that means the version have been released and we don't need to keep PRE-RELEASE (dev) versions anymore. + # All versions for that key should be added into a list to delete from PyPi (pkg_vers). + # When the version is not released yet, it appears among the version_dict keys. In this case we'd like to keep + # some number of versions (self.max_dev_releases), so we add the version names from the beginning + # of the versions list sorted by date, except for mentioned number of versions to keep. + if version_key in releases_by_date.keys() or self.max_dev_releases == 0: + pkg_vers.extend(versions) + else: + # sort by the suffix casted to int to keep only the most recent builds + sorted_versions = sorted(versions, key=lambda x: int(x.split('dev')[-1])) + pkg_vers.extend(sorted_versions[:-self.max_dev_releases]) + + if not self.do_it: + print("Following pkg_vers can be deleted: ", pkg_vers) + return + + if not pkg_vers: + logging.info(f"No releases were found matching specified patterns and dates in package {self.package}") + return + + if set(pkg_vers) == set(releases_by_date.keys()): + print( + dedent( + f""" + WARNING: + \tYou have selected the following patterns: {self.patterns} + \tThese patterns would delete all available released versions of `{self.package}`. + \tThis will render your project/package permanently inaccessible. + \tSince the costs of an error are too high I'm refusing to do this. + \tGoodbye. + """ + ), + file=sys.stderr, + ) + + if not self.do_it: + return 3 + for pkg in pkg_vers: + if 'dev' not in pkg: + raise Exception(f"Would be deleting version {pkg} but the version is not a dev version") + + if self.username is None: + raise Exception("No username provided") + + if self.password is None: + raise Exception("No password provided") + + with s.get(f"{self.url}/account/login/") as r: + r.raise_for_status() + form_action = "/account/login/" + parser = CsfrParser(form_action) + parser.feed(r.text) + if not parser.csrf: + raise ValueError(f"No CSFR found in {form_action}") + csrf = parser.csrf + + two_factor = False + with s.post( + f"{self.url}/account/login/", + data={"csrf_token": csrf, "username": self.username, "password": self.password}, + headers={"referer": f"{self.url}/account/login/"}, + ) as r: + r.raise_for_status() + if r.url == f"{self.url}/account/login/": + logging.error(f"Login for user {self.username} failed") + return 1 + + if r.url.startswith(f"{self.url}/account/two-factor/"): + form_action = r.url[len(self.url) :] + parser = CsfrParser(form_action) + parser.feed(r.text) + if not parser.csrf: + raise ValueError(f"No CSFR found in {form_action}") + csrf = parser.csrf + two_factor = True + two_factor_url = r.url + + if two_factor: + success = False + for i in range(3): + auth_code = pyotp.TOTP(self.otp).now() + with s.post( + two_factor_url, + data={"csrf_token": csrf, "method": "totp", "totp_value": auth_code}, + headers={"referer": two_factor_url}, + ) as r: + r.raise_for_status() + if r.url == two_factor_url: + logging.error(f"Authentication code {auth_code} is invalid, retrying in 5 seconds...") + time.sleep(5) + else: + success = True + break + if not success: + raise Exception("Could not authenticate with OTP") + + for pkg_ver in pkg_vers: + if 'dev' not in pkg_ver: + raise Exception(f"Would be deleting version {pkg_ver} but the version is not a dev version") + if self.do_it: + logging.info(f"Deleting {self.package} version {pkg_ver}") + form_action = f"/manage/project/{self.package}/release/{pkg_ver}/" + form_url = f"{self.url}{form_action}" + with s.get(form_url) as r: + r.raise_for_status() + parser = CsfrParser(form_action, "confirm_delete_version") + parser.feed(r.text) + if not parser.csrf: + raise ValueError(f"No CSFR found in {form_action}") + csrf = parser.csrf + referer = r.url + + with s.post( + form_url, + data={ + "csrf_token": csrf, + "confirm_delete_version": pkg_ver, + }, + headers={"referer": referer}, + ) as r: + r.raise_for_status() + + logging.info(f"Deleted {self.package} version {pkg_ver}") + else: + logging.info(f"Would be deleting {self.package} version {pkg_ver}, but not doing it!") + + +PypiCleanup(host, pypi_username, 'duckdb', pypi_password, pypi_otp, patterns, actually_delete, max_dev_releases).run() From 6279c5b13326250757c1c29e29a4f9b266db603d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 24 Jul 2025 06:25:01 +0000 Subject: [PATCH 036/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 524b7f9c..c8164851 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 524b7f9c25769a9779b04a2c4ef0526c31810c6d +Subproject commit c8164851be62bcad38c080e975c577ff951b16be From 8e464ea23c5dcdacd9127b3997c2255abf49f93b Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 24 Jul 2025 09:37:48 +0200 Subject: [PATCH 037/472] Added to workflow --- .github/workflows/cleanup_pypi.yml | 31 ++++++++++++++++++++++++++++++ pyproject.toml | 6 ++++-- scripts/pypi_cleanup.py | 6 +++--- 3 files changed, 38 insertions(+), 5 deletions(-) create mode 100644 .github/workflows/cleanup_pypi.yml diff --git a/.github/workflows/cleanup_pypi.yml b/.github/workflows/cleanup_pypi.yml new file mode 100644 index 00000000..9f2367a8 --- /dev/null +++ b/.github/workflows/cleanup_pypi.yml @@ -0,0 +1,31 @@ +name: cleanup_pypi.yml +on: + workflow_dispatch: + inputs: + dry-run: + description: List packages that would be deleted but don't delete them + type: boolean + default: false + workflow_call: +jobs: + cleanup_pypi: + name: Remove Nightlies from PyPI + runs-on: ubuntu-latest + env: + PYPI_CLEANUP_PASSWORD: ${{secrets.PYPI_CLEANUP_PASSWORD}} + PYPI_CLEANUP_OTP: ${{secrets.PYPI_CLEANUP_OTP}} + steps: + - uses: actions/checkout@v4 + - name: Install Astral UV + uses: astral-sh/setup-uv@v6 + with: + version: "0.7.14" + python-version: 3.11 + + - name: Run Cleanup + run: | + uv sync --only-group pypi --no-install-project + uv run --no-sync -s scripts/pypi_cleanup.py ${{ inputs.dry-run && '--dry' || '' }} \ + --index-hostname "${{ vars.PYPI_HOST }}" \ + --username "${{ vars.PYPI_CLEANUP_USERNAME }}" \ + --max-nightlies ${{ vars.PYPI_MAX_NIGHTLIES }} diff --git a/pyproject.toml b/pyproject.toml index 669d188a..2b70ba71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -250,7 +250,6 @@ test = [ # dependencies used for running tests "numpy<2; python_version < '3.12'", "numpy>=2; python_version >= '3.12'", ] - scripts = [ # dependencies used for running scripts "cxxheaderparser", "ipython", @@ -260,9 +259,12 @@ scripts = [ # dependencies used for running scripts "pcpp", "polars", "pyarrow", - "pyotp>=2.9.0", "pytz" ] +pypi = [ # dependencies used by the pypi cleanup script + "pyotp>=2.9.0", + "requests>=2.32.4", +] build = [ "cmake>=3.29.0", "ninja>=1.10", diff --git a/scripts/pypi_cleanup.py b/scripts/pypi_cleanup.py index 1eaf3d9a..914d0a93 100644 --- a/scripts/pypi_cleanup.py +++ b/scripts/pypi_cleanup.py @@ -41,9 +41,9 @@ def non_empty_string(value): epilog="Environment variables required (unless --dry): PYPI_CLEANUP_PASSWORD, PYPI_CLEANUP_OTP" ) parser.add_argument("--dry", action="store_true", help="Show what would be deleted but don't actually do it") -parser.add_argument("--index-hostname", type=valid_hostname, required=True, help="Index hostname (required)") -parser.add_argument("--max-nightlies", type=int, default=2, help="Max number of nightlies of unreleased versions (default=2)") -parser.add_argument("--username", type=non_empty_string, help="Username (required unless --dry)") +parser.add_argument("-i", "--index-hostname", type=valid_hostname, required=True, help="Index hostname (required)") +parser.add_argument("-m", "--max-nightlies", type=int, default=2, help="Max number of nightlies of unreleased versions (default=2)") +parser.add_argument("-u", "--username", type=non_empty_string, help="Username (required unless --dry)") args = parser.parse_args() # Handle secrets from environment variables From ae507ba2a7da18aeed1455ed68897f0cd0e5bc86 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 24 Jul 2025 09:50:21 +0200 Subject: [PATCH 038/472] Plugged into the upload workflow --- .github/workflows/cleanup_pypi.yml | 2 +- .github/workflows/upload_to_pypi.yml | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/cleanup_pypi.yml b/.github/workflows/cleanup_pypi.yml index 9f2367a8..d007a7d8 100644 --- a/.github/workflows/cleanup_pypi.yml +++ b/.github/workflows/cleanup_pypi.yml @@ -1,12 +1,12 @@ name: cleanup_pypi.yml on: + workflow_call: workflow_dispatch: inputs: dry-run: description: List packages that would be deleted but don't delete them type: boolean default: false - workflow_call: jobs: cleanup_pypi: name: Remove Nightlies from PyPI diff --git a/.github/workflows/upload_to_pypi.yml b/.github/workflows/upload_to_pypi.yml index d1d734b4..a4fedaf0 100644 --- a/.github/workflows/upload_to_pypi.yml +++ b/.github/workflows/upload_to_pypi.yml @@ -105,3 +105,8 @@ jobs: with: repository-url: 'https://${{ vars.PYPI_HOST }}/legacy/' packages-dir: packages + + cleanup_nightlies: + name: Remove Nightlies from PyPI + needs: publish-pypi + uses: ./.github/workflows/cleanup_pypi.yml From 902968b85d60613c7bd1455927ee8d292772b48d Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 24 Jul 2025 09:56:04 +0200 Subject: [PATCH 039/472] Naming and comments --- .github/workflows/on_external_dispatch.yml | 1 + .github/workflows/upload_to_pypi.yml | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index 947a7780..fb73f9b5 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -1,3 +1,4 @@ +# External Dispatch is called by duckdb's InvokeCI -> NotifyExternalRepositories job name: External Dispatch on: workflow_dispatch: diff --git a/.github/workflows/upload_to_pypi.yml b/.github/workflows/upload_to_pypi.yml index a4fedaf0..72ab7e88 100644 --- a/.github/workflows/upload_to_pypi.yml +++ b/.github/workflows/upload_to_pypi.yml @@ -1,6 +1,6 @@ -name: upload_to_pypi.yml +name: Upload Artifacts to PyPI on: - # this workflow runs after the below workflows are completed + # this workflow runs after the "External Dispatch" workflow is completed workflow_run: workflows: [ External Dispatch ] types: [ completed ] From 3fbb71ad1a0c47cb11de26830ff636a16ea08661 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 24 Jul 2025 10:08:03 +0200 Subject: [PATCH 040/472] Bypass project install --- .github/workflows/cleanup_pypi.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cleanup_pypi.yml b/.github/workflows/cleanup_pypi.yml index d007a7d8..72c1e93d 100644 --- a/.github/workflows/cleanup_pypi.yml +++ b/.github/workflows/cleanup_pypi.yml @@ -1,4 +1,4 @@ -name: cleanup_pypi.yml +name: Cleanup PyPI on: workflow_call: workflow_dispatch: @@ -20,10 +20,10 @@ jobs: uses: astral-sh/setup-uv@v6 with: version: "0.7.14" - python-version: 3.11 - name: Run Cleanup run: | + uv venv uv sync --only-group pypi --no-install-project uv run --no-sync -s scripts/pypi_cleanup.py ${{ inputs.dry-run && '--dry' || '' }} \ --index-hostname "${{ vars.PYPI_HOST }}" \ From 45287b36cc3057f00dc7559de7fa103bc7351ca6 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 24 Jul 2025 10:12:00 +0200 Subject: [PATCH 041/472] Be verbose --- .github/workflows/cleanup_pypi.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/cleanup_pypi.yml b/.github/workflows/cleanup_pypi.yml index 72c1e93d..1119b7d6 100644 --- a/.github/workflows/cleanup_pypi.yml +++ b/.github/workflows/cleanup_pypi.yml @@ -23,9 +23,10 @@ jobs: - name: Run Cleanup run: | - uv venv - uv sync --only-group pypi --no-install-project - uv run --no-sync -s scripts/pypi_cleanup.py ${{ inputs.dry-run && '--dry' || '' }} \ + set -x + uv -v venv + uv -v sync --only-group pypi --no-install-project + uv -v run --no-sync -s scripts/pypi_cleanup.py ${{ inputs.dry-run && '--dry' || '' }} \ --index-hostname "${{ vars.PYPI_HOST }}" \ --username "${{ vars.PYPI_CLEANUP_USERNAME }}" \ --max-nightlies ${{ vars.PYPI_MAX_NIGHTLIES }} From 39f6f24976cbf6e1423c9727db927b9167616bd8 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 24 Jul 2025 10:14:36 +0200 Subject: [PATCH 042/472] Allow version discovery --- .github/workflows/cleanup_pypi.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/cleanup_pypi.yml b/.github/workflows/cleanup_pypi.yml index 1119b7d6..7aa88743 100644 --- a/.github/workflows/cleanup_pypi.yml +++ b/.github/workflows/cleanup_pypi.yml @@ -16,6 +16,9 @@ jobs: PYPI_CLEANUP_OTP: ${{secrets.PYPI_CLEANUP_OTP}} steps: - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Install Astral UV uses: astral-sh/setup-uv@v6 with: @@ -24,7 +27,6 @@ jobs: - name: Run Cleanup run: | set -x - uv -v venv uv -v sync --only-group pypi --no-install-project uv -v run --no-sync -s scripts/pypi_cleanup.py ${{ inputs.dry-run && '--dry' || '' }} \ --index-hostname "${{ vars.PYPI_HOST }}" \ From 69084f9909c3389b4132da99b0cb2314cf732f07 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 24 Jul 2025 11:35:44 +0200 Subject: [PATCH 043/472] Made pypi upload simpler --- .github/workflows/cleanup_pypi.yml | 33 ++++++++- .github/workflows/on_external_dispatch.yml | 30 +++++--- .github/workflows/upload_to_pypi.yml | 82 ++++++---------------- 3 files changed, 73 insertions(+), 72 deletions(-) diff --git a/.github/workflows/cleanup_pypi.yml b/.github/workflows/cleanup_pypi.yml index 7aa88743..685be9b4 100644 --- a/.github/workflows/cleanup_pypi.yml +++ b/.github/workflows/cleanup_pypi.yml @@ -1,16 +1,32 @@ name: Cleanup PyPI on: workflow_call: + inputs: + environment: + description: CI environment to run in (test.pypi or production.pypi) + type: string + required: true workflow_dispatch: inputs: dry-run: description: List packages that would be deleted but don't delete them type: boolean default: false + environment: + description: CI environment to run in + type: choice + required: true + default: test.pypi + options: + - test.pypi + - production.pypi + jobs: cleanup_pypi: name: Remove Nightlies from PyPI runs-on: ubuntu-latest + environment: + name: ${{ inputs.environment }} env: PYPI_CLEANUP_PASSWORD: ${{secrets.PYPI_CLEANUP_PASSWORD}} PYPI_CLEANUP_OTP: ${{secrets.PYPI_CLEANUP_OTP}} @@ -19,6 +35,19 @@ jobs: with: fetch-depth: 0 + - if: ${{ vars.PYPI_HOST == '' }} + run: | + echo "Error: PYPI_HOST is not set in CI environment '${{ inputs.environment }}'" + exit 1 + - if: ${{ vars.PYPI_CLEANUP_USERNAME == '' }} + run: | + echo "Error: PYPI_CLEANUP_USERNAME is not set in CI environment '${{ inputs.environment }}'" + exit 1 + - if: ${{ vars.PYPI_MAX_NIGHTLIES == '' }} + run: | + echo "Error: PYPI_MAX_NIGHTLIES is not set in CI environment '${{ inputs.environment }}'" + exit 1 + - name: Install Astral UV uses: astral-sh/setup-uv@v6 with: @@ -27,8 +56,8 @@ jobs: - name: Run Cleanup run: | set -x - uv -v sync --only-group pypi --no-install-project - uv -v run --no-sync -s scripts/pypi_cleanup.py ${{ inputs.dry-run && '--dry' || '' }} \ + uv sync --only-group pypi --no-install-project + uv run --no-sync -s scripts/pypi_cleanup.py ${{ inputs.dry-run && '--dry' || '' }} \ --index-hostname "${{ vars.PYPI_HOST }}" \ --username "${{ vars.PYPI_CLEANUP_USERNAME }}" \ --max-nightlies ${{ vars.PYPI_MAX_NIGHTLIES }} diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index fb73f9b5..38ff7520 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -17,7 +17,7 @@ on: required: false publish-packages: type: boolean - description: Publish to S3 + description: Upload packages to S3 required: true default: false @@ -77,10 +77,12 @@ jobs: duckdb-git-ref: ${{ inputs.duckdb-sha }} force-version: ${{ inputs.force-version }} - publish_s3: - name: Publish Artifacts to the S3 Staging Bucket + upload_s3: + name: Upload Artifacts to the S3 Staging Bucket runs-on: ubuntu-latest - needs: [ externally_triggered_build ] + needs: externally_triggered_build + outputs: + version: ${{ steps.s3_upload.outputs.version }} if: ${{ github.repository_owner == 'duckdb' && inputs.publish-packages }} steps: - name: Fetch artifacts @@ -97,10 +99,18 @@ jobs: aws-access-key-id: ${{ secrets.S3_DUCKDB_STAGING_ID }} aws-secret-access-key: ${{ secrets.S3_DUCKDB_STAGING_KEY }} - - name: Upload artifacts to S3 bucket - # semantics: if a version is forced then we upload into a folder by the version name, otherwise we upload - # into a folder that is named -. Only the latter will be discovered be - # upload_to_pypi.yml. + - name: Upload Artifacts + id: s3_upload run: | - FOLDER="${{ inputs.force-version != '' && inputs.force-version || format('{0}-{1}', github.run_id, github.run_attempt) }}" - aws s3 cp artifacts s3://duckdb-staging/${{ github.repository }}/${FOLDER}/ --recursive + version=$(basename artifacts/*.tar.gz | sed 's/duckdb-\(.*\).tar.gz/\1/g') + aws s3 cp artifacts s3://duckdb-staging/${{ github.repository }}/${version}/ --recursive + echo "version=${version}" >> $GITHUB_OUTPUT + + publish_to_pypi: + name: Upload Artifacts to PyPI + needs: upload_s3 + if: ${{ force-version == '' }} + uses: ./.github/workflows/upload_to_pypi.yml + with: + version: ${{ needs.upload_s3.outputs.version }} + environment: pypi.production \ No newline at end of file diff --git a/.github/workflows/upload_to_pypi.yml b/.github/workflows/upload_to_pypi.yml index 72ab7e88..44f4dd22 100644 --- a/.github/workflows/upload_to_pypi.yml +++ b/.github/workflows/upload_to_pypi.yml @@ -1,87 +1,47 @@ name: Upload Artifacts to PyPI on: - # this workflow runs after the "External Dispatch" workflow is completed - workflow_run: - workflows: [ External Dispatch ] - types: [ completed ] - branches: - - main - - v*.*-* + workflow_call: + inputs: + environment: + description: CI environment to run in (test.pypi or production.pypi) + type: string + required: true + version: + description: The version to upload (must be present in the S3 staging bucket) + type: string + required: true workflow_dispatch: inputs: environment: - description: Environment to run in () + description: CI environment to run in (test.pypi or production.pypi) type: choice required: true default: test.pypi options: - test.pypi - production.pypi - artifact-folder: - description: The S3 folder that contains the artifacts (s3://duckdb-staging/duckdb/duckdb-python/) + version: + description: The version to upload (must be present in the S3 staging bucket) type: string required: true -concurrency: ${{ inputs.artifact-folder || format('{0}-{1}', github.event.workflow_run.id, github.event.workflow_run.run_attempt) }} +concurrency: + group: ${{ inputs.version }} + cancel-in-progress: true jobs: - prepare: - name: Prepare and guard upload - if: ${{ github.repository_owner == 'duckdb' && ( github.event.workflow_run.conclusion == 'success' || github.event_name != 'workflow_run' ) }} - runs-on: ubuntu-latest - outputs: - s3_prefix: ${{ steps.get_s3_prefix.outputs.s3_prefix }} - steps: - - name: Determine S3 Prefix - id: get_s3_prefix - run: | - artifact_folder="${{ inputs.artifact-folder || format('{0}-{1}', github.event.workflow_run.id, github.event.workflow_run.run_attempt) }}" - if [[ -n "${artifact_folder}" ]]; then - s3_prefix="${{ github.repository }}/${artifact_folder}" - echo "Created S3 prefix: ${s3_prefix}" - echo "s3_prefix=${s3_prefix}" >> $GITHUB_OUTPUT - else - echo "Can't determine S3 prefix for event: ${{ github.event_name }}. Quitting." - exit 1 - fi - - - name: Authenticate With AWS - uses: aws-actions/configure-aws-credentials@v4 - with: - aws-region: 'us-east-2' - aws-access-key-id: ${{ secrets.S3_DUCKDB_STAGING_ID }} - aws-secret-access-key: ${{ secrets.S3_DUCKDB_STAGING_KEY }} - - - name: Check S3 Prefix - shell: bash - run: | - if [[ $(aws s3api list-objects-v2 \ - --bucket duckdb-staging \ - --prefix "${{ steps.get_s3_prefix.outputs.s3_prefix }}/" \ - --max-items 1 \ - --query 'Contents[0].Key' \ - --output text) == "None" ]]; then - echo "Prefix does not exist: ${{ steps.get_s3_prefix.outputs.s3_prefix }}" - echo "${{ github.event_name == 'workflow_run' && 'Possibly built a stable release?' || 'Unexpected error' }}" - exit 1 - fi - publish-pypi: name: Publish Artifacts to PyPI - needs: [ prepare ] runs-on: ubuntu-latest environment: - name: ${{ github.event_name == 'workflow_dispatch' && inputs.environment || 'test.pypi' }} + name: ${{ inputs.environment }} permissions: # this is needed for the OIDC flow that is used with trusted publishing on PyPI id-token: write steps: - - name: Fail if PYPI_HOST is not set - if: ${{ vars.PYPI_HOST == '' }} - shell: bash + - if: ${{ vars.PYPI_HOST == '' }} run: | - env_name="${{ github.event_name == 'workflow_dispatch' && inputs.environment || 'test.pypi' }}" - echo "Error: vars.PYPI_HOST is not set in the resolved environment (${env_name})" + echo "Error: PYPI_HOST is not set in CI environment '${{ inputs.environment }}'" exit 1 - name: Authenticate With AWS @@ -93,7 +53,7 @@ jobs: - name: Download Artifacts From S3 env: - S3_URL: 's3://duckdb-staging/${{ needs.prepare.outputs.s3_prefix }}/' + S3_URL: 's3://duckdb-staging/${{ github.repository }}/${{ inputs.version }}/' AWS_ACCESS_KEY_ID: ${{ secrets.S3_DUCKDB_STAGING_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_DUCKDB_STAGING_KEY }} run: | @@ -110,3 +70,5 @@ jobs: name: Remove Nightlies from PyPI needs: publish-pypi uses: ./.github/workflows/cleanup_pypi.yml + with: + environment: ${{ inputs.environment }} From 3891894c690f85640ac8a971b62c3a44a6fb24b8 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 24 Jul 2025 12:10:18 +0200 Subject: [PATCH 044/472] Added bug report and discussion templates --- .github/ISSUE_TEMPLATE/bug_report.yml | 119 ++++++++++++++++++++++++++ .github/ISSUE_TEMPLATE/config.yml | 8 ++ 2 files changed, 127 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/bug_report.yml create mode 100644 .github/ISSUE_TEMPLATE/config.yml diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 00000000..0c77dcf4 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,119 @@ +name: Bug report +description: Create a report to help us improve +labels: + - needs triage +body: + - type: markdown + attributes: + value: > + Please report security vulnerabilities using GitHub's [report vulnerability form](https://github.com/duckdb/duckdb/security/advisories/new). + + - type: textarea + attributes: + label: What happens? + description: A short, clear and concise description of what the bug is. + validations: + required: true + + - type: textarea + attributes: + label: To Reproduce + description: | + Please provide steps to reproduce the behavior, preferably a [minimal reproducible example](https://en.wikipedia.org/wiki/Minimal_reproducible_example). Please adhere the following guidelines: + + * Format the code and the output as [code blocks](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks) using triple backticks: + + ```` + ``` + CODE HERE + ``` + ```` + * Add all required imports for scripts, e.g., `import duckdb`, `import pandas as pd`. + * Remove all prompts from the scripts. This include DuckDB's 'D' prompt and Python's `>>>` prompt. Removing these prompts makes reproduction attempts quicker. + * Make sure that the script and its outputs are provided in separate code blocks. + * If applicable, please check whether the issue is reproducible via running plain SQL queries from the DuckDB CLI client. + validations: + required: true + + - type: markdown + attributes: + value: "# Environment (please complete the following information):" + - type: input + attributes: + label: "OS:" + placeholder: e.g., OSX + description: Please include operating system version and architecture (e.g., aarch64, x86_64, etc.). + validations: + required: true + - type: input + attributes: + label: "DuckDB Package Version:" + placeholder: e.g., 1.3.2 + validations: + required: true + - type: input + attributes: + label: "Python Version:" + placeholder: e.g., 3.12 + validations: + required: true + - type: markdown + attributes: + value: "# Identity Disclosure:" + - type: input + attributes: + label: "Full Name:" + placeholder: e.g., John Doe + validations: + required: true + - type: input + attributes: + label: "Affiliation:" + placeholder: e.g., Acme Corporation + validations: + required: true + + - type: markdown + attributes: + value: | + If the above is not given and is not obvious from your GitHub profile page, we might close your issue without further review. Please refer to the [reasoning behind this rule](https://berthub.eu/articles/posts/anonymous-help/) if you have questions. + + # Before Submitting: + + - type: dropdown + attributes: + label: What is the latest build you tested with? If possible, we recommend testing with the latest nightly build. + description: | + Visit the [installation page](https://duckdb.org/docs/installation/) for instructions. + options: + - I have not tested with any build + - I have tested with a stable release + - I have tested with a nightly build + - I have tested with a source build + validations: + required: true + + - type: dropdown + attributes: + label: Did you include all relevant data sets for reproducing the issue? + options: + - "No - Other reason (please specify in the issue body)" + - "No - I cannot share the data sets because they are confidential" + - "No - I cannot easily share my data sets due to their large size" + - "Not applicable - the reproduction does not require a data set" + - "Yes" + default: 0 + validations: + required: true + + - type: checkboxes + attributes: + label: Did you include all code required to reproduce the issue? + options: + - label: Yes, I have + + - type: checkboxes + attributes: + label: Did you include all relevant configuration to reproduce the issue? + options: + - label: Yes, I have diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 00000000..1a21e420 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,8 @@ +blank_issues_enabled: false +contact_links: + - name: Feature Request + url: https://github.com/duckdb/duckdb-python/discussions/new?category=ideas&title=Feature%20Request:%20...&labels=feature&body=Why%20do%20you%20want%20this%20feature%3F + about: Submit feature requests here + - name: Discussions + url: https://github.com/duckdb/duckdb-python/discussions + about: Please ask and answer general questions here. From bfe005e3ef46aee39db4771aebf12d17a31ec0ac Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 24 Jul 2025 12:44:08 +0200 Subject: [PATCH 045/472] Ensure unicity of S3 path --- .github/workflows/cleanup_pypi.yml | 13 ++++++++++- .github/workflows/on_external_dispatch.yml | 21 ++++++++++++----- .github/workflows/upload_to_pypi.yml | 20 ++++++++++++----- scripts/pypi_cleanup.py | 26 +++++++++++----------- 4 files changed, 54 insertions(+), 26 deletions(-) diff --git a/.github/workflows/cleanup_pypi.yml b/.github/workflows/cleanup_pypi.yml index 685be9b4..f4a9fd36 100644 --- a/.github/workflows/cleanup_pypi.yml +++ b/.github/workflows/cleanup_pypi.yml @@ -60,4 +60,15 @@ jobs: uv run --no-sync -s scripts/pypi_cleanup.py ${{ inputs.dry-run && '--dry' || '' }} \ --index-hostname "${{ vars.PYPI_HOST }}" \ --username "${{ vars.PYPI_CLEANUP_USERNAME }}" \ - --max-nightlies ${{ vars.PYPI_MAX_NIGHTLIES }} + --max-nightlies ${{ vars.PYPI_MAX_NIGHTLIES }} > cleanup_output 2>&1 + + - name: PyPI Cleanup Summary + run : | + echo "## PyPI Cleanup Summary" >> $GITHUB_STEP_SUMMARY + echo "* Dry run: ${{ inputs.dry-run }}" >> $GITHUB_STEP_SUMMARY + echo "* PyPI Host: ${{ vars.PYPI_HOST }}" >> $GITHUB_STEP_SUMMARY + echo "* CI Environment: " >> $GITHUB_STEP_SUMMARY + echo "* Output:" >> $GITHUB_STEP_SUMMARY + echo '```' >> $GITHUB_STEP_SUMMARY + cat cleanup_output >> $GITHUB_STEP_SUMMARY + echo '```' >> $GITHUB_STEP_SUMMARY diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index 38ff7520..66afde42 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -80,7 +80,7 @@ jobs: upload_s3: name: Upload Artifacts to the S3 Staging Bucket runs-on: ubuntu-latest - needs: externally_triggered_build + needs: [commit_submodule, externally_triggered_build] outputs: version: ${{ steps.s3_upload.outputs.version }} if: ${{ github.repository_owner == 'duckdb' && inputs.publish-packages }} @@ -102,15 +102,24 @@ jobs: - name: Upload Artifacts id: s3_upload run: | - version=$(basename artifacts/*.tar.gz | sed 's/duckdb-\(.*\).tar.gz/\1/g') - aws s3 cp artifacts s3://duckdb-staging/${{ github.repository }}/${version}/ --recursive + sha=${{ needs.commit_submodule.outputs.sha-after-commit }} + aws s3 cp artifacts s3://duckdb-staging/${{ github.repository }}/${sha:0:10}/ --recursive echo "version=${version}" >> $GITHUB_OUTPUT + - name: S3 Upload Summary + run : | + sha=${{ needs.commit_submodule.outputs.sha-after-commit }} + version=$(basename artifacts/*.tar.gz | sed 's/duckdb-\(.*\).tar.gz/\1/g') + echo "## S3 Upload Summary" >> $GITHUB_STEP_SUMMARY + echo "* Version: ${version}" >> $GITHUB_STEP_SUMMARY + echo "* SHA: ${sha:0:10}" >> $GITHUB_STEP_SUMMARY + echo "* S3 URL: s3://duckdb-staging/${{ github.repository }}/${sha:0:10}/" >> $GITHUB_STEP_SUMMARY + publish_to_pypi: name: Upload Artifacts to PyPI - needs: upload_s3 + needs: [ commit_submodule, upload_s3 ] if: ${{ force-version == '' }} uses: ./.github/workflows/upload_to_pypi.yml with: - version: ${{ needs.upload_s3.outputs.version }} - environment: pypi.production \ No newline at end of file + sha: ${{ needs.commit_submodule.outputs.sha-after-commit }} + environment: pypi.production diff --git a/.github/workflows/upload_to_pypi.yml b/.github/workflows/upload_to_pypi.yml index 44f4dd22..edffcbe8 100644 --- a/.github/workflows/upload_to_pypi.yml +++ b/.github/workflows/upload_to_pypi.yml @@ -6,8 +6,8 @@ on: description: CI environment to run in (test.pypi or production.pypi) type: string required: true - version: - description: The version to upload (must be present in the S3 staging bucket) + sha: + description: The SHA of the commit that the packages were built from type: string required: true workflow_dispatch: @@ -20,13 +20,13 @@ on: options: - test.pypi - production.pypi - version: - description: The version to upload (must be present in the S3 staging bucket) + sha: + description: The SHA of the commit that the packages were built from type: string required: true concurrency: - group: ${{ inputs.version }} + group: ${{ inputs.sha }} cancel-in-progress: true jobs: @@ -53,7 +53,7 @@ jobs: - name: Download Artifacts From S3 env: - S3_URL: 's3://duckdb-staging/${{ github.repository }}/${{ inputs.version }}/' + S3_URL: 's3://duckdb-staging/${{ github.repository }}/${{ inputs.sha }}/' AWS_ACCESS_KEY_ID: ${{ secrets.S3_DUCKDB_STAGING_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_DUCKDB_STAGING_KEY }} run: | @@ -66,6 +66,14 @@ jobs: repository-url: 'https://${{ vars.PYPI_HOST }}/legacy/' packages-dir: packages + - name: PyPI Upload Summary + run : | + version=$(basename packages/*.tar.gz | sed 's/duckdb-\(.*\).tar.gz/\1/g') + echo "## PyPI Upload Summary" >> $GITHUB_STEP_SUMMARY + echo "* Version: ${version}" >> $GITHUB_STEP_SUMMARY + echo "* PyPI Host: ${{ vars.PYPI_HOST }}" >> $GITHUB_STEP_SUMMARY + echo "* CI Environment: ${{ inputs.environment }}" >> $GITHUB_STEP_SUMMARY + cleanup_nightlies: name: Remove Nightlies from PyPI needs: publish-pypi diff --git a/scripts/pypi_cleanup.py b/scripts/pypi_cleanup.py index 914d0a93..39bc48b4 100644 --- a/scripts/pypi_cleanup.py +++ b/scripts/pypi_cleanup.py @@ -185,8 +185,8 @@ def run(self): sorted_versions = sorted(versions, key=lambda x: int(x.split('dev')[-1])) pkg_vers.extend(sorted_versions[:-self.max_dev_releases]) + print("Following pkg_vers can be deleted: ", pkg_vers) if not self.do_it: - print("Following pkg_vers can be deleted: ", pkg_vers) return if not pkg_vers: @@ -231,9 +231,9 @@ def run(self): two_factor = False with s.post( - f"{self.url}/account/login/", - data={"csrf_token": csrf, "username": self.username, "password": self.password}, - headers={"referer": f"{self.url}/account/login/"}, + f"{self.url}/account/login/", + data={"csrf_token": csrf, "username": self.username, "password": self.password}, + headers={"referer": f"{self.url}/account/login/"}, ) as r: r.raise_for_status() if r.url == f"{self.url}/account/login/": @@ -255,9 +255,9 @@ def run(self): for i in range(3): auth_code = pyotp.TOTP(self.otp).now() with s.post( - two_factor_url, - data={"csrf_token": csrf, "method": "totp", "totp_value": auth_code}, - headers={"referer": two_factor_url}, + two_factor_url, + data={"csrf_token": csrf, "method": "totp", "totp_value": auth_code}, + headers={"referer": two_factor_url}, ) as r: r.raise_for_status() if r.url == two_factor_url: @@ -286,12 +286,12 @@ def run(self): referer = r.url with s.post( - form_url, - data={ - "csrf_token": csrf, - "confirm_delete_version": pkg_ver, - }, - headers={"referer": referer}, + form_url, + data={ + "csrf_token": csrf, + "confirm_delete_version": pkg_ver, + }, + headers={"referer": referer}, ) as r: r.raise_for_status() From ec742603985c154df088a3f24399f0b766b35c72 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 24 Jul 2025 12:46:29 +0200 Subject: [PATCH 046/472] add ci environment --- .github/workflows/cleanup_pypi.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cleanup_pypi.yml b/.github/workflows/cleanup_pypi.yml index f4a9fd36..f67c780c 100644 --- a/.github/workflows/cleanup_pypi.yml +++ b/.github/workflows/cleanup_pypi.yml @@ -67,7 +67,7 @@ jobs: echo "## PyPI Cleanup Summary" >> $GITHUB_STEP_SUMMARY echo "* Dry run: ${{ inputs.dry-run }}" >> $GITHUB_STEP_SUMMARY echo "* PyPI Host: ${{ vars.PYPI_HOST }}" >> $GITHUB_STEP_SUMMARY - echo "* CI Environment: " >> $GITHUB_STEP_SUMMARY + echo "* CI Environment: ${{ inputs.environment }}" >> $GITHUB_STEP_SUMMARY echo "* Output:" >> $GITHUB_STEP_SUMMARY echo '```' >> $GITHUB_STEP_SUMMARY cat cleanup_output >> $GITHUB_STEP_SUMMARY From c360d19624d7ba617b43551a5202d8fb11a878b7 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 24 Jul 2025 16:02:50 +0200 Subject: [PATCH 047/472] Only support pypi in cleanup script --- .github/workflows/cleanup_pypi.yml | 8 ++----- scripts/pypi_cleanup.py | 38 ++++++++++++++++-------------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/.github/workflows/cleanup_pypi.yml b/.github/workflows/cleanup_pypi.yml index f67c780c..5026ac52 100644 --- a/.github/workflows/cleanup_pypi.yml +++ b/.github/workflows/cleanup_pypi.yml @@ -34,11 +34,6 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 - - - if: ${{ vars.PYPI_HOST == '' }} - run: | - echo "Error: PYPI_HOST is not set in CI environment '${{ inputs.environment }}'" - exit 1 - if: ${{ vars.PYPI_CLEANUP_USERNAME == '' }} run: | echo "Error: PYPI_CLEANUP_USERNAME is not set in CI environment '${{ inputs.environment }}'" @@ -56,9 +51,10 @@ jobs: - name: Run Cleanup run: | set -x + pypi_index_flag=${{ inputs.environment == 'production.pypi' && '--prod' || '--test' }} uv sync --only-group pypi --no-install-project uv run --no-sync -s scripts/pypi_cleanup.py ${{ inputs.dry-run && '--dry' || '' }} \ - --index-hostname "${{ vars.PYPI_HOST }}" \ + ${pypi_index_flag} \ --username "${{ vars.PYPI_CLEANUP_USERNAME }}" \ --max-nightlies ${{ vars.PYPI_MAX_NIGHTLIES }} > cleanup_output 2>&1 diff --git a/scripts/pypi_cleanup.py b/scripts/pypi_cleanup.py index 39bc48b4..cf3548b6 100644 --- a/scripts/pypi_cleanup.py +++ b/scripts/pypi_cleanup.py @@ -14,10 +14,6 @@ import requests from requests.exceptions import RequestException -import argparse -import re -import os - def valid_hostname(hostname): """Validate hostname format""" if len(hostname) > 253: @@ -41,7 +37,9 @@ def non_empty_string(value): epilog="Environment variables required (unless --dry): PYPI_CLEANUP_PASSWORD, PYPI_CLEANUP_OTP" ) parser.add_argument("--dry", action="store_true", help="Show what would be deleted but don't actually do it") -parser.add_argument("-i", "--index-hostname", type=valid_hostname, required=True, help="Index hostname (required)") +host_group = parser.add_mutually_exclusive_group(required=True) +host_group.add_argument("--prod", action="store_true", help="Use production PyPI (pypi.org)") +host_group.add_argument("--test", action="store_true", help="Use test PyPI (test.pypi.org)") parser.add_argument("-m", "--max-nightlies", type=int, default=2, help="Max number of nightlies of unreleased versions (default=2)") parser.add_argument("-u", "--username", type=non_empty_string, help="Username (required unless --dry)") args = parser.parse_args() @@ -62,21 +60,25 @@ def non_empty_string(value): if not otp: parser.error("PYPI_CLEANUP_OTP environment variable is required when not in dry-run mode") -print(f"Dry run: {args.dry}") -print(f"Max nightlies: {args.max_nightlies}") +dry_run = args.dry +pypi_username = args.username +max_dev_releases = args.max_nightlies +host = None +if args.prod: + host = 'pypi.org' +elif args.test: + host = 'test.pypi.org' +pypi_url = 'https://{}/'.format(host) +pypi_password = password +pypi_otp = otp + +print(f"Dry run: {dry_run}") +print(f"Max nightlies: {max_dev_releases}") if not args.dry: - print(f"Hostname: {args.index_hostname}") - print(f"Username: {args.username}") + print(f"URL: {pypi_url}") + print(f"Username: {pypi_username}") print("Password and OTP loaded from environment variables") -# deletes old dev wheels from pypi. evil hack. -actually_delete = not args.dry -pypi_username = args.username or "user" -max_dev_releases = args.max_nightlies -host = 'https://{}/'.format(args.index_hostname) -pypi_password = password or "password" -pypi_otp = otp or "otp" - patterns = [re.compile(r".*\.dev\d+$")] ###### NOTE: This code is taken from the pypi-cleanup package (https://github.com/arcivanov/pypi-cleanup/tree/master) class CsfrParser(HTMLParser): @@ -300,4 +302,4 @@ def run(self): logging.info(f"Would be deleting {self.package} version {pkg_ver}, but not doing it!") -PypiCleanup(host, pypi_username, 'duckdb', pypi_password, pypi_otp, patterns, actually_delete, max_dev_releases).run() +PypiCleanup(pypi_url, pypi_username, 'duckdb', pypi_password, pypi_otp, patterns, not dry_run, max_dev_releases).run() From 8695fafadc903284cfa45332658a659354851aa2 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 24 Jul 2025 16:05:11 +0200 Subject: [PATCH 048/472] use globals --- scripts/pypi_cleanup.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/scripts/pypi_cleanup.py b/scripts/pypi_cleanup.py index cf3548b6..a1da4536 100644 --- a/scripts/pypi_cleanup.py +++ b/scripts/pypi_cleanup.py @@ -14,6 +14,9 @@ import requests from requests.exceptions import RequestException +_PYPI_URL_PROD = 'https://pypi.org/' +_PYPI_URL_TEST = 'https://test.pypi.org/' + def valid_hostname(hostname): """Validate hostname format""" if len(hostname) > 253: @@ -63,12 +66,10 @@ def non_empty_string(value): dry_run = args.dry pypi_username = args.username max_dev_releases = args.max_nightlies -host = None if args.prod: - host = 'pypi.org' -elif args.test: - host = 'test.pypi.org' -pypi_url = 'https://{}/'.format(host) + pypi_url = _PYPI_URL_PROD +else: + pypi_url = _PYPI_URL_TEST pypi_password = password pypi_otp = otp From d1c0ea447955878df4319a34871788bb0aec436c Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 25 Jul 2025 08:06:56 +0200 Subject: [PATCH 049/472] Fix workflow --- .github/workflows/on_external_dispatch.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index 66afde42..a9895634 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -118,7 +118,7 @@ jobs: publish_to_pypi: name: Upload Artifacts to PyPI needs: [ commit_submodule, upload_s3 ] - if: ${{ force-version == '' }} + if: ${{ inputs.force-version == '' }} uses: ./.github/workflows/upload_to_pypi.yml with: sha: ${{ needs.commit_submodule.outputs.sha-after-commit }} From ed6ea21a4a8af0413d6245c4cf22ff13cd01bb5d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 25 Jul 2025 06:10:52 +0000 Subject: [PATCH 050/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 2ca36309..8755ee6e 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 2ca36309ec79d076a2023323aea168616d48e17a +Subproject commit 8755ee6e1c6aaace193466373c9f46635969576e From 56422fb2286bb6b6e4dcfdfad5f80eeea73868ab Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 25 Jul 2025 06:48:05 +0000 Subject: [PATCH 051/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index c8164851..e78a4989 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit c8164851be62bcad38c080e975c577ff951b16be +Subproject commit e78a49891b34c25c78cccf2c88de775d49750775 From a825dc675a370bd5284847062747471a152fd5ff Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 25 Jul 2025 14:02:05 +0200 Subject: [PATCH 052/472] fix gh environment name --- .github/workflows/on_external_dispatch.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index a9895634..6e07d764 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -122,4 +122,4 @@ jobs: uses: ./.github/workflows/upload_to_pypi.yml with: sha: ${{ needs.commit_submodule.outputs.sha-after-commit }} - environment: pypi.production + environment: production.pypi From a54d813a57d5a64e97d3e0273c880c7a494cb8f0 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 25 Jul 2025 17:16:40 +0200 Subject: [PATCH 053/472] Cleanup script tests passing --- .github/workflows/cleanup_pypi.yml | 5 +- pyproject.toml | 1 + scripts/pypi_cleanup.py | 306 ----------------------------- 3 files changed, 3 insertions(+), 309 deletions(-) delete mode 100644 scripts/pypi_cleanup.py diff --git a/.github/workflows/cleanup_pypi.yml b/.github/workflows/cleanup_pypi.yml index 5026ac52..e1aad516 100644 --- a/.github/workflows/cleanup_pypi.yml +++ b/.github/workflows/cleanup_pypi.yml @@ -51,10 +51,9 @@ jobs: - name: Run Cleanup run: | set -x - pypi_index_flag=${{ inputs.environment == 'production.pypi' && '--prod' || '--test' }} uv sync --only-group pypi --no-install-project - uv run --no-sync -s scripts/pypi_cleanup.py ${{ inputs.dry-run && '--dry' || '' }} \ - ${pypi_index_flag} \ + uv run --no-sync python -u -m duckdb_packaging.pypi_cleanup ${{ inputs.dry-run && '--dry' || '' }} \ + ${{ inputs.environment == 'production.pypi' && '--prod' || '--test' }} \ --username "${{ vars.PYPI_CLEANUP_USERNAME }}" \ --max-nightlies ${{ vars.PYPI_MAX_NIGHTLIES }} > cleanup_output 2>&1 diff --git a/pyproject.toml b/pyproject.toml index 2b70ba71..2ade07cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -235,6 +235,7 @@ test = [ # dependencies used for running tests "polars", "psutil", "py4j", + "pyotp", "pyspark", "pytz", "requests", diff --git a/scripts/pypi_cleanup.py b/scripts/pypi_cleanup.py deleted file mode 100644 index a1da4536..00000000 --- a/scripts/pypi_cleanup.py +++ /dev/null @@ -1,306 +0,0 @@ -import argparse -import pyotp -import datetime -import logging -import os -import re -import sys -import time -from collections import defaultdict -from html.parser import HTMLParser -from textwrap import dedent -from urllib.parse import urlparse - -import requests -from requests.exceptions import RequestException - -_PYPI_URL_PROD = 'https://pypi.org/' -_PYPI_URL_TEST = 'https://test.pypi.org/' - -def valid_hostname(hostname): - """Validate hostname format""" - if len(hostname) > 253: - raise argparse.ArgumentTypeError("Hostname too long (max 253 characters)") - - # Check for valid hostname pattern - hostname_pattern = r'^[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?)*$' - if not re.match(hostname_pattern, hostname): - raise argparse.ArgumentTypeError(f"Invalid hostname format: {hostname}") - - return hostname - -def non_empty_string(value): - """Validate non-empty string""" - if not value or not value.strip(): - raise argparse.ArgumentTypeError("Value cannot be empty") - return value.strip() - -parser = argparse.ArgumentParser( - description="PyPI cleanup script", - epilog="Environment variables required (unless --dry): PYPI_CLEANUP_PASSWORD, PYPI_CLEANUP_OTP" -) -parser.add_argument("--dry", action="store_true", help="Show what would be deleted but don't actually do it") -host_group = parser.add_mutually_exclusive_group(required=True) -host_group.add_argument("--prod", action="store_true", help="Use production PyPI (pypi.org)") -host_group.add_argument("--test", action="store_true", help="Use test PyPI (test.pypi.org)") -parser.add_argument("-m", "--max-nightlies", type=int, default=2, help="Max number of nightlies of unreleased versions (default=2)") -parser.add_argument("-u", "--username", type=non_empty_string, help="Username (required unless --dry)") -args = parser.parse_args() - -# Handle secrets from environment variables -password = None -otp = None - -if not args.dry: - if not args.username: - parser.error("--username is required when not in dry-run mode") - - password = os.getenv('PYPI_CLEANUP_PASSWORD') - otp = os.getenv('PYPI_CLEANUP_OTP') - - if not password: - parser.error("PYPI_CLEANUP_PASSWORD environment variable is required when not in dry-run mode") - if not otp: - parser.error("PYPI_CLEANUP_OTP environment variable is required when not in dry-run mode") - -dry_run = args.dry -pypi_username = args.username -max_dev_releases = args.max_nightlies -if args.prod: - pypi_url = _PYPI_URL_PROD -else: - pypi_url = _PYPI_URL_TEST -pypi_password = password -pypi_otp = otp - -print(f"Dry run: {dry_run}") -print(f"Max nightlies: {max_dev_releases}") -if not args.dry: - print(f"URL: {pypi_url}") - print(f"Username: {pypi_username}") - print("Password and OTP loaded from environment variables") - -patterns = [re.compile(r".*\.dev\d+$")] -###### NOTE: This code is taken from the pypi-cleanup package (https://github.com/arcivanov/pypi-cleanup/tree/master) -class CsfrParser(HTMLParser): - def __init__(self, target, contains_input=None): - super().__init__() - self._target = target - self._contains_input = contains_input - self.csrf = None # Result value from all forms on page - self._csrf = None # Temp value from current form - self._in_form = False # Currently parsing a form with an action we're interested in - self._input_contained = False # Input field requested is contained in the current form - - def handle_starttag(self, tag, attrs): - if tag == "form": - attrs = dict(attrs) - action = attrs.get("action") # Might be None. - if action and (action == self._target or action.startswith(self._target)): - self._in_form = True - return - - if self._in_form and tag == "input": - attrs = dict(attrs) - if attrs.get("name") == "csrf_token": - self._csrf = attrs["value"] - - if self._contains_input and attrs.get("name") == self._contains_input: - self._input_contained = True - - return - - def handle_endtag(self, tag): - if tag == "form": - self._in_form = False - # If we're in a right form that contains the requested input and csrf is not set - if (not self._contains_input or self._input_contained) and not self.csrf: - self.csrf = self._csrf - return - - -class PypiCleanup: - def __init__(self, url, username, package, password, otp, patterns, delete, max_dev_releases, verbose=False): - self.url = urlparse(url).geturl() - if self.url[-1] == "/": - self.url = self.url[:-1] - self.username = username - self.password = password - self.otp = otp - self.do_it = delete - self.package = package - self.patterns = patterns - self.max_dev_releases = max_dev_releases - self.verbose = verbose - - def run(self): - csrf = None - - if self.verbose: - logging.root.setLevel(logging.DEBUG) - - if self.do_it: - logging.warning("!!! WILL ACTUALLY DELETE THINGS !!!") - logging.warning("Will sleep for 3 seconds - Ctrl-C to abort!") - time.sleep(3.0) - else: - logging.info("Running in DRY RUN mode") - - logging.info(f"Will use the following patterns {self.patterns} on package {self.package}") - - with requests.Session() as s: - with s.get(f"{self.url}/pypi/{self.package}/json") as r: - try: - r.raise_for_status() - except RequestException as e: - logging.error(f"Unable to find package {repr(self.package)}", exc_info=e) - return 1 - - releases_by_date = {} - for release, files in r.json()["releases"].items(): - releases_by_date[release] = max( - [datetime.datetime.strptime(f["upload_time"], '%Y-%m-%dT%H:%M:%S') for f in files] - ) - - if not releases_by_date: - logging.info(f"No releases for package {self.package} have been found") - return - - version_dict = defaultdict(list) - releases = [] - for key in releases_by_date.keys(): - if '.dev' in key: - prefix, postfix = key.split('.dev') - version_dict[prefix].append(key) - - pkg_vers = [] - for version_key, versions in version_dict.items(): - # releases_by_date.keys() is a list of release versions, so when the version key appears in that list, - # that means the version have been released and we don't need to keep PRE-RELEASE (dev) versions anymore. - # All versions for that key should be added into a list to delete from PyPi (pkg_vers). - # When the version is not released yet, it appears among the version_dict keys. In this case we'd like to keep - # some number of versions (self.max_dev_releases), so we add the version names from the beginning - # of the versions list sorted by date, except for mentioned number of versions to keep. - if version_key in releases_by_date.keys() or self.max_dev_releases == 0: - pkg_vers.extend(versions) - else: - # sort by the suffix casted to int to keep only the most recent builds - sorted_versions = sorted(versions, key=lambda x: int(x.split('dev')[-1])) - pkg_vers.extend(sorted_versions[:-self.max_dev_releases]) - - print("Following pkg_vers can be deleted: ", pkg_vers) - if not self.do_it: - return - - if not pkg_vers: - logging.info(f"No releases were found matching specified patterns and dates in package {self.package}") - return - - if set(pkg_vers) == set(releases_by_date.keys()): - print( - dedent( - f""" - WARNING: - \tYou have selected the following patterns: {self.patterns} - \tThese patterns would delete all available released versions of `{self.package}`. - \tThis will render your project/package permanently inaccessible. - \tSince the costs of an error are too high I'm refusing to do this. - \tGoodbye. - """ - ), - file=sys.stderr, - ) - - if not self.do_it: - return 3 - for pkg in pkg_vers: - if 'dev' not in pkg: - raise Exception(f"Would be deleting version {pkg} but the version is not a dev version") - - if self.username is None: - raise Exception("No username provided") - - if self.password is None: - raise Exception("No password provided") - - with s.get(f"{self.url}/account/login/") as r: - r.raise_for_status() - form_action = "/account/login/" - parser = CsfrParser(form_action) - parser.feed(r.text) - if not parser.csrf: - raise ValueError(f"No CSFR found in {form_action}") - csrf = parser.csrf - - two_factor = False - with s.post( - f"{self.url}/account/login/", - data={"csrf_token": csrf, "username": self.username, "password": self.password}, - headers={"referer": f"{self.url}/account/login/"}, - ) as r: - r.raise_for_status() - if r.url == f"{self.url}/account/login/": - logging.error(f"Login for user {self.username} failed") - return 1 - - if r.url.startswith(f"{self.url}/account/two-factor/"): - form_action = r.url[len(self.url) :] - parser = CsfrParser(form_action) - parser.feed(r.text) - if not parser.csrf: - raise ValueError(f"No CSFR found in {form_action}") - csrf = parser.csrf - two_factor = True - two_factor_url = r.url - - if two_factor: - success = False - for i in range(3): - auth_code = pyotp.TOTP(self.otp).now() - with s.post( - two_factor_url, - data={"csrf_token": csrf, "method": "totp", "totp_value": auth_code}, - headers={"referer": two_factor_url}, - ) as r: - r.raise_for_status() - if r.url == two_factor_url: - logging.error(f"Authentication code {auth_code} is invalid, retrying in 5 seconds...") - time.sleep(5) - else: - success = True - break - if not success: - raise Exception("Could not authenticate with OTP") - - for pkg_ver in pkg_vers: - if 'dev' not in pkg_ver: - raise Exception(f"Would be deleting version {pkg_ver} but the version is not a dev version") - if self.do_it: - logging.info(f"Deleting {self.package} version {pkg_ver}") - form_action = f"/manage/project/{self.package}/release/{pkg_ver}/" - form_url = f"{self.url}{form_action}" - with s.get(form_url) as r: - r.raise_for_status() - parser = CsfrParser(form_action, "confirm_delete_version") - parser.feed(r.text) - if not parser.csrf: - raise ValueError(f"No CSFR found in {form_action}") - csrf = parser.csrf - referer = r.url - - with s.post( - form_url, - data={ - "csrf_token": csrf, - "confirm_delete_version": pkg_ver, - }, - headers={"referer": referer}, - ) as r: - r.raise_for_status() - - logging.info(f"Deleted {self.package} version {pkg_ver}") - else: - logging.info(f"Would be deleting {self.package} version {pkg_ver}, but not doing it!") - - -PypiCleanup(pypi_url, pypi_username, 'duckdb', pypi_password, pypi_otp, patterns, not dry_run, max_dev_releases).run() From a6afa86b09b1a778f4403d9bbfd54f113c87b3c7 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 25 Jul 2025 17:35:40 +0200 Subject: [PATCH 054/472] Forgot to add files --- duckdb_packaging/pypi_cleanup.py | 576 +++++++++++++++++++++++++++++++ tests/fast/test_pypi_cleanup.py | 427 +++++++++++++++++++++++ 2 files changed, 1003 insertions(+) create mode 100644 duckdb_packaging/pypi_cleanup.py create mode 100644 tests/fast/test_pypi_cleanup.py diff --git a/duckdb_packaging/pypi_cleanup.py b/duckdb_packaging/pypi_cleanup.py new file mode 100644 index 00000000..d386a606 --- /dev/null +++ b/duckdb_packaging/pypi_cleanup.py @@ -0,0 +1,576 @@ +""" +!!HERE BE DRAGONS!! Use this script with care! + +PyPI package cleanup tool. This script will: +* Never remove a stable version (including a post release version) +* Remove all release candidates for versions that have stable releases +* Remove all dev releases for versions that have stable releases +* Keep the configured amount of dev releases per version, and remove older dev releases +""" + +import argparse +import contextlib +import datetime +import heapq +import logging +import os +import re +import sys +import time +from collections import defaultdict +from html.parser import HTMLParser +from typing import Dict, Optional, Set, Generator +from urllib.parse import urlparse + +import pyotp +import requests +from requests import Session +from requests.adapters import HTTPAdapter +from requests.exceptions import RequestException +from urllib3 import Retry + +_PYPI_URL_PROD = 'https://pypi.org/' +_PYPI_URL_TEST = 'https://test.pypi.org/' +_DEFAULT_MAX_NIGHTLIES = 2 +_LOGIN_RETRY_ATTEMPTS = 3 +_LOGIN_RETRY_DELAY = 5 + + +def create_argument_parser() -> argparse.ArgumentParser: + """Create and configure the argument parser.""" + parser = argparse.ArgumentParser( + description=""" +PyPI cleanup script for removing development versions. + +!!HERE BE DRAGONS!! Use this script with care! + +This script will: +* Never remove a stable version (including a post release version) +* Remove all release candidates for versions that have stable releases +* Remove all dev releases for versions that have stable releases +* Keep the configured amount of dev releases per version, and remove older dev releases + """, + epilog="Environment variables required (unless --dry-run): PYPI_CLEANUP_PASSWORD, PYPI_CLEANUP_OTP", + formatter_class=argparse.RawDescriptionHelpFormatter + ) + + parser.add_argument( + "--dry-run", + action="store_true", + help="Show what would be deleted but don't actually do it" + ) + + host_group = parser.add_mutually_exclusive_group(required=True) + host_group.add_argument( + "--prod", + action="store_true", + help="Use production PyPI (pypi.org)" + ) + host_group.add_argument( + "--test", + action="store_true", + help="Use test PyPI (test.pypi.org)" + ) + + parser.add_argument( + "-m", "--max-nightlies", + type=int, + default=_DEFAULT_MAX_NIGHTLIES, + help=f"Max number of nightlies of unreleased versions (default={_DEFAULT_MAX_NIGHTLIES})" + ) + + parser.add_argument( + "-u", "--username", + type=validate_username, + help="PyPI username (required unless --dry-run)" + ) + + parser.add_argument( + "-v", "--verbose", + action="store_true", + help="Enable verbose debug logging" + ) + + return parser + +class PyPICleanupError(Exception): + """Base exception for PyPI cleanup operations.""" + pass + + +class AuthenticationError(PyPICleanupError): + """Raised when authentication fails.""" + pass + + +class ValidationError(PyPICleanupError): + """Raised when input validation fails.""" + pass + + +def setup_logging(verbose: bool = False) -> None: + """Configure logging with appropriate level and format.""" + level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig( + level=level, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + +def validate_username(value: str) -> str: + """Validate and sanitize username input.""" + if not value or not value.strip(): + raise argparse.ArgumentTypeError("Username cannot be empty") + + username = value.strip() + if len(username) > 100: # Reasonable limit + raise argparse.ArgumentTypeError("Username too long (max 100 characters)") + + # Basic validation - PyPI usernames are alphanumeric with limited special chars + if not re.match(r'^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$|^[a-zA-Z0-9]$', username): + raise argparse.ArgumentTypeError("Invalid username format") + + return username + +@contextlib.contextmanager +def session_with_retries() -> Generator[Session, None, None]: + """Create a requests session with retry strategy for ephemeral errors.""" + with requests.Session() as session: + retry_strategy = Retry( + allowed_methods=["GET", "POST"], + total=None, # disable to make the below take effect + redirect=10, # Don't follow more than 10 redirects in a row + connect=3, # try 3 times before giving up on connection errors + read=3, # try 3 times before giving up on read errors + status=3, # try 3 times before giving up on status errors (see forcelist below) + status_forcelist=[429] + [status for status in range(500, 512)], + other=0, # whatever else may cause an error should break + backoff_factor=0.1, # [0.0s, 0.2s, 0.4s] + raise_on_redirect=True, # raise exception when redirect error retries are exhausted + raise_on_status=True, # raise exception when status error retries are exhausted + respect_retry_after_header=True, # respect Retry-After headers + ) + adapter = HTTPAdapter(max_retries=retry_strategy) + session.mount("https://", adapter) + yield session + +def load_credentials(dry_run: bool) -> tuple[Optional[str], Optional[str]]: + """Load credentials from environment variables.""" + if dry_run: + return None, None + + password = os.getenv('PYPI_CLEANUP_PASSWORD') + otp = os.getenv('PYPI_CLEANUP_OTP') + + if not password: + raise ValidationError("PYPI_CLEANUP_PASSWORD environment variable is required when not in dry-run mode") + if not otp: + raise ValidationError("PYPI_CLEANUP_OTP environment variable is required when not in dry-run mode") + + return password, otp + + +def validate_arguments(args: argparse.Namespace) -> None: + """Validate parsed arguments.""" + if not args.dry_run and not args.username: + raise ValidationError("--username is required when not in dry-run mode") + + if args.max_nightlies < 0: + raise ValidationError("--max-nightlies must be non-negative") + +class CsrfParser(HTMLParser): + """HTML parser to extract CSRF tokens from PyPI forms. + + Based on pypi-cleanup package (https://github.com/arcivanov/pypi-cleanup/tree/master) + """ + def __init__(self, target, contains_input=None): + super().__init__() + self._target = target + self._contains_input = contains_input + self.csrf = None # Result value from all forms on page + self._csrf = None # Temp value from current form + self._in_form = False # Currently parsing a form with an action we're interested in + self._input_contained = False # Input field requested is contained in the current form + + def handle_starttag(self, tag, attrs): + if tag == "form": + attrs = dict(attrs) + action = attrs.get("action") # Might be None. + if action and (action == self._target or action.startswith(self._target)): + self._in_form = True + return + + if self._in_form and tag == "input": + attrs = dict(attrs) + if attrs.get("name") == "csrf_token": + self._csrf = attrs["value"] + + if self._contains_input and attrs.get("name") == self._contains_input: + self._input_contained = True + + return + + def handle_endtag(self, tag): + if tag == "form": + self._in_form = False + # If we're in a right form that contains the requested input and csrf is not set + if (not self._contains_input or self._input_contained) and not self.csrf: + self.csrf = self._csrf + return + + +class PyPICleanup: + """Main class for performing PyPI package cleanup operations.""" + + def __init__(self, index_url: str, do_delete: bool, max_dev_releases: int=_DEFAULT_MAX_NIGHTLIES, + username: Optional[str]=None, password: Optional[str]=None, otp: Optional[str]=None): + parsed_url = urlparse(index_url) + self._index_url = parsed_url.geturl().rstrip('/') + self._index_host = parsed_url.hostname + self._do_delete = do_delete + self._max_dev_releases = max_dev_releases + self._username = username + self._password = password + self._otp = otp + self._package = 'duckdb' + self._dev_version_pattern = re.compile(r"^(?P\d+\.\d+\.\d+)\.dev(?P\d+)$") + self._rc_version_pattern = re.compile(r"^(?P\d+\.\d+\.\d+)\.rc\d+$") + self._stable_version_pattern = re.compile(r"^\d+\.\d+\.\d+(\.post\d+)?$") + + def run(self) -> int: + """Execute the cleanup process. + + Returns: + int: Exit code (0 for success, non-zero for failure) + """ + if self._do_delete: + logging.warning(f"NOT A DRILL: WILL DELETE PACKAGES") + else: + logging.info("Running in DRY RUN mode, nothing will be deleted") + + logging.info(f"Max development releases to keep per unreleased version: {self._max_dev_releases}") + + try: + return self._execute_cleanup() + except PyPICleanupError as e: + logging.error(f"Cleanup failed: {e}") + return 1 + except Exception as e: + logging.error(f"Unexpected error: {e}", exc_info=True) + return 1 + + def _execute_cleanup(self) -> int: + """Execute the main cleanup logic.""" + + # Get released versions + versions = self._fetch_released_versions() + if not versions: + logging.info(f"No releases found for {self._package}") + return 0 + + # Determine versions to delete + versions_to_delete = self._determine_versions_to_delete(versions) + if not versions_to_delete: + logging.info("No versions to delete (no stale rc's or dev releases)") + return 0 + + logging.warning(f"Found {len(versions_to_delete)} versions to clean up:") + for version in sorted(versions_to_delete): + logging.warning(version) + + if not self._do_delete: + logging.info("Dry run complete - no packages were deleted") + return 0 + + # Perform authentication and deletion + self._authenticate() + self._delete_versions(versions_to_delete) + + logging.info(f"Successfully cleaned up {len(versions_to_delete)} development versions") + return 0 + + def _fetch_released_versions(self) -> Set[str]: + """Fetch package release information from PyPI API.""" + logging.debug(f"Fetching package information for '{self._package}'") + + try: + with session_with_retries() as session: + req = session.get(f"{self._index_url}/pypi/{self._package}/json") + req.raise_for_status() + data = req.json() + versions = {v for v, files in data["releases"].items() if len(files) > 0} + logging.debug(f"Found {len(versions)} releases with files") + return versions + except RequestException as e: + raise PyPICleanupError(f"Failed to fetch package information for '{self._package}': {e}") from e + + def _is_stable_release_version(self, version: str) -> bool: + """Determine whether a version string denotes a stable release.""" + return self._stable_version_pattern.match(version) is not None + + def _is_rc_version(self, version: str) -> bool: + """Determine whether a version string denotes a stable release.""" + return self._rc_version_pattern.match(version) is not None + + def _is_dev_version(self, version: str) -> bool: + """Determine whether a version string denotes a dev release.""" + return self._dev_version_pattern.match(version) is not None + + def _parse_rc_version(self, version: str) -> str: + """Parse a rc version string to determine the base version.""" + match = self._rc_version_pattern.match(version) + if not match: + raise PyPICleanupError(f"Invalid rc version '{version}'") + return match.group("version") if match else None + + def _parse_dev_version(self, version: str) -> tuple[str, int]: + """Parse a dev version string to determine the base version and dev version id.""" + match = self._dev_version_pattern.match(version) + if not match: + raise PyPICleanupError(f"Invalid dev version '{version}'") + return match.group("version"), int(match.group("dev_id")) + + def _determine_versions_to_delete(self, versions: Set[str]) -> Set[str]: + """Determine which package versions should be deleted.""" + logging.debug("Analyzing versions to determine cleanup candidates") + + # Get all stable, rc and dev versions + stable_versions = {v for v in versions if self._is_stable_release_version(v)} + rc_versions = {v for v in versions if self._is_rc_version(v)} + rc_base_versions = {self._parse_rc_version(v) for v in versions if self._is_rc_version(v)} + dev_versions = {v for v in versions if self._is_dev_version(v)} + + # Set of all rc releases of versions that have a stable release + rcs_of_stable = {v for v in rc_versions if self._parse_rc_version(v) in stable_versions} + # Set of all dev releases of versions that have a stable or rc release + devs_of_stable = {v for v in dev_versions if self._parse_dev_version(v)[0] in stable_versions} + devs_of_rc = {v for v in dev_versions if self._parse_dev_version(v)[0] in rc_base_versions} + # Set of orphan dev versions + orphan_devs = dev_versions.difference(devs_of_stable).difference(devs_of_rc) + + # Construct list of orphan dev + orphan_devs_per_version = defaultdict(list) + # 1. put all dev keep candidates on a max heap indexed by negative dev id (i.e. dev10 -> -10) + for version in orphan_devs: + base_version, dev_id = self._parse_dev_version(version) + heapq.heappush(orphan_devs_per_version[base_version], (-dev_id, version)) + # 2. remove the amount of latest dev releases we want to keep + for version_list in orphan_devs_per_version.values(): + for _ in range(min(self._max_dev_releases, len(version_list))): + heapq.heappop(version_list) + # 3. Result: set of outdated dev versions + devs_outdated = {v for version_list in orphan_devs_per_version.values() for _, v in version_list} + + # Construct final deletion set + versions_to_delete = set() + if rcs_of_stable: + versions_to_delete.update(rcs_of_stable) + logging.info(f"Found {len(rcs_of_stable)} release candidates that have stable releases") + if devs_of_stable: + versions_to_delete.update(devs_of_stable) + logging.info(f"Found {len(devs_of_stable)} dev releases that have stable releases") + if devs_of_rc: + versions_to_delete.update(devs_of_rc) + logging.info(f"Found {len(devs_of_rc)} dev releases that have release candidates") + if devs_outdated: + versions_to_delete.update(devs_outdated) + logging.info(f"Found {len(devs_outdated)} dev releases that are outdated") + + # Final safety checks + if versions_to_delete == versions: + raise PyPICleanupError( + f"Safety check failed: cleanup would delete ALL versions of '{self._package}'. " + "This would make the package permanently inaccessible. Aborting." + ) + if len(versions_to_delete.intersection(stable_versions)) > 0: + raise PyPICleanupError( + f"Safety check failed: cleanup would delete one or more stable versions of '{self._package}'. " + f"A regexp might be broken? (would delete {versions_to_delete.intersection(stable_versions)})" + ) + unknown_versions = versions.difference(stable_versions).difference(rc_versions).difference(dev_versions) + if unknown_versions: + logging.warning(f"Found version string(s) in an unsupported format: {unknown_versions}") + + return versions_to_delete + + def _authenticate(self) -> None: + """Authenticate with PyPI.""" + if not self._username or not self._password: + raise AuthenticationError("Username and password are required for authentication") + + logging.info(f"Authenticating user '{self._username}' with PyPI") + + try: + # Get login form and CSRF token + csrf_token = self._get_csrf_token("/account/login/") + + # Attempt login + login_response = self._perform_login(csrf_token) + + # Handle two-factor authentication if required + if login_response.url.startswith(f"{self._index_url}/account/two-factor/"): + logging.debug("Two-factor authentication required") + self._handle_two_factor_auth(login_response) + + logging.info("Authentication successful") + + except RequestException as e: + raise AuthenticationError(f"Network error during authentication: {e}") from e + + def _get_csrf_token(self, form_action: str) -> str: + """Extract CSRF token from a form page.""" + with session_with_retries() as session: + req = session.get(f"{self._index_url}{form_action}") + req.raise_for_status() + parser = CsrfParser(form_action) + parser.feed(req.text) + if not parser.csrf: + raise AuthenticationError(f"No CSRF token found in {form_action}") + return parser.csrf + + def _perform_login(self, csrf_token: str) -> requests.Response: + """Perform the initial login with username/password.""" + login_data = { + "csrf_token": csrf_token, + "username": self._username, + "password": self._password + } + + with session_with_retries() as session: + response = session.post( + f"{self._index_url}/account/login/", + data=login_data, + headers={"referer": f"{self._index_url}/account/login/"} + ) + response.raise_for_status() + + # Check if login failed (redirected back to login page) + if response.url == f"{self._index_url}/account/login/": + raise AuthenticationError(f"Login failed for user '{self._username}' - check credentials") + + return response + + def _handle_two_factor_auth(self, response: requests.Response) -> None: + """Handle two-factor authentication.""" + if not self._otp: + raise AuthenticationError("Two-factor authentication required but no OTP secret provided") + + two_factor_url = response.url + form_action = two_factor_url[len(self._index_url):] + csrf_token = self._get_csrf_token(form_action) + + # Try authentication with retries + for attempt in range(_LOGIN_RETRY_ATTEMPTS): + try: + auth_code = pyotp.TOTP(self._otp).now() + logging.debug(f"Attempting 2FA with code (attempt {attempt + 1}/{_LOGIN_RETRY_ATTEMPTS})") + + with session_with_retries() as session: + auth_response = session.post( + two_factor_url, + data={"csrf_token": csrf_token, "method": "totp", "totp_value": auth_code}, + headers={"referer": two_factor_url} + ) + auth_response.raise_for_status() + + # Check if 2FA succeeded (redirected away from 2FA page) + if auth_response.url != two_factor_url: + logging.debug("Two-factor authentication successful") + return + + if attempt < _LOGIN_RETRY_ATTEMPTS - 1: + logging.debug(f"2FA code rejected, retrying in {_LOGIN_RETRY_DELAY} seconds...") + time.sleep(_LOGIN_RETRY_DELAY) + + except RequestException as e: + if attempt == _LOGIN_RETRY_ATTEMPTS - 1: + raise AuthenticationError(f"Network error during 2FA: {e}") from e + logging.debug(f"Network error during 2FA attempt {attempt + 1}, retrying...") + time.sleep(_LOGIN_RETRY_DELAY) + + raise AuthenticationError("Two-factor authentication failed after all attempts") + + def _delete_versions(self, versions_to_delete: Set[str]) -> None: + """Delete the specified package versions.""" + logging.info(f"Starting deletion of {len(versions_to_delete)} development versions") + + failed_deletions = list() + for version in sorted(versions_to_delete): + try: + self._delete_single_version(version) + logging.info(f"Successfully deleted {self._package} version {version}") + except Exception as e: + # Continue with other versions rather than failing completely + logging.error(f"Failed to delete version {version}: {e}") + failed_deletions.append(version) + + if failed_deletions: + raise PyPICleanupError( + f"Failed to delete {len(failed_deletions)}/{len(versions_to_delete)} versions: {failed_deletions}" + ) + + def _delete_single_version(self, version: str) -> None: + """Delete a single package version.""" + # Safety check + if not self._is_dev_version(version) or self._is_rc_version(version): + raise PyPICleanupError(f"Refusing to delete non-[dev|rc] version: {version}") + + logging.debug(f"Deleting {self._package} version {version}") + + # Get deletion form and CSRF token + form_action = f"/manage/project/{self._package}/release/{version}/" + form_url = f"{self._index_url}{form_action}" + + csrf_token = self._get_csrf_token(form_action) + + with session_with_retries() as session: + # Submit deletion request + delete_response = session.post( + form_url, + data={ + "csrf_token": csrf_token, + "confirm_delete_version": version, + }, + headers={"referer": form_url} + ) + delete_response.raise_for_status() + + +def main() -> int: + """Main entry point for the script.""" + parser = create_argument_parser() + args = parser.parse_args() + + # Setup logging + setup_logging(args.verbose) + + try: + # Validate arguments + validate_arguments(args) + + # Load credentials + password, otp = load_credentials(args.dry_run) + + # Determine PyPI URL + pypi_url = _PYPI_URL_PROD if args.prod else _PYPI_URL_TEST + + # Create and run cleanup + cleanup = PyPICleanup(pypi_url, not args.dry_run, args.max_nightlies, username=args.username, + password=password, otp=otp) + + return cleanup.run() + + except ValidationError as e: + logging.error(f"Configuration error: {e}") + return 2 + except KeyboardInterrupt: + logging.info("Operation cancelled by user") + return 130 + except Exception as e: + logging.error(f"Unexpected error: {e}", exc_info=args.verbose) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/fast/test_pypi_cleanup.py b/tests/fast/test_pypi_cleanup.py new file mode 100644 index 00000000..d90f9fff --- /dev/null +++ b/tests/fast/test_pypi_cleanup.py @@ -0,0 +1,427 @@ +#!/usr/bin/env python3 +""" +Unit tests for pypi_cleanup.py + +Run with: python -m pytest test_pypi_cleanup.py -v +""" + +import os +from unittest.mock import Mock, patch + +import pytest +import requests +from urllib3 import Retry + +duckdb_packaging = pytest.importorskip("duckdb_packaging") + +from duckdb_packaging.pypi_cleanup import ( + PyPICleanup, CsrfParser, PyPICleanupError, AuthenticationError, ValidationError, + setup_logging, validate_username, create_argument_parser, session_with_retries, + load_credentials, validate_arguments, main +) + +class TestValidation: + """Test input validation functions.""" + + def test_validate_username_valid(self): + """Test valid usernames.""" + assert validate_username("user123") == "user123" + assert validate_username(" user.name ") == "user.name" + assert validate_username("test-user_name") == "test-user_name" + assert validate_username("a") == "a" + + def test_validate_username_invalid(self): + """Test invalid usernames.""" + from argparse import ArgumentTypeError + + with pytest.raises(ArgumentTypeError, match="cannot be empty"): + validate_username("") + + with pytest.raises(ArgumentTypeError, match="cannot be empty"): + validate_username(" ") + + with pytest.raises(ArgumentTypeError, match="too long"): + validate_username("a" * 101) + + with pytest.raises(ArgumentTypeError, match="Invalid username format"): + validate_username("-invalid") + + with pytest.raises(ArgumentTypeError, match="Invalid username format"): + validate_username("invalid-") + + def test_validate_arguments_dry_run(self): + """Test argument validation for dry run mode.""" + args = Mock(dry_run=True, username=None, max_nightlies=2) + validate_arguments(args) # Should not raise + + def test_validate_arguments_live_mode_no_username(self): + """Test argument validation for live mode without username.""" + args = Mock(dry_run=False, username=None, max_nightlies=2) + with pytest.raises(ValidationError, match="username is required"): + validate_arguments(args) + + def test_validate_arguments_negative_nightlies(self): + """Test argument validation with negative max nightlies.""" + args = Mock(dry_run=True, username="test", max_nightlies=-1) + with pytest.raises(ValidationError, match="must be non-negative"): + validate_arguments(args) + + +class TestCredentials: + """Test credential loading.""" + + def test_load_credentials_dry_run(self): + """Test credential loading in dry run mode.""" + password, otp = load_credentials(dry_run=True) + assert password is None + assert otp is None + + @patch.dict(os.environ, {'PYPI_CLEANUP_PASSWORD': 'test_pass', 'PYPI_CLEANUP_OTP': 'test_otp'}) + def test_load_credentials_live_mode_success(self): + """Test successful credential loading in live mode.""" + password, otp = load_credentials(dry_run=False) + assert password == 'test_pass' + assert otp == 'test_otp' + + @patch.dict(os.environ, {}, clear=True) + def test_load_credentials_missing_password(self): + """Test credential loading with missing password.""" + with pytest.raises(ValidationError, match="PYPI_CLEANUP_PASSWORD"): + load_credentials(dry_run=False) + + @patch.dict(os.environ, {'PYPI_CLEANUP_PASSWORD': 'test_pass'}) + def test_load_credentials_missing_otp(self): + """Test credential loading with missing OTP.""" + with pytest.raises(ValidationError, match="PYPI_CLEANUP_OTP"): + load_credentials(dry_run=False) + + +class TestUtilities: + """Test utility functions.""" + + def test_create_session_with_retries(self): + """Test session creation with retry configuration.""" + with session_with_retries() as session: + assert isinstance(session, requests.Session) + # Verify retry adapter is mounted + adapter = session.get_adapter("https://example.com") + assert hasattr(adapter, 'max_retries') + retries = getattr(adapter, 'max_retries') + assert isinstance(retries, Retry) + + @patch('duckdb_packaging.pypi_cleanup.logging.basicConfig') + def test_setup_logging_normal(self, mock_basicConfig): + """Test logging setup in normal mode.""" + setup_logging(verbose=False) + mock_basicConfig.assert_called_once() + call_args = mock_basicConfig.call_args[1] + assert call_args['level'] == 20 # INFO level + + @patch('duckdb_packaging.pypi_cleanup.logging.basicConfig') + def test_setup_logging_verbose(self, mock_basicConfig): + """Test logging setup in verbose mode.""" + setup_logging(verbose=True) + mock_basicConfig.assert_called_once() + call_args = mock_basicConfig.call_args[1] + assert call_args['level'] == 10 # DEBUG level + + +class TestCsrfParser: + """Test CSRF token parser.""" + + def test_csrf_parser_simple_form(self): + """Test parsing CSRF token from simple form.""" + html = ''' +
+ + +
+ ''' + parser = CsrfParser("/test") + parser.feed(html) + assert parser.csrf == "abc123" + + def test_csrf_parser_multiple_forms(self): + """Test parsing CSRF token when multiple forms exist.""" + html = ''' +
+ +
+
+ +
+ ''' + parser = CsrfParser("/test") + parser.feed(html) + assert parser.csrf == "correct" + + def test_csrf_parser_no_token(self): + """Test parser when no CSRF token is found.""" + html = '
' + parser = CsrfParser("/test") + parser.feed(html) + assert parser.csrf is None + + +class TestPyPICleanup: + """Test the main PyPICleanup class.""" + @pytest.fixture + def cleanup_dryrun_max_2(self) -> PyPICleanup: + return PyPICleanup("https://test.pypi.org/", False, 2) + + @pytest.fixture + def cleanup_dryrun_max_0(self) -> PyPICleanup: + return PyPICleanup("https://test.pypi.org/", False, 0) + + @pytest.fixture + def cleanup_max_2(self) -> PyPICleanup: + return PyPICleanup("https://test.pypi.org/", True, 2, + username="", password="", otp="") + + def test_determine_versions_to_delete_max_2(self, cleanup_dryrun_max_2): + start_state = { + "0.1.0", + "1.0.0.dev1", "1.0.0.dev2", "1.0.0.rc1", "1.0.0", + "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", "1.0.1", + "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", "1.1.0", "1.1.0.post1", + "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", + "2.0.0.dev602", "2.0.0.rc1", "2.0.0.rc2", "2.0.0.rc3", "2.0.0.rc4", "2.0.0", + "2.0.1.dev974", "2.0.1.rc1", "2.0.1.rc2", "2.0.1.rc3", + } + expected_deletions = { + "1.0.0.dev1", "1.0.0.dev2", "1.0.0.rc1", + "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", + "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", + "1.1.1.dev142", + "2.0.0.dev602", "2.0.0.rc1", "2.0.0.rc2", "2.0.0.rc3", "2.0.0.rc4", + "2.0.1.dev974" + } + versions_to_delete = cleanup_dryrun_max_2._determine_versions_to_delete(start_state) + assert versions_to_delete == expected_deletions + + def test_determine_versions_to_delete_max_0(self, cleanup_dryrun_max_0): + start_state = { + "0.1.0", + "1.0.0.dev1", "1.0.0.dev2", "1.0.0.rc1", "1.0.0", + "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", "1.0.1", + "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", "1.1.0", "1.1.0.post1", + "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", + "2.0.0.dev602", "2.0.0.rc1", "2.0.0.rc2", "2.0.0.rc3", "2.0.0.rc4", "2.0.0", + "2.0.1.dev974", "2.0.1.rc1", "2.0.1.rc2", "2.0.1.rc3", + } + expected_deletions = { + "1.0.0.dev1", "1.0.0.dev2", "1.0.0.rc1", + "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", + "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", + "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", + "2.0.0.dev602", "2.0.0.rc1", "2.0.0.rc2", "2.0.0.rc3", "2.0.0.rc4", + "2.0.1.dev974" + } + versions_to_delete = cleanup_dryrun_max_0._determine_versions_to_delete(start_state) + assert versions_to_delete == expected_deletions + + def test_determine_versions_to_delete_only_devs_max_2(self, cleanup_dryrun_max_2): + start_state = { + "1.0.0.dev1", "1.0.0.dev2", + "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", + "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", + "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", + "2.0.0.dev602", + "2.0.1.dev974", + } + expected_deletions = { + "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", + "1.1.0.dev34", + "1.1.1.dev142", + } + versions_to_delete = cleanup_dryrun_max_2._determine_versions_to_delete(start_state) + assert versions_to_delete == expected_deletions + + def test_determine_versions_to_delete_only_devs_max_0_fails(self, cleanup_dryrun_max_0): + start_state = { + "1.0.0.dev1", "1.0.0.dev2", + "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", + "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", + "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", + "2.0.0.dev602", + "2.0.1.dev974", + } + with pytest.raises(PyPICleanupError, match="Safety check failed"): + cleanup_dryrun_max_0._determine_versions_to_delete(start_state) + + @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._delete_versions') + @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._fetch_released_versions') + @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._determine_versions_to_delete') + def test_execute_cleanup_dry_run(self, mock_determine, mock_fetch, mock_delete, cleanup_dryrun_max_2): + mock_fetch.return_value = {"1.0.0.dev1"} + mock_determine.return_value = {"1.0.0.dev1"} + + result = cleanup_dryrun_max_2._execute_cleanup() + + assert result == 0 + mock_fetch.assert_called_once() + mock_determine.assert_called_once() + mock_delete.assert_not_called() + + @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._fetch_released_versions') + def test_execute_cleanup_no_releases(self, mock_fetch, cleanup_dryrun_max_2): + mock_fetch.return_value = {} + result = cleanup_dryrun_max_2._execute_cleanup() + assert result == 0 + + @patch('requests.Session.get') + def test_fetch_released_versions_success(self, mock_get, cleanup_dryrun_max_2): + """Test successful package release fetching.""" + mock_response = Mock() + mock_response.json.return_value = { + "releases": { + "1.0.0": [{"upload_time": "2023-01-01T10:00:00"}], + "1.0.0.dev1": [{"upload_time": "2022-12-01T10:00:00"}], + } + } + mock_get.return_value = mock_response + + releases = cleanup_dryrun_max_2._fetch_released_versions() + + assert releases == {"1.0.0", "1.0.0.dev1"} + + @patch('requests.Session.get') + def test_fetch_released_versions_not_found(self, mock_get, cleanup_dryrun_max_2): + """Test package release fetching when package not found.""" + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.HTTPError("404") + mock_get.return_value = mock_response + + with pytest.raises(PyPICleanupError, match="Failed to fetch package information"): + cleanup_dryrun_max_2._fetch_released_versions() + + @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._get_csrf_token') + @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._perform_login') + def test_authenticate_success(self, mock_login, mock_csrf, cleanup_max_2): + """Test successful authentication.""" + mock_csrf.return_value = "csrf123" + mock_response = Mock() + mock_response.url = "https://test.pypi.org/manage/" + mock_login.return_value = mock_response + + cleanup_max_2._authenticate() # Should not raise + + mock_csrf.assert_called_once_with("/account/login/") + mock_login.assert_called_once_with("csrf123") + + @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._get_csrf_token') + @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._perform_login') + @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._handle_two_factor_auth') + def test_authenticate_with_2fa(self, mock_2fa, mock_login, mock_csrf, cleanup_max_2): + mock_csrf.return_value = "csrf123" + mock_response = Mock() + mock_response.url = "https://test.pypi.org/account/two-factor/totp" + mock_login.return_value = mock_response + + cleanup_max_2._authenticate() + + mock_2fa.assert_called_once_with(mock_response) + + def test_authenticate_missing_credentials(self, cleanup_dryrun_max_2): + with pytest.raises(AuthenticationError, match="Username and password are required"): + cleanup_dryrun_max_2._authenticate() + + @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._delete_single_version') + def test_delete_versions_success(self, mock_delete, cleanup_max_2): + """Test successful version deletion.""" + versions = {"1.0.0.dev1", "1.0.0.dev2"} + mock_delete.side_effect = [None, None] # Successful deletions + + cleanup_max_2._delete_versions(versions) + + assert mock_delete.call_count == 2 + + @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._delete_single_version') + def test_delete_versions_partial_failure(self, mock_delete, cleanup_max_2): + """Test version deletion with partial failures.""" + versions = {"1.0.0.dev1", "1.0.0.dev2"} + mock_delete.side_effect = [None, Exception("Delete failed")] + + with pytest.raises(PyPICleanupError, match="Failed to delete 1/2 versions"): + cleanup_max_2._delete_versions(versions) + + def test_delete_single_version_safety_check(self, cleanup_max_2): + """Test single version deletion safety check.""" + with pytest.raises(PyPICleanupError, match="Refusing to delete non-\\[dev\\|rc\\] version"): + cleanup_max_2._delete_single_version("1.0.0") # Non-dev version + + +class TestArgumentParser: + """Test command line argument parsing.""" + + def test_argument_parser_creation(self): + """Test argument parser creation.""" + parser = create_argument_parser() + assert parser.prog is not None + + def test_parse_args_prod_dry_run(self): + """Test parsing arguments for production dry run.""" + parser = create_argument_parser() + args = parser.parse_args(['--prod', '--dry-run']) + + assert args.prod is True + assert args.test is False + assert args.dry_run is True + assert args.max_nightlies == 2 + assert args.verbose is False + + def test_parse_args_test_with_username(self): + """Test parsing arguments for test with username.""" + parser = create_argument_parser() + args = parser.parse_args(['--test', '-u', 'testuser', '--verbose']) + + assert args.test is True + assert args.prod is False + assert args.username == 'testuser' + assert args.verbose is True + + def test_parse_args_missing_host(self): + """Test parsing arguments with missing host selection.""" + parser = create_argument_parser() + + with pytest.raises(SystemExit): + parser.parse_args(['--dry-run']) # Missing --prod or --test + + +class TestMainFunction: + """Test the main function.""" + + @patch('duckdb_packaging.pypi_cleanup.setup_logging') + @patch('duckdb_packaging.pypi_cleanup.PyPICleanup') + @patch.dict(os.environ, {'PYPI_CLEANUP_PASSWORD': 'test', 'PYPI_CLEANUP_OTP': 'test'}) + def test_main_success(self, mock_cleanup_class, mock_setup_logging): + """Test successful main function execution.""" + mock_cleanup = Mock() + mock_cleanup.run.return_value = 0 + mock_cleanup_class.return_value = mock_cleanup + + with patch('sys.argv', ['pypi_cleanup.py', '--test', '-u', 'testuser']): + result = main() + + assert result == 0 + mock_setup_logging.assert_called_once() + mock_cleanup.run.assert_called_once() + + @patch('duckdb_packaging.pypi_cleanup.setup_logging') + def test_main_validation_error(self, mock_setup_logging): + """Test main function with validation error.""" + with patch('sys.argv', ['pypi_cleanup.py', '--test']): # Missing username for live mode + result = main() + + assert result == 2 # Validation error exit code + + @patch('duckdb_packaging.pypi_cleanup.setup_logging') + @patch('duckdb_packaging.pypi_cleanup.validate_arguments') + def test_main_keyboard_interrupt(self, mock_validate, mock_setup_logging): + """Test main function with keyboard interrupt.""" + mock_validate.side_effect = KeyboardInterrupt() + + with patch('sys.argv', ['pypi_cleanup.py', '--test', '--dry-run']): + result = main() + + assert result == 130 # Keyboard interrupt exit code From 6d36321ea45165ed81cd099f1b8efdef07492e79 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 25 Jul 2025 17:46:33 +0200 Subject: [PATCH 055/472] Always use dev pypi --- .github/workflows/cleanup_pypi.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/cleanup_pypi.yml b/.github/workflows/cleanup_pypi.yml index e1aad516..5ceb62c3 100644 --- a/.github/workflows/cleanup_pypi.yml +++ b/.github/workflows/cleanup_pypi.yml @@ -49,11 +49,14 @@ jobs: version: "0.7.14" - name: Run Cleanup + env: + PYTHON_UNBUFFERED: 1 run: | set -x uv sync --only-group pypi --no-install-project + # TODO: set test/prod flag according to env (inputs.environment == 'production.pypi' && '--prod' || '--test') uv run --no-sync python -u -m duckdb_packaging.pypi_cleanup ${{ inputs.dry-run && '--dry' || '' }} \ - ${{ inputs.environment == 'production.pypi' && '--prod' || '--test' }} \ + --test \ --username "${{ vars.PYPI_CLEANUP_USERNAME }}" \ --max-nightlies ${{ vars.PYPI_MAX_NIGHTLIES }} > cleanup_output 2>&1 From 760cc25ad051cfd506e79f7bd48507a4fa87b764 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 25 Jul 2025 17:48:01 +0200 Subject: [PATCH 056/472] Revert submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 4245d5e3..c1062e49 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 4245d5e37981f41026d829b10a90b44b80c94f3a +Subproject commit c1062e494b57b2a858a5b0520a008d3e05d25622 From e67c5d502b94af094ddb8784726fb911f0640214 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 25 Jul 2025 15:51:26 +0000 Subject: [PATCH 057/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index e78a4989..4f243e8c 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit e78a49891b34c25c78cccf2c88de775d49750775 +Subproject commit 4f243e8c1fe894c06efb252d8ae8c64bcf272906 From a6e79c701da90039e34d03b0c58e84da7ac3eea7 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 25 Jul 2025 17:53:44 +0000 Subject: [PATCH 058/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 4f243e8c..e78a4989 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 4f243e8c1fe894c06efb252d8ae8c64bcf272906 +Subproject commit e78a49891b34c25c78cccf2c88de775d49750775 From cc968a5d053f7cbf62e4fcaccb9e195083bab1ca Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 25 Jul 2025 17:57:57 +0000 Subject: [PATCH 059/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index c1062e49..8755ee6e 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit c1062e494b57b2a858a5b0520a008d3e05d25622 +Subproject commit 8755ee6e1c6aaace193466373c9f46635969576e From 01005894a0c3d551237a4fe20bfb5132c9b36c07 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sat, 26 Jul 2025 05:45:54 +0000 Subject: [PATCH 060/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 8755ee6e..c1062e49 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 8755ee6e1c6aaace193466373c9f46635969576e +Subproject commit c1062e494b57b2a858a5b0520a008d3e05d25622 From 908abd6216bba9d948fd1ffd0c2ce9ed866b6b3e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sat, 26 Jul 2025 06:20:56 +0000 Subject: [PATCH 061/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index e78a4989..4f243e8c 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit e78a49891b34c25c78cccf2c88de775d49750775 +Subproject commit 4f243e8c1fe894c06efb252d8ae8c64bcf272906 From 31740e0a879f3909aef78e9a2b7d960ab05521ff Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Sat, 26 Jul 2025 13:06:05 +0200 Subject: [PATCH 062/472] Tee output --- .github/workflows/cleanup_pypi.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cleanup_pypi.yml b/.github/workflows/cleanup_pypi.yml index 5ceb62c3..5da0c700 100644 --- a/.github/workflows/cleanup_pypi.yml +++ b/.github/workflows/cleanup_pypi.yml @@ -58,7 +58,7 @@ jobs: uv run --no-sync python -u -m duckdb_packaging.pypi_cleanup ${{ inputs.dry-run && '--dry' || '' }} \ --test \ --username "${{ vars.PYPI_CLEANUP_USERNAME }}" \ - --max-nightlies ${{ vars.PYPI_MAX_NIGHTLIES }} > cleanup_output 2>&1 + --max-nightlies ${{ vars.PYPI_MAX_NIGHTLIES }} 2>&1 | tee cleanup_output - name: PyPI Cleanup Summary run : | From e357031f40eca2704ff8a76956b66f0b7f156a1c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sun, 27 Jul 2025 05:57:52 +0000 Subject: [PATCH 063/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index c1062e49..3de8d76b 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit c1062e494b57b2a858a5b0520a008d3e05d25622 +Subproject commit 3de8d76ba83a82fb4337ca7cb8e4de2eab748561 From 59bb1cce9dfa2e1dc7d11b0c63666bd22f0704d9 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 29 Jul 2025 11:06:10 +0200 Subject: [PATCH 064/472] Let upload inherit secrets --- .github/workflows/on_external_dispatch.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml index 6e07d764..504d6a49 100644 --- a/.github/workflows/on_external_dispatch.yml +++ b/.github/workflows/on_external_dispatch.yml @@ -120,6 +120,7 @@ jobs: needs: [ commit_submodule, upload_s3 ] if: ${{ inputs.force-version == '' }} uses: ./.github/workflows/upload_to_pypi.yml + secrets: inherit with: sha: ${{ needs.commit_submodule.outputs.sha-after-commit }} environment: production.pypi From caaab9b1a3b289c31d37caa56bb5f47d7c28f7a8 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 29 Jul 2025 09:11:21 +0000 Subject: [PATCH 065/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index eae0dd31..3de8d76b 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit eae0dd31557759cc57a6a1401998e23250369fa0 +Subproject commit 3de8d76ba83a82fb4337ca7cb8e4de2eab748561 From 831d790edb69aa4067b0ae8396d33b7a2a68d691 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 29 Jul 2025 13:24:23 +0200 Subject: [PATCH 066/472] Fix S3 URL --- .github/workflows/upload_to_pypi.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/upload_to_pypi.yml b/.github/workflows/upload_to_pypi.yml index edffcbe8..12e3bdb6 100644 --- a/.github/workflows/upload_to_pypi.yml +++ b/.github/workflows/upload_to_pypi.yml @@ -57,8 +57,9 @@ jobs: AWS_ACCESS_KEY_ID: ${{ secrets.S3_DUCKDB_STAGING_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_DUCKDB_STAGING_KEY }} run: | + sha=${{ inputs.sha }} mkdir packages - aws s3 cp --recursive "${S3_URL}" packages + aws s3 cp --recursive s3://duckdb-staging/${{ github.repository }}/${sha:0:10}/ packages - name: Upload artifacts to PyPI uses: pypa/gh-action-pypi-publish@release/v1 From 703403d389c942816c5fd017df06cb2a9f4d4eea Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 30 Jul 2025 05:41:06 +0000 Subject: [PATCH 067/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 4f243e8c..2efc9ec5 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 4f243e8c1fe894c06efb252d8ae8c64bcf272906 +Subproject commit 2efc9ec537476eccd9a0cb5e730e4e9ea314df7a From f6d7a8fdc25680628df6a6526375c5fb66867697 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 30 Jul 2025 11:35:13 +0200 Subject: [PATCH 068/472] Verbose mode during pypi upload --- .github/workflows/upload_to_pypi.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/upload_to_pypi.yml b/.github/workflows/upload_to_pypi.yml index 12e3bdb6..0452f868 100644 --- a/.github/workflows/upload_to_pypi.yml +++ b/.github/workflows/upload_to_pypi.yml @@ -66,6 +66,7 @@ jobs: with: repository-url: 'https://${{ vars.PYPI_HOST }}/legacy/' packages-dir: packages + verbose: 'true' - name: PyPI Upload Summary run : | From af5cbd27541f7ce4317157735996c08c23227911 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 31 Jul 2025 05:41:03 +0000 Subject: [PATCH 069/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 2efc9ec5..3ac7e19a 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 2efc9ec537476eccd9a0cb5e730e4e9ea314df7a +Subproject commit 3ac7e19ac4942a910436157afd0c530b3bb5aba1 From 79c1a7acea30cbef359e5f3416faa9bb69825eb9 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sat, 2 Aug 2025 05:35:16 +0000 Subject: [PATCH 070/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 3ac7e19a..0e258eca 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 3ac7e19ac4942a910436157afd0c530b3bb5aba1 +Subproject commit 0e258ecaaf50d89eb4e73b5969994f9fb3656681 From ead47b39435a58367b1415dce674783dd6777169 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sun, 3 Aug 2025 03:44:24 +0000 Subject: [PATCH 071/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 3de8d76b..dee87268 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 3de8d76ba83a82fb4337ca7cb8e4de2eab748561 +Subproject commit dee8726869f57cf0672580785ea01aa51339ad6f From 25ef0bde1b700bd37a71e21349d78499e30f01d8 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 4 Aug 2025 05:06:14 +0000 Subject: [PATCH 072/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index dee87268..631b11e9 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit dee8726869f57cf0672580785ea01aa51339ad6f +Subproject commit 631b11e9ed4eeb21a9f54d9cf5cdacc58266a007 From 6505c6fd827fc7e1d0690ea4c7a7ce902583a2ea Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 5 Aug 2025 05:00:22 +0000 Subject: [PATCH 073/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 631b11e9..13bef88c 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 631b11e9ed4eeb21a9f54d9cf5cdacc58266a007 +Subproject commit 13bef88caefe22b998fddaed7c1eb8c583fc4d0d From 768881b973b727920b9d0ef8065a50f6ed8e9847 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 5 Aug 2025 06:15:30 +0000 Subject: [PATCH 074/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 0e258eca..67cbce34 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 0e258ecaaf50d89eb4e73b5969994f9fb3656681 +Subproject commit 67cbce34e13c7b6c9178d13b3886428b3f6f7485 From b0c3ea1e03475a3af91034bbd9f0acae9c978f74 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 6 Aug 2025 06:04:44 +0000 Subject: [PATCH 075/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 13bef88c..22b928b4 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 13bef88caefe22b998fddaed7c1eb8c583fc4d0d +Subproject commit 22b928b4d998648c8f44c0b5b24ecf37b8a622a1 From b9f8cdf44d499163e9b60cf7cd492e9050f87b57 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 7 Aug 2025 06:13:23 +0000 Subject: [PATCH 076/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 22b928b4..482b5702 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 22b928b4d998648c8f44c0b5b24ecf37b8a622a1 +Subproject commit 482b5702f78ee4122612cfe4de6e373e8e1ac963 From 90cc2b6686e45d392b1486e586dc654e55be2fec Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 7 Aug 2025 06:18:04 +0000 Subject: [PATCH 077/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 67cbce34..5766d00d 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 67cbce34e13c7b6c9178d13b3886428b3f6f7485 +Subproject commit 5766d00d2bea7ec1bd8fbd92e498f48d5b92953b From 86763da0aa56d6293aaff321e3057ccbdacf7013 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 8 Aug 2025 06:17:20 +0000 Subject: [PATCH 078/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 482b5702..e76b5346 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 482b5702f78ee4122612cfe4de6e373e8e1ac963 +Subproject commit e76b5346e2856599925828ac257d0f4a4dbdf10f From cfb89208fe4c3b4b410793e9eb7fc9165c0006c1 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sat, 9 Aug 2025 05:32:26 +0000 Subject: [PATCH 079/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index e76b5346..582bf198 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit e76b5346e2856599925828ac257d0f4a4dbdf10f +Subproject commit 582bf198c49a9345465d5af5b4a1fc4568dc9465 From cd2b80abf82d17735930f26310e0131f5712850b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 12 Aug 2025 06:13:53 +0000 Subject: [PATCH 080/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 582bf198..d0e18172 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 582bf198c49a9345465d5af5b4a1fc4568dc9465 +Subproject commit d0e18172e9a7a262cf586bc4116070405ea7b8ab From 80068d777a3d3a99fb583453ef70b36cf67e4b96 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 13 Aug 2025 05:52:34 +0000 Subject: [PATCH 081/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index d0e18172..3483d12a 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit d0e18172e9a7a262cf586bc4116070405ea7b8ab +Subproject commit 3483d12aab380beacb3f0228ed997388c105aed5 From e72fd7d5493edc477dad068a6e611c033e7b97bd Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 14 Aug 2025 06:48:00 +0000 Subject: [PATCH 082/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 3483d12a..42115215 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 3483d12aab380beacb3f0228ed997388c105aed5 +Subproject commit 421152158231ac5fe2d04e9958cedf06cb4aee64 From ab33930deb1566cce29364ec3ee41a60b6474ce6 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 14 Aug 2025 06:48:21 +0000 Subject: [PATCH 083/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 5766d00d..17c093c0 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 5766d00d2bea7ec1bd8fbd92e498f48d5b92953b +Subproject commit 17c093c0bbafaeeec6d614aad298b7357cd10f39 From 7a1f18a842ad500cb7c18b2e993ffb8847fb874e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sat, 16 Aug 2025 05:41:40 +0000 Subject: [PATCH 084/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 17c093c0..99791844 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 17c093c0bbafaeeec6d614aad298b7357cd10f39 +Subproject commit 997918446788c6036a50fa21108a6f63d44fb865 From d7cc8b22a5cdab177e60803e5d267828f3a0a9ec Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sat, 16 Aug 2025 05:55:58 +0000 Subject: [PATCH 085/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 42115215..2ed9bf88 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 421152158231ac5fe2d04e9958cedf06cb4aee64 +Subproject commit 2ed9bf887f61a0ac226ab8c8f1164601d985d607 From ab457f5d68b7ab83ef77e7ae8efe3659b893d1fc Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 19 Aug 2025 06:02:18 +0000 Subject: [PATCH 086/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 2ed9bf88..fbb0df1e 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 2ed9bf887f61a0ac226ab8c8f1164601d985d607 +Subproject commit fbb0df1e225e7c6f6f8fe37b7876ffef019f07ba From 14e07c947f7c203a3a63598dab9e70ecabedbb8b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 19 Aug 2025 06:38:05 +0000 Subject: [PATCH 087/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 99791844..3bcf0148 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 997918446788c6036a50fa21108a6f63d44fb865 +Subproject commit 3bcf01485639fc272d527456232e47f6694f589c From 793a86b2c6e4a44e4a75dccd7a1b9ac3260b1182 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 20 Aug 2025 05:32:08 +0000 Subject: [PATCH 088/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 3bcf0148..0663a714 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 3bcf01485639fc272d527456232e47f6694f589c +Subproject commit 0663a7142014ac4e21779575634bf35cbfd4cc1b From bde1487a7588f86a037b97930017fbd8f41f4773 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 20 Aug 2025 06:10:27 +0000 Subject: [PATCH 089/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index fbb0df1e..a8206a21 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit fbb0df1e225e7c6f6f8fe37b7876ffef019f07ba +Subproject commit a8206a211f01652e5109fc05b0a56b3b778dea1d From 6ee31df3ff1107d798cc96034d48a06c49566c47 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 21 Aug 2025 05:43:23 +0000 Subject: [PATCH 090/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 0663a714..aaa4635f 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 0663a7142014ac4e21779575634bf35cbfd4cc1b +Subproject commit aaa4635fff6a92736c6fc5bf4023f75c0414be02 From a7873e5ce0ce0a3c0cd05e9e3be8b7d6130ae7c8 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 21 Aug 2025 05:49:07 +0000 Subject: [PATCH 091/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index a8206a21..129b1fe5 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit a8206a211f01652e5109fc05b0a56b3b778dea1d +Subproject commit 129b1fe55ef24e616754238cb100e3b9a926e4b6 From a4fdc8354d4c953a40d83b0387621975471247c5 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 21 Aug 2025 10:49:32 +0200 Subject: [PATCH 092/472] Fwd port of PR 18658: Load pandas in import cache before binding --- src/duckdb_py/pandas/bind.cpp | 9 +++++++++ tests/fast/pandas/test_import_cache.py | 28 ++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 tests/fast/pandas/test_import_cache.py diff --git a/src/duckdb_py/pandas/bind.cpp b/src/duckdb_py/pandas/bind.cpp index 8f6919cb..4e40c20e 100644 --- a/src/duckdb_py/pandas/bind.cpp +++ b/src/duckdb_py/pandas/bind.cpp @@ -1,6 +1,7 @@ #include "duckdb_python/pandas/pandas_bind.hpp" #include "duckdb_python/pandas/pandas_analyzer.hpp" #include "duckdb_python/pandas/column/pandas_numpy_column.hpp" +#include "duckdb_python/pyconnection/pyconnection.hpp" namespace duckdb { @@ -123,6 +124,14 @@ void Pandas::Bind(const ClientContext &context, py::handle df_p, vector Date: Thu, 21 Aug 2025 11:16:35 +0200 Subject: [PATCH 093/472] Fwd ports of: PR #18642: Change arrow() to export record batch reader PR #18624: Adjust filter pushdown to latest polars release PR #18547: Rename the Varint type to Bignum --- duckdb/__init__.pyi | 5 +- duckdb/polars_io.py | 44 +++++++++----- scripts/connection_methods.json | 11 ++-- src/duckdb_py/duckdb_python.cpp | 40 ++++++------- src/duckdb_py/native/python_objects.cpp | 8 +-- src/duckdb_py/pyconnection.cpp | 4 +- tests/fast/api/test_dbapi_fetch.py | 6 +- tests/fast/api/test_duckdb_connection.py | 4 +- tests/fast/api/test_native_tz.py | 28 +++++++-- tests/fast/arrow/test_6584.py | 2 +- tests/fast/arrow/test_arrow_binary_view.py | 4 +- tests/fast/arrow/test_arrow_decimal256.py | 2 +- tests/fast/arrow/test_arrow_decimal_32_64.py | 4 +- tests/fast/arrow/test_arrow_extensions.py | 60 ++++++++++--------- tests/fast/arrow/test_arrow_fetch.py | 6 +- .../fast/arrow/test_arrow_run_end_encoding.py | 38 ++++++------ tests/fast/arrow/test_arrow_string_view.py | 6 +- tests/fast/arrow/test_arrow_types.py | 2 +- tests/fast/arrow/test_arrow_union.py | 4 +- tests/fast/arrow/test_arrow_version_format.py | 28 ++++----- tests/fast/arrow/test_buffer_size_option.py | 10 ++-- tests/fast/arrow/test_date.py | 6 +- tests/fast/arrow/test_dictionary_arrow.py | 2 +- tests/fast/arrow/test_filter_pushdown.py | 22 +++---- tests/fast/arrow/test_integration.py | 30 +++++----- tests/fast/arrow/test_interval.py | 6 +- tests/fast/arrow/test_large_offsets.py | 8 +-- tests/fast/arrow/test_nested_arrow.py | 18 ++++-- tests/fast/arrow/test_projection_pushdown.py | 2 +- tests/fast/arrow/test_time.py | 10 ++-- tests/fast/arrow/test_timestamp_timezone.py | 4 +- tests/fast/arrow/test_timestamps.py | 6 +- tests/fast/arrow/test_tpch.py | 6 +- tests/fast/relational_api/test_rapi_close.py | 2 +- tests/fast/spark/test_spark_types.py | 2 +- tests/fast/test_all_types.py | 8 +-- tests/fast/test_replacement_scan.py | 2 +- tests/fast/test_runtime_error.py | 6 +- tests/fast/udf/test_scalar_arrow.py | 2 +- 39 files changed, 249 insertions(+), 209 deletions(-) diff --git a/duckdb/__init__.pyi b/duckdb/__init__.pyi index 7ed8b4e1..adf142dd 100644 --- a/duckdb/__init__.pyi +++ b/duckdb/__init__.pyi @@ -180,6 +180,7 @@ class StatementType: DETACH: StatementType MULTI: StatementType COPY_DATABASE: StatementType + MERGE_INTO: StatementType def __int__(self) -> int: ... def __index__(self) -> int: ... @property @@ -320,8 +321,8 @@ class DuckDBPyConnection: def fetch_df_chunk(self, vectors_per_chunk: int = 1, *, date_as_object: bool = False) -> pandas.DataFrame: ... def pl(self, rows_per_batch: int = 1000000, *, lazy: bool = False) -> polars.DataFrame: ... def fetch_arrow_table(self, rows_per_batch: int = 1000000) -> pyarrow.lib.Table: ... - def arrow(self, rows_per_batch: int = 1000000) -> pyarrow.lib.Table: ... def fetch_record_batch(self, rows_per_batch: int = 1000000) -> pyarrow.lib.RecordBatchReader: ... + def arrow(self, rows_per_batch: int = 1000000) -> pyarrow.lib.RecordBatchReader: ... def torch(self) -> dict: ... def tf(self) -> dict: ... def begin(self) -> DuckDBPyConnection: ... @@ -668,8 +669,8 @@ def df(*, date_as_object: bool = False, connection: DuckDBPyConnection = ...) -> def fetch_df_chunk(vectors_per_chunk: int = 1, *, date_as_object: bool = False, connection: DuckDBPyConnection = ...) -> pandas.DataFrame: ... def pl(rows_per_batch: int = 1000000, *, lazy: bool = False, connection: DuckDBPyConnection = ...) -> polars.DataFrame: ... def fetch_arrow_table(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.Table: ... -def arrow(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.Table: ... def fetch_record_batch(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.RecordBatchReader: ... +def arrow(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.RecordBatchReader: ... def torch(*, connection: DuckDBPyConnection = ...) -> dict: ... def tf(*, connection: DuckDBPyConnection = ...) -> dict: ... def begin(*, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index bea155e7..dbe8727b 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -123,47 +123,61 @@ def _pl_tree_to_sql(tree: dict) -> str: raise NotImplementedError(f"Unsupported function type: {func_dict}") if node_type == "Scalar": - # Handle scalar values with typed representations - dtype = str(subtree["dtype"]) - value = subtree["value"] + # Detect format: old style (dtype/value) or new style (direct type key) + if "dtype" in subtree and "value" in subtree: + dtype = str(subtree["dtype"]) + value = subtree["value"] + else: + # New style: dtype is the single key in the dict + dtype = next(iter(subtree.keys())) + value = subtree # Decimal support - if dtype.startswith("{'Decimal'"): + if dtype.startswith("{'Decimal'") or dtype == "Decimal": decimal_value = value['Decimal'] decimal_value = Decimal(decimal_value[0]) / Decimal(10 ** decimal_value[1]) return str(decimal_value) # Datetime with microseconds since epoch - if dtype.startswith("{'Datetime'"): + if dtype.startswith("{'Datetime'") or dtype == "Datetime": micros = value['Datetime'][0] dt_timestamp = datetime.datetime.fromtimestamp(micros / 1_000_000, tz=datetime.UTC) return f"'{str(dt_timestamp)}'::TIMESTAMP" - # Match simple types - if dtype in ("Int8", "Int16", "Int32", "Int64", "UInt8", "UInt16", "UInt32", "UInt64", "Float32", "Float64", "Boolean"): + # Match simple numeric/boolean types + if dtype in ("Int8", "Int16", "Int32", "Int64", + "UInt8", "UInt16", "UInt32", "UInt64", + "Float32", "Float64", "Boolean"): return str(value[dtype]) + # Time type if dtype == "Time": - # Convert nanoseconds to TIME nanoseconds = value["Time"] seconds = nanoseconds // 1_000_000_000 microseconds = (nanoseconds % 1_000_000_000) // 1_000 - dt_time = (datetime.datetime.min + datetime.timedelta(seconds=seconds, microseconds=microseconds)).time() - return f"'{str(dt_time)}'::TIME" + dt_time = (datetime.datetime.min + datetime.timedelta( + seconds=seconds, microseconds=microseconds + )).time() + return f"'{dt_time}'::TIME" + # Date type if dtype == "Date": - # Convert days since Unix epoch to SQL DATE days_since_epoch = value["Date"] date = datetime.date(1970, 1, 1) + datetime.timedelta(days=days_since_epoch) - return f"'{str(date)}'::DATE" + return f"'{date}'::DATE" + + # Binary type if dtype == "Binary": - # Convert binary data to hex string for BLOB binary_data = bytes(value["Binary"]) escaped = ''.join(f'\\x{b:02x}' for b in binary_data) return f"'{escaped}'::BLOB" - if dtype == "String": - return f"'{value['StringOwned']}'" + # String type + if dtype == "String" or dtype == "StringOwned": + # Some new formats may store directly under StringOwned + string_val = value.get("StringOwned", value.get("String", None)) + return f"'{string_val}'" + raise NotImplementedError(f"Unsupported scalar type {str(dtype)}, with value {value}") diff --git a/scripts/connection_methods.json b/scripts/connection_methods.json index 521d7acb..27705d6a 100644 --- a/scripts/connection_methods.json +++ b/scripts/connection_methods.json @@ -395,10 +395,7 @@ "return": "polars.DataFrame" }, { - "name": [ - "fetch_arrow_table", - "arrow" - ], + "name": "fetch_arrow_table", "function": "FetchArrow", "docs": "Fetch a result as Arrow table following execute()", "args": [ @@ -411,7 +408,11 @@ "return": "pyarrow.lib.Table" }, { - "name": "fetch_record_batch", + "name": [ + "fetch_record_batch", + "arrow" + ], + "function": "FetchRecordBatchReader", "docs": "Fetch an Arrow RecordBatchReader following execute()", "args": [ diff --git a/src/duckdb_py/duckdb_python.cpp b/src/duckdb_py/duckdb_python.cpp index 27ebe8b9..939fa41a 100644 --- a/src/duckdb_py/duckdb_python.cpp +++ b/src/duckdb_py/duckdb_python.cpp @@ -427,17 +427,17 @@ static void InitializeConnectionMethods(py::module_ &m) { "Fetch a result as Arrow table following execute()", py::arg("rows_per_batch") = 1000000, py::kw_only(), py::arg("connection") = py::none()); m.def( - "arrow", - [](idx_t rows_per_batch, shared_ptr conn = nullptr) { + "fetch_record_batch", + [](const idx_t rows_per_batch, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } - return conn->FetchArrow(rows_per_batch); + return conn->FetchRecordBatchReader(rows_per_batch); }, - "Fetch a result as Arrow table following execute()", py::arg("rows_per_batch") = 1000000, py::kw_only(), + "Fetch an Arrow RecordBatchReader following execute()", py::arg("rows_per_batch") = 1000000, py::kw_only(), py::arg("connection") = py::none()); m.def( - "fetch_record_batch", + "arrow", [](const idx_t rows_per_batch, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); @@ -971,21 +971,21 @@ static void InitializeConnectionMethods(py::module_ &m) { static void RegisterStatementType(py::handle &m) { auto statement_type = py::enum_(m, "StatementType"); static const duckdb::StatementType TYPES[] = { - duckdb::StatementType::INVALID_STATEMENT, duckdb::StatementType::SELECT_STATEMENT, - duckdb::StatementType::INSERT_STATEMENT, duckdb::StatementType::UPDATE_STATEMENT, - duckdb::StatementType::CREATE_STATEMENT, duckdb::StatementType::DELETE_STATEMENT, - duckdb::StatementType::PREPARE_STATEMENT, duckdb::StatementType::EXECUTE_STATEMENT, - duckdb::StatementType::ALTER_STATEMENT, duckdb::StatementType::TRANSACTION_STATEMENT, - duckdb::StatementType::COPY_STATEMENT, duckdb::StatementType::ANALYZE_STATEMENT, - duckdb::StatementType::VARIABLE_SET_STATEMENT, duckdb::StatementType::CREATE_FUNC_STATEMENT, - duckdb::StatementType::EXPLAIN_STATEMENT, duckdb::StatementType::DROP_STATEMENT, - duckdb::StatementType::EXPORT_STATEMENT, duckdb::StatementType::PRAGMA_STATEMENT, - duckdb::StatementType::VACUUM_STATEMENT, duckdb::StatementType::CALL_STATEMENT, - duckdb::StatementType::SET_STATEMENT, duckdb::StatementType::LOAD_STATEMENT, - duckdb::StatementType::RELATION_STATEMENT, duckdb::StatementType::EXTENSION_STATEMENT, - duckdb::StatementType::LOGICAL_PLAN_STATEMENT, duckdb::StatementType::ATTACH_STATEMENT, - duckdb::StatementType::DETACH_STATEMENT, duckdb::StatementType::MULTI_STATEMENT, - duckdb::StatementType::COPY_DATABASE_STATEMENT}; + duckdb::StatementType::INVALID_STATEMENT, duckdb::StatementType::SELECT_STATEMENT, + duckdb::StatementType::INSERT_STATEMENT, duckdb::StatementType::UPDATE_STATEMENT, + duckdb::StatementType::CREATE_STATEMENT, duckdb::StatementType::DELETE_STATEMENT, + duckdb::StatementType::PREPARE_STATEMENT, duckdb::StatementType::EXECUTE_STATEMENT, + duckdb::StatementType::ALTER_STATEMENT, duckdb::StatementType::TRANSACTION_STATEMENT, + duckdb::StatementType::COPY_STATEMENT, duckdb::StatementType::ANALYZE_STATEMENT, + duckdb::StatementType::VARIABLE_SET_STATEMENT, duckdb::StatementType::CREATE_FUNC_STATEMENT, + duckdb::StatementType::EXPLAIN_STATEMENT, duckdb::StatementType::DROP_STATEMENT, + duckdb::StatementType::EXPORT_STATEMENT, duckdb::StatementType::PRAGMA_STATEMENT, + duckdb::StatementType::VACUUM_STATEMENT, duckdb::StatementType::CALL_STATEMENT, + duckdb::StatementType::SET_STATEMENT, duckdb::StatementType::LOAD_STATEMENT, + duckdb::StatementType::RELATION_STATEMENT, duckdb::StatementType::EXTENSION_STATEMENT, + duckdb::StatementType::LOGICAL_PLAN_STATEMENT, duckdb::StatementType::ATTACH_STATEMENT, + duckdb::StatementType::DETACH_STATEMENT, duckdb::StatementType::MULTI_STATEMENT, + duckdb::StatementType::COPY_DATABASE_STATEMENT, duckdb::StatementType::MERGE_INTO_STATEMENT}; static const idx_t AMOUNT = sizeof(TYPES) / sizeof(duckdb::StatementType); for (idx_t i = 0; i < AMOUNT; i++) { auto &type = TYPES[i]; diff --git a/src/duckdb_py/native/python_objects.cpp b/src/duckdb_py/native/python_objects.cpp index 45112483..21aa281f 100644 --- a/src/duckdb_py/native/python_objects.cpp +++ b/src/duckdb_py/native/python_objects.cpp @@ -8,7 +8,7 @@ #include "duckdb/common/operator/cast_operators.hpp" #include "duckdb_python/pyconnection/pyconnection.hpp" #include "duckdb/common/operator/add.hpp" -#include "duckdb/common/types/varint.hpp" +#include "duckdb/common/types/bignum.hpp" #include "duckdb/function/to_interval.hpp" #include "datetime.h" // Python datetime initialize #1 @@ -683,9 +683,9 @@ py::object PythonObject::FromValue(const Value &val, const LogicalType &type, auto uuid_value = val.GetValueUnsafe(); return import_cache.uuid.UUID()(UUID::ToString(uuid_value)); } - case LogicalTypeId::VARINT: { - auto varint_value = val.GetValueUnsafe(); - return py::str(Varint::VarIntToVarchar(varint_value)); + case LogicalTypeId::BIGNUM: { + auto bignum_value = val.GetValueUnsafe(); + return py::str(Bignum::BignumToVarchar(bignum_value)); } case LogicalTypeId::INTERVAL: { auto interval_value = val.GetValueUnsafe(); diff --git a/src/duckdb_py/pyconnection.cpp b/src/duckdb_py/pyconnection.cpp index 30b58701..90156395 100644 --- a/src/duckdb_py/pyconnection.cpp +++ b/src/duckdb_py/pyconnection.cpp @@ -204,10 +204,10 @@ static void InitializeConnectionMethods(py::class_ '2019-01-01'").df() expected_df = duckdb.from_parquet(glob_pattern.as_posix()).filter("date > '2019-01-01'").df() @@ -737,7 +737,7 @@ def test_filter_column_removal(self, duckdb_cursor, create_table): match = re.search("│ +b +│", query_res[0][1]) assert not match - @pytest.mark.skipif(sys.version_info <= (3, 9), reason="Requires python 3.9") + @pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9") @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) def test_struct_filter_pushdown(self, duckdb_cursor, create_table): duckdb_cursor.execute( @@ -808,7 +808,7 @@ def test_struct_filter_pushdown(self, duckdb_cursor, create_table): match = re.search(".*ARROW_SCAN.*Filters: s\\.a IS NULL.*", query_res[0][1], flags=re.DOTALL) assert not match - @pytest.mark.skipif(sys.version_info <= (3, 9), reason="Requires python 3.9") + @pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9") @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) def test_nested_struct_filter_pushdown(self, duckdb_cursor, create_table): duckdb_cursor.execute( @@ -896,7 +896,7 @@ def test_filter_pushdown_not_supported(self): con.execute( "CREATE TABLE T as SELECT i::integer a, i::varchar b, i::uhugeint c, i::integer d FROM range(5) tbl(i)" ) - arrow_tbl = con.execute("FROM T").arrow() + arrow_tbl = con.execute("FROM T").fetch_arrow_table() # No projection just unsupported filter assert con.execute("from arrow_tbl where c == 3").fetchall() == [(3, '3', 3, 3)] @@ -920,7 +920,7 @@ def test_filter_pushdown_not_supported(self): "CREATE TABLE T_2 as SELECT i::integer a, i::varchar b, i::uhugeint c, i::integer d , i::uhugeint e, i::smallint f, i::uhugeint g FROM range(50) tbl(i)" ) - arrow_tbl = con.execute("FROM T_2").arrow() + arrow_tbl = con.execute("FROM T_2").fetch_arrow_table() assert con.execute( "select a, b from arrow_tbl where a > 2 and c < 40 and b == '28' and g > 15 and e < 30" @@ -932,8 +932,8 @@ def test_join_filter_pushdown(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE build as select (random()*9999)::INT b from range(20);") duck_probe = duckdb_conn.table("probe") duck_build = duckdb_conn.table("build") - duck_probe_arrow = duck_probe.arrow() - duck_build_arrow = duck_build.arrow() + duck_probe_arrow = duck_probe.fetch_arrow_table() + duck_build_arrow = duck_build.fetch_arrow_table() duckdb_conn.register("duck_probe_arrow", duck_probe_arrow) duckdb_conn.register("duck_build_arrow", duck_build_arrow) assert duckdb_conn.execute("SELECT count(*) from duck_probe_arrow, duck_build_arrow where a=b").fetchall() == [ @@ -944,7 +944,7 @@ def test_in_filter_pushdown(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE probe as select range a from range(1000);") duck_probe = duckdb_conn.table("probe") - duck_probe_arrow = duck_probe.arrow() + duck_probe_arrow = duck_probe.fetch_arrow_table() duckdb_conn.register("duck_probe_arrow", duck_probe_arrow) assert duckdb_conn.execute("SELECT * from duck_probe_arrow where a = any([1,999])").fetchall() == [(1,), (999,)] @@ -1006,7 +1006,7 @@ def assert_equal_results(con, arrow_table, query): arrow_res = con.sql(query.format(table='arrow_table')).fetchall() assert len(duckdb_res) == len(arrow_res) - arrow_table = duckdb_cursor.table('test').arrow() + arrow_table = duckdb_cursor.table('test').fetch_arrow_table() assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a > 'NaN'::FLOAT") assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a >= 'NaN'::FLOAT") assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a < 'NaN'::FLOAT") diff --git a/tests/fast/arrow/test_integration.py b/tests/fast/arrow/test_integration.py index 7562221b..d9006758 100644 --- a/tests/fast/arrow/test_integration.py +++ b/tests/fast/arrow/test_integration.py @@ -17,10 +17,10 @@ def test_parquet_roundtrip(self, duckdb_cursor): userdata_parquet_table = pq.read_table(parquet_filename) userdata_parquet_table.validate(full=True) - rel_from_arrow = duckdb.arrow(userdata_parquet_table).project(cols).arrow() + rel_from_arrow = duckdb.arrow(userdata_parquet_table).project(cols).fetch_arrow_table() rel_from_arrow.validate(full=True) - rel_from_duckdb = duckdb_cursor.from_parquet(parquet_filename).project(cols).arrow() + rel_from_duckdb = duckdb_cursor.from_parquet(parquet_filename).project(cols).fetch_arrow_table() rel_from_duckdb.validate(full=True) # batched version, lets use various values for batch size @@ -28,7 +28,7 @@ def test_parquet_roundtrip(self, duckdb_cursor): userdata_parquet_table2 = pa.Table.from_batches(userdata_parquet_table.to_batches(i)) assert userdata_parquet_table.equals(userdata_parquet_table2, check_metadata=True) - rel_from_arrow2 = duckdb.arrow(userdata_parquet_table2).project(cols).arrow() + rel_from_arrow2 = duckdb.arrow(userdata_parquet_table2).project(cols).fetch_arrow_table() rel_from_arrow2.validate(full=True) assert rel_from_arrow.equals(rel_from_arrow2, check_metadata=True) @@ -40,10 +40,10 @@ def test_unsigned_roundtrip(self, duckdb_cursor): unsigned_parquet_table = pq.read_table(parquet_filename) unsigned_parquet_table.validate(full=True) - rel_from_arrow = duckdb.arrow(unsigned_parquet_table).project(cols).arrow() + rel_from_arrow = duckdb.arrow(unsigned_parquet_table).project(cols).fetch_arrow_table() rel_from_arrow.validate(full=True) - rel_from_duckdb = duckdb_cursor.from_parquet(parquet_filename).project(cols).arrow() + rel_from_duckdb = duckdb_cursor.from_parquet(parquet_filename).project(cols).fetch_arrow_table() rel_from_duckdb.validate(full=True) assert rel_from_arrow.equals(rel_from_duckdb, check_metadata=True) @@ -70,7 +70,7 @@ def test_decimals_roundtrip(self, duckdb_cursor): duck_tbl = duckdb_cursor.table("test") - duck_from_arrow = duckdb_cursor.from_arrow(duck_tbl.arrow()) + duck_from_arrow = duckdb_cursor.from_arrow(duck_tbl.fetch_arrow_table()) duck_from_arrow.create("testarrow") @@ -112,7 +112,7 @@ def test_intervals_roundtrip(self, duckdb_cursor): data = pa.array(arr, pa.month_day_nano_interval()) arrow_tbl = pa.Table.from_arrays([data], ['a']) duckdb_cursor.from_arrow(arrow_tbl).create("intervaltbl") - duck_arrow_tbl = duckdb_cursor.table("intervaltbl").arrow()['a'] + duck_arrow_tbl = duckdb_cursor.table("intervaltbl").fetch_arrow_table()['a'] assert duck_arrow_tbl[0].value == expected_value @@ -120,7 +120,7 @@ def test_intervals_roundtrip(self, duckdb_cursor): duckdb_cursor.execute("CREATE TABLE test (a INTERVAL)") duckdb_cursor.execute("INSERT INTO test VALUES (INTERVAL 1 YEAR + INTERVAL 1 DAY + INTERVAL 1 SECOND)") expected_value = pa.MonthDayNano([12, 1, 1000000000]) - duck_tbl_arrow = duckdb_cursor.table("test").arrow()['a'] + duck_tbl_arrow = duckdb_cursor.table("test").fetch_arrow_table()['a'] assert duck_tbl_arrow[0].value.months == expected_value.months assert duck_tbl_arrow[0].value.days == expected_value.days assert duck_tbl_arrow[0].value.nanoseconds == expected_value.nanoseconds @@ -142,7 +142,7 @@ def test_null_intervals_roundtrip(self, duckdb_cursor): data = pa.array(arr, pa.month_day_nano_interval()) arrow_tbl = pa.Table.from_arrays([data], ['a']) duckdb_cursor.from_arrow(arrow_tbl).create("intervalnulltbl") - duckdb_tbl_arrow = duckdb_cursor.table("intervalnulltbl").arrow()['a'] + duckdb_tbl_arrow = duckdb_cursor.table("intervalnulltbl").fetch_arrow_table()['a'] assert duckdb_tbl_arrow[0].value == None assert duckdb_tbl_arrow[1].value == expected_value @@ -156,7 +156,7 @@ def test_nested_interval_roundtrip(self, duckdb_cursor): dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) arrow_table = pa.Table.from_arrays([dict_array], ['a']) duckdb_cursor.from_arrow(arrow_table).create("dictionarytbl") - duckdb_tbl_arrow = duckdb_cursor.table("dictionarytbl").arrow()['a'] + duckdb_tbl_arrow = duckdb_cursor.table("dictionarytbl").fetch_arrow_table()['a'] assert duckdb_tbl_arrow[0].value == first_value assert duckdb_tbl_arrow[1].value == second_value @@ -170,7 +170,7 @@ def test_nested_interval_roundtrip(self, duckdb_cursor): # List query = duckdb_cursor.sql( "SELECT a from (select list_value(INTERVAL 3 MONTHS, INTERVAL 5 DAYS, INTERVAL 10 SECONDS, NULL) as a) as t" - ).arrow()['a'] + ).fetch_arrow_table()['a'] assert query[0][0].value == pa.MonthDayNano([3, 0, 0]) assert query[0][1].value == pa.MonthDayNano([0, 5, 0]) assert query[0][2].value == pa.MonthDayNano([0, 0, 10000000000]) @@ -179,7 +179,7 @@ def test_nested_interval_roundtrip(self, duckdb_cursor): # Struct query = "SELECT a from (SELECT STRUCT_PACK(a := INTERVAL 1 MONTHS, b := INTERVAL 10 DAYS, c:= INTERVAL 20 SECONDS) as a) as t" true_answer = duckdb_cursor.sql(query).fetchall() - from_arrow = duckdb_cursor.from_arrow(duckdb_cursor.sql(query).arrow()).fetchall() + from_arrow = duckdb_cursor.from_arrow(duckdb_cursor.sql(query).fetch_arrow_table()).fetchall() assert true_answer[0][0]['a'] == from_arrow[0][0]['a'] assert true_answer[0][0]['b'] == from_arrow[0][0]['b'] assert true_answer[0][0]['c'] == from_arrow[0][0]['c'] @@ -191,7 +191,7 @@ def test_min_max_interval_roundtrip(self, duckdb_cursor): arrow_tbl = pa.Table.from_arrays([data], ['a']) duckdb_cursor.from_arrow(arrow_tbl).create("intervalminmaxtbl") - duck_arrow_tbl = duckdb_cursor.table("intervalminmaxtbl").arrow()['a'] + duck_arrow_tbl = duckdb_cursor.table("intervalminmaxtbl").fetch_arrow_table()['a'] assert duck_arrow_tbl[0].value == pa.MonthDayNano([0, 0, 0]) assert duck_arrow_tbl[1].value == pa.MonthDayNano([2147483647, 2147483647, 9223372036854775000]) @@ -209,7 +209,7 @@ def test_duplicate_column_names(self, duckdb_cursor): df_b table2 ON table1.join_key = table2.join_key """ - ).arrow() + ).fetch_arrow_table() assert res.schema.names == ['join_key', 'col_a', 'join_key', 'col_a'] def test_strings_roundtrip(self, duckdb_cursor): @@ -225,7 +225,7 @@ def test_strings_roundtrip(self, duckdb_cursor): duck_tbl = duckdb_cursor.table("test") - duck_from_arrow = duckdb_cursor.from_arrow(duck_tbl.arrow()) + duck_from_arrow = duckdb_cursor.from_arrow(duck_tbl.fetch_arrow_table()) duck_from_arrow.create("testarrow") diff --git a/tests/fast/arrow/test_interval.py b/tests/fast/arrow/test_interval.py index 7a891a61..a548818f 100644 --- a/tests/fast/arrow/test_interval.py +++ b/tests/fast/arrow/test_interval.py @@ -26,7 +26,7 @@ def test_duration_types(self, duckdb_cursor): pa.array([1], pa.duration('s')), ) arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) - rel = duckdb.from_arrow(arrow_table).arrow() + rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() assert rel['a'] == expected_arrow['a'] assert rel['b'] == expected_arrow['a'] assert rel['c'] == expected_arrow['a'] @@ -43,7 +43,7 @@ def test_duration_null(self, duckdb_cursor): pa.array([None], pa.duration('s')), ) arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) - rel = duckdb.from_arrow(arrow_table).arrow() + rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() assert rel['a'] == expected_arrow['a'] assert rel['b'] == expected_arrow['a'] assert rel['c'] == expected_arrow['a'] @@ -58,4 +58,4 @@ def test_duration_overflow(self, duckdb_cursor): arrow_table = pa.Table.from_arrays([data], ['a']) with pytest.raises(duckdb.ConversionException, match='Could not convert Interval to Microsecond'): - arrow_from_duck = duckdb.from_arrow(arrow_table).arrow() + arrow_from_duck = duckdb.from_arrow(arrow_table).fetch_arrow_table() diff --git a/tests/fast/arrow/test_large_offsets.py b/tests/fast/arrow/test_large_offsets.py index 03705e75..1bcdd1b7 100644 --- a/tests/fast/arrow/test_large_offsets.py +++ b/tests/fast/arrow/test_large_offsets.py @@ -20,11 +20,11 @@ def test_large_lists(self, duckdb_cursor): duckdb.InvalidInputException, match='Arrow Appender: The maximum combined list offset for regular list buffers is 2147483647 but the offset of 2147481000 exceeds this.', ): - res = duckdb_cursor.sql("SELECT col FROM tbl").arrow() + res = duckdb_cursor.sql("SELECT col FROM tbl").fetch_arrow_table() tbl2 = pa.Table.from_pydict(dict(col=ary.cast(pa.large_list(pa.uint8())))) duckdb_cursor.sql("set arrow_large_buffer_size = true") - res2 = duckdb_cursor.sql("SELECT col FROM tbl2").arrow() + res2 = duckdb_cursor.sql("SELECT col FROM tbl2").fetch_arrow_table() res2.validate() @pytest.mark.skip(reason="CI does not have enough memory to validate this") @@ -36,8 +36,8 @@ def test_large_maps(self, duckdb_cursor): duckdb.InvalidInputException, match='Arrow Appender: The maximum combined list offset for regular list buffers is 2147483647 but the offset of 2147481000 exceeds this.', ): - arrow_map = duckdb_cursor.sql("select map(col, col) from tbl").arrow() + arrow_map = duckdb_cursor.sql("select map(col, col) from tbl").fetch_arrow_table() duckdb_cursor.sql("set arrow_large_buffer_size = true") - arrow_map_large = duckdb_cursor.sql("select map(col, col) from tbl").arrow() + arrow_map_large = duckdb_cursor.sql("select map(col, col) from tbl").fetch_arrow_table() arrow_map_large.validate() diff --git a/tests/fast/arrow/test_nested_arrow.py b/tests/fast/arrow/test_nested_arrow.py index bdb211ac..693a5155 100644 --- a/tests/fast/arrow/test_nested_arrow.py +++ b/tests/fast/arrow/test_nested_arrow.py @@ -10,13 +10,13 @@ def compare_results(duckdb_cursor, query): true_answer = duckdb_cursor.query(query).fetchall() - produced_arrow = duckdb_cursor.query(query).arrow() + produced_arrow = duckdb_cursor.query(query).fetch_arrow_table() from_arrow = duckdb_cursor.from_arrow(produced_arrow).fetchall() assert true_answer == from_arrow def arrow_to_pandas(duckdb_cursor, query): - return duckdb_cursor.query(query).arrow().to_pandas()['a'].values.tolist() + return duckdb_cursor.query(query).fetch_arrow_table().to_pandas()['a'].values.tolist() def get_use_list_view_options(): @@ -30,17 +30,25 @@ def get_use_list_view_options(): class TestArrowNested(object): def test_lists_basic(self, duckdb_cursor): # Test Constant List - query = duckdb_cursor.query("SELECT a from (select list_value(3,5,10) as a) as t").arrow()['a'].to_numpy() + query = ( + duckdb_cursor.query("SELECT a from (select list_value(3,5,10) as a) as t") + .fetch_arrow_table()['a'] + .to_numpy() + ) assert query[0][0] == 3 assert query[0][1] == 5 assert query[0][2] == 10 # Empty List - query = duckdb_cursor.query("SELECT a from (select list_value() as a) as t").arrow()['a'].to_numpy() + query = duckdb_cursor.query("SELECT a from (select list_value() as a) as t").fetch_arrow_table()['a'].to_numpy() assert len(query[0]) == 0 # Test Constant List With Null - query = duckdb_cursor.query("SELECT a from (select list_value(3,NULL) as a) as t").arrow()['a'].to_numpy() + query = ( + duckdb_cursor.query("SELECT a from (select list_value(3,NULL) as a) as t") + .fetch_arrow_table()['a'] + .to_numpy() + ) assert query[0][0] == 3 assert np.isnan(query[0][1]) diff --git a/tests/fast/arrow/test_projection_pushdown.py b/tests/fast/arrow/test_projection_pushdown.py index 39022ab4..802259e1 100644 --- a/tests/fast/arrow/test_projection_pushdown.py +++ b/tests/fast/arrow/test_projection_pushdown.py @@ -23,7 +23,7 @@ def test_projection_pushdown_no_filter(self, duckdb_cursor): """ ) duck_tbl = duckdb_cursor.table("test") - arrow_table = duck_tbl.arrow() + arrow_table = duck_tbl.fetch_arrow_table() assert duckdb_cursor.execute("SELECT sum(c) FROM arrow_table").fetchall() == [(333,)] # RecordBatch does not use projection pushdown, test that this also still works diff --git a/tests/fast/arrow/test_time.py b/tests/fast/arrow/test_time.py index d575fc0f..726b0f6a 100644 --- a/tests/fast/arrow/test_time.py +++ b/tests/fast/arrow/test_time.py @@ -24,7 +24,7 @@ def test_time_types(self, duckdb_cursor): pa.array([1000000000], pa.time64('ns')), ) arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) - rel = duckdb.from_arrow(arrow_table).arrow() + rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() assert rel['a'] == arrow_table['c'] assert rel['b'] == arrow_table['c'] assert rel['c'] == arrow_table['c'] @@ -40,7 +40,7 @@ def test_time_null(self, duckdb_cursor): pa.array([None], pa.time64('ns')), ) arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) - rel = duckdb.from_arrow(arrow_table).arrow() + rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() assert rel['a'] == arrow_table['c'] assert rel['b'] == arrow_table['c'] assert rel['c'] == arrow_table['c'] @@ -54,7 +54,7 @@ def test_max_times(self, duckdb_cursor): # Max Sec data = pa.array([2147483647], type=pa.time32('s')) arrow_table = pa.Table.from_arrays([data], ['a']) - rel = duckdb.from_arrow(arrow_table).arrow() + rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() assert rel['a'] == result['a'] # Max MSec @@ -62,7 +62,7 @@ def test_max_times(self, duckdb_cursor): result = pa.Table.from_arrays([data], ['a']) data = pa.array([2147483647], type=pa.time32('ms')) arrow_table = pa.Table.from_arrays([data], ['a']) - rel = duckdb.from_arrow(arrow_table).arrow() + rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() assert rel['a'] == result['a'] # Max NSec @@ -70,7 +70,7 @@ def test_max_times(self, duckdb_cursor): result = pa.Table.from_arrays([data], ['a']) data = pa.array([9223372036854774000], type=pa.time64('ns')) arrow_table = pa.Table.from_arrays([data], ['a']) - rel = duckdb.from_arrow(arrow_table).arrow() + rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() print(rel['a']) print(result['a']) diff --git a/tests/fast/arrow/test_timestamp_timezone.py b/tests/fast/arrow/test_timestamp_timezone.py index 7a2c738d..4fdadf49 100644 --- a/tests/fast/arrow/test_timestamp_timezone.py +++ b/tests/fast/arrow/test_timestamp_timezone.py @@ -43,7 +43,7 @@ def test_timestamp_tz_to_arrow(self, duckdb_cursor): for timezone in timezones: con.execute("SET TimeZone = '" + timezone + "'") arrow_table = generate_table(current_time, precision, timezone) - res = con.from_arrow(arrow_table).arrow() + res = con.from_arrow(arrow_table).fetch_arrow_table() assert res[0].type == pa.timestamp('us', tz=timezone) assert res == generate_table(current_time, 'us', timezone) @@ -52,7 +52,7 @@ def test_timestamp_tz_with_null(self, duckdb_cursor): con.execute("create table t (i timestamptz)") con.execute("insert into t values (NULL),('2021-11-15 02:30:00'::timestamptz)") rel = con.table('t') - arrow_tbl = rel.arrow() + arrow_tbl = rel.fetch_arrow_table() con.register('t2', arrow_tbl) assert con.execute("select * from t").fetchall() == con.execute("select * from t2").fetchall() diff --git a/tests/fast/arrow/test_timestamps.py b/tests/fast/arrow/test_timestamps.py index f43ca951..c2529c83 100644 --- a/tests/fast/arrow/test_timestamps.py +++ b/tests/fast/arrow/test_timestamps.py @@ -23,7 +23,7 @@ def test_timestamp_types(self, duckdb_cursor): pa.array([datetime.datetime.now()], pa.timestamp('s')), ) arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) - rel = duckdb.from_arrow(arrow_table).arrow() + rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() assert rel['a'] == arrow_table['a'] assert rel['b'] == arrow_table['b'] assert rel['c'] == arrow_table['c'] @@ -39,7 +39,7 @@ def test_timestamp_nulls(self, duckdb_cursor): pa.array([None], pa.timestamp('s')), ) arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) - rel = duckdb.from_arrow(arrow_table).arrow() + rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() assert rel['a'] == arrow_table['a'] assert rel['b'] == arrow_table['b'] assert rel['c'] == arrow_table['c'] @@ -54,7 +54,7 @@ def test_timestamp_overflow(self, duckdb_cursor): pa.array([9223372036854775807], pa.timestamp('us')), ) arrow_table = pa.Table.from_arrays([data[0], data[1], data[2]], ['a', 'b', 'c']) - arrow_from_duck = duckdb.from_arrow(arrow_table).arrow() + arrow_from_duck = duckdb.from_arrow(arrow_table).fetch_arrow_table() assert arrow_from_duck['a'] == arrow_table['a'] assert arrow_from_duck['b'] == arrow_table['b'] assert arrow_from_duck['c'] == arrow_table['c'] diff --git a/tests/fast/arrow/test_tpch.py b/tests/fast/arrow/test_tpch.py index bef8f309..ff4a0445 100644 --- a/tests/fast/arrow/test_tpch.py +++ b/tests/fast/arrow/test_tpch.py @@ -47,7 +47,7 @@ def test_tpch_arrow(self, duckdb_cursor): for tpch_table in tpch_tables: duck_tbl = duckdb_conn.table(tpch_table) - arrow_tables.append(duck_tbl.arrow()) + arrow_tables.append(duck_tbl.fetch_arrow_table()) duck_arrow_table = duckdb_conn.from_arrow(arrow_tables[-1]) duckdb_conn.execute("DROP TABLE " + tpch_table) duck_arrow_table.create(tpch_table) @@ -77,7 +77,7 @@ def test_tpch_arrow_01(self, duckdb_cursor): for tpch_table in tpch_tables: duck_tbl = duckdb_conn.table(tpch_table) - arrow_tables.append(duck_tbl.arrow()) + arrow_tables.append(duck_tbl.fetch_arrow_table()) duck_arrow_table = duckdb_conn.from_arrow(arrow_tables[-1]) duckdb_conn.execute("DROP TABLE " + tpch_table) duck_arrow_table.create(tpch_table) @@ -105,7 +105,7 @@ def test_tpch_arrow_batch(self, duckdb_cursor): for tpch_table in tpch_tables: duck_tbl = duckdb_conn.table(tpch_table) - arrow_tables.append(pyarrow.Table.from_batches(duck_tbl.arrow().to_batches(10))) + arrow_tables.append(pyarrow.Table.from_batches(duck_tbl.fetch_arrow_table().to_batches(10))) duck_arrow_table = duckdb_conn.from_arrow(arrow_tables[-1]) duckdb_conn.execute("DROP TABLE " + tpch_table) duck_arrow_table.create(tpch_table) diff --git a/tests/fast/relational_api/test_rapi_close.py b/tests/fast/relational_api/test_rapi_close.py index c9c605a1..270c58f5 100644 --- a/tests/fast/relational_api/test_rapi_close.py +++ b/tests/fast/relational_api/test_rapi_close.py @@ -26,7 +26,7 @@ def test_close_conn_rel(self, duckdb_cursor): with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): rel.arg_min("", "") with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.arrow() + rel.fetch_arrow_table() with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): rel.avg("") with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): diff --git a/tests/fast/spark/test_spark_types.py b/tests/fast/spark/test_spark_types.py index 980e24a2..fb6e6102 100644 --- a/tests/fast/spark/test_spark_types.py +++ b/tests/fast/spark/test_spark_types.py @@ -63,7 +63,7 @@ def test_all_types_schema(self, spark): struct_of_fixed_array, fixed_array_of_int_list, list_of_fixed_int_array, - varint + bignum ) from test_all_types() """ ) diff --git a/tests/fast/test_all_types.py b/tests/fast/test_all_types.py index ce56165e..2128f9f1 100644 --- a/tests/fast/test_all_types.py +++ b/tests/fast/test_all_types.py @@ -553,16 +553,16 @@ def test_arrow(self, cur_type): conn = duckdb.connect() if cur_type in replacement_values: - arrow_table = conn.execute("select " + replacement_values[cur_type]).arrow() + arrow_table = conn.execute("select " + replacement_values[cur_type]).fetch_arrow_table() else: - arrow_table = conn.execute(f'select "{cur_type}" from test_all_types()').arrow() + arrow_table = conn.execute(f'select "{cur_type}" from test_all_types()').fetch_arrow_table() if cur_type in enum_types: - round_trip_arrow_table = conn.execute("select * from arrow_table").arrow() + round_trip_arrow_table = conn.execute("select * from arrow_table").fetch_arrow_table() result_arrow = conn.execute("select * from arrow_table").fetchall() result_roundtrip = conn.execute("select * from round_trip_arrow_table").fetchall() assert recursive_equality(result_arrow, result_roundtrip) else: - round_trip_arrow_table = conn.execute("select * from arrow_table").arrow() + round_trip_arrow_table = conn.execute("select * from arrow_table").fetch_arrow_table() assert arrow_table.equals(round_trip_arrow_table, check_metadata=True) @pytest.mark.parametrize('cur_type', all_types) diff --git a/tests/fast/test_replacement_scan.py b/tests/fast/test_replacement_scan.py index e96170fa..0cf69356 100644 --- a/tests/fast/test_replacement_scan.py +++ b/tests/fast/test_replacement_scan.py @@ -31,7 +31,7 @@ def fetch_df(rel): def fetch_arrow(rel): - return rel.arrow() + return rel.fetch_arrow_table() def fetch_arrow_table(rel): diff --git a/tests/fast/test_runtime_error.py b/tests/fast/test_runtime_error.py index 618b3669..29e81d1e 100644 --- a/tests/fast/test_runtime_error.py +++ b/tests/fast/test_runtime_error.py @@ -25,7 +25,7 @@ def test_arrow_error(self): con = duckdb.connect() con.execute("create table tbl as select 'hello' i") with pytest.raises(duckdb.ConversionException): - con.execute("select i::int from tbl").arrow() + con.execute("select i::int from tbl").fetch_arrow_table() def test_register_error(self): con = duckdb.connect() @@ -37,7 +37,7 @@ def test_arrow_fetch_table_error(self): pytest.importorskip('pyarrow') con = duckdb.connect() - arrow_object = con.execute("select 1").arrow() + arrow_object = con.execute("select 1").fetch_arrow_table() arrow_relation = con.from_arrow(arrow_object) res = arrow_relation.execute() res.close() @@ -48,7 +48,7 @@ def test_arrow_record_batch_reader_error(self): pytest.importorskip('pyarrow') con = duckdb.connect() - arrow_object = con.execute("select 1").arrow() + arrow_object = con.execute("select 1").fetch_arrow_table() arrow_relation = con.from_arrow(arrow_object) res = arrow_relation.execute() res.close() diff --git a/tests/fast/udf/test_scalar_arrow.py b/tests/fast/udf/test_scalar_arrow.py index fd7f00f7..5773c474 100644 --- a/tests/fast/udf/test_scalar_arrow.py +++ b/tests/fast/udf/test_scalar_arrow.py @@ -155,7 +155,7 @@ def return_struct(col): """ select {'a': 5, 'b': 'test', 'c': [5,3,2]} """ - ).arrow() + ).fetch_arrow_table() con = duckdb.connect() struct_type = con.struct_type({'a': BIGINT, 'b': VARCHAR, 'c': con.list_type(BIGINT)}) From 7ca7e400feed7bebb0d9e7241a32f86612ed4e9c Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 21 Aug 2025 13:08:40 +0200 Subject: [PATCH 094/472] Release workflow stub --- .github/workflows/release.yml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 .github/workflows/release.yml diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..5ef47f4c --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,4 @@ +# Release is called by duckdb's InvokeCI -> NotifyExternalRepositories job +name: Release +on: +jobs: From 580f988ac3529e4bb8478d0f385732264a13480a Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 21 Aug 2025 13:09:54 +0200 Subject: [PATCH 095/472] Release workflow stub --- .github/workflows/release.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5ef47f4c..f988f2c2 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,4 +1,5 @@ # Release is called by duckdb's InvokeCI -> NotifyExternalRepositories job name: Release on: + workflow_dispatch: jobs: From ed72b4f4b96752d8a3df06d4b94561ed9e1f7a3b Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 21 Aug 2025 12:57:46 +0200 Subject: [PATCH 096/472] Fix and simplify the release workflow --- .github/workflows/cleanup_pypi.yml | 14 +- .github/workflows/on_external_dispatch.yml | 126 ----------------- .github/workflows/on_push_postrelease.yml | 42 ------ .github/workflows/packaging.yml | 15 +- .github/workflows/release.yml | 151 ++++++++++++++++++++ .github/workflows/upload_to_pypi.yml | 84 ------------ README.md | 4 +- duckdb_packaging/pypi_cleanup.py | 152 ++++++++++----------- tests/fast/test_pypi_cleanup.py | 47 ++++--- 9 files changed, 271 insertions(+), 364 deletions(-) delete mode 100644 .github/workflows/on_external_dispatch.yml delete mode 100644 .github/workflows/on_push_postrelease.yml create mode 100644 .github/workflows/release.yml delete mode 100644 .github/workflows/upload_to_pypi.yml diff --git a/.github/workflows/cleanup_pypi.yml b/.github/workflows/cleanup_pypi.yml index 5da0c700..70fe13d6 100644 --- a/.github/workflows/cleanup_pypi.yml +++ b/.github/workflows/cleanup_pypi.yml @@ -3,9 +3,16 @@ on: workflow_call: inputs: environment: - description: CI environment to run in (test.pypi or production.pypi) + description: CI environment to run in (pypi-test or pypi-prod-nightly) type: string required: true + secrets: + PYPI_CLEANUP_OTP: + description: PyPI OTP + required: true + PYPI_CLEANUP_PASSWORD: + description: PyPI password + required: true workflow_dispatch: inputs: dry-run: @@ -16,10 +23,9 @@ on: description: CI environment to run in type: choice required: true - default: test.pypi options: - - test.pypi - - production.pypi + - pypi-prod-nightly + - pypi-test jobs: cleanup_pypi: diff --git a/.github/workflows/on_external_dispatch.yml b/.github/workflows/on_external_dispatch.yml deleted file mode 100644 index 504d6a49..00000000 --- a/.github/workflows/on_external_dispatch.yml +++ /dev/null @@ -1,126 +0,0 @@ -# External Dispatch is called by duckdb's InvokeCI -> NotifyExternalRepositories job -name: External Dispatch -on: - workflow_dispatch: - inputs: - duckdb-sha: - type: string - description: The DuckDB submodule commit to build against - required: true - commit-duckdb-sha: - type: boolean - description: Commit and push the DuckDB submodule ref - default: false - force-version: - type: string - description: Force version (vX.Y.Z-((rc|post)N)) - required: false - publish-packages: - type: boolean - description: Upload packages to S3 - required: true - default: false - -defaults: - run: - shell: bash - -jobs: - commit_submodule: - name: Commit the submodule to the given DuckDB sha - outputs: - sha-after-commit: ${{ steps.git_commit_sha.outputs.commit_sha }} - runs-on: ubuntu-24.04 - permissions: - contents: write - steps: - - name: Checkout DuckDB Python - uses: actions/checkout@v4 - with: - ref: ${{ github.ref }} - fetch-depth: 0 - submodules: true - - - name: Checkout DuckDB - run: | - cd external/duckdb - git fetch origin - git checkout ${{ inputs.duckdb-sha }} - - - name: Commit and push new submodule ref - if: ${{ inputs.commit-duckdb-sha }} - # see https://github.com/actions/checkout?tab=readme-ov-file#push-a-commit-to-a-pr-using-the-built-in-token - run: | - git config user.name "github-actions[bot]" - git config user.email "41898282+github-actions[bot]@users.noreply.github.com" - git add external/duckdb - if git diff --cached --quiet; then - echo "No changes to commit: submodule ref is unchanged." - else - git commit -m "Update submodule ref" - git push - fi - - - name: Get the SHA of the latest commit - id: git_commit_sha - run: | - echo "commit_sha=$( git rev-parse HEAD )" >> $GITHUB_OUTPUT - - externally_triggered_build: - name: Build and test releases - needs: commit_submodule - uses: ./.github/workflows/packaging.yml - with: - minimal: false - testsuite: all - git-ref: ${{ needs.commit_submodule.outputs.sha-after-commit }} - duckdb-git-ref: ${{ inputs.duckdb-sha }} - force-version: ${{ inputs.force-version }} - - upload_s3: - name: Upload Artifacts to the S3 Staging Bucket - runs-on: ubuntu-latest - needs: [commit_submodule, externally_triggered_build] - outputs: - version: ${{ steps.s3_upload.outputs.version }} - if: ${{ github.repository_owner == 'duckdb' && inputs.publish-packages }} - steps: - - name: Fetch artifacts - uses: actions/download-artifact@v4 - with: - pattern: '{sdist,wheel}*' - path: artifacts/ - merge-multiple: true - - - name: Authenticate with AWS - uses: aws-actions/configure-aws-credentials@v4 - with: - aws-region: 'us-east-2' - aws-access-key-id: ${{ secrets.S3_DUCKDB_STAGING_ID }} - aws-secret-access-key: ${{ secrets.S3_DUCKDB_STAGING_KEY }} - - - name: Upload Artifacts - id: s3_upload - run: | - sha=${{ needs.commit_submodule.outputs.sha-after-commit }} - aws s3 cp artifacts s3://duckdb-staging/${{ github.repository }}/${sha:0:10}/ --recursive - echo "version=${version}" >> $GITHUB_OUTPUT - - - name: S3 Upload Summary - run : | - sha=${{ needs.commit_submodule.outputs.sha-after-commit }} - version=$(basename artifacts/*.tar.gz | sed 's/duckdb-\(.*\).tar.gz/\1/g') - echo "## S3 Upload Summary" >> $GITHUB_STEP_SUMMARY - echo "* Version: ${version}" >> $GITHUB_STEP_SUMMARY - echo "* SHA: ${sha:0:10}" >> $GITHUB_STEP_SUMMARY - echo "* S3 URL: s3://duckdb-staging/${{ github.repository }}/${sha:0:10}/" >> $GITHUB_STEP_SUMMARY - - publish_to_pypi: - name: Upload Artifacts to PyPI - needs: [ commit_submodule, upload_s3 ] - if: ${{ inputs.force-version == '' }} - uses: ./.github/workflows/upload_to_pypi.yml - secrets: inherit - with: - sha: ${{ needs.commit_submodule.outputs.sha-after-commit }} - environment: production.pypi diff --git a/.github/workflows/on_push_postrelease.yml b/.github/workflows/on_push_postrelease.yml deleted file mode 100644 index 14575754..00000000 --- a/.github/workflows/on_push_postrelease.yml +++ /dev/null @@ -1,42 +0,0 @@ -name: Testing and packaging for post releases -on: - push: - branches: - - v*.*.*-post* - paths-ignore: - - '**.md' - - 'LICENSE' - - '.editorconfig' - - 'scripts/**' - - '.github//**' - - '!.github/workflows/on_push.yml' - - '!.github/workflows/coverage.yml' - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - extract_duckdb_tag: - runs-on: ubuntu-24.04 - outputs: - duckdb_version: ${{ steps.extract_version.outputs.version }} - steps: - - name: Get DuckDB version from branch name - id: extract_version - shell: bash - run: | - BRANCH="${{ github.ref_name }}" - VERSION="${BRANCH%%-*}" - echo "version=$VERSION" >> $GITHUB_OUTPUT - - packaging_test: - name: Build and test post release packages and upload to S3 - needs: extract_duckdb_tag - uses: ./.github/workflows/packaging.yml - with: - minimal: false - testsuite: all - git-ref: ${{ github.ref }} - duckdb-git-ref: ${{ needs.extract_duckdb_tag.outputs.duckdb_version }} - force-version: ${{ github.ref_name }} diff --git a/.github/workflows/packaging.yml b/.github/workflows/packaging.yml index 977bc914..573919c8 100644 --- a/.github/workflows/packaging.yml +++ b/.github/workflows/packaging.yml @@ -1,5 +1,5 @@ name: Packaging -run-name: Build ${{ inputs.minimal && 'minimal set of' || 'all' }} packages (version=${{ inputs.force-version != '' && inputs.force-version || 'dev' }}, tests=${{ inputs.testsuite }}, ref=${{ inputs.git-ref }}, duckdb ref=${{ inputs.duckdb-git-ref }}) +run-name: Build ${{ inputs.minimal && 'minimal set of' || 'all' }} packages (version=${{ inputs.set-version != '' && inputs.set-version || 'dev' }}, tests=${{ inputs.testsuite }}, ref=${{ inputs.git-ref }}, duckdb ref=${{ inputs.duckdb-git-ref }}) on: workflow_dispatch: inputs: @@ -25,7 +25,7 @@ on: description: Git ref of DuckDB required: true default: refs/heads/main - force-version: + set-version: type: string description: Force version (vX.Y.Z-((rc|post)N)) required: false @@ -48,7 +48,7 @@ on: type: string description: Git ref of DuckDB required: false - force-version: + set-version: description: Force version (vX.Y.Z-((rc|post)N)) required: false type: string @@ -82,8 +82,8 @@ jobs: git checkout ${{ inputs.duckdb-git-ref }} - name: Set OVERRIDE_GIT_DESCRIBE - if: ${{ inputs.force-version != '' }} - run: echo "OVERRIDE_GIT_DESCRIBE=${{ inputs.force-version }}" >> $GITHUB_ENV + if: ${{ inputs.set-version != '' }} + run: echo "OVERRIDE_GIT_DESCRIBE=${{ inputs.set-version }}" >> $GITHUB_ENV - name: Install Astral UV uses: astral-sh/setup-uv@v6 @@ -149,7 +149,6 @@ jobs: uv run -v pytest ${{ inputs.testsuite == 'fast' && './tests/fast' || './tests' }} --verbose --ignore=./tests/stubs steps: - - name: Checkout DuckDB Python uses: actions/checkout@v4 with: @@ -166,8 +165,8 @@ jobs: # Make sure that OVERRIDE_GIT_DESCRIBE is propagated to cibuildwhel's env, also when it's running linux builds - name: Set OVERRIDE_GIT_DESCRIBE - if: ${{ inputs.force-version != '' }} - run: echo "CIBW_ENVIRONMENT=OVERRIDE_GIT_DESCRIBE=${{ inputs.force-version }}" >> $GITHUB_ENV + if: ${{ inputs.set-version != '' }} + run: echo "CIBW_ENVIRONMENT=OVERRIDE_GIT_DESCRIBE=${{ inputs.set-version }}" >> $GITHUB_ENV # Install Astral UV, which will be used as build-frontend for cibuildwheel - uses: astral-sh/setup-uv@v6 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..8e27fa3c --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,151 @@ +# Release is called by duckdb's InvokeCI -> NotifyExternalRepositories job +name: Release +on: + workflow_dispatch: + inputs: + duckdb-sha: + type: string + description: The DuckDB submodule commit to build against + required: true + stable-version: + type: string + description: Release a stable version (vX.Y.Z-((rc|post)N)) + required: false + pypi-index: + type: choice + description: Which PyPI to use + required: true + options: + - test + - prod + store-s3: + type: boolean + description: Also store test packages in S3 + default: false + +defaults: + run: + shell: bash + +jobs: + build_and_test: + name: Build and test releases + uses: ./.github/workflows/packaging.yml + with: + minimal: false + testsuite: all + git-ref: ${{ github.ref }} + duckdb-git-ref: ${{ inputs.duckdb-sha }} + set-version: ${{ inputs.stable-version }} + + upload_s3: + name: Upload Artifacts to S3 + runs-on: ubuntu-latest + needs: [build_and_test] + if: ${{ github.repository_owner == 'duckdb' && ( inputs.pypi-index == 'prod' || inputs.store-s3 ) }} + steps: + - name: Fetch artifacts + uses: actions/download-artifact@v4 + with: + pattern: '{sdist,wheel}*' + path: artifacts/ + merge-multiple: true + + - name: Authenticate with AWS + uses: aws-actions/configure-aws-credentials@v4 + with: + aws-region: 'us-east-2' + aws-access-key-id: ${{ secrets.S3_DUCKDB_STAGING_ID }} + aws-secret-access-key: ${{ secrets.S3_DUCKDB_STAGING_KEY }} + + - name: Upload Artifacts + id: s3_upload + run: | + sha=${{ github.ref }} + aws s3 cp artifacts s3://duckdb-staging/${{ github.repository }}/${sha:0:10}/ --recursive + + - name: S3 Upload Summary + run : | + sha=${{ github.ref }} + version=$(basename artifacts/*.tar.gz | sed 's/duckdb-\(.*\).tar.gz/\1/g') + echo "## S3 Upload Summary" >> $GITHUB_STEP_SUMMARY + echo "* Version: ${version}" >> $GITHUB_STEP_SUMMARY + echo "* SHA: ${sha:0:10}" >> $GITHUB_STEP_SUMMARY + echo "* S3 URL: s3://duckdb-staging/${{ github.repository }}/${sha:0:10}/" >> $GITHUB_STEP_SUMMARY + + determine_environment: + name: Determine the Github Actions environment to use + runs-on: ubuntu-latest + needs: build_and_test + outputs: + env_name: ${{ steps.set-env.outputs.env_name }} + steps: + - name: Set environment name + id: set-env + run: | + set -euo pipefail + case "${{ inputs.pypi-index }}" in + test) + echo "env_name=pypi-test" >> "$GITHUB_OUTPUT" + ;; + prod) + if [[ -n "${{ inputs.stable-version }}" ]]; then + echo "env_name=pypi-prod" >> "$GITHUB_OUTPUT" + else + echo "env_name=pypi-prod-nightly" >> "$GITHUB_OUTPUT" + fi + ;; + *) + echo "Error: invalid combination of inputs.pypi-index='${{ inputs.pypi-index }}' and inputs.stable-version='${{ inputs.stable-version }}'" >&2 + exit 1 + ;; + esac + + publish_pypi: + name: Publish Artifacts to PyPI + runs-on: ubuntu-latest + needs: determine_environment + environment: + name: ${{ needs.determine_environment.outputs.env_name }} + permissions: + # this is needed for the OIDC flow that is used with trusted publishing on PyPI + id-token: write + steps: + - if: ${{ vars.PYPI_HOST == '' }} + run: | + echo "Error: PYPI_HOST is not set in CI environment '${{ needs.determine_environment.outputs.env_name }}'" + exit 1 + + - name: Fetch artifacts + uses: actions/download-artifact@v4 + with: + pattern: '{sdist,wheel}*' + path: packages/ + merge-multiple: true + + - name: Upload artifacts to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + repository-url: 'https://${{ vars.PYPI_HOST }}/legacy/' + packages-dir: packages + verbose: 'true' + + - name: PyPI Upload Summary + run : | + version=$(basename packages/*.tar.gz | sed 's/duckdb-\(.*\).tar.gz/\1/g') + echo "## PyPI Upload Summary" >> $GITHUB_STEP_SUMMARY + echo "* Version: ${version}" >> $GITHUB_STEP_SUMMARY + echo "* PyPI Host: ${{ vars.PYPI_HOST }}" >> $GITHUB_STEP_SUMMARY + echo "* CI Environment: ${{ needs.determine_environment.outputs.env_name }}" >> $GITHUB_STEP_SUMMARY + + cleanup_nightlies: + name: Remove Nightlies from PyPI + needs: [determine_environment, publish_pypi] + if: ${{ inputs.stable-version == '' }} + uses: ./.github/workflows/cleanup_pypi.yml + with: + environment: ${{ needs.determine_environment.outputs.env_name }} + secrets: + # reusable workflows and secrets are not great: https://github.com/actions/runner/issues/3206 + PYPI_CLEANUP_OTP: ${{secrets.PYPI_CLEANUP_OTP}} + PYPI_CLEANUP_PASSWORD: ${{secrets.PYPI_CLEANUP_PASSWORD}} diff --git a/.github/workflows/upload_to_pypi.yml b/.github/workflows/upload_to_pypi.yml deleted file mode 100644 index 0452f868..00000000 --- a/.github/workflows/upload_to_pypi.yml +++ /dev/null @@ -1,84 +0,0 @@ -name: Upload Artifacts to PyPI -on: - workflow_call: - inputs: - environment: - description: CI environment to run in (test.pypi or production.pypi) - type: string - required: true - sha: - description: The SHA of the commit that the packages were built from - type: string - required: true - workflow_dispatch: - inputs: - environment: - description: CI environment to run in (test.pypi or production.pypi) - type: choice - required: true - default: test.pypi - options: - - test.pypi - - production.pypi - sha: - description: The SHA of the commit that the packages were built from - type: string - required: true - -concurrency: - group: ${{ inputs.sha }} - cancel-in-progress: true - -jobs: - publish-pypi: - name: Publish Artifacts to PyPI - runs-on: ubuntu-latest - environment: - name: ${{ inputs.environment }} - permissions: - # this is needed for the OIDC flow that is used with trusted publishing on PyPI - id-token: write - steps: - - if: ${{ vars.PYPI_HOST == '' }} - run: | - echo "Error: PYPI_HOST is not set in CI environment '${{ inputs.environment }}'" - exit 1 - - - name: Authenticate With AWS - uses: aws-actions/configure-aws-credentials@v4 - with: - aws-region: 'us-east-2' - aws-access-key-id: ${{ secrets.S3_DUCKDB_STAGING_ID }} - aws-secret-access-key: ${{ secrets.S3_DUCKDB_STAGING_KEY }} - - - name: Download Artifacts From S3 - env: - S3_URL: 's3://duckdb-staging/${{ github.repository }}/${{ inputs.sha }}/' - AWS_ACCESS_KEY_ID: ${{ secrets.S3_DUCKDB_STAGING_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_DUCKDB_STAGING_KEY }} - run: | - sha=${{ inputs.sha }} - mkdir packages - aws s3 cp --recursive s3://duckdb-staging/${{ github.repository }}/${sha:0:10}/ packages - - - name: Upload artifacts to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - repository-url: 'https://${{ vars.PYPI_HOST }}/legacy/' - packages-dir: packages - verbose: 'true' - - - name: PyPI Upload Summary - run : | - version=$(basename packages/*.tar.gz | sed 's/duckdb-\(.*\).tar.gz/\1/g') - echo "## PyPI Upload Summary" >> $GITHUB_STEP_SUMMARY - echo "* Version: ${version}" >> $GITHUB_STEP_SUMMARY - echo "* PyPI Host: ${{ vars.PYPI_HOST }}" >> $GITHUB_STEP_SUMMARY - echo "* CI Environment: ${{ inputs.environment }}" >> $GITHUB_STEP_SUMMARY - - cleanup_nightlies: - name: Remove Nightlies from PyPI - needs: publish-pypi - uses: ./.github/workflows/cleanup_pypi.yml - with: - environment: ${{ inputs.environment }} diff --git a/README.md b/README.md index 13e9569d..5a240c76 100644 --- a/README.md +++ b/README.md @@ -67,8 +67,8 @@ git config --local core.hooksPath .githooks/ ### Editable installs (general) - It's good to be aware of the following when creating an editable install: -- `uv sync` or `uv run [tool]` create editable installs by default, however, it work the way you expect. We have + It's good to be aware of the following when performing an editable install: +- `uv sync` or `uv run [tool]` perform an editable install by default. We have configured the project so that scikit-build-core will use a persistent build-dir, but since the build itself happens in an isolated, ephemeral environment, cmake's paths will point to non-existing directories. CMake itself will be missing. diff --git a/duckdb_packaging/pypi_cleanup.py b/duckdb_packaging/pypi_cleanup.py index d386a606..81d4c8e0 100644 --- a/duckdb_packaging/pypi_cleanup.py +++ b/duckdb_packaging/pypi_cleanup.py @@ -10,7 +10,6 @@ import argparse import contextlib -import datetime import heapq import logging import os @@ -19,7 +18,7 @@ import time from collections import defaultdict from html.parser import HTMLParser -from typing import Dict, Optional, Set, Generator +from typing import Optional, Set, Generator from urllib.parse import urlparse import pyotp @@ -252,7 +251,8 @@ def run(self) -> int: logging.info(f"Max development releases to keep per unreleased version: {self._max_dev_releases}") try: - return self._execute_cleanup() + with session_with_retries() as http_session: + return self._execute_cleanup(http_session) except PyPICleanupError as e: logging.error(f"Cleanup failed: {e}") return 1 @@ -260,11 +260,11 @@ def run(self) -> int: logging.error(f"Unexpected error: {e}", exc_info=True) return 1 - def _execute_cleanup(self) -> int: + def _execute_cleanup(self, http_session: Session) -> int: """Execute the main cleanup logic.""" # Get released versions - versions = self._fetch_released_versions() + versions = self._fetch_released_versions(http_session) if not versions: logging.info(f"No releases found for {self._package}") return 0 @@ -284,24 +284,23 @@ def _execute_cleanup(self) -> int: return 0 # Perform authentication and deletion - self._authenticate() - self._delete_versions(versions_to_delete) + self._authenticate(http_session) + self._delete_versions(http_session, versions_to_delete) logging.info(f"Successfully cleaned up {len(versions_to_delete)} development versions") return 0 - def _fetch_released_versions(self) -> Set[str]: + def _fetch_released_versions(self, http_session: Session) -> Set[str]: """Fetch package release information from PyPI API.""" logging.debug(f"Fetching package information for '{self._package}'") try: - with session_with_retries() as session: - req = session.get(f"{self._index_url}/pypi/{self._package}/json") - req.raise_for_status() - data = req.json() - versions = {v for v, files in data["releases"].items() if len(files) > 0} - logging.debug(f"Found {len(versions)} releases with files") - return versions + req = http_session.get(f"{self._index_url}/pypi/{self._package}/json") + req.raise_for_status() + data = req.json() + versions = {v for v, files in data["releases"].items() if len(files) > 0} + logging.debug(f"Found {len(versions)} releases with files") + return versions except RequestException as e: raise PyPICleanupError(f"Failed to fetch package information for '{self._package}': {e}") from e @@ -394,7 +393,7 @@ def _determine_versions_to_delete(self, versions: Set[str]) -> Set[str]: return versions_to_delete - def _authenticate(self) -> None: + def _authenticate(self, http_session: Session) -> None: """Authenticate with PyPI.""" if not self._username or not self._password: raise AuthenticationError("Username and password are required for authentication") @@ -402,63 +401,62 @@ def _authenticate(self) -> None: logging.info(f"Authenticating user '{self._username}' with PyPI") try: - # Get login form and CSRF token - csrf_token = self._get_csrf_token("/account/login/") - # Attempt login - login_response = self._perform_login(csrf_token) - + login_response = self._perform_login(http_session) + # Handle two-factor authentication if required if login_response.url.startswith(f"{self._index_url}/account/two-factor/"): logging.debug("Two-factor authentication required") - self._handle_two_factor_auth(login_response) + self._handle_two_factor_auth(http_session, login_response) logging.info("Authentication successful") - + except RequestException as e: raise AuthenticationError(f"Network error during authentication: {e}") from e - def _get_csrf_token(self, form_action: str) -> str: + def _get_csrf_token(self, http_session: Session, form_action: str) -> str: """Extract CSRF token from a form page.""" - with session_with_retries() as session: - req = session.get(f"{self._index_url}{form_action}") - req.raise_for_status() - parser = CsrfParser(form_action) - parser.feed(req.text) - if not parser.csrf: - raise AuthenticationError(f"No CSRF token found in {form_action}") - return parser.csrf + resp = http_session.get(f"{self._index_url}{form_action}") + resp.raise_for_status() + parser = CsrfParser(form_action) + parser.feed(resp.text) + if not parser.csrf: + raise AuthenticationError(f"No CSRF token found in {form_action}") + return parser.csrf - def _perform_login(self, csrf_token: str) -> requests.Response: + def _perform_login(self, http_session: Session) -> requests.Response: """Perform the initial login with username/password.""" + + # Get login form and CSRF token + csrf_token = self._get_csrf_token(http_session, "/account/login/") + login_data = { "csrf_token": csrf_token, "username": self._username, "password": self._password } - with session_with_retries() as session: - response = session.post( - f"{self._index_url}/account/login/", - data=login_data, - headers={"referer": f"{self._index_url}/account/login/"} - ) - response.raise_for_status() + response = http_session.post( + f"{self._index_url}/account/login/", + data=login_data, + headers={"referer": f"{self._index_url}/account/login/"} + ) + response.raise_for_status() - # Check if login failed (redirected back to login page) - if response.url == f"{self._index_url}/account/login/": - raise AuthenticationError(f"Login failed for user '{self._username}' - check credentials") + # Check if login failed (redirected back to login page) + if response.url == f"{self._index_url}/account/login/": + raise AuthenticationError(f"Login failed for user '{self._username}' - check credentials") - return response + return response - def _handle_two_factor_auth(self, response: requests.Response) -> None: + def _handle_two_factor_auth(self, http_session: Session, response: requests.Response) -> None: """Handle two-factor authentication.""" if not self._otp: raise AuthenticationError("Two-factor authentication required but no OTP secret provided") two_factor_url = response.url form_action = two_factor_url[len(self._index_url):] - csrf_token = self._get_csrf_token(form_action) + csrf_token = self._get_csrf_token(http_session, form_action) # Try authentication with retries for attempt in range(_LOGIN_RETRY_ATTEMPTS): @@ -466,22 +464,21 @@ def _handle_two_factor_auth(self, response: requests.Response) -> None: auth_code = pyotp.TOTP(self._otp).now() logging.debug(f"Attempting 2FA with code (attempt {attempt + 1}/{_LOGIN_RETRY_ATTEMPTS})") - with session_with_retries() as session: - auth_response = session.post( - two_factor_url, - data={"csrf_token": csrf_token, "method": "totp", "totp_value": auth_code}, - headers={"referer": two_factor_url} - ) - auth_response.raise_for_status() - - # Check if 2FA succeeded (redirected away from 2FA page) - if auth_response.url != two_factor_url: - logging.debug("Two-factor authentication successful") - return - - if attempt < _LOGIN_RETRY_ATTEMPTS - 1: - logging.debug(f"2FA code rejected, retrying in {_LOGIN_RETRY_DELAY} seconds...") - time.sleep(_LOGIN_RETRY_DELAY) + auth_response = http_session.post( + two_factor_url, + data={"csrf_token": csrf_token, "method": "totp", "totp_value": auth_code}, + headers={"referer": two_factor_url} + ) + auth_response.raise_for_status() + + # Check if 2FA succeeded (redirected away from 2FA page) + if auth_response.url != two_factor_url: + logging.debug("Two-factor authentication successful") + return + + if attempt < _LOGIN_RETRY_ATTEMPTS - 1: + logging.debug(f"2FA code rejected, retrying in {_LOGIN_RETRY_DELAY} seconds...") + time.sleep(_LOGIN_RETRY_DELAY) except RequestException as e: if attempt == _LOGIN_RETRY_ATTEMPTS - 1: @@ -491,14 +488,14 @@ def _handle_two_factor_auth(self, response: requests.Response) -> None: raise AuthenticationError("Two-factor authentication failed after all attempts") - def _delete_versions(self, versions_to_delete: Set[str]) -> None: + def _delete_versions(self, http_session: Session, versions_to_delete: Set[str]) -> None: """Delete the specified package versions.""" logging.info(f"Starting deletion of {len(versions_to_delete)} development versions") failed_deletions = list() for version in sorted(versions_to_delete): try: - self._delete_single_version(version) + self._delete_single_version(http_session, version) logging.info(f"Successfully deleted {self._package} version {version}") except Exception as e: # Continue with other versions rather than failing completely @@ -510,7 +507,7 @@ def _delete_versions(self, versions_to_delete: Set[str]) -> None: f"Failed to delete {len(failed_deletions)}/{len(versions_to_delete)} versions: {failed_deletions}" ) - def _delete_single_version(self, version: str) -> None: + def _delete_single_version(self, http_session: Session, version: str) -> None: """Delete a single package version.""" # Safety check if not self._is_dev_version(version) or self._is_rc_version(version): @@ -522,19 +519,18 @@ def _delete_single_version(self, version: str) -> None: form_action = f"/manage/project/{self._package}/release/{version}/" form_url = f"{self._index_url}{form_action}" - csrf_token = self._get_csrf_token(form_action) - - with session_with_retries() as session: - # Submit deletion request - delete_response = session.post( - form_url, - data={ - "csrf_token": csrf_token, - "confirm_delete_version": version, - }, - headers={"referer": form_url} - ) - delete_response.raise_for_status() + csrf_token = self._get_csrf_token(http_session, form_action) + + # Submit deletion request + delete_response = http_session.post( + form_url, + data={ + "csrf_token": csrf_token, + "confirm_delete_version": version, + }, + headers={"referer": form_url} + ) + delete_response.raise_for_status() def main() -> int: diff --git a/tests/fast/test_pypi_cleanup.py b/tests/fast/test_pypi_cleanup.py index d90f9fff..6e1460e2 100644 --- a/tests/fast/test_pypi_cleanup.py +++ b/tests/fast/test_pypi_cleanup.py @@ -256,7 +256,8 @@ def test_execute_cleanup_dry_run(self, mock_determine, mock_fetch, mock_delete, mock_fetch.return_value = {"1.0.0.dev1"} mock_determine.return_value = {"1.0.0.dev1"} - result = cleanup_dryrun_max_2._execute_cleanup() + with session_with_retries() as session: + result = cleanup_dryrun_max_2._execute_cleanup(session) assert result == 0 mock_fetch.assert_called_once() @@ -266,7 +267,8 @@ def test_execute_cleanup_dry_run(self, mock_determine, mock_fetch, mock_delete, @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._fetch_released_versions') def test_execute_cleanup_no_releases(self, mock_fetch, cleanup_dryrun_max_2): mock_fetch.return_value = {} - result = cleanup_dryrun_max_2._execute_cleanup() + with session_with_retries() as session: + result = cleanup_dryrun_max_2._execute_cleanup(session) assert result == 0 @patch('requests.Session.get') @@ -281,7 +283,8 @@ def test_fetch_released_versions_success(self, mock_get, cleanup_dryrun_max_2): } mock_get.return_value = mock_response - releases = cleanup_dryrun_max_2._fetch_released_versions() + with session_with_retries() as session: + releases = cleanup_dryrun_max_2._fetch_released_versions(session) assert releases == {"1.0.0", "1.0.0.dev1"} @@ -293,38 +296,41 @@ def test_fetch_released_versions_not_found(self, mock_get, cleanup_dryrun_max_2) mock_get.return_value = mock_response with pytest.raises(PyPICleanupError, match="Failed to fetch package information"): - cleanup_dryrun_max_2._fetch_released_versions() + with session_with_retries() as session: + cleanup_dryrun_max_2._fetch_released_versions(session) @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._get_csrf_token') - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._perform_login') - def test_authenticate_success(self, mock_login, mock_csrf, cleanup_max_2): + @patch('requests.Session.post') + def test_authenticate_success(self, mock_post, mock_csrf, cleanup_max_2): """Test successful authentication.""" mock_csrf.return_value = "csrf123" mock_response = Mock() mock_response.url = "https://test.pypi.org/manage/" - mock_login.return_value = mock_response + mock_post.return_value = mock_response - cleanup_max_2._authenticate() # Should not raise + with session_with_retries() as session: + cleanup_max_2._authenticate(session) # Should not raise + mock_csrf.assert_called_once_with(session, "/account/login/") - mock_csrf.assert_called_once_with("/account/login/") - mock_login.assert_called_once_with("csrf123") + mock_post.assert_called_once() + assert mock_post.call_args.args[0].endswith('/account/login/') @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._get_csrf_token') - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._perform_login') + @patch('requests.Session.post') @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._handle_two_factor_auth') - def test_authenticate_with_2fa(self, mock_2fa, mock_login, mock_csrf, cleanup_max_2): + def test_authenticate_with_2fa(self, mock_2fa, mock_post, mock_csrf, cleanup_max_2): mock_csrf.return_value = "csrf123" mock_response = Mock() mock_response.url = "https://test.pypi.org/account/two-factor/totp" - mock_login.return_value = mock_response + mock_post.return_value = mock_response - cleanup_max_2._authenticate() - - mock_2fa.assert_called_once_with(mock_response) + with session_with_retries() as session: + cleanup_max_2._authenticate(session) + mock_2fa.assert_called_once_with(session, mock_response) def test_authenticate_missing_credentials(self, cleanup_dryrun_max_2): with pytest.raises(AuthenticationError, match="Username and password are required"): - cleanup_dryrun_max_2._authenticate() + cleanup_dryrun_max_2._authenticate(None) @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._delete_single_version') def test_delete_versions_success(self, mock_delete, cleanup_max_2): @@ -332,7 +338,8 @@ def test_delete_versions_success(self, mock_delete, cleanup_max_2): versions = {"1.0.0.dev1", "1.0.0.dev2"} mock_delete.side_effect = [None, None] # Successful deletions - cleanup_max_2._delete_versions(versions) + with session_with_retries() as session: + cleanup_max_2._delete_versions(session, versions) assert mock_delete.call_count == 2 @@ -343,12 +350,12 @@ def test_delete_versions_partial_failure(self, mock_delete, cleanup_max_2): mock_delete.side_effect = [None, Exception("Delete failed")] with pytest.raises(PyPICleanupError, match="Failed to delete 1/2 versions"): - cleanup_max_2._delete_versions(versions) + cleanup_max_2._delete_versions(None, versions) def test_delete_single_version_safety_check(self, cleanup_max_2): """Test single version deletion safety check.""" with pytest.raises(PyPICleanupError, match="Refusing to delete non-\\[dev\\|rc\\] version"): - cleanup_max_2._delete_single_version("1.0.0") # Non-dev version + cleanup_max_2._delete_single_version(None, "1.0.0") # Non-dev version class TestArgumentParser: From 0a107416b2fa8e62bb1d0bb3a9e2745c4081a9bf Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 21 Aug 2025 17:15:19 +0200 Subject: [PATCH 097/472] Publish both the dist package version and the duckdb version --- duckdb/__init__.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/duckdb/__init__.py b/duckdb/__init__.py index 5f997bd3..c3ec0610 100644 --- a/duckdb/__init__.py +++ b/duckdb/__init__.py @@ -1,8 +1,17 @@ -_exported_symbols = [] - # Modules import duckdb.functional as functional import duckdb.typing as typing +from _duckdb import __version__ as duckdb_version +from importlib.metadata import version + +# duckdb.__version__ returns the version of the distribution package, i.e. the pypi version +__version__ = version("duckdb") + +# version() is a more human friendly formatted version string of both the distribution package and the bundled duckdb +def version(): + return f"{__version__} (with duckdb {duckdb_version})" + +_exported_symbols = ['__version__', 'version'] _exported_symbols.extend([ "typing", @@ -248,7 +257,6 @@ __interactive__, __jupyter__, __formatted_python_version__, - __version__, apilevel, comment, identifier, @@ -266,7 +274,6 @@ "__interactive__", "__jupyter__", "__formatted_python_version__", - "__version__", "apilevel", "comment", "identifier", From 0ab6530e12d2a238461ffb8c6ef1420ccb0cc88f Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 22 Aug 2025 07:47:38 +0000 Subject: [PATCH 098/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 129b1fe5..d229d97f 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 129b1fe55ef24e616754238cb100e3b9a926e4b6 +Subproject commit d229d97f4028e153234647b5a2b65682b321e77d From af79ff52a8f57b9fad8738683935d89a412ce111 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 22 Aug 2025 13:16:24 +0200 Subject: [PATCH 099/472] Include latest python sqlogic changes --- sqllogic/conftest.py | 2 +- sqllogic/skipped_tests.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sqllogic/conftest.py b/sqllogic/conftest.py index 7c5ce2e2..fbe55f27 100644 --- a/sqllogic/conftest.py +++ b/sqllogic/conftest.py @@ -10,7 +10,7 @@ SQLLOGIC_TEST_CASE_NAME = "test_sqllogic" SQLLOGIC_TEST_PARAMETER = "test_script_path" -DUCKDB_ROOT_DIR = (pathlib.Path(__file__).parent.parent / 'external' / 'duckdb').resolve() +DUCKDB_ROOT_DIR = pathlib.Path(__file__).parent.joinpath("../../..").resolve() def pytest_addoption(parser: pytest.Parser): diff --git a/sqllogic/skipped_tests.py b/sqllogic/skipped_tests.py index ac0f73f8..39269c42 100644 --- a/sqllogic/skipped_tests.py +++ b/sqllogic/skipped_tests.py @@ -37,5 +37,6 @@ 'test/sql/tpcds/tpcds_sf0.test', # problems connected to auto installing tpcds from remote 'test/sql/optimizer/plan/test_filter_pushdown_materialized_cte.test', # problems connected to auto installing tpcds from remote 'test/sql/explain/test_explain_analyze.test', # unknown problem with changes in API + 'test/sql/pragma/profiling/test_profiling_all.test', # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement ] ) From 97b54beb8f8e1227668e6e2b343e72e20d59eba3 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 22 Aug 2025 14:08:54 +0200 Subject: [PATCH 100/472] Correct path in sqlogic testconfig --- sqllogic/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqllogic/conftest.py b/sqllogic/conftest.py index fbe55f27..0d7ea2eb 100644 --- a/sqllogic/conftest.py +++ b/sqllogic/conftest.py @@ -10,7 +10,7 @@ SQLLOGIC_TEST_CASE_NAME = "test_sqllogic" SQLLOGIC_TEST_PARAMETER = "test_script_path" -DUCKDB_ROOT_DIR = pathlib.Path(__file__).parent.joinpath("../../..").resolve() +DUCKDB_ROOT_DIR = pathlib.Path(__file__).parent.joinpath("external/duckdb/").resolve() def pytest_addoption(parser: pytest.Parser): From dcee66cdb3a468f87fb131ce574abbb6a8c917cb Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 22 Aug 2025 15:23:49 +0200 Subject: [PATCH 101/472] Fix S3 url and give better summary output --- .github/workflows/release.yml | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8e27fa3c..e4c92e7f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -51,6 +51,17 @@ jobs: path: artifacts/ merge-multiple: true + - name: Compute upload input + id: input + run: | + sha=${{ github.sha }} + dsha=${{ inputs.duckdb-sha }} + version=$(basename artifacts/*.tar.gz | sed 's/duckdb-\(.*\).tar.gz/\1/g') + echo "short_sha=${sha:0:10}" >> $GITHUB_OUTPUT + echo "short_dsha=${dsha:0:10}" >> $GITHUB_OUTPUT + echo "version=${version}" >> $GITHUB_OUTPUT + echo "s3_upload_url="s3://duckdb-staging/python/${version}/${sha:0:10}-duckdb-${dsha:0:10}/" >> $GITHUB_OUTPUT + - name: Authenticate with AWS uses: aws-actions/configure-aws-credentials@v4 with: @@ -59,19 +70,16 @@ jobs: aws-secret-access-key: ${{ secrets.S3_DUCKDB_STAGING_KEY }} - name: Upload Artifacts - id: s3_upload run: | - sha=${{ github.ref }} - aws s3 cp artifacts s3://duckdb-staging/${{ github.repository }}/${sha:0:10}/ --recursive + aws s3 cp artifacts ${{ steps.input.outputs.s3_upload_url }} --recursive - name: S3 Upload Summary run : | - sha=${{ github.ref }} - version=$(basename artifacts/*.tar.gz | sed 's/duckdb-\(.*\).tar.gz/\1/g') echo "## S3 Upload Summary" >> $GITHUB_STEP_SUMMARY - echo "* Version: ${version}" >> $GITHUB_STEP_SUMMARY - echo "* SHA: ${sha:0:10}" >> $GITHUB_STEP_SUMMARY - echo "* S3 URL: s3://duckdb-staging/${{ github.repository }}/${sha:0:10}/" >> $GITHUB_STEP_SUMMARY + echo "* Version: ${{ steps.input.outputs.version }}" >> $GITHUB_STEP_SUMMARY + echo "* SHA: ${{ steps.input.outputs.short_sha }}" >> $GITHUB_STEP_SUMMARY + echo "* DuckDB SHA: ${{ steps.input.outputs.short_dsha }}" >> $GITHUB_STEP_SUMMARY + echo "* S3 URL: ${{ steps.input.outputs.upload_url }}" >> $GITHUB_STEP_SUMMARY determine_environment: name: Determine the Github Actions environment to use From a77c56a5d5a6c1950c0d655950ed6c80472e6f2c Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 22 Aug 2025 15:36:44 +0200 Subject: [PATCH 102/472] Update submodule --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index aaa4635f..aca23dcd 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit aaa4635fff6a92736c6fc5bf4023f75c0414be02 +Subproject commit aca23dcd266cdca3854efa77fdceff471f39d3c3 From 797f509abcd11848ffc39248bcd3377285e10ffa Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 22 Aug 2025 15:50:05 +0200 Subject: [PATCH 103/472] Update submodule --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index d229d97f..18d0c636 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit d229d97f4028e153234647b5a2b65682b321e77d +Subproject commit 18d0c636b012b6822f1003bfa5f75b8941b38384 From 89a3375467e75d26646db9403e7cb9fde482de10 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 22 Aug 2025 15:51:39 +0200 Subject: [PATCH 104/472] Update submodule --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index aca23dcd..c5310ec8 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit aca23dcd266cdca3854efa77fdceff471f39d3c3 +Subproject commit c5310ec83bf0f2a6ce8b6f04656f54fdcf7da3ef From 122b4e72d18a2667e5ce0c079834f1b513d3484b Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 22 Aug 2025 19:40:00 +0200 Subject: [PATCH 105/472] Fix s3 url --- .github/workflows/release.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e4c92e7f..3d03bf1d 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -57,10 +57,11 @@ jobs: sha=${{ github.sha }} dsha=${{ inputs.duckdb-sha }} version=$(basename artifacts/*.tar.gz | sed 's/duckdb-\(.*\).tar.gz/\1/g') + url="s3://duckdb-staging/python/${version}/${sha:0:10}-duckdb-${dsha:0:10}/" echo "short_sha=${sha:0:10}" >> $GITHUB_OUTPUT echo "short_dsha=${dsha:0:10}" >> $GITHUB_OUTPUT echo "version=${version}" >> $GITHUB_OUTPUT - echo "s3_upload_url="s3://duckdb-staging/python/${version}/${sha:0:10}-duckdb-${dsha:0:10}/" >> $GITHUB_OUTPUT + echo "s3_upload_url=${url}" >> $GITHUB_OUTPUT - name: Authenticate with AWS uses: aws-actions/configure-aws-credentials@v4 From 544e7924d1df24889203fcfd3836204cb2e4a115 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sat, 23 Aug 2025 05:32:36 +0000 Subject: [PATCH 106/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 18d0c636..4768277a 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 18d0c636b012b6822f1003bfa5f75b8941b38384 +Subproject commit 4768277a98072c9e8900fce3372ee20ed3321149 From 82194dc8381a101bf76f3829f9ae505d0e640505 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Sat, 23 Aug 2025 10:38:28 +0200 Subject: [PATCH 107/472] S3 upload summary --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3d03bf1d..9a6dbf0b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -80,7 +80,7 @@ jobs: echo "* Version: ${{ steps.input.outputs.version }}" >> $GITHUB_STEP_SUMMARY echo "* SHA: ${{ steps.input.outputs.short_sha }}" >> $GITHUB_STEP_SUMMARY echo "* DuckDB SHA: ${{ steps.input.outputs.short_dsha }}" >> $GITHUB_STEP_SUMMARY - echo "* S3 URL: ${{ steps.input.outputs.upload_url }}" >> $GITHUB_STEP_SUMMARY + echo "* S3 URL: ${{ steps.input.outputs.s3_upload_url }}" >> $GITHUB_STEP_SUMMARY determine_environment: name: Determine the Github Actions environment to use From 64017b4ca5122b4a657792115ae23a868702ad74 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 25 Aug 2025 05:27:50 +0000 Subject: [PATCH 108/472] Update submodule ref --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 4768277a..22e6d1e3 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 4768277a98072c9e8900fce3372ee20ed3321149 +Subproject commit 22e6d1e3751829d2d029636b75ff931710ca7cbc From c70193cd69efbc2b072f4edd4e02540d0be2f7b4 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 25 Aug 2025 15:43:51 +0200 Subject: [PATCH 109/472] Use correct env --- .github/workflows/cleanup_pypi.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cleanup_pypi.yml b/.github/workflows/cleanup_pypi.yml index 70fe13d6..e33e8b65 100644 --- a/.github/workflows/cleanup_pypi.yml +++ b/.github/workflows/cleanup_pypi.yml @@ -25,6 +25,7 @@ on: required: true options: - pypi-prod-nightly + - pypi-prod - pypi-test jobs: @@ -59,10 +60,10 @@ jobs: PYTHON_UNBUFFERED: 1 run: | set -x + env_flag=$( [[ ${{ inputs.environment }} == pypi-prod* ]] && echo "--prod" || echo "--test" ) uv sync --only-group pypi --no-install-project - # TODO: set test/prod flag according to env (inputs.environment == 'production.pypi' && '--prod' || '--test') uv run --no-sync python -u -m duckdb_packaging.pypi_cleanup ${{ inputs.dry-run && '--dry' || '' }} \ - --test \ + $env_flag \ --username "${{ vars.PYPI_CLEANUP_USERNAME }}" \ --max-nightlies ${{ vars.PYPI_MAX_NIGHTLIES }} 2>&1 | tee cleanup_output From 379f8fd3432cb8d725279f03df433a284ee560de Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 25 Aug 2025 16:01:46 +0200 Subject: [PATCH 110/472] Remove obsolete PYPI_HOST from summary --- .github/workflows/cleanup_pypi.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/cleanup_pypi.yml b/.github/workflows/cleanup_pypi.yml index e33e8b65..1aeaa365 100644 --- a/.github/workflows/cleanup_pypi.yml +++ b/.github/workflows/cleanup_pypi.yml @@ -71,7 +71,6 @@ jobs: run : | echo "## PyPI Cleanup Summary" >> $GITHUB_STEP_SUMMARY echo "* Dry run: ${{ inputs.dry-run }}" >> $GITHUB_STEP_SUMMARY - echo "* PyPI Host: ${{ vars.PYPI_HOST }}" >> $GITHUB_STEP_SUMMARY echo "* CI Environment: ${{ inputs.environment }}" >> $GITHUB_STEP_SUMMARY echo "* Output:" >> $GITHUB_STEP_SUMMARY echo '```' >> $GITHUB_STEP_SUMMARY From 6e29f57faeb67dfdb75bc8047e2faf81c2542e02 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 25 Aug 2025 16:02:52 +0200 Subject: [PATCH 111/472] Set submodule commit --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index f8886f66..4d2cc504 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit f8886f66ac712c75bdab5cdad54235a629c8d883 +Subproject commit 4d2cc5049aa6444f98e848abe2a5d4f628d99e4d From 548906ce9c4589b6b50f48f52a67af4ac597d0c7 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 25 Aug 2025 16:03:27 +0200 Subject: [PATCH 112/472] Set submodule commit --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index c5310ec8..df0a3de7 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit c5310ec83bf0f2a6ce8b6f04656f54fdcf7da3ef +Subproject commit df0a3de74429887333ec4af047e7aac2737e52d8 From 57cd93466dc785c2725c372aa5f49256646277d5 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 25 Aug 2025 16:07:35 +0200 Subject: [PATCH 113/472] Disallow pypi-prod for removing dev versions from pypi --- .github/workflows/cleanup_pypi.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/cleanup_pypi.yml b/.github/workflows/cleanup_pypi.yml index 1aeaa365..86e96c49 100644 --- a/.github/workflows/cleanup_pypi.yml +++ b/.github/workflows/cleanup_pypi.yml @@ -25,7 +25,6 @@ on: required: true options: - pypi-prod-nightly - - pypi-prod - pypi-test jobs: @@ -60,10 +59,9 @@ jobs: PYTHON_UNBUFFERED: 1 run: | set -x - env_flag=$( [[ ${{ inputs.environment }} == pypi-prod* ]] && echo "--prod" || echo "--test" ) uv sync --only-group pypi --no-install-project uv run --no-sync python -u -m duckdb_packaging.pypi_cleanup ${{ inputs.dry-run && '--dry' || '' }} \ - $env_flag \ + ${{ inputs.environment == 'pypi-prod-nightly' && '--prod' || '--test' }} \ --username "${{ vars.PYPI_CLEANUP_USERNAME }}" \ --max-nightlies ${{ vars.PYPI_MAX_NIGHTLIES }} 2>&1 | tee cleanup_output From 040159cb1fcc250f747dcea9b95c5cafa22a0a85 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 27 Aug 2025 14:06:13 +0200 Subject: [PATCH 114/472] Update submodule --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 4d2cc504..050ba71c 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 4d2cc5049aa6444f98e848abe2a5d4f628d99e4d +Subproject commit 050ba71c812fe42cc7f79d601e94fda8d0549e69 From 5410b11433d23ef62c267bea15112a97b37152b3 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 27 Aug 2025 16:41:54 +0200 Subject: [PATCH 115/472] Correct authors --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2ade07cf..07eea22e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,8 +32,8 @@ classifiers = [ "Programming Language :: Python :: 3.13", "Programming Language :: C++", ] -authors = [{name = "DuckDB Labs", email = "info@duckdblabs.nl"}] -maintainers = [{name = "DuckDB Labs", email = "info@duckdblabs.nl"}] +authors = [{name = "DuckDB Foundation"}] +maintainers = [{name = "DuckDB Foundation"}] [project.urls] Documentation = "https://duckdb.org/docs/stable/clients/python/overview" From 5ad9e789e76f88ff41f320858dd8a6127ae55e64 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 27 Aug 2025 16:42:52 +0200 Subject: [PATCH 116/472] Updated submodule --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 050ba71c..1dec2d04 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 050ba71c812fe42cc7f79d601e94fda8d0549e69 +Subproject commit 1dec2d047e2ab845442babbd4adc08f5d45bda4d From 32f3d1b4b075a9053ccb2dc5b104d538f836d01a Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 28 Aug 2025 16:38:15 +0200 Subject: [PATCH 117/472] Enable extension autoinstall and autoloading (#21) Fixes https://github.com/duckdb/duckdb/issues/18770 --- cmake/duckdb_loader.cmake | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cmake/duckdb_loader.cmake b/cmake/duckdb_loader.cmake index 5309924b..ec45d497 100644 --- a/cmake/duckdb_loader.cmake +++ b/cmake/duckdb_loader.cmake @@ -47,8 +47,10 @@ _duckdb_set_default(BUILD_UNITTESTS OFF) _duckdb_set_default(BUILD_BENCHMARKS OFF) _duckdb_set_default(DISABLE_UNITY OFF) -# Extension configuration - static linking for Python modules +# Extension configuration _duckdb_set_default(DISABLE_BUILTIN_EXTENSIONS OFF) +_duckdb_set_default(ENABLE_EXTENSION_AUTOINSTALL 1) # todo: set to ON https://github.com/duckdb/duckdb/pull/18778/files +_duckdb_set_default(ENABLE_EXTENSION_AUTOLOADING ON) # Performance options - enable optimizations by default _duckdb_set_default(NATIVE_ARCH OFF) @@ -69,6 +71,8 @@ set(BUILD_UNITTESTS "${BUILD_UNITTESTS}" CACHE BOOL "Build DuckDB unit tests") set(BUILD_BENCHMARKS "${BUILD_BENCHMARKS}" CACHE BOOL "Build DuckDB benchmarks") set(DISABLE_UNITY "${DISABLE_UNITY}" CACHE BOOL "Disable unity builds (slower compilation)") set(DISABLE_BUILTIN_EXTENSIONS "${DISABLE_BUILTIN_EXTENSIONS}" CACHE BOOL "Disable all built-in extensions") +set(ENABLE_EXTENSION_AUTOINSTALL "${ENABLE_EXTENSION_AUTOINSTALL}" CACHE BOOL "Enable extension auto-installing by default.") +set(ENABLE_EXTENSION_AUTOLOADING "${ENABLE_EXTENSION_AUTOLOADING}" CACHE BOOL "Enable extension auto-loading by default.") set(NATIVE_ARCH "${NATIVE_ARCH}" CACHE BOOL "Optimize for native architecture") set(ENABLE_SANITIZER "${ENABLE_SANITIZER}" CACHE BOOL "Enable address sanitizer") set(ENABLE_UBSAN "${ENABLE_UBSAN}" CACHE BOOL "Enable undefined behavior sanitizer") From fee1dc266094e30a5f5ad370737a6d6becd91593 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 28 Aug 2025 16:51:28 +0200 Subject: [PATCH 118/472] Update submodule --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 1dec2d04..4bb3b055 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 1dec2d047e2ab845442babbd4adc08f5d45bda4d +Subproject commit 4bb3b05577ce60934fe4d3dd41fbfbd2b7bcdf11 From 136d8c7d8cd96391555be58fc03dbe621dd035ac Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 28 Aug 2025 17:17:50 +0200 Subject: [PATCH 119/472] Fwd ports of: * #18749 Fixing lazy polars execution on query result * #18682 Make sure parse errors are wrapped in ErrorData --- .../include/duckdb_python/pyrelation.hpp | 4 +- .../include/duckdb_python/pyresult.hpp | 2 + src/duckdb_py/pyconnection.cpp | 8 ++-- src/duckdb_py/pyrelation.cpp | 11 ++++- src/duckdb_py/pyresult.cpp | 4 ++ tests/fast/arrow/test_polars.py | 8 ++++ tests/fast/test_json_logging.py | 48 +++++++++++++++++++ 7 files changed, 77 insertions(+), 8 deletions(-) create mode 100644 tests/fast/test_json_logging.py diff --git a/src/duckdb_py/include/duckdb_python/pyrelation.hpp b/src/duckdb_py/include/duckdb_python/pyrelation.hpp index b1feb8ba..e1f78b5a 100644 --- a/src/duckdb_py/include/duckdb_python/pyrelation.hpp +++ b/src/duckdb_py/include/duckdb_python/pyrelation.hpp @@ -28,7 +28,7 @@ namespace duckdb { struct DuckDBPyRelation { public: explicit DuckDBPyRelation(shared_ptr rel); - explicit DuckDBPyRelation(unique_ptr result); + explicit DuckDBPyRelation(shared_ptr result); ~DuckDBPyRelation(); public: @@ -288,7 +288,7 @@ struct DuckDBPyRelation { shared_ptr rel; vector types; vector names; - unique_ptr result; + shared_ptr result; std::string rendered_result; }; diff --git a/src/duckdb_py/include/duckdb_python/pyresult.hpp b/src/duckdb_py/include/duckdb_python/pyresult.hpp index 2e6b8307..fc3641c4 100644 --- a/src/duckdb_py/include/duckdb_python/pyresult.hpp +++ b/src/duckdb_py/include/duckdb_python/pyresult.hpp @@ -59,6 +59,8 @@ struct DuckDBPyResult { const vector &GetNames(); const vector &GetTypes(); + ClientProperties GetClientProperties(); + private: void FillNumpy(py::dict &res, idx_t col_idx, NumpyResultConversion &conversion, const char *name); diff --git a/src/duckdb_py/pyconnection.cpp b/src/duckdb_py/pyconnection.cpp index 90156395..94745b75 100644 --- a/src/duckdb_py/pyconnection.cpp +++ b/src/duckdb_py/pyconnection.cpp @@ -614,17 +614,17 @@ unique_ptr DuckDBPyConnection::PrepareAndExecuteInternal(unique_ptr } vector> DuckDBPyConnection::GetStatements(const py::object &query) { - vector> result; - auto &connection = con.GetConnection(); - shared_ptr statement_obj; if (py::try_cast(query, statement_obj)) { + vector> result; result.push_back(statement_obj->GetStatement()); return result; } if (py::isinstance(query)) { + auto &connection = con.GetConnection(); auto sql_query = std::string(py::str(query)); - return connection.ExtractStatements(sql_query); + auto statements = connection.ExtractStatements(sql_query); + return std::move(statements); } throw InvalidInputException("Please provide either a DuckDBPyStatement or a string representing the query"); } diff --git a/src/duckdb_py/pyrelation.cpp b/src/duckdb_py/pyrelation.cpp index 60c4fbf5..3553bff0 100644 --- a/src/duckdb_py/pyrelation.cpp +++ b/src/duckdb_py/pyrelation.cpp @@ -66,7 +66,7 @@ DuckDBPyRelation::~DuckDBPyRelation() { rel.reset(); } -DuckDBPyRelation::DuckDBPyRelation(unique_ptr result_p) : rel(nullptr), result(std::move(result_p)) { +DuckDBPyRelation::DuckDBPyRelation(shared_ptr result_p) : rel(nullptr), result(std::move(result_p)) { if (!result) { throw InternalException("DuckDBPyRelation created without a result"); } @@ -984,7 +984,14 @@ PolarsDataFrame DuckDBPyRelation::ToPolars(idx_t batch_size, bool lazy) { ArrowSchema arrow_schema; auto result_names = names; QueryResult::DeduplicateColumns(result_names); - auto client_properties = rel->context->GetContext()->GetClientProperties(); + ClientProperties client_properties; + if (rel) { + client_properties = rel->context->GetContext()->GetClientProperties(); + } else if (result) { + client_properties = result->GetClientProperties(); + } else { + throw InternalException("DuckDBPyRelation To Polars must have a valid relation or result"); + } ArrowConverter::ToArrowSchema(&arrow_schema, types, result_names, client_properties); py::list batches; // Now we create an empty arrow table diff --git a/src/duckdb_py/pyresult.cpp b/src/duckdb_py/pyresult.cpp index aa84cee5..5997d57b 100644 --- a/src/duckdb_py/pyresult.cpp +++ b/src/duckdb_py/pyresult.cpp @@ -41,6 +41,10 @@ DuckDBPyResult::~DuckDBPyResult() { } } +ClientProperties DuckDBPyResult::GetClientProperties() { + return result->client_properties; +} + const vector &DuckDBPyResult::GetNames() { if (!result) { throw InternalException("Calling GetNames without a result object"); diff --git a/tests/fast/arrow/test_polars.py b/tests/fast/arrow/test_polars.py index 1a86c82d..89ccf031 100644 --- a/tests/fast/arrow/test_polars.py +++ b/tests/fast/arrow/test_polars.py @@ -101,6 +101,14 @@ def test_polars_from_json_error(self, duckdb_cursor): my_res = duckdb.query("select my_str from my_table where my_str != 'y'") assert my_res.fetchall() == [('x',)] + def test_polars_lazy_from_conn(self, duckdb_cursor): + duckdb_conn = duckdb.connect() + + result = duckdb_conn.execute("SELECT 42 as bla") + + lazy_df = result.pl(lazy=True) + assert lazy_df.collect().to_dicts() == [{'bla': 42}] + def test_polars_lazy(self, duckdb_cursor): con = duckdb.connect() con.execute("Create table names (a varchar, b integer)") diff --git a/tests/fast/test_json_logging.py b/tests/fast/test_json_logging.py new file mode 100644 index 00000000..a7f305f3 --- /dev/null +++ b/tests/fast/test_json_logging.py @@ -0,0 +1,48 @@ +import json + +import duckdb +import pytest + + +def _parse_json_func(error_prefix: str): + """Helper to check that the error message is indeed parsable json""" + + def parse_func(exception): + msg = exception.args[0] + assert msg.startswith(error_prefix) + json_str = msg.split(error_prefix, 1)[1] + try: + json.loads(json_str) + except: + return False + return True + + return parse_func + + +def test_json_syntax_error(): + conn = duckdb.connect() + conn.execute("SET errors_as_json='true'") + with pytest.raises(duckdb.ParserException, match="SYNTAX_ERROR", check=_parse_json_func("Parser Error: ")): + conn.execute("syntax error") + + +def test_json_catalog_error(): + conn = duckdb.connect() + conn.execute("SET errors_as_json='true'") + with pytest.raises(duckdb.CatalogException, match="MISSING_ENTRY", check=_parse_json_func("Catalog Error: ")): + conn.execute("SELECT * FROM nonexistent_table") + + +def test_json_syntax_error_extract_statements(): + conn = duckdb.connect() + conn.execute("SET errors_as_json='true'") + with pytest.raises(duckdb.ParserException, match="SYNTAX_ERROR", check=_parse_json_func("Parser Error: ")): + conn.extract_statements("syntax error") + + +def test_json_syntax_error_get_table_names(): + conn = duckdb.connect() + conn.execute("SET errors_as_json='true'") + with pytest.raises(duckdb.ParserException, match="SYNTAX_ERROR", check=_parse_json_func("Parser Error: ")): + conn.get_table_names("syntax error") From 477e2581b9adf1a2e99130d0a2c028b32c107d76 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 28 Aug 2025 17:25:42 +0200 Subject: [PATCH 120/472] Port support for MERGE INTO --- src/duckdb_py/pystatement.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/duckdb_py/pystatement.cpp b/src/duckdb_py/pystatement.cpp index bfcb96aa..7e84df7e 100644 --- a/src/duckdb_py/pystatement.cpp +++ b/src/duckdb_py/pystatement.cpp @@ -92,6 +92,12 @@ py::list DuckDBPyStatement::ExpectedResultType() const { possibilities.append(StatementReturnType::NOTHING); break; } + case StatementType::MERGE_INTO_STATEMENT: { + possibilities.append(StatementReturnType::CHANGED_ROWS); + possibilities.append(StatementReturnType::QUERY_RESULT); + possibilities.append(StatementReturnType::NOTHING); + break; + } default: { throw InternalException("Unrecognized StatementType in ExpectedResultType: %s", StatementTypeToString(statement->type)); From 728228e3804720dfbe80e6c14a5dd860b198941c Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 28 Aug 2025 17:49:41 +0200 Subject: [PATCH 121/472] Fail PRs with incorrect submodule URL --- .github/workflows/on_pr.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/.github/workflows/on_pr.yml b/.github/workflows/on_pr.yml index d8bfe825..4a7e6e3e 100644 --- a/.github/workflows/on_pr.yml +++ b/.github/workflows/on_pr.yml @@ -19,6 +19,25 @@ concurrency: cancel-in-progress: true jobs: + ensure_submodule_sanity: + name: Make sure we're not building with a fork + runs-on: ubuntu-latest + steps: + - name: Checkout DuckDB Python + uses: actions/checkout@v4 + + - shell: bash + run: | + submodule_url=$(git config --file .gitmodules --get submodule.external/duckdb.url || true) + expected="github.com/duckdb/duckdb" + if [[ -z "$submodule_url" ]]; then + echo "::error::DuckDB submodule not found in .gitmodules" + exit 1 + fi + if [[ "$submodule_url" != *"$expected"* ]]; then + echo "::error::DuckDB submodule must point to $expected, found: $submodule_url" + exit 1 + fi packaging_test: name: Build a minimal set of packages and run all tests on them From fe6b7a66c0f815933bc760f3ec80f95b81262f1d Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 28 Aug 2025 19:06:17 +0200 Subject: [PATCH 122/472] Fwd port of https://github.com/duckdb/duckdb/pull/15789/files (#20) --- duckdb/experimental/spark/sql/dataframe.py | 30 ++++++++++++++++++++++ tests/fast/spark/test_spark_dataframe.py | 8 ++++++ 2 files changed, 38 insertions(+) diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 31b13ded..b8a4698b 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -1403,5 +1403,35 @@ def construct_row(values, names) -> Row: rows = [construct_row(x, columns) for x in result] return rows + def cache(self) -> "DataFrame": + """Persists the :class:`DataFrame` with the default storage level (`MEMORY_AND_DISK_DESER`). + + .. versionadded:: 1.3.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Notes + ----- + The default storage level has changed to `MEMORY_AND_DISK_DESER` to match Scala in 3.0. + + Returns + ------- + :class:`DataFrame` + Cached DataFrame. + + Examples + -------- + >>> df = spark.range(1) + >>> df.cache() + DataFrame[id: bigint] + + >>> df.explain() + == Physical Plan == + InMemoryTableScan ... + """ + cached_relation = self.relation.execute() + return DataFrame(cached_relation, self.session) + __all__ = ["DataFrame"] diff --git a/tests/fast/spark/test_spark_dataframe.py b/tests/fast/spark/test_spark_dataframe.py index 5b7492d7..d88b03eb 100644 --- a/tests/fast/spark/test_spark_dataframe.py +++ b/tests/fast/spark/test_spark_dataframe.py @@ -421,3 +421,11 @@ def test_drop(self, spark): assert df.drop("two", "three").columns == expected assert df.drop("two", col("three")).columns == expected assert df.drop("two", col("three"), col("missing")).columns == expected + + def test_cache(self, spark): + data = [(1, 2, 3, 4)] + df = spark.createDataFrame(data, ["one", "two", "three", "four"]) + cached = df.cache() + assert df is not cached + assert cached.collect() == df.collect() + assert cached.collect() == [Row(one=1, two=2, three=3, four=4)] From ee0aaba4aaee22f3780727d17d9eea6162788f15 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 29 Aug 2025 10:14:52 +0200 Subject: [PATCH 123/472] Submodule sanity is a requirement --- .github/workflows/on_pr.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/on_pr.yml b/.github/workflows/on_pr.yml index 4a7e6e3e..5d7328fb 100644 --- a/.github/workflows/on_pr.yml +++ b/.github/workflows/on_pr.yml @@ -41,6 +41,7 @@ jobs: packaging_test: name: Build a minimal set of packages and run all tests on them + needs: ensure_submodule_sanity # Skip packaging tests for draft PRs if: ${{ github.event_name != 'pull_request' || github.event.pull_request.draft == false }} uses: ./.github/workflows/packaging.yml @@ -51,7 +52,9 @@ jobs: coverage_test: name: Run coverage tests - if: ${{ github.event_name != 'pull_request' || github.event.pull_request.draft == false }} + needs: ensure_submodule_sanity + # Only run coverage test for draft PRs + if: ${{ github.event_name == 'pull_request' && github.event.pull_request.draft == true }} uses: ./.github/workflows/coverage.yml with: duckdb_git_ref: ${{ github.base_ref }} From 2fad7319e93c774a8a8d948ec529dbea05ac289d Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 29 Aug 2025 10:32:33 +0200 Subject: [PATCH 124/472] Better contribute instructions --- README.md | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 1546ad51..74936d52 100644 --- a/README.md +++ b/README.md @@ -46,18 +46,22 @@ pip install 'duckdb[all]' ### Cloning -When you clone the repo or your fork, make sure you initialize the duckdb submodule: +After forking this duckdb-python, we recommend you clone your fork as follows: ```shell -git clone --recurse-submodules +git clone --recurse-submodules $REPO_URL +git remote add upstream https://github.com/duckdb/duckdb-python.git +git fetch --all ``` -... or, if you already have the repo locally: +... or, if you have already cloned your fork: ```shell -git clone -cd git submodule update --init --recursive +git remote add upstream https://github.com/duckdb/duckdb-python.git +git fetch --all ``` +### Submodule update hook + If you'll be switching between branches that are have the submodule set to different refs, then make your life easier and add the git hooks in the .githooks directory to your local config: ```shell @@ -78,7 +82,7 @@ git config --local core.hooksPath .githooks/ ```bash # install all dev dependencies without building the project (needed once) -uv sync -p 3.9 --no-install-project +uv sync -p 3.11 --no-install-project # build and install without build isolation uv sync --no-build-isolation ``` From 1d21b07c308d9632a6c9813b6adb5e93ee548a0a Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 29 Aug 2025 10:39:36 +0200 Subject: [PATCH 125/472] Add fork icon --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 74936d52..561b3afd 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,10 @@ pip install 'duckdb[all]' ## Development +Start by + +forking duckdb-python. + ### Cloning After forking this duckdb-python, we recommend you clone your fork as follows: From 172cacadd667b8b4b7adc209239a76568a45a0dc Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 29 Aug 2025 10:40:58 +0200 Subject: [PATCH 126/472] Typos --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 561b3afd..c70d6e2b 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ Start by Date: Fri, 29 Aug 2025 10:44:06 +0200 Subject: [PATCH 127/472] Use ENABLE_EXTENSION_AUTOINSTALL as BOOL --- cmake/duckdb_loader.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/duckdb_loader.cmake b/cmake/duckdb_loader.cmake index ec45d497..80ec9a4d 100644 --- a/cmake/duckdb_loader.cmake +++ b/cmake/duckdb_loader.cmake @@ -49,7 +49,7 @@ _duckdb_set_default(DISABLE_UNITY OFF) # Extension configuration _duckdb_set_default(DISABLE_BUILTIN_EXTENSIONS OFF) -_duckdb_set_default(ENABLE_EXTENSION_AUTOINSTALL 1) # todo: set to ON https://github.com/duckdb/duckdb/pull/18778/files +_duckdb_set_default(ENABLE_EXTENSION_AUTOINSTALL ON) _duckdb_set_default(ENABLE_EXTENSION_AUTOLOADING ON) # Performance options - enable optimizations by default From 213a6891d1e82350367e00f5cf1a917859a10a09 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Fri, 29 Aug 2025 11:57:14 +0200 Subject: [PATCH 128/472] Forward port of https://github.com/duckdb/duckdb/pull/15462 and https://github.com/duckdb/duckdb/pull/15036 (#19) --- duckdb/experimental/spark/sql/functions.py | 123 ++++++++++++++++++ .../fast/spark/test_spark_functions_array.py | 64 ++++----- tests/fast/spark/test_spark_functions_date.py | 25 +++- tests/fast/spark/test_spark_functions_null.py | 5 + .../spark/test_spark_functions_numeric.py | 82 ++++++------ 5 files changed, 229 insertions(+), 70 deletions(-) diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index a6d67aeb..fecada95 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -1851,6 +1851,30 @@ def isnotnull(col: "ColumnOrName") -> Column: return Column(_to_column_expr(col).isnotnull()) +def equal_null(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: + """ + Returns same result as the EQUAL(=) operator for non-null operands, + but returns true if both are null, false if one of the them is null. + .. versionadded:: 3.5.0 + Parameters + ---------- + col1 : :class:`~pyspark.sql.Column` or str + col2 : :class:`~pyspark.sql.Column` or str + Examples + -------- + >>> df = spark.createDataFrame([(None, None,), (1, 9,)], ["a", "b"]) + >>> df.select(equal_null(df.a, df.b).alias('r')).collect() + [Row(r=True), Row(r=False)] + """ + if isinstance(col1, str): + col1 = col(col1) + + if isinstance(col2, str): + col2 = col(col2) + + return nvl((col1 == col2) | ((col1.isNull() & col2.isNull())), lit(False)) + + def flatten(col: "ColumnOrName") -> Column: """ Collection function: creates a single array from an array of arrays. @@ -2157,6 +2181,33 @@ def e() -> Column: return lit(2.718281828459045) +def negative(col: "ColumnOrName") -> Column: + """ + Returns the negative value. + .. versionadded:: 3.5.0 + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + column to calculate negative value for. + Returns + ------- + :class:`~pyspark.sql.Column` + negative value. + Examples + -------- + >>> import pyspark.sql.functions as sf + >>> spark.range(3).select(sf.negative("id")).show() + +------------+ + |negative(id)| + +------------+ + | 0| + | -1| + | -2| + +------------+ + """ + return abs(col) * -1 + + def pi() -> Column: """Returns Pi. @@ -3774,6 +3825,53 @@ def datepart(field: "ColumnOrName", source: "ColumnOrName") -> Column: return date_part(field, source) +def date_diff(end: "ColumnOrName", start: "ColumnOrName") -> Column: + """ + Returns the number of days from `start` to `end`. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + end : :class:`~pyspark.sql.Column` or column name + to date column to work on. + start : :class:`~pyspark.sql.Column` or column name + from date column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + difference in days between two dates. + + See Also + -------- + :meth:`pyspark.sql.functions.dateadd` + :meth:`pyspark.sql.functions.date_add` + :meth:`pyspark.sql.functions.date_sub` + :meth:`pyspark.sql.functions.datediff` + :meth:`pyspark.sql.functions.timestamp_diff` + + Examples + -------- + >>> import pyspark.sql.functions as sf + >>> df = spark.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2']) + >>> df.select('*', sf.date_diff(sf.col('d1').cast('DATE'), sf.col('d2').cast('DATE'))).show() + +----------+----------+-----------------+ + | d1| d2|date_diff(d1, d2)| + +----------+----------+-----------------+ + |2015-04-08|2015-05-10| -32| + +----------+----------+-----------------+ + + >>> df.select('*', sf.date_diff(sf.col('d1').cast('DATE'), sf.col('d2').cast('DATE'))).show() + +----------+----------+-----------------+ + | d1| d2|date_diff(d2, d1)| + +----------+----------+-----------------+ + |2015-04-08|2015-05-10| 32| + +----------+----------+-----------------+ + """ + return _invoke_function_over_columns("date_diff", lit("day"), end, start) + + def year(col: "ColumnOrName") -> Column: """ Extract the year of a given date/timestamp as integer. @@ -5685,6 +5783,31 @@ def to_timestamp_ntz( return _to_date_or_timestamp(timestamp, _types.TimestampNTZType(), format) +def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = None) -> Column: + """ + Parses the `col` with the `format` to a timestamp. The function always + returns null on an invalid input with/without ANSI SQL mode enabled. The result data type is + consistent with the value of configuration `spark.sql.timestampType`. + .. versionadded:: 3.5.0 + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + column values to convert. + format: str, optional + format to use to convert timestamp values. + Examples + -------- + >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df.select(try_to_timestamp(df.t).alias('dt')).collect() + [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] + >>> df.select(try_to_timestamp(df.t, lit('yyyy-MM-dd HH:mm:ss')).alias('dt')).collect() + [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] + """ + if format is None: + format = lit(['%Y-%m-%d', '%Y-%m-%d %H:%M:%S']) + + return _invoke_function_over_columns("try_strptime", col, format) + def substr( str: "ColumnOrName", pos: "ColumnOrName", len: Optional["ColumnOrName"] = None ) -> Column: diff --git a/tests/fast/spark/test_spark_functions_array.py b/tests/fast/spark/test_spark_functions_array.py index 77c4c21a..f83e0ef2 100644 --- a/tests/fast/spark/test_spark_functions_array.py +++ b/tests/fast/spark/test_spark_functions_array.py @@ -2,7 +2,7 @@ import platform _ = pytest.importorskip("duckdb.experimental.spark") -from spark_namespace.sql import functions as F +from spark_namespace.sql import functions as sf from spark_namespace.sql.types import Row from spark_namespace import USE_ACTUAL_SPARK @@ -19,7 +19,7 @@ def test_array_distinct(self, spark): ([2, 4, 5], 3), ] df = spark.createDataFrame(data, ["firstColumn", "secondColumn"]) - df = df.withColumn("distinct_values", F.array_distinct(F.col("firstColumn"))) + df = df.withColumn("distinct_values", sf.array_distinct(sf.col("firstColumn"))) res = df.select("distinct_values").collect() # Output order can vary across platforms which is why we sort it first assert len(res) == 2 @@ -31,7 +31,7 @@ def test_array_intersect(self, spark): (["b", "a", "c"], ["c", "d", "a", "f"]), ] df = spark.createDataFrame(data, ["c1", "c2"]) - df = df.withColumn("intersect_values", F.array_intersect(F.col("c1"), F.col("c2"))) + df = df.withColumn("intersect_values", sf.array_intersect(sf.col("c1"), sf.col("c2"))) res = df.select("intersect_values").collect() # Output order can vary across platforms which is why we sort it first assert len(res) == 1 @@ -42,7 +42,7 @@ def test_array_union(self, spark): (["b", "a", "c"], ["c", "d", "a", "f"]), ] df = spark.createDataFrame(data, ["c1", "c2"]) - df = df.withColumn("union_values", F.array_union(F.col("c1"), F.col("c2"))) + df = df.withColumn("union_values", sf.array_union(sf.col("c1"), sf.col("c2"))) res = df.select("union_values").collect() # Output order can vary across platforms which is why we sort it first assert len(res) == 1 @@ -54,7 +54,7 @@ def test_array_max(self, spark): ([4, 2, 5], 5), ] df = spark.createDataFrame(data, ["firstColumn", "secondColumn"]) - df = df.withColumn("max_value", F.array_max(F.col("firstColumn"))) + df = df.withColumn("max_value", sf.array_max(sf.col("firstColumn"))) res = df.select("max_value").collect() assert res == [ Row(max_value=3), @@ -67,7 +67,7 @@ def test_array_min(self, spark): ([2, 4, 5], 5), ] df = spark.createDataFrame(data, ["firstColumn", "secondColumn"]) - df = df.withColumn("min_value", F.array_min(F.col("firstColumn"))) + df = df.withColumn("min_value", sf.array_min(sf.col("firstColumn"))) res = df.select("min_value").collect() assert res == [ Row(max_value=1), @@ -77,58 +77,58 @@ def test_array_min(self, spark): def test_get(self, spark): df = spark.createDataFrame([(["a", "b", "c"], 1)], ['data', 'index']) - res = df.select(F.get(df.data, 1).alias("r")).collect() + res = df.select(sf.get(df.data, 1).alias("r")).collect() assert res == [Row(r="b")] - res = df.select(F.get(df.data, -1).alias("r")).collect() + res = df.select(sf.get(df.data, -1).alias("r")).collect() assert res == [Row(r=None)] - res = df.select(F.get(df.data, 3).alias("r")).collect() + res = df.select(sf.get(df.data, 3).alias("r")).collect() assert res == [Row(r=None)] - res = df.select(F.get(df.data, "index").alias("r")).collect() + res = df.select(sf.get(df.data, "index").alias("r")).collect() assert res == [Row(r='b')] - res = df.select(F.get(df.data, F.col("index") - 1).alias("r")).collect() + res = df.select(sf.get(df.data, sf.col("index") - 1).alias("r")).collect() assert res == [Row(r='a')] def test_flatten(self, spark): df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data']) - res = df.select(F.flatten(df.data).alias("r")).collect() + res = df.select(sf.flatten(df.data).alias("r")).collect() assert res == [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)] def test_array_compact(self, spark): df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ['data']) - res = df.select(F.array_compact(df.data).alias("v")).collect() + res = df.select(sf.array_compact(df.data).alias("v")).collect() assert [Row(v=[1, 2, 3]), Row(v=[4, 5, 4])] def test_array_remove(self, spark): df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data']) - res = df.select(F.array_remove(df.data, 1).alias("v")).collect() + res = df.select(sf.array_remove(df.data, 1).alias("v")).collect() assert res == [Row(v=[2, 3]), Row(v=[])] def test_array_agg(self, spark): df = spark.createDataFrame([[1, "A"], [1, "A"], [2, "A"]], ["c", "group"]) - res = df.groupBy("group").agg(F.array_agg("c").alias("r")).collect() + res = df.groupBy("group").agg(sf.array_agg("c").alias("r")).collect() assert res[0] == Row(group="A", r=[1, 1, 2]) def test_collect_list(self, spark): df = spark.createDataFrame([[1, "A"], [1, "A"], [2, "A"]], ["c", "group"]) - res = df.groupBy("group").agg(F.collect_list("c").alias("r")).collect() + res = df.groupBy("group").agg(sf.collect_list("c").alias("r")).collect() assert res[0] == Row(group="A", r=[1, 1, 2]) def test_array_append(self, spark): df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2="c")], ["c1", "c2"]) - res = df.select(F.array_append(df.c1, df.c2).alias("r")).collect() + res = df.select(sf.array_append(df.c1, df.c2).alias("r")).collect() assert res == [Row(r=['b', 'a', 'c', 'c'])] - res = df.select(F.array_append(df.c1, 'x')).collect() + res = df.select(sf.array_append(df.c1, 'x')).collect() assert res == [Row(r=['b', 'a', 'c', 'x'])] def test_array_insert(self, spark): @@ -137,21 +137,21 @@ def test_array_insert(self, spark): ['data', 'pos', 'val'], ) - res = df.select(F.array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect() + res = df.select(sf.array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect() assert res == [ Row(data=['a', 'd', 'b', 'c']), Row(data=['a', 'd', 'b', 'c', 'e']), Row(data=['c', 'b', 'd', 'a']), ] - res = df.select(F.array_insert(df.data, 5, 'hello').alias('data')).collect() + res = df.select(sf.array_insert(df.data, 5, 'hello').alias('data')).collect() assert res == [ Row(data=['a', 'b', 'c', None, 'hello']), Row(data=['a', 'b', 'c', 'e', 'hello']), Row(data=['c', 'b', 'a', None, 'hello']), ] - res = df.select(F.array_insert(df.data, -5, 'hello').alias('data')).collect() + res = df.select(sf.array_insert(df.data, -5, 'hello').alias('data')).collect() assert res == [ Row(data=['hello', None, 'a', 'b', 'c']), Row(data=['hello', 'a', 'b', 'c', 'e']), @@ -160,53 +160,53 @@ def test_array_insert(self, spark): def test_slice(self, spark): df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x']) - res = df.select(F.slice(df.x, 2, 2).alias("sliced")).collect() + res = df.select(sf.slice(df.x, 2, 2).alias("sliced")).collect() assert res == [Row(sliced=[2, 3]), Row(sliced=[5])] def test_sort_array(self, spark): df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ['data']) - res = df.select(F.sort_array(df.data).alias('r')).collect() + res = df.select(sf.sort_array(df.data).alias('r')).collect() assert res == [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])] - res = df.select(F.sort_array(df.data, asc=False).alias('r')).collect() + res = df.select(sf.sort_array(df.data, asc=False).alias('r')).collect() assert res == [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] @pytest.mark.parametrize(("null_replacement", "expected_joined_2"), [(None, "a"), ("replaced", "a,replaced")]) def test_array_join(self, spark, null_replacement, expected_joined_2): df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data']) - res = df.select(F.array_join(df.data, ",", null_replacement=null_replacement).alias("joined")).collect() + res = df.select(sf.array_join(df.data, ",", null_replacement=null_replacement).alias("joined")).collect() assert res == [Row(joined='a,b,c'), Row(joined=expected_joined_2)] def test_array_position(self, spark): df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data']) - res = df.select(F.array_position(df.data, "a").alias("pos")).collect() + res = df.select(sf.array_position(df.data, "a").alias("pos")).collect() assert res == [Row(pos=3), Row(pos=0)] def test_array_preprend(self, spark): df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) - res = df.select(F.array_prepend(df.data, 1).alias("pre")).collect() + res = df.select(sf.array_prepend(df.data, 1).alias("pre")).collect() assert res == [Row(pre=[1, 2, 3, 4]), Row(pre=[1])] def test_array_repeat(self, spark): df = spark.createDataFrame([('ab',)], ['data']) - res = df.select(F.array_repeat(df.data, 3).alias('r')).collect() + res = df.select(sf.array_repeat(df.data, 3).alias('r')).collect() assert res == [Row(r=['ab', 'ab', 'ab'])] def test_array_size(self, spark): df = spark.createDataFrame([([2, 1, 3],), (None,)], ['data']) - res = df.select(F.array_size(df.data).alias('r')).collect() + res = df.select(sf.array_size(df.data).alias('r')).collect() assert res == [Row(r=3), Row(r=None)] def test_array_sort(self, spark): df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ['data']) - res = df.select(F.array_sort(df.data).alias('r')).collect() + res = df.select(sf.array_sort(df.data).alias('r')).collect() assert res == [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])] def test_arrays_overlap(self, spark): @@ -214,13 +214,13 @@ def test_arrays_overlap(self, spark): [(["a", "b"], ["b", "c"]), (["a"], ["b", "c"]), ([None, "c"], ["a"]), ([None, "c"], [None])], ['x', 'y'] ) - res = df.select(F.arrays_overlap(df.x, df.y).alias("overlap")).collect() + res = df.select(sf.arrays_overlap(df.x, df.y).alias("overlap")).collect() assert res == [Row(overlap=True), Row(overlap=False), Row(overlap=None), Row(overlap=None)] def test_arrays_zip(self, spark): df = spark.createDataFrame([([1, 2, 3], [2, 4, 6], [3, 6])], ['vals1', 'vals2', 'vals3']) - res = df.select(F.arrays_zip(df.vals1, df.vals2, df.vals3).alias('zipped')).collect() + res = df.select(sf.arrays_zip(df.vals1, df.vals2, df.vals3).alias('zipped')).collect() # FIXME: The structure of the results should be the same if USE_ACTUAL_SPARK: assert res == [ diff --git a/tests/fast/spark/test_spark_functions_date.py b/tests/fast/spark/test_spark_functions_date.py index 8a03fd68..2a51d9b8 100644 --- a/tests/fast/spark/test_spark_functions_date.py +++ b/tests/fast/spark/test_spark_functions_date.py @@ -2,7 +2,7 @@ import pytest _ = pytest.importorskip("duckdb.experimental.spark") -from datetime import date, datetime, timezone +from datetime import date, datetime from spark_namespace import USE_ACTUAL_SPARK from spark_namespace.sql import functions as F @@ -217,3 +217,26 @@ def test_add_months(self, spark): assert result[0].with_literal == date(2024, 6, 12) assert result[0].with_str == date(2024, 7, 12) assert result[0].with_col == date(2024, 7, 12) + + def test_date_diff(self, spark): + df = spark.createDataFrame([('2015-04-08', '2015-05-10')], ["d1", "d2"]) + + result_data = df.select(F.date_diff(col("d2").cast('DATE'), col("d1").cast('DATE')).alias("diff")).collect() + assert result_data[0]["diff"] == -32 + + result_data = df.select(F.date_diff(col("d1").cast('DATE'), col("d2").cast('DATE')).alias("diff")).collect() + assert result_data[0]["diff"] == 32 + + def test_try_to_timestamp(self, spark): + df = spark.createDataFrame([("1997-02-28 10:30:00",), ("2024-01-01",), ("invalid",)], ["t"]) + res = df.select(F.try_to_timestamp(df.t).alias("dt")).collect() + assert res[0].dt == datetime(1997, 2, 28, 10, 30) + assert res[1].dt == datetime(2024, 1, 1, 0, 0) + assert res[2].dt is None + + def test_try_to_timestamp_with_format(self, spark): + df = spark.createDataFrame([("1997-02-28 10:30:00",), ("2024-01-01",), ("invalid",)], ["t"]) + res = df.select(F.try_to_timestamp(df.t, format=F.lit("%Y-%m-%d %H:%M:%S")).alias("dt")).collect() + assert res[0].dt == datetime(1997, 2, 28, 10, 30) + assert res[1].dt is None + assert res[2].dt is None \ No newline at end of file diff --git a/tests/fast/spark/test_spark_functions_null.py b/tests/fast/spark/test_spark_functions_null.py index 39ca4ce2..3f5ee31b 100644 --- a/tests/fast/spark/test_spark_functions_null.py +++ b/tests/fast/spark/test_spark_functions_null.py @@ -112,3 +112,8 @@ def test_isnotnull(self, spark): Row(a=1, b=None, r1=True, r2=False), Row(a=None, b=2, r1=False, r2=True), ] + + def test_equal_null(self, spark): + df = spark.createDataFrame([(1, 1), (None, 2), (None, None)], ("a", "b")) + res = df.select(F.equal_null("a", F.col("b")).alias("r")).collect() + assert res == [Row(r=True), Row(r=False), Row(r=True)] \ No newline at end of file diff --git a/tests/fast/spark/test_spark_functions_numeric.py b/tests/fast/spark/test_spark_functions_numeric.py index 3d7b5c3b..9c4bafb9 100644 --- a/tests/fast/spark/test_spark_functions_numeric.py +++ b/tests/fast/spark/test_spark_functions_numeric.py @@ -5,7 +5,7 @@ import math import numpy as np from spark_namespace import USE_ACTUAL_SPARK -from spark_namespace.sql import functions as F +from spark_namespace.sql import functions as sf from spark_namespace.sql.types import Row @@ -16,7 +16,7 @@ def test_greatest(self, spark): (4, 3), ] df = spark.createDataFrame(data, ["firstColumn", "secondColumn"]) - df = df.withColumn("greatest_value", F.greatest(F.col("firstColumn"), F.col("secondColumn"))) + df = df.withColumn("greatest_value", sf.greatest(sf.col("firstColumn"), sf.col("secondColumn"))) res = df.select("greatest_value").collect() assert res == [ Row(greatest_value=2), @@ -29,7 +29,7 @@ def test_least(self, spark): (4, 3), ] df = spark.createDataFrame(data, ["firstColumn", "secondColumn"]) - df = df.withColumn("least_value", F.least(F.col("firstColumn"), F.col("secondColumn"))) + df = df.withColumn("least_value", sf.least(sf.col("firstColumn"), sf.col("secondColumn"))) res = df.select("least_value").collect() assert res == [ Row(least_value=1), @@ -42,7 +42,7 @@ def test_ceil(self, spark): (2.9,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("ceil_value", F.ceil(F.col("firstColumn"))) + df = df.withColumn("ceil_value", sf.ceil(sf.col("firstColumn"))) res = df.select("ceil_value").collect() assert res == [ Row(ceil_value=2), @@ -55,7 +55,7 @@ def test_floor(self, spark): (2.9,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("floor_value", F.floor(F.col("firstColumn"))) + df = df.withColumn("floor_value", sf.floor(sf.col("firstColumn"))) res = df.select("floor_value").collect() assert res == [ Row(floor_value=1), @@ -68,7 +68,7 @@ def test_abs(self, spark): (-2.9,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("abs_value", F.abs(F.col("firstColumn"))) + df = df.withColumn("abs_value", sf.abs(sf.col("firstColumn"))) res = df.select("abs_value").collect() assert res == [ Row(abs_value=1.1), @@ -81,7 +81,7 @@ def test_sqrt(self, spark): (9,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("sqrt_value", F.sqrt(F.col("firstColumn"))) + df = df.withColumn("sqrt_value", sf.sqrt(sf.col("firstColumn"))) res = df.select("sqrt_value").collect() assert res == [ Row(sqrt_value=2.0), @@ -94,7 +94,7 @@ def test_cbrt(self, spark): (27,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("cbrt_value", F.cbrt(F.col("firstColumn"))) + df = df.withColumn("cbrt_value", sf.cbrt(sf.col("firstColumn"))) res = df.select("cbrt_value").collect() assert pytest.approx(res[0].cbrt_value) == 2.0 assert pytest.approx(res[1].cbrt_value) == 3.0 @@ -105,7 +105,7 @@ def test_cos(self, spark): (3.14159,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("cos_value", F.cos(F.col("firstColumn"))) + df = df.withColumn("cos_value", sf.cos(sf.col("firstColumn"))) res = df.select("cos_value").collect() assert len(res) == 2 assert res[0].cos_value == pytest.approx(1.0) @@ -117,7 +117,7 @@ def test_acos(self, spark): (-1,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("acos_value", F.acos(F.col("firstColumn"))) + df = df.withColumn("acos_value", sf.acos(sf.col("firstColumn"))) res = df.select("acos_value").collect() assert len(res) == 2 assert res[0].acos_value == pytest.approx(0.0) @@ -129,7 +129,7 @@ def test_exp(self, spark): (0.0,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("exp_value", F.exp(F.col("firstColumn"))) + df = df.withColumn("exp_value", sf.exp(sf.col("firstColumn"))) res = df.select("exp_value").collect() round(res[0].exp_value, 2) == 2 res[1].exp_value == 1 @@ -140,7 +140,7 @@ def test_factorial(self, spark): (5,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("factorial_value", F.factorial(F.col("firstColumn"))) + df = df.withColumn("factorial_value", sf.factorial(sf.col("firstColumn"))) res = df.select("factorial_value").collect() assert res == [ Row(factorial_value=24), @@ -153,7 +153,7 @@ def test_log2(self, spark): (8,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("log2_value", F.log2(F.col("firstColumn"))) + df = df.withColumn("log2_value", sf.log2(sf.col("firstColumn"))) res = df.select("log2_value").collect() assert res == [ Row(log2_value=2.0), @@ -166,7 +166,7 @@ def test_ln(self, spark): (1.0,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("ln_value", F.ln(F.col("firstColumn"))) + df = df.withColumn("ln_value", sf.ln(sf.col("firstColumn"))) res = df.select("ln_value").collect() round(res[0].ln_value, 2) == 1 res[1].ln_value == 0 @@ -177,7 +177,7 @@ def test_degrees(self, spark): (0.0,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("degrees_value", F.degrees(F.col("firstColumn"))) + df = df.withColumn("degrees_value", sf.degrees(sf.col("firstColumn"))) res = df.select("degrees_value").collect() round(res[0].degrees_value, 2) == 180 res[1].degrees_value == 0 @@ -188,7 +188,7 @@ def test_radians(self, spark): (0,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("radians_value", F.radians(F.col("firstColumn"))) + df = df.withColumn("radians_value", sf.radians(sf.col("firstColumn"))) res = df.select("radians_value").collect() round(res[0].radians_value, 2) == 3.14 res[1].radians_value == 0 @@ -199,7 +199,7 @@ def test_atan(self, spark): (0,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("atan_value", F.atan(F.col("firstColumn"))) + df = df.withColumn("atan_value", sf.atan(sf.col("firstColumn"))) res = df.select("atan_value").collect() round(res[0].atan_value, 2) == 0.79 res[1].atan_value == 0 @@ -212,19 +212,19 @@ def test_atan2(self, spark): df = spark.createDataFrame(data, ["firstColumn", "secondColumn"]) # Both columns - df2 = df.withColumn("atan2_value", F.atan2(F.col("firstColumn"), "secondColumn")) + df2 = df.withColumn("atan2_value", sf.atan2(sf.col("firstColumn"), "secondColumn")) res = df2.select("atan2_value").collect() round(res[0].atan2_value, 2) == 0.79 res[1].atan2_value == 0 # Both literals - df2 = df.withColumn("atan2_value_lit", F.atan2(1, 1)) + df2 = df.withColumn("atan2_value_lit", sf.atan2(1, 1)) res = df2.select("atan2_value_lit").collect() round(res[0].atan2_value_lit, 2) == 0.79 round(res[1].atan2_value_lit, 2) == 0.79 # One literal, one column - df2 = df.withColumn("atan2_value_lit_col", F.atan2(1.0, F.col("secondColumn"))) + df2 = df.withColumn("atan2_value_lit_col", sf.atan2(1.0, sf.col("secondColumn"))) res = df2.select("atan2_value_lit_col").collect() round(res[0].atan2_value_lit_col, 2) == 0.79 res[1].atan2_value_lit_col == 0 @@ -235,7 +235,7 @@ def test_tan(self, spark): (1,), ] df = spark.createDataFrame(data, ["firstColumn"]) - df = df.withColumn("tan_value", F.tan(F.col("firstColumn"))) + df = df.withColumn("tan_value", sf.tan(sf.col("firstColumn"))) res = df.select("tan_value").collect() res[0].tan_value == 0 round(res[1].tan_value, 2) == 1.56 @@ -249,9 +249,9 @@ def test_round(self, spark): ] df = spark.createDataFrame(data, ["firstColumn"]) df = ( - df.withColumn("round_value", F.round("firstColumn")) - .withColumn("round_value_1", F.round(F.col("firstColumn"), 1)) - .withColumn("round_value_minus_1", F.round("firstColumn", -1)) + df.withColumn("round_value", sf.round("firstColumn")) + .withColumn("round_value_1", sf.round(sf.col("firstColumn"), 1)) + .withColumn("round_value_minus_1", sf.round("firstColumn", -1)) ) res = df.select("round_value", "round_value_1", "round_value_minus_1").collect() assert res == [ @@ -269,9 +269,9 @@ def test_bround(self, spark): ] df = spark.createDataFrame(data, ["firstColumn"]) df = ( - df.withColumn("round_value", F.bround(F.col("firstColumn"))) - .withColumn("round_value_1", F.bround(F.col("firstColumn"), 1)) - .withColumn("round_value_minus_1", F.bround(F.col("firstColumn"), -1)) + df.withColumn("round_value", sf.bround(sf.col("firstColumn"))) + .withColumn("round_value_1", sf.bround(sf.col("firstColumn"), 1)) + .withColumn("round_value_minus_1", sf.bround(sf.col("firstColumn"), -1)) ) res = df.select("round_value", "round_value_1", "round_value_minus_1").collect() assert res == [ @@ -283,7 +283,7 @@ def test_bround(self, spark): def test_asin(self, spark): df = spark.createDataFrame([(0,), (2,)], ["value"]) - df = df.withColumn("asin_value", F.asin("value")) + df = df.withColumn("asin_value", sf.asin("value")) res = df.select("asin_value").collect() assert res[0].asin_value == 0 @@ -301,36 +301,36 @@ def test_corr(self, spark): # Have to use a groupby to test this as agg is not yet implemented without df = spark.createDataFrame(zip(a, b, ["group1"] * N), ["a", "b", "g"]) - res = df.groupBy("g").agg(F.corr("a", "b").alias('c')).collect() + res = df.groupBy("g").agg(sf.corr("a", "b").alias('c')).collect() assert pytest.approx(res[0].c) == 1 def test_cot(self, spark): df = spark.createDataFrame([(math.radians(45),)], ["value"]) - res = df.select(F.cot(df["value"]).alias("cot")).collect() + res = df.select(sf.cot(df["value"]).alias("cot")).collect() assert pytest.approx(res[0].cot) == 1 def test_e(self, spark): df = spark.createDataFrame([("value",)], ["value"]) - res = df.select(F.e().alias("e")).collect() + res = df.select(sf.e().alias("e")).collect() assert pytest.approx(res[0].e) == math.e def test_pi(self, spark): df = spark.createDataFrame([("value",)], ["value"]) - res = df.select(F.pi().alias("pi")).collect() + res = df.select(sf.pi().alias("pi")).collect() assert pytest.approx(res[0].pi) == math.pi def test_pow(self, spark): df = spark.createDataFrame([(2, 3)], ["a", "b"]) - res = df.select(F.pow(df["a"], df["b"]).alias("pow")).collect() + res = df.select(sf.pow(df["a"], df["b"]).alias("pow")).collect() assert res[0].pow == 8 def test_random(self, spark): df = spark.range(0, 2, 1) - res = df.withColumn('rand', F.rand()).collect() + res = df.withColumn('rand', sf.rand()).collect() assert isinstance(res[0].rand, float) assert res[0].rand >= 0 and res[0].rand < 1 @@ -338,13 +338,21 @@ def test_random(self, spark): assert isinstance(res[1].rand, float) assert res[1].rand >= 0 and res[1].rand < 1 - @pytest.mark.parametrize("sign_func", [F.sign, F.signum]) + @pytest.mark.parametrize("sign_func", [sf.sign, sf.signum]) def test_sign(self, spark, sign_func): - df = spark.range(1).select(sign_func(F.lit(-5).alias("v1")), sign_func(F.lit(6).alias("v2"))) + df = spark.range(1).select(sign_func(sf.lit(-5).alias("v1")), sign_func(sf.lit(6).alias("v2"))) res = df.collect() assert res == [Row(v1=-1.0, v2=1.0)] def test_sin(self, spark): df = spark.range(1) - res = df.select(F.sin(F.lit(math.radians(90))).alias("v")).collect() + res = df.select(sf.sin(sf.lit(math.radians(90))).alias("v")).collect() assert res == [Row(v=1.0)] + + def test_negative(self, spark): + df = spark.createDataFrame([(0,), (2,), (-3,)], ["value"]) + df = df.withColumn("value", sf.negative(sf.col("value"))) + res = df.collect() + assert res[0].value == 0 + assert res[1].value == -2 + assert res[2].value == -3 \ No newline at end of file From 3986c06b4bab85fe7eaeef6d88aeec9385b8c8e8 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 1 Sep 2025 12:29:48 +0200 Subject: [PATCH 129/472] Updated submodule --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 4bb3b055..24d2e45b 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 4bb3b05577ce60934fe4d3dd41fbfbd2b7bcdf11 +Subproject commit 24d2e45b14126a7083a0c01c6a45b75390a46922 From aaf091437f66d2c69715efd9b75110a97b1a33a6 Mon Sep 17 00:00:00 2001 From: Emil Sadek Date: Mon, 1 Sep 2025 09:19:12 -0700 Subject: [PATCH 130/472] Refactor EditorConfig file (#25) * Remove unused sections * Add base properties * Remove redundant properties * Add Python section --------- Co-authored-by: Emil Sadek --- .editorconfig | 43 +++++++++---------------------------------- 1 file changed, 9 insertions(+), 34 deletions(-) diff --git a/.editorconfig b/.editorconfig index aee02bd4..052e478e 100644 --- a/.editorconfig +++ b/.editorconfig @@ -1,27 +1,19 @@ -# Unix-style newlines with a newline ending every file -[*.{c,cpp,h,hpp}] +root = true + +[*] +charset = utf-8 end_of_line = lf insert_final_newline = true -indent_style = tab -tab_width = 4 -indent_size = tab trim_trailing_whitespace = true -charset = utf-8 -max_line_length = 120 -x-soft-wrap-text = true -x-soft-wrap-mode = CharacterWidth -x-soft-wrap-limit = 120 -x-show-invisibles = false -x-show-spaces = false -[*.{java}] -end_of_line = lf -insert_final_newline = true +[*.{py,pyi}] +indent_style = space +indent_size = 4 + +[*.{c,cpp,h,hpp}] indent_style = tab tab_width = 4 indent_size = tab -trim_trailing_whitespace = false -charset = utf-8 max_line_length = 120 x-soft-wrap-text = true x-soft-wrap-mode = CharacterWidth @@ -29,25 +21,8 @@ x-soft-wrap-limit = 120 x-show-invisibles = false x-show-spaces = false -[*.{test,test_slow,test_coverage,benchmark}] -end_of_line = lf -insert_final_newline = true -indent_style = tab -tab_width = 4 -indent_size = tab -trim_trailing_whitespace = false -charset = utf-8 -x-soft-wrap-text = false - [Makefile] -end_of_line = lf -insert_final_newline = true indent_style = tab tab_width = 4 indent_size = tab -trim_trailing_whitespace = true -charset = utf-8 x-soft-wrap-text = false - -[*keywords.list] -insert_final_newline = false From 716b89501f694d8052fc3b39a382743542888109 Mon Sep 17 00:00:00 2001 From: Emil Sadek Date: Mon, 1 Sep 2025 09:19:59 -0700 Subject: [PATCH 131/472] Update project URLs (#24) * Update issues URL * Update source URL --------- Authored-by: Emil Sadek --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 07eea22e..6291b811 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,8 +37,8 @@ maintainers = [{name = "DuckDB Foundation"}] [project.urls] Documentation = "https://duckdb.org/docs/stable/clients/python/overview" -Source = "https://github.com/duckdb/duckdb/blob/main/tools/pythonpkg" -Issues = "https://github.com/duckdb/duckdb/issues" +Source = "https://github.com/duckdb/duckdb-python" +Issues = "https://github.com/duckdb/duckdb-python/issues" Changelog = "https://github.com/duckdb/duckdb/releases" [project.optional-dependencies] From a5a68c8596f4d827da8fe00685a27f9226b184a4 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 2 Sep 2025 16:49:04 +0200 Subject: [PATCH 132/472] Better streamlined release pipeline --- .github/workflows/cleanup_pypi.yml | 1 + .github/workflows/packaging.yml | 137 ++----------------- .github/workflows/packaging_sdist.yml | 94 +++++++++++++ .github/workflows/packaging_wheels.yml | 96 +++++++++++++ .github/workflows/release.yml | 181 ++++++++++++++++--------- 5 files changed, 320 insertions(+), 189 deletions(-) create mode 100644 .github/workflows/packaging_sdist.yml create mode 100644 .github/workflows/packaging_wheels.yml diff --git a/.github/workflows/cleanup_pypi.yml b/.github/workflows/cleanup_pypi.yml index 86e96c49..c4300be3 100644 --- a/.github/workflows/cleanup_pypi.yml +++ b/.github/workflows/cleanup_pypi.yml @@ -62,6 +62,7 @@ jobs: uv sync --only-group pypi --no-install-project uv run --no-sync python -u -m duckdb_packaging.pypi_cleanup ${{ inputs.dry-run && '--dry' || '' }} \ ${{ inputs.environment == 'pypi-prod-nightly' && '--prod' || '--test' }} \ + --verbose \ --username "${{ vars.PYPI_CLEANUP_USERNAME }}" \ --max-nightlies ${{ vars.PYPI_MAX_NIGHTLIES }} 2>&1 | tee cleanup_output diff --git a/.github/workflows/packaging.yml b/.github/workflows/packaging.yml index 573919c8..0dafaf75 100644 --- a/.github/workflows/packaging.yml +++ b/.github/workflows/packaging.yml @@ -63,127 +63,20 @@ defaults: jobs: build_sdist: - name: Build sdist - runs-on: ubuntu-24.04 - steps: - - - name: Checkout DuckDB Python - uses: actions/checkout@v4 - with: - ref: ${{ inputs.git-ref }} - fetch-depth: 0 - submodules: true - - - name: Checkout DuckDB - shell: bash - run: | - cd external/duckdb - git fetch origin - git checkout ${{ inputs.duckdb-git-ref }} - - - name: Set OVERRIDE_GIT_DESCRIBE - if: ${{ inputs.set-version != '' }} - run: echo "OVERRIDE_GIT_DESCRIBE=${{ inputs.set-version }}" >> $GITHUB_ENV - - - name: Install Astral UV - uses: astral-sh/setup-uv@v6 - with: - version: "0.7.14" - python-version: 3.11 - - - name: Build sdist - run: uv build --sdist - - - name: Install sdist - run: | - cd ${{ runner.temp }} - uv venv - uv pip install ${{ github.workspace }}/dist/duckdb-*.tar.gz - - - name: Test sdist - if: ${{ inputs.testsuite != 'none' }} - run: | - # install the test requirements - uv export --only-group test --no-emit-project --output-file ${{ runner.temp }}/pylock.toml --quiet - cd ${{ runner.temp }} - uv pip install -r pylock.toml - # run tests - tests_root="${{ github.workspace }}/tests" - tests_dir="${tests_root}${{ inputs.testsuite == 'fast' && '/fast' || '/' }}" - uv run --verbose pytest $tests_dir --verbose --ignore=${tests_root}/stubs - - - uses: actions/upload-artifact@v4 - with: - name: sdist - path: dist/*.tar.gz - compression-level: 0 + name: Build an sdist and determine versions + uses: ./.github/workflows/packaging_sdist.yml + with: + testsuite: all + git-ref: ${{ github.ref }} + duckdb-git-ref: ${{ inputs.duckdb-sha }} + set-version: ${{ inputs.stable-version }} build_wheels: - name: 'Wheel: ${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }}' - strategy: - fail-fast: false - matrix: - python: [ cp39, cp310, cp311, cp312, cp313 ] - platform: - - { os: windows-2025, arch: amd64, cibw_system: win } - - { os: ubuntu-24.04, arch: x86_64, cibw_system: manylinux } - - { os: ubuntu-24.04-arm, arch: aarch64, cibw_system: manylinux } - - { os: macos-15, arch: arm64, cibw_system: macosx } - - { os: macos-15, arch: universal2, cibw_system: macosx } - - { os: macos-13, arch: x86_64, cibw_system: macosx } - minimal: - - ${{ inputs.minimal }} - exclude: - - { minimal: true, python: cp310 } - - { minimal: true, python: cp311 } - - { minimal: true, python: cp312 } - - { minimal: true, platform: { arch: universal2 } } - runs-on: ${{ matrix.platform.os }} - env: - CIBW_TEST_SKIP: ${{ inputs.testsuite == 'none' && '*' || '*-macosx_universal2' }} - CIBW_TEST_SOURCES: tests - CIBW_BEFORE_TEST: > - uv export --only-group test --no-emit-project --output-file pylock.toml --directory {project} && - uv pip install -r pylock.toml - CIBW_TEST_COMMAND: > - uv run -v pytest ${{ inputs.testsuite == 'fast' && './tests/fast' || './tests' }} --verbose --ignore=./tests/stubs - - steps: - - name: Checkout DuckDB Python - uses: actions/checkout@v4 - with: - ref: ${{ inputs.git-ref }} - fetch-depth: 0 - submodules: true - - - name: Checkout DuckDB - shell: bash - run: | - cd external/duckdb - git fetch origin - git checkout ${{ inputs.duckdb-git-ref }} - - # Make sure that OVERRIDE_GIT_DESCRIBE is propagated to cibuildwhel's env, also when it's running linux builds - - name: Set OVERRIDE_GIT_DESCRIBE - if: ${{ inputs.set-version != '' }} - run: echo "CIBW_ENVIRONMENT=OVERRIDE_GIT_DESCRIBE=${{ inputs.set-version }}" >> $GITHUB_ENV - - # Install Astral UV, which will be used as build-frontend for cibuildwheel - - uses: astral-sh/setup-uv@v6 - with: - version: "0.7.14" - enable-cache: false - cache-suffix: -${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} - - - name: Build${{ inputs.testsuite != 'none' && ' and test ' || ' ' }}wheels - uses: pypa/cibuildwheel@v3.0 - env: - CIBW_ARCHS: ${{ matrix.platform.arch == 'amd64' && 'AMD64' || matrix.platform.arch }} - CIBW_BUILD: ${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} - - - name: Upload wheel - uses: actions/upload-artifact@v4 - with: - name: wheel-${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} - path: wheelhouse/*.whl - compression-level: 0 + name: Build and test releases + uses: ./.github/workflows/packaging_wheels.yml + with: + minimal: false + testsuite: all + git-ref: ${{ github.ref }} + duckdb-git-ref: ${{ inputs.duckdb-sha }} + set-version: ${{ inputs.stable-version }} diff --git a/.github/workflows/packaging_sdist.yml b/.github/workflows/packaging_sdist.yml new file mode 100644 index 00000000..e8ee97f2 --- /dev/null +++ b/.github/workflows/packaging_sdist.yml @@ -0,0 +1,94 @@ +name: Sdist packaging +on: + workflow_call: + inputs: + testsuite: + type: string + description: Testsuite to run (none, fast, all) + required: true + default: all + git-ref: + type: string + description: Git ref of the DuckDB python package + required: false + duckdb-git-ref: + type: string + description: Git ref of DuckDB + required: false + set-version: + description: Force version (vX.Y.Z-((rc|post)N)) + required: false + type: string + outputs: + package-version: + description: The version of the DuckDB Python package + value: ${{ jobs.build_sdist.outputs.pkg_version }} + duckdb-version: + description: The version of DuckDB that was packaged + value: ${{ jobs.build_sdist.outputs.duckdb_version }} + +jobs: + build_sdist: + name: Build sdist + runs-on: ubuntu-24.04 + outputs: + pkg_version: ${{ steps.versioning.outputs.pkg_version }} + duckdb_version: ${{ steps.versioning.outputs.duckdb_version }} + steps: + + - name: Checkout DuckDB Python + uses: actions/checkout@v4 + with: + ref: ${{ inputs.git-ref }} + fetch-depth: 0 + submodules: true + + - name: Checkout DuckDB + shell: bash + run: | + cd external/duckdb + git fetch origin + git checkout ${{ inputs.duckdb-git-ref }} + + - name: Set OVERRIDE_GIT_DESCRIBE + if: ${{ inputs.set-version != '' }} + run: echo "OVERRIDE_GIT_DESCRIBE=${{ inputs.set-version }}" >> $GITHUB_ENV + + - name: Install Astral UV + uses: astral-sh/setup-uv@v6 + with: + version: "0.7.14" + python-version: 3.11 + + - name: Build sdist + run: uv build --sdist + + - name: Install sdist + run: | + cd ${{ runner.temp }} + uv venv + uv pip install ${{ github.workspace }}/dist/duckdb-*.tar.gz + + - name: Test sdist + if: ${{ inputs.testsuite != 'none' }} + run: | + # install the test requirements + uv export --only-group test --no-emit-project --output-file ${{ runner.temp }}/pylock.toml --quiet + cd ${{ runner.temp }} + uv pip install -r pylock.toml + # run tests + tests_root="${{ github.workspace }}/tests" + tests_dir="${tests_root}${{ inputs.testsuite == 'fast' && '/fast' || '/' }}" + uv run --verbose pytest $tests_dir --verbose --ignore=${tests_root}/stubs + + - id: versioning + run: | + cd ${{ runner.temp }} + echo "pkg_version=$( .venv/bin/python -c 'import duckdb; print(duckdb.__version__)' )" >> $GITHUB_OUTPUT + echo "duckdb_version=$( .venv/bin/python -c 'import duckdb; print(duckdb.duckdb_version)' )" >> $GITHUB_OUTPUT + + - uses: actions/upload-artifact@v4 + with: + name: sdist + path: dist/*.tar.gz + compression-level: 0 diff --git a/.github/workflows/packaging_wheels.yml b/.github/workflows/packaging_wheels.yml new file mode 100644 index 00000000..f74b5f09 --- /dev/null +++ b/.github/workflows/packaging_wheels.yml @@ -0,0 +1,96 @@ +name: Wheels packaging +on: + workflow_call: + inputs: + minimal: + type: boolean + description: Build a minimal set of wheels to do a sanity check + default: false + testsuite: + type: string + description: Testsuite to run (none, fast, all) + required: true + default: all + git-ref: + type: string + description: Git ref of the DuckDB python package + required: false + duckdb-git-ref: + type: string + description: Git ref of DuckDB + required: false + set-version: + description: Force version (vX.Y.Z-((rc|post)N)) + required: false + type: string + +jobs: + build_wheels: + name: 'Wheel: ${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }}' + strategy: + fail-fast: false + matrix: + python: [ cp39, cp310, cp311, cp312, cp313 ] + platform: + - { os: windows-2025, arch: amd64, cibw_system: win } + - { os: ubuntu-24.04, arch: x86_64, cibw_system: manylinux } + - { os: ubuntu-24.04-arm, arch: aarch64, cibw_system: manylinux } + - { os: macos-15, arch: arm64, cibw_system: macosx } + - { os: macos-15, arch: universal2, cibw_system: macosx } + - { os: macos-13, arch: x86_64, cibw_system: macosx } + minimal: + - ${{ inputs.minimal }} + exclude: + - { minimal: true, python: cp310 } + - { minimal: true, python: cp311 } + - { minimal: true, python: cp312 } + - { minimal: true, platform: { arch: universal2 } } + runs-on: ${{ matrix.platform.os }} + env: + CIBW_TEST_SKIP: ${{ inputs.testsuite == 'none' && '*' || '*-macosx_universal2' }} + CIBW_TEST_SOURCES: tests + CIBW_BEFORE_TEST: > + uv export --only-group test --no-emit-project --output-file pylock.toml --directory {project} && + uv pip install -r pylock.toml + CIBW_TEST_COMMAND: > + uv run -v pytest ${{ inputs.testsuite == 'fast' && './tests/fast' || './tests' }} --verbose --ignore=./tests/stubs + + steps: + - name: Checkout DuckDB Python + uses: actions/checkout@v4 + with: + ref: ${{ inputs.git-ref }} + fetch-depth: 0 + submodules: true + + - name: Checkout DuckDB + shell: bash + run: | + cd external/duckdb + git fetch origin + git checkout ${{ inputs.duckdb-git-ref }} + + # Make sure that OVERRIDE_GIT_DESCRIBE is propagated to cibuildwhel's env, also when it's running linux builds + - name: Set OVERRIDE_GIT_DESCRIBE + if: ${{ inputs.set-version != '' }} + run: echo "CIBW_ENVIRONMENT=OVERRIDE_GIT_DESCRIBE=${{ inputs.set-version }}" >> $GITHUB_ENV + + # Install Astral UV, which will be used as build-frontend for cibuildwheel + - uses: astral-sh/setup-uv@v6 + with: + version: "0.7.14" + enable-cache: false + cache-suffix: -${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} + + - name: Build${{ inputs.testsuite != 'none' && ' and test ' || ' ' }}wheels + uses: pypa/cibuildwheel@v3.0 + env: + CIBW_ARCHS: ${{ matrix.platform.arch == 'amd64' && 'AMD64' || matrix.platform.arch }} + CIBW_BUILD: ${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} + + - name: Upload wheel + uses: actions/upload-artifact@v4 + with: + name: wheel-${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} + path: wheelhouse/*.whl + compression-level: 0 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9a6dbf0b..721c92fa 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,7 +20,7 @@ on: - prod store-s3: type: boolean - description: Also store test packages in S3 + description: Also store test packages in S3 (always true for prod) default: false defaults: @@ -28,9 +28,86 @@ defaults: shell: bash jobs: - build_and_test: + build_sdist: + name: Build an sdist and determine versions + uses: ./.github/workflows/packaging_sdist.yml + with: + testsuite: all + git-ref: ${{ github.ref }} + duckdb-git-ref: ${{ inputs.duckdb-sha }} + set-version: ${{ inputs.stable-version }} + + workflow_state: + name: Set state for the release workflow + needs: build_sdist + outputs: + pypi_state: ${{ steps.index_check.outputs.pypi_state }} + ci_env: ${{ steps.ci_env_check.outputs.ci_env }} + s3_url: ${{ steps.s3_check.outputs.s3_url }} + runs-on: ubuntu-latest + steps: + - id: index_check + name: Check ${{ needs.build_sdist.outputs.package-version }} on PyPI + run: | + set -eu + # Check PyPI whether the release we're building is already present + pypi_hostname=${{ inputs.pypi-index == 'test' && 'test.' || '' }}pypi.org + pkg_version=${{ needs.build_sdist.outputs.package-version }} + url=https://${pypi_hostname}/pypi/duckdb/${pkg_version}/json + http_status=$( curl -s -o /dev/null -w "%{http_code}" $url || echo $? ) + if [[ $http_status == "200" ]]; then + echo "::warning::Package version ${pkg_version} is already present on ${pypi_hostname}" + pypi_state=VERSION_FOUND + elif [[ $http_status == 000* ]]; then + echo "::error::Error checking PyPI at ${url}: curl exit code ${http_status#'000'}" + pypi_state=UNKNOWN + else + echo "::notice::Package version ${pkg_version} not found on ${pypi_hostname} (http status: ${http_status})" + pypi_state=VERSION_NOT_FOUND + fi + echo "pypi_state=${pypi_state}" >> $GITHUB_OUTPUT + + - id: ci_env_check + name: Determine CI environment + run: | + set -eu + if [[ test == "${{ inputs.pypi-index }}" ]]; then + ci_env=pypi-test + elif [[ prod == "${{ inputs.pypi-index }}" ]]; then + ci_env=pypi-prod${{ inputs.stable-version && '' || '-nightly' }} + else + echo "::error::Invalid value for inputs.pypi-index: ${{ inputs.pypi-index }}" + exit 1 + fi + echo "ci_env=${ci_env}" >> "$GITHUB_OUTPUT" + echo "::notice::Using CI environment ${ci_env}" + + - id: s3_check + name: Generate S3 upload URL + if: github.repository_owner == 'duckdb' + run: | + set -eu + should_store=${{ (inputs.pypi-index == 'prod' || inputs.store-s3) && '1' || '0' }} + if [[ $should_store == 0 ]]; then + echo "::notice::S3 upload disabled in inputs, not generating S3 URL" + exit 0 + fi + if [[ VERSION_FOUND == "${{ steps.index_check.outputs.pypi_state }}" ]]; then + echo "::warning::S3 upload disabled because package version already uploaded to PyPI" + exit 0 + fi + sha=${{ github.sha }} + dsha=${{ inputs.duckdb-sha }} + version=${{ needs.build_sdist.outputs.package-version }} + s3_url="s3://duckdb-staging/python/${version}/${sha:0:10}-duckdb-${dsha:0:10}/" + echo "::notice::Generated S3 URL: ${s3_url}" + echo "s3_url=${s3_url}" >> $GITHUB_OUTPUT + + build_wheels: name: Build and test releases - uses: ./.github/workflows/packaging.yml + needs: workflow_state + if: ${{ needs.workflow_state.outputs.pypi_state != 'VERSION_FOUND' }} + uses: ./.github/workflows/packaging_wheels.yml with: minimal: false testsuite: all @@ -41,8 +118,8 @@ jobs: upload_s3: name: Upload Artifacts to S3 runs-on: ubuntu-latest - needs: [build_and_test] - if: ${{ github.repository_owner == 'duckdb' && ( inputs.pypi-index == 'prod' || inputs.store-s3 ) }} + needs: [build_sdist, build_wheels, workflow_state] + if: ${{ needs.workflow_state.outputs.s3_url }} steps: - name: Fetch artifacts uses: actions/download-artifact@v4 @@ -51,18 +128,6 @@ jobs: path: artifacts/ merge-multiple: true - - name: Compute upload input - id: input - run: | - sha=${{ github.sha }} - dsha=${{ inputs.duckdb-sha }} - version=$(basename artifacts/*.tar.gz | sed 's/duckdb-\(.*\).tar.gz/\1/g') - url="s3://duckdb-staging/python/${version}/${sha:0:10}-duckdb-${dsha:0:10}/" - echo "short_sha=${sha:0:10}" >> $GITHUB_OUTPUT - echo "short_dsha=${dsha:0:10}" >> $GITHUB_OUTPUT - echo "version=${version}" >> $GITHUB_OUTPUT - echo "s3_upload_url=${url}" >> $GITHUB_OUTPUT - - name: Authenticate with AWS uses: aws-actions/configure-aws-credentials@v4 with: @@ -72,57 +137,21 @@ jobs: - name: Upload Artifacts run: | - aws s3 cp artifacts ${{ steps.input.outputs.s3_upload_url }} --recursive - - - name: S3 Upload Summary - run : | - echo "## S3 Upload Summary" >> $GITHUB_STEP_SUMMARY - echo "* Version: ${{ steps.input.outputs.version }}" >> $GITHUB_STEP_SUMMARY - echo "* SHA: ${{ steps.input.outputs.short_sha }}" >> $GITHUB_STEP_SUMMARY - echo "* DuckDB SHA: ${{ steps.input.outputs.short_dsha }}" >> $GITHUB_STEP_SUMMARY - echo "* S3 URL: ${{ steps.input.outputs.s3_upload_url }}" >> $GITHUB_STEP_SUMMARY - - determine_environment: - name: Determine the Github Actions environment to use - runs-on: ubuntu-latest - needs: build_and_test - outputs: - env_name: ${{ steps.set-env.outputs.env_name }} - steps: - - name: Set environment name - id: set-env - run: | - set -euo pipefail - case "${{ inputs.pypi-index }}" in - test) - echo "env_name=pypi-test" >> "$GITHUB_OUTPUT" - ;; - prod) - if [[ -n "${{ inputs.stable-version }}" ]]; then - echo "env_name=pypi-prod" >> "$GITHUB_OUTPUT" - else - echo "env_name=pypi-prod-nightly" >> "$GITHUB_OUTPUT" - fi - ;; - *) - echo "Error: invalid combination of inputs.pypi-index='${{ inputs.pypi-index }}' and inputs.stable-version='${{ inputs.stable-version }}'" >&2 - exit 1 - ;; - esac + aws s3 cp artifacts ${{ needs.workflow_state.outputs.s3_url }} --recursive publish_pypi: name: Publish Artifacts to PyPI runs-on: ubuntu-latest - needs: determine_environment + needs: [workflow_state, build_sdist, build_wheels] environment: - name: ${{ needs.determine_environment.outputs.env_name }} + name: ${{ needs.workflow_state.outputs.ci_env }} permissions: # this is needed for the OIDC flow that is used with trusted publishing on PyPI id-token: write steps: - if: ${{ vars.PYPI_HOST == '' }} run: | - echo "Error: PYPI_HOST is not set in CI environment '${{ needs.determine_environment.outputs.env_name }}'" + echo "Error: PYPI_HOST is not set in CI environment '${{ needs.workflow_state.outputs.ci_env }}'" exit 1 - name: Fetch artifacts @@ -139,22 +168,40 @@ jobs: packages-dir: packages verbose: 'true' - - name: PyPI Upload Summary - run : | - version=$(basename packages/*.tar.gz | sed 's/duckdb-\(.*\).tar.gz/\1/g') - echo "## PyPI Upload Summary" >> $GITHUB_STEP_SUMMARY - echo "* Version: ${version}" >> $GITHUB_STEP_SUMMARY - echo "* PyPI Host: ${{ vars.PYPI_HOST }}" >> $GITHUB_STEP_SUMMARY - echo "* CI Environment: ${{ needs.determine_environment.outputs.env_name }}" >> $GITHUB_STEP_SUMMARY - cleanup_nightlies: name: Remove Nightlies from PyPI - needs: [determine_environment, publish_pypi] + needs: [workflow_state, publish_pypi] if: ${{ inputs.stable-version == '' }} uses: ./.github/workflows/cleanup_pypi.yml with: - environment: ${{ needs.determine_environment.outputs.env_name }} + environment: ${{ needs.workflow_state.outputs.ci_env }} secrets: # reusable workflows and secrets are not great: https://github.com/actions/runner/issues/3206 PYPI_CLEANUP_OTP: ${{secrets.PYPI_CLEANUP_OTP}} PYPI_CLEANUP_PASSWORD: ${{secrets.PYPI_CLEANUP_PASSWORD}} + + summary: + name: Release summary + runs-on: ubuntu-latest + needs: [build_sdist, workflow_state, build_wheels, upload_s3, publish_pypi, cleanup_nightlies] + if: true + steps: + - run: | + sha=${{ github.sha }} + dsha=${{ inputs.duckdb-sha }} + pversion=${{ needs.build_sdist.outputs.package-version }} + long_pversion="${pversion} (${sha:0:10})" + pypi_host=${{ inputs.pypi-index == 'test' && 'test.' || '' }}pypi.org + pypi_duckdb_url=https://${pypi_host}/project/duckdb/${pversion}/ + was_released=${{ needs.publish_pypi.result == 'success' && '1' || '0' }} + if [[ $was_released == 1 ]]; then + echo "## Version ${long_pversion} successfully released" >> $GITHUB_STEP_SUMMARY + echo "* Package URL: [${pypi_duckdb_url}](${pypi_duckdb_url})" >> $GITHUB_STEP_SUMMARY + else + echo "## Version ${long_pversion} was not released" >> $GITHUB_STEP_SUMMARY + echo "* Package index state before release: ${{ needs.workflow_state.outputs.pypi_state }}" >> $GITHUB_STEP_SUMMARY + fi + echo "* Package index: ${pypi_host}" >> $GITHUB_STEP_SUMMARY + echo "* Vendored DuckDB Version: ${{ needs.build_sdist.outputs.duckdb-version }} (${dsha:0:10})" >> $GITHUB_STEP_SUMMARY + echo "* S3 upload status: ${{ needs.upload_s3.result == 'success' && needs.workflow_state.outputs.s3_url || needs.upload_s3.result }}" >> $GITHUB_STEP_SUMMARY + echo "* CI Environment: ${{ needs.workflow_state.outputs.ci_env }}" >> $GITHUB_STEP_SUMMARY From 814c1b7cb08db879de73d2fd646be5f6396701fd Mon Sep 17 00:00:00 2001 From: Diego Sevilla Ruiz Date: Wed, 3 Sep 2025 14:21:29 +0200 Subject: [PATCH 133/472] Update from_parquet and read_parquet method signatures. from_parquet and read_parquet are incorrectly described as receiving just a str parameter where they allow to receive also a list of files/globs as a list of str. This fixes #26, although more work is needed because this file is auto-generated. --- duckdb/__init__.pyi | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/duckdb/__init__.pyi b/duckdb/__init__.pyi index adf142dd..91945dfd 100644 --- a/duckdb/__init__.pyi +++ b/duckdb/__init__.pyi @@ -345,8 +345,10 @@ class DuckDBPyConnection: def from_csv_auto(self, path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str| List[str]] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None, lineterminator: Optional[str] = None, columns: Optional[Dict[str, str]] = None, auto_type_candidates: Optional[List[str]] = None, max_line_size: Optional[int] = None, ignore_errors: Optional[bool] = None, store_rejects: Optional[bool] = None, rejects_table: Optional[str] = None, rejects_scan: Optional[str] = None, rejects_limit: Optional[int] = None, force_not_null: Optional[List[str]] = None, buffer_size: Optional[int] = None, decimal: Optional[str] = None, allow_quoted_nulls: Optional[bool] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None) -> DuckDBPyRelation: ... def from_df(self, df: pandas.DataFrame) -> DuckDBPyRelation: ... def from_arrow(self, arrow_object: object) -> DuckDBPyRelation: ... - def from_parquet(self, file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None) -> DuckDBPyRelation: ... - def read_parquet(self, file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None) -> DuckDBPyRelation: ... + # stubgen override + def from_parquet(self, file_or_files_glob: Union[str, List[str]], binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None) -> DuckDBPyRelation: ... + def read_parquet(self, file_or_files_glob: Union[str, List[str]], binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None) -> DuckDBPyRelation: ... + # end stubgen override def get_table_names(self, query: str, *, qualified: bool = False) -> Set[str]: ... def install_extension(self, extension: str, *, force_install: bool = False, repository: Optional[str] = None, repository_url: Optional[str] = None, version: Optional[str] = None) -> None: ... def load_extension(self, extension: str) -> None: ... @@ -693,8 +695,10 @@ def read_csv(path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Option def from_csv_auto(path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str| List[str]] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None, lineterminator: Optional[str] = None, columns: Optional[Dict[str, str]] = None, auto_type_candidates: Optional[List[str]] = None, max_line_size: Optional[int] = None, ignore_errors: Optional[bool] = None, store_rejects: Optional[bool] = None, rejects_table: Optional[str] = None, rejects_scan: Optional[str] = None, rejects_limit: Optional[int] = None, force_not_null: Optional[List[str]] = None, buffer_size: Optional[int] = None, decimal: Optional[str] = None, allow_quoted_nulls: Optional[bool] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... def from_df(df: pandas.DataFrame, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... def from_arrow(arrow_object: object, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def from_parquet(file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def read_parquet(file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... +# stubgen override +def from_parquet(file_or_files_glob: Union[str, List[str]], binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... +def read_parquet(file_or_files_glob: Union[str, List[str]], binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... +# end stubgen override def get_table_names(query: str, *, qualified: bool = False, connection: DuckDBPyConnection = ...) -> Set[str]: ... def install_extension(extension: str, *, force_install: bool = False, repository: Optional[str] = None, repository_url: Optional[str] = None, version: Optional[str] = None, connection: DuckDBPyConnection = ...) -> None: ... def load_extension(extension: str, *, connection: DuckDBPyConnection = ...) -> None: ... From b928d17abcf738b1e0867ebf998e8aa783c9dc86 Mon Sep 17 00:00:00 2001 From: Emil Sadek Date: Wed, 3 Sep 2025 18:36:51 -0700 Subject: [PATCH 134/472] Fix capitalization --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index c70d6e2b..5f81ff5e 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,8 @@

- discord - PyPi Latest Release + Discord + PyPI Latest Release


@@ -30,7 +30,7 @@ ## Installation -Install the latest release of DuckDB directly from [PyPi](https://pypi.org/project/duckdb/): +Install the latest release of DuckDB directly from [PyPI](https://pypi.org/project/duckdb/): ```bash pip install duckdb @@ -169,11 +169,11 @@ uvx gcovr \ ### Typechecking and linting - We're not running any mypy typechecking tests at the moment -- We're not running any ruff / linting / formatting at the moment +- We're not running any Ruff / linting / formatting at the moment ### Cibuildwheel -You can run cibuildwheel locally for linux. E.g. limited to Python 3.9: +You can run cibuildwheel locally for Linux. E.g. limited to Python 3.9: ```bash CIBW_BUILD='cp39-*' uvx cibuildwheel --platform linux . ``` @@ -186,7 +186,7 @@ CIBW_BUILD='cp39-*' uvx cibuildwheel --platform linux . ### Tooling This codebase is developed with the following tools: -- [Astral UV](https://docs.astral.sh/uv/) - for dependency management across all platforms we provide wheels for, +- [Astral uv](https://docs.astral.sh/uv/) - for dependency management across all platforms we provide wheels for, and for Python environment management. It will be hard to work on this codebase without having UV installed. - [Scikit-build-core](https://scikit-build-core.readthedocs.io/en/latest/index.html) - the build backend for building the extension. On the background, scikit-build-core uses cmake and ninja for compilation. From e7ca70e1943f62368c3e1fbf437936bfa94e49b7 Mon Sep 17 00:00:00 2001 From: Diego Sevilla Ruiz Date: Thu, 4 Sep 2025 11:25:13 +0200 Subject: [PATCH 135/472] Revert "Update from_parquet and read_parquet method signatures." This reverts commit 814c1b7cb08db879de73d2fd646be5f6396701fd. --- duckdb/__init__.pyi | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/duckdb/__init__.pyi b/duckdb/__init__.pyi index 91945dfd..adf142dd 100644 --- a/duckdb/__init__.pyi +++ b/duckdb/__init__.pyi @@ -345,10 +345,8 @@ class DuckDBPyConnection: def from_csv_auto(self, path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str| List[str]] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None, lineterminator: Optional[str] = None, columns: Optional[Dict[str, str]] = None, auto_type_candidates: Optional[List[str]] = None, max_line_size: Optional[int] = None, ignore_errors: Optional[bool] = None, store_rejects: Optional[bool] = None, rejects_table: Optional[str] = None, rejects_scan: Optional[str] = None, rejects_limit: Optional[int] = None, force_not_null: Optional[List[str]] = None, buffer_size: Optional[int] = None, decimal: Optional[str] = None, allow_quoted_nulls: Optional[bool] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None) -> DuckDBPyRelation: ... def from_df(self, df: pandas.DataFrame) -> DuckDBPyRelation: ... def from_arrow(self, arrow_object: object) -> DuckDBPyRelation: ... - # stubgen override - def from_parquet(self, file_or_files_glob: Union[str, List[str]], binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None) -> DuckDBPyRelation: ... - def read_parquet(self, file_or_files_glob: Union[str, List[str]], binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None) -> DuckDBPyRelation: ... - # end stubgen override + def from_parquet(self, file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None) -> DuckDBPyRelation: ... + def read_parquet(self, file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None) -> DuckDBPyRelation: ... def get_table_names(self, query: str, *, qualified: bool = False) -> Set[str]: ... def install_extension(self, extension: str, *, force_install: bool = False, repository: Optional[str] = None, repository_url: Optional[str] = None, version: Optional[str] = None) -> None: ... def load_extension(self, extension: str) -> None: ... @@ -695,10 +693,8 @@ def read_csv(path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Option def from_csv_auto(path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str| List[str]] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None, lineterminator: Optional[str] = None, columns: Optional[Dict[str, str]] = None, auto_type_candidates: Optional[List[str]] = None, max_line_size: Optional[int] = None, ignore_errors: Optional[bool] = None, store_rejects: Optional[bool] = None, rejects_table: Optional[str] = None, rejects_scan: Optional[str] = None, rejects_limit: Optional[int] = None, force_not_null: Optional[List[str]] = None, buffer_size: Optional[int] = None, decimal: Optional[str] = None, allow_quoted_nulls: Optional[bool] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... def from_df(df: pandas.DataFrame, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... def from_arrow(arrow_object: object, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -# stubgen override -def from_parquet(file_or_files_glob: Union[str, List[str]], binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def read_parquet(file_or_files_glob: Union[str, List[str]], binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -# end stubgen override +def from_parquet(file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... +def read_parquet(file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... def get_table_names(query: str, *, qualified: bool = False, connection: DuckDBPyConnection = ...) -> Set[str]: ... def install_extension(extension: str, *, force_install: bool = False, repository: Optional[str] = None, repository_url: Optional[str] = None, version: Optional[str] = None, connection: DuckDBPyConnection = ...) -> None: ... def load_extension(extension: str, *, connection: DuckDBPyConnection = ...) -> None: ... From 8ff11afd186588615af4db3ca14ba3edf7cfcb0a Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 4 Sep 2025 11:36:39 +0200 Subject: [PATCH 136/472] Protect submodule in workflow --- .github/workflows/on_pr.yml | 24 +++----------- .github/workflows/on_push_protected.yml | 10 ++++++ .github/workflows/packaging.yml | 27 ++++++++-------- .github/workflows/packaging_sdist.yml | 13 ++++---- .github/workflows/packaging_wheels.yml | 13 ++++---- .github/workflows/release.yml | 14 ++++++--- .github/workflows/submodule.yml | 42 +++++++++++++++++++++++++ 7 files changed, 92 insertions(+), 51 deletions(-) create mode 100644 .github/workflows/on_push_protected.yml create mode 100644 .github/workflows/submodule.yml diff --git a/.github/workflows/on_pr.yml b/.github/workflows/on_pr.yml index 5d7328fb..24640124 100644 --- a/.github/workflows/on_pr.yml +++ b/.github/workflows/on_pr.yml @@ -19,25 +19,9 @@ concurrency: cancel-in-progress: true jobs: - ensure_submodule_sanity: - name: Make sure we're not building with a fork - runs-on: ubuntu-latest - steps: - - name: Checkout DuckDB Python - uses: actions/checkout@v4 - - - shell: bash - run: | - submodule_url=$(git config --file .gitmodules --get submodule.external/duckdb.url || true) - expected="github.com/duckdb/duckdb" - if [[ -z "$submodule_url" ]]; then - echo "::error::DuckDB submodule not found in .gitmodules" - exit 1 - fi - if [[ "$submodule_url" != *"$expected"* ]]; then - echo "::error::DuckDB submodule must point to $expected, found: $submodule_url" - exit 1 - fi + submodule_sanity_guard: + name: Make sure submodule is in a sane state + uses: .github/workflows/submodule_sanity.yml packaging_test: name: Build a minimal set of packages and run all tests on them @@ -48,7 +32,7 @@ jobs: with: minimal: true testsuite: all - duckdb-git-ref: ${{ github.base_ref }} + duckdb-sha: ${{ github.base_ref }} coverage_test: name: Run coverage tests diff --git a/.github/workflows/on_push_protected.yml b/.github/workflows/on_push_protected.yml new file mode 100644 index 00000000..af64f88e --- /dev/null +++ b/.github/workflows/on_push_protected.yml @@ -0,0 +1,10 @@ +name: Guard pushes to protected branches +on: + push: + branches: + - main + - v*.*-* +jobs: + submodule_sanity_guard: + name: Make sure submodule is in a sane state + uses: .github/workflows/submodule_sanity.yml diff --git a/.github/workflows/packaging.yml b/.github/workflows/packaging.yml index 0dafaf75..25cf3bdd 100644 --- a/.github/workflows/packaging.yml +++ b/.github/workflows/packaging.yml @@ -16,15 +16,14 @@ on: - none - fast - all - git-ref: + duckdb-python-sha: type: string - description: Git ref of the DuckDB python package + description: The commit to build against (defaults to latest commit of current ref) required: false - duckdb-git-ref: + duckdb-sha: type: string - description: Git ref of DuckDB - required: true - default: refs/heads/main + description: Override the DuckDB submodule commit or ref to build against + required: false set-version: type: string description: Force version (vX.Y.Z-((rc|post)N)) @@ -40,13 +39,13 @@ on: description: Testsuite to run (none, fast, all) required: true default: all - git-ref: + duckdb-python-sha: type: string - description: Git ref of the DuckDB python package + description: The commit or ref to build against (defaults to latest commit of current ref) required: false - duckdb-git-ref: + duckdb-sha: type: string - description: Git ref of DuckDB + description: Override the DuckDB submodule commit or ref to build against required: false set-version: description: Force version (vX.Y.Z-((rc|post)N)) @@ -67,8 +66,8 @@ jobs: uses: ./.github/workflows/packaging_sdist.yml with: testsuite: all - git-ref: ${{ github.ref }} - duckdb-git-ref: ${{ inputs.duckdb-sha }} + duckdb-python-sha: ${{ inputs.duckdb-python-sha != '' && inputs.duckdb-python-sha || github.sha }} + duckdb-sha: ${{ inputs.duckdb-sha }} set-version: ${{ inputs.stable-version }} build_wheels: @@ -77,6 +76,6 @@ jobs: with: minimal: false testsuite: all - git-ref: ${{ github.ref }} - duckdb-git-ref: ${{ inputs.duckdb-sha }} + duckdb-python-sha: ${{ inputs.duckdb-python-sha != '' && inputs.duckdb-python-sha || github.sha }} + duckdb-sha: ${{ inputs.duckdb-sha }} set-version: ${{ inputs.stable-version }} diff --git a/.github/workflows/packaging_sdist.yml b/.github/workflows/packaging_sdist.yml index e8ee97f2..2723b437 100644 --- a/.github/workflows/packaging_sdist.yml +++ b/.github/workflows/packaging_sdist.yml @@ -7,13 +7,13 @@ on: description: Testsuite to run (none, fast, all) required: true default: all - git-ref: + duckdb-python-sha: type: string - description: Git ref of the DuckDB python package + description: The commit or ref to build against (defaults to current ref) required: false - duckdb-git-ref: + duckdb-sha: type: string - description: Git ref of DuckDB + description: Override the DuckDB submodule commit or ref to build against required: false set-version: description: Force version (vX.Y.Z-((rc|post)N)) @@ -39,16 +39,17 @@ jobs: - name: Checkout DuckDB Python uses: actions/checkout@v4 with: - ref: ${{ inputs.git-ref }} + ref: ${{ inputs.duckdb-python-sha }} fetch-depth: 0 submodules: true - name: Checkout DuckDB shell: bash + if: ${{ inputs.duckdb-sha }} run: | cd external/duckdb git fetch origin - git checkout ${{ inputs.duckdb-git-ref }} + git checkout ${{ inputs.duckdb-sha }} - name: Set OVERRIDE_GIT_DESCRIBE if: ${{ inputs.set-version != '' }} diff --git a/.github/workflows/packaging_wheels.yml b/.github/workflows/packaging_wheels.yml index f74b5f09..e3a3c08c 100644 --- a/.github/workflows/packaging_wheels.yml +++ b/.github/workflows/packaging_wheels.yml @@ -11,13 +11,13 @@ on: description: Testsuite to run (none, fast, all) required: true default: all - git-ref: + duckdb-python-sha: type: string - description: Git ref of the DuckDB python package + description: The commit or ref to build against (defaults to latest commit of current ref) required: false - duckdb-git-ref: + duckdb-sha: type: string - description: Git ref of DuckDB + description: Override the DuckDB submodule commit or ref to build against required: false set-version: description: Force version (vX.Y.Z-((rc|post)N)) @@ -59,16 +59,17 @@ jobs: - name: Checkout DuckDB Python uses: actions/checkout@v4 with: - ref: ${{ inputs.git-ref }} + ref: ${{ inputs.duckdb-python-sha }} fetch-depth: 0 submodules: true - name: Checkout DuckDB shell: bash + if: ${{ inputs.duckdb-python-sha }} run: | cd external/duckdb git fetch origin - git checkout ${{ inputs.duckdb-git-ref }} + git checkout ${{ inputs.duckdb-python-sha }} # Make sure that OVERRIDE_GIT_DESCRIBE is propagated to cibuildwhel's env, also when it's running linux builds - name: Set OVERRIDE_GIT_DESCRIBE diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 721c92fa..f0cfc3f1 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -3,9 +3,13 @@ name: Release on: workflow_dispatch: inputs: + duckdb-python-sha: + type: string + description: The commit to build against (defaults to latest commit of current ref) + required: false duckdb-sha: type: string - description: The DuckDB submodule commit to build against + description: The DuckDB submodule commit or ref to build against required: true stable-version: type: string @@ -33,8 +37,8 @@ jobs: uses: ./.github/workflows/packaging_sdist.yml with: testsuite: all - git-ref: ${{ github.ref }} - duckdb-git-ref: ${{ inputs.duckdb-sha }} + duckdb-python-sha: ${{ inputs.duckdb-python-sha != '' && inputs.duckdb-python-sha || github.sha }} + duckdb-sha: ${{ inputs.duckdb-sha }} set-version: ${{ inputs.stable-version }} workflow_state: @@ -111,8 +115,8 @@ jobs: with: minimal: false testsuite: all - git-ref: ${{ github.ref }} - duckdb-git-ref: ${{ inputs.duckdb-sha }} + duckdb-python-sha: ${{ inputs.duckdb-python-sha != '' && inputs.duckdb-python-sha || github.sha }} + duckdb-sha: ${{ inputs.duckdb-sha }} set-version: ${{ inputs.stable-version }} upload_s3: diff --git a/.github/workflows/submodule.yml b/.github/workflows/submodule.yml new file mode 100644 index 00000000..bb89124c --- /dev/null +++ b/.github/workflows/submodule.yml @@ -0,0 +1,42 @@ +name: Check DuckDB submodule sanity +on: + workflow_call: + workflow_dispatch: +jobs: + submodule_sanity: + name: Make sure submodule is in a sane state + runs-on: ubuntu-latest + steps: + - name: Checkout DuckDB Python + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Verify submodule origin + shell: bash + run: | + set -eux + git submodule update --init + cd external/duckdb + remote_count=$(git remote | wc -l) + if [[ $remote_count -gt 1 ]]; then + echo "::error::Multiple remotes found - only origin allowed" + git remote -v + fi + origin_url=$(git remote get-url origin) + if [[ "$origin_url" != "https://github.com/duckdb/duckdb"* ]]; then + echo "::error::Submodule origin has been tampered with: $origin_url" + exit 1 + fi + + - name: Disallow changes to .gitmodules in PRs and pushes + if: ${{ github.event_name == 'pull_request' || github.event_name == 'push' }} + shell: bash + run: | + set -eux + before=${{ github.event_name == 'push' && github.event.before || format('origin/{0}', github.base_ref) }} + after=${{ github.event_name == 'push' && github.event.after || github.head_ref }} + if git diff --name-only $before...$after | grep -q "^\.gitmodules$"; then + echo "::error::.gitmodules may not be modified. If you see a reason to update, please discuss with the maintainers." + exit 1 + fi From 8df5a2a1760cfdb74b3da4e5601c57bf966bccbb Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 4 Sep 2025 11:43:41 +0200 Subject: [PATCH 137/472] Renamed workflow --- .github/workflows/{submodule.yml => submodule_sanity.yml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename .github/workflows/{submodule.yml => submodule_sanity.yml} (100%) diff --git a/.github/workflows/submodule.yml b/.github/workflows/submodule_sanity.yml similarity index 100% rename from .github/workflows/submodule.yml rename to .github/workflows/submodule_sanity.yml From 80202a9370b463a8e71b3e264919783e7336a458 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 4 Sep 2025 11:45:08 +0200 Subject: [PATCH 138/472] Fixed job names --- .github/workflows/on_pr.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/on_pr.yml b/.github/workflows/on_pr.yml index 24640124..d8e6acb5 100644 --- a/.github/workflows/on_pr.yml +++ b/.github/workflows/on_pr.yml @@ -25,7 +25,7 @@ jobs: packaging_test: name: Build a minimal set of packages and run all tests on them - needs: ensure_submodule_sanity + needs: submodule_sanity_guard # Skip packaging tests for draft PRs if: ${{ github.event_name != 'pull_request' || github.event.pull_request.draft == false }} uses: ./.github/workflows/packaging.yml @@ -36,7 +36,7 @@ jobs: coverage_test: name: Run coverage tests - needs: ensure_submodule_sanity + needs: submodule_sanity_guard # Only run coverage test for draft PRs if: ${{ github.event_name == 'pull_request' && github.event.pull_request.draft == true }} uses: ./.github/workflows/coverage.yml From 38cc9b7801232dbd238cba4a545c9a7ff18106c9 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 4 Sep 2025 11:46:36 +0200 Subject: [PATCH 139/472] Fixed paths --- .github/workflows/on_pr.yml | 2 +- .github/workflows/on_push_protected.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/on_pr.yml b/.github/workflows/on_pr.yml index d8e6acb5..28e14e45 100644 --- a/.github/workflows/on_pr.yml +++ b/.github/workflows/on_pr.yml @@ -21,7 +21,7 @@ concurrency: jobs: submodule_sanity_guard: name: Make sure submodule is in a sane state - uses: .github/workflows/submodule_sanity.yml + uses: ./.github/workflows/submodule_sanity.yml packaging_test: name: Build a minimal set of packages and run all tests on them diff --git a/.github/workflows/on_push_protected.yml b/.github/workflows/on_push_protected.yml index af64f88e..9d5a52ad 100644 --- a/.github/workflows/on_push_protected.yml +++ b/.github/workflows/on_push_protected.yml @@ -7,4 +7,4 @@ on: jobs: submodule_sanity_guard: name: Make sure submodule is in a sane state - uses: .github/workflows/submodule_sanity.yml + uses: ./.github/workflows/submodule_sanity.yml From fdda027cd05278c87404a51c11c69168d23bbd62 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 4 Sep 2025 12:02:08 +0200 Subject: [PATCH 140/472] No push protection --- .github/workflows/on_push_protected.yml | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 .github/workflows/on_push_protected.yml diff --git a/.github/workflows/on_push_protected.yml b/.github/workflows/on_push_protected.yml deleted file mode 100644 index 9d5a52ad..00000000 --- a/.github/workflows/on_push_protected.yml +++ /dev/null @@ -1,10 +0,0 @@ -name: Guard pushes to protected branches -on: - push: - branches: - - main - - v*.*-* -jobs: - submodule_sanity_guard: - name: Make sure submodule is in a sane state - uses: ./.github/workflows/submodule_sanity.yml From 0e559f259506a135eb0f48631b0dbead0a8723d5 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 4 Sep 2025 12:52:21 +0200 Subject: [PATCH 141/472] Fix versioning --- duckdb_packaging/_versioning.py | 41 ++++++++++++---------- duckdb_packaging/build_backend.py | 30 ++++++++-------- duckdb_packaging/setuptools_scm_version.py | 8 ++--- 3 files changed, 42 insertions(+), 37 deletions(-) diff --git a/duckdb_packaging/_versioning.py b/duckdb_packaging/_versioning.py index 8a76fb42..ca8e7716 100644 --- a/duckdb_packaging/_versioning.py +++ b/duckdb_packaging/_versioning.py @@ -16,34 +16,34 @@ def parse_version(version: str) -> tuple[int, int, int, int, int]: """Parse a version string into its components. - + Args: version: Version string (e.g., "1.3.1", "1.3.2.rc3" or "1.3.1.post2") - + Returns: Tuple of (major, minor, patch, post, rc) - + Raises: ValueError: If version format is invalid """ match = VERSION_RE.match(version) if not match: raise ValueError(f"Invalid version format: {version} (expected X.Y.Z, X.Y.Z.rcM or X.Y.Z.postN)") - + major, minor, patch, rc, post = match.groups() return int(major), int(minor), int(patch), int(post or 0), int(rc or 0) def format_version(major: int, minor: int, patch: int, post: int = 0, rc: int = 0) -> str: """Format version components into a version string. - + Args: major: Major version number - minor: Minor version number + minor: Minor version number patch: Patch version number post: Post-release number rc: RC number - + Returns: Formatted version string """ @@ -59,31 +59,31 @@ def format_version(major: int, minor: int, patch: int, post: int = 0, rc: int = def git_tag_to_pep440(git_tag: str) -> str: """Convert git tag format to PEP440 format. - + Args: git_tag: Git tag (e.g., "v1.3.1", "v1.3.1-post1") - + Returns: PEP440 version string (e.g., "1.3.1", "1.3.1.post1") """ # Remove 'v' prefix if present version = git_tag[1:] if git_tag.startswith('v') else git_tag - + if "-post" in version: assert 'rc' not in version version = version.replace("-post", ".post") elif '-rc' in version: version = version.replace("-rc", "rc") - + return version def pep440_to_git_tag(version: str) -> str: """Convert PEP440 version to git tag format. - + Args: version: PEP440 version string (e.g., "1.3.1.post1" or "1.3.1rc2") - + Returns: Git tag format (e.g., "v1.3.1-post1") """ @@ -98,7 +98,7 @@ def pep440_to_git_tag(version: str) -> str: def get_current_version() -> Optional[str]: """Get the current version from git tags. - + Returns: Current version string or None if no tags exist """ @@ -149,16 +149,21 @@ def strip_post_from_version(version: str) -> str: return re.sub(r"[\.-]post[0-9]+", "", version) -def get_git_describe(repo_path: Optional[pathlib.Path] = None) -> Optional[str]: +def get_git_describe(repo_path: Optional[pathlib.Path] = None, since_major=False, since_minor=False) -> Optional[str]: """Get git describe output for version determination. - + Returns: Git describe output or None if no tags exist """ cwd = repo_path if repo_path is not None else None + pattern="v*.*.*" + if since_major: + pattern="v*.0.0" + elif since_minor: + pattern="v*.*.0" try: result = subprocess.run( - ["git", "describe", "--tags", "--long", "--match", "v*.*.*"], + ["git", "describe", "--tags", "--long", "--match", pattern], capture_output=True, text=True, check=True, @@ -167,4 +172,4 @@ def get_git_describe(repo_path: Optional[pathlib.Path] = None) -> Optional[str]: result.check_returncode() return result.stdout.strip() except FileNotFoundError: - raise RuntimeError("git executable can't be found") \ No newline at end of file + raise RuntimeError("git executable can't be found") diff --git a/duckdb_packaging/build_backend.py b/duckdb_packaging/build_backend.py index 2928d2e8..d96a4847 100644 --- a/duckdb_packaging/build_backend.py +++ b/duckdb_packaging/build_backend.py @@ -29,7 +29,7 @@ ) from duckdb_packaging._versioning import create_git_tag, pep440_to_git_tag, get_git_describe, strip_post_from_version -from duckdb_packaging.setuptools_scm_version import forced_version_from_env +from duckdb_packaging.setuptools_scm_version import forced_version_from_env, MAIN_BRANCH_VERSIONING _DUCKDB_VERSION_FILENAME = "duckdb_version.txt" @@ -41,7 +41,7 @@ def _log(msg: str, is_error: bool=False) -> None: """Log a message with build backend prefix. - + Args: msg: The message to log. is_error: If True, log to stderr; otherwise log to stdout. @@ -51,7 +51,7 @@ def _log(msg: str, is_error: bool=False) -> None: def _in_git_repository() -> bool: """Check if the current directory is inside a git repository. - + Returns: True if .git directory exists, False otherwise. """ @@ -129,7 +129,7 @@ def _skbuild_config_add( key: str, value: Union[List, str], config_settings: Dict[str, Union[List[str],str]], fail_if_exists: bool=False ): """Add or modify a configuration setting for scikit-build-core. - + This function handles adding values to scikit-build-core configuration settings, supporting both string and list types with appropriate merging behavior. @@ -145,7 +145,7 @@ def _skbuild_config_add( Behavior Rules: - String value + list setting: value is appended to the list - - String value + string setting: existing value is overridden + - String value + string setting: existing value is overridden - List value + list setting: existing list is extended - List value + string setting: raises RuntimeError @@ -180,18 +180,18 @@ def _skbuild_config_add( def build_sdist(sdist_directory: str, config_settings: Optional[Dict[str, Union[List[str],str]]] = None) -> str: """Build a source distribution using the DuckDB submodule. - + This function extracts the DuckDB version from either the git submodule and saves it to a version file before building the sdist with scikit-build-core. If _FORCED_PEP440_VERSION was set then we first create a tag on the submodule. - + Args: sdist_directory: Directory where the sdist will be created. config_settings: Optional build configuration settings. - + Returns: The filename of the created sdist. - + Raises: RuntimeError: If not in a git repository or DuckDB submodule issues. """ @@ -201,7 +201,7 @@ def build_sdist(sdist_directory: str, config_settings: Optional[Dict[str, Union[ if _FORCED_PEP440_VERSION is not None: duckdb_version = pep440_to_git_tag(strip_post_from_version(_FORCED_PEP440_VERSION)) else: - duckdb_version = get_git_describe(repo_path=submodule_path) + duckdb_version = get_git_describe(repo_path=submodule_path, since_minor=MAIN_BRANCH_VERSIONING) _write_duckdb_long_version(duckdb_version) return skbuild_build_sdist(sdist_directory, config_settings=config_settings) @@ -212,19 +212,19 @@ def build_wheel( metadata_directory: Optional[str] = None, ) -> str: """Build a wheel from either git submodule or extracted sdist sources. - + This function builds a wheel using scikit-build-core, handling two scenarios: 1. In a git repository: builds directly from the DuckDB submodule 2. In an sdist: reads the saved DuckDB version and passes it to CMake - + Args: wheel_directory: Directory where the wheel will be created. config_settings: Optional build configuration settings. metadata_directory: Optional directory for metadata preparation. - + Returns: The filename of the created wheel. - + Raises: RuntimeError: If not in a git repository or sdist environment. """ @@ -259,4 +259,4 @@ def build_wheel( "get_requires_for_build_editable", "prepare_metadata_for_build_wheel", "prepare_metadata_for_build_editable", -] \ No newline at end of file +] diff --git a/duckdb_packaging/setuptools_scm_version.py b/duckdb_packaging/setuptools_scm_version.py index 932fcd52..8381e1e2 100644 --- a/duckdb_packaging/setuptools_scm_version.py +++ b/duckdb_packaging/setuptools_scm_version.py @@ -27,10 +27,10 @@ def _main_branch_versioning(): def version_scheme(version: Any) -> str: """ setuptools_scm version scheme that matches DuckDB's original behavior. - + Args: version: setuptools_scm version object - + Returns: PEP440 compliant version string """ @@ -38,11 +38,11 @@ def version_scheme(version: Any) -> str: print(f"[version_scheme] version.tag: {version.tag}") print(f"[version_scheme] version.distance: {version.distance}") print(f"[version_scheme] version.dirty: {version.dirty}") - + # Handle case where tag is None if version.tag is None: raise ValueError("Need a valid version. Did you set a fallback_version in pyproject.toml?") - + try: return _bump_version(str(version.tag), version.distance, version.dirty) except Exception as e: From 958af3e8f829eb51e673bc6c1092d1c2c72ccd0f Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 4 Sep 2025 13:53:58 +0200 Subject: [PATCH 142/472] use correct sha --- .github/workflows/packaging_wheels.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/packaging_wheels.yml b/.github/workflows/packaging_wheels.yml index e3a3c08c..4c7599a6 100644 --- a/.github/workflows/packaging_wheels.yml +++ b/.github/workflows/packaging_wheels.yml @@ -65,11 +65,11 @@ jobs: - name: Checkout DuckDB shell: bash - if: ${{ inputs.duckdb-python-sha }} + if: ${{ inputs.duckdb-sha }} run: | cd external/duckdb git fetch origin - git checkout ${{ inputs.duckdb-python-sha }} + git checkout ${{ inputs.duckdb-sha }} # Make sure that OVERRIDE_GIT_DESCRIBE is propagated to cibuildwhel's env, also when it's running linux builds - name: Set OVERRIDE_GIT_DESCRIBE From 3f26c2b1012729116182bf2b816d8c95078856f9 Mon Sep 17 00:00:00 2001 From: Diego Sevilla Ruiz Date: Thu, 4 Sep 2025 17:58:28 +0200 Subject: [PATCH 143/472] Fix for #26: - Changed the type of the file_globs parameter of from_parquet and read_parquet in connection_methods.json - Added the generation of @overload functions in the generation wrappers python code. --- scripts/connection_methods.json | 4 ++-- scripts/generate_connection_stubs.py | 17 +++++++++-------- scripts/generate_connection_wrapper_stubs.py | 16 +++++++++------- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/scripts/connection_methods.json b/scripts/connection_methods.json index 27705d6a..a87b992f 100644 --- a/scripts/connection_methods.json +++ b/scripts/connection_methods.json @@ -412,7 +412,7 @@ "fetch_record_batch", "arrow" ], - + "function": "FetchRecordBatchReader", "docs": "Fetch an Arrow RecordBatchReader following execute()", "args": [ @@ -992,7 +992,7 @@ "args": [ { "name": "file_globs", - "type": "str" + "type": "List[str]" }, { "name": "binary_as_string", diff --git a/scripts/generate_connection_stubs.py b/scripts/generate_connection_stubs.py index 563ade3d..32831134 100644 --- a/scripts/generate_connection_stubs.py +++ b/scripts/generate_connection_stubs.py @@ -51,8 +51,12 @@ def create_arguments(arguments) -> list: result.append(argument) return result - def create_definition(name, method) -> str: - definition = f"def {name}(" + def create_definition(name, method, overloaded: bool) -> str: + if overloaded: + definition: str = "@overload\n" + else: + definition: str = "" + definition += f"def {name}(" arguments = ['self'] if 'args' in method: arguments.extend(create_arguments(method['args'])) @@ -66,8 +70,8 @@ def create_definition(name, method) -> str: return definition # We have "duplicate" methods, which are overloaded - # maybe we should add @overload to these instead, but this is easier - written_methods = set() + # We keep note of them to add the @overload decorator. + overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m['name'], list)} for method in connection_methods: if isinstance(method['name'], list): @@ -75,10 +79,7 @@ def create_definition(name, method) -> str: else: names = [method['name']] for name in names: - if name in written_methods: - continue - body.append(create_definition(name, method)) - written_methods.add(name) + body.append(create_definition(name, method, name in overloaded_methods)) # ---- End of generation code ---- diff --git a/scripts/generate_connection_wrapper_stubs.py b/scripts/generate_connection_wrapper_stubs.py index 94b0e0ee..d1ce50e3 100644 --- a/scripts/generate_connection_wrapper_stubs.py +++ b/scripts/generate_connection_wrapper_stubs.py @@ -66,8 +66,12 @@ def create_arguments(arguments) -> list: result.append(argument) return result - def create_definition(name, method) -> str: - definition = f"def {name}(" + def create_definition(name, method, overloaded: bool) -> str: + if overloaded: + definition: str = "@overload\n" + else: + definition: str = "" + definition += f"def {name}(" arguments = [] if name in SPECIAL_METHOD_NAMES: arguments.append('df: pandas.DataFrame') @@ -84,7 +88,8 @@ def create_definition(name, method) -> str: # We have "duplicate" methods, which are overloaded # maybe we should add @overload to these instead, but this is easier - written_methods = set() + # We keep note of them to add the @overload decorator. + overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m['name'], list)} body = [] for method in methods: @@ -99,10 +104,7 @@ def create_definition(name, method) -> str: method['kwargs'].append({'name': 'connection', 'type': 'DuckDBPyConnection', 'default': '...'}) for name in names: - if name in written_methods: - continue - body.append(create_definition(name, method)) - written_methods.add(name) + body.append(create_definition(name, method, name in overloaded_methods)) # ---- End of generation code ---- From d320863f844afad45e96e39432db39493932bba2 Mon Sep 17 00:00:00 2001 From: Diego Sevilla Ruiz Date: Thu, 4 Sep 2025 18:26:20 +0200 Subject: [PATCH 144/472] Fix comment. --- scripts/generate_connection_wrapper_stubs.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/scripts/generate_connection_wrapper_stubs.py b/scripts/generate_connection_wrapper_stubs.py index d1ce50e3..64912861 100644 --- a/scripts/generate_connection_wrapper_stubs.py +++ b/scripts/generate_connection_wrapper_stubs.py @@ -86,9 +86,7 @@ def create_definition(name, method, overloaded: bool) -> str: definition += f" -> {method['return']}: ..." return definition - # We have "duplicate" methods, which are overloaded - # maybe we should add @overload to these instead, but this is easier - # We keep note of them to add the @overload decorator. + # We have "duplicate" methods, which are overloaded. overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m['name'], list)} body = [] @@ -104,7 +102,6 @@ def create_definition(name, method, overloaded: bool) -> str: method['kwargs'].append({'name': 'connection', 'type': 'DuckDBPyConnection', 'default': '...'}) for name in names: - body.append(create_definition(name, method, name in overloaded_methods)) # ---- End of generation code ---- From ff1043e99880043e13876a2391ea07f877cb9323 Mon Sep 17 00:00:00 2001 From: Diego Sevilla Ruiz Date: Thu, 4 Sep 2025 18:29:20 +0200 Subject: [PATCH 145/472] Fix comments again. --- scripts/generate_connection_stubs.py | 2 +- scripts/generate_connection_wrapper_stubs.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/generate_connection_stubs.py b/scripts/generate_connection_stubs.py index 32831134..fbb66c21 100644 --- a/scripts/generate_connection_stubs.py +++ b/scripts/generate_connection_stubs.py @@ -69,7 +69,7 @@ def create_definition(name, method, overloaded: bool) -> str: definition += f" -> {method['return']}: ..." return definition - # We have "duplicate" methods, which are overloaded + # We have "duplicate" methods, which are overloaded. # We keep note of them to add the @overload decorator. overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m['name'], list)} diff --git a/scripts/generate_connection_wrapper_stubs.py b/scripts/generate_connection_wrapper_stubs.py index 64912861..62c60a84 100644 --- a/scripts/generate_connection_wrapper_stubs.py +++ b/scripts/generate_connection_wrapper_stubs.py @@ -87,6 +87,7 @@ def create_definition(name, method, overloaded: bool) -> str: return definition # We have "duplicate" methods, which are overloaded. + # We keep note of them to add the @overload decorator. overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m['name'], list)} body = [] @@ -102,6 +103,7 @@ def create_definition(name, method, overloaded: bool) -> str: method['kwargs'].append({'name': 'connection', 'type': 'DuckDBPyConnection', 'default': '...'}) for name in names: + body.append(create_definition(name, method, name in overloaded_methods)) # ---- End of generation code ---- From 537e39d65913ea5a15ac49eed101360492f61074 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 4 Sep 2025 20:49:26 +0200 Subject: [PATCH 146/472] always show summary --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f0cfc3f1..d09e6e46 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -188,7 +188,7 @@ jobs: name: Release summary runs-on: ubuntu-latest needs: [build_sdist, workflow_state, build_wheels, upload_s3, publish_pypi, cleanup_nightlies] - if: true + if: always() steps: - run: | sha=${{ github.sha }} From 0c09355de163f27f525816a80decf47ac30951e8 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 4 Sep 2025 20:51:12 +0200 Subject: [PATCH 147/472] bumped submodule --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 24d2e45b..605eaf76 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 24d2e45b14126a7083a0c01c6a45b75390a46922 +Subproject commit 605eaf76be154d5c6d38353f96b23c031795572d From 462a395902c774c2d67dab28484e8d44f1df8931 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 4 Sep 2025 20:52:37 +0200 Subject: [PATCH 148/472] rc branch 1.4-andium --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 605eaf76..03987f96 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 605eaf76be154d5c6d38353f96b23c031795572d +Subproject commit 03987f9660037059fdabd5b052a694a264b927c8 From 35e6c7d6bc4c55d0a1749b308f4992247afe1761 Mon Sep 17 00:00:00 2001 From: Julian Meyers Date: Fri, 5 Sep 2025 09:50:11 -0500 Subject: [PATCH 149/472] Add support for funky column names --- duckdb/polars_io.py | 6 +++--- tests/fast/arrow/test_polars.py | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index dbe8727b..ad03038a 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -95,7 +95,8 @@ def _pl_tree_to_sql(tree: dict) -> str: ) if node_type == "Column": # A reference to a column name - return subtree + # Wrap in quotes to handle special characters + return f'"{subtree}"' if node_type in ("Literal", "Dyn"): # Recursively process dynamic or literal values @@ -196,7 +197,7 @@ def source_generator( duck_predicate = None relation_final = relation if with_columns is not None: - cols = ",".join(with_columns) + cols = ",".join(f'"{col}"' for col in with_columns) relation_final = relation_final.project(cols) if n_rows is not None: relation_final = relation_final.limit(n_rows) @@ -213,7 +214,6 @@ def source_generator( while True: try: record_batch = results.read_next_batch() - df = pl.from_arrow(record_batch) if predicate is not None and duck_predicate is None: # We have a predicate, but did not manage to push it down, we fallback here yield pl.from_arrow(record_batch).filter(predicate) diff --git a/tests/fast/arrow/test_polars.py b/tests/fast/arrow/test_polars.py index 89ccf031..604c3987 100644 --- a/tests/fast/arrow/test_polars.py +++ b/tests/fast/arrow/test_polars.py @@ -131,6 +131,29 @@ def test_polars_lazy(self, duckdb_cursor): ] assert lazy_df.filter(pl.col("b") < 32).select('a').collect().to_dicts() == [{'a': 'Mark'}, {'a': 'Thijs'}] + def test_polars_column_with_tricky_name(self, duckdb_cursor): + # Test that a polars DataFrame with a column name that is non standard still works + df_colon = pl.DataFrame({"x:y": [1, 2]}) + lf = duckdb_cursor.sql("from df_colon").pl(lazy=True) + result = lf.select(pl.all()).collect() + assert result.to_dicts() == [{"x:y": 1}, {"x:y": 2}] + result = lf.select(pl.all()).filter(pl.col("x:y") == 1).collect() + assert result.to_dicts() == [{"x:y": 1}] + + df_space = pl.DataFrame({"x y": [1, 2]}) + lf = duckdb_cursor.sql("from df_space").pl(lazy=True) + result = lf.select(pl.all()).collect() + assert result.to_dicts() == [{"x y": 1}, {"x y": 2}] + result = lf.select(pl.all()).filter(pl.col("x y") == 1).collect() + assert result.to_dicts() == [{"x y": 1}] + + df_dot = pl.DataFrame({"x.y": [1, 2]}) + lf = duckdb_cursor.sql("from df_dot").pl(lazy=True) + result = lf.select(pl.all()).collect() + assert result.to_dicts() == [{"x.y": 1}, {"x.y": 2}] + result = lf.select(pl.all()).filter(pl.col("x.y") == 1).collect() + assert result.to_dicts() == [{"x.y": 1}] + @pytest.mark.parametrize( 'data_type', [ From 71a4026cfa9dad7b004eb8d93f646c4d65e61205 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 8 Sep 2025 13:30:10 +0200 Subject: [PATCH 150/472] update submodule --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 03987f96..16671635 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 03987f9660037059fdabd5b052a694a264b927c8 +Subproject commit 166716352edf482a378ade0f2990abb7237ae841 From 8e11220f751f4f20e35fee08a30ed8bb83ec119c Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 8 Sep 2025 13:30:43 +0200 Subject: [PATCH 151/472] update submodule --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 605eaf76..f99fed1e 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 605eaf76be154d5c6d38353f96b23c031795572d +Subproject commit f99fed1e0b16a842573f9dad529f6c170a004f6e From f9d84b23c35b2762296d12d3aee51b1bb800bd5a Mon Sep 17 00:00:00 2001 From: Julian Meyers Date: Mon, 8 Sep 2025 09:58:36 -0500 Subject: [PATCH 152/472] Add handling for identifiers containing double quotes --- duckdb/polars_io.py | 16 ++++++++++++++-- tests/fast/arrow/test_polars.py | 7 +++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index ad03038a..d8d4cfe9 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -58,6 +58,18 @@ def _pl_operation_to_sql(op: str) -> str: raise NotImplementedError(op) +def _escape_sql_identifier(identifier: str) -> str: + """ + Escape SQL identifiers by doubling any double quotes and wrapping in double quotes. + + Example: + >>> _escape_sql_identifier('column"name') + '"column""name"' + """ + escaped = identifier.replace('"', '""') + return f'"{escaped}"' + + def _pl_tree_to_sql(tree: dict) -> str: """ Recursively convert a Polars expression tree (as JSON) to a SQL string. @@ -96,7 +108,7 @@ def _pl_tree_to_sql(tree: dict) -> str: if node_type == "Column": # A reference to a column name # Wrap in quotes to handle special characters - return f'"{subtree}"' + return _escape_sql_identifier(subtree) if node_type in ("Literal", "Dyn"): # Recursively process dynamic or literal values @@ -197,7 +209,7 @@ def source_generator( duck_predicate = None relation_final = relation if with_columns is not None: - cols = ",".join(f'"{col}"' for col in with_columns) + cols = ",".join(map(_escape_sql_identifier, with_columns)) relation_final = relation_final.project(cols) if n_rows is not None: relation_final = relation_final.limit(n_rows) diff --git a/tests/fast/arrow/test_polars.py b/tests/fast/arrow/test_polars.py index 604c3987..87e2f726 100644 --- a/tests/fast/arrow/test_polars.py +++ b/tests/fast/arrow/test_polars.py @@ -154,6 +154,13 @@ def test_polars_column_with_tricky_name(self, duckdb_cursor): result = lf.select(pl.all()).filter(pl.col("x.y") == 1).collect() assert result.to_dicts() == [{"x.y": 1}] + df_quote = pl.DataFrame({'"xy"': [1, 2]}) + lf = duckdb_cursor.sql("from df_quote").pl(lazy=True) + result = lf.select(pl.all()).collect() + assert result.to_dicts() == [{'"xy"': 1}, {'"xy"': 2}] + result = lf.select(pl.all()).filter(pl.col('"xy"') == 1).collect() + assert result.to_dicts() == [{'"xy"': 1}] + @pytest.mark.parametrize( 'data_type', [ From f9c3d39fa4d37a5c0bf455d39989b9cc0f11e607 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 8 Sep 2025 20:48:13 +0200 Subject: [PATCH 153/472] Run workflow on edited PR --- .github/workflows/on_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/on_pr.yml b/.github/workflows/on_pr.yml index 28e14e45..d62dadae 100644 --- a/.github/workflows/on_pr.yml +++ b/.github/workflows/on_pr.yml @@ -4,7 +4,7 @@ on: branches: - main - v*.*-* - types: [opened, reopened, ready_for_review, converted_to_draft] + types: [opened, edited, reopened, ready_for_review, converted_to_draft] paths-ignore: - '**.md' - 'LICENSE' From 725260a0cf5124e4bfbc36469ccf1ddba199d983 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 8 Sep 2025 20:48:13 +0200 Subject: [PATCH 154/472] Run workflow on edited PR --- .github/workflows/on_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/on_pr.yml b/.github/workflows/on_pr.yml index 28e14e45..d62dadae 100644 --- a/.github/workflows/on_pr.yml +++ b/.github/workflows/on_pr.yml @@ -4,7 +4,7 @@ on: branches: - main - v*.*-* - types: [opened, reopened, ready_for_review, converted_to_draft] + types: [opened, edited, reopened, ready_for_review, converted_to_draft] paths-ignore: - '**.md' - 'LICENSE' From 94d153b5b54d44f7fd2cf50fd34d644ecc14d3d3 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 8 Sep 2025 21:58:23 +0200 Subject: [PATCH 155/472] Use synchronize pr event type --- .github/workflows/on_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/on_pr.yml b/.github/workflows/on_pr.yml index d62dadae..7a4669cb 100644 --- a/.github/workflows/on_pr.yml +++ b/.github/workflows/on_pr.yml @@ -4,7 +4,7 @@ on: branches: - main - v*.*-* - types: [opened, edited, reopened, ready_for_review, converted_to_draft] + types: [opened, reopened, ready_for_review, converted_to_draft, synchronize] paths-ignore: - '**.md' - 'LICENSE' From 3601806dc16c955ea169bdff8d937d0c7f928a31 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 8 Sep 2025 21:58:23 +0200 Subject: [PATCH 156/472] Use synchronize pr event type --- .github/workflows/on_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/on_pr.yml b/.github/workflows/on_pr.yml index d62dadae..7a4669cb 100644 --- a/.github/workflows/on_pr.yml +++ b/.github/workflows/on_pr.yml @@ -4,7 +4,7 @@ on: branches: - main - v*.*-* - types: [opened, edited, reopened, ready_for_review, converted_to_draft] + types: [opened, reopened, ready_for_review, converted_to_draft, synchronize] paths-ignore: - '**.md' - 'LICENSE' From 433a9ea14b45f9994a2f89c20d1d7379211f746b Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 4 Sep 2025 21:09:53 +0200 Subject: [PATCH 157/472] Make arrow from relation return record batch reader --- duckdb/__init__.pyi | 2 +- duckdb/experimental/spark/sql/dataframe.py | 2 +- src/duckdb_py/pyrelation/initialize.cpp | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/duckdb/__init__.pyi b/duckdb/__init__.pyi index adf142dd..b22daef4 100644 --- a/duckdb/__init__.pyi +++ b/duckdb/__init__.pyi @@ -453,7 +453,7 @@ class DuckDBPyRelation: def set_alias(self, alias: str) -> DuckDBPyRelation: ... def show(self, max_width: Optional[int] = None, max_rows: Optional[int] = None, max_col_width: Optional[int] = None, null_value: Optional[str] = None, render_mode: Optional[RenderMode] = None) -> None: ... def sql_query(self) -> str: ... - def to_arrow_table(self, batch_size: int = ...) -> pyarrow.lib.Table: ... + def to_arrow_table(self, batch_size: int = ...) -> pyarrow.lib.RecordBatchReader: ... def to_csv( self, file_name: str, diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index b8a4698b..a81a423b 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -75,7 +75,7 @@ def toArrow(self) -> "pa.Table": age: [[2,5]] name: [["Alice","Bob"]] """ - return self.relation.arrow() + return self.relation.to_arrow_table() def createOrReplaceTempView(self, name: str) -> None: """Creates or replaces a local temporary view with this :class:`DataFrame`. diff --git a/src/duckdb_py/pyrelation/initialize.cpp b/src/duckdb_py/pyrelation/initialize.cpp index a93a54b5..794c420b 100644 --- a/src/duckdb_py/pyrelation/initialize.cpp +++ b/src/duckdb_py/pyrelation/initialize.cpp @@ -61,7 +61,7 @@ static void InitializeConsumers(py::class_ &m) { py::arg("date_as_object") = false) .def("fetch_df_chunk", &DuckDBPyRelation::FetchDFChunk, "Execute and fetch a chunk of the rows", py::arg("vectors_per_chunk") = 1, py::kw_only(), py::arg("date_as_object") = false) - .def("arrow", &DuckDBPyRelation::ToArrowTable, "Execute and fetch all rows as an Arrow Table", + .def("arrow", &DuckDBPyRelation::ToRecordBatch, "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000) .def("fetch_arrow_table", &DuckDBPyRelation::ToArrowTable, "Execute and fetch all rows as an Arrow Table", py::arg("batch_size") = 1000000) @@ -78,7 +78,7 @@ static void InitializeConsumers(py::class_ &m) { )"; m.def("__arrow_c_stream__", &DuckDBPyRelation::ToArrowCapsule, capsule_docs, py::arg("requested_schema") = py::none()); - m.def("record_batch", &DuckDBPyRelation::ToRecordBatch, + m.def("fetch_record_batch", &DuckDBPyRelation::ToRecordBatch, "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000) .def("fetch_arrow_reader", &DuckDBPyRelation::ToRecordBatch, "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000); From 834fa435477ef740a560fe9460c4571ab8c25bc5 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 8 Sep 2025 16:36:13 +0200 Subject: [PATCH 158/472] deprecate instead of remove --- duckdb/__init__.pyi | 5 +++-- src/duckdb_py/pyrelation/initialize.cpp | 17 ++++++++++++++--- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/duckdb/__init__.pyi b/duckdb/__init__.pyi index b22daef4..8f27e5e3 100644 --- a/duckdb/__init__.pyi +++ b/duckdb/__init__.pyi @@ -415,7 +415,7 @@ class DuckDBPyRelation: def variance(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... def list(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def arrow(self, batch_size: int = ...) -> pyarrow.lib.Table: ... + def arrow(self, batch_size: int = ...) -> pyarrow.lib.RecordBatchReader: ... def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> object: ... def create(self, table_name: str) -> None: ... def create_view(self, view_name: str, replace: bool = ...) -> DuckDBPyRelation: ... @@ -448,12 +448,13 @@ class DuckDBPyRelation: def pl(self, rows_per_batch: int = ..., connection: DuckDBPyConnection = ...) -> polars.DataFrame: ... def query(self, virtual_table_name: str, sql_query: str) -> DuckDBPyRelation: ... def record_batch(self, batch_size: int = ...) -> pyarrow.lib.RecordBatchReader: ... + def fetch_record_batch(self, rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.RecordBatchReader: ... def select_types(self, types: List[Union[str, DuckDBPyType]]) -> DuckDBPyRelation: ... def select_dtypes(self, types: List[Union[str, DuckDBPyType]]) -> DuckDBPyRelation: ... def set_alias(self, alias: str) -> DuckDBPyRelation: ... def show(self, max_width: Optional[int] = None, max_rows: Optional[int] = None, max_col_width: Optional[int] = None, null_value: Optional[str] = None, render_mode: Optional[RenderMode] = None) -> None: ... def sql_query(self) -> str: ... - def to_arrow_table(self, batch_size: int = ...) -> pyarrow.lib.RecordBatchReader: ... + def to_arrow_table(self, batch_size: int = ...) -> pyarrow.lib.Table: ... def to_csv( self, file_name: str, diff --git a/src/duckdb_py/pyrelation/initialize.cpp b/src/duckdb_py/pyrelation/initialize.cpp index 794c420b..6f66c563 100644 --- a/src/duckdb_py/pyrelation/initialize.cpp +++ b/src/duckdb_py/pyrelation/initialize.cpp @@ -79,9 +79,20 @@ static void InitializeConsumers(py::class_ &m) { m.def("__arrow_c_stream__", &DuckDBPyRelation::ToArrowCapsule, capsule_docs, py::arg("requested_schema") = py::none()); m.def("fetch_record_batch", &DuckDBPyRelation::ToRecordBatch, - "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000) - .def("fetch_arrow_reader", &DuckDBPyRelation::ToRecordBatch, - "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000); + "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("rows_per_batch") = 1000000) + .def("fetch_arrow_reader", &DuckDBPyRelation::ToRecordBatch, + "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("rows_per_batch") = 1000000) + .def("record_batch", + [](pybind11::object &self, idx_t rows_per_batch) + { + auto warnings = pybind11::module::import("warnings"); + auto builtins = pybind11::module::import("builtins"); + warnings.attr("warn")( + "record_batch() is deprecated, use fetch_record_batch() instead.", + builtins.attr("DeprecationWarning")); + + return self.attr("fetch_record_batch")(rows_per_batch); + }, py::arg("rows_per_batch") = 1000000); } static void InitializeAggregates(py::class_ &m) { From e8a630e22726245f97b3ec3ccee63fcd24100252 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 8 Sep 2025 21:53:54 +0200 Subject: [PATCH 159/472] Ignore deprecationwarnings in tests --- src/duckdb_py/pyrelation/initialize.cpp | 11 ++++------- tests/pytest.ini | 1 + 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/duckdb_py/pyrelation/initialize.cpp b/src/duckdb_py/pyrelation/initialize.cpp index 6f66c563..867cd7a6 100644 --- a/src/duckdb_py/pyrelation/initialize.cpp +++ b/src/duckdb_py/pyrelation/initialize.cpp @@ -85,13 +85,10 @@ static void InitializeConsumers(py::class_ &m) { .def("record_batch", [](pybind11::object &self, idx_t rows_per_batch) { - auto warnings = pybind11::module::import("warnings"); - auto builtins = pybind11::module::import("builtins"); - warnings.attr("warn")( - "record_batch() is deprecated, use fetch_record_batch() instead.", - builtins.attr("DeprecationWarning")); - - return self.attr("fetch_record_batch")(rows_per_batch); + PyErr_WarnEx(PyExc_DeprecationWarning, + "record_batch() is deprecated, use fetch_record_batch() instead.", + 0); + return self.attr("fetch_record_batch")(rows_per_batch); }, py::arg("rows_per_batch") = 1000000); } diff --git a/tests/pytest.ini b/tests/pytest.ini index 5dd3c306..0c17afd5 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -3,6 +3,7 @@ filterwarnings = error ignore::UserWarning + ignore::DeprecationWarning # Jupyter is throwing DeprecationWarnings ignore:function ham\(\) is deprecated:DeprecationWarning # Pyspark is throwing these warnings From de8c94452972ae0031d70e5a48869ee4c680d819 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 08:48:37 +0200 Subject: [PATCH 160/472] param name shouldnt change --- src/duckdb_py/pyrelation/initialize.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/duckdb_py/pyrelation/initialize.cpp b/src/duckdb_py/pyrelation/initialize.cpp index 867cd7a6..7992cc17 100644 --- a/src/duckdb_py/pyrelation/initialize.cpp +++ b/src/duckdb_py/pyrelation/initialize.cpp @@ -81,7 +81,7 @@ static void InitializeConsumers(py::class_ &m) { m.def("fetch_record_batch", &DuckDBPyRelation::ToRecordBatch, "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("rows_per_batch") = 1000000) .def("fetch_arrow_reader", &DuckDBPyRelation::ToRecordBatch, - "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("rows_per_batch") = 1000000) + "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000) .def("record_batch", [](pybind11::object &self, idx_t rows_per_batch) { @@ -89,7 +89,7 @@ static void InitializeConsumers(py::class_ &m) { "record_batch() is deprecated, use fetch_record_batch() instead.", 0); return self.attr("fetch_record_batch")(rows_per_batch); - }, py::arg("rows_per_batch") = 1000000); + }, py::arg("batch_size") = 1000000); } static void InitializeAggregates(py::class_ &m) { From 32e41a89d225477af7c179e420241187a0f81988 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 10:12:44 +0200 Subject: [PATCH 161/472] Fix environment determination --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index d09e6e46..39ee631e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -78,7 +78,7 @@ jobs: if [[ test == "${{ inputs.pypi-index }}" ]]; then ci_env=pypi-test elif [[ prod == "${{ inputs.pypi-index }}" ]]; then - ci_env=pypi-prod${{ inputs.stable-version && '' || '-nightly' }} + ci_env=pypi-prod${{ inputs.stable-version != '' && '' || '-nightly' }} else echo "::error::Invalid value for inputs.pypi-index: ${{ inputs.pypi-index }}" exit 1 From ce802de6981cf9b8e89890a12f4ca04053b88347 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 11:46:17 +0200 Subject: [PATCH 162/472] switch condition - empty string evaluates to false --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 39ee631e..f54b0f76 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -78,7 +78,7 @@ jobs: if [[ test == "${{ inputs.pypi-index }}" ]]; then ci_env=pypi-test elif [[ prod == "${{ inputs.pypi-index }}" ]]; then - ci_env=pypi-prod${{ inputs.stable-version != '' && '' || '-nightly' }} + ci_env=pypi-prod${{ inputs.stable-version == '' && '-nightly' || '' }} else echo "::error::Invalid value for inputs.pypi-index: ${{ inputs.pypi-index }}" exit 1 From 7059576d61faf719b22b523b8ddfda544fb9b2e4 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 19:19:19 +0200 Subject: [PATCH 163/472] bumped submodule --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 16671635..c11813f1 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 166716352edf482a378ade0f2990abb7237ae841 +Subproject commit c11813f1b4e89aa7096b8d18ca9eb608e27168a8 From bf1cbebb875cd04781e9831f93fd1d5852a50f67 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 10 Sep 2025 11:28:10 +0200 Subject: [PATCH 164/472] Disable uploading to pypi for main --- .github/workflows/release.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f54b0f76..77d6dcc6 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -146,6 +146,7 @@ jobs: publish_pypi: name: Publish Artifacts to PyPI runs-on: ubuntu-latest + if: ${{ !always() }} needs: [workflow_state, build_sdist, build_wheels] environment: name: ${{ needs.workflow_state.outputs.ci_env }} From 20848e2a36d02fc1cd30371c7b903c7bfc453957 Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Sep 2025 09:54:52 +0200 Subject: [PATCH 165/472] fix the 'description' attribute, using DuckDBPyType --- duckdb/__init__.py | 20 ++++++++++++++++++++ src/duckdb_py/pyresult.cpp | 2 +- src/duckdb_py/typing/pytype.cpp | 13 +++++++++++-- tests/fast/api/test_duckdb_connection.py | 22 +++++++++++++++++++++- 4 files changed, 53 insertions(+), 4 deletions(-) diff --git a/duckdb/__init__.py b/duckdb/__init__.py index c3ec0610..c174c30e 100644 --- a/duckdb/__init__.py +++ b/duckdb/__init__.py @@ -18,6 +18,26 @@ def version(): "functional" ]) +class DBAPITypeObject: + def __init__(self, name, types): + self.name = name + self.types = list(types) + + def __eq__(self, other): + if isinstance(other, typing.DuckDBPyType): + return other in self.types + return False + + def __repr__(self): + return f"" + +# Define the standard DBAPI sentinels +STRING = DBAPITypeObject("STRING", {"VARCHAR", "CHAR", "TEXT"}) +NUMBER = DBAPITypeObject("NUMBER", {"INTEGER", "BIGINT", "DECIMAL", "DOUBLE"}) +DATETIME = DBAPITypeObject("DATETIME", {"DATE", "TIME", "TIMESTAMP"}) +BINARY = DBAPITypeObject("BINARY", {"BLOB"}) +ROWID = None + # Classes from _duckdb import ( DuckDBPyRelation, diff --git a/src/duckdb_py/pyresult.cpp b/src/duckdb_py/pyresult.cpp index 5997d57b..68a569a4 100644 --- a/src/duckdb_py/pyresult.cpp +++ b/src/duckdb_py/pyresult.cpp @@ -579,7 +579,7 @@ py::list DuckDBPyResult::GetDescription(const vector &names, const vecto for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { auto py_name = py::str(names[col_idx]); - auto py_type = GetTypeToPython(types[col_idx]); + auto py_type = DuckDBPyType(types[col_idx]); desc.append(py::make_tuple(py_name, py_type, py::none(), py::none(), py::none(), py::none(), py::none())); } return desc; diff --git a/src/duckdb_py/typing/pytype.cpp b/src/duckdb_py/typing/pytype.cpp index de03fa7d..f679896b 100644 --- a/src/duckdb_py/typing/pytype.cpp +++ b/src/duckdb_py/typing/pytype.cpp @@ -325,8 +325,17 @@ static LogicalType FromObject(const py::object &object) { void DuckDBPyType::Initialize(py::handle &m) { auto type_module = py::class_>(m, "DuckDBPyType", py::module_local()); - type_module.def("__repr__", &DuckDBPyType::ToString, "Stringified representation of the type object"); - type_module.def("__eq__", &DuckDBPyType::Equals, "Compare two types for equality", py::arg("other")); + type_module.def("__repr__", &DuckDBPyType::ToString, "Stringified representation of the type object") + .def("__eq__", [](const DuckDBPyType &self, py::handle other) -> py::object { + if (py::isinstance(other)) { + return py::bool_(self.Equals(other.cast>())); + } + else if (py::isinstance(other)) { + return py::bool_(self.EqualsString(other.cast())); + } + // Return NotImplemented for unsupported types + return py::reinterpret_borrow(Py_NotImplemented); + }); type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other")); type_module.def_property_readonly("id", &DuckDBPyType::GetId); type_module.def_property_readonly("children", &DuckDBPyType::Children); diff --git a/tests/fast/api/test_duckdb_connection.py b/tests/fast/api/test_duckdb_connection.py index c9f46021..7f0cd390 100644 --- a/tests/fast/api/test_duckdb_connection.py +++ b/tests/fast/api/test_duckdb_connection.py @@ -1,4 +1,5 @@ import duckdb +import duckdb.typing import pytest from conftest import NumpyPandas, ArrowPandas @@ -113,9 +114,28 @@ def test_readonly_properties(self): duckdb.execute("select 42") description = duckdb.description() rowcount = duckdb.rowcount() - assert description == [('42', 'NUMBER', None, None, None, None, None)] + assert description == [('42', 'INTEGER', None, None, None, None, None)] assert rowcount == -1 + def test_description(self): + duckdb.execute("select 42 a, 'test' b, true c") + types = [x[1] for x in duckdb.description()] + + STRING = duckdb.STRING + NUMBER = duckdb.NUMBER + DATETIME = duckdb.DATETIME + + assert(types[1] == STRING) + assert(STRING == types[1]) + assert(types[0] != STRING) + assert((types[1] != STRING) == False) + assert((STRING != types[1]) == False) + + assert(types[1] in [STRING]) + assert(types[1] in [STRING, NUMBER]) + assert(types[1] not in [NUMBER, DATETIME]) + + def test_execute(self): assert [([4, 2],)] == duckdb.execute("select [4,2]").fetchall() From 746784fa2fb300bbb6632e14b4f24ad46164fdbf Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Sep 2025 10:05:56 +0200 Subject: [PATCH 166/472] fix tests --- tests/fast/api/test_3728.py | 2 +- tests/fast/api/test_dbapi10.py | 33 +++++++++++++++---- tests/fast/api/test_duckdb_connection.py | 22 ------------- .../relational_api/test_rapi_description.py | 3 +- tests/fast/test_map.py | 4 +-- tests/fast/test_result.py | 8 ++--- 6 files changed, 35 insertions(+), 37 deletions(-) diff --git a/tests/fast/api/test_3728.py b/tests/fast/api/test_3728.py index da0d2015..2df3c156 100644 --- a/tests/fast/api/test_3728.py +++ b/tests/fast/api/test_3728.py @@ -14,6 +14,6 @@ def test_3728_describe_enum(self, duckdb_cursor): # This fails with "RuntimeError: Not implemented Error: unsupported type: mood" assert cursor.table("person").execute().description == [ - ('name', 'STRING', None, None, None, None, None), + ('name', 'VARCHAR', None, None, None, None, None), ('current_mood', "ENUM('sad', 'ok', 'happy')", None, None, None, None, None), ] diff --git a/tests/fast/api/test_dbapi10.py b/tests/fast/api/test_dbapi10.py index aec4b24a..1fbde602 100644 --- a/tests/fast/api/test_dbapi10.py +++ b/tests/fast/api/test_dbapi10.py @@ -1,19 +1,20 @@ # cursor description from datetime import datetime, date from pytest import mark +import duckdb class TestCursorDescription(object): @mark.parametrize( "query,column_name,string_type,real_type", [ - ["SELECT * FROM integers", "i", "NUMBER", int], - ["SELECT * FROM timestamps", "t", "DATETIME", datetime], - ["SELECT DATE '1992-09-20' AS date_col;", "date_col", "Date", date], - ["SELECT '\\xAA'::BLOB AS blob_col;", "blob_col", "BINARY", bytes], - ["SELECT {'x': 1, 'y': 2, 'z': 3} AS struct_col", "struct_col", "dict", dict], - ["SELECT [1, 2, 3] AS list_col", "list_col", "list", list], - ["SELECT 'Frank' AS str_col", "str_col", "STRING", str], + ["SELECT * FROM integers", "i", "INTEGER", int], + ["SELECT * FROM timestamps", "t", "TIMESTAMP", datetime], + ["SELECT DATE '1992-09-20' AS date_col;", "date_col", "DATE", date], + ["SELECT '\\xAA'::BLOB AS blob_col;", "blob_col", "BLOB", bytes], + ["SELECT {'x': 1, 'y': 2, 'z': 3} AS struct_col", "struct_col", "STRUCT(x INTEGER, y INTEGER, z INTEGER)", dict], + ["SELECT [1, 2, 3] AS list_col", "list_col", "INTEGER[]", list], + ["SELECT 'Frank' AS str_col", "str_col", "VARCHAR", str], ["SELECT [1, 2, 3]::JSON AS json_col", "json_col", "JSON", str], ["SELECT union_value(tag := 1) AS union_col", "union_col", "UNION(tag INTEGER)", int], ], @@ -23,6 +24,24 @@ def test_description(self, query, column_name, string_type, real_type, duckdb_cu assert duckdb_cursor.description == [(column_name, string_type, None, None, None, None, None)] assert isinstance(duckdb_cursor.fetchone()[0], real_type) + def test_description_comparisons(self): + duckdb.execute("select 42 a, 'test' b, true c") + types = [x[1] for x in duckdb.description()] + + STRING = duckdb.STRING + NUMBER = duckdb.NUMBER + DATETIME = duckdb.DATETIME + + assert(types[1] == STRING) + assert(STRING == types[1]) + assert(types[0] != STRING) + assert((types[1] != STRING) == False) + assert((STRING != types[1]) == False) + + assert(types[1] in [STRING]) + assert(types[1] in [STRING, NUMBER]) + assert(types[1] not in [NUMBER, DATETIME]) + def test_none_description(self, duckdb_empty_cursor): assert duckdb_empty_cursor.description is None diff --git a/tests/fast/api/test_duckdb_connection.py b/tests/fast/api/test_duckdb_connection.py index 7f0cd390..4cb565c1 100644 --- a/tests/fast/api/test_duckdb_connection.py +++ b/tests/fast/api/test_duckdb_connection.py @@ -117,25 +117,6 @@ def test_readonly_properties(self): assert description == [('42', 'INTEGER', None, None, None, None, None)] assert rowcount == -1 - def test_description(self): - duckdb.execute("select 42 a, 'test' b, true c") - types = [x[1] for x in duckdb.description()] - - STRING = duckdb.STRING - NUMBER = duckdb.NUMBER - DATETIME = duckdb.DATETIME - - assert(types[1] == STRING) - assert(STRING == types[1]) - assert(types[0] != STRING) - assert((types[1] != STRING) == False) - assert((STRING != types[1]) == False) - - assert(types[1] in [STRING]) - assert(types[1] in [STRING, NUMBER]) - assert(types[1] not in [NUMBER, DATETIME]) - - def test_execute(self): assert [([4, 2],)] == duckdb.execute("select [4,2]").fetchall() @@ -369,9 +350,6 @@ def test_view(self): assert [([0, 1, 2, 3, 4],)] == duckdb.view("vw").fetchall() duckdb.execute("drop view vw") - def test_description(self): - assert None != duckdb.description - def test_close(self): assert None != duckdb.close diff --git a/tests/fast/relational_api/test_rapi_description.py b/tests/fast/relational_api/test_rapi_description.py index 8738b30a..01c8a460 100644 --- a/tests/fast/relational_api/test_rapi_description.py +++ b/tests/fast/relational_api/test_rapi_description.py @@ -9,7 +9,8 @@ def test_rapi_description(self, duckdb_cursor): names = [x[0] for x in desc] types = [x[1] for x in desc] assert names == ['a', 'b'] - assert types == ['NUMBER', 'NUMBER'] + assert types == ['INTEGER', 'BIGINT'] + assert (all([x == duckdb.NUMBER for x in types])) def test_rapi_describe(self, duckdb_cursor): np = pytest.importorskip("numpy") diff --git a/tests/fast/test_map.py b/tests/fast/test_map.py index 894f1050..4dbd1a36 100644 --- a/tests/fast/test_map.py +++ b/tests/fast/test_map.py @@ -154,9 +154,9 @@ def cast_to_string(df): con = duckdb.connect() rel = con.sql('select i from range (10) tbl(i)') - assert rel.types[0] == int + assert rel.types[0] == duckdb.NUMBER mapped_rel = rel.map(cast_to_string, schema={'i': str}) - assert mapped_rel.types[0] == str + assert mapped_rel.types[0] == duckdb.STRING def test_explicit_schema_returntype_mismatch(self): def does_nothing(df): diff --git a/tests/fast/test_result.py b/tests/fast/test_result.py index 34c8e187..af68e268 100644 --- a/tests/fast/test_result.py +++ b/tests/fast/test_result.py @@ -31,9 +31,9 @@ def test_result_describe_types(self, duckdb_cursor): rel = connection.table("test") res = rel.execute() assert res.description == [ - ('i', 'bool', None, None, None, None, None), - ('j', 'Time', None, None, None, None, None), - ('k', 'STRING', None, None, None, None, None), + ('i', 'BOOLEAN', None, None, None, None, None), + ('j', 'TIME', None, None, None, None, None), + ('k', 'VARCHAR', None, None, None, None, None), ] def test_result_timestamps(self, duckdb_cursor): @@ -64,7 +64,7 @@ def test_result_interval(self): rel = connection.table("intervals") res = rel.execute() - assert res.description == [('ivals', 'TIMEDELTA', None, None, None, None, None)] + assert res.description == [('ivals', 'INTERVAL', None, None, None, None, None)] assert res.fetchall() == [ (datetime.timedelta(days=1.0),), (datetime.timedelta(seconds=2.0),), From 657b825a59eb2d3f825ae97d3cb0e365206875ed Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Sep 2025 10:18:48 +0200 Subject: [PATCH 167/472] add the remaining types, remove dead code --- duckdb/__init__.py | 30 +++++++++++++++++-- src/duckdb_py/pyresult.cpp | 60 -------------------------------------- 2 files changed, 27 insertions(+), 63 deletions(-) diff --git a/duckdb/__init__.py b/duckdb/__init__.py index c174c30e..a73f3e91 100644 --- a/duckdb/__init__.py +++ b/duckdb/__init__.py @@ -32,9 +32,33 @@ def __repr__(self): return f"" # Define the standard DBAPI sentinels -STRING = DBAPITypeObject("STRING", {"VARCHAR", "CHAR", "TEXT"}) -NUMBER = DBAPITypeObject("NUMBER", {"INTEGER", "BIGINT", "DECIMAL", "DOUBLE"}) -DATETIME = DBAPITypeObject("DATETIME", {"DATE", "TIME", "TIMESTAMP"}) +STRING = DBAPITypeObject("STRING", {"VARCHAR"}) +NUMBER = DBAPITypeObject("NUMBER", { + "TINYINT", + "UTINYINT", + "SMALLINT", + "USMALLINT", + "INTEGER", + "UINTEGER", + "BIGINT", + "UBIGINT", + "HUGEINT", + "UHUGEINT", + "BIGNUM", + "DECIMAL", + "FLOAT", + "DOUBLE" +}) +DATETIME = DBAPITypeObject("DATETIME", { + "DATE", + "TIME", + "TIME_TZ", + "TIMESTAMP", + "TIMESTAMP_TZ", + "TIMESTAMP_NS", + "TIMESTAMP_MS", + "TIMESTAMP_SEC" +}) BINARY = DBAPITypeObject("BINARY", {"BLOB"}) ROWID = None diff --git a/src/duckdb_py/pyresult.cpp b/src/duckdb_py/pyresult.cpp index 68a569a4..a2607a12 100644 --- a/src/duckdb_py/pyresult.cpp +++ b/src/duckdb_py/pyresult.cpp @@ -514,66 +514,6 @@ py::object DuckDBPyResult::FetchArrowCapsule(idx_t rows_per_batch) { return py::capsule(stream, "arrow_array_stream", ArrowArrayStreamPyCapsuleDestructor); } -py::str GetTypeToPython(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::BOOLEAN: - return py::str("bool"); - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::UHUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: { - return py::str("NUMBER"); - } - case LogicalTypeId::VARCHAR: { - if (type.HasAlias() && type.GetAlias() == "JSON") { - return py::str("JSON"); - } else { - return py::str("STRING"); - } - } - case LogicalTypeId::BLOB: - case LogicalTypeId::BIT: - return py::str("BINARY"); - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::TIMESTAMP_SEC: { - return py::str("DATETIME"); - } - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: { - return py::str("Time"); - } - case LogicalTypeId::DATE: { - return py::str("Date"); - } - case LogicalTypeId::STRUCT: - case LogicalTypeId::MAP: - return py::str("dict"); - case LogicalTypeId::LIST: { - return py::str("list"); - } - case LogicalTypeId::INTERVAL: { - return py::str("TIMEDELTA"); - } - case LogicalTypeId::UUID: { - return py::str("UUID"); - } - default: - return py::str(type.ToString()); - } -} - py::list DuckDBPyResult::GetDescription(const vector &names, const vector &types) { py::list desc; From 87fa42d904b5a92be6b8eff64d59c3dcad553ab0 Mon Sep 17 00:00:00 2001 From: Thijs Date: Wed, 10 Sep 2025 15:15:13 +0200 Subject: [PATCH 168/472] Update duckdb/__init__.py Co-authored-by: Evert Lammerts --- duckdb/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/duckdb/__init__.py b/duckdb/__init__.py index a73f3e91..2d9cabef 100644 --- a/duckdb/__init__.py +++ b/duckdb/__init__.py @@ -19,9 +19,9 @@ def version(): ]) class DBAPITypeObject: - def __init__(self, name, types): + def __init__(self, name: str, types: set[str]) -> None: self.name = name - self.types = list(types) + self.types = types def __eq__(self, other): if isinstance(other, typing.DuckDBPyType): From b9c0b8043ef950d1ecd11d6fb3144a2d9e9ec940 Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Sep 2025 15:25:05 +0200 Subject: [PATCH 169/472] use constants, better __repr__ implementation --- duckdb/__init__.py | 61 +++++++++++++++++++++++----------------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/duckdb/__init__.py b/duckdb/__init__.py index 2d9cabef..b5e994fa 100644 --- a/duckdb/__init__.py +++ b/duckdb/__init__.py @@ -19,8 +19,7 @@ def version(): ]) class DBAPITypeObject: - def __init__(self, name: str, types: set[str]) -> None: - self.name = name + def __init__(self, types: list[typing.DuckDBPyType]) -> None: self.types = types def __eq__(self, other): @@ -29,37 +28,37 @@ def __eq__(self, other): return False def __repr__(self): - return f"" + return f"" # Define the standard DBAPI sentinels -STRING = DBAPITypeObject("STRING", {"VARCHAR"}) -NUMBER = DBAPITypeObject("NUMBER", { - "TINYINT", - "UTINYINT", - "SMALLINT", - "USMALLINT", - "INTEGER", - "UINTEGER", - "BIGINT", - "UBIGINT", - "HUGEINT", - "UHUGEINT", - "BIGNUM", - "DECIMAL", - "FLOAT", - "DOUBLE" -}) -DATETIME = DBAPITypeObject("DATETIME", { - "DATE", - "TIME", - "TIME_TZ", - "TIMESTAMP", - "TIMESTAMP_TZ", - "TIMESTAMP_NS", - "TIMESTAMP_MS", - "TIMESTAMP_SEC" -}) -BINARY = DBAPITypeObject("BINARY", {"BLOB"}) +STRING = DBAPITypeObject([typing.VARCHAR]) +NUMBER = DBAPITypeObject([ + typing.TINYINT, + typing.UTINYINT, + typing.SMALLINT, + typing.USMALLINT, + typing.INTEGER, + typing.UINTEGER, + typing.BIGINT, + typing.UBIGINT, + typing.HUGEINT, + typing.UHUGEINT, + typing.DuckDBPyType("BIGNUM"), + typing.DuckDBPyType("DECIMAL"), + typing.FLOAT, + typing.DOUBLE +]) +DATETIME = DBAPITypeObject([ + typing.DATE, + typing.TIME, + typing.TIME_TZ, + typing.TIMESTAMP, + typing.TIMESTAMP_TZ, + typing.TIMESTAMP_NS, + typing.TIMESTAMP_MS, + typing.TIMESTAMP_S +]) +BINARY = DBAPITypeObject([typing.BLOB]) ROWID = None # Classes From 292da0d8bca5415ce2e913978eade8d76eb9cff8 Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 9 Sep 2025 16:14:18 +0200 Subject: [PATCH 170/472] detect the error and throw before trying to execute, removing the 'Attempting to execute and unsuccesful or closed pending query result' --- src/duckdb_py/pyconnection.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/duckdb_py/pyconnection.cpp b/src/duckdb_py/pyconnection.cpp index 94745b75..b88b88ed 100644 --- a/src/duckdb_py/pyconnection.cpp +++ b/src/duckdb_py/pyconnection.cpp @@ -465,6 +465,9 @@ shared_ptr DuckDBPyConnection::ExecuteMany(const py::object unique_ptr DuckDBPyConnection::CompletePendingQuery(PendingQueryResult &pending_query) { PendingExecutionResult execution_result; + if (pending_query.HasError()) { + pending_query.ThrowError(); + } while (!PendingQueryResult::IsResultReady(execution_result = pending_query.ExecuteTask())) { { py::gil_scoped_acquire gil; From 6f835e663316d8bed52c84fdfa3a9a0cee943aa1 Mon Sep 17 00:00:00 2001 From: Tishj Date: Tue, 9 Sep 2025 17:45:42 +0200 Subject: [PATCH 171/472] update tests --- tests/fast/udf/test_remove_function.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/fast/udf/test_remove_function.py b/tests/fast/udf/test_remove_function.py index bb917400..15dd6b2b 100644 --- a/tests/fast/udf/test_remove_function.py +++ b/tests/fast/udf/test_remove_function.py @@ -52,7 +52,7 @@ def func(x: int) -> int: Error: Catalog Error: Scalar Function with name func does not exist! """ with pytest.raises( - duckdb.InvalidInputException, match='Attempting to execute an unsuccessful or closed pending query result' + duckdb.CatalogException, match='Scalar Function with name func does not exist!' ): res = rel.fetchall() @@ -72,7 +72,7 @@ def also_func(x: int) -> int: return x con.create_function('func', also_func) - with pytest.raises(duckdb.InvalidInputException, match='No function matches the given name'): + with pytest.raises(duckdb.BinderException, match='No function matches the given name'): res = rel2.fetchall() def test_overwrite_name(self): @@ -98,7 +98,7 @@ def other_func(x): con.remove_function('func') with pytest.raises( - duckdb.InvalidInputException, match='Catalog Error: Scalar Function with name func does not exist!' + duckdb.CatalogException, match='Catalog Error: Scalar Function with name func does not exist!' ): # Attempted to execute the relation using the 'func' function, but it was deleted rel1.fetchall() From 72428ed9f425b06eea281f74b0b07ce7dfc6535b Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Sep 2025 16:26:34 +0200 Subject: [PATCH 172/472] use py::is_operator(), to keep the NotImplemented exception behavior, but with a much smaller change --- src/duckdb_py/typing/pytype.cpp | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/duckdb_py/typing/pytype.cpp b/src/duckdb_py/typing/pytype.cpp index f679896b..009e3dab 100644 --- a/src/duckdb_py/typing/pytype.cpp +++ b/src/duckdb_py/typing/pytype.cpp @@ -325,18 +325,9 @@ static LogicalType FromObject(const py::object &object) { void DuckDBPyType::Initialize(py::handle &m) { auto type_module = py::class_>(m, "DuckDBPyType", py::module_local()); - type_module.def("__repr__", &DuckDBPyType::ToString, "Stringified representation of the type object") - .def("__eq__", [](const DuckDBPyType &self, py::handle other) -> py::object { - if (py::isinstance(other)) { - return py::bool_(self.Equals(other.cast>())); - } - else if (py::isinstance(other)) { - return py::bool_(self.EqualsString(other.cast())); - } - // Return NotImplemented for unsupported types - return py::reinterpret_borrow(Py_NotImplemented); - }); - type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other")); + type_module.def("__repr__", &DuckDBPyType::ToString, "Stringified representation of the type object"); + type_module.def("__eq__", &DuckDBPyType::Equals, "Compare two types for equality", py::arg("other"), py::is_operator()); + type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other"), py::is_operator()); type_module.def_property_readonly("id", &DuckDBPyType::GetId); type_module.def_property_readonly("children", &DuckDBPyType::Children); type_module.def(py::init<>([](const string &type_str, shared_ptr connection = nullptr) { @@ -356,7 +347,7 @@ void DuckDBPyType::Initialize(py::handle &m) { return make_shared_ptr(ltype); })); type_module.def("__getattr__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name")); - type_module.def("__getitem__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name")); + type_module.def("__getitem__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name"), py::is_operator()); py::implicitly_convertible(); py::implicitly_convertible(); From 7b279b947678a2a4fe9f623ad287ef8473239426 Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 10 Sep 2025 16:36:39 +0200 Subject: [PATCH 173/472] add 'py::is_operator()' where appropriate --- src/duckdb_py/pyexpression/initialize.cpp | 57 +++++++++++------------ 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/src/duckdb_py/pyexpression/initialize.cpp b/src/duckdb_py/pyexpression/initialize.cpp index 41f417a4..2d2d6af9 100644 --- a/src/duckdb_py/pyexpression/initialize.cpp +++ b/src/duckdb_py/pyexpression/initialize.cpp @@ -61,9 +61,9 @@ static void InitializeDunderMethods(py::class_' expr )"; - m.def("__gt__", &DuckDBPyExpression::GreaterThan, docs); + m.def("__gt__", &DuckDBPyExpression::GreaterThan, docs, py::is_operator()); docs = R"( Create a greater than or equal expression between two expressions @@ -198,7 +197,7 @@ static void InitializeDunderMethods(py::class_=' expr )"; - m.def("__ge__", &DuckDBPyExpression::GreaterThanOrEqual, docs); + m.def("__ge__", &DuckDBPyExpression::GreaterThanOrEqual, docs, py::is_operator()); docs = R"( Create a less than expression between two expressions @@ -209,7 +208,7 @@ static void InitializeDunderMethods(py::class_> &m) { From 01c972f611e3fc23ee180cc17bd50480eef4ce6a Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 11 Sep 2025 11:23:55 +0200 Subject: [PATCH 174/472] Bumped submodule --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 16671635..1a492426 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 166716352edf482a378ade0f2990abb7237ae841 +Subproject commit 1a49242623f94af2c7da399c616295a00a700078 From ecc372ceebbd83611cae07b3fbfc93c8ffba7427 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 15 Sep 2025 08:38:00 +0200 Subject: [PATCH 175/472] Bumped submodule --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index 1a492426..b3edbac8 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 1a49242623f94af2c7da399c616295a00a700078 +Subproject commit b3edbac8519f8ed04f58a6f30ec349112bdc7d6c From 2c6829ebe8431c1f4b026edc6e98c873d5b065a2 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 15 Sep 2025 10:07:31 +0200 Subject: [PATCH 176/472] Bumped submodule --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index c11813f1..25ebb000 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit c11813f1b4e89aa7096b8d18ca9eb608e27168a8 +Subproject commit 25ebb000e3f18e6346ac7a600280b7eb18624ed1 From 2a37b7c5515c5628a7df67d77d2e341de4a2d14f Mon Sep 17 00:00:00 2001 From: Tishj Date: Mon, 15 Sep 2025 11:09:32 +0200 Subject: [PATCH 177/472] pandas is not required to sniff objects --- src/duckdb_py/pandas/analyzer.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/duckdb_py/pandas/analyzer.cpp b/src/duckdb_py/pandas/analyzer.cpp index 98fa25cc..ee264524 100644 --- a/src/duckdb_py/pandas/analyzer.cpp +++ b/src/duckdb_py/pandas/analyzer.cpp @@ -502,12 +502,6 @@ bool PandasAnalyzer::Analyze(py::object column) { if (sample_size == 0) { return false; } - auto &import_cache = *DuckDBPyConnection::ImportCache(); - auto pandas = import_cache.pandas(); - if (!pandas) { - //! Pandas is not installed, no need to analyze - return false; - } bool can_convert = true; idx_t increment = GetSampleIncrement(py::len(column)); From 6dde3db33b9f7093a0a3abda752e2a1861046f3e Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Mon, 15 Sep 2025 13:25:32 +0200 Subject: [PATCH 178/472] Make sure we set CIBW_ENVIRONMENT using bash or Windows will not see it --- .github/workflows/packaging.yml | 4 ++-- .github/workflows/packaging_wheels.yml | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/packaging.yml b/.github/workflows/packaging.yml index 25cf3bdd..507c7bda 100644 --- a/.github/workflows/packaging.yml +++ b/.github/workflows/packaging.yml @@ -68,7 +68,7 @@ jobs: testsuite: all duckdb-python-sha: ${{ inputs.duckdb-python-sha != '' && inputs.duckdb-python-sha || github.sha }} duckdb-sha: ${{ inputs.duckdb-sha }} - set-version: ${{ inputs.stable-version }} + set-version: ${{ inputs.set-version }} build_wheels: name: Build and test releases @@ -78,4 +78,4 @@ jobs: testsuite: all duckdb-python-sha: ${{ inputs.duckdb-python-sha != '' && inputs.duckdb-python-sha || github.sha }} duckdb-sha: ${{ inputs.duckdb-sha }} - set-version: ${{ inputs.stable-version }} + set-version: ${{ inputs.set-version }} diff --git a/.github/workflows/packaging_wheels.yml b/.github/workflows/packaging_wheels.yml index 4c7599a6..b1a393a1 100644 --- a/.github/workflows/packaging_wheels.yml +++ b/.github/workflows/packaging_wheels.yml @@ -73,6 +73,7 @@ jobs: # Make sure that OVERRIDE_GIT_DESCRIBE is propagated to cibuildwhel's env, also when it's running linux builds - name: Set OVERRIDE_GIT_DESCRIBE + shell: bash if: ${{ inputs.set-version != '' }} run: echo "CIBW_ENVIRONMENT=OVERRIDE_GIT_DESCRIBE=${{ inputs.set-version }}" >> $GITHUB_ENV From 95c8f2b49bc92183b4baf258ff10fbe2e7f622df Mon Sep 17 00:00:00 2001 From: Tishj Date: Mon, 15 Sep 2025 14:41:48 +0200 Subject: [PATCH 179/472] separate the pyarrow filter pushdown to separate file, cleanup direct imports --- Makefile | 4 + scripts/cache_data.json | 86 ++++- scripts/generate_import_cache_cpp.py | 8 +- scripts/generate_import_cache_json.py | 1 - scripts/imports.py | 16 + src/duckdb_py/arrow/CMakeLists.txt | 3 +- src/duckdb_py/arrow/arrow_array_stream.cpp | 340 +----------------- .../arrow/pyarrow_filter_pushdown.cpp | 336 +++++++++++++++++ .../arrow/arrow_array_stream.hpp | 5 - .../arrow/pyarrow_filter_pushdown.hpp | 26 ++ .../import_cache/modules/pyarrow_module.hpp | 17 +- src/duckdb_py/pyrelation/initialize.cpp | 22 +- src/duckdb_py/typing/pytype.cpp | 9 +- tests/fast/api/test_dbapi10.py | 23 +- .../relational_api/test_rapi_description.py | 2 +- tests/fast/udf/test_remove_function.py | 4 +- 16 files changed, 530 insertions(+), 372 deletions(-) create mode 100644 Makefile create mode 100644 src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp create mode 100644 src/duckdb_py/include/duckdb_python/arrow/pyarrow_filter_pushdown.hpp diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..07008f11 --- /dev/null +++ b/Makefile @@ -0,0 +1,4 @@ +PYTHON ?= python3 + +format-main: + $(PYTHON) external/duckdb/scripts/format.py main --fix --noconfirm \ No newline at end of file diff --git a/scripts/cache_data.json b/scripts/cache_data.json index 640052cd..3dd9a1f1 100644 --- a/scripts/cache_data.json +++ b/scripts/cache_data.json @@ -7,7 +7,19 @@ "pyarrow.dataset", "pyarrow.Table", "pyarrow.RecordBatchReader", - "pyarrow.ipc" + "pyarrow.ipc", + "pyarrow.scalar", + "pyarrow.date32", + "pyarrow.time64", + "pyarrow.timestamp", + "pyarrow.uint8", + "pyarrow.uint16", + "pyarrow.uint32", + "pyarrow.uint64", + "pyarrow.binary_view", + "pyarrow.decimal32", + "pyarrow.decimal64", + "pyarrow.decimal128" ] }, "pyarrow.dataset": { @@ -709,5 +721,77 @@ "name": "duckdb_source", "children": [], "required": false + }, + "pyarrow.scalar": { + "type": "attribute", + "full_path": "pyarrow.scalar", + "name": "scalar", + "children": [] + }, + "pyarrow.date32": { + "type": "attribute", + "full_path": "pyarrow.date32", + "name": "date32", + "children": [] + }, + "pyarrow.time64": { + "type": "attribute", + "full_path": "pyarrow.time64", + "name": "time64", + "children": [] + }, + "pyarrow.timestamp": { + "type": "attribute", + "full_path": "pyarrow.timestamp", + "name": "timestamp", + "children": [] + }, + "pyarrow.uint8": { + "type": "attribute", + "full_path": "pyarrow.uint8", + "name": "uint8", + "children": [] + }, + "pyarrow.uint16": { + "type": "attribute", + "full_path": "pyarrow.uint16", + "name": "uint16", + "children": [] + }, + "pyarrow.uint32": { + "type": "attribute", + "full_path": "pyarrow.uint32", + "name": "uint32", + "children": [] + }, + "pyarrow.uint64": { + "type": "attribute", + "full_path": "pyarrow.uint64", + "name": "uint64", + "children": [] + }, + "pyarrow.binary_view": { + "type": "attribute", + "full_path": "pyarrow.binary_view", + "name": "binary_view", + "children": [] + }, + "pyarrow.decimal32": { + "type": "attribute", + "full_path": "pyarrow.decimal32", + "name": "decimal32", + "children": [] + }, + "pyarrow.decimal64": { + "type": "attribute", + "full_path": "pyarrow.decimal64", + "name": "decimal64", + "children": [] + }, + "pyarrow.decimal128": { + "type": "attribute", + "full_path": "pyarrow.decimal128", + "name": "decimal128", + "children": [] } } \ No newline at end of file diff --git a/scripts/generate_import_cache_cpp.py b/scripts/generate_import_cache_cpp.py index f902c5a5..07744e37 100644 --- a/scripts/generate_import_cache_cpp.py +++ b/scripts/generate_import_cache_cpp.py @@ -182,7 +182,7 @@ def to_string(self): for file in files: content = file.to_string() - path = f'src/include/duckdb_python/import_cache/modules/{file.file_name}' + path = f'src/duckdb_py/include/duckdb_python/import_cache/modules/{file.file_name}' import_cache_path = os.path.join(script_dir, '..', path) with open(import_cache_path, "w") as f: f.write(content) @@ -237,7 +237,9 @@ def get_root_modules(files: List[ModuleFile]): """ -import_cache_path = os.path.join(script_dir, '..', 'src/include/duckdb_python/import_cache/python_import_cache.hpp') +import_cache_path = os.path.join( + script_dir, '..', 'src/duckdb_py/include/duckdb_python/import_cache/python_import_cache.hpp' +) with open(import_cache_path, "w") as f: f.write(import_cache_file) @@ -252,7 +254,7 @@ def get_module_file_path_includes(files: List[ModuleFile]): module_includes = get_module_file_path_includes(files) modules_header = os.path.join( - script_dir, '..', 'src/include/duckdb_python/import_cache/python_import_cache_modules.hpp' + script_dir, '..', 'src/duckdb_py/include/duckdb_python/import_cache/python_import_cache_modules.hpp' ) with open(modules_header, "w") as f: f.write(module_includes) diff --git a/scripts/generate_import_cache_json.py b/scripts/generate_import_cache_json.py index 7a59e6b7..40e6a773 100644 --- a/scripts/generate_import_cache_json.py +++ b/scripts/generate_import_cache_json.py @@ -170,7 +170,6 @@ def update_json(existing: dict, new: dict) -> dict: # If both values are dictionaries, update recursively. if isinstance(new_value, dict) and isinstance(old_value, dict): - print(key) updated = update_json(old_value, new_value) existing[key] = updated else: diff --git a/scripts/imports.py b/scripts/imports.py index 6b035768..c51f53b7 100644 --- a/scripts/imports.py +++ b/scripts/imports.py @@ -6,6 +6,22 @@ pyarrow.Table pyarrow.RecordBatchReader pyarrow.ipc.MessageReader +pyarrow.scalar +pyarrow.date32 +pyarrow.time64 +pyarrow.timestamp +pyarrow.timestamp +pyarrow.timestamp +pyarrow.timestamp +pyarrow.timestamp +pyarrow.uint8 +pyarrow.uint16 +pyarrow.uint32 +pyarrow.uint64 +pyarrow.binary_view +pyarrow.decimal32 +pyarrow.decimal64 +pyarrow.decimal128 import pandas diff --git a/src/duckdb_py/arrow/CMakeLists.txt b/src/duckdb_py/arrow/CMakeLists.txt index 29b188c6..9a9188b8 100644 --- a/src/duckdb_py/arrow/CMakeLists.txt +++ b/src/duckdb_py/arrow/CMakeLists.txt @@ -1,4 +1,5 @@ # this is used for clang-tidy checks -add_library(python_arrow OBJECT arrow_array_stream.cpp arrow_export_utils.cpp) +add_library(python_arrow OBJECT arrow_array_stream.cpp arrow_export_utils.cpp + pyarrow_filter_pushdown.cpp) target_link_libraries(python_arrow PRIVATE _duckdb_dependencies) diff --git a/src/duckdb_py/arrow/arrow_array_stream.cpp b/src/duckdb_py/arrow/arrow_array_stream.cpp index 533c31ed..f9cfd1bb 100644 --- a/src/duckdb_py/arrow/arrow_array_stream.cpp +++ b/src/duckdb_py/arrow/arrow_array_stream.cpp @@ -1,22 +1,15 @@ #include "duckdb_python/arrow/arrow_array_stream.hpp" +#include "duckdb_python/arrow/pyarrow_filter_pushdown.hpp" -#include "duckdb/common/types/value_map.hpp" -#include "duckdb/planner/filter/in_filter.hpp" -#include "duckdb/planner/filter/optional_filter.hpp" +#include "duckdb_python/pyconnection/pyconnection.hpp" +#include "duckdb_python/pyrelation.hpp" +#include "duckdb_python/pyresult.hpp" +#include "duckdb/function/table/arrow.hpp" #include "duckdb/common/assert.hpp" #include "duckdb/common/common.hpp" #include "duckdb/common/limits.hpp" #include "duckdb/main/client_config.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" -#include "duckdb/planner/filter/constant_filter.hpp" -#include "duckdb/planner/filter/struct_filter.hpp" -#include "duckdb/planner/table_filter.hpp" - -#include "duckdb_python/pyconnection/pyconnection.hpp" -#include "duckdb_python/pyrelation.hpp" -#include "duckdb_python/pyresult.hpp" -#include "duckdb/function/table/arrow.hpp" namespace duckdb { @@ -56,8 +49,8 @@ py::object PythonTableArrowArrayStreamFactory::ProduceScanner(DBConfig &config, } if (has_filter) { - auto filter = TransformFilter(*filters, parameters.projected_columns.projection_map, filter_to_col, - client_properties, arrow_table); + auto filter = PyArrowFilterPushdown::TransformFilter(*filters, parameters.projected_columns.projection_map, + filter_to_col, client_properties, arrow_table); if (!filter.is(py::none())) { kwargs["filter"] = filter; } @@ -171,323 +164,4 @@ void PythonTableArrowArrayStreamFactory::GetSchema(uintptr_t factory_ptr, ArrowS GetSchemaInternal(arrow_obj_handle, schema); } -string ConvertTimestampUnit(ArrowDateTimeType unit) { - switch (unit) { - case ArrowDateTimeType::MICROSECONDS: - return "us"; - case ArrowDateTimeType::MILLISECONDS: - return "ms"; - case ArrowDateTimeType::NANOSECONDS: - return "ns"; - case ArrowDateTimeType::SECONDS: - return "s"; - default: - throw NotImplementedException("DatetimeType not recognized in ConvertTimestampUnit: %d", (int)unit); - } -} - -int64_t ConvertTimestampTZValue(int64_t base_value, ArrowDateTimeType datetime_type) { - auto input = timestamp_t(base_value); - if (!Timestamp::IsFinite(input)) { - return base_value; - } - - switch (datetime_type) { - case ArrowDateTimeType::MICROSECONDS: - return Timestamp::GetEpochMicroSeconds(input); - case ArrowDateTimeType::MILLISECONDS: - return Timestamp::GetEpochMs(input); - case ArrowDateTimeType::NANOSECONDS: - return Timestamp::GetEpochNanoSeconds(input); - case ArrowDateTimeType::SECONDS: - return Timestamp::GetEpochSeconds(input); - default: - throw NotImplementedException("DatetimeType not recognized in ConvertTimestampTZValue"); - } -} - -py::object GetScalar(Value &constant, const string &timezone_config, const ArrowType &type) { - py::object scalar = py::module_::import("pyarrow").attr("scalar"); - auto &import_cache = *DuckDBPyConnection::ImportCache(); - py::object dataset_scalar = import_cache.pyarrow.dataset().attr("scalar"); - py::object scalar_value; - switch (constant.type().id()) { - case LogicalTypeId::BOOLEAN: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::TINYINT: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::SMALLINT: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::INTEGER: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::BIGINT: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::DATE: { - py::object date_type = py::module_::import("pyarrow").attr("date32"); - return dataset_scalar(scalar(constant.GetValue(), date_type())); - } - case LogicalTypeId::TIME: { - py::object date_type = py::module_::import("pyarrow").attr("time64"); - return dataset_scalar(scalar(constant.GetValue(), date_type("us"))); - } - case LogicalTypeId::TIMESTAMP: { - py::object date_type = py::module_::import("pyarrow").attr("timestamp"); - return dataset_scalar(scalar(constant.GetValue(), date_type("us"))); - } - case LogicalTypeId::TIMESTAMP_MS: { - py::object date_type = py::module_::import("pyarrow").attr("timestamp"); - return dataset_scalar(scalar(constant.GetValue(), date_type("ms"))); - } - case LogicalTypeId::TIMESTAMP_NS: { - py::object date_type = py::module_::import("pyarrow").attr("timestamp"); - return dataset_scalar(scalar(constant.GetValue(), date_type("ns"))); - } - case LogicalTypeId::TIMESTAMP_SEC: { - py::object date_type = py::module_::import("pyarrow").attr("timestamp"); - return dataset_scalar(scalar(constant.GetValue(), date_type("s"))); - } - case LogicalTypeId::TIMESTAMP_TZ: { - auto &datetime_info = type.GetTypeInfo(); - auto base_value = constant.GetValue(); - auto arrow_datetime_type = datetime_info.GetDateTimeType(); - auto time_unit_string = ConvertTimestampUnit(arrow_datetime_type); - auto converted_value = ConvertTimestampTZValue(base_value, arrow_datetime_type); - py::object date_type = py::module_::import("pyarrow").attr("timestamp"); - return dataset_scalar(scalar(converted_value, date_type(time_unit_string, py::arg("tz") = timezone_config))); - } - case LogicalTypeId::UTINYINT: { - py::object integer_type = py::module_::import("pyarrow").attr("uint8"); - return dataset_scalar(scalar(constant.GetValue(), integer_type())); - } - case LogicalTypeId::USMALLINT: { - py::object integer_type = py::module_::import("pyarrow").attr("uint16"); - return dataset_scalar(scalar(constant.GetValue(), integer_type())); - } - case LogicalTypeId::UINTEGER: { - py::object integer_type = py::module_::import("pyarrow").attr("uint32"); - return dataset_scalar(scalar(constant.GetValue(), integer_type())); - } - case LogicalTypeId::UBIGINT: { - py::object integer_type = py::module_::import("pyarrow").attr("uint64"); - return dataset_scalar(scalar(constant.GetValue(), integer_type())); - } - case LogicalTypeId::FLOAT: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::DOUBLE: - return dataset_scalar(constant.GetValue()); - case LogicalTypeId::VARCHAR: - return dataset_scalar(constant.ToString()); - case LogicalTypeId::BLOB: { - if (type.GetTypeInfo().GetSizeType() == ArrowVariableSizeType::VIEW) { - py::object binary_view_type = py::module_::import("pyarrow").attr("binary_view"); - return dataset_scalar(scalar(py::bytes(constant.GetValueUnsafe()), binary_view_type())); - } - return dataset_scalar(py::bytes(constant.GetValueUnsafe())); - } - case LogicalTypeId::DECIMAL: { - py::object decimal_type; - auto &datetime_info = type.GetTypeInfo(); - auto bit_width = datetime_info.GetBitWidth(); - switch (bit_width) { - case DecimalBitWidth::DECIMAL_32: - decimal_type = py::module_::import("pyarrow").attr("decimal32"); - break; - case DecimalBitWidth::DECIMAL_64: - decimal_type = py::module_::import("pyarrow").attr("decimal64"); - break; - case DecimalBitWidth::DECIMAL_128: - decimal_type = py::module_::import("pyarrow").attr("decimal128"); - break; - default: - throw NotImplementedException("Unsupported precision for Arrow Decimal Type."); - } - - uint8_t width; - uint8_t scale; - constant.type().GetDecimalProperties(width, scale); - // pyarrow only allows 'decimal.Decimal' to be used to construct decimal scalars such as 0.05 - auto val = import_cache.decimal.Decimal()(constant.ToString()); - return dataset_scalar( - scalar(std::move(val), decimal_type(py::arg("precision") = width, py::arg("scale") = scale))); - } - default: - throw NotImplementedException("Unimplemented type \"%s\" for Arrow Filter Pushdown", - constant.type().ToString()); - } -} - -py::object TransformFilterRecursive(TableFilter &filter, vector column_ref, const string &timezone_config, - const ArrowType &type) { - auto &import_cache = *DuckDBPyConnection::ImportCache(); - py::object field = import_cache.pyarrow.dataset().attr("field"); - switch (filter.filter_type) { - case TableFilterType::CONSTANT_COMPARISON: { - auto &constant_filter = filter.Cast(); - auto constant_field = field(py::tuple(py::cast(column_ref))); - auto constant_value = GetScalar(constant_filter.constant, timezone_config, type); - - bool is_nan = false; - auto &constant = constant_filter.constant; - auto &constant_type = constant.type(); - if (constant_type.id() == LogicalTypeId::FLOAT) { - is_nan = Value::IsNan(constant.GetValue()); - } else if (constant_type.id() == LogicalTypeId::DOUBLE) { - is_nan = Value::IsNan(constant.GetValue()); - } - - // Special handling for NaN comparisons (to explicitly violate IEEE-754) - if (is_nan) { - switch (constant_filter.comparison_type) { - case ExpressionType::COMPARE_EQUAL: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return constant_field.attr("is_nan")(); - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_NOTEQUAL: - return constant_field.attr("is_nan")().attr("__invert__")(); - case ExpressionType::COMPARE_GREATERTHAN: - // Nothing is greater than NaN - return import_cache.pyarrow.dataset().attr("scalar")(false); - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - // Everything is less than or equal to NaN - return import_cache.pyarrow.dataset().attr("scalar")(true); - default: - throw NotImplementedException("Unsupported comparison type (%s) for NaN values", - EnumUtil::ToString(constant_filter.comparison_type)); - } - } - - switch (constant_filter.comparison_type) { - case ExpressionType::COMPARE_EQUAL: - return constant_field.attr("__eq__")(constant_value); - case ExpressionType::COMPARE_LESSTHAN: - return constant_field.attr("__lt__")(constant_value); - case ExpressionType::COMPARE_GREATERTHAN: - return constant_field.attr("__gt__")(constant_value); - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return constant_field.attr("__le__")(constant_value); - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return constant_field.attr("__ge__")(constant_value); - case ExpressionType::COMPARE_NOTEQUAL: - return constant_field.attr("__ne__")(constant_value); - default: - throw NotImplementedException("Comparison Type %s can't be an Arrow Scan Pushdown Filter", - EnumUtil::ToString(constant_filter.comparison_type)); - } - } - //! We do not pushdown is null yet - case TableFilterType::IS_NULL: { - auto constant_field = field(py::tuple(py::cast(column_ref))); - return constant_field.attr("is_null")(); - } - case TableFilterType::IS_NOT_NULL: { - auto constant_field = field(py::tuple(py::cast(column_ref))); - return constant_field.attr("is_valid")(); - } - //! We do not pushdown or conjunctions yet - case TableFilterType::CONJUNCTION_OR: { - auto &or_filter = filter.Cast(); - py::object expression = py::none(); - for (idx_t i = 0; i < or_filter.child_filters.size(); i++) { - auto &child_filter = *or_filter.child_filters[i]; - py::object child_expression = TransformFilterRecursive(child_filter, column_ref, timezone_config, type); - if (child_expression.is(py::none())) { - continue; - } - if (expression.is(py::none())) { - expression = std::move(child_expression); - } else { - expression = expression.attr("__or__")(child_expression); - } - } - return expression; - } - case TableFilterType::CONJUNCTION_AND: { - auto &and_filter = filter.Cast(); - py::object expression = py::none(); - for (idx_t i = 0; i < and_filter.child_filters.size(); i++) { - auto &child_filter = *and_filter.child_filters[i]; - py::object child_expression = TransformFilterRecursive(child_filter, column_ref, timezone_config, type); - if (child_expression.is(py::none())) { - continue; - } - if (expression.is(py::none())) { - expression = std::move(child_expression); - } else { - expression = expression.attr("__and__")(child_expression); - } - } - return expression; - } - case TableFilterType::STRUCT_EXTRACT: { - auto &struct_filter = filter.Cast(); - auto &child_name = struct_filter.child_name; - auto &struct_type_info = type.GetTypeInfo(); - auto &struct_child_type = struct_type_info.GetChild(struct_filter.child_idx); - - column_ref.push_back(child_name); - auto child_expr = TransformFilterRecursive(*struct_filter.child_filter, std::move(column_ref), timezone_config, - struct_child_type); - return child_expr; - } - case TableFilterType::OPTIONAL_FILTER: { - auto &optional_filter = filter.Cast(); - if (!optional_filter.child_filter) { - return py::none(); - } - return TransformFilterRecursive(*optional_filter.child_filter, column_ref, timezone_config, type); - } - case TableFilterType::IN_FILTER: { - auto &in_filter = filter.Cast(); - ConjunctionOrFilter or_filter; - value_set_t unique_values; - for (const auto &value : in_filter.values) { - if (unique_values.find(value) == unique_values.end()) { - unique_values.insert(value); - } - } - for (const auto &value : unique_values) { - or_filter.child_filters.push_back(make_uniq(ExpressionType::COMPARE_EQUAL, value)); - } - return TransformFilterRecursive(or_filter, column_ref, timezone_config, type); - } - case TableFilterType::DYNAMIC_FILTER: { - //! Ignore dynamic filters for now, not necessary for correctness - return py::none(); - } - default: - throw NotImplementedException("Pushdown Filter Type %s is not currently supported in PyArrow Scans", - EnumUtil::ToString(filter.filter_type)); - } -} - -py::object PythonTableArrowArrayStreamFactory::TransformFilter(TableFilterSet &filter_collection, - std::unordered_map &columns, - unordered_map filter_to_col, - const ClientProperties &config, - const ArrowTableSchema &arrow_table) { - auto &filters_map = filter_collection.filters; - - py::object expression = py::none(); - for (auto &it : filters_map) { - auto column_idx = it.first; - auto &column_name = columns[column_idx]; - - vector column_ref; - column_ref.push_back(column_name); - - D_ASSERT(columns.find(column_idx) != columns.end()); - - auto &arrow_type = arrow_table.GetColumns().at(filter_to_col.at(column_idx)); - py::object child_expression = TransformFilterRecursive(*it.second, column_ref, config.time_zone, *arrow_type); - if (child_expression.is(py::none())) { - continue; - } else if (expression.is(py::none())) { - expression = std::move(child_expression); - } else { - expression = expression.attr("__and__")(child_expression); - } - } - return expression; -} - } // namespace duckdb diff --git a/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp b/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp new file mode 100644 index 00000000..66a6e3fa --- /dev/null +++ b/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp @@ -0,0 +1,336 @@ +#include "duckdb_python/arrow/pyarrow_filter_pushdown.hpp" + +#include "duckdb/common/types/value_map.hpp" +#include "duckdb/planner/filter/in_filter.hpp" +#include "duckdb/planner/filter/optional_filter.hpp" +#include "duckdb/planner/filter/conjunction_filter.hpp" +#include "duckdb/planner/filter/constant_filter.hpp" +#include "duckdb/planner/filter/struct_filter.hpp" +#include "duckdb/planner/table_filter.hpp" + +#include "duckdb_python/pyconnection/pyconnection.hpp" +#include "duckdb_python/pyrelation.hpp" +#include "duckdb_python/pyresult.hpp" +#include "duckdb/function/table/arrow.hpp" + +namespace duckdb { + +string ConvertTimestampUnit(ArrowDateTimeType unit) { + switch (unit) { + case ArrowDateTimeType::MICROSECONDS: + return "us"; + case ArrowDateTimeType::MILLISECONDS: + return "ms"; + case ArrowDateTimeType::NANOSECONDS: + return "ns"; + case ArrowDateTimeType::SECONDS: + return "s"; + default: + throw NotImplementedException("DatetimeType not recognized in ConvertTimestampUnit: %d", (int)unit); + } +} + +int64_t ConvertTimestampTZValue(int64_t base_value, ArrowDateTimeType datetime_type) { + auto input = timestamp_t(base_value); + if (!Timestamp::IsFinite(input)) { + return base_value; + } + + switch (datetime_type) { + case ArrowDateTimeType::MICROSECONDS: + return Timestamp::GetEpochMicroSeconds(input); + case ArrowDateTimeType::MILLISECONDS: + return Timestamp::GetEpochMs(input); + case ArrowDateTimeType::NANOSECONDS: + return Timestamp::GetEpochNanoSeconds(input); + case ArrowDateTimeType::SECONDS: + return Timestamp::GetEpochSeconds(input); + default: + throw NotImplementedException("DatetimeType not recognized in ConvertTimestampTZValue"); + } +} + +py::object GetScalar(Value &constant, const string &timezone_config, const ArrowType &type) { + auto &import_cache = *DuckDBPyConnection::ImportCache(); + auto scalar = import_cache.pyarrow.scalar(); + py::handle dataset_scalar = import_cache.pyarrow.dataset().attr("scalar"); + + switch (constant.type().id()) { + case LogicalTypeId::BOOLEAN: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::TINYINT: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::SMALLINT: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::INTEGER: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::BIGINT: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::DATE: { + py::handle date_type = import_cache.pyarrow.date32(); + return dataset_scalar(scalar(constant.GetValue(), date_type())); + } + case LogicalTypeId::TIME: { + py::handle date_type = import_cache.pyarrow.time64(); + return dataset_scalar(scalar(constant.GetValue(), date_type("us"))); + } + case LogicalTypeId::TIMESTAMP: { + py::handle date_type = import_cache.pyarrow.timestamp(); + return dataset_scalar(scalar(constant.GetValue(), date_type("us"))); + } + case LogicalTypeId::TIMESTAMP_MS: { + py::handle date_type = import_cache.pyarrow.timestamp(); + return dataset_scalar(scalar(constant.GetValue(), date_type("ms"))); + } + case LogicalTypeId::TIMESTAMP_NS: { + py::handle date_type = import_cache.pyarrow.timestamp(); + return dataset_scalar(scalar(constant.GetValue(), date_type("ns"))); + } + case LogicalTypeId::TIMESTAMP_SEC: { + py::handle date_type = import_cache.pyarrow.timestamp(); + return dataset_scalar(scalar(constant.GetValue(), date_type("s"))); + } + case LogicalTypeId::TIMESTAMP_TZ: { + auto &datetime_info = type.GetTypeInfo(); + auto base_value = constant.GetValue(); + auto arrow_datetime_type = datetime_info.GetDateTimeType(); + auto time_unit_string = ConvertTimestampUnit(arrow_datetime_type); + auto converted_value = ConvertTimestampTZValue(base_value, arrow_datetime_type); + py::handle date_type = import_cache.pyarrow.timestamp(); + return dataset_scalar(scalar(converted_value, date_type(time_unit_string, py::arg("tz") = timezone_config))); + } + case LogicalTypeId::UTINYINT: { + py::handle integer_type = import_cache.pyarrow.uint8(); + return dataset_scalar(scalar(constant.GetValue(), integer_type())); + } + case LogicalTypeId::USMALLINT: { + py::handle integer_type = import_cache.pyarrow.uint16(); + return dataset_scalar(scalar(constant.GetValue(), integer_type())); + } + case LogicalTypeId::UINTEGER: { + py::handle integer_type = import_cache.pyarrow.uint32(); + return dataset_scalar(scalar(constant.GetValue(), integer_type())); + } + case LogicalTypeId::UBIGINT: { + py::handle integer_type = import_cache.pyarrow.uint64(); + return dataset_scalar(scalar(constant.GetValue(), integer_type())); + } + case LogicalTypeId::FLOAT: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::DOUBLE: + return dataset_scalar(constant.GetValue()); + case LogicalTypeId::VARCHAR: + return dataset_scalar(constant.ToString()); + case LogicalTypeId::BLOB: { + if (type.GetTypeInfo().GetSizeType() == ArrowVariableSizeType::VIEW) { + py::handle binary_view_type = import_cache.pyarrow.binary_view(); + return dataset_scalar(scalar(py::bytes(constant.GetValueUnsafe()), binary_view_type())); + } + return dataset_scalar(py::bytes(constant.GetValueUnsafe())); + } + case LogicalTypeId::DECIMAL: { + py::handle decimal_type; + auto &datetime_info = type.GetTypeInfo(); + auto bit_width = datetime_info.GetBitWidth(); + switch (bit_width) { + case DecimalBitWidth::DECIMAL_32: + decimal_type = import_cache.pyarrow.decimal32(); + break; + case DecimalBitWidth::DECIMAL_64: + decimal_type = import_cache.pyarrow.decimal64(); + break; + case DecimalBitWidth::DECIMAL_128: + decimal_type = import_cache.pyarrow.decimal128(); + break; + default: + throw NotImplementedException("Unsupported precision for Arrow Decimal Type."); + } + + uint8_t width; + uint8_t scale; + constant.type().GetDecimalProperties(width, scale); + // pyarrow only allows 'decimal.Decimal' to be used to construct decimal scalars such as 0.05 + auto val = import_cache.decimal.Decimal()(constant.ToString()); + return dataset_scalar( + scalar(std::move(val), decimal_type(py::arg("precision") = width, py::arg("scale") = scale))); + } + default: + throw NotImplementedException("Unimplemented type \"%s\" for Arrow Filter Pushdown", + constant.type().ToString()); + } +} + +py::object TransformFilterRecursive(TableFilter &filter, vector column_ref, const string &timezone_config, + const ArrowType &type) { + auto &import_cache = *DuckDBPyConnection::ImportCache(); + py::object field = import_cache.pyarrow.dataset().attr("field"); + switch (filter.filter_type) { + case TableFilterType::CONSTANT_COMPARISON: { + auto &constant_filter = filter.Cast(); + auto constant_field = field(py::tuple(py::cast(column_ref))); + auto constant_value = GetScalar(constant_filter.constant, timezone_config, type); + + bool is_nan = false; + auto &constant = constant_filter.constant; + auto &constant_type = constant.type(); + if (constant_type.id() == LogicalTypeId::FLOAT) { + is_nan = Value::IsNan(constant.GetValue()); + } else if (constant_type.id() == LogicalTypeId::DOUBLE) { + is_nan = Value::IsNan(constant.GetValue()); + } + + // Special handling for NaN comparisons (to explicitly violate IEEE-754) + if (is_nan) { + switch (constant_filter.comparison_type) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return constant_field.attr("is_nan")(); + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_NOTEQUAL: + return constant_field.attr("is_nan")().attr("__invert__")(); + case ExpressionType::COMPARE_GREATERTHAN: + // Nothing is greater than NaN + return import_cache.pyarrow.dataset().attr("scalar")(false); + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + // Everything is less than or equal to NaN + return import_cache.pyarrow.dataset().attr("scalar")(true); + default: + throw NotImplementedException("Unsupported comparison type (%s) for NaN values", + EnumUtil::ToString(constant_filter.comparison_type)); + } + } + + switch (constant_filter.comparison_type) { + case ExpressionType::COMPARE_EQUAL: + return constant_field.attr("__eq__")(constant_value); + case ExpressionType::COMPARE_LESSTHAN: + return constant_field.attr("__lt__")(constant_value); + case ExpressionType::COMPARE_GREATERTHAN: + return constant_field.attr("__gt__")(constant_value); + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + return constant_field.attr("__le__")(constant_value); + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return constant_field.attr("__ge__")(constant_value); + case ExpressionType::COMPARE_NOTEQUAL: + return constant_field.attr("__ne__")(constant_value); + default: + throw NotImplementedException("Comparison Type %s can't be an Arrow Scan Pushdown Filter", + EnumUtil::ToString(constant_filter.comparison_type)); + } + } + //! We do not pushdown is null yet + case TableFilterType::IS_NULL: { + auto constant_field = field(py::tuple(py::cast(column_ref))); + return constant_field.attr("is_null")(); + } + case TableFilterType::IS_NOT_NULL: { + auto constant_field = field(py::tuple(py::cast(column_ref))); + return constant_field.attr("is_valid")(); + } + //! We do not pushdown or conjunctions yet + case TableFilterType::CONJUNCTION_OR: { + auto &or_filter = filter.Cast(); + py::object expression = py::none(); + for (idx_t i = 0; i < or_filter.child_filters.size(); i++) { + auto &child_filter = *or_filter.child_filters[i]; + py::object child_expression = TransformFilterRecursive(child_filter, column_ref, timezone_config, type); + if (child_expression.is(py::none())) { + continue; + } + if (expression.is(py::none())) { + expression = std::move(child_expression); + } else { + expression = expression.attr("__or__")(child_expression); + } + } + return expression; + } + case TableFilterType::CONJUNCTION_AND: { + auto &and_filter = filter.Cast(); + py::object expression = py::none(); + for (idx_t i = 0; i < and_filter.child_filters.size(); i++) { + auto &child_filter = *and_filter.child_filters[i]; + py::object child_expression = TransformFilterRecursive(child_filter, column_ref, timezone_config, type); + if (child_expression.is(py::none())) { + continue; + } + if (expression.is(py::none())) { + expression = std::move(child_expression); + } else { + expression = expression.attr("__and__")(child_expression); + } + } + return expression; + } + case TableFilterType::STRUCT_EXTRACT: { + auto &struct_filter = filter.Cast(); + auto &child_name = struct_filter.child_name; + auto &struct_type_info = type.GetTypeInfo(); + auto &struct_child_type = struct_type_info.GetChild(struct_filter.child_idx); + + column_ref.push_back(child_name); + auto child_expr = TransformFilterRecursive(*struct_filter.child_filter, std::move(column_ref), timezone_config, + struct_child_type); + return child_expr; + } + case TableFilterType::OPTIONAL_FILTER: { + auto &optional_filter = filter.Cast(); + if (!optional_filter.child_filter) { + return py::none(); + } + return TransformFilterRecursive(*optional_filter.child_filter, column_ref, timezone_config, type); + } + case TableFilterType::IN_FILTER: { + auto &in_filter = filter.Cast(); + ConjunctionOrFilter or_filter; + value_set_t unique_values; + for (const auto &value : in_filter.values) { + if (unique_values.find(value) == unique_values.end()) { + unique_values.insert(value); + } + } + for (const auto &value : unique_values) { + or_filter.child_filters.push_back(make_uniq(ExpressionType::COMPARE_EQUAL, value)); + } + return TransformFilterRecursive(or_filter, column_ref, timezone_config, type); + } + case TableFilterType::DYNAMIC_FILTER: { + //! Ignore dynamic filters for now, not necessary for correctness + return py::none(); + } + default: + throw NotImplementedException("Pushdown Filter Type %s is not currently supported in PyArrow Scans", + EnumUtil::ToString(filter.filter_type)); + } +} + +py::object PyArrowFilterPushdown::TransformFilter(TableFilterSet &filter_collection, + unordered_map &columns, + unordered_map filter_to_col, + const ClientProperties &config, const ArrowTableSchema &arrow_table) { + auto &filters_map = filter_collection.filters; + + py::object expression = py::none(); + for (auto &it : filters_map) { + auto column_idx = it.first; + auto &column_name = columns[column_idx]; + + vector column_ref; + column_ref.push_back(column_name); + + D_ASSERT(columns.find(column_idx) != columns.end()); + + auto &arrow_type = arrow_table.GetColumns().at(filter_to_col.at(column_idx)); + py::object child_expression = TransformFilterRecursive(*it.second, column_ref, config.time_zone, *arrow_type); + if (child_expression.is(py::none())) { + continue; + } else if (expression.is(py::none())) { + expression = std::move(child_expression); + } else { + expression = expression.attr("__and__")(child_expression); + } + } + return expression; +} + +} // namespace duckdb diff --git a/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp b/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp index 7eb6d20b..a5895b4a 100644 --- a/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp +++ b/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp @@ -86,11 +86,6 @@ class PythonTableArrowArrayStreamFactory { DBConfig &config; private: - //! We transform a TableFilterSet to an Arrow Expression Object - static py::object TransformFilter(TableFilterSet &filters, std::unordered_map &columns, - unordered_map filter_to_col, - const ClientProperties &client_properties, const ArrowTableSchema &arrow_table); - static py::object ProduceScanner(DBConfig &config, py::object &arrow_scanner, py::handle &arrow_obj_handle, ArrowStreamParameters ¶meters, const ClientProperties &client_properties); }; diff --git a/src/duckdb_py/include/duckdb_python/arrow/pyarrow_filter_pushdown.hpp b/src/duckdb_py/include/duckdb_python/arrow/pyarrow_filter_pushdown.hpp new file mode 100644 index 00000000..4cc85a47 --- /dev/null +++ b/src/duckdb_py/include/duckdb_python/arrow/pyarrow_filter_pushdown.hpp @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb_python/arrow/pyarrow_filter_pushdown.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/arrow/arrow_wrapper.hpp" +#include "duckdb/function/table/arrow/arrow_duck_schema.hpp" +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/main/client_properties.hpp" +#include "duckdb_python/pybind11/pybind_wrapper.hpp" + +namespace duckdb { + +struct PyArrowFilterPushdown { + static py::object TransformFilter(TableFilterSet &filter_collection, unordered_map &columns, + unordered_map filter_to_col, const ClientProperties &config, + const ArrowTableSchema &arrow_table); +}; + +} // namespace duckdb diff --git a/src/duckdb_py/include/duckdb_python/import_cache/modules/pyarrow_module.hpp b/src/duckdb_py/include/duckdb_python/import_cache/modules/pyarrow_module.hpp index ccd8a16d..d3331565 100644 --- a/src/duckdb_py/include/duckdb_python/import_cache/modules/pyarrow_module.hpp +++ b/src/duckdb_py/include/duckdb_python/import_cache/modules/pyarrow_module.hpp @@ -56,7 +56,10 @@ struct PyarrowCacheItem : public PythonImportCacheItem { public: PyarrowCacheItem() : PythonImportCacheItem("pyarrow"), dataset(), Table("Table", this), - RecordBatchReader("RecordBatchReader", this), ipc(this) { + RecordBatchReader("RecordBatchReader", this), ipc(this), scalar("scalar", this), date32("date32", this), + time64("time64", this), timestamp("timestamp", this), uint8("uint8", this), uint16("uint16", this), + uint32("uint32", this), uint64("uint64", this), binary_view("binary_view", this), + decimal32("decimal32", this), decimal64("decimal64", this), decimal128("decimal128", this) { } ~PyarrowCacheItem() override { } @@ -65,6 +68,18 @@ struct PyarrowCacheItem : public PythonImportCacheItem { PythonImportCacheItem Table; PythonImportCacheItem RecordBatchReader; PyarrowIpcCacheItem ipc; + PythonImportCacheItem scalar; + PythonImportCacheItem date32; + PythonImportCacheItem time64; + PythonImportCacheItem timestamp; + PythonImportCacheItem uint8; + PythonImportCacheItem uint16; + PythonImportCacheItem uint32; + PythonImportCacheItem uint64; + PythonImportCacheItem binary_view; + PythonImportCacheItem decimal32; + PythonImportCacheItem decimal64; + PythonImportCacheItem decimal128; }; } // namespace duckdb diff --git a/src/duckdb_py/pyrelation/initialize.cpp b/src/duckdb_py/pyrelation/initialize.cpp index 7992cc17..cd1f042c 100644 --- a/src/duckdb_py/pyrelation/initialize.cpp +++ b/src/duckdb_py/pyrelation/initialize.cpp @@ -61,8 +61,8 @@ static void InitializeConsumers(py::class_ &m) { py::arg("date_as_object") = false) .def("fetch_df_chunk", &DuckDBPyRelation::FetchDFChunk, "Execute and fetch a chunk of the rows", py::arg("vectors_per_chunk") = 1, py::kw_only(), py::arg("date_as_object") = false) - .def("arrow", &DuckDBPyRelation::ToRecordBatch, "Execute and return an Arrow Record Batch Reader that yields all rows", - py::arg("batch_size") = 1000000) + .def("arrow", &DuckDBPyRelation::ToRecordBatch, + "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000) .def("fetch_arrow_table", &DuckDBPyRelation::ToArrowTable, "Execute and fetch all rows as an Arrow Table", py::arg("batch_size") = 1000000) .def("to_arrow_table", &DuckDBPyRelation::ToArrowTable, "Execute and fetch all rows as an Arrow Table", @@ -80,16 +80,16 @@ static void InitializeConsumers(py::class_ &m) { py::arg("requested_schema") = py::none()); m.def("fetch_record_batch", &DuckDBPyRelation::ToRecordBatch, "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("rows_per_batch") = 1000000) - .def("fetch_arrow_reader", &DuckDBPyRelation::ToRecordBatch, + .def("fetch_arrow_reader", &DuckDBPyRelation::ToRecordBatch, "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000) - .def("record_batch", - [](pybind11::object &self, idx_t rows_per_batch) - { - PyErr_WarnEx(PyExc_DeprecationWarning, - "record_batch() is deprecated, use fetch_record_batch() instead.", - 0); - return self.attr("fetch_record_batch")(rows_per_batch); - }, py::arg("batch_size") = 1000000); + .def( + "record_batch", + [](pybind11::object &self, idx_t rows_per_batch) { + PyErr_WarnEx(PyExc_DeprecationWarning, + "record_batch() is deprecated, use fetch_record_batch() instead.", 0); + return self.attr("fetch_record_batch")(rows_per_batch); + }, + py::arg("batch_size") = 1000000); } static void InitializeAggregates(py::class_ &m) { diff --git a/src/duckdb_py/typing/pytype.cpp b/src/duckdb_py/typing/pytype.cpp index 009e3dab..449c4c7d 100644 --- a/src/duckdb_py/typing/pytype.cpp +++ b/src/duckdb_py/typing/pytype.cpp @@ -326,8 +326,10 @@ void DuckDBPyType::Initialize(py::handle &m) { auto type_module = py::class_>(m, "DuckDBPyType", py::module_local()); type_module.def("__repr__", &DuckDBPyType::ToString, "Stringified representation of the type object"); - type_module.def("__eq__", &DuckDBPyType::Equals, "Compare two types for equality", py::arg("other"), py::is_operator()); - type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other"), py::is_operator()); + type_module.def("__eq__", &DuckDBPyType::Equals, "Compare two types for equality", py::arg("other"), + py::is_operator()); + type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other"), + py::is_operator()); type_module.def_property_readonly("id", &DuckDBPyType::GetId); type_module.def_property_readonly("children", &DuckDBPyType::Children); type_module.def(py::init<>([](const string &type_str, shared_ptr connection = nullptr) { @@ -347,7 +349,8 @@ void DuckDBPyType::Initialize(py::handle &m) { return make_shared_ptr(ltype); })); type_module.def("__getattr__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name")); - type_module.def("__getitem__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name"), py::is_operator()); + type_module.def("__getitem__", &DuckDBPyType::GetAttribute, "Get the child type by 'name'", py::arg("name"), + py::is_operator()); py::implicitly_convertible(); py::implicitly_convertible(); diff --git a/tests/fast/api/test_dbapi10.py b/tests/fast/api/test_dbapi10.py index 1fbde602..0ab69e0b 100644 --- a/tests/fast/api/test_dbapi10.py +++ b/tests/fast/api/test_dbapi10.py @@ -12,7 +12,12 @@ class TestCursorDescription(object): ["SELECT * FROM timestamps", "t", "TIMESTAMP", datetime], ["SELECT DATE '1992-09-20' AS date_col;", "date_col", "DATE", date], ["SELECT '\\xAA'::BLOB AS blob_col;", "blob_col", "BLOB", bytes], - ["SELECT {'x': 1, 'y': 2, 'z': 3} AS struct_col", "struct_col", "STRUCT(x INTEGER, y INTEGER, z INTEGER)", dict], + [ + "SELECT {'x': 1, 'y': 2, 'z': 3} AS struct_col", + "struct_col", + "STRUCT(x INTEGER, y INTEGER, z INTEGER)", + dict, + ], ["SELECT [1, 2, 3] AS list_col", "list_col", "INTEGER[]", list], ["SELECT 'Frank' AS str_col", "str_col", "VARCHAR", str], ["SELECT [1, 2, 3]::JSON AS json_col", "json_col", "JSON", str], @@ -32,15 +37,15 @@ def test_description_comparisons(self): NUMBER = duckdb.NUMBER DATETIME = duckdb.DATETIME - assert(types[1] == STRING) - assert(STRING == types[1]) - assert(types[0] != STRING) - assert((types[1] != STRING) == False) - assert((STRING != types[1]) == False) + assert types[1] == STRING + assert STRING == types[1] + assert types[0] != STRING + assert (types[1] != STRING) == False + assert (STRING != types[1]) == False - assert(types[1] in [STRING]) - assert(types[1] in [STRING, NUMBER]) - assert(types[1] not in [NUMBER, DATETIME]) + assert types[1] in [STRING] + assert types[1] in [STRING, NUMBER] + assert types[1] not in [NUMBER, DATETIME] def test_none_description(self, duckdb_empty_cursor): assert duckdb_empty_cursor.description is None diff --git a/tests/fast/relational_api/test_rapi_description.py b/tests/fast/relational_api/test_rapi_description.py index 01c8a460..41813d94 100644 --- a/tests/fast/relational_api/test_rapi_description.py +++ b/tests/fast/relational_api/test_rapi_description.py @@ -10,7 +10,7 @@ def test_rapi_description(self, duckdb_cursor): types = [x[1] for x in desc] assert names == ['a', 'b'] assert types == ['INTEGER', 'BIGINT'] - assert (all([x == duckdb.NUMBER for x in types])) + assert all([x == duckdb.NUMBER for x in types]) def test_rapi_describe(self, duckdb_cursor): np = pytest.importorskip("numpy") diff --git a/tests/fast/udf/test_remove_function.py b/tests/fast/udf/test_remove_function.py index 15dd6b2b..e67045c4 100644 --- a/tests/fast/udf/test_remove_function.py +++ b/tests/fast/udf/test_remove_function.py @@ -51,9 +51,7 @@ def func(x: int) -> int: """ Error: Catalog Error: Scalar Function with name func does not exist! """ - with pytest.raises( - duckdb.CatalogException, match='Scalar Function with name func does not exist!' - ): + with pytest.raises(duckdb.CatalogException, match='Scalar Function with name func does not exist!'): res = rel.fetchall() def test_use_after_remove_and_recreation(self): From 1d90b105ca303f344174b975151b86b860b34307 Mon Sep 17 00:00:00 2001 From: Tishj Date: Mon, 15 Sep 2025 15:29:04 +0200 Subject: [PATCH 180/472] remove Makefile --- Makefile | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 Makefile diff --git a/Makefile b/Makefile deleted file mode 100644 index 07008f11..00000000 --- a/Makefile +++ /dev/null @@ -1,4 +0,0 @@ -PYTHON ?= python3 - -format-main: - $(PYTHON) external/duckdb/scripts/format.py main --fix --noconfirm \ No newline at end of file From d0b3c8744578231d29ed409cdeb292bf3c3fdd6a Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Mon, 15 Sep 2025 09:24:49 -0600 Subject: [PATCH 181/472] docs: update readme and move contributing docs to CONTRIBUTING.md --- CONTRIBUTING.md | 362 +++++++++++++++++++++++++++++++----------------- README.md | 252 +-------------------------------- 2 files changed, 241 insertions(+), 373 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index cd1b9854..f7bd4d47 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,128 +1,242 @@ -# Contributing +# Contributing to duckdb-python -## Code of Conduct +Start by + +forking duckdb-python. -This project and everyone participating in it is governed by a [Code of Conduct](CODE_OF_CONDUCT.md). By participating, you are expected to uphold this code. Please report unacceptable behavior to [quack@duckdb.org](mailto:quack@duckdb.org). +### Cloning +After forking the duckdb-python repo we recommend you clone your fork as follows: +```shell +git clone --recurse-submodules $REPO_URL +git remote add upstream https://github.com/duckdb/duckdb-python.git +git fetch --all +``` + +... or, if you have already cloned your fork: +```shell +git submodule update --init --recursive +git remote add upstream https://github.com/duckdb/duckdb-python.git +git fetch --all +``` + +### Submodule update hook + +If you'll be switching between branches that are have the submodule set to different refs, then make your life +easier and add the git hooks in the .githooks directory to your local config: +```shell +git config --local core.hooksPath .githooks/ +``` + + +### Editable installs (general) + + It's good to be aware of the following when performing an editable install: +- `uv sync` or `uv run [tool]` perform an editable install by default. We have + configured the project so that scikit-build-core will use a persistent build-dir, but since the build itself + happens in an isolated, ephemeral environment, cmake's paths will point to non-existing directories. CMake itself + will be missing. +- You should install all development dependencies, and then build the project without build isolation, in two separate + steps. After this you can happily keep building and running, as long as you don't forget to pass in the + `--no-build-isolation` flag. + +```bash +# install all dev dependencies without building the project (needed once) +uv sync -p 3.11 --no-install-project +# build and install without build isolation +uv sync --no-build-isolation +``` + +### Editable installs (IDEs) + + If you're using an IDE then life is a little simpler. You install build dependencies and the project in the two + steps outlined above, and from that point on you can rely on e.g. CLion's cmake capabilities to do incremental + compilation and editable rebuilds. This will skip scikit-build-core's build backend and all of uv's dependency + management, so for "real" builds you better revert to the CLI. However, this should work fine for coding and debugging. + + +### Cleaning + +```shell +uv cache clean +rm -rf build .venv uv.lock +``` + + +### Building wheels and sdists + +To build a wheel and sdist for your system and the default Python version: +```bash +uv build +```` + +To build a wheel for a different Python version: +```bash +# E.g. for Python 3.9 +uv build -p 3.9 +``` + +### Running tests + + Run all pytests: +```bash +uv run --no-build-isolation pytest ./tests --verbose +``` + + Exclude the test/slow directory: +```bash +uv run --no-build-isolation pytest ./tests --verbose --ignore=./tests/slow +``` + +### Test coverage + + Run with coverage (during development you probably want to specify which tests to run): +```bash +COVERAGE=1 uv run --no-build-isolation coverage run -m pytest ./tests --verbose +``` + + The `COVERAGE` env var will compile the extension with `--coverage`, allowing us to collect coverage stats of C++ + code as well as Python code. + + Check coverage for Python code: +```bash +uvx coverage html -d htmlcov-python +uvx coverage report --format=markdown +``` + + Check coverage for C++ code (note: this will clutter your project dir with html files, consider saving them in some + other place): +```bash +uvx gcovr \ + --gcov-ignore-errors all \ + --root "$PWD" \ + --filter "${PWD}/src/duckdb_py" \ + --exclude '.*/\.cache/.*' \ + --gcov-exclude '.*/\.cache/.*' \ + --gcov-exclude '.*/external/.*' \ + --gcov-exclude '.*/site-packages/.*' \ + --exclude-unreachable-branches \ + --exclude-throw-branches \ + --html --html-details -o coverage-cpp.html \ + build/coverage/src/duckdb_py \ + --print-summary +``` + +### Typechecking and linting + +- We're not running any mypy typechecking tests at the moment +- We're not running any Ruff / linting / formatting at the moment + +### Cibuildwheel + +You can run cibuildwheel locally for Linux. E.g. limited to Python 3.9: +```bash +CIBW_BUILD='cp39-*' uvx cibuildwheel --platform linux . +``` + +### Code conventions + +* Follow the [Google Python styleguide](https://google.github.io/styleguide/pyguide.html) +* See the section on [Comments and Docstrings](https://google.github.io/styleguide/pyguide.html#s3.8-comments-and-docstrings) + +### Tooling + +This codebase is developed with the following tools: +- [Astral uv](https://docs.astral.sh/uv/) - for dependency management across all platforms we provide wheels for, + and for Python environment management. It will be hard to work on this codebase without having UV installed. +- [Scikit-build-core](https://scikit-build-core.readthedocs.io/en/latest/index.html) - the build backend for + building the extension. On the background, scikit-build-core uses cmake and ninja for compilation. +- [pybind11](https://pybind11.readthedocs.io/en/stable/index.html) - a bridge between C++ and Python. +- [CMake](https://cmake.org/) - the build system for both DuckDB itself and the DuckDB Python module. +- Cibuildwheel + +### Merging changes to pythonpkg from duckdb main + +1. Checkout main +2Identify the merge commits that brought in tags to main: +```bash +git log --graph --oneline --decorate main --simplify-by-decoration +``` + +3. Get the log of commits +```bash +git log --oneline 71c5c07cdd..c9254ecff2 -- tools/pythonpkg/ +``` + +4. Checkout v1.3-ossivalis +5. Get the log of commits +```bash +git log --oneline v1.3.0..v1.3.1 -- tools/pythonpkg/ +``` +git diff --name-status 71c5c07cdd c9254ecff2 -- tools/pythonpkg/ + +```bash +git log --oneline 71c5c07cdd..c9254ecff2 -- tools/pythonpkg/ +git diff --name-status -- tools/pythonpkg/ +``` + + +## Versioning and Releases + +The DuckDB Python package versioning and release scheme follows that of DuckDB itself. This means that a `X.Y.Z[. +postN]` release of the Python package ships the DuckDB stable release `X.Y.Z`. The optional `.postN` releases ship the same stable release of DuckDB as their predecessors plus Python package-specific fixes and / or features. + +| Types | DuckDB Version | Resulting Python Extension Version | +|------------------------------------------------------------------------|----------------|------------------------------------| +| Stable release: DuckDB stable release | `1.3.1` | `1.3.1` | +| Stable post release: DuckDB stable release + Python fixes and features | `1.3.1` | `1.3.1.postX` | +| Nightly micro: DuckDB next micro nightly + Python next micro nightly | `1.3.2.devM` | `1.3.2.devN` | +| Nightly minor: DuckDB next minor nightly + Python next minor nightly | `1.4.0.devM` | `1.4.0.devN` | + +Note that we do not ship nightly post releases (e.g. we don't ship `1.3.1.post2.dev3`). + +### Branch and Tag Strategy + +We cut releases as follows: + +| Type | Tag | How | +|----------------------|--------------|---------------------------------------------------------------------------------| +| Stable minor release | vX.Y.0 | Adding a tag on `main` | +| Stable micro release | vX.Y.Z | Adding a tag on a minor release branch (e.g. `v1.3-ossivalis`) | +| Stable post release | vX.Y.Z-postN | Adding a tag on a post release branch (e.g. `v1.3.1-post`) | +| Nightly micro | _not tagged_ | Combining HEAD of the _micro_ release branches of DuckDB and the Python package | +| Nightly minor | _not tagged_ | Combining HEAD of the _minor_ release branches of DuckDB and the Python package | + +### Release Runbooks + +We cut a new **stable minor release** with the following steps: +1. Create a PR on `main` to pin the DuckDB submodule to the tag of its current release. +1. Iff all tests pass in CI, merge the PR. +1. Manually start the release workflow with the hash of this commit, and the tag name. +1. Iff all goes well, create a new PR to let the submodule track DuckDB main. + +We cut a new **stable micro release** with the following steps: +1. Create a PR on the minor release branch to pin the DuckDB submodule to the tag of its current release. +1. Iff all tests pass in CI, merge the PR. +1. Manually start the release workflow with the hash of this commit, and the tag name. +1. Iff all goes well, create a new PR to let the submodule track DuckDB's minor release branch. + +We cut a new **stable post release** with the following steps: +1. Create a PR on the post release branch to pin the DuckDB submodule to the tag of its current release. +1. Iff all tests pass in CI, merge the PR. +1. Manually start the release workflow with the hash of this commit, and the tag name. +1. Iff all goes well, create a new PR to let the submodule track DuckDB's minor release branch. -## **Did you find a bug?** - -* **Ensure the bug was not already reported** by searching on GitHub under [Issues](https://github.com/duckdb/duckdb/issues). -* If you're unable to find an open issue addressing the problem, [open a new one](https://github.com/duckdb/duckdb/issues/new/choose). Be sure to include a **title and clear description**, as much relevant information as possible, and a **code sample** or an **executable test case** demonstrating the expected behavior that is not occurring. - -## **Did you write a patch that fixes a bug?** - -* Great! -* If possible, add a unit test case to make sure the issue does not occur again. -* Make sure you run the code formatter (`make format-fix`). -* Open a new GitHub pull request with the patch. -* Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable. - -## Outside Contributors - -* Discuss your intended changes with the core team on Github -* Announce that you are working or want to work on a specific issue -* Avoid large pull requests - they are much less likely to be merged as they are incredibly hard to review - -## Pull Requests - -* Do not commit/push directly to the main branch. Instead, create a fork and file a pull request. -* When maintaining a branch, merge frequently with the main. -* When maintaining a branch, submit pull requests to the main frequently. -* If you are working on a bigger issue try to split it up into several smaller issues. -* Please do not open "Draft" pull requests. Rather, use issues or discussion topics to discuss whatever needs discussing. -* We reserve full and final discretion over whether or not we will merge a pull request. Adhering to these guidelines is not a complete guarantee that your pull request will be merged. - -## CI for pull requests - -* Pull requests will need to pass all continuous integration checks before merging. -* For faster iteration and more control, consider running CI on your own fork or when possible directly locally. -* Submitting changes to an open pull request will move it to 'draft' state. -* Pull requests will get a complete run on the main repo CI only when marked as 'ready for review' (via Web UI, button on bottom right). - -## Nightly CI - -* Packages creation and long running tests will be performed during a nightly run -* On your fork you can trigger long running tests (NightlyTests.yml) for any branch following information from https://docs.github.com/en/actions/using-workflows/manually-running-a-workflow#running-a-workflow - -## Building - -* To build the project, run `make`. -* To build the project for debugging, run `make debug`. -* For parallel builds, you can use the [Ninja](https://ninja-build.org/) build system: `GEN=ninja make`. - * The default number of parallel processes can lock up the system depending on the CPU-to-memory ratio. If this happens, restrict the maximum number of build processes: `CMAKE_BUILD_PARALLEL_LEVEL=4 GEN=ninja make`. - * Without using Ninja, build times can still be reduced by setting `CMAKE_BUILD_PARALLEL_LEVEL=$(nproc)`. - -## Testing - -* Unit tests can be written either using the sqllogictest framework (`.test` files) or in C++ directly. We **strongly** prefer tests to be written using the sqllogictest framework. Only write tests in C++ if you absolutely need to (e.g. when testing concurrent connections or other exotic behavior). -* Documentation for the testing framework can be found [here](https://duckdb.org/dev/testing). -* Write many tests. -* Test with different types, especially numerics, strings and complex nested types. -* Try to test unexpected/incorrect usage as well, instead of only the happy path. -* `make unit` runs the **fast** unit tests (~one minute), `make allunit` runs **all** unit tests (~one hour). -* Make sure **all** unit tests pass before sending a PR. -* Slower tests should be added to the **all** unit tests. You can do this by naming the test file `.test_slow` in the sqllogictests, or by adding `[.]` after the test group in the C++ tests. -* Look at the code coverage report of your branch and attempt to cover all code paths in the fast unit tests. Attempt to trigger exceptions as well. It is acceptable to have some exceptions not triggered (e.g. out of memory exceptions or type switch exceptions), but large branches of code should always be either covered or removed. -* DuckDB uses GitHub Actions as its continuous integration (CI) tool. You also have the option to run GitHub Actions on your forked repository. For detailed instructions, you can refer to the [GitHub documentation](https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/enabling-features-for-your-repository/managing-github-actions-settings-for-a-repository). Before running GitHub Actions, please ensure that you have all the Git tags from the duckdb/duckdb repository. To accomplish this, execute the following commands `git fetch --tags` and then -`git push --tags` These commands will fetch all the git tags from the duckdb/duckdb repository and push them to your forked repository. This ensures that you have all the necessary tags available for your GitHub Actions workflow. - -## Formatting - -* Use tabs for indentation, spaces for alignment. -* Lines should not exceed 120 columns. -* To make sure the formatting is consistent, please use version 11.0.1, installable through `python3 -m pip install clang-format==11.0.1` or `pipx install clang-format==11.0.1`. -* `clang_format` and `black` enforce these rules automatically, use `make format-fix` to run the formatter. -* The project also comes with an [`.editorconfig` file](https://editorconfig.org/) that corresponds to these rules. - -## C++ Guidelines - -* Do not use `malloc`, prefer the use of smart pointers. Keywords `new` and `delete` are a code smell. -* Strongly prefer the use of `unique_ptr` over `shared_ptr`, only use `shared_ptr` if you **absolutely** have to. -* Use `const` whenever possible. -* Do **not** import namespaces (e.g. `using std`). -* All functions in source files in the core (`src` directory) should be part of the `duckdb` namespace. -* When overriding a virtual method, avoid repeating virtual and always use `override` or `final`. -* Use `[u]int(8|16|32|64)_t` instead of `int`, `long`, `uint` etc. Use `idx_t` instead of `size_t` for offsets/indices/counts of any kind. -* Prefer using references over pointers as arguments. -* Use `const` references for arguments of non-trivial objects (e.g. `std::vector`, ...). -* Use C++11 for loops when possible: `for (const auto& item : items) {...}` -* Use braces for indenting `if` statements and loops. Avoid single-line if statements and loops, especially nested ones. -* **Class Layout:** Start out with a `public` block containing the constructor and public variables, followed by a `public` block containing public methods of the class. After that follow any private functions and private variables. For example: - ```cpp - class MyClass { - public: - MyClass(); - - int my_public_variable; - - public: - void MyFunction(); - - private: - void MyPrivateFunction(); - - private: - int my_private_variable; - }; - ``` -* Avoid [unnamed magic numbers](https://en.wikipedia.org/wiki/Magic_number_(programming)). Instead, use named variables that are stored in a `constexpr`. -* [Return early](https://medium.com/swlh/return-early-pattern-3d18a41bba8). Avoid deep nested branches. -* Do not include commented out code blocks in pull requests. - -## Error Handling - -* Use exceptions **only** when an error is encountered that terminates a query (e.g. parser error, table not found). Exceptions should only be used for **exceptional** situations. For regular errors that do not break the execution flow (e.g. errors you **expect** might occur) use a return value instead. -* Try to add test cases that trigger exceptions. If an exception cannot be easily triggered using a test case then it should probably be an assertion. This is not always true (e.g. out of memory errors are exceptions, but are very hard to trigger). -* Use `D_ASSERT` to assert. Use **assert** only when failing the assert means a programmer error. Assert should never be triggered by user input. Avoid code like `D_ASSERT(a > b + 3);` without comments or context. -* Assert liberally, but make it clear with comments next to the assert what went wrong when the assert is triggered. - -## Naming Conventions - -* Choose descriptive names. Avoid single-letter variable names. -* Files: lowercase separated by underscores, e.g., abstract_operator.cpp -* Types (classes, structs, enums, typedefs, using): CamelCase starting with uppercase letter, e.g., BaseColumn -* Variables: lowercase separated by underscores, e.g., chunk_size -* Functions: CamelCase starting with uppercase letter, e.g., GetChunk -* Avoid `i`, `j`, etc. in **nested** loops. Prefer to use e.g. **column_idx**, **check_idx**. In a **non-nested** loop it is permissible to use **i** as iterator index. -* These rules are partially enforced by `clang-tidy`. +### Dynamic Versioning Integration + +The package uses `setuptools_scm` with `scikit-build` for automatic version determination, and implements a custom +versioning scheme. + +- **pyproject.toml configuration**: + ```toml + [tool.scikit-build] + metadata.version.provider = "scikit_build_core.metadata.setuptools_scm" + + [tool.setuptools_scm] + version_scheme = "duckdb_packaging._setuptools_scm_version:version_scheme" + ``` + +- **Environment variables**: + - `MAIN_BRANCH_VERSIONING=0`: Use release branch versioning (patch increments) + - `MAIN_BRANCH_VERSIONING=1`: Use main branch versioning (minor increments) + - `OVERRIDE_GIT_DESCRIBE`: Override version detection \ No newline at end of file diff --git a/README.md b/README.md index 5f81ff5e..7dd7895a 100644 --- a/README.md +++ b/README.md @@ -19,14 +19,7 @@ API Docs (Python)

-# DuckDB: A Fast, In-Process, Portable, Open Source, Analytical Database System - -* **Simple**: DuckDB is easy to install and deploy. It has zero external dependencies and runs in-process in its host application or as a single binary. -* **Portable**: DuckDB runs on Linux, macOS, Windows, Android, iOS and all popular hardware architectures. It has idiomatic client APIs for major programming languages. -* **Feature-rich**: DuckDB offers a rich SQL dialect. It can read and write file formats such as CSV, Parquet, and JSON, to and from the local file system and remote endpoints such as S3 buckets. -* **Fast**: DuckDB runs analytical queries at blazing speed thanks to its columnar engine, which supports parallel execution and can process larger-than-memory workloads. -* **Extensible**: DuckDB is extensible by third-party features such as new data types, functions, file formats and new SQL syntax. User contributions are available as community extensions. -* **Free**: DuckDB and its core extensions are open-source under the permissive MIT License. The intellectual property of the project is held by the DuckDB Foundation. +# The Official Python bindings to [DuckDB](https://github.com/duckdb/duckdb) ## Installation @@ -42,245 +35,6 @@ Install with all optional dependencies: pip install 'duckdb[all]' ``` -## Development - -Start by - -forking duckdb-python. - -### Cloning - -After forking the duckdb-python repo we recommend you clone your fork as follows: -```shell -git clone --recurse-submodules $REPO_URL -git remote add upstream https://github.com/duckdb/duckdb-python.git -git fetch --all -``` - -... or, if you have already cloned your fork: -```shell -git submodule update --init --recursive -git remote add upstream https://github.com/duckdb/duckdb-python.git -git fetch --all -``` - -### Submodule update hook - -If you'll be switching between branches that are have the submodule set to different refs, then make your life -easier and add the git hooks in the .githooks directory to your local config: -```shell -git config --local core.hooksPath .githooks/ -``` - - -### Editable installs (general) - - It's good to be aware of the following when performing an editable install: -- `uv sync` or `uv run [tool]` perform an editable install by default. We have - configured the project so that scikit-build-core will use a persistent build-dir, but since the build itself - happens in an isolated, ephemeral environment, cmake's paths will point to non-existing directories. CMake itself - will be missing. -- You should install all development dependencies, and then build the project without build isolation, in two separate - steps. After this you can happily keep building and running, as long as you don't forget to pass in the - `--no-build-isolation` flag. - -```bash -# install all dev dependencies without building the project (needed once) -uv sync -p 3.11 --no-install-project -# build and install without build isolation -uv sync --no-build-isolation -``` - -### Editable installs (IDEs) - - If you're using an IDE then life is a little simpler. You install build dependencies and the project in the two - steps outlined above, and from that point on you can rely on e.g. CLion's cmake capabilities to do incremental - compilation and editable rebuilds. This will skip scikit-build-core's build backend and all of uv's dependency - management, so for "real" builds you better revert to the CLI. However, this should work fine for coding and debugging. - - -### Cleaning - -```shell -uv cache clean -rm -rf build .venv uv.lock -``` - - -### Building wheels and sdists - -To build a wheel and sdist for your system and the default Python version: -```bash -uv build -```` - -To build a wheel for a different Python version: -```bash -# E.g. for Python 3.9 -uv build -p 3.9 -``` - -### Running tests - - Run all pytests: -```bash -uv run --no-build-isolation pytest ./tests --verbose -``` - - Exclude the test/slow directory: -```bash -uv run --no-build-isolation pytest ./tests --verbose --ignore=./tests/slow -``` - -### Test coverage - - Run with coverage (during development you probably want to specify which tests to run): -```bash -COVERAGE=1 uv run --no-build-isolation coverage run -m pytest ./tests --verbose -``` - - The `COVERAGE` env var will compile the extension with `--coverage`, allowing us to collect coverage stats of C++ - code as well as Python code. - - Check coverage for Python code: -```bash -uvx coverage html -d htmlcov-python -uvx coverage report --format=markdown -``` - - Check coverage for C++ code (note: this will clutter your project dir with html files, consider saving them in some - other place): -```bash -uvx gcovr \ - --gcov-ignore-errors all \ - --root "$PWD" \ - --filter "${PWD}/src/duckdb_py" \ - --exclude '.*/\.cache/.*' \ - --gcov-exclude '.*/\.cache/.*' \ - --gcov-exclude '.*/external/.*' \ - --gcov-exclude '.*/site-packages/.*' \ - --exclude-unreachable-branches \ - --exclude-throw-branches \ - --html --html-details -o coverage-cpp.html \ - build/coverage/src/duckdb_py \ - --print-summary -``` - -### Typechecking and linting - -- We're not running any mypy typechecking tests at the moment -- We're not running any Ruff / linting / formatting at the moment - -### Cibuildwheel - -You can run cibuildwheel locally for Linux. E.g. limited to Python 3.9: -```bash -CIBW_BUILD='cp39-*' uvx cibuildwheel --platform linux . -``` - -### Code conventions - -* Follow the [Google Python styleguide](https://google.github.io/styleguide/pyguide.html) -* See the section on [Comments and Docstrings](https://google.github.io/styleguide/pyguide.html#s3.8-comments-and-docstrings) - -### Tooling - -This codebase is developed with the following tools: -- [Astral uv](https://docs.astral.sh/uv/) - for dependency management across all platforms we provide wheels for, - and for Python environment management. It will be hard to work on this codebase without having UV installed. -- [Scikit-build-core](https://scikit-build-core.readthedocs.io/en/latest/index.html) - the build backend for - building the extension. On the background, scikit-build-core uses cmake and ninja for compilation. -- [pybind11](https://pybind11.readthedocs.io/en/stable/index.html) - a bridge between C++ and Python. -- [CMake](https://cmake.org/) - the build system for both DuckDB itself and the DuckDB Python module. -- Cibuildwheel - -### Merging changes to pythonpkg from duckdb main - -1. Checkout main -2Identify the merge commits that brought in tags to main: -```bash -git log --graph --oneline --decorate main --simplify-by-decoration -``` - -3. Get the log of commits -```bash -git log --oneline 71c5c07cdd..c9254ecff2 -- tools/pythonpkg/ -``` - -4. Checkout v1.3-ossivalis -5. Get the log of commits -```bash -git log --oneline v1.3.0..v1.3.1 -- tools/pythonpkg/ -``` -git diff --name-status 71c5c07cdd c9254ecff2 -- tools/pythonpkg/ - -```bash -git log --oneline 71c5c07cdd..c9254ecff2 -- tools/pythonpkg/ -git diff --name-status -- tools/pythonpkg/ -``` - - -## Versioning and Releases - -The DuckDB Python package versioning and release scheme follows that of DuckDB itself. This means that a `X.Y.Z[. -postN]` release of the Python package ships the DuckDB stable release `X.Y.Z`. The optional `.postN` releases ship the same stable release of DuckDB as their predecessors plus Python package-specific fixes and / or features. - -| Types | DuckDB Version | Resulting Python Extension Version | -|------------------------------------------------------------------------|----------------|------------------------------------| -| Stable release: DuckDB stable release | `1.3.1` | `1.3.1` | -| Stable post release: DuckDB stable release + Python fixes and features | `1.3.1` | `1.3.1.postX` | -| Nightly micro: DuckDB next micro nightly + Python next micro nightly | `1.3.2.devM` | `1.3.2.devN` | -| Nightly minor: DuckDB next minor nightly + Python next minor nightly | `1.4.0.devM` | `1.4.0.devN` | - -Note that we do not ship nightly post releases (e.g. we don't ship `1.3.1.post2.dev3`). - -### Branch and Tag Strategy - -We cut releases as follows: - -| Type | Tag | How | -|----------------------|--------------|---------------------------------------------------------------------------------| -| Stable minor release | vX.Y.0 | Adding a tag on `main` | -| Stable micro release | vX.Y.Z | Adding a tag on a minor release branch (e.g. `v1.3-ossivalis`) | -| Stable post release | vX.Y.Z-postN | Adding a tag on a post release branch (e.g. `v1.3.1-post`) | -| Nightly micro | _not tagged_ | Combining HEAD of the _micro_ release branches of DuckDB and the Python package | -| Nightly minor | _not tagged_ | Combining HEAD of the _minor_ release branches of DuckDB and the Python package | - -### Release Runbooks - -We cut a new **stable minor release** with the following steps: -1. Create a PR on `main` to pin the DuckDB submodule to the tag of its current release. -1. Iff all tests pass in CI, merge the PR. -1. Manually start the release workflow with the hash of this commit, and the tag name. -1. Iff all goes well, create a new PR to let the submodule track DuckDB main. - -We cut a new **stable micro release** with the following steps: -1. Create a PR on the minor release branch to pin the DuckDB submodule to the tag of its current release. -1. Iff all tests pass in CI, merge the PR. -1. Manually start the release workflow with the hash of this commit, and the tag name. -1. Iff all goes well, create a new PR to let the submodule track DuckDB's minor release branch. - -We cut a new **stable post release** with the following steps: -1. Create a PR on the post release branch to pin the DuckDB submodule to the tag of its current release. -1. Iff all tests pass in CI, merge the PR. -1. Manually start the release workflow with the hash of this commit, and the tag name. -1. Iff all goes well, create a new PR to let the submodule track DuckDB's minor release branch. - -### Dynamic Versioning Integration - -The package uses `setuptools_scm` with `scikit-build` for automatic version determination, and implements a custom -versioning scheme. - -- **pyproject.toml configuration**: - ```toml - [tool.scikit-build] - metadata.version.provider = "scikit_build_core.metadata.setuptools_scm" - - [tool.setuptools_scm] - version_scheme = "duckdb_packaging._setuptools_scm_version:version_scheme" - ``` +## Contributing -- **Environment variables**: - - `MAIN_BRANCH_VERSIONING=0`: Use release branch versioning (patch increments) - - `MAIN_BRANCH_VERSIONING=1`: Use main branch versioning (minor increments) - - `OVERRIDE_GIT_DESCRIBE`: Override version detection +See the [CONTRIBUTING.md](CONTRIBUTING.md) for instructions on how to set up a development environment. From 9e234f45846eb6cff321ae96136bb881e386bac0 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Mon, 15 Sep 2025 12:18:13 -0600 Subject: [PATCH 182/472] contributing.md: keep some of the old general guidelines, re-organize into sections depending on phases of development --- CONTRIBUTING.md | 125 +++++++++++++++++++++++++++++++----------------- 1 file changed, 82 insertions(+), 43 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f7bd4d47..669f7b62 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,10 +1,50 @@ # Contributing to duckdb-python -Start by - -forking duckdb-python. +## General Guidelines -### Cloning +### **Did you find a bug?** + +* **Ensure the bug was not already reported** by searching on GitHub under [Issues](https://github.com/duckdb/duckdb-python/issues). +* If you're unable to find an open issue addressing the problem, [open a new one](https://github.com/duckdb/duckdb-python/issues/new/choose). Be sure to include a **title and clear description**, as much relevant information as possible, and a **code sample** or an **executable test case** demonstrating the expected behavior that is not occurring. + +### **Did you write a patch that fixes a bug?** + +* Great! +* If possible, add a unit test case to make sure the issue does not occur again. +* Open a new GitHub pull request with the patch. +* Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable. + +### Outside Contributors + +* Discuss your intended changes with the core team on Github +* Announce that you are working or want to work on a specific issue +* Avoid large pull requests - they are much less likely to be merged as they are incredibly hard to review + +### Pull Requests + +* Do not commit/push directly to the main branch. Instead, create a fork and file a pull request. +* When maintaining a branch, merge frequently with the main. +* When maintaining a branch, submit pull requests to the main frequently. +* If you are working on a bigger issue try to split it up into several smaller issues. +* Please do not open "Draft" pull requests. Rather, use issues or discussion topics to discuss whatever needs discussing. +* We reserve full and final discretion over whether or not we will merge a pull request. Adhering to these guidelines is not a complete guarantee that your pull request will be merged. + +### CI for pull requests + +* Pull requests will need to pass all continuous integration checks before merging. +* For faster iteration and more control, consider running CI on your own fork or when possible directly locally. +* Submitting changes to an open pull request will move it to 'draft' state. +* Pull requests will get a complete run on the main repo CI only when marked as 'ready for review' (via Web UI, button on bottom right). + +### Nightly CI + +* Packages creation and long running tests will be performed during a nightly run +* On your fork you can trigger long running tests (NightlyTests.yml) for any branch following information from https://docs.github.com/en/actions/using-workflows/manually-running-a-workflow#running-a-workflow + +## Setting up a development environment + +Start by [forking duckdb-python](https://github.com/duckdb/duckdb-python/fork) into +a personal repository. After forking the duckdb-python repo we recommend you clone your fork as follows: ```shell @@ -20,6 +60,9 @@ git remote add upstream https://github.com/duckdb/duckdb-python.git git fetch --all ``` +The submodule stuff is needed because we vendor the core DuckDB repository as a git submodule, +and to build the python package we also need to build DuckDB itself. + ### Submodule update hook If you'll be switching between branches that are have the submodule set to different refs, then make your life @@ -28,10 +71,10 @@ easier and add the git hooks in the .githooks directory to your local config: git config --local core.hooksPath .githooks/ ``` - ### Editable installs (general) - It's good to be aware of the following when performing an editable install: +It's good to be aware of the following when performing an editable install: + - `uv sync` or `uv run [tool]` perform an editable install by default. We have configured the project so that scikit-build-core will use a persistent build-dir, but since the build itself happens in an isolated, ephemeral environment, cmake's paths will point to non-existing directories. CMake itself @@ -49,33 +92,31 @@ uv sync --no-build-isolation ### Editable installs (IDEs) - If you're using an IDE then life is a little simpler. You install build dependencies and the project in the two - steps outlined above, and from that point on you can rely on e.g. CLion's cmake capabilities to do incremental - compilation and editable rebuilds. This will skip scikit-build-core's build backend and all of uv's dependency - management, so for "real" builds you better revert to the CLI. However, this should work fine for coding and debugging. +If you're using an IDE then life is a little simpler. You install build dependencies and the project in the two +steps outlined above, and from that point on you can rely on e.g. CLion's cmake capabilities to do incremental +compilation and editable rebuilds. This will skip scikit-build-core's build backend and all of uv's dependency +management, so for "real" builds you better revert to the CLI. However, this should work fine for coding and debugging. +## Day to day development -### Cleaning +After setting up the development environment, these are the most common tasks you'll be performing. +### Tooling +This codebase is developed with the following tools: +- [Astral uv](https://docs.astral.sh/uv/) - for dependency management across all platforms we provide wheels for, + and for Python environment management. It will be hard to work on this codebase without having UV installed. +- [Scikit-build-core](https://scikit-build-core.readthedocs.io/en/latest/index.html) - the build backend for + building the extension. On the background, scikit-build-core uses cmake and ninja for compilation. +- [pybind11](https://pybind11.readthedocs.io/en/stable/index.html) - a bridge between C++ and Python. +- [CMake](https://cmake.org/) - the build system for both DuckDB itself and the DuckDB Python module. +- Cibuildwheel + +### Cleaning ```shell uv cache clean rm -rf build .venv uv.lock ``` - -### Building wheels and sdists - -To build a wheel and sdist for your system and the default Python version: -```bash -uv build -```` - -To build a wheel for a different Python version: -```bash -# E.g. for Python 3.9 -uv build -p 3.9 -``` - ### Running tests Run all pytests: @@ -122,10 +163,25 @@ uvx gcovr \ --print-summary ``` -### Typechecking and linting +### Typechecking, linting, style, and formatting - We're not running any mypy typechecking tests at the moment - We're not running any Ruff / linting / formatting at the moment +- Follow the [Google Python styleguide](https://google.github.io/styleguide/pyguide.html) +- See the section on [Comments and Docstrings](https://google.github.io/styleguide/pyguide.html#s3.8-comments-and-docstrings) + +### Building wheels and sdists + +To build a wheel and sdist for your system and the default Python version: +```bash +uv build +```` + +To build a wheel for a different Python version: +```bash +# E.g. for Python 3.9 +uv build -p 3.9 +``` ### Cibuildwheel @@ -134,22 +190,6 @@ You can run cibuildwheel locally for Linux. E.g. limited to Python 3.9: CIBW_BUILD='cp39-*' uvx cibuildwheel --platform linux . ``` -### Code conventions - -* Follow the [Google Python styleguide](https://google.github.io/styleguide/pyguide.html) -* See the section on [Comments and Docstrings](https://google.github.io/styleguide/pyguide.html#s3.8-comments-and-docstrings) - -### Tooling - -This codebase is developed with the following tools: -- [Astral uv](https://docs.astral.sh/uv/) - for dependency management across all platforms we provide wheels for, - and for Python environment management. It will be hard to work on this codebase without having UV installed. -- [Scikit-build-core](https://scikit-build-core.readthedocs.io/en/latest/index.html) - the build backend for - building the extension. On the background, scikit-build-core uses cmake and ninja for compilation. -- [pybind11](https://pybind11.readthedocs.io/en/stable/index.html) - a bridge between C++ and Python. -- [CMake](https://cmake.org/) - the build system for both DuckDB itself and the DuckDB Python module. -- Cibuildwheel - ### Merging changes to pythonpkg from duckdb main 1. Checkout main @@ -175,7 +215,6 @@ git log --oneline 71c5c07cdd..c9254ecff2 -- tools/pythonpkg/ git diff --name-status -- tools/pythonpkg/ ``` - ## Versioning and Releases The DuckDB Python package versioning and release scheme follows that of DuckDB itself. This means that a `X.Y.Z[. From f9c741a365e62f3f9be12ec93348e94b7d67ffb7 Mon Sep 17 00:00:00 2001 From: Paul Timmins Date: Sun, 14 Sep 2025 05:07:12 +0000 Subject: [PATCH 183/472] chore: Add 3.14 and 3.14t builds: update GHA matrix, bump uv and cibuildwheel to include 3.14rc2, and handle new pandas warning, and mark unsupported packages as < 3.14. --- .github/workflows/cleanup_pypi.yml | 2 +- .github/workflows/coverage.yml | 2 +- .github/workflows/packaging_sdist.yml | 2 +- .github/workflows/packaging_wheels.yml | 9 ++++---- pyproject.toml | 32 +++++++++++++++++--------- tests/pytest.ini | 2 ++ 6 files changed, 31 insertions(+), 18 deletions(-) diff --git a/.github/workflows/cleanup_pypi.yml b/.github/workflows/cleanup_pypi.yml index c4300be3..e290faae 100644 --- a/.github/workflows/cleanup_pypi.yml +++ b/.github/workflows/cleanup_pypi.yml @@ -52,7 +52,7 @@ jobs: - name: Install Astral UV uses: astral-sh/setup-uv@v6 with: - version: "0.7.14" + version: "0.8.16" - name: Run Cleanup env: diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index fdd2a838..ab696897 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -70,7 +70,7 @@ jobs: - name: Install Astral UV and enable the cache uses: astral-sh/setup-uv@v6 with: - version: "0.7.14" + version: "0.8.16" python-version: 3.9 enable-cache: true cache-suffix: -${{ github.workflow }} diff --git a/.github/workflows/packaging_sdist.yml b/.github/workflows/packaging_sdist.yml index 2723b437..87923f4c 100644 --- a/.github/workflows/packaging_sdist.yml +++ b/.github/workflows/packaging_sdist.yml @@ -58,7 +58,7 @@ jobs: - name: Install Astral UV uses: astral-sh/setup-uv@v6 with: - version: "0.7.14" + version: "0.8.16" python-version: 3.11 - name: Build sdist diff --git a/.github/workflows/packaging_wheels.yml b/.github/workflows/packaging_wheels.yml index 4c7599a6..00e5cdea 100644 --- a/.github/workflows/packaging_wheels.yml +++ b/.github/workflows/packaging_wheels.yml @@ -30,7 +30,7 @@ jobs: strategy: fail-fast: false matrix: - python: [ cp39, cp310, cp311, cp312, cp313 ] + python: [ cp39, cp310, cp311, cp312, cp313, cp314, cp314t ] platform: - { os: windows-2025, arch: amd64, cibw_system: win } - { os: ubuntu-24.04, arch: x86_64, cibw_system: manylinux } @@ -79,16 +79,17 @@ jobs: # Install Astral UV, which will be used as build-frontend for cibuildwheel - uses: astral-sh/setup-uv@v6 with: - version: "0.7.14" + version: "0.8.16" enable-cache: false cache-suffix: -${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} + python-version: ${{ matrix.python }} - name: Build${{ inputs.testsuite != 'none' && ' and test ' || ' ' }}wheels - uses: pypa/cibuildwheel@v3.0 + uses: pypa/cibuildwheel@v3.1 env: CIBW_ARCHS: ${{ matrix.platform.arch == 'amd64' && 'AMD64' || matrix.platform.arch }} CIBW_BUILD: ${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} - + CIBW_ENVIRONMENT: PYTHON_GIL=1 - name: Upload wheel uses: actions/upload-artifact@v4 with: diff --git a/pyproject.toml b/pyproject.toml index 6291b811..edd71a02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,8 +47,8 @@ all = [ # users can install duckdb with 'duckdb[all]', which will install this l "fsspec", # used in duckdb.filesystem "numpy", # used in duckdb.experimental.spark and in duckdb.fetchnumpy() "pandas", # used for pandas dataframes all over the place - "pyarrow", # used for pyarrow support - "adbc_driver_manager", # for the adbc driver (TODO: this should live under the duckdb package) + "pyarrow; python_version < '3.14'", # used for pyarrow support + "adbc_driver_manager; python_version < '3.14'", # for the adbc driver (TODO: this should live under the duckdb package) ] ###################################################################################################### @@ -123,6 +123,13 @@ if.env.COVERAGE = false inherit.cmake.define = "append" cmake.define.DISABLE_UNITY = "1" +[[tool.scikit-build.overrides]] +# Windows Free-Threading +if.platform-system = "^win32" +if.abi-flags = "t" +inherit.cmake.define = "append" +cmake.define.CMAKE_C_FLAGS="/DPy_MOD_GIL_USED /DPy_GIL_DISABLED" +cmake.define.CMAKE_CXX_FLAGS="/DPy_MOD_GIL_USED /DPy_GIL_DISABLED" [tool.scikit-build.sdist] include = [ @@ -204,6 +211,7 @@ required-environments = [ # ... but do always resolve for all of them "python_version >= '3.9' and sys_platform == 'linux' and platform_machine == 'x86_64'", "python_version >= '3.9' and sys_platform == 'linux' and platform_machine == 'aarch64'", ] +prerelease = "allow" # for 3.14 # We just need pytorch for tests, wihtout GPU acceleration. PyPI doesn't host a cpu-only version for Linux, so we have # to configure the index url for cpu-only pytorch manually @@ -220,8 +228,8 @@ torchvision = [ { index = "pytorch-cpu" } ] stubdeps = [ # dependencies used for typehints in the stubs "fsspec", "pandas", - "polars", - "pyarrow", + "polars; python_version < '3.14'", + "pyarrow; python_version < '3.14'", ] test = [ # dependencies used for running tests "pytest", @@ -229,21 +237,21 @@ test = [ # dependencies used for running tests "pytest-timeout", "mypy", "coverage", - "gcovr", + "gcovr; python_version < '3.14'", "gcsfs", "packaging", - "polars", + "polars; python_version < '3.14'", "psutil", "py4j", "pyotp", - "pyspark", + "pyspark; python_version < '3.14'", "pytz", "requests", "urllib3", "fsspec>=2022.11.0", "pandas>=2.0.0", - "pyarrow>=18.0.0", - "torch>=2.2.2; sys_platform != 'darwin' or platform_machine != 'x86_64' or python_version < '3.13'", + "pyarrow>=18.0.0; python_version < '3.14'", + "torch>=2.2.2; python_version < '3.14' and (sys_platform != 'darwin' or platform_machine != 'x86_64' or python_version < '3.13')", "tensorflow==2.14.0; sys_platform == 'darwin' and python_version < '3.12'", "tensorflow-cpu>=2.14.0; sys_platform == 'linux' and platform_machine != 'aarch64' and python_version < '3.12'", "tensorflow-cpu>=2.14.0; sys_platform == 'win32' and python_version < '3.12'", @@ -258,8 +266,8 @@ scripts = [ # dependencies used for running scripts "numpy", "pandas", "pcpp", - "polars", - "pyarrow", + "polars; python_version < '3.14'", + "pyarrow; python_version < '3.14'", "pytz" ] pypi = [ # dependencies used by the pypi cleanup script @@ -305,6 +313,7 @@ filterwarnings = [ # Pyspark is throwing these warnings "ignore:distutils Version classes are deprecated:DeprecationWarning", "ignore:is_datetime64tz_dtype is deprecated:DeprecationWarning", + "ignore:ChainedAssignmentError.*:FutureWarning" ] [tool.coverage.run] @@ -379,6 +388,7 @@ manylinux-x86_64-image = "manylinux_2_28" manylinux-pypy_x86_64-image = "manylinux_2_28" manylinux-aarch64-image = "manylinux_2_28" manylinux-pypy_aarch64-image = "manylinux_2_28" +enable = ["cpython-freethreading", "cpython-prerelease"] [tool.cibuildwheel.linux] before-build = ["yum install -y ccache"] diff --git a/tests/pytest.ini b/tests/pytest.ini index 0c17afd5..5081ee33 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -2,6 +2,8 @@ [pytest] filterwarnings = error + # Pandas ChainedAssignmentError warnings for 3.0 + ignore:ChainedAssignmentError.*:FutureWarning ignore::UserWarning ignore::DeprecationWarning # Jupyter is throwing DeprecationWarnings From f8a83b45e6941c2db7c617fa0c5581247868d8fe Mon Sep 17 00:00:00 2001 From: Paul Timmins Date: Sun, 14 Sep 2025 18:58:38 +0000 Subject: [PATCH 184/472] chore: remove pandas 3.0 warnings -> instead, disable pandas for 3.14 for now. --- pyproject.toml | 3 +- tests/conftest.py | 40 ++++++++++++++++++++----- tests/fast/numpy/test_numpy_new_path.py | 1 + tests/pytest.ini | 2 -- 4 files changed, 34 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index edd71a02..657ab2b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -211,7 +211,7 @@ required-environments = [ # ... but do always resolve for all of them "python_version >= '3.9' and sys_platform == 'linux' and platform_machine == 'x86_64'", "python_version >= '3.9' and sys_platform == 'linux' and platform_machine == 'aarch64'", ] -prerelease = "allow" # for 3.14 +prerelease = "if-necessary-or-explicit" # for 3.14 # We just need pytorch for tests, wihtout GPU acceleration. PyPI doesn't host a cpu-only version for Linux, so we have # to configure the index url for cpu-only pytorch manually @@ -313,7 +313,6 @@ filterwarnings = [ # Pyspark is throwing these warnings "ignore:distutils Version classes are deprecated:DeprecationWarning", "ignore:is_datetime64tz_dtype is deprecated:DeprecationWarning", - "ignore:ChainedAssignmentError.*:FutureWarning" ] [tool.coverage.run] diff --git a/tests/conftest.py b/tests/conftest.py index ce2d0e68..6c3cb2fb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,24 +6,37 @@ import duckdb import warnings from importlib import import_module +import sys try: # need to ignore warnings that might be thrown deep inside pandas's import tree (from dateutil in this case) - warnings.simplefilter(action='ignore', category=DeprecationWarning) - pandas = import_module('pandas') + warnings.simplefilter(action="ignore", category=DeprecationWarning) + pandas = import_module("pandas") warnings.resetwarnings() - pyarrow_dtype = getattr(pandas, 'ArrowDtype', None) + pyarrow_dtype = getattr(pandas, "ArrowDtype", None) except ImportError: pandas = None pyarrow_dtype = None + # Only install mock after we've failed to import pandas for conftest.py + class MockPandas: + def __getattr__(self, name): + pytest.skip("pandas not available", allow_module_level=True) + + sys.modules["pandas"] = MockPandas() + sys.modules["pandas.testing"] = MockPandas() + sys.modules["pandas._testing"] = MockPandas() + # Check if pandas has arrow dtypes enabled -try: - from pandas.compat import pa_version_under7p0 +if pandas is not None: + try: + from pandas.compat import pa_version_under7p0 - pyarrow_dtypes_enabled = not pa_version_under7p0 -except ImportError: + pyarrow_dtypes_enabled = not pa_version_under7p0 + except (ImportError, AttributeError): + pyarrow_dtypes_enabled = False +else: pyarrow_dtypes_enabled = False @@ -31,7 +44,7 @@ def import_pandas(): if pandas: return pandas else: - pytest.skip("Couldn't import pandas") + pytest.skip("Couldn't import pandas", allow_module_level=True) # https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option @@ -39,6 +52,17 @@ def import_pandas(): def pytest_addoption(parser): parser.addoption("--skiplist", action="append", nargs="+", type=str, help="skip listed tests") +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_call(item): + """Convert pandas requirement exceptions to skips""" + outcome = yield + try: + outcome.get_result() + except Exception as e: + if "'pandas' is required for this operation but it was not installed" in str(e): + pytest.skip("pandas not available - test requires pandas functionality") + + def pytest_collection_modifyitems(config, items): tests_to_skip = config.getoption("--skiplist") diff --git a/tests/fast/numpy/test_numpy_new_path.py b/tests/fast/numpy/test_numpy_new_path.py index 4267085c..6e424c9f 100644 --- a/tests/fast/numpy/test_numpy_new_path.py +++ b/tests/fast/numpy/test_numpy_new_path.py @@ -6,6 +6,7 @@ import duckdb from datetime import timedelta import pytest +import pandas # https://github.com/duckdb/duckdb-python/issues/48 class TestScanNumpy(object): diff --git a/tests/pytest.ini b/tests/pytest.ini index 5081ee33..0c17afd5 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -2,8 +2,6 @@ [pytest] filterwarnings = error - # Pandas ChainedAssignmentError warnings for 3.0 - ignore:ChainedAssignmentError.*:FutureWarning ignore::UserWarning ignore::DeprecationWarning # Jupyter is throwing DeprecationWarnings From 63a19f3a75098f403e9ac57f252fb91405014e9d Mon Sep 17 00:00:00 2001 From: paultiq <104510378+paultiq@users.noreply.github.com> Date: Sun, 14 Sep 2025 16:45:18 -0400 Subject: [PATCH 185/472] test: Disable Pandas for 3.14 Not yet available --- pyproject.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 657ab2b8..9a1cb980 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ all = [ # users can install duckdb with 'duckdb[all]', which will install this l "ipython", # used in duckdb.query_graph "fsspec", # used in duckdb.filesystem "numpy", # used in duckdb.experimental.spark and in duckdb.fetchnumpy() - "pandas", # used for pandas dataframes all over the place + "pandas; python_version < '3.14'", # used for pandas dataframes all over the place "pyarrow; python_version < '3.14'", # used for pyarrow support "adbc_driver_manager; python_version < '3.14'", # for the adbc driver (TODO: this should live under the duckdb package) ] @@ -227,7 +227,7 @@ torchvision = [ { index = "pytorch-cpu" } ] [dependency-groups] # used for development only, requires pip >=25.1.0 stubdeps = [ # dependencies used for typehints in the stubs "fsspec", - "pandas", + "pandas; python_version < '3.14'", "polars; python_version < '3.14'", "pyarrow; python_version < '3.14'", ] @@ -249,7 +249,7 @@ test = [ # dependencies used for running tests "requests", "urllib3", "fsspec>=2022.11.0", - "pandas>=2.0.0", + "pandas>=2.0.0; python_version < '3.14'", "pyarrow>=18.0.0; python_version < '3.14'", "torch>=2.2.2; python_version < '3.14' and (sys_platform != 'darwin' or platform_machine != 'x86_64' or python_version < '3.13')", "tensorflow==2.14.0; sys_platform == 'darwin' and python_version < '3.12'", @@ -264,7 +264,7 @@ scripts = [ # dependencies used for running scripts "ipython", "ipywidgets", "numpy", - "pandas", + "pandas; python_version < '3.14'", "pcpp", "polars; python_version < '3.14'", "pyarrow; python_version < '3.14'", From c71f85bccc35dbbce3c22d0e637b82722d6ef0e4 Mon Sep 17 00:00:00 2001 From: Paul Timmins Date: Mon, 15 Sep 2025 02:15:05 +0000 Subject: [PATCH 186/472] test: disable failing test "Windows fatal exception: access violation" --- tests/fast/api/test_connection_interrupt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/fast/api/test_connection_interrupt.py b/tests/fast/api/test_connection_interrupt.py index 4efd68b5..eae6cbb8 100644 --- a/tests/fast/api/test_connection_interrupt.py +++ b/tests/fast/api/test_connection_interrupt.py @@ -1,12 +1,13 @@ import platform import threading import time - +import sys import duckdb import pytest class TestConnectionInterrupt(object): + @pytest.mark.xfail(sys.platform == "win32" and sys.version_info[:2] == (3, 14) and __import__('sysconfig').get_config_var("Py_GIL_DISABLED") == 1, reason="known issue on Windows 3.14t (free-threaded)", strict=False) @pytest.mark.xfail( condition=platform.system() == "Emscripten", reason="threads not allowed on Emscripten", From 251fdcc87741c50118c02d8f525f3c96a8117169 Mon Sep 17 00:00:00 2001 From: Paul Timmins Date: Mon, 15 Sep 2025 02:43:20 +0000 Subject: [PATCH 187/472] tests: skip, don't xfail --- tests/fast/api/test_connection_interrupt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/fast/api/test_connection_interrupt.py b/tests/fast/api/test_connection_interrupt.py index eae6cbb8..931ceaeb 100644 --- a/tests/fast/api/test_connection_interrupt.py +++ b/tests/fast/api/test_connection_interrupt.py @@ -7,7 +7,7 @@ class TestConnectionInterrupt(object): - @pytest.mark.xfail(sys.platform == "win32" and sys.version_info[:2] == (3, 14) and __import__('sysconfig').get_config_var("Py_GIL_DISABLED") == 1, reason="known issue on Windows 3.14t (free-threaded)", strict=False) + @pytest.mark.skipif(sys.platform == "win32" and sys.version_info[:2] == (3, 14) and __import__('sysconfig').get_config_var("Py_GIL_DISABLED") == 1, reason="known issue on Windows 3.14t (free-threaded)") @pytest.mark.xfail( condition=platform.system() == "Emscripten", reason="threads not allowed on Emscripten", From 0f39855882a445b87cd74307fc9571255435dcd6 Mon Sep 17 00:00:00 2001 From: Paul Timmins Date: Mon, 15 Sep 2025 03:22:05 +0000 Subject: [PATCH 188/472] exclude Windows --- .github/workflows/packaging_wheels.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/packaging_wheels.yml b/.github/workflows/packaging_wheels.yml index 00e5cdea..f1e9ddf0 100644 --- a/.github/workflows/packaging_wheels.yml +++ b/.github/workflows/packaging_wheels.yml @@ -45,6 +45,8 @@ jobs: - { minimal: true, python: cp311 } - { minimal: true, python: cp312 } - { minimal: true, platform: { arch: universal2 } } + - { python: cp314t, platform: { os: windows-2025 } } + runs-on: ${{ matrix.platform.os }} env: CIBW_TEST_SKIP: ${{ inputs.testsuite == 'none' && '*' || '*-macosx_universal2' }} From 9de15b82959fd30f01e7d1ef0206083207b408d5 Mon Sep 17 00:00:00 2001 From: Paul Timmins Date: Mon, 15 Sep 2025 03:23:40 +0000 Subject: [PATCH 189/472] tests: revert the skip since we're excluding Windows 3.14t builds entirely. --- tests/fast/api/test_connection_interrupt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/fast/api/test_connection_interrupt.py b/tests/fast/api/test_connection_interrupt.py index 931ceaeb..f9fa37d0 100644 --- a/tests/fast/api/test_connection_interrupt.py +++ b/tests/fast/api/test_connection_interrupt.py @@ -7,7 +7,6 @@ class TestConnectionInterrupt(object): - @pytest.mark.skipif(sys.platform == "win32" and sys.version_info[:2] == (3, 14) and __import__('sysconfig').get_config_var("Py_GIL_DISABLED") == 1, reason="known issue on Windows 3.14t (free-threaded)") @pytest.mark.xfail( condition=platform.system() == "Emscripten", reason="threads not allowed on Emscripten", From 920600471342dbb0f02bde72754ebd5c1d632b21 Mon Sep 17 00:00:00 2001 From: "paul@iqmo.com" Date: Sun, 14 Sep 2025 23:55:11 -0400 Subject: [PATCH 190/472] revert: import that was added, no longer needed --- tests/fast/api/test_connection_interrupt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/fast/api/test_connection_interrupt.py b/tests/fast/api/test_connection_interrupt.py index f9fa37d0..ce9d2599 100644 --- a/tests/fast/api/test_connection_interrupt.py +++ b/tests/fast/api/test_connection_interrupt.py @@ -1,7 +1,6 @@ import platform import threading import time -import sys import duckdb import pytest From 1f95ceac50d14901fd27cf47a3a103d5ed76c820 Mon Sep 17 00:00:00 2001 From: "paul@iqmo.com" Date: Mon, 15 Sep 2025 00:34:12 -0400 Subject: [PATCH 191/472] revert: exactly to original --- tests/fast/api/test_connection_interrupt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/fast/api/test_connection_interrupt.py b/tests/fast/api/test_connection_interrupt.py index ce9d2599..4efd68b5 100644 --- a/tests/fast/api/test_connection_interrupt.py +++ b/tests/fast/api/test_connection_interrupt.py @@ -1,6 +1,7 @@ import platform import threading import time + import duckdb import pytest From 947a52aad7b29ef1684850c5d39050e532874388 Mon Sep 17 00:00:00 2001 From: "paul@iqmo.com" Date: Mon, 15 Sep 2025 07:22:11 -0400 Subject: [PATCH 192/472] test: Mark test xfail --- tests/fast/numpy/test_numpy_new_path.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/fast/numpy/test_numpy_new_path.py b/tests/fast/numpy/test_numpy_new_path.py index 6e424c9f..abc09ef5 100644 --- a/tests/fast/numpy/test_numpy_new_path.py +++ b/tests/fast/numpy/test_numpy_new_path.py @@ -2,14 +2,15 @@ Therefore, we only test the new codes and exec paths. """ +import sys import numpy as np import duckdb from datetime import timedelta import pytest -import pandas # https://github.com/duckdb/duckdb-python/issues/48 class TestScanNumpy(object): + @pytest.mark.skipif(sys.version_info[:2] == (3, 14), reason="Fails when testing without pandas https://github.com/duckdb/duckdb-python/issues/48") def test_scan_numpy(self, duckdb_cursor): z = np.array([1, 2, 3]) res = duckdb_cursor.sql("select * from z").fetchall() From 9f08a81dc205967a487222f86eb7f6024cfa10e3 Mon Sep 17 00:00:00 2001 From: "paul@iqmo.com" Date: Mon, 15 Sep 2025 07:23:15 -0400 Subject: [PATCH 193/472] test: mark test xfail --- tests/fast/numpy/test_numpy_new_path.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/fast/numpy/test_numpy_new_path.py b/tests/fast/numpy/test_numpy_new_path.py index abc09ef5..c1122797 100644 --- a/tests/fast/numpy/test_numpy_new_path.py +++ b/tests/fast/numpy/test_numpy_new_path.py @@ -10,7 +10,7 @@ class TestScanNumpy(object): - @pytest.mark.skipif(sys.version_info[:2] == (3, 14), reason="Fails when testing without pandas https://github.com/duckdb/duckdb-python/issues/48") + @pytest.mark.xfail(sys.version_info[:2] == (3, 14), reason="Fails when testing without pandas https://github.com/duckdb/duckdb-python/issues/48") def test_scan_numpy(self, duckdb_cursor): z = np.array([1, 2, 3]) res = duckdb_cursor.sql("select * from z").fetchall() From e60db527008c02fb4c154b266e93713c081c8a11 Mon Sep 17 00:00:00 2001 From: "paul@iqmo.com" Date: Mon, 15 Sep 2025 09:34:33 -0400 Subject: [PATCH 194/472] chore: Add comments and todo's for workflow changes --- .github/workflows/packaging_wheels.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/packaging_wheels.yml b/.github/workflows/packaging_wheels.yml index f1e9ddf0..463faea8 100644 --- a/.github/workflows/packaging_wheels.yml +++ b/.github/workflows/packaging_wheels.yml @@ -45,7 +45,9 @@ jobs: - { minimal: true, python: cp311 } - { minimal: true, python: cp312 } - { minimal: true, platform: { arch: universal2 } } - - { python: cp314t, platform: { os: windows-2025 } } + # Windows+cp314t disabled due to test failures in CI. + # TODO: Diagnose why tests fail (access violations) in some configurations + - { python: cp314t, platform: { os: windows-2025 } } runs-on: ${{ matrix.platform.os }} env: @@ -91,6 +93,8 @@ jobs: env: CIBW_ARCHS: ${{ matrix.platform.arch == 'amd64' && 'AMD64' || matrix.platform.arch }} CIBW_BUILD: ${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} + # PYTHON_GIL=1: Suppresses the RuntimeWarning that the GIL is enabled on free-threaded builds. + # TODO: Remove PYTHON_GIL=1 when free-threaded is supported. CIBW_ENVIRONMENT: PYTHON_GIL=1 - name: Upload wheel uses: actions/upload-artifact@v4 From 5b01263fd9bead30d7080ffa8dd59ad470b24a3d Mon Sep 17 00:00:00 2001 From: "paul@iqmo.com" Date: Mon, 15 Sep 2025 09:38:57 -0400 Subject: [PATCH 195/472] chore: Remove unused section for Windows 3.14t builds. --- pyproject.toml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9a1cb980..bcbb24f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,14 +123,6 @@ if.env.COVERAGE = false inherit.cmake.define = "append" cmake.define.DISABLE_UNITY = "1" -[[tool.scikit-build.overrides]] -# Windows Free-Threading -if.platform-system = "^win32" -if.abi-flags = "t" -inherit.cmake.define = "append" -cmake.define.CMAKE_C_FLAGS="/DPy_MOD_GIL_USED /DPy_GIL_DISABLED" -cmake.define.CMAKE_CXX_FLAGS="/DPy_MOD_GIL_USED /DPy_GIL_DISABLED" - [tool.scikit-build.sdist] include = [ "README.md", From e3afe7e3f8cc5cfbaf83425b63a6fa2c3a661cc3 Mon Sep 17 00:00:00 2001 From: "paul@iqmo.com" Date: Mon, 15 Sep 2025 10:14:39 -0400 Subject: [PATCH 196/472] chore: Add version check to only allow no-Pandas for 3.14, plus a TODO --- tests/conftest.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6c3cb2fb..e2f427c3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +import sys import pytest import shutil from os.path import abspath, join, dirname, normpath @@ -52,16 +53,22 @@ def import_pandas(): def pytest_addoption(parser): parser.addoption("--skiplist", action="append", nargs="+", type=str, help="skip listed tests") + @pytest.hookimpl(hookwrapper=True) def pytest_runtest_call(item): """Convert pandas requirement exceptions to skips""" + outcome = yield - try: - outcome.get_result() - except Exception as e: - if "'pandas' is required for this operation but it was not installed" in str(e): - pytest.skip("pandas not available - test requires pandas functionality") + # TODO: Remove skip when Pandas releases for 3.14. After, consider bumping to 3.15 + if sys.version_info[:2] == (3, 14): + try: + outcome.get_result() + except Exception as e: + if "'pandas' is required for this operation but it was not installed" in str(e): + pytest.skip("pandas not available - test requires pandas functionality") + else: + raise e def pytest_collection_modifyitems(config, items): From 8fd92021d8ce4f1ce755997a640520e1c74f9bfe Mon Sep 17 00:00:00 2001 From: "paul@iqmo.com" Date: Mon, 15 Sep 2025 16:44:55 -0400 Subject: [PATCH 197/472] tests: Narrow pandas not installed skip to duckdb.InvalidInputException --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index e2f427c3..5e297aee 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -64,7 +64,7 @@ def pytest_runtest_call(item): if sys.version_info[:2] == (3, 14): try: outcome.get_result() - except Exception as e: + except duckdb.InvalidInputException as e: if "'pandas' is required for this operation but it was not installed" in str(e): pytest.skip("pandas not available - test requires pandas functionality") else: From c3873a21bd2ea16d726cf6b38fc8cd7881bff03e Mon Sep 17 00:00:00 2001 From: "paul@iqmo.com" Date: Tue, 16 Sep 2025 07:07:27 -0400 Subject: [PATCH 198/472] tests: revert xfail for 3.14 now that #48 is merged --- tests/fast/numpy/test_numpy_new_path.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/fast/numpy/test_numpy_new_path.py b/tests/fast/numpy/test_numpy_new_path.py index c1122797..3735ff6e 100644 --- a/tests/fast/numpy/test_numpy_new_path.py +++ b/tests/fast/numpy/test_numpy_new_path.py @@ -10,7 +10,6 @@ class TestScanNumpy(object): - @pytest.mark.xfail(sys.version_info[:2] == (3, 14), reason="Fails when testing without pandas https://github.com/duckdb/duckdb-python/issues/48") def test_scan_numpy(self, duckdb_cursor): z = np.array([1, 2, 3]) res = duckdb_cursor.sql("select * from z").fetchall() From f2c1e377bed3403b861a69ce8f32f648e1db56df Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 16 Sep 2025 16:07:10 +0200 Subject: [PATCH 199/472] submodule at v1.4.0 --- external/duckdb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/duckdb b/external/duckdb index b3edbac8..b8a06e4a 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit b3edbac8519f8ed04f58a6f30ec349112bdc7d6c +Subproject commit b8a06e4a22672e254cd0baa68a3dbed2eb51c56e From bf5f3352c0be48420459b2b98000bb6fd26787fe Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 16 Sep 2025 16:54:25 +0200 Subject: [PATCH 200/472] set MAIN_BRANCH_VERSIONING to False for v1.4-andium --- duckdb_packaging/setuptools_scm_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/duckdb_packaging/setuptools_scm_version.py b/duckdb_packaging/setuptools_scm_version.py index 8381e1e2..27bedd24 100644 --- a/duckdb_packaging/setuptools_scm_version.py +++ b/duckdb_packaging/setuptools_scm_version.py @@ -13,7 +13,7 @@ from ._versioning import parse_version, format_version # MAIN_BRANCH_VERSIONING should be 'True' on main branch only -MAIN_BRANCH_VERSIONING = True +MAIN_BRANCH_VERSIONING = False SCM_PRETEND_ENV_VAR = "SETUPTOOLS_SCM_PRETEND_VERSION_FOR_DUCKDB" SCM_GLOBAL_PRETEND_ENV_VAR = "SETUPTOOLS_SCM_PRETEND_VERSION" From fe60146e20f7e89afc7ce79c169732674db90ae1 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 16 Sep 2025 17:11:20 +0200 Subject: [PATCH 201/472] Re-enable nightlies for main --- .github/workflows/release.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 77d6dcc6..f54b0f76 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -146,7 +146,6 @@ jobs: publish_pypi: name: Publish Artifacts to PyPI runs-on: ubuntu-latest - if: ${{ !always() }} needs: [workflow_state, build_sdist, build_wheels] environment: name: ${{ needs.workflow_state.outputs.ci_env }} From 2a640a6f3b8b73ae03ca4798940b84e7bac3e815 Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 17 Sep 2025 10:14:10 +0200 Subject: [PATCH 202/472] add hash method --- src/duckdb_py/include/duckdb_python/pytype.hpp | 1 + src/duckdb_py/typing/pytype.cpp | 5 +++++ tests/fast/test_type.py | 14 ++++++++++++++ 3 files changed, 20 insertions(+) diff --git a/src/duckdb_py/include/duckdb_python/pytype.hpp b/src/duckdb_py/include/duckdb_python/pytype.hpp index a6e13dfd..6d1e8074 100644 --- a/src/duckdb_py/include/duckdb_python/pytype.hpp +++ b/src/duckdb_py/include/duckdb_python/pytype.hpp @@ -30,6 +30,7 @@ class DuckDBPyType : public enable_shared_from_this { public: bool Equals(const shared_ptr &other) const; + ssize_t Hash() const; bool EqualsString(const string &type_str) const; shared_ptr GetAttribute(const string &name) const; py::list Children() const; diff --git a/src/duckdb_py/typing/pytype.cpp b/src/duckdb_py/typing/pytype.cpp index 009e3dab..01357ad3 100644 --- a/src/duckdb_py/typing/pytype.cpp +++ b/src/duckdb_py/typing/pytype.cpp @@ -46,6 +46,10 @@ bool DuckDBPyType::Equals(const shared_ptr &other) const { return type == other->type; } +ssize_t DuckDBPyType::Hash() const { + return py::hash(py::str(ToString())); +} + bool DuckDBPyType::EqualsString(const string &type_str) const { return StringUtil::CIEquals(type.ToString(), type_str); } @@ -328,6 +332,7 @@ void DuckDBPyType::Initialize(py::handle &m) { type_module.def("__repr__", &DuckDBPyType::ToString, "Stringified representation of the type object"); type_module.def("__eq__", &DuckDBPyType::Equals, "Compare two types for equality", py::arg("other"), py::is_operator()); type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other"), py::is_operator()); + type_module.def("__hash__", &DuckDBPyType::Hash, "Hashes the type, equal to stringifying+hashing"); type_module.def_property_readonly("id", &DuckDBPyType::GetId); type_module.def_property_readonly("children", &DuckDBPyType::Children); type_module.def(py::init<>([](const string &type_str, shared_ptr connection = nullptr) { diff --git a/tests/fast/test_type.py b/tests/fast/test_type.py index 6f648179..c5a62694 100644 --- a/tests/fast/test_type.py +++ b/tests/fast/test_type.py @@ -214,6 +214,20 @@ def test_struct_from_dict(self): res = duckdb.list_type({'a': VARCHAR, 'b': VARCHAR}) assert res == 'STRUCT(a VARCHAR, b VARCHAR)[]' + def test_hash_method(self): + type1 = duckdb.list_type({'a': VARCHAR, 'b': VARCHAR}) + type2 = duckdb.list_type({'b': VARCHAR, 'a': VARCHAR}) + type3 = VARCHAR + + type_set = set() + type_set.add(type1) + type_set.add(type2) + type_set.add(type3) + + type_set.add(type1) + expected = ['STRUCT(a VARCHAR, b VARCHAR)[]', 'STRUCT(b VARCHAR, a VARCHAR)[]', 'VARCHAR'] + assert sorted([str(x) for x in list(type_set)]) == expected + # NOTE: we can support this, but I don't think going through hoops for an outdated version of python is worth it @pytest.mark.skipif(sys.version_info < (3, 9), reason="python3.7 does not store Optional[..] in a recognized way") def test_optional(self): From 6b86b571ca2e6bad6dd26432cc93ee6c6a3e147b Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 17 Sep 2025 10:45:41 +0200 Subject: [PATCH 203/472] Packaging workflow should respect the 'minimal' input param --- .github/workflows/packaging.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/packaging.yml b/.github/workflows/packaging.yml index 507c7bda..16771deb 100644 --- a/.github/workflows/packaging.yml +++ b/.github/workflows/packaging.yml @@ -74,7 +74,7 @@ jobs: name: Build and test releases uses: ./.github/workflows/packaging_wheels.yml with: - minimal: false + minimal: ${{ inputs.minimal }} testsuite: all duckdb-python-sha: ${{ inputs.duckdb-python-sha != '' && inputs.duckdb-python-sha || github.sha }} duckdb-sha: ${{ inputs.duckdb-sha }} From f39ce02f9ef72663b886e66dab39cce3e95f5f23 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 17 Sep 2025 10:45:41 +0200 Subject: [PATCH 204/472] Packaging workflow should respect the 'minimal' input param --- .github/workflows/packaging.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/packaging.yml b/.github/workflows/packaging.yml index 507c7bda..16771deb 100644 --- a/.github/workflows/packaging.yml +++ b/.github/workflows/packaging.yml @@ -74,7 +74,7 @@ jobs: name: Build and test releases uses: ./.github/workflows/packaging_wheels.yml with: - minimal: false + minimal: ${{ inputs.minimal }} testsuite: all duckdb-python-sha: ${{ inputs.duckdb-python-sha != '' && inputs.duckdb-python-sha || github.sha }} duckdb-sha: ${{ inputs.duckdb-sha }} From 7f930d1a85f9bf20a66d15e1340cea334634013c Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 17 Sep 2025 12:23:48 +0200 Subject: [PATCH 205/472] avoid collision with windows define --- src/duckdb_py/include/duckdb_python/pytype.hpp | 2 +- src/duckdb_py/typing/pytype.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/duckdb_py/include/duckdb_python/pytype.hpp b/src/duckdb_py/include/duckdb_python/pytype.hpp index 6d1e8074..fced489e 100644 --- a/src/duckdb_py/include/duckdb_python/pytype.hpp +++ b/src/duckdb_py/include/duckdb_python/pytype.hpp @@ -30,7 +30,7 @@ class DuckDBPyType : public enable_shared_from_this { public: bool Equals(const shared_ptr &other) const; - ssize_t Hash() const; + ssize_t HashType() const; bool EqualsString(const string &type_str) const; shared_ptr GetAttribute(const string &name) const; py::list Children() const; diff --git a/src/duckdb_py/typing/pytype.cpp b/src/duckdb_py/typing/pytype.cpp index 35ff81a9..91a95d91 100644 --- a/src/duckdb_py/typing/pytype.cpp +++ b/src/duckdb_py/typing/pytype.cpp @@ -46,7 +46,7 @@ bool DuckDBPyType::Equals(const shared_ptr &other) const { return type == other->type; } -ssize_t DuckDBPyType::Hash() const { +ssize_t DuckDBPyType::HashType() const { return py::hash(py::str(ToString())); } @@ -334,7 +334,7 @@ void DuckDBPyType::Initialize(py::handle &m) { py::is_operator()); type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other"), py::is_operator()); - type_module.def("__hash__", &DuckDBPyType::Hash, "Hashes the type, equal to stringifying+hashing"); + type_module.def("__hash__", &DuckDBPyType::HashType, "Hashes the type, equal to stringifying+hashing"); type_module.def_property_readonly("id", &DuckDBPyType::GetId); type_module.def_property_readonly("children", &DuckDBPyType::Children); type_module.def(py::init<>([](const string &type_str, shared_ptr connection = nullptr) { From 21122feb3213811d772caf7005c5d51db72a206a Mon Sep 17 00:00:00 2001 From: Tishj Date: Wed, 17 Sep 2025 16:36:40 +0200 Subject: [PATCH 206/472] third attempt at making windows happy --- src/duckdb_py/include/duckdb_python/pytype.hpp | 1 - src/duckdb_py/typing/pytype.cpp | 8 +++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/duckdb_py/include/duckdb_python/pytype.hpp b/src/duckdb_py/include/duckdb_python/pytype.hpp index fced489e..a6e13dfd 100644 --- a/src/duckdb_py/include/duckdb_python/pytype.hpp +++ b/src/duckdb_py/include/duckdb_python/pytype.hpp @@ -30,7 +30,6 @@ class DuckDBPyType : public enable_shared_from_this { public: bool Equals(const shared_ptr &other) const; - ssize_t HashType() const; bool EqualsString(const string &type_str) const; shared_ptr GetAttribute(const string &name) const; py::list Children() const; diff --git a/src/duckdb_py/typing/pytype.cpp b/src/duckdb_py/typing/pytype.cpp index 91a95d91..f04c14ba 100644 --- a/src/duckdb_py/typing/pytype.cpp +++ b/src/duckdb_py/typing/pytype.cpp @@ -46,10 +46,6 @@ bool DuckDBPyType::Equals(const shared_ptr &other) const { return type == other->type; } -ssize_t DuckDBPyType::HashType() const { - return py::hash(py::str(ToString())); -} - bool DuckDBPyType::EqualsString(const string &type_str) const { return StringUtil::CIEquals(type.ToString(), type_str); } @@ -334,7 +330,9 @@ void DuckDBPyType::Initialize(py::handle &m) { py::is_operator()); type_module.def("__eq__", &DuckDBPyType::EqualsString, "Compare two types for equality", py::arg("other"), py::is_operator()); - type_module.def("__hash__", &DuckDBPyType::HashType, "Hashes the type, equal to stringifying+hashing"); + type_module.def("__hash__", [](const DuckDBPyType &type) { + return py::hash(py::str(type.ToString())); + }); type_module.def_property_readonly("id", &DuckDBPyType::GetId); type_module.def_property_readonly("children", &DuckDBPyType::Children); type_module.def(py::init<>([](const string &type_str, shared_ptr connection = nullptr) { From 8b49f73ed6423571136cac9d1aa98868299a3a73 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Wed, 17 Sep 2025 10:21:30 -0600 Subject: [PATCH 207/472] Apply suggestions from code review Co-authored-by: Evert Lammerts --- CONTRIBUTING.md | 12 +++++++----- README.md | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 669f7b62..11191111 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -36,10 +36,9 @@ * Submitting changes to an open pull request will move it to 'draft' state. * Pull requests will get a complete run on the main repo CI only when marked as 'ready for review' (via Web UI, button on bottom right). -### Nightly CI +### Testing cross-platform and cross-Python -* Packages creation and long running tests will be performed during a nightly run -* On your fork you can trigger long running tests (NightlyTests.yml) for any branch following information from https://docs.github.com/en/actions/using-workflows/manually-running-a-workflow#running-a-workflow +* On your fork you can [run](https://docs.github.com/en/actions/using-workflows/manually-running-a-workflow#running-a-workflow) the Packaging workflow manually for any branch. You can choose whether to build for all platforms or a subset, and to either run the full testsuite, the fast tests only, or no tests at all. ## Setting up a development environment @@ -60,8 +59,11 @@ git remote add upstream https://github.com/duckdb/duckdb-python.git git fetch --all ``` -The submodule stuff is needed because we vendor the core DuckDB repository as a git submodule, -and to build the python package we also need to build DuckDB itself. +Two things to be aware of when cloning this repository: +* DuckDB is vendored as a git submodule and needs to be initialized during or after cloning duckdb-python. +* Currently, for DuckDB to determine its version while building, it depends on the local availability of its tags. + +After forking the duckdb-python repo we recommend you clone your fork as follows: ### Submodule update hook diff --git a/README.md b/README.md index 7dd7895a..627349b2 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ API Docs (Python)

-# The Official Python bindings to [DuckDB](https://github.com/duckdb/duckdb) +# The [DuckDB](https://github.com/duckdb/duckdb) Python Package ## Installation From ba489176d8aa1b7184debdd5419728f9d944142c Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 19:30:17 +0200 Subject: [PATCH 208/472] ruff conf: exclude pyi from linting --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 6291b811..4da79b50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -320,6 +320,7 @@ fixable = ["ALL"] exclude = ['external/duckdb'] [tool.ruff.lint] +exclude = ['*.pyi'] select = [ "ANN", # flake8-annotations "B", # flake8-bugbear From 49ef8dc4ba38663bb6a3fcfc8a5993caf9a8347b Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 19:33:50 +0200 Subject: [PATCH 209/472] Ruff UP006: dont use typing module for dict and list typehints --- duckdb/experimental/spark/conf.py | 6 +-- .../spark/errors/exceptions/base.py | 6 +-- duckdb/experimental/spark/errors/utils.py | 2 +- duckdb/experimental/spark/sql/_typing.py | 2 +- duckdb/experimental/spark/sql/catalog.py | 8 ++-- duckdb/experimental/spark/sql/column.py | 4 +- duckdb/experimental/spark/sql/dataframe.py | 30 ++++++------ duckdb/experimental/spark/sql/functions.py | 2 +- duckdb/experimental/spark/sql/group.py | 6 +-- duckdb/experimental/spark/sql/readwriter.py | 8 ++-- duckdb/experimental/spark/sql/session.py | 2 +- duckdb/experimental/spark/sql/type_utils.py | 6 +-- duckdb/experimental/spark/sql/types.py | 46 +++++++++---------- duckdb/value/constant/__init__.py | 4 +- duckdb_packaging/build_backend.py | 6 +-- duckdb_packaging/pypi_cleanup.py | 6 +-- .../generate_connection_wrapper_methods.py | 2 +- scripts/generate_import_cache_cpp.py | 14 +++--- scripts/generate_import_cache_json.py | 8 ++-- scripts/get_cpp_methods.py | 4 +- sqllogic/conftest.py | 6 +-- tests/fast/test_filesystem.py | 2 +- tests/fast/test_multithread.py | 4 +- 23 files changed, 92 insertions(+), 92 deletions(-) diff --git a/duckdb/experimental/spark/conf.py b/duckdb/experimental/spark/conf.py index 11680a9a..a04c993b 100644 --- a/duckdb/experimental/spark/conf.py +++ b/duckdb/experimental/spark/conf.py @@ -12,20 +12,20 @@ def contains(self, key: str) -> bool: def get(self, key: str, defaultValue: Optional[str] = None) -> Optional[str]: raise ContributionsAcceptedError - def getAll(self) -> List[Tuple[str, str]]: + def getAll(self) -> list[tuple[str, str]]: raise ContributionsAcceptedError def set(self, key: str, value: str) -> "SparkConf": raise ContributionsAcceptedError - def setAll(self, pairs: List[Tuple[str, str]]) -> "SparkConf": + def setAll(self, pairs: list[tuple[str, str]]) -> "SparkConf": raise ContributionsAcceptedError def setAppName(self, value: str) -> "SparkConf": raise ContributionsAcceptedError def setExecutorEnv( - self, key: Optional[str] = None, value: Optional[str] = None, pairs: Optional[List[Tuple[str, str]]] = None + self, key: Optional[str] = None, value: Optional[str] = None, pairs: Optional[list[tuple[str, str]]] = None ) -> "SparkConf": raise ContributionsAcceptedError diff --git a/duckdb/experimental/spark/errors/exceptions/base.py b/duckdb/experimental/spark/errors/exceptions/base.py index 21dba03b..80e91170 100644 --- a/duckdb/experimental/spark/errors/exceptions/base.py +++ b/duckdb/experimental/spark/errors/exceptions/base.py @@ -13,7 +13,7 @@ def __init__( # The error class, decides the message format, must be one of the valid options listed in 'error_classes.py' error_class: Optional[str] = None, # The dictionary listing the arguments specified in the message (or the error_class) - message_parameters: Optional[Dict[str, str]] = None, + message_parameters: Optional[dict[str, str]] = None, ): # `message` vs `error_class` & `message_parameters` are mutually exclusive. assert (message is not None and (error_class is None and message_parameters is None)) or ( @@ -24,7 +24,7 @@ def __init__( if message is None: self.message = self.error_reader.get_error_message( - cast(str, error_class), cast(Dict[str, str], message_parameters) + cast(str, error_class), cast(dict[str, str], message_parameters) ) else: self.message = message @@ -45,7 +45,7 @@ def getErrorClass(self) -> Optional[str]: """ return self.error_class - def getMessageParameters(self) -> Optional[Dict[str, str]]: + def getMessageParameters(self) -> Optional[dict[str, str]]: """ Returns a message parameters as a dictionary. diff --git a/duckdb/experimental/spark/errors/utils.py b/duckdb/experimental/spark/errors/utils.py index a375c0c7..3ef418bd 100644 --- a/duckdb/experimental/spark/errors/utils.py +++ b/duckdb/experimental/spark/errors/utils.py @@ -29,7 +29,7 @@ class ErrorClassesReader: def __init__(self) -> None: self.error_info_map = ERROR_CLASSES_MAP - def get_error_message(self, error_class: str, message_parameters: Dict[str, str]) -> str: + def get_error_message(self, error_class: str, message_parameters: dict[str, str]) -> str: """ Returns the completed error message by applying message parameters to the message template. """ diff --git a/duckdb/experimental/spark/sql/_typing.py b/duckdb/experimental/spark/sql/_typing.py index 7b1f9ad1..645b60bb 100644 --- a/duckdb/experimental/spark/sql/_typing.py +++ b/duckdb/experimental/spark/sql/_typing.py @@ -57,7 +57,7 @@ float, ) -RowLike = TypeVar("RowLike", List[Any], Tuple[Any, ...], types.Row) +RowLike = TypeVar("RowLike", list[Any], tuple[Any, ...], types.Row) SQLBatchedUDFType = Literal[100] diff --git a/duckdb/experimental/spark/sql/catalog.py b/duckdb/experimental/spark/sql/catalog.py index ebedb1a1..d3b857fb 100644 --- a/duckdb/experimental/spark/sql/catalog.py +++ b/duckdb/experimental/spark/sql/catalog.py @@ -36,7 +36,7 @@ class Catalog: def __init__(self, session: SparkSession): self._session = session - def listDatabases(self) -> List[Database]: + def listDatabases(self) -> list[Database]: res = self._session.conn.sql('select database_name from duckdb_databases()').fetchall() def transform_to_database(x) -> Database: @@ -45,7 +45,7 @@ def transform_to_database(x) -> Database: databases = [transform_to_database(x) for x in res] return databases - def listTables(self) -> List[Table]: + def listTables(self) -> list[Table]: res = self._session.conn.sql('select table_name, database_name, sql, temporary from duckdb_tables()').fetchall() def transform_to_table(x) -> Table: @@ -54,7 +54,7 @@ def transform_to_table(x) -> Table: tables = [transform_to_table(x) for x in res] return tables - def listColumns(self, tableName: str, dbName: Optional[str] = None) -> List[Column]: + def listColumns(self, tableName: str, dbName: Optional[str] = None) -> list[Column]: query = f""" select column_name, data_type, is_nullable from duckdb_columns() where table_name = '{tableName}' """ @@ -68,7 +68,7 @@ def transform_to_column(x) -> Column: columns = [transform_to_column(x) for x in res] return columns - def listFunctions(self, dbName: Optional[str] = None) -> List[Function]: + def listFunctions(self, dbName: Optional[str] = None) -> list[Function]: raise NotImplementedError def setCurrentDatabase(self, dbName: str) -> None: diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index 5f0b2b99..de0c95f8 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -234,10 +234,10 @@ def cast(self, dataType: Union[DataType, str]) -> "Column": def isin(self, *cols: Any) -> "Column": if len(cols) == 1 and isinstance(cols[0], (list, set)): # Only one argument supplied, it's a list - cols = cast(Tuple, cols[0]) + cols = cast(tuple, cols[0]) cols = cast( - Tuple, + tuple, [_get_expr(c) for c in cols], ) return Column(self.expr.isin(*cols)) diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index a81a423b..54c220eb 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -143,7 +143,7 @@ def withColumn(self, columnName: str, col: Column) -> "DataFrame": rel = self.relation.select(*cols) return DataFrame(rel, self.session) - def withColumns(self, *colsMap: Dict[str, Column]) -> "DataFrame": + def withColumns(self, *colsMap: dict[str, Column]) -> "DataFrame": """ Returns a new :class:`DataFrame` by adding multiple columns or replacing the existing columns that have the same names. @@ -218,7 +218,7 @@ def withColumns(self, *colsMap: Dict[str, Column]) -> "DataFrame": rel = self.relation.select(*cols) return DataFrame(rel, self.session) - def withColumnsRenamed(self, colsMap: Dict[str, str]) -> "DataFrame": + def withColumnsRenamed(self, colsMap: dict[str, str]) -> "DataFrame": """ Returns a new :class:`DataFrame` by renaming multiple columns. This is a no-op if the schema doesn't contain the given column names. @@ -356,7 +356,7 @@ def transform( return result def sort( - self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: Any + self, *cols: Union[str, Column, list[Union[str, Column]]], **kwargs: Any ) -> "DataFrame": """Returns a new :class:`DataFrame` sorted by the specified column(s). @@ -487,7 +487,7 @@ def sort( orderBy = sort - def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]: + def head(self, n: Optional[int] = None) -> Union[Optional[Row], list[Row]]: if n is None: rs = self.head(1) return rs[0] if rs else None @@ -495,7 +495,7 @@ def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]: first = head - def take(self, num: int) -> List[Row]: + def take(self, num: int) -> list[Row]: return self.limit(num).collect() def filter(self, condition: "ColumnOrName") -> "DataFrame": @@ -579,7 +579,7 @@ def select(self, *cols) -> "DataFrame": return DataFrame(rel, self.session) @property - def columns(self) -> List[str]: + def columns(self) -> list[str]: """Returns all column names as a list. Examples @@ -589,12 +589,12 @@ def columns(self) -> List[str]: """ return [f.name for f in self.schema.fields] - def _ipython_key_completions_(self) -> List[str]: + def _ipython_key_completions_(self) -> list[str]: # Provides tab-completion for column names in PySpark DataFrame # when accessed in bracket notation, e.g. df['] return self.columns - def __dir__(self) -> List[str]: + def __dir__(self) -> list[str]: out = set(super().__dir__()) out.update(c for c in self.columns if c.isidentifier() and not iskeyword(c)) return sorted(out) @@ -602,7 +602,7 @@ def __dir__(self) -> List[str]: def join( self, other: "DataFrame", - on: Optional[Union[str, List[str], Column, List[Column]]] = None, + on: Optional[Union[str, list[str], Column, list[Column]]] = None, how: Optional[str] = None, ) -> "DataFrame": """Joins with another :class:`DataFrame`, using the given join expression. @@ -704,7 +704,7 @@ def join( assert isinstance( on[0], Expression ), "on should be Column or list of Column" - on = reduce(lambda x, y: x.__and__(y), cast(List[Expression], on)) + on = reduce(lambda x, y: x.__and__(y), cast(list[Expression], on)) if on is None and how is None: @@ -893,11 +893,11 @@ def __getitem__(self, item: Union[int, str]) -> Column: ... @overload - def __getitem__(self, item: Union[Column, List, Tuple]) -> "DataFrame": + def __getitem__(self, item: Union[Column, list, tuple]) -> "DataFrame": ... def __getitem__( - self, item: Union[int, str, Column, List, Tuple] + self, item: Union[int, str, Column, list, tuple] ) -> Union[Column, "DataFrame"]: """Returns the column as a :class:`Column`. @@ -942,7 +942,7 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": ... @overload - def groupBy(self, __cols: Union[List[Column], List[str]]) -> "GroupedData": + def groupBy(self, __cols: Union[list[Column], list[str]]) -> "GroupedData": ... def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] @@ -1259,7 +1259,7 @@ def exceptAll(self, other: "DataFrame") -> "DataFrame": """ return DataFrame(self.relation.except_(other.relation), self.session) - def dropDuplicates(self, subset: Optional[List[str]] = None) -> "DataFrame": + def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": """Return a new :class:`DataFrame` with duplicate rows removed, optionally only considering certain columns. @@ -1391,7 +1391,7 @@ def toDF(self, *cols) -> "DataFrame": new_rel = self.relation.project(*projections) return DataFrame(new_rel, self.session) - def collect(self) -> List[Row]: + def collect(self) -> list[Row]: columns = self.relation.columns result = self.relation.fetchall() diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index fecada95..78b14de7 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -111,7 +111,7 @@ def struct(*cols: Column) -> Column: def array( - *cols: Union["ColumnOrName", Union[List["ColumnOrName"], Tuple["ColumnOrName", ...]]] + *cols: Union["ColumnOrName", Union[list["ColumnOrName"], tuple["ColumnOrName", ...]]] ) -> Column: """Creates a new array column. diff --git a/duckdb/experimental/spark/sql/group.py b/duckdb/experimental/spark/sql/group.py index e6e99beb..ad7e7e2a 100644 --- a/duckdb/experimental/spark/sql/group.py +++ b/duckdb/experimental/spark/sql/group.py @@ -177,7 +177,7 @@ def avg(self, *cols: str) -> DataFrame: if len(columns) == 0: schema = self._df.schema # Take only the numeric types of the relation - columns: List[str] = [x.name for x in schema.fields if isinstance(x.dataType, NumericType)] + columns: list[str] = [x.name for x in schema.fields if isinstance(x.dataType, NumericType)] return _api_internal(self, "avg", *columns) @df_varargs_api @@ -312,10 +312,10 @@ def agg(self, *exprs: Column) -> DataFrame: ... @overload - def agg(self, __exprs: Dict[str, str]) -> DataFrame: + def agg(self, __exprs: dict[str, str]) -> DataFrame: ... - def agg(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame: + def agg(self, *exprs: Union[Column, dict[str, str]]) -> DataFrame: """Compute aggregates and returns the result as a :class:`DataFrame`. The available aggregate functions can be: diff --git a/duckdb/experimental/spark/sql/readwriter.py b/duckdb/experimental/spark/sql/readwriter.py index 990201cf..6c8b5e7d 100644 --- a/duckdb/experimental/spark/sql/readwriter.py +++ b/duckdb/experimental/spark/sql/readwriter.py @@ -26,7 +26,7 @@ def parquet( self, path: str, mode: Optional[str] = None, - partitionBy: Union[str, List[str], None] = None, + partitionBy: Union[str, list[str], None] = None, compression: Optional[str] = None, ) -> None: relation = self.dataframe.relation @@ -94,7 +94,7 @@ def __init__(self, session: "SparkSession"): def load( self, - path: Optional[Union[str, List[str]]] = None, + path: Optional[Union[str, list[str]]] = None, format: Optional[str] = None, schema: Optional[Union[StructType, str]] = None, **options: OptionalPrimitiveType, @@ -131,7 +131,7 @@ def load( def csv( self, - path: Union[str, List[str]], + path: Union[str, list[str]], schema: Optional[Union[StructType, str]] = None, sep: Optional[str] = None, encoding: Optional[str] = None, @@ -263,7 +263,7 @@ def parquet(self, *paths: str, **options: "OptionalPrimitiveType") -> "DataFrame def json( self, - path: Union[str, List[str]], + path: Union[str, list[str]], schema: Optional[Union[StructType, str]] = None, primitivesAsString: Optional[Union[bool, str]] = None, prefersDecimal: Optional[Union[bool, str]] = None, diff --git a/duckdb/experimental/spark/sql/session.py b/duckdb/experimental/spark/sql/session.py index d3cfaa68..91f9cc0e 100644 --- a/duckdb/experimental/spark/sql/session.py +++ b/duckdb/experimental/spark/sql/session.py @@ -126,7 +126,7 @@ def _createDataFrameFromPandas(self, data: "PandasDataFrame", types, names) -> D def createDataFrame( self, data: Union["PandasDataFrame", Iterable[Any]], - schema: Optional[Union[StructType, List[str]]] = None, + schema: Optional[Union[StructType, list[str]]] = None, samplingRatio: Optional[float] = None, verifySchema: bool = True, ) -> DataFrame: diff --git a/duckdb/experimental/spark/sql/type_utils.py b/duckdb/experimental/spark/sql/type_utils.py index a17d0f53..ecccc014 100644 --- a/duckdb/experimental/spark/sql/type_utils.py +++ b/duckdb/experimental/spark/sql/type_utils.py @@ -79,7 +79,7 @@ def convert_nested_type(dtype: DuckDBPyType) -> DataType: return ArrayType(convert_type(children[0][1])) # TODO: add support for 'union' if id == 'struct': - children: List[Tuple[str, DuckDBPyType]] = dtype.children + children: list[tuple[str, DuckDBPyType]] = dtype.children fields = [StructField(x[0], convert_type(x[1])) for x in children] return StructType(fields) if id == 'map': @@ -92,7 +92,7 @@ def convert_type(dtype: DuckDBPyType) -> DataType: if id in ['list', 'struct', 'map', 'array']: return convert_nested_type(dtype) if id == 'decimal': - children: List[Tuple[str, DuckDBPyType]] = dtype.children + children: list[tuple[str, DuckDBPyType]] = dtype.children precision = cast(int, children[0][1]) scale = cast(int, children[1][1]) return DecimalType(precision, scale) @@ -100,6 +100,6 @@ def convert_type(dtype: DuckDBPyType) -> DataType: return spark_type() -def duckdb_to_spark_schema(names: List[str], types: List[DuckDBPyType]) -> StructType: +def duckdb_to_spark_schema(names: list[str], types: list[DuckDBPyType]) -> StructType: fields = [StructField(name, dtype) for name, dtype in zip(names, [convert_type(x) for x in types])] return StructType(fields) diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 13cd8480..d4dcbd9a 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -92,7 +92,7 @@ def typeName(cls) -> str: def simpleString(self) -> str: return self.typeName() - def jsonValue(self) -> Union[str, Dict[str, Any]]: + def jsonValue(self) -> Union[str, dict[str, Any]]: raise ContributionsAcceptedError def json(self) -> str: @@ -124,9 +124,9 @@ def fromInternal(self, obj: Any) -> Any: class DataTypeSingleton(type): """Metaclass for DataType""" - _instances: ClassVar[Dict[Type["DataTypeSingleton"], "DataTypeSingleton"]] = {} + _instances: ClassVar[dict[type["DataTypeSingleton"], "DataTypeSingleton"]] = {} - def __call__(cls: Type[T]) -> T: # type: ignore[override] + def __call__(cls: type[T]) -> T: # type: ignore[override] if cls not in cls._instances: # type: ignore[attr-defined] cls._instances[cls] = super(DataTypeSingleton, cls).__call__() # type: ignore[misc, attr-defined] return cls._instances[cls] # type: ignore[attr-defined] @@ -603,12 +603,12 @@ def __repr__(self) -> str: def needConversion(self) -> bool: return self.elementType.needConversion() - def toInternal(self, obj: List[Optional[T]]) -> List[Optional[T]]: + def toInternal(self, obj: list[Optional[T]]) -> list[Optional[T]]: if not self.needConversion(): return obj return obj and [self.elementType.toInternal(v) for v in obj] - def fromInternal(self, obj: List[Optional[T]]) -> List[Optional[T]]: + def fromInternal(self, obj: list[Optional[T]]) -> list[Optional[T]]: if not self.needConversion(): return obj return obj and [self.elementType.fromInternal(v) for v in obj] @@ -670,12 +670,12 @@ def __repr__(self) -> str: def needConversion(self) -> bool: return self.keyType.needConversion() or self.valueType.needConversion() - def toInternal(self, obj: Dict[T, Optional[U]]) -> Dict[T, Optional[U]]: + def toInternal(self, obj: dict[T, Optional[U]]) -> dict[T, Optional[U]]: if not self.needConversion(): return obj return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v)) for k, v in obj.items()) - def fromInternal(self, obj: Dict[T, Optional[U]]) -> Dict[T, Optional[U]]: + def fromInternal(self, obj: dict[T, Optional[U]]) -> dict[T, Optional[U]]: if not self.needConversion(): return obj return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v)) for k, v in obj.items()) @@ -710,7 +710,7 @@ def __init__( name: str, dataType: DataType, nullable: bool = True, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, ): super().__init__(dataType.duckdb_type) assert isinstance(dataType, DataType), "dataType %s should be an instance of %s" % ( @@ -776,7 +776,7 @@ class StructType(DataType): def _update_internal_duckdb_type(self): self.duckdb_type = duckdb.struct_type(dict(zip(self.names, [x.duckdb_type for x in self.fields]))) - def __init__(self, fields: Optional[List[StructField]] = None): + def __init__(self, fields: Optional[list[StructField]] = None): if not fields: self.fields = [] self.names = [] @@ -795,7 +795,7 @@ def add( field: str, data_type: Union[str, DataType], nullable: bool = True, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, ) -> "StructType": ... @@ -808,7 +808,7 @@ def add( field: Union[str, StructField], data_type: Optional[Union[str, DataType]] = None, nullable: bool = True, - metadata: Optional[Dict[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, ) -> "StructType": """ Construct a :class:`StructType` by adding new elements to it, to define the schema. @@ -900,7 +900,7 @@ def __repr__(self) -> str: def __contains__(self, item: Any) -> bool: return item in self.names - def extract_types_and_names(self) -> Tuple[List[str], List[str]]: + def extract_types_and_names(self) -> tuple[list[str], list[str]]: names = [] types = [] for f in self.fields: @@ -908,7 +908,7 @@ def extract_types_and_names(self) -> Tuple[List[str], List[str]]: names.append(f.name) return (types, names) - def fieldNames(self) -> List[str]: + def fieldNames(self) -> list[str]: """ Returns all field names in a list. @@ -924,7 +924,7 @@ def needConversion(self) -> bool: # We need convert Row()/namedtuple into tuple() return True - def toInternal(self, obj: Tuple) -> Tuple: + def toInternal(self, obj: tuple) -> tuple: if obj is None: return @@ -956,14 +956,14 @@ def toInternal(self, obj: Tuple) -> Tuple: else: raise ValueError("Unexpected tuple %r with StructType" % obj) - def fromInternal(self, obj: Tuple) -> "Row": + def fromInternal(self, obj: tuple) -> "Row": if obj is None: return if isinstance(obj, Row): # it's already converted by pickler return obj - values: Union[Tuple, List] + values: Union[tuple, list] if self._needSerializeAnyField: # Only calling fromInternal function for fields that need conversion values = [f.fromInternal(v) if c else v for f, v, c in zip(self.fields, obj, self._needConversion)] @@ -1052,7 +1052,7 @@ def __eq__(self, other: Any) -> bool: return type(self) == type(other) -_atomic_types: List[Type[DataType]] = [ +_atomic_types: list[type[DataType]] = [ StringType, BinaryType, BooleanType, @@ -1068,14 +1068,14 @@ def __eq__(self, other: Any) -> bool: TimestampNTZType, NullType, ] -_all_atomic_types: Dict[str, Type[DataType]] = dict((t.typeName(), t) for t in _atomic_types) +_all_atomic_types: dict[str, type[DataType]] = dict((t.typeName(), t) for t in _atomic_types) -_complex_types: List[Type[Union[ArrayType, MapType, StructType]]] = [ +_complex_types: list[type[Union[ArrayType, MapType, StructType]]] = [ ArrayType, MapType, StructType, ] -_all_complex_types: Dict[str, Type[Union[ArrayType, MapType, StructType]]] = dict( +_all_complex_types: dict[str, type[Union[ArrayType, MapType, StructType]]] = dict( (v.typeName(), v) for v in _complex_types ) @@ -1084,7 +1084,7 @@ def __eq__(self, other: Any) -> bool: _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?") -def _create_row(fields: Union["Row", List[str]], values: Union[Tuple[Any, ...], List[Any]]) -> "Row": +def _create_row(fields: Union["Row", list[str]], values: Union[tuple[Any, ...], list[Any]]) -> "Row": row = Row(*values) row.__fields__ = fields return row @@ -1166,7 +1166,7 @@ def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": # create row class or objects return tuple.__new__(cls, args) - def asDict(self, recursive: bool = False) -> Dict[str, Any]: + def asDict(self, recursive: bool = False) -> dict[str, Any]: """ Return as a dict @@ -1260,7 +1260,7 @@ def __setattr__(self, key: Any, value: Any) -> None: def __reduce__( self, - ) -> Union[str, Tuple[Any, ...]]: + ) -> Union[str, tuple[Any, ...]]: """Returns a tuple so Python knows how to pickle Row.""" if hasattr(self, "__fields__"): return (_create_row, (self.__fields__, tuple(self))) diff --git a/duckdb/value/constant/__init__.py b/duckdb/value/constant/__init__.py index da2004b9..0a5a62c0 100644 --- a/duckdb/value/constant/__init__.py +++ b/duckdb/value/constant/__init__.py @@ -210,7 +210,7 @@ def __init__(self, object: Any, child_type: DuckDBPyType): class StructValue(Value): - def __init__(self, object: Any, children: Dict[str, DuckDBPyType]): + def __init__(self, object: Any, children: dict[str, DuckDBPyType]): import duckdb struct_type = duckdb.struct_type(children) @@ -226,7 +226,7 @@ def __init__(self, object: Any, key_type: DuckDBPyType, value_type: DuckDBPyType class UnionType(Value): - def __init__(self, object: Any, members: Dict[str, DuckDBPyType]): + def __init__(self, object: Any, members: dict[str, DuckDBPyType]): import duckdb union_type = duckdb.union_type(members) diff --git a/duckdb_packaging/build_backend.py b/duckdb_packaging/build_backend.py index d96a4847..de1a9535 100644 --- a/duckdb_packaging/build_backend.py +++ b/duckdb_packaging/build_backend.py @@ -126,7 +126,7 @@ def _read_duckdb_long_version() -> str: def _skbuild_config_add( - key: str, value: Union[List, str], config_settings: Dict[str, Union[List[str],str]], fail_if_exists: bool=False + key: str, value: Union[list, str], config_settings: dict[str, Union[list[str],str]], fail_if_exists: bool=False ): """Add or modify a configuration setting for scikit-build-core. @@ -178,7 +178,7 @@ def _skbuild_config_add( ) -def build_sdist(sdist_directory: str, config_settings: Optional[Dict[str, Union[List[str],str]]] = None) -> str: +def build_sdist(sdist_directory: str, config_settings: Optional[dict[str, Union[list[str],str]]] = None) -> str: """Build a source distribution using the DuckDB submodule. This function extracts the DuckDB version from either the git submodule and saves it @@ -208,7 +208,7 @@ def build_sdist(sdist_directory: str, config_settings: Optional[Dict[str, Union[ def build_wheel( wheel_directory: str, - config_settings: Optional[Dict[str, Union[List[str],str]]] = None, + config_settings: Optional[dict[str, Union[list[str],str]]] = None, metadata_directory: Optional[str] = None, ) -> str: """Build a wheel from either git submodule or extracted sdist sources. diff --git a/duckdb_packaging/pypi_cleanup.py b/duckdb_packaging/pypi_cleanup.py index 81d4c8e0..8236dd1d 100644 --- a/duckdb_packaging/pypi_cleanup.py +++ b/duckdb_packaging/pypi_cleanup.py @@ -290,7 +290,7 @@ def _execute_cleanup(self, http_session: Session) -> int: logging.info(f"Successfully cleaned up {len(versions_to_delete)} development versions") return 0 - def _fetch_released_versions(self, http_session: Session) -> Set[str]: + def _fetch_released_versions(self, http_session: Session) -> set[str]: """Fetch package release information from PyPI API.""" logging.debug(f"Fetching package information for '{self._package}'") @@ -330,7 +330,7 @@ def _parse_dev_version(self, version: str) -> tuple[str, int]: raise PyPICleanupError(f"Invalid dev version '{version}'") return match.group("version"), int(match.group("dev_id")) - def _determine_versions_to_delete(self, versions: Set[str]) -> Set[str]: + def _determine_versions_to_delete(self, versions: set[str]) -> set[str]: """Determine which package versions should be deleted.""" logging.debug("Analyzing versions to determine cleanup candidates") @@ -488,7 +488,7 @@ def _handle_two_factor_auth(self, http_session: Session, response: requests.Resp raise AuthenticationError("Two-factor authentication failed after all attempts") - def _delete_versions(self, http_session: Session, versions_to_delete: Set[str]) -> None: + def _delete_versions(self, http_session: Session, versions_to_delete: set[str]) -> None: """Delete the specified package versions.""" logging.info(f"Starting deletion of {len(versions_to_delete)} development versions") diff --git a/scripts/generate_connection_wrapper_methods.py b/scripts/generate_connection_wrapper_methods.py index 7be7256c..af5ad4ac 100644 --- a/scripts/generate_connection_wrapper_methods.py +++ b/scripts/generate_connection_wrapper_methods.py @@ -71,7 +71,7 @@ def is_py_kwargs(method): return 'kwargs_as_dict' in method and method['kwargs_as_dict'] == True -def remove_section(content, start_marker, end_marker) -> Tuple[List[str], List[str]]: +def remove_section(content, start_marker, end_marker) -> tuple[list[str], list[str]]: start_index = -1 end_index = -1 for i, line in enumerate(content): diff --git a/scripts/generate_import_cache_cpp.py b/scripts/generate_import_cache_cpp.py index 07744e37..f1f9d983 100644 --- a/scripts/generate_import_cache_cpp.py +++ b/scripts/generate_import_cache_cpp.py @@ -16,7 +16,7 @@ # deal with leaf nodes?? Those are just PythonImportCacheItem def get_class_name(path: str) -> str: - parts: List[str] = path.replace('_', '').split('.') + parts: list[str] = path.replace('_', '').split('.') parts = [x.title() for x in parts] return ''.join(parts) + 'CacheItem' @@ -31,7 +31,7 @@ def get_variable_name(name: str) -> str: return name -def collect_items_of_module(module: dict, collection: Dict): +def collect_items_of_module(module: dict, collection: dict): global json_data children = module['children'] collection[module['full_path']] = module @@ -122,8 +122,8 @@ def to_string(self): """ -def collect_classes(items: Dict) -> List: - output: List = [] +def collect_classes(items: dict) -> list: + output: list = [] for item in items.values(): if item['children'] == []: continue @@ -174,7 +174,7 @@ def to_string(self): return string -files: List[ModuleFile] = [] +files: list[ModuleFile] = [] for name, value in json_data.items(): if value['full_path'] != value['name']: continue @@ -188,7 +188,7 @@ def to_string(self): f.write(content) -def get_root_modules(files: List[ModuleFile]): +def get_root_modules(files: list[ModuleFile]): modules = [] for file in files: name = file.module['name'] @@ -244,7 +244,7 @@ def get_root_modules(files: List[ModuleFile]): f.write(import_cache_file) -def get_module_file_path_includes(files: List[ModuleFile]): +def get_module_file_path_includes(files: list[ModuleFile]): includes = [] for file in files: includes.append(f'#include "duckdb_python/import_cache/modules/{file.file_name}"') diff --git a/scripts/generate_import_cache_json.py b/scripts/generate_import_cache_json.py index 40e6a773..53d98c57 100644 --- a/scripts/generate_import_cache_json.py +++ b/scripts/generate_import_cache_json.py @@ -4,7 +4,7 @@ from typing import List, Dict, Union import json -lines: List[str] = [file for file in open(f'{script_dir}/imports.py').read().split('\n') if file != ''] +lines: list[str] = [file for file in open(f'{script_dir}/imports.py').read().split('\n') if file != ''] class ImportCacheAttribute: @@ -13,7 +13,7 @@ def __init__(self, full_path: str): self.type = "attribute" self.name = parts[-1] self.full_path = full_path - self.children: Dict[str, "ImportCacheAttribute"] = {} + self.children: dict[str, "ImportCacheAttribute"] = {} def has_item(self, item_name: str) -> bool: return item_name in self.children @@ -46,7 +46,7 @@ def __init__(self, full_path): self.type = "module" self.name = parts[-1] self.full_path = full_path - self.items: Dict[str, Union[ImportCacheAttribute, "ImportCacheModule"]] = {} + self.items: dict[str, Union[ImportCacheAttribute, "ImportCacheModule"]] = {} def add_item(self, item: Union[ImportCacheAttribute, "ImportCacheModule"]): assert self.full_path != item.full_path @@ -79,7 +79,7 @@ def root_module(self) -> bool: class ImportCacheGenerator: def __init__(self): - self.modules: Dict[str, ImportCacheModule] = {} + self.modules: dict[str, ImportCacheModule] = {} def add_module(self, path: str): assert path.startswith('import') diff --git a/scripts/get_cpp_methods.py b/scripts/get_cpp_methods.py index e784d054..9f86b4cb 100644 --- a/scripts/get_cpp_methods.py +++ b/scripts/get_cpp_methods.py @@ -16,7 +16,7 @@ def __init__(self, name: str, proto: str): class ConnectionMethod: - def __init__(self, name: str, params: List[FunctionParam], is_void: bool): + def __init__(self, name: str, params: list[FunctionParam], is_void: bool): self.name = name self.params = params self.is_void = is_void @@ -49,7 +49,7 @@ def on_class_method(self, state, node): self.methods_dict[name] = ConnectionMethod(name, params, is_void) -def get_methods(class_name: str) -> Dict[str, ConnectionMethod]: +def get_methods(class_name: str) -> dict[str, ConnectionMethod]: CLASSES = { "DuckDBPyConnection": os.path.join( scripts_folder, diff --git a/sqllogic/conftest.py b/sqllogic/conftest.py index 73219e0d..64ad8edc 100644 --- a/sqllogic/conftest.py +++ b/sqllogic/conftest.py @@ -90,7 +90,7 @@ def get_test_id(path: pathlib.Path, root_dir: pathlib.Path, config: pytest.Confi return str(path.relative_to(root_dir.parent)) -def get_test_marks(path: pathlib.Path, root_dir: pathlib.Path, config: pytest.Config) -> typing.List[typing.Any]: +def get_test_marks(path: pathlib.Path, root_dir: pathlib.Path, config: pytest.Config) -> list[typing.Any]: # Tests are tagged with the their category (i.e., name of their parent directory) category = path.parent.name @@ -142,7 +142,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc): if metafunc.definition.name != SQLLOGIC_TEST_CASE_NAME: return - test_dirs: typing.List[pathlib.Path] = metafunc.config.getoption("test_dirs") + test_dirs: list[pathlib.Path] = metafunc.config.getoption("test_dirs") test_glob: typing.Optional[pathlib.Path] = metafunc.config.getoption("path") parameters = [] @@ -165,7 +165,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc): metafunc.parametrize(SQLLOGIC_TEST_PARAMETER, parameters) -def determine_test_offsets(config: pytest.Config, num_tests: int) -> typing.Tuple[int, int]: +def determine_test_offsets(config: pytest.Config, num_tests: int) -> tuple[int, int]: """ If start_offset and end_offset are specified, then these are used. start_offset defaults to 0. end_offset defaults to and is capped to the last test index. diff --git a/tests/fast/test_filesystem.py b/tests/fast/test_filesystem.py index eaa86398..195de165 100644 --- a/tests/fast/test_filesystem.py +++ b/tests/fast/test_filesystem.py @@ -20,7 +20,7 @@ logging.basicConfig(level=logging.DEBUG) -def intercept(monkeypatch: MonkeyPatch, obj: object, name: str) -> List[str]: +def intercept(monkeypatch: MonkeyPatch, obj: object, name: str) -> list[str]: error_occurred = [] orig = getattr(obj, name) diff --git a/tests/fast/test_multithread.py b/tests/fast/test_multithread.py index 1ffdfc25..4b470b84 100644 --- a/tests/fast/test_multithread.py +++ b/tests/fast/test_multithread.py @@ -20,7 +20,7 @@ def connect_duck(duckdb_conn): assert out == [(42,), (84,), (None,), (128,)] -def everything_succeeded(results: List[bool]): +def everything_succeeded(results: list[bool]): return all([result == True for result in results]) @@ -501,7 +501,7 @@ def test_description(self, duckdb_cursor, pandas): @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) def test_cursor(self, duckdb_cursor, pandas): - def only_some_succeed(results: List[bool]): + def only_some_succeed(results: list[bool]): if not any([result == True for result in results]): return False if all([result == True for result in results]): From a00166cbea0e3b9403f05ebf4b7d121956156e07 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:12:46 +0200 Subject: [PATCH 210/472] Ruff ANN204: return type annotations --- duckdb/bytes_io_wrapper.py | 4 +- duckdb/experimental/spark/_globals.py | 10 +-- duckdb/experimental/spark/conf.py | 2 +- duckdb/experimental/spark/context.py | 2 +- .../spark/errors/exceptions/base.py | 2 +- duckdb/experimental/spark/exception.py | 2 +- duckdb/experimental/spark/sql/catalog.py | 2 +- duckdb/experimental/spark/sql/column.py | 4 +- duckdb/experimental/spark/sql/conf.py | 2 +- duckdb/experimental/spark/sql/dataframe.py | 4 +- duckdb/experimental/spark/sql/group.py | 6 +- duckdb/experimental/spark/sql/readwriter.py | 4 +- duckdb/experimental/spark/sql/session.py | 4 +- duckdb/experimental/spark/sql/streaming.py | 4 +- duckdb/experimental/spark/sql/types.py | 70 +++++++++---------- duckdb/experimental/spark/sql/udf.py | 2 +- duckdb/query_graph/__main__.py | 2 +- duckdb/value/constant/__init__.py | 66 ++++++++--------- duckdb_packaging/pypi_cleanup.py | 4 +- scripts/generate_import_cache_cpp.py | 4 +- scripts/generate_import_cache_json.py | 6 +- scripts/get_cpp_methods.py | 10 +-- sqllogic/test_sqllogic.py | 2 +- tests/conftest.py | 14 ++-- tests/fast/api/test_fsspec.py | 2 +- tests/fast/api/test_read_csv.py | 10 +-- tests/fast/arrow/test_arrow_extensions.py | 4 +- tests/fast/arrow/test_arrow_list.py | 2 +- tests/fast/arrow/test_arrow_pycapsule.py | 8 +-- tests/fast/arrow/test_dataset.py | 4 +- .../fast/pandas/test_df_object_resolution.py | 8 +-- tests/fast/pandas/test_pandas_types.py | 2 +- tests/fast/test_expression.py | 2 +- tests/fast/test_multithread.py | 2 +- tests/fast/udf/test_scalar.py | 6 +- 35 files changed, 143 insertions(+), 139 deletions(-) diff --git a/duckdb/bytes_io_wrapper.py b/duckdb/bytes_io_wrapper.py index 829b69cd..0957652b 100644 --- a/duckdb/bytes_io_wrapper.py +++ b/duckdb/bytes_io_wrapper.py @@ -1,5 +1,5 @@ from io import StringIO, TextIOBase -from typing import Union +from typing import Any, Union """ BSD 3-Clause License @@ -48,7 +48,7 @@ def __init__(self, buffer: Union[StringIO, TextIOBase], encoding: str = "utf-8") # overflow to the front of the bytestring the next time reading is performed self.overflow = b"" - def __getattr__(self, attr: str): + def __getattr__(self, attr: str) -> Any: return getattr(self.buffer, attr) def read(self, n: Union[int, None] = -1) -> bytes: diff --git a/duckdb/experimental/spark/_globals.py b/duckdb/experimental/spark/_globals.py index c43287e6..be16be41 100644 --- a/duckdb/experimental/spark/_globals.py +++ b/duckdb/experimental/spark/_globals.py @@ -32,6 +32,8 @@ def foo(arg=pyducdkb.spark._NoValue): Note that this approach is taken after from NumPy. """ +from typing import Type + __ALL__ = ["_NoValue"] @@ -54,23 +56,23 @@ class _NoValueType: __instance = None - def __new__(cls): + def __new__(cls) -> '_NoValueType': # ensure that only one instance exists if not cls.__instance: cls.__instance = super(_NoValueType, cls).__new__(cls) return cls.__instance # Make the _NoValue instance falsey - def __nonzero__(self): + def __nonzero__(self) -> bool: return False __bool__ = __nonzero__ # needed for python 2 to preserve identity through a pickle - def __reduce__(self): + def __reduce__(self) -> tuple[Type, tuple]: return (self.__class__, ()) - def __repr__(self): + def __repr__(self) -> str: return "" diff --git a/duckdb/experimental/spark/conf.py b/duckdb/experimental/spark/conf.py index a04c993b..79706781 100644 --- a/duckdb/experimental/spark/conf.py +++ b/duckdb/experimental/spark/conf.py @@ -3,7 +3,7 @@ class SparkConf: - def __init__(self): + def __init__(self) -> None: raise NotImplementedError def contains(self, key: str) -> bool: diff --git a/duckdb/experimental/spark/context.py b/duckdb/experimental/spark/context.py index a2e7c78f..95227add 100644 --- a/duckdb/experimental/spark/context.py +++ b/duckdb/experimental/spark/context.py @@ -7,7 +7,7 @@ class SparkContext: - def __init__(self, master: str): + def __init__(self, master: str) -> None: self._connection = duckdb.connect(':memory:') # This aligns the null ordering with Spark. self._connection.execute("set default_null_order='nulls_first_on_asc_last_on_desc'") diff --git a/duckdb/experimental/spark/errors/exceptions/base.py b/duckdb/experimental/spark/errors/exceptions/base.py index 80e91170..fcdce827 100644 --- a/duckdb/experimental/spark/errors/exceptions/base.py +++ b/duckdb/experimental/spark/errors/exceptions/base.py @@ -14,7 +14,7 @@ def __init__( error_class: Optional[str] = None, # The dictionary listing the arguments specified in the message (or the error_class) message_parameters: Optional[dict[str, str]] = None, - ): + ) -> None: # `message` vs `error_class` & `message_parameters` are mutually exclusive. assert (message is not None and (error_class is None and message_parameters is None)) or ( message is None and (error_class is not None and message_parameters is not None) diff --git a/duckdb/experimental/spark/exception.py b/duckdb/experimental/spark/exception.py index 7cb47650..21668cf5 100644 --- a/duckdb/experimental/spark/exception.py +++ b/duckdb/experimental/spark/exception.py @@ -5,7 +5,7 @@ class ContributionsAcceptedError(NotImplementedError): feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb """ - def __init__(self, message=None): + def __init__(self, message=None) -> None: doc = self.__class__.__doc__ if message: doc = message + '\n' + doc diff --git a/duckdb/experimental/spark/sql/catalog.py b/duckdb/experimental/spark/sql/catalog.py index d3b857fb..0cd790f7 100644 --- a/duckdb/experimental/spark/sql/catalog.py +++ b/duckdb/experimental/spark/sql/catalog.py @@ -33,7 +33,7 @@ class Function(NamedTuple): class Catalog: - def __init__(self, session: SparkSession): + def __init__(self, session: SparkSession) -> None: self._session = session def listDatabases(self) -> list[Database]: diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index de0c95f8..0dd86178 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -95,11 +95,11 @@ class Column: .. versionadded:: 1.3.0 """ - def __init__(self, expr: Expression): + def __init__(self, expr: Expression) -> None: self.expr = expr # arithmetic operators - def __neg__(self): + def __neg__(self) -> 'Column': return Column(-self.expr) # `and`, `or`, `not` cannot be overloaded in Python, diff --git a/duckdb/experimental/spark/sql/conf.py b/duckdb/experimental/spark/sql/conf.py index 98b773fb..8e30d7ca 100644 --- a/duckdb/experimental/spark/sql/conf.py +++ b/duckdb/experimental/spark/sql/conf.py @@ -4,7 +4,7 @@ class RuntimeConfig: - def __init__(self, connection: DuckDBPyConnection): + def __init__(self, connection: DuckDBPyConnection) -> None: self._connection = connection def set(self, key: str, value: str) -> None: diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 54c220eb..42a5b8f0 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -37,7 +37,7 @@ class DataFrame: - def __init__(self, relation: duckdb.DuckDBPyRelation, session: "SparkSession"): + def __init__(self, relation: duckdb.DuckDBPyRelation, session: "SparkSession") -> None: self.relation = relation self.session = session self._schema = None @@ -870,7 +870,7 @@ def limit(self, num: int) -> "DataFrame": rel = self.relation.limit(num) return DataFrame(rel, self.session) - def __contains__(self, item: str): + def __contains__(self, item: str) -> bool: """ Check if the :class:`DataFrame` contains a column by the name of `item` """ diff --git a/duckdb/experimental/spark/sql/group.py b/duckdb/experimental/spark/sql/group.py index ad7e7e2a..4c4d5bb6 100644 --- a/duckdb/experimental/spark/sql/group.py +++ b/duckdb/experimental/spark/sql/group.py @@ -53,7 +53,7 @@ def _api(self: "GroupedData", *cols: str) -> DataFrame: class Grouping: - def __init__(self, *cols: "ColumnOrName", **kwargs): + def __init__(self, *cols: "ColumnOrName", **kwargs) -> None: self._type = "" self._cols = [_to_column_expr(x) for x in cols] if 'special' in kwargs: @@ -66,7 +66,7 @@ def get_columns(self) -> str: columns = ",".join([str(x) for x in self._cols]) return columns - def __str__(self): + def __str__(self) -> str: columns = self.get_columns() if self._type: return self._type + '(' + columns + ')' @@ -80,7 +80,7 @@ class GroupedData: """ - def __init__(self, grouping: Grouping, df: DataFrame): + def __init__(self, grouping: Grouping, df: DataFrame) -> None: self._grouping = grouping self._df = df self.session: SparkSession = df.session diff --git a/duckdb/experimental/spark/sql/readwriter.py b/duckdb/experimental/spark/sql/readwriter.py index 6c8b5e7d..6e8c72c6 100644 --- a/duckdb/experimental/spark/sql/readwriter.py +++ b/duckdb/experimental/spark/sql/readwriter.py @@ -15,7 +15,7 @@ class DataFrameWriter: - def __init__(self, dataframe: "DataFrame"): + def __init__(self, dataframe: "DataFrame") -> None: self.dataframe = dataframe def saveAsTable(self, table_name: str) -> None: @@ -89,7 +89,7 @@ def csv( class DataFrameReader: - def __init__(self, session: "SparkSession"): + def __init__(self, session: "SparkSession") -> None: self.session = session def load( diff --git a/duckdb/experimental/spark/sql/session.py b/duckdb/experimental/spark/sql/session.py index 91f9cc0e..744a77e8 100644 --- a/duckdb/experimental/spark/sql/session.py +++ b/duckdb/experimental/spark/sql/session.py @@ -45,7 +45,7 @@ def _combine_data_and_schema(data: Iterable[Any], schema: StructType): class SparkSession: - def __init__(self, context: SparkContext): + def __init__(self, context: SparkContext) -> None: self.conn = context.connection self._context = context self._conf = RuntimeConfig(self.conn) @@ -258,7 +258,7 @@ def version(self) -> str: return '1.0.0' class Builder: - def __init__(self): + def __init__(self) -> None: pass def master(self, name: str) -> "SparkSession.Builder": diff --git a/duckdb/experimental/spark/sql/streaming.py b/duckdb/experimental/spark/sql/streaming.py index 5414344f..cda80602 100644 --- a/duckdb/experimental/spark/sql/streaming.py +++ b/duckdb/experimental/spark/sql/streaming.py @@ -10,7 +10,7 @@ class DataStreamWriter: - def __init__(self, dataframe: "DataFrame"): + def __init__(self, dataframe: "DataFrame") -> None: self.dataframe = dataframe def toTable(self, table_name: str) -> None: @@ -19,7 +19,7 @@ def toTable(self, table_name: str) -> None: class DataStreamReader: - def __init__(self, session: "SparkSession"): + def __init__(self, session: "SparkSession") -> None: self.session = session def load( diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index d4dcbd9a..4b3a4132 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -70,7 +70,7 @@ class DataType: """Base class for data types.""" - def __init__(self, duckdb_type): + def __init__(self, duckdb_type) -> None: self.duckdb_type = duckdb_type def __repr__(self) -> str: @@ -138,7 +138,7 @@ class NullType(DataType, metaclass=DataTypeSingleton): The data type representing None, used for the types that cannot be inferred. """ - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("NULL")) @classmethod @@ -166,42 +166,42 @@ class FractionalType(NumericType): class StringType(AtomicType, metaclass=DataTypeSingleton): """String data type.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("VARCHAR")) class BitstringType(AtomicType, metaclass=DataTypeSingleton): """Bitstring data type.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("BIT")) class UUIDType(AtomicType, metaclass=DataTypeSingleton): """UUID data type.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("UUID")) class BinaryType(AtomicType, metaclass=DataTypeSingleton): """Binary (byte array) data type.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("BLOB")) class BooleanType(AtomicType, metaclass=DataTypeSingleton): """Boolean data type.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("BOOLEAN")) class DateType(AtomicType, metaclass=DataTypeSingleton): """Date (datetime.date) data type.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("DATE")) EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() @@ -221,7 +221,7 @@ def fromInternal(self, v: int) -> datetime.date: class TimestampType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("TIMESTAMPTZ")) @classmethod @@ -245,7 +245,7 @@ def fromInternal(self, ts: int) -> datetime.datetime: class TimestampNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with microsecond precision.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("TIMESTAMP")) def needConversion(self) -> bool: @@ -269,7 +269,7 @@ def fromInternal(self, ts: int) -> datetime.datetime: class TimestampSecondNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with second precision.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("TIMESTAMP_S")) def needConversion(self) -> bool: @@ -289,7 +289,7 @@ def fromInternal(self, ts: int) -> datetime.datetime: class TimestampMilisecondNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with milisecond precision.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("TIMESTAMP_MS")) def needConversion(self) -> bool: @@ -309,7 +309,7 @@ def fromInternal(self, ts: int) -> datetime.datetime: class TimestampNanosecondNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with nanosecond precision.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("TIMESTAMP_NS")) def needConversion(self) -> bool: @@ -346,7 +346,7 @@ class DecimalType(FractionalType): the number of digits on right side of dot. (default: 0) """ - def __init__(self, precision: int = 10, scale: int = 0): + def __init__(self, precision: int = 10, scale: int = 0) -> None: super().__init__(duckdb.decimal_type(precision, scale)) self.precision = precision self.scale = scale @@ -362,21 +362,21 @@ def __repr__(self) -> str: class DoubleType(FractionalType, metaclass=DataTypeSingleton): """Double data type, representing double precision floats.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("DOUBLE")) class FloatType(FractionalType, metaclass=DataTypeSingleton): """Float data type, representing single precision floats.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("FLOAT")) class ByteType(IntegralType): """Byte data type, i.e. a signed integer in a single byte.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("TINYINT")) def simpleString(self) -> str: @@ -386,7 +386,7 @@ def simpleString(self) -> str: class UnsignedByteType(IntegralType): """Unsigned byte data type, i.e. a unsigned integer in a single byte.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("UTINYINT")) def simpleString(self) -> str: @@ -396,7 +396,7 @@ def simpleString(self) -> str: class ShortType(IntegralType): """Short data type, i.e. a signed 16-bit integer.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("SMALLINT")) def simpleString(self) -> str: @@ -406,7 +406,7 @@ def simpleString(self) -> str: class UnsignedShortType(IntegralType): """Unsigned short data type, i.e. a unsigned 16-bit integer.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("USMALLINT")) def simpleString(self) -> str: @@ -416,7 +416,7 @@ def simpleString(self) -> str: class IntegerType(IntegralType): """Int data type, i.e. a signed 32-bit integer.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("INTEGER")) def simpleString(self) -> str: @@ -426,7 +426,7 @@ def simpleString(self) -> str: class UnsignedIntegerType(IntegralType): """Unsigned int data type, i.e. a unsigned 32-bit integer.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("UINTEGER")) def simpleString(self) -> str: @@ -440,7 +440,7 @@ class LongType(IntegralType): please use :class:`DecimalType`. """ - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("BIGINT")) def simpleString(self) -> str: @@ -454,7 +454,7 @@ class UnsignedLongType(IntegralType): please use :class:`HugeIntegerType`. """ - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("UBIGINT")) def simpleString(self) -> str: @@ -468,7 +468,7 @@ class HugeIntegerType(IntegralType): please use :class:`DecimalType`. """ - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("HUGEINT")) def simpleString(self) -> str: @@ -482,7 +482,7 @@ class UnsignedHugeIntegerType(IntegralType): please use :class:`DecimalType`. """ - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("UHUGEINT")) def simpleString(self) -> str: @@ -492,7 +492,7 @@ def simpleString(self) -> str: class TimeType(IntegralType): """Time (datetime.time) data type.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("TIMETZ")) def simpleString(self) -> str: @@ -502,7 +502,7 @@ def simpleString(self) -> str: class TimeNTZType(IntegralType): """Time (datetime.time) data type without timezone information.""" - def __init__(self): + def __init__(self) -> None: super().__init__(DuckDBPyType("TIME")) def simpleString(self) -> str: @@ -526,7 +526,7 @@ class DayTimeIntervalType(AtomicType): _inverted_fields = dict(zip(_fields.values(), _fields.keys())) - def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None): + def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None) -> None: super().__init__(DuckDBPyType("INTERVAL")) if startField is None and endField is None: # Default matched to scala side. @@ -585,7 +585,7 @@ class ArrayType(DataType): False """ - def __init__(self, elementType: DataType, containsNull: bool = True): + def __init__(self, elementType: DataType, containsNull: bool = True) -> None: super().__init__(duckdb.list_type(elementType.duckdb_type)) assert isinstance(elementType, DataType), "elementType %s should be an instance of %s" % ( elementType, @@ -640,7 +640,7 @@ class MapType(DataType): False """ - def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True): + def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True) -> None: super().__init__(duckdb.map_type(keyType.duckdb_type, valueType.duckdb_type)) assert isinstance(keyType, DataType), "keyType %s should be an instance of %s" % ( keyType, @@ -711,7 +711,7 @@ def __init__( dataType: DataType, nullable: bool = True, metadata: Optional[dict[str, Any]] = None, - ): + ) -> None: super().__init__(dataType.duckdb_type) assert isinstance(dataType, DataType), "dataType %s should be an instance of %s" % ( dataType, @@ -776,7 +776,7 @@ class StructType(DataType): def _update_internal_duckdb_type(self): self.duckdb_type = duckdb.struct_type(dict(zip(self.names, [x.duckdb_type for x in self.fields]))) - def __init__(self, fields: Optional[list[StructField]] = None): + def __init__(self, fields: Optional[list[StructField]] = None) -> None: if not fields: self.fields = [] self.names = [] @@ -973,7 +973,7 @@ def fromInternal(self, obj: tuple) -> "Row": class UnionType(DataType): - def __init__(self): + def __init__(self) -> None: raise ContributionsAcceptedError @@ -983,7 +983,7 @@ class UserDefinedType(DataType): .. note:: WARN: Spark Internal Use Only """ - def __init__(self): + def __init__(self) -> None: raise ContributionsAcceptedError @classmethod diff --git a/duckdb/experimental/spark/sql/udf.py b/duckdb/experimental/spark/sql/udf.py index 61d3bee9..389d43ab 100644 --- a/duckdb/experimental/spark/sql/udf.py +++ b/duckdb/experimental/spark/sql/udf.py @@ -11,7 +11,7 @@ class UDFRegistration: - def __init__(self, sparkSession: "SparkSession"): + def __init__(self, sparkSession: "SparkSession") -> None: self.sparkSession = sparkSession def register( diff --git a/duckdb/query_graph/__main__.py b/duckdb/query_graph/__main__.py index 26038a6f..eab68179 100644 --- a/duckdb/query_graph/__main__.py +++ b/duckdb/query_graph/__main__.py @@ -95,7 +95,7 @@ def combine_timing(l: object, r: object) -> object: class AllTimings: - def __init__(self): + def __init__(self) -> None: self.phase_to_timings = {} def add_node_timing(self, node_timing: NodeTiming): diff --git a/duckdb/value/constant/__init__.py b/duckdb/value/constant/__init__.py index 0a5a62c0..fb7d7284 100644 --- a/duckdb/value/constant/__init__.py +++ b/duckdb/value/constant/__init__.py @@ -32,7 +32,7 @@ class Value: - def __init__(self, object: Any, type: DuckDBPyType): + def __init__(self, object: Any, type: DuckDBPyType) -> None: self.object = object self.type = type @@ -44,12 +44,12 @@ def __repr__(self) -> str: class NullValue(Value): - def __init__(self): + def __init__(self) -> None: super().__init__(None, SQLNULL) class BooleanValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, BOOLEAN) @@ -57,22 +57,22 @@ def __init__(self, object: Any): class UnsignedBinaryValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, UTINYINT) class UnsignedShortValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, USMALLINT) class UnsignedIntegerValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, UINTEGER) class UnsignedLongValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, UBIGINT) @@ -80,32 +80,32 @@ def __init__(self, object: Any): class BinaryValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TINYINT) class ShortValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, SMALLINT) class IntegerValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, INTEGER) class LongValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, BIGINT) class HugeIntegerValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, HUGEINT) class UnsignedHugeIntegerValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, UHUGEINT) @@ -113,17 +113,17 @@ def __init__(self, object: Any): class FloatValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, FLOAT) class DoubleValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, DOUBLE) class DecimalValue(Value): - def __init__(self, object: Any, width: int, scale: int): + def __init__(self, object: Any, width: int, scale: int) -> None: import duckdb decimal_type = duckdb.decimal_type(width, scale) @@ -134,22 +134,22 @@ def __init__(self, object: Any, width: int, scale: int): class StringValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, VARCHAR) class UUIDValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, UUID) class BitValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, BIT) class BlobValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, BLOB) @@ -157,52 +157,52 @@ def __init__(self, object: Any): class DateValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, DATE) class IntervalValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, INTERVAL) class TimestampValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIMESTAMP) class TimestampSecondValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIMESTAMP_S) class TimestampMilisecondValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIMESTAMP_MS) class TimestampNanosecondValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIMESTAMP_NS) class TimestampTimeZoneValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIMESTAMP_TZ) class TimeValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIME) class TimeTimeZoneValue(Value): - def __init__(self, object: Any): + def __init__(self, object: Any) -> None: super().__init__(object, TIME_TZ) class ListValue(Value): - def __init__(self, object: Any, child_type: DuckDBPyType): + def __init__(self, object: Any, child_type: DuckDBPyType) -> None: import duckdb list_type = duckdb.list_type(child_type) @@ -210,7 +210,7 @@ def __init__(self, object: Any, child_type: DuckDBPyType): class StructValue(Value): - def __init__(self, object: Any, children: dict[str, DuckDBPyType]): + def __init__(self, object: Any, children: dict[str, DuckDBPyType]) -> None: import duckdb struct_type = duckdb.struct_type(children) @@ -218,7 +218,7 @@ def __init__(self, object: Any, children: dict[str, DuckDBPyType]): class MapValue(Value): - def __init__(self, object: Any, key_type: DuckDBPyType, value_type: DuckDBPyType): + def __init__(self, object: Any, key_type: DuckDBPyType, value_type: DuckDBPyType) -> None: import duckdb map_type = duckdb.map_type(key_type, value_type) @@ -226,7 +226,7 @@ def __init__(self, object: Any, key_type: DuckDBPyType, value_type: DuckDBPyType class UnionType(Value): - def __init__(self, object: Any, members: dict[str, DuckDBPyType]): + def __init__(self, object: Any, members: dict[str, DuckDBPyType]) -> None: import duckdb union_type = duckdb.union_type(members) diff --git a/duckdb_packaging/pypi_cleanup.py b/duckdb_packaging/pypi_cleanup.py index 8236dd1d..031adf94 100644 --- a/duckdb_packaging/pypi_cleanup.py +++ b/duckdb_packaging/pypi_cleanup.py @@ -183,7 +183,7 @@ class CsrfParser(HTMLParser): Based on pypi-cleanup package (https://github.com/arcivanov/pypi-cleanup/tree/master) """ - def __init__(self, target, contains_input=None): + def __init__(self, target, contains_input=None) -> None: super().__init__() self._target = target self._contains_input = contains_input @@ -223,7 +223,7 @@ class PyPICleanup: """Main class for performing PyPI package cleanup operations.""" def __init__(self, index_url: str, do_delete: bool, max_dev_releases: int=_DEFAULT_MAX_NIGHTLIES, - username: Optional[str]=None, password: Optional[str]=None, otp: Optional[str]=None): + username: Optional[str]=None, password: Optional[str]=None, otp: Optional[str]=None) -> None: parsed_url = urlparse(index_url) self._index_url = parsed_url.geturl().rstrip('/') self._index_host = parsed_url.hostname diff --git a/scripts/generate_import_cache_cpp.py b/scripts/generate_import_cache_cpp.py index f1f9d983..f03d8d89 100644 --- a/scripts/generate_import_cache_cpp.py +++ b/scripts/generate_import_cache_cpp.py @@ -40,7 +40,7 @@ def collect_items_of_module(module: dict, collection: dict): class CacheItem: - def __init__(self, module: dict, items): + def __init__(self, module: dict, items) -> None: self.name = module['name'] self.module = module self.items = items @@ -132,7 +132,7 @@ def collect_classes(items: dict) -> list: class ModuleFile: - def __init__(self, module: dict): + def __init__(self, module: dict) -> None: self.module = module self.file_name = get_filename(module['name']) self.items = {} diff --git a/scripts/generate_import_cache_json.py b/scripts/generate_import_cache_json.py index 53d98c57..2df33b24 100644 --- a/scripts/generate_import_cache_json.py +++ b/scripts/generate_import_cache_json.py @@ -8,7 +8,7 @@ class ImportCacheAttribute: - def __init__(self, full_path: str): + def __init__(self, full_path: str) -> None: parts = full_path.split('.') self.type = "attribute" self.name = parts[-1] @@ -41,7 +41,7 @@ def populate_json(self, json_data: dict): class ImportCacheModule: - def __init__(self, full_path): + def __init__(self, full_path) -> None: parts = full_path.split('.') self.type = "module" self.name = parts[-1] @@ -78,7 +78,7 @@ def root_module(self) -> bool: class ImportCacheGenerator: - def __init__(self): + def __init__(self) -> None: self.modules: dict[str, ImportCacheModule] = {} def add_module(self, path: str): diff --git a/scripts/get_cpp_methods.py b/scripts/get_cpp_methods.py index 9f86b4cb..97b28af3 100644 --- a/scripts/get_cpp_methods.py +++ b/scripts/get_cpp_methods.py @@ -4,30 +4,30 @@ import cxxheaderparser.parser import cxxheaderparser.visitor import cxxheaderparser.preprocessor -from typing import List, Dict +from typing import List, Dict, Callable scripts_folder = os.path.dirname(os.path.abspath(__file__)) class FunctionParam: - def __init__(self, name: str, proto: str): + def __init__(self, name: str, proto: str) -> None: self.proto = proto self.name = name class ConnectionMethod: - def __init__(self, name: str, params: list[FunctionParam], is_void: bool): + def __init__(self, name: str, params: list[FunctionParam], is_void: bool) -> None: self.name = name self.params = params self.is_void = is_void class Visitor: - def __init__(self, class_name: str): + def __init__(self, class_name: str) -> None: self.methods_dict = {} self.class_name = class_name - def __getattr__(self, name): + def __getattr__(self, name) -> Callable[[...], bool]: return lambda *state: True def on_class_start(self, state): diff --git a/sqllogic/test_sqllogic.py b/sqllogic/test_sqllogic.py index ee7426cd..4e7cead0 100644 --- a/sqllogic/test_sqllogic.py +++ b/sqllogic/test_sqllogic.py @@ -39,7 +39,7 @@ def sigquit_handler(signum, frame): class SQLLogicTestExecutor(SQLLogicRunner): - def __init__(self, test_directory: str, build_directory: Optional[str] = None): + def __init__(self, test_directory: str, build_directory: Optional[str] = None) -> None: super().__init__(build_directory) self.test_directory = test_directory # TODO: get this from the `duckdb` package diff --git a/tests/conftest.py b/tests/conftest.py index ce2d0e68..b9950ee7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,6 @@ import os +from typing import Any + import pytest import shutil from os.path import abspath, join, dirname, normpath @@ -121,12 +123,12 @@ def arrow_pandas_df(*args, **kwargs): class NumpyPandas: - def __init__(self): + def __init__(self) -> None: self.backend = 'numpy_nullable' self.DataFrame = numpy_pandas_df self.pandas = import_pandas() - def __getattr__(self, name: str): + def __getattr__(self, name: str) -> Any: return getattr(self.pandas, name) @@ -156,11 +158,11 @@ def convert_and_equal(df1, df2, **kwargs): class ArrowMockTesting: - def __init__(self): + def __init__(self) -> None: self.testing = import_pandas().testing self.assert_frame_equal = convert_and_equal - def __getattr__(self, name: str): + def __getattr__(self, name: str) -> Any: return getattr(self.testing, name) @@ -168,7 +170,7 @@ def __getattr__(self, name: str): # Assert equal does the opposite, turning all pyarrow backed dataframes into numpy backed ones # this is done because we don't produce pyarrow backed dataframes yet class ArrowPandas: - def __init__(self): + def __init__(self) -> None: self.pandas = import_pandas() if pandas_2_or_higher() and pyarrow_dtypes_enabled: self.backend = 'pyarrow' @@ -179,7 +181,7 @@ def __init__(self): self.DataFrame = self.pandas.DataFrame self.testing = ArrowMockTesting() - def __getattr__(self, name: str): + def __getattr__(self, name: str) -> Any: return getattr(self.pandas, name) diff --git a/tests/fast/api/test_fsspec.py b/tests/fast/api/test_fsspec.py index 0a289972..a878fda5 100644 --- a/tests/fast/api/test_fsspec.py +++ b/tests/fast/api/test_fsspec.py @@ -44,7 +44,7 @@ def modified(self, path): def _open(self, path, **kwargs): return io.BytesIO(self._data[path]) - def __init__(self): + def __init__(self) -> None: super().__init__() self._data = {"a": parquet_data, "b": parquet_data} diff --git a/tests/fast/api/test_read_csv.py b/tests/fast/api/test_read_csv.py index 1a297109..7337515d 100644 --- a/tests/fast/api/test_read_csv.py +++ b/tests/fast/api/test_read_csv.py @@ -327,7 +327,7 @@ def test_filelike_exception(self, duckdb_cursor): _ = pytest.importorskip("fsspec") class ReadError: - def __init__(self): + def __init__(self) -> None: pass def read(self, amount=-1): @@ -337,7 +337,7 @@ def seek(self, loc): return 0 class SeekError: - def __init__(self): + def __init__(self) -> None: pass def read(self, amount=-1): @@ -359,7 +359,7 @@ def test_filelike_custom(self, duckdb_cursor): _ = pytest.importorskip("fsspec") class CustomIO: - def __init__(self): + def __init__(self) -> None: self.loc = 0 pass @@ -398,11 +398,11 @@ def test_internal_object_filesystem_cleanup(self, duckdb_cursor): class CountedObject(StringIO): instance_count = 0 - def __init__(self, str): + def __init__(self, str) -> None: CountedObject.instance_count += 1 super().__init__(str) - def __del__(self): + def __del__(self) -> None: CountedObject.instance_count -= 1 def scoped_objects(duckdb_cursor): diff --git a/tests/fast/arrow/test_arrow_extensions.py b/tests/fast/arrow/test_arrow_extensions.py index 9180fa90..95a2108a 100644 --- a/tests/fast/arrow/test_arrow_extensions.py +++ b/tests/fast/arrow/test_arrow_extensions.py @@ -116,10 +116,10 @@ def test_function(x): def test_unimplemented_extension(self, duckdb_cursor): class MyType(pa.ExtensionType): - def __init__(self): + def __init__(self) -> None: pa.ExtensionType.__init__(self, pa.binary(5), "pedro.binary") - def __arrow_ext_serialize__(self): + def __arrow_ext_serialize__(self) -> bytes: return b'' @classmethod diff --git a/tests/fast/arrow/test_arrow_list.py b/tests/fast/arrow/test_arrow_list.py index e2449fd3..556f614a 100644 --- a/tests/fast/arrow/test_arrow_list.py +++ b/tests/fast/arrow/test_arrow_list.py @@ -41,7 +41,7 @@ def create_and_register_comparison_result(column_list, duckdb_cursor): class ListGenerationResult: - def __init__(self, list, list_view): + def __init__(self, list, list_view) -> None: self.list = list self.list_view = list_view diff --git a/tests/fast/arrow/test_arrow_pycapsule.py b/tests/fast/arrow/test_arrow_pycapsule.py index c293344d..8310c58b 100644 --- a/tests/fast/arrow/test_arrow_pycapsule.py +++ b/tests/fast/arrow/test_arrow_pycapsule.py @@ -17,11 +17,11 @@ def polars_supports_capsule(): class TestArrowPyCapsule(object): def test_polars_pycapsule_scan(self, duckdb_cursor): class MyObject: - def __init__(self, obj): + def __init__(self, obj) -> None: self.obj = obj self.count = 0 - def __arrow_c_stream__(self, requested_schema=None): + def __arrow_c_stream__(self, requested_schema=None) -> object: self.count += 1 return self.obj.__arrow_c_stream__(requested_schema=requested_schema) @@ -71,11 +71,11 @@ def test_automatic_reexecution(self, duckdb_cursor): def test_consumer_interface_roundtrip(self, duckdb_cursor): def create_table(): class MyTable: - def __init__(self, rel, conn): + def __init__(self, rel, conn) -> None: self.rel = rel self.conn = conn - def __arrow_c_stream__(self, requested_schema=None): + def __arrow_c_stream__(self, requested_schema=None) -> object: return self.rel.__arrow_c_stream__(requested_schema=requested_schema) conn = duckdb.connect() diff --git a/tests/fast/arrow/test_dataset.py b/tests/fast/arrow/test_dataset.py index 2f3d7a53..521ec8f7 100644 --- a/tests/fast/arrow/test_dataset.py +++ b/tests/fast/arrow/test_dataset.py @@ -102,7 +102,7 @@ class CustomDataset(pyarrow.dataset.Dataset): SCHEMA = pyarrow.schema([pyarrow.field("a", pyarrow.int64(), True), pyarrow.field("b", pyarrow.float64(), True)]) DATA = pyarrow.Table.from_arrays([pyarrow.array(range(100)), pyarrow.array(np.arange(100) * 1.0)], schema=SCHEMA) - def __init__(self): + def __init__(self) -> None: pass def scanner(self, **kwargs): @@ -114,7 +114,7 @@ def schema(self): class CustomScanner(pyarrow.dataset.Scanner): - def __init__(self, filter=None, columns=None, **kwargs): + def __init__(self, filter=None, columns=None, **kwargs) -> None: self.filter = filter self.columns = columns self.kwargs = kwargs diff --git a/tests/fast/pandas/test_df_object_resolution.py b/tests/fast/pandas/test_df_object_resolution.py index ed89f324..d54db072 100644 --- a/tests/fast/pandas/test_df_object_resolution.py +++ b/tests/fast/pandas/test_df_object_resolution.py @@ -30,10 +30,10 @@ def create_trailing_non_null(size): class IntString: - def __init__(self, value: int): + def __init__(self, value: int) -> None: self.value = value - def __str__(self): + def __str__(self) -> str: return str(self.value) @@ -48,11 +48,11 @@ def ConvertStringToDecimal(data: list, pandas): class ObjectPair: - def __init__(self, obj1, obj2): + def __init__(self, obj1, obj2) -> None: self.first = obj1 self.second = obj2 - def __repr__(self): + def __repr__(self) -> str: return str([self.first, self.second]) diff --git a/tests/fast/pandas/test_pandas_types.py b/tests/fast/pandas/test_pandas_types.py index aeb33ea4..b21c7f14 100644 --- a/tests/fast/pandas/test_pandas_types.py +++ b/tests/fast/pandas/test_pandas_types.py @@ -185,7 +185,7 @@ def test_pandas_encoded_utf8(self, duckdb_cursor): ) def test_producing_nullable_dtypes(self, duckdb_cursor, dtype): class Input: - def __init__(self, value, expected_dtype): + def __init__(self, value, expected_dtype) -> None: self.value = value self.expected_dtype = expected_dtype diff --git a/tests/fast/test_expression.py b/tests/fast/test_expression.py index 289d88a9..e0f830c5 100644 --- a/tests/fast/test_expression.py +++ b/tests/fast/test_expression.py @@ -987,7 +987,7 @@ def test_aggregate_error(self): ): class MyClass: - def __init__(self): + def __init__(self) -> None: pass res = rel.aggregate([MyClass()]).fetchone()[0] diff --git a/tests/fast/test_multithread.py b/tests/fast/test_multithread.py index 4b470b84..ad2d56fd 100644 --- a/tests/fast/test_multithread.py +++ b/tests/fast/test_multithread.py @@ -25,7 +25,7 @@ def everything_succeeded(results: list[bool]): class DuckDBThreaded: - def __init__(self, duckdb_insert_thread_count, thread_function, pandas): + def __init__(self, duckdb_insert_thread_count, thread_function, pandas) -> None: self.duckdb_insert_thread_count = duckdb_insert_thread_count self.threads = [] self.thread_function = thread_function diff --git a/tests/fast/udf/test_scalar.py b/tests/fast/udf/test_scalar.py index 61648c20..8e0eb8b1 100644 --- a/tests/fast/udf/test_scalar.py +++ b/tests/fast/udf/test_scalar.py @@ -4,7 +4,7 @@ pd = pytest.importorskip("pandas") pa = pytest.importorskip('pyarrow', '18.0.0') -from typing import Union +from typing import Union, Any import pyarrow.compute as pc import uuid import datetime @@ -156,10 +156,10 @@ def test_non_callable(self): con.create_function('func', 5, [BIGINT], BIGINT, type='arrow') class MyCallable: - def __init__(self): + def __init__(self) -> None: pass - def __call__(self, x): + def __call__(self, x) -> Any: return x my_callable = MyCallable() From 517df46c255ca3dce3177bf8c8e83ae4cdf1f7e2 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:24:46 +0200 Subject: [PATCH 211/472] Ruff config: line-length to 120 and fixable no longer top level --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4da79b50..a53f9eb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -312,14 +312,14 @@ branch = true source = ["duckdb"] [tool.ruff] -line-length = 88 +line-length = 120 indent-width = 4 target-version = "py39" fix = true -fixable = ["ALL"] exclude = ['external/duckdb'] [tool.ruff.lint] +fixable = ["ALL"] exclude = ['*.pyi'] select = [ "ANN", # flake8-annotations From c368456a1b8f8cb6b6d07ef6f22b87a9cde0b64a Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:26:45 +0200 Subject: [PATCH 212/472] Ruff format fixes --- duckdb/__init__.py | 571 +++++----- duckdb/__init__.pyi | 784 ++++++++++---- duckdb/bytes_io_wrapper.py | 1 - duckdb/experimental/__init__.py | 1 + duckdb/experimental/spark/_globals.py | 2 +- duckdb/experimental/spark/_typing.py | 9 +- duckdb/experimental/spark/context.py | 2 +- duckdb/experimental/spark/errors/__init__.py | 1 + .../spark/errors/exceptions/base.py | 2 + duckdb/experimental/spark/errors/utils.py | 3 +- duckdb/experimental/spark/exception.py | 2 +- duckdb/experimental/spark/sql/_typing.py | 19 +- duckdb/experimental/spark/sql/catalog.py | 8 +- duckdb/experimental/spark/sql/column.py | 10 +- duckdb/experimental/spark/sql/dataframe.py | 178 ++-- duckdb/experimental/spark/sql/functions.py | 995 +++++++++++------- duckdb/experimental/spark/sql/group.py | 57 +- duckdb/experimental/spark/sql/readwriter.py | 80 +- duckdb/experimental/spark/sql/session.py | 24 +- duckdb/experimental/spark/sql/streaming.py | 2 +- duckdb/experimental/spark/sql/type_utils.py | 70 +- duckdb/experimental/spark/sql/types.py | 54 +- duckdb/filesystem.py | 3 +- duckdb/functional/__init__.py | 18 +- duckdb/polars_io.py | 60 +- duckdb/query_graph/__main__.py | 104 +- duckdb/typing/__init__.py | 4 +- duckdb/typing/__init__.pyi | 6 +- duckdb/value/constant/__init__.pyi | 7 +- duckdb_packaging/_versioning.py | 32 +- duckdb_packaging/build_backend.py | 20 +- duckdb_packaging/pypi_cleanup.py | 173 ++- duckdb_packaging/setuptools_scm_version.py | 15 +- scripts/generate_connection_code.py | 2 +- scripts/generate_connection_methods.py | 66 +- scripts/generate_connection_stubs.py | 34 +- .../generate_connection_wrapper_methods.py | 132 +-- scripts/generate_connection_wrapper_stubs.py | 48 +- scripts/generate_import_cache_cpp.py | 96 +- scripts/generate_import_cache_json.py | 26 +- sqllogic/conftest.py | 6 +- sqllogic/skipped_tests.py | 76 +- sqllogic/test_sqllogic.py | 14 +- tests/conftest.py | 50 +- .../test_pandas_categorical_coverage.py | 16 +- tests/extensions/json/test_read_json.py | 106 +- tests/extensions/test_extensions_loading.py | 22 +- tests/extensions/test_httpfs.py | 30 +- tests/fast/adbc/test_adbc.py | 34 +- tests/fast/adbc/test_statement_bind.py | 32 +- tests/fast/api/test_3324.py | 2 +- tests/fast/api/test_3654.py | 8 +- tests/fast/api/test_3728.py | 4 +- tests/fast/api/test_6315.py | 6 +- tests/fast/api/test_attribute_getter.py | 28 +- tests/fast/api/test_config.py | 58 +- tests/fast/api/test_connection_close.py | 6 +- tests/fast/api/test_cursor.py | 22 +- tests/fast/api/test_dbapi00.py | 36 +- tests/fast/api/test_dbapi01.py | 18 +- tests/fast/api/test_dbapi04.py | 2 +- tests/fast/api/test_dbapi05.py | 24 +- tests/fast/api/test_dbapi07.py | 4 +- tests/fast/api/test_dbapi08.py | 4 +- tests/fast/api/test_dbapi09.py | 8 +- tests/fast/api/test_dbapi12.py | 48 +- tests/fast/api/test_dbapi13.py | 4 +- tests/fast/api/test_dbapi_fetch.py | 88 +- tests/fast/api/test_duckdb_connection.py | 86 +- tests/fast/api/test_duckdb_execute.py | 14 +- tests/fast/api/test_duckdb_query.py | 58 +- tests/fast/api/test_explain.py | 24 +- tests/fast/api/test_fsspec.py | 2 +- tests/fast/api/test_insert_into.py | 14 +- tests/fast/api/test_join.py | 22 +- tests/fast/api/test_native_tz.py | 36 +- tests/fast/api/test_query_interrupt.py | 2 +- tests/fast/api/test_read_csv.py | 394 +++---- tests/fast/api/test_relation_to_view.py | 28 +- tests/fast/api/test_streaming_result.py | 12 +- tests/fast/api/test_to_csv.py | 132 +-- tests/fast/api/test_to_parquet.py | 65 +- .../api/test_with_propagating_exceptions.py | 12 +- tests/fast/arrow/parquet_write_roundtrip.py | 38 +- tests/fast/arrow/test_10795.py | 6 +- tests/fast/arrow/test_12384.py | 10 +- tests/fast/arrow/test_14344.py | 2 +- tests/fast/arrow/test_2426.py | 6 +- tests/fast/arrow/test_5547.py | 2 +- tests/fast/arrow/test_6584.py | 2 +- tests/fast/arrow/test_6796.py | 4 +- tests/fast/arrow/test_7652.py | 4 +- tests/fast/arrow/test_7699.py | 2 +- tests/fast/arrow/test_arrow_batch_index.py | 8 +- tests/fast/arrow/test_arrow_binary_view.py | 4 +- tests/fast/arrow/test_arrow_case_sensitive.py | 16 +- tests/fast/arrow/test_arrow_decimal_32_64.py | 20 +- tests/fast/arrow/test_arrow_extensions.py | 123 ++- tests/fast/arrow/test_arrow_fetch.py | 6 +- .../arrow/test_arrow_fetch_recordbatch.py | 26 +- tests/fast/arrow/test_arrow_fixed_binary.py | 6 +- tests/fast/arrow/test_arrow_ipc.py | 8 +- tests/fast/arrow/test_arrow_list.py | 22 +- tests/fast/arrow/test_arrow_offsets.py | 128 +-- tests/fast/arrow/test_arrow_pycapsule.py | 6 +- .../arrow/test_arrow_recordbatchreader.py | 36 +- .../fast/arrow/test_arrow_replacement_scan.py | 20 +- .../fast/arrow/test_arrow_run_end_encoding.py | 150 ++- tests/fast/arrow/test_arrow_scanner.py | 20 +- tests/fast/arrow/test_arrow_string_view.py | 16 +- tests/fast/arrow/test_arrow_types.py | 8 +- tests/fast/arrow/test_arrow_union.py | 14 +- tests/fast/arrow/test_arrow_version_format.py | 12 +- tests/fast/arrow/test_buffer_size_option.py | 2 +- tests/fast/arrow/test_dataset.py | 12 +- tests/fast/arrow/test_date.py | 20 +- tests/fast/arrow/test_dictionary_arrow.py | 56 +- tests/fast/arrow/test_filter_pushdown.py | 179 ++-- tests/fast/arrow/test_integration.py | 52 +- tests/fast/arrow/test_interval.py | 46 +- tests/fast/arrow/test_large_offsets.py | 4 +- tests/fast/arrow/test_large_string.py | 2 +- tests/fast/arrow/test_multiple_reads.py | 4 +- tests/fast/arrow/test_nested_arrow.py | 58 +- tests/fast/arrow/test_parallel.py | 12 +- tests/fast/arrow/test_polars.py | 84 +- tests/fast/arrow/test_progress.py | 16 +- tests/fast/arrow/test_time.py | 70 +- tests/fast/arrow/test_timestamp_timezone.py | 28 +- tests/fast/arrow/test_timestamps.py | 56 +- tests/fast/arrow/test_tpch.py | 8 +- tests/fast/arrow/test_unregister.py | 16 +- tests/fast/arrow/test_view.py | 6 +- tests/fast/numpy/test_numpy_new_path.py | 14 +- tests/fast/pandas/test_2304.py | 72 +- tests/fast/pandas/test_append_df.py | 44 +- tests/fast/pandas/test_bug2281.py | 6 +- tests/fast/pandas/test_bug5922.py | 14 +- tests/fast/pandas/test_copy_on_write.py | 10 +- .../pandas/test_create_table_from_pandas.py | 4 +- tests/fast/pandas/test_date_as_datetime.py | 8 +- tests/fast/pandas/test_datetime_time.py | 20 +- tests/fast/pandas/test_datetime_timestamp.py | 52 +- tests/fast/pandas/test_df_analyze.py | 22 +- .../fast/pandas/test_df_object_resolution.py | 396 +++---- tests/fast/pandas/test_df_recursive_nested.py | 68 +- tests/fast/pandas/test_fetch_df_chunk.py | 14 +- tests/fast/pandas/test_fetch_nested.py | 8 +- .../fast/pandas/test_implicit_pandas_scan.py | 10 +- tests/fast/pandas/test_import_cache.py | 18 +- tests/fast/pandas/test_issue_1767.py | 4 +- tests/fast/pandas/test_limit.py | 10 +- tests/fast/pandas/test_pandas_arrow.py | 94 +- tests/fast/pandas/test_pandas_category.py | 56 +- tests/fast/pandas/test_pandas_enum.py | 8 +- tests/fast/pandas/test_pandas_limit.py | 8 +- tests/fast/pandas/test_pandas_na.py | 22 +- tests/fast/pandas/test_pandas_object.py | 62 +- tests/fast/pandas/test_pandas_string.py | 23 +- tests/fast/pandas/test_pandas_timestamp.py | 16 +- tests/fast/pandas/test_pandas_types.py | 110 +- tests/fast/pandas/test_pandas_unregister.py | 12 +- tests/fast/pandas/test_pandas_update.py | 12 +- .../fast/pandas/test_parallel_pandas_scan.py | 52 +- .../pandas/test_partitioned_pandas_scan.py | 4 +- tests/fast/pandas/test_progress_bar.py | 16 +- .../test_pyarrow_projection_pushdown.py | 4 +- tests/fast/pandas/test_same_name.py | 50 +- tests/fast/pandas/test_stride.py | 22 +- tests/fast/pandas/test_timedelta.py | 20 +- tests/fast/pandas/test_timestamp.py | 34 +- tests/fast/relational_api/test_groupings.py | 6 +- tests/fast/relational_api/test_joins.py | 52 +- tests/fast/relational_api/test_pivot.py | 2 +- .../relational_api/test_rapi_aggregations.py | 4 +- tests/fast/relational_api/test_rapi_close.py | 162 +-- .../relational_api/test_rapi_description.py | 26 +- .../relational_api/test_rapi_functions.py | 4 +- tests/fast/relational_api/test_rapi_query.py | 60 +- .../fast/relational_api/test_rapi_windows.py | 18 +- .../relational_api/test_table_function.py | 6 +- tests/fast/spark/test_replace_column_value.py | 22 +- tests/fast/spark/test_replace_empty_value.py | 32 +- tests/fast/spark/test_spark_catalog.py | 30 +- tests/fast/spark/test_spark_column.py | 20 +- tests/fast/spark/test_spark_dataframe.py | 134 +-- tests/fast/spark/test_spark_dataframe_sort.py | 20 +- .../fast/spark/test_spark_drop_duplicates.py | 34 +- tests/fast/spark/test_spark_except.py | 1 - tests/fast/spark/test_spark_filter.py | 74 +- .../fast/spark/test_spark_functions_array.py | 82 +- .../fast/spark/test_spark_functions_base64.py | 2 +- tests/fast/spark/test_spark_functions_date.py | 46 +- tests/fast/spark/test_spark_functions_hex.py | 4 +- .../test_spark_functions_miscellaneous.py | 30 +- tests/fast/spark/test_spark_functions_null.py | 6 +- .../spark/test_spark_functions_numeric.py | 6 +- .../fast/spark/test_spark_functions_string.py | 164 +-- tests/fast/spark/test_spark_group_by.py | 12 +- tests/fast/spark/test_spark_intersect.py | 2 - tests/fast/spark/test_spark_join.py | 254 ++--- tests/fast/spark/test_spark_order_by.py | 94 +- .../fast/spark/test_spark_pandas_dataframe.py | 12 +- tests/fast/spark/test_spark_readcsv.py | 4 +- tests/fast/spark/test_spark_readjson.py | 4 +- tests/fast/spark/test_spark_readparquet.py | 4 +- tests/fast/spark/test_spark_session.py | 16 +- tests/fast/spark/test_spark_to_csv.py | 40 +- tests/fast/spark/test_spark_transform.py | 12 +- tests/fast/spark/test_spark_types.py | 90 +- tests/fast/spark/test_spark_udf.py | 1 - tests/fast/spark/test_spark_union.py | 32 +- tests/fast/spark/test_spark_union_by_name.py | 32 +- tests/fast/spark/test_spark_with_column.py | 26 +- .../spark/test_spark_with_column_renamed.py | 56 +- tests/fast/spark/test_spark_with_columns.py | 22 +- .../spark/test_spark_with_columns_renamed.py | 38 +- tests/fast/sqlite/test_types.py | 20 +- tests/fast/test_alex_multithread.py | 32 +- tests/fast/test_all_types.py | 298 +++--- tests/fast/test_case_alias.py | 12 +- tests/fast/test_context_manager.py | 2 +- tests/fast/test_duckdb_api.py | 2 +- tests/fast/test_expression.py | 212 ++-- tests/fast/test_filesystem.py | 104 +- tests/fast/test_get_table_names.py | 44 +- tests/fast/test_import_export.py | 12 +- tests/fast/test_insert.py | 16 +- tests/fast/test_many_con_same_file.py | 10 +- tests/fast/test_map.py | 102 +- tests/fast/test_metatransaction.py | 4 +- tests/fast/test_multi_statement.py | 20 +- tests/fast/test_multithread.py | 134 +-- tests/fast/test_non_default_conn.py | 28 +- tests/fast/test_parameter_list.py | 8 +- tests/fast/test_parquet.py | 50 +- tests/fast/test_pypi_cleanup.py | 320 ++++-- tests/fast/test_pytorch.py | 16 +- tests/fast/test_relation.py | 282 ++--- tests/fast/test_relation_dependency_leak.py | 14 +- tests/fast/test_replacement_scan.py | 100 +- tests/fast/test_result.py | 34 +- tests/fast/test_runtime_error.py | 44 +- tests/fast/test_sql_expression.py | 28 +- tests/fast/test_string_annotation.py | 12 +- tests/fast/test_tf.py | 16 +- tests/fast/test_transaction.py | 16 +- tests/fast/test_type.py | 170 +-- tests/fast/test_type_explicit.py | 7 +- tests/fast/test_unicode.py | 8 +- tests/fast/test_value.py | 30 +- tests/fast/test_versioning.py | 46 +- tests/fast/test_windows_abs_path.py | 20 +- tests/fast/types/test_blob.py | 4 +- tests/fast/types/test_datetime_datetime.py | 20 +- tests/fast/types/test_decimal.py | 14 +- tests/fast/types/test_hugeint.py | 6 +- tests/fast/types/test_nan.py | 36 +- tests/fast/types/test_nested.py | 12 +- tests/fast/types/test_numpy.py | 8 +- tests/fast/types/test_object_int.py | 54 +- tests/fast/types/test_time_tz.py | 2 +- tests/fast/types/test_unsigned.py | 6 +- tests/fast/udf/test_null_filtering.py | 78 +- tests/fast/udf/test_remove_function.py | 48 +- tests/fast/udf/test_scalar.py | 114 +- tests/fast/udf/test_scalar_arrow.py | 76 +- tests/fast/udf/test_scalar_native.py | 68 +- tests/fast/udf/test_transactionality.py | 6 +- tests/slow/test_h2oai_arrow.py | 50 +- tests/stubs/test_stubs.py | 8 +- 271 files changed, 6750 insertions(+), 6184 deletions(-) diff --git a/duckdb/__init__.py b/duckdb/__init__.py index b5e994fa..bf50be5b 100644 --- a/duckdb/__init__.py +++ b/duckdb/__init__.py @@ -7,16 +7,15 @@ # duckdb.__version__ returns the version of the distribution package, i.e. the pypi version __version__ = version("duckdb") + # version() is a more human friendly formatted version string of both the distribution package and the bundled duckdb def version(): return f"{__version__} (with duckdb {duckdb_version})" -_exported_symbols = ['__version__', 'version'] -_exported_symbols.extend([ - "typing", - "functional" -]) +_exported_symbols = ["__version__", "version"] + +_exported_symbols.extend(["typing", "functional"]) class DBAPITypeObject: def __init__(self, types: list[typing.DuckDBPyType]) -> None: @@ -69,7 +68,7 @@ def __repr__(self): ExplainType, StatementType, ExpectedResultType, - CSVLineTerminator, + CSVLineTerminator, PythonExceptionHandling, RenderMode, Expression, @@ -81,217 +80,205 @@ def __repr__(self): StarExpression, FunctionExpression, CaseExpression, - SQLExpression + SQLExpression, ) -_exported_symbols.extend([ - "DuckDBPyRelation", - "DuckDBPyConnection", - "ExplainType", - "PythonExceptionHandling", - "Expression", - "ConstantExpression", - "ColumnExpression", - "DefaultExpression", - "CoalesceOperator", - "LambdaExpression", - "StarExpression", - "FunctionExpression", - "CaseExpression", - "SQLExpression" -]) -# These are overloaded twice, we define them inside of C++ so pybind can deal with it -_exported_symbols.extend([ - 'df', - 'arrow' -]) -from _duckdb import ( - df, - arrow +_exported_symbols.extend( + [ + "DuckDBPyRelation", + "DuckDBPyConnection", + "ExplainType", + "PythonExceptionHandling", + "Expression", + "ConstantExpression", + "ColumnExpression", + "DefaultExpression", + "CoalesceOperator", + "LambdaExpression", + "StarExpression", + "FunctionExpression", + "CaseExpression", + "SQLExpression", + ] ) +# These are overloaded twice, we define them inside of C++ so pybind can deal with it +_exported_symbols.extend(["df", "arrow"]) +from _duckdb import df, arrow + # NOTE: this section is generated by tools/pythonpkg/scripts/generate_connection_wrapper_methods.py. # Do not edit this section manually, your changes will be overwritten! # START OF CONNECTION WRAPPER from _duckdb import ( - cursor, - register_filesystem, - unregister_filesystem, - list_filesystems, - filesystem_is_registered, - create_function, - remove_function, - sqltype, - dtype, - type, - array_type, - list_type, - union_type, - string_type, - enum_type, - decimal_type, - struct_type, - row_type, - map_type, - duplicate, - execute, - executemany, - close, - interrupt, - query_progress, - fetchone, - fetchmany, - fetchall, - fetchnumpy, - fetchdf, - fetch_df, - df, - fetch_df_chunk, - pl, - fetch_arrow_table, - arrow, - fetch_record_batch, - torch, - tf, - begin, - commit, - rollback, - checkpoint, - append, - register, - unregister, - table, - view, - values, - table_function, - read_json, - extract_statements, - sql, - query, - from_query, - read_csv, - from_csv_auto, - from_df, - from_arrow, - from_parquet, - read_parquet, - from_parquet, - read_parquet, - get_table_names, - install_extension, - load_extension, - project, - distinct, - write_csv, - aggregate, - alias, - filter, - limit, - order, - query_df, - description, - rowcount, + cursor, + register_filesystem, + unregister_filesystem, + list_filesystems, + filesystem_is_registered, + create_function, + remove_function, + sqltype, + dtype, + type, + array_type, + list_type, + union_type, + string_type, + enum_type, + decimal_type, + struct_type, + row_type, + map_type, + duplicate, + execute, + executemany, + close, + interrupt, + query_progress, + fetchone, + fetchmany, + fetchall, + fetchnumpy, + fetchdf, + fetch_df, + df, + fetch_df_chunk, + pl, + fetch_arrow_table, + arrow, + fetch_record_batch, + torch, + tf, + begin, + commit, + rollback, + checkpoint, + append, + register, + unregister, + table, + view, + values, + table_function, + read_json, + extract_statements, + sql, + query, + from_query, + read_csv, + from_csv_auto, + from_df, + from_arrow, + from_parquet, + read_parquet, + from_parquet, + read_parquet, + get_table_names, + install_extension, + load_extension, + project, + distinct, + write_csv, + aggregate, + alias, + filter, + limit, + order, + query_df, + description, + rowcount, ) -_exported_symbols.extend([ - 'cursor', - 'register_filesystem', - 'unregister_filesystem', - 'list_filesystems', - 'filesystem_is_registered', - 'create_function', - 'remove_function', - 'sqltype', - 'dtype', - 'type', - 'array_type', - 'list_type', - 'union_type', - 'string_type', - 'enum_type', - 'decimal_type', - 'struct_type', - 'row_type', - 'map_type', - 'duplicate', - 'execute', - 'executemany', - 'close', - 'interrupt', - 'query_progress', - 'fetchone', - 'fetchmany', - 'fetchall', - 'fetchnumpy', - 'fetchdf', - 'fetch_df', - 'df', - 'fetch_df_chunk', - 'pl', - 'fetch_arrow_table', - 'arrow', - 'fetch_record_batch', - 'torch', - 'tf', - 'begin', - 'commit', - 'rollback', - 'checkpoint', - 'append', - 'register', - 'unregister', - 'table', - 'view', - 'values', - 'table_function', - 'read_json', - 'extract_statements', - 'sql', - 'query', - 'from_query', - 'read_csv', - 'from_csv_auto', - 'from_df', - 'from_arrow', - 'from_parquet', - 'read_parquet', - 'from_parquet', - 'read_parquet', - 'get_table_names', - 'install_extension', - 'load_extension', - 'project', - 'distinct', - 'write_csv', - 'aggregate', - 'alias', - 'filter', - 'limit', - 'order', - 'query_df', - 'description', - 'rowcount', -]) +_exported_symbols.extend( + [ + "cursor", + "register_filesystem", + "unregister_filesystem", + "list_filesystems", + "filesystem_is_registered", + "create_function", + "remove_function", + "sqltype", + "dtype", + "type", + "array_type", + "list_type", + "union_type", + "string_type", + "enum_type", + "decimal_type", + "struct_type", + "row_type", + "map_type", + "duplicate", + "execute", + "executemany", + "close", + "interrupt", + "query_progress", + "fetchone", + "fetchmany", + "fetchall", + "fetchnumpy", + "fetchdf", + "fetch_df", + "df", + "fetch_df_chunk", + "pl", + "fetch_arrow_table", + "arrow", + "fetch_record_batch", + "torch", + "tf", + "begin", + "commit", + "rollback", + "checkpoint", + "append", + "register", + "unregister", + "table", + "view", + "values", + "table_function", + "read_json", + "extract_statements", + "sql", + "query", + "from_query", + "read_csv", + "from_csv_auto", + "from_df", + "from_arrow", + "from_parquet", + "read_parquet", + "from_parquet", + "read_parquet", + "get_table_names", + "install_extension", + "load_extension", + "project", + "distinct", + "write_csv", + "aggregate", + "alias", + "filter", + "limit", + "order", + "query_df", + "description", + "rowcount", + ] +) # END OF CONNECTION WRAPPER # Enums -from _duckdb import ( - ANALYZE, - DEFAULT, - RETURN_NULL, - STANDARD, - COLUMNS, - ROWS -) -_exported_symbols.extend([ - "ANALYZE", - "DEFAULT", - "RETURN_NULL", - "STANDARD" -]) +from _duckdb import ANALYZE, DEFAULT, RETURN_NULL, STANDARD, COLUMNS, ROWS + +_exported_symbols.extend(["ANALYZE", "DEFAULT", "RETURN_NULL", "STANDARD"]) # read-only properties @@ -310,25 +297,28 @@ def __repr__(self): string_const, threadsafety, token_type, - tokenize + tokenize, +) + +_exported_symbols.extend( + [ + "__standard_vector_size__", + "__interactive__", + "__jupyter__", + "__formatted_python_version__", + "apilevel", + "comment", + "identifier", + "keyword", + "numeric_const", + "operator", + "paramstyle", + "string_const", + "threadsafety", + "token_type", + "tokenize", + ] ) -_exported_symbols.extend([ - "__standard_vector_size__", - "__interactive__", - "__jupyter__", - "__formatted_python_version__", - "apilevel", - "comment", - "identifier", - "keyword", - "numeric_const", - "operator", - "paramstyle", - "string_const", - "threadsafety", - "token_type", - "tokenize" -]) from _duckdb import ( @@ -337,11 +327,13 @@ def __repr__(self): set_default_connection, ) -_exported_symbols.extend([ - "connect", - "default_connection", - "set_default_connection", -]) +_exported_symbols.extend( + [ + "connect", + "default_connection", + "set_default_connection", + ] +) # Exceptions from _duckdb import ( @@ -374,40 +366,43 @@ def __repr__(self): ParserException, SyntaxException, SequenceException, - Warning + Warning, +) + +_exported_symbols.extend( + [ + "Error", + "DataError", + "ConversionException", + "OutOfRangeException", + "TypeMismatchException", + "FatalException", + "IntegrityError", + "ConstraintException", + "InternalError", + "InternalException", + "InterruptException", + "NotSupportedError", + "NotImplementedException", + "OperationalError", + "ConnectionException", + "IOException", + "HTTPException", + "OutOfMemoryException", + "SerializationException", + "TransactionException", + "PermissionException", + "ProgrammingError", + "BinderException", + "CatalogException", + "InvalidInputException", + "InvalidTypeException", + "ParserException", + "SyntaxException", + "SequenceException", + "Warning", + ] ) -_exported_symbols.extend([ - "Error", - "DataError", - "ConversionException", - "OutOfRangeException", - "TypeMismatchException", - "FatalException", - "IntegrityError", - "ConstraintException", - "InternalError", - "InternalException", - "InterruptException", - "NotSupportedError", - "NotImplementedException", - "OperationalError", - "ConnectionException", - "IOException", - "HTTPException", - "OutOfMemoryException", - "SerializationException", - "TransactionException", - "PermissionException", - "ProgrammingError", - "BinderException", - "CatalogException", - "InvalidInputException", - "InvalidTypeException", - "ParserException", - "SyntaxException", - "SequenceException", - "Warning" -]) # Value from duckdb.value.constant import ( @@ -441,35 +436,37 @@ def __repr__(self): TimeTimeZoneValue, ) -_exported_symbols.extend([ - "Value", - "NullValue", - "BooleanValue", - "UnsignedBinaryValue", - "UnsignedShortValue", - "UnsignedIntegerValue", - "UnsignedLongValue", - "BinaryValue", - "ShortValue", - "IntegerValue", - "LongValue", - "HugeIntegerValue", - "FloatValue", - "DoubleValue", - "DecimalValue", - "StringValue", - "UUIDValue", - "BitValue", - "BlobValue", - "DateValue", - "IntervalValue", - "TimestampValue", - "TimestampSecondValue", - "TimestampMilisecondValue", - "TimestampNanosecondValue", - "TimestampTimeZoneValue", - "TimeValue", - "TimeTimeZoneValue", -]) +_exported_symbols.extend( + [ + "Value", + "NullValue", + "BooleanValue", + "UnsignedBinaryValue", + "UnsignedShortValue", + "UnsignedIntegerValue", + "UnsignedLongValue", + "BinaryValue", + "ShortValue", + "IntegerValue", + "LongValue", + "HugeIntegerValue", + "FloatValue", + "DoubleValue", + "DecimalValue", + "StringValue", + "UUIDValue", + "BitValue", + "BlobValue", + "DateValue", + "IntervalValue", + "TimestampValue", + "TimestampSecondValue", + "TimestampMilisecondValue", + "TimestampNanosecondValue", + "TimestampTimeZoneValue", + "TimeValue", + "TimeTimeZoneValue", + ] +) __all__ = _exported_symbols diff --git a/duckdb/__init__.pyi b/duckdb/__init__.pyi index 8f27e5e3..0c597d11 100644 --- a/duckdb/__init__.pyi +++ b/duckdb/__init__.pyi @@ -41,6 +41,7 @@ from duckdb.value.constant import ( # We also run this in python3.7, where this is needed from typing_extensions import Literal + # stubgen override - missing import of Set from typing import Any, ClassVar, Set, Optional, Callable from io import StringIO, TextIOBase @@ -48,11 +49,13 @@ from pathlib import Path from typing import overload, Dict, List, Union, Tuple import pandas + # stubgen override - unfortunately we need this for version checks import sys import fsspec import pyarrow.lib import polars + # stubgen override - This should probably not be exposed apilevel: str comment: token_type @@ -78,15 +81,10 @@ __jupyter__: bool __formatted_python_version__: str class BinderException(ProgrammingError): ... - class CatalogException(ProgrammingError): ... - class ConnectionException(OperationalError): ... - class ConstraintException(IntegrityError): ... - class ConversionException(DataError): ... - class DataError(Error): ... class ExplainType: @@ -204,46 +202,37 @@ class Statement: class Expression: def __init__(self, *args, **kwargs) -> None: ... def __neg__(self) -> "Expression": ... - def __add__(self, expr: "Expression") -> "Expression": ... def __radd__(self, expr: "Expression") -> "Expression": ... - def __sub__(self, expr: "Expression") -> "Expression": ... def __rsub__(self, expr: "Expression") -> "Expression": ... - def __mul__(self, expr: "Expression") -> "Expression": ... def __rmul__(self, expr: "Expression") -> "Expression": ... - def __div__(self, expr: "Expression") -> "Expression": ... def __rdiv__(self, expr: "Expression") -> "Expression": ... - def __truediv__(self, expr: "Expression") -> "Expression": ... def __rtruediv__(self, expr: "Expression") -> "Expression": ... - def __floordiv__(self, expr: "Expression") -> "Expression": ... def __rfloordiv__(self, expr: "Expression") -> "Expression": ... - def __mod__(self, expr: "Expression") -> "Expression": ... def __rmod__(self, expr: "Expression") -> "Expression": ... - def __pow__(self, expr: "Expression") -> "Expression": ... def __rpow__(self, expr: "Expression") -> "Expression": ... - def __and__(self, expr: "Expression") -> "Expression": ... def __rand__(self, expr: "Expression") -> "Expression": ... def __or__(self, expr: "Expression") -> "Expression": ... def __ror__(self, expr: "Expression") -> "Expression": ... def __invert__(self) -> "Expression": ... - - def __eq__(# type: ignore[override] - self, expr: "Expression") -> "Expression": ... - def __ne__(# type: ignore[override] - self, expr: "Expression") -> "Expression": ... + def __eq__( # type: ignore[override] + self, expr: "Expression" + ) -> "Expression": ... + def __ne__( # type: ignore[override] + self, expr: "Expression" + ) -> "Expression": ... def __gt__(self, expr: "Expression") -> "Expression": ... def __ge__(self, expr: "Expression") -> "Expression": ... def __lt__(self, expr: "Expression") -> "Expression": ... def __le__(self, expr: "Expression") -> "Expression": ... - def show(self) -> None: ... def __repr__(self) -> str: ... def get_name(self) -> str: ... @@ -291,7 +280,18 @@ class DuckDBPyConnection: def unregister_filesystem(self, name: str) -> None: ... def list_filesystems(self) -> list: ... def filesystem_is_registered(self, name: str) -> bool: ... - def create_function(self, name: str, function: function, parameters: Optional[List[DuckDBPyType]] = None, return_type: Optional[DuckDBPyType] = None, *, type: Optional[PythonUDFType] = PythonUDFType.NATIVE, null_handling: Optional[FunctionNullHandling] = FunctionNullHandling.DEFAULT, exception_handling: Optional[PythonExceptionHandling] = PythonExceptionHandling.DEFAULT, side_effects: bool = False) -> DuckDBPyConnection: ... + def create_function( + self, + name: str, + function: function, + parameters: Optional[List[DuckDBPyType]] = None, + return_type: Optional[DuckDBPyType] = None, + *, + type: Optional[PythonUDFType] = PythonUDFType.NATIVE, + null_handling: Optional[FunctionNullHandling] = FunctionNullHandling.DEFAULT, + exception_handling: Optional[PythonExceptionHandling] = PythonExceptionHandling.DEFAULT, + side_effects: bool = False, + ) -> DuckDBPyConnection: ... def remove_function(self, name: str) -> DuckDBPyConnection: ... def sqltype(self, type_str: str) -> DuckDBPyType: ... def dtype(self, type_str: str) -> DuckDBPyType: ... @@ -334,21 +334,152 @@ class DuckDBPyConnection: def unregister(self, view_name: str) -> DuckDBPyConnection: ... def table(self, table_name: str) -> DuckDBPyRelation: ... def view(self, view_name: str) -> DuckDBPyRelation: ... - def values(self, *args: Union[List[Any],Expression, Tuple[Expression]]) -> DuckDBPyRelation: ... + def values(self, *args: Union[List[Any], Expression, Tuple[Expression]]) -> DuckDBPyRelation: ... def table_function(self, name: str, parameters: object = None) -> DuckDBPyRelation: ... - def read_json(self, path_or_buffer: Union[str, StringIO, TextIOBase], *, columns: Optional[Dict[str,str]] = None, sample_size: Optional[int] = None, maximum_depth: Optional[int] = None, records: Optional[str] = None, format: Optional[str] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, compression: Optional[str] = None, maximum_object_size: Optional[int] = None, ignore_errors: Optional[bool] = None, convert_strings_to_integers: Optional[bool] = None, field_appearance_threshold: Optional[float] = None, map_inference_threshold: Optional[int] = None, maximum_sample_files: Optional[int] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None) -> DuckDBPyRelation: ... + def read_json( + self, + path_or_buffer: Union[str, StringIO, TextIOBase], + *, + columns: Optional[Dict[str, str]] = None, + sample_size: Optional[int] = None, + maximum_depth: Optional[int] = None, + records: Optional[str] = None, + format: Optional[str] = None, + date_format: Optional[str] = None, + timestamp_format: Optional[str] = None, + compression: Optional[str] = None, + maximum_object_size: Optional[int] = None, + ignore_errors: Optional[bool] = None, + convert_strings_to_integers: Optional[bool] = None, + field_appearance_threshold: Optional[float] = None, + map_inference_threshold: Optional[int] = None, + maximum_sample_files: Optional[int] = None, + filename: Optional[bool | str] = None, + hive_partitioning: Optional[bool] = None, + union_by_name: Optional[bool] = None, + hive_types: Optional[Dict[str, str]] = None, + hive_types_autocast: Optional[bool] = None, + ) -> DuckDBPyRelation: ... def extract_statements(self, query: str) -> List[Statement]: ... def sql(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... def query(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... def from_query(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... - def read_csv(self, path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str| List[str]] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None, lineterminator: Optional[str] = None, columns: Optional[Dict[str, str]] = None, auto_type_candidates: Optional[List[str]] = None, max_line_size: Optional[int] = None, ignore_errors: Optional[bool] = None, store_rejects: Optional[bool] = None, rejects_table: Optional[str] = None, rejects_scan: Optional[str] = None, rejects_limit: Optional[int] = None, force_not_null: Optional[List[str]] = None, buffer_size: Optional[int] = None, decimal: Optional[str] = None, allow_quoted_nulls: Optional[bool] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None) -> DuckDBPyRelation: ... - def from_csv_auto(self, path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str| List[str]] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None, lineterminator: Optional[str] = None, columns: Optional[Dict[str, str]] = None, auto_type_candidates: Optional[List[str]] = None, max_line_size: Optional[int] = None, ignore_errors: Optional[bool] = None, store_rejects: Optional[bool] = None, rejects_table: Optional[str] = None, rejects_scan: Optional[str] = None, rejects_limit: Optional[int] = None, force_not_null: Optional[List[str]] = None, buffer_size: Optional[int] = None, decimal: Optional[str] = None, allow_quoted_nulls: Optional[bool] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None) -> DuckDBPyRelation: ... + def read_csv( + self, + path_or_buffer: Union[str, StringIO, TextIOBase], + *, + header: Optional[bool | int] = None, + compression: Optional[str] = None, + sep: Optional[str] = None, + delimiter: Optional[str] = None, + dtype: Optional[Dict[str, str] | List[str]] = None, + na_values: Optional[str | List[str]] = None, + skiprows: Optional[int] = None, + quotechar: Optional[str] = None, + escapechar: Optional[str] = None, + encoding: Optional[str] = None, + parallel: Optional[bool] = None, + date_format: Optional[str] = None, + timestamp_format: Optional[str] = None, + sample_size: Optional[int] = None, + all_varchar: Optional[bool] = None, + normalize_names: Optional[bool] = None, + null_padding: Optional[bool] = None, + names: Optional[List[str]] = None, + lineterminator: Optional[str] = None, + columns: Optional[Dict[str, str]] = None, + auto_type_candidates: Optional[List[str]] = None, + max_line_size: Optional[int] = None, + ignore_errors: Optional[bool] = None, + store_rejects: Optional[bool] = None, + rejects_table: Optional[str] = None, + rejects_scan: Optional[str] = None, + rejects_limit: Optional[int] = None, + force_not_null: Optional[List[str]] = None, + buffer_size: Optional[int] = None, + decimal: Optional[str] = None, + allow_quoted_nulls: Optional[bool] = None, + filename: Optional[bool | str] = None, + hive_partitioning: Optional[bool] = None, + union_by_name: Optional[bool] = None, + hive_types: Optional[Dict[str, str]] = None, + hive_types_autocast: Optional[bool] = None, + ) -> DuckDBPyRelation: ... + def from_csv_auto( + self, + path_or_buffer: Union[str, StringIO, TextIOBase], + *, + header: Optional[bool | int] = None, + compression: Optional[str] = None, + sep: Optional[str] = None, + delimiter: Optional[str] = None, + dtype: Optional[Dict[str, str] | List[str]] = None, + na_values: Optional[str | List[str]] = None, + skiprows: Optional[int] = None, + quotechar: Optional[str] = None, + escapechar: Optional[str] = None, + encoding: Optional[str] = None, + parallel: Optional[bool] = None, + date_format: Optional[str] = None, + timestamp_format: Optional[str] = None, + sample_size: Optional[int] = None, + all_varchar: Optional[bool] = None, + normalize_names: Optional[bool] = None, + null_padding: Optional[bool] = None, + names: Optional[List[str]] = None, + lineterminator: Optional[str] = None, + columns: Optional[Dict[str, str]] = None, + auto_type_candidates: Optional[List[str]] = None, + max_line_size: Optional[int] = None, + ignore_errors: Optional[bool] = None, + store_rejects: Optional[bool] = None, + rejects_table: Optional[str] = None, + rejects_scan: Optional[str] = None, + rejects_limit: Optional[int] = None, + force_not_null: Optional[List[str]] = None, + buffer_size: Optional[int] = None, + decimal: Optional[str] = None, + allow_quoted_nulls: Optional[bool] = None, + filename: Optional[bool | str] = None, + hive_partitioning: Optional[bool] = None, + union_by_name: Optional[bool] = None, + hive_types: Optional[Dict[str, str]] = None, + hive_types_autocast: Optional[bool] = None, + ) -> DuckDBPyRelation: ... def from_df(self, df: pandas.DataFrame) -> DuckDBPyRelation: ... def from_arrow(self, arrow_object: object) -> DuckDBPyRelation: ... - def from_parquet(self, file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None) -> DuckDBPyRelation: ... - def read_parquet(self, file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None) -> DuckDBPyRelation: ... + def from_parquet( + self, + file_glob: str, + binary_as_string: bool = False, + *, + file_row_number: bool = False, + filename: bool = False, + hive_partitioning: bool = False, + union_by_name: bool = False, + compression: Optional[str] = None, + ) -> DuckDBPyRelation: ... + def read_parquet( + self, + file_glob: str, + binary_as_string: bool = False, + *, + file_row_number: bool = False, + filename: bool = False, + hive_partitioning: bool = False, + union_by_name: bool = False, + compression: Optional[str] = None, + ) -> DuckDBPyRelation: ... def get_table_names(self, query: str, *, qualified: bool = False) -> Set[str]: ... - def install_extension(self, extension: str, *, force_install: bool = False, repository: Optional[str] = None, repository_url: Optional[str] = None, version: Optional[str] = None) -> None: ... + def install_extension( + self, + extension: str, + *, + force_install: bool = False, + repository: Optional[str] = None, + repository_url: Optional[str] = None, + version: Optional[str] = None, + ) -> None: ... def load_extension(self, extension: str) -> None: ... # END OF CONNECTION METHODS @@ -359,19 +490,41 @@ class DuckDBPyRelation: def __init__(self, *args, **kwargs) -> None: ... def __contains__(self, name: str) -> bool: ... def aggregate(self, aggr_expr: str, group_expr: str = ...) -> DuckDBPyRelation: ... - def apply(self, function_name: str, function_aggr: str, group_expr: str = ..., function_parameter: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - + def apply( + self, + function_name: str, + function_aggr: str, + group_expr: str = ..., + function_parameter: str = ..., + projected_columns: str = ..., + ) -> DuckDBPyRelation: ... def cume_dist(self, window_spec: str, projected_columns: str = ...) -> DuckDBPyRelation: ... def dense_rank(self, window_spec: str, projected_columns: str = ...) -> DuckDBPyRelation: ... def percent_rank(self, window_spec: str, projected_columns: str = ...) -> DuckDBPyRelation: ... def rank(self, window_spec: str, projected_columns: str = ...) -> DuckDBPyRelation: ... def rank_dense(self, window_spec: str, projected_columns: str = ...) -> DuckDBPyRelation: ... def row_number(self, window_spec: str, projected_columns: str = ...) -> DuckDBPyRelation: ... - - def lag(self, column: str, window_spec: str, offset: int, default_value: str, ignore_nulls: bool, projected_columns: str = ...) -> DuckDBPyRelation: ... - def lead(self, column: str, window_spec: str, offset: int, default_value: str, ignore_nulls: bool, projected_columns: str = ...) -> DuckDBPyRelation: ... - def nth_value(self, column: str, window_spec: str, offset: int, ignore_nulls: bool = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - + def lag( + self, + column: str, + window_spec: str, + offset: int, + default_value: str, + ignore_nulls: bool, + projected_columns: str = ..., + ) -> DuckDBPyRelation: ... + def lead( + self, + column: str, + window_spec: str, + offset: int, + default_value: str, + ignore_nulls: bool, + projected_columns: str = ..., + ) -> DuckDBPyRelation: ... + def nth_value( + self, column: str, window_spec: str, offset: int, ignore_nulls: bool = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... def value_counts(self, column: str, groups: str = ...) -> DuckDBPyRelation: ... def geomean(self, column: str, groups: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... def first(self, column: str, groups: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... @@ -380,41 +533,119 @@ class DuckDBPyRelation: def last_value(self, column: str, window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... def mode(self, aggregation_columns: str, group_columns: str = ...) -> DuckDBPyRelation: ... def n_tile(self, window_spec: str, num_buckets: int, projected_columns: str = ...) -> DuckDBPyRelation: ... - def quantile_cont(self, column: str, q: Union[float, List[float]] = ..., groups: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def quantile_disc(self, column: str, q: Union[float, List[float]] = ..., groups: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... + def quantile_cont( + self, column: str, q: Union[float, List[float]] = ..., groups: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def quantile_disc( + self, column: str, q: Union[float, List[float]] = ..., groups: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... def sum(self, sum_aggr: str, group_expr: str = ...) -> DuckDBPyRelation: ... - - def any_value(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def arg_max(self, arg_column: str, value_column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def arg_min(self, arg_column: str, value_column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def avg(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def bit_and(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def bit_or(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def bit_xor(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def bitstring_agg(self, column: str, min: Optional[int], max: Optional[int], groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def bool_and(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def bool_or(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def count(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def favg(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def fsum(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def histogram(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def max(self, max_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def min(self, min_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def mean(self, mean_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def median(self, median_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def product(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def quantile(self, q: str, quantile_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def std(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def stddev(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def stddev_pop(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def stddev_samp(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def string_agg(self, column: str, sep: str = ..., groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def var(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def var_pop(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def var_samp(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def variance(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def list(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - + def any_value( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def arg_max( + self, + arg_column: str, + value_column: str, + groups: str = ..., + window_spec: str = ..., + projected_columns: str = ..., + ) -> DuckDBPyRelation: ... + def arg_min( + self, + arg_column: str, + value_column: str, + groups: str = ..., + window_spec: str = ..., + projected_columns: str = ..., + ) -> DuckDBPyRelation: ... + def avg( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def bit_and( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def bit_or( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def bit_xor( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def bitstring_agg( + self, + column: str, + min: Optional[int], + max: Optional[int], + groups: str = ..., + window_spec: str = ..., + projected_columns: str = ..., + ) -> DuckDBPyRelation: ... + def bool_and( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def bool_or( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def count( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def favg( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def fsum( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def histogram( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def max( + self, max_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def min( + self, min_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def mean( + self, mean_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def median( + self, median_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def product( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def quantile( + self, q: str, quantile_aggr: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def std( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def stddev( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def stddev_pop( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def stddev_samp( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def string_agg( + self, column: str, sep: str = ..., groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def var( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def var_pop( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def var_samp( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def variance( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... + def list( + self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ... + ) -> DuckDBPyRelation: ... def arrow(self, batch_size: int = ...) -> pyarrow.lib.RecordBatchReader: ... def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> object: ... def create(self, table_name: str) -> None: ... @@ -424,7 +655,7 @@ class DuckDBPyRelation: def distinct(self) -> DuckDBPyRelation: ... def except_(self, other_rel: DuckDBPyRelation) -> DuckDBPyRelation: ... def execute(self, *args, **kwargs) -> DuckDBPyRelation: ... - def explain(self, type: Optional[Literal['standard', 'analyze'] | int] = 'standard') -> str: ... + def explain(self, type: Optional[Literal["standard", "analyze"] | int] = "standard") -> str: ... def fetchall(self) -> List[Any]: ... def fetchmany(self, size: int = ...) -> List[Any]: ... def fetchnumpy(self) -> dict: ... @@ -437,7 +668,9 @@ class DuckDBPyRelation: def update(self, set: Dict[str, Expression], condition: Optional[Expression] = None) -> None: ... def insert_into(self, table_name: str) -> None: ... def intersect(self, other_rel: DuckDBPyRelation) -> DuckDBPyRelation: ... - def join(self, other_rel: DuckDBPyRelation, condition: Union[str, Expression], how: str = ...) -> DuckDBPyRelation: ... + def join( + self, other_rel: DuckDBPyRelation, condition: Union[str, Expression], how: str = ... + ) -> DuckDBPyRelation: ... def cross(self, other_rel: DuckDBPyRelation) -> DuckDBPyRelation: ... def limit(self, n: int, offset: int = ...) -> DuckDBPyRelation: ... def map(self, map_function: function, schema: Optional[Dict[str, DuckDBPyType]] = None) -> DuckDBPyRelation: ... @@ -448,46 +681,55 @@ class DuckDBPyRelation: def pl(self, rows_per_batch: int = ..., connection: DuckDBPyConnection = ...) -> polars.DataFrame: ... def query(self, virtual_table_name: str, sql_query: str) -> DuckDBPyRelation: ... def record_batch(self, batch_size: int = ...) -> pyarrow.lib.RecordBatchReader: ... - def fetch_record_batch(self, rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.RecordBatchReader: ... + def fetch_record_batch( + self, rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ... + ) -> pyarrow.lib.RecordBatchReader: ... def select_types(self, types: List[Union[str, DuckDBPyType]]) -> DuckDBPyRelation: ... def select_dtypes(self, types: List[Union[str, DuckDBPyType]]) -> DuckDBPyRelation: ... def set_alias(self, alias: str) -> DuckDBPyRelation: ... - def show(self, max_width: Optional[int] = None, max_rows: Optional[int] = None, max_col_width: Optional[int] = None, null_value: Optional[str] = None, render_mode: Optional[RenderMode] = None) -> None: ... + def show( + self, + max_width: Optional[int] = None, + max_rows: Optional[int] = None, + max_col_width: Optional[int] = None, + null_value: Optional[str] = None, + render_mode: Optional[RenderMode] = None, + ) -> None: ... def sql_query(self) -> str: ... def to_arrow_table(self, batch_size: int = ...) -> pyarrow.lib.Table: ... def to_csv( - self, - file_name: str, - sep: Optional[str] = None, - na_rep: Optional[str] = None, - header: Optional[bool] = None, - quotechar: Optional[str] = None, - escapechar: Optional[str] = None, - date_format: Optional[str] = None, - timestamp_format: Optional[str] = None, - quoting: Optional[str | int] = None, - encoding: Optional[str] = None, - compression: Optional[str] = None, - write_partition_columns: Optional[bool] = None, - overwrite: Optional[bool] = None, - per_thread_output: Optional[bool] = None, - use_tmp_file: Optional[bool] = None, - partition_by: Optional[List[str]] = None + self, + file_name: str, + sep: Optional[str] = None, + na_rep: Optional[str] = None, + header: Optional[bool] = None, + quotechar: Optional[str] = None, + escapechar: Optional[str] = None, + date_format: Optional[str] = None, + timestamp_format: Optional[str] = None, + quoting: Optional[str | int] = None, + encoding: Optional[str] = None, + compression: Optional[str] = None, + write_partition_columns: Optional[bool] = None, + overwrite: Optional[bool] = None, + per_thread_output: Optional[bool] = None, + use_tmp_file: Optional[bool] = None, + partition_by: Optional[List[str]] = None, ) -> None: ... def to_df(self, *args, **kwargs) -> pandas.DataFrame: ... def to_parquet( - self, - file_name: str, - compression: Optional[str] = None, - field_ids: Optional[dict | str] = None, - row_group_size_bytes: Optional[int | str] = None, - row_group_size: Optional[int] = None, - partition_by: Optional[List[str]] = None, - write_partition_columns: Optional[bool] = None, - overwrite: Optional[bool] = None, - per_thread_output: Optional[bool] = None, - use_tmp_file: Optional[bool] = None, - append: Optional[bool] = None + self, + file_name: str, + compression: Optional[str] = None, + field_ids: Optional[dict | str] = None, + row_group_size_bytes: Optional[int | str] = None, + row_group_size: Optional[int] = None, + partition_by: Optional[List[str]] = None, + write_partition_columns: Optional[bool] = None, + overwrite: Optional[bool] = None, + per_thread_output: Optional[bool] = None, + use_tmp_file: Optional[bool] = None, + append: Optional[bool] = None, ) -> None: ... def fetch_df_chunk(self, vectors_per_chunk: int = 1, *, date_as_object: bool = False) -> pandas.DataFrame: ... def to_table(self, table_name: str) -> None: ... @@ -497,37 +739,37 @@ class DuckDBPyRelation: def union(self, union_rel: DuckDBPyRelation) -> DuckDBPyRelation: ... def unique(self, unique_aggr: str) -> DuckDBPyRelation: ... def write_csv( - self, - file_name: str, - sep: Optional[str] = None, - na_rep: Optional[str] = None, - header: Optional[bool] = None, - quotechar: Optional[str] = None, - escapechar: Optional[str] = None, - date_format: Optional[str] = None, - timestamp_format: Optional[str] = None, - quoting: Optional[str | int] = None, - encoding: Optional[str] = None, - compression: Optional[str] = None, - write_partition_columns: Optional[bool] = None, - overwrite: Optional[bool] = None, - per_thread_output: Optional[bool] = None, - use_tmp_file: Optional[bool] = None, - partition_by: Optional[List[str]] = None + self, + file_name: str, + sep: Optional[str] = None, + na_rep: Optional[str] = None, + header: Optional[bool] = None, + quotechar: Optional[str] = None, + escapechar: Optional[str] = None, + date_format: Optional[str] = None, + timestamp_format: Optional[str] = None, + quoting: Optional[str | int] = None, + encoding: Optional[str] = None, + compression: Optional[str] = None, + write_partition_columns: Optional[bool] = None, + overwrite: Optional[bool] = None, + per_thread_output: Optional[bool] = None, + use_tmp_file: Optional[bool] = None, + partition_by: Optional[List[str]] = None, ) -> None: ... def write_parquet( - self, - file_name: str, - compression: Optional[str] = None, - field_ids: Optional[dict | str] = None, - row_group_size_bytes: Optional[int | str] = None, - row_group_size: Optional[int] = None, - partition_by: Optional[List[str]] = None, - write_partition_columns: Optional[bool] = None, - overwrite: Optional[bool] = None, - per_thread_output: Optional[bool] = None, - use_tmp_file: Optional[bool] = None, - append: Optional[bool] = None + self, + file_name: str, + compression: Optional[str] = None, + field_ids: Optional[dict | str] = None, + row_group_size_bytes: Optional[int | str] = None, + row_group_size: Optional[int] = None, + partition_by: Optional[List[str]] = None, + write_partition_columns: Optional[bool] = None, + overwrite: Optional[bool] = None, + per_thread_output: Optional[bool] = None, + use_tmp_file: Optional[bool] = None, + append: Optional[bool] = None, ) -> None: ... def __len__(self) -> int: ... @property @@ -546,7 +788,6 @@ class DuckDBPyRelation: def types(self) -> List[DuckDBPyType]: ... class Error(Exception): ... - class FatalException(Error): ... class HTTPException(IOException): @@ -556,51 +797,31 @@ class HTTPException(IOException): headers: Dict[str, str] class IOException(OperationalError): ... - class IntegrityError(Error): ... - class InternalError(Error): ... - class InternalException(InternalError): ... - class InterruptException(Error): ... - class InvalidInputException(ProgrammingError): ... - class InvalidTypeException(ProgrammingError): ... - class NotImplementedException(NotSupportedError): ... - class NotSupportedError(Error): ... - class OperationalError(Error): ... - class OutOfMemoryException(OperationalError): ... - class OutOfRangeException(DataError): ... - class ParserException(ProgrammingError): ... - class PermissionException(Error): ... - class ProgrammingError(Error): ... - class SequenceException(Error): ... - class SerializationException(OperationalError): ... - class SyntaxException(ProgrammingError): ... - class TransactionException(OperationalError): ... - class TypeMismatchException(DataError): ... - class Warning(Exception): ... class token_type: # stubgen override - these make mypy sad - #__doc__: ClassVar[str] = ... # read-only - #__members__: ClassVar[dict] = ... # read-only + # __doc__: ClassVar[str] = ... # read-only + # __members__: ClassVar[dict] = ... # read-only __entries: ClassVar[dict] = ... comment: ClassVar[token_type] = ... identifier: ClassVar[token_type] = ... @@ -640,7 +861,18 @@ def register_filesystem(filesystem: fsspec.AbstractFileSystem, *, connection: Du def unregister_filesystem(name: str, *, connection: DuckDBPyConnection = ...) -> None: ... def list_filesystems(*, connection: DuckDBPyConnection = ...) -> list: ... def filesystem_is_registered(name: str, *, connection: DuckDBPyConnection = ...) -> bool: ... -def create_function(name: str, function: function, parameters: Optional[List[DuckDBPyType]] = None, return_type: Optional[DuckDBPyType] = None, *, type: Optional[PythonUDFType] = PythonUDFType.NATIVE, null_handling: Optional[FunctionNullHandling] = FunctionNullHandling.DEFAULT, exception_handling: Optional[PythonExceptionHandling] = PythonExceptionHandling.DEFAULT, side_effects: bool = False, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... +def create_function( + name: str, + function: function, + parameters: Optional[List[DuckDBPyType]] = None, + return_type: Optional[DuckDBPyType] = None, + *, + type: Optional[PythonUDFType] = PythonUDFType.NATIVE, + null_handling: Optional[FunctionNullHandling] = FunctionNullHandling.DEFAULT, + exception_handling: Optional[PythonExceptionHandling] = PythonExceptionHandling.DEFAULT, + side_effects: bool = False, + connection: DuckDBPyConnection = ..., +) -> DuckDBPyConnection: ... def remove_function(name: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... def sqltype(type_str: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... def dtype(type_str: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... @@ -649,14 +881,24 @@ def array_type(type: DuckDBPyType, size: int, *, connection: DuckDBPyConnection def list_type(type: DuckDBPyType, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... def union_type(members: DuckDBPyType, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... def string_type(collation: str = "", *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def enum_type(name: str, type: DuckDBPyType, values: List[Any], *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... +def enum_type( + name: str, type: DuckDBPyType, values: List[Any], *, connection: DuckDBPyConnection = ... +) -> DuckDBPyType: ... def decimal_type(width: int, scale: int, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def struct_type(fields: Union[Dict[str, DuckDBPyType], List[str]], *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... -def row_type(fields: Union[Dict[str, DuckDBPyType], List[str]], *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... +def struct_type( + fields: Union[Dict[str, DuckDBPyType], List[str]], *, connection: DuckDBPyConnection = ... +) -> DuckDBPyType: ... +def row_type( + fields: Union[Dict[str, DuckDBPyType], List[str]], *, connection: DuckDBPyConnection = ... +) -> DuckDBPyType: ... def map_type(key: DuckDBPyType, value: DuckDBPyType, *, connection: DuckDBPyConnection = ...) -> DuckDBPyType: ... def duplicate(*, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def execute(query: object, parameters: object = None, *, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def executemany(query: object, parameters: object = None, *, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... +def execute( + query: object, parameters: object = None, *, connection: DuckDBPyConnection = ... +) -> DuckDBPyConnection: ... +def executemany( + query: object, parameters: object = None, *, connection: DuckDBPyConnection = ... +) -> DuckDBPyConnection: ... def close(*, connection: DuckDBPyConnection = ...) -> None: ... def interrupt(*, connection: DuckDBPyConnection = ...) -> None: ... def query_progress(*, connection: DuckDBPyConnection = ...) -> float: ... @@ -667,10 +909,16 @@ def fetchnumpy(*, connection: DuckDBPyConnection = ...) -> dict: ... def fetchdf(*, date_as_object: bool = False, connection: DuckDBPyConnection = ...) -> pandas.DataFrame: ... def fetch_df(*, date_as_object: bool = False, connection: DuckDBPyConnection = ...) -> pandas.DataFrame: ... def df(*, date_as_object: bool = False, connection: DuckDBPyConnection = ...) -> pandas.DataFrame: ... -def fetch_df_chunk(vectors_per_chunk: int = 1, *, date_as_object: bool = False, connection: DuckDBPyConnection = ...) -> pandas.DataFrame: ... -def pl(rows_per_batch: int = 1000000, *, lazy: bool = False, connection: DuckDBPyConnection = ...) -> polars.DataFrame: ... +def fetch_df_chunk( + vectors_per_chunk: int = 1, *, date_as_object: bool = False, connection: DuckDBPyConnection = ... +) -> pandas.DataFrame: ... +def pl( + rows_per_batch: int = 1000000, *, lazy: bool = False, connection: DuckDBPyConnection = ... +) -> polars.DataFrame: ... def fetch_arrow_table(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.Table: ... -def fetch_record_batch(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.RecordBatchReader: ... +def fetch_record_batch( + rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ... +) -> pyarrow.lib.RecordBatchReader: ... def arrow(rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.RecordBatchReader: ... def torch(*, connection: DuckDBPyConnection = ...) -> dict: ... def tf(*, connection: DuckDBPyConnection = ...) -> dict: ... @@ -678,36 +926,212 @@ def begin(*, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... def commit(*, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... def rollback(*, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... def checkpoint(*, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... -def append(table_name: str, df: pandas.DataFrame, *, by_name: bool = False, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... +def append( + table_name: str, df: pandas.DataFrame, *, by_name: bool = False, connection: DuckDBPyConnection = ... +) -> DuckDBPyConnection: ... def register(view_name: str, python_object: object, *, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... def unregister(view_name: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyConnection: ... def table(table_name: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... def view(view_name: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def values(*args: Union[List[Any],Expression, Tuple[Expression]], connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def table_function(name: str, parameters: object = None, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def read_json(path_or_buffer: Union[str, StringIO, TextIOBase], *, columns: Optional[Dict[str,str]] = None, sample_size: Optional[int] = None, maximum_depth: Optional[int] = None, records: Optional[str] = None, format: Optional[str] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, compression: Optional[str] = None, maximum_object_size: Optional[int] = None, ignore_errors: Optional[bool] = None, convert_strings_to_integers: Optional[bool] = None, field_appearance_threshold: Optional[float] = None, map_inference_threshold: Optional[int] = None, maximum_sample_files: Optional[int] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... +def values( + *args: Union[List[Any], Expression, Tuple[Expression]], connection: DuckDBPyConnection = ... +) -> DuckDBPyRelation: ... +def table_function( + name: str, parameters: object = None, *, connection: DuckDBPyConnection = ... +) -> DuckDBPyRelation: ... +def read_json( + path_or_buffer: Union[str, StringIO, TextIOBase], + *, + columns: Optional[Dict[str, str]] = None, + sample_size: Optional[int] = None, + maximum_depth: Optional[int] = None, + records: Optional[str] = None, + format: Optional[str] = None, + date_format: Optional[str] = None, + timestamp_format: Optional[str] = None, + compression: Optional[str] = None, + maximum_object_size: Optional[int] = None, + ignore_errors: Optional[bool] = None, + convert_strings_to_integers: Optional[bool] = None, + field_appearance_threshold: Optional[float] = None, + map_inference_threshold: Optional[int] = None, + maximum_sample_files: Optional[int] = None, + filename: Optional[bool | str] = None, + hive_partitioning: Optional[bool] = None, + union_by_name: Optional[bool] = None, + hive_types: Optional[Dict[str, str]] = None, + hive_types_autocast: Optional[bool] = None, + connection: DuckDBPyConnection = ..., +) -> DuckDBPyRelation: ... def extract_statements(query: str, *, connection: DuckDBPyConnection = ...) -> List[Statement]: ... -def sql(query: str, *, alias: str = "", params: object = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def query(query: str, *, alias: str = "", params: object = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def from_query(query: str, *, alias: str = "", params: object = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def read_csv(path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str| List[str]] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None, lineterminator: Optional[str] = None, columns: Optional[Dict[str, str]] = None, auto_type_candidates: Optional[List[str]] = None, max_line_size: Optional[int] = None, ignore_errors: Optional[bool] = None, store_rejects: Optional[bool] = None, rejects_table: Optional[str] = None, rejects_scan: Optional[str] = None, rejects_limit: Optional[int] = None, force_not_null: Optional[List[str]] = None, buffer_size: Optional[int] = None, decimal: Optional[str] = None, allow_quoted_nulls: Optional[bool] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def from_csv_auto(path_or_buffer: Union[str, StringIO, TextIOBase], *, header: Optional[bool | int] = None, compression: Optional[str] = None, sep: Optional[str] = None, delimiter: Optional[str] = None, dtype: Optional[Dict[str, str] | List[str]] = None, na_values: Optional[str| List[str]] = None, skiprows: Optional[int] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, encoding: Optional[str] = None, parallel: Optional[bool] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, sample_size: Optional[int] = None, all_varchar: Optional[bool] = None, normalize_names: Optional[bool] = None, null_padding: Optional[bool] = None, names: Optional[List[str]] = None, lineterminator: Optional[str] = None, columns: Optional[Dict[str, str]] = None, auto_type_candidates: Optional[List[str]] = None, max_line_size: Optional[int] = None, ignore_errors: Optional[bool] = None, store_rejects: Optional[bool] = None, rejects_table: Optional[str] = None, rejects_scan: Optional[str] = None, rejects_limit: Optional[int] = None, force_not_null: Optional[List[str]] = None, buffer_size: Optional[int] = None, decimal: Optional[str] = None, allow_quoted_nulls: Optional[bool] = None, filename: Optional[bool | str] = None, hive_partitioning: Optional[bool] = None, union_by_name: Optional[bool] = None, hive_types: Optional[Dict[str, str]] = None, hive_types_autocast: Optional[bool] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... +def sql( + query: str, *, alias: str = "", params: object = None, connection: DuckDBPyConnection = ... +) -> DuckDBPyRelation: ... +def query( + query: str, *, alias: str = "", params: object = None, connection: DuckDBPyConnection = ... +) -> DuckDBPyRelation: ... +def from_query( + query: str, *, alias: str = "", params: object = None, connection: DuckDBPyConnection = ... +) -> DuckDBPyRelation: ... +def read_csv( + path_or_buffer: Union[str, StringIO, TextIOBase], + *, + header: Optional[bool | int] = None, + compression: Optional[str] = None, + sep: Optional[str] = None, + delimiter: Optional[str] = None, + dtype: Optional[Dict[str, str] | List[str]] = None, + na_values: Optional[str | List[str]] = None, + skiprows: Optional[int] = None, + quotechar: Optional[str] = None, + escapechar: Optional[str] = None, + encoding: Optional[str] = None, + parallel: Optional[bool] = None, + date_format: Optional[str] = None, + timestamp_format: Optional[str] = None, + sample_size: Optional[int] = None, + all_varchar: Optional[bool] = None, + normalize_names: Optional[bool] = None, + null_padding: Optional[bool] = None, + names: Optional[List[str]] = None, + lineterminator: Optional[str] = None, + columns: Optional[Dict[str, str]] = None, + auto_type_candidates: Optional[List[str]] = None, + max_line_size: Optional[int] = None, + ignore_errors: Optional[bool] = None, + store_rejects: Optional[bool] = None, + rejects_table: Optional[str] = None, + rejects_scan: Optional[str] = None, + rejects_limit: Optional[int] = None, + force_not_null: Optional[List[str]] = None, + buffer_size: Optional[int] = None, + decimal: Optional[str] = None, + allow_quoted_nulls: Optional[bool] = None, + filename: Optional[bool | str] = None, + hive_partitioning: Optional[bool] = None, + union_by_name: Optional[bool] = None, + hive_types: Optional[Dict[str, str]] = None, + hive_types_autocast: Optional[bool] = None, + connection: DuckDBPyConnection = ..., +) -> DuckDBPyRelation: ... +def from_csv_auto( + path_or_buffer: Union[str, StringIO, TextIOBase], + *, + header: Optional[bool | int] = None, + compression: Optional[str] = None, + sep: Optional[str] = None, + delimiter: Optional[str] = None, + dtype: Optional[Dict[str, str] | List[str]] = None, + na_values: Optional[str | List[str]] = None, + skiprows: Optional[int] = None, + quotechar: Optional[str] = None, + escapechar: Optional[str] = None, + encoding: Optional[str] = None, + parallel: Optional[bool] = None, + date_format: Optional[str] = None, + timestamp_format: Optional[str] = None, + sample_size: Optional[int] = None, + all_varchar: Optional[bool] = None, + normalize_names: Optional[bool] = None, + null_padding: Optional[bool] = None, + names: Optional[List[str]] = None, + lineterminator: Optional[str] = None, + columns: Optional[Dict[str, str]] = None, + auto_type_candidates: Optional[List[str]] = None, + max_line_size: Optional[int] = None, + ignore_errors: Optional[bool] = None, + store_rejects: Optional[bool] = None, + rejects_table: Optional[str] = None, + rejects_scan: Optional[str] = None, + rejects_limit: Optional[int] = None, + force_not_null: Optional[List[str]] = None, + buffer_size: Optional[int] = None, + decimal: Optional[str] = None, + allow_quoted_nulls: Optional[bool] = None, + filename: Optional[bool | str] = None, + hive_partitioning: Optional[bool] = None, + union_by_name: Optional[bool] = None, + hive_types: Optional[Dict[str, str]] = None, + hive_types_autocast: Optional[bool] = None, + connection: DuckDBPyConnection = ..., +) -> DuckDBPyRelation: ... def from_df(df: pandas.DataFrame, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... def from_arrow(arrow_object: object, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def from_parquet(file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def read_parquet(file_glob: str, binary_as_string: bool = False, *, file_row_number: bool = False, filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, compression: Optional[str] = None, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... +def from_parquet( + file_glob: str, + binary_as_string: bool = False, + *, + file_row_number: bool = False, + filename: bool = False, + hive_partitioning: bool = False, + union_by_name: bool = False, + compression: Optional[str] = None, + connection: DuckDBPyConnection = ..., +) -> DuckDBPyRelation: ... +def read_parquet( + file_glob: str, + binary_as_string: bool = False, + *, + file_row_number: bool = False, + filename: bool = False, + hive_partitioning: bool = False, + union_by_name: bool = False, + compression: Optional[str] = None, + connection: DuckDBPyConnection = ..., +) -> DuckDBPyRelation: ... def get_table_names(query: str, *, qualified: bool = False, connection: DuckDBPyConnection = ...) -> Set[str]: ... -def install_extension(extension: str, *, force_install: bool = False, repository: Optional[str] = None, repository_url: Optional[str] = None, version: Optional[str] = None, connection: DuckDBPyConnection = ...) -> None: ... +def install_extension( + extension: str, + *, + force_install: bool = False, + repository: Optional[str] = None, + repository_url: Optional[str] = None, + version: Optional[str] = None, + connection: DuckDBPyConnection = ..., +) -> None: ... def load_extension(extension: str, *, connection: DuckDBPyConnection = ...) -> None: ... -def project(df: pandas.DataFrame, *args: str, groups: str = "", connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... +def project( + df: pandas.DataFrame, *args: str, groups: str = "", connection: DuckDBPyConnection = ... +) -> DuckDBPyRelation: ... def distinct(df: pandas.DataFrame, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def write_csv(df: pandas.DataFrame, filename: str, *, sep: Optional[str] = None, na_rep: Optional[str] = None, header: Optional[bool] = None, quotechar: Optional[str] = None, escapechar: Optional[str] = None, date_format: Optional[str] = None, timestamp_format: Optional[str] = None, quoting: Optional[str | int] = None, encoding: Optional[str] = None, compression: Optional[str] = None, overwrite: Optional[bool] = None, per_thread_output: Optional[bool] = None, use_tmp_file: Optional[bool] = None, partition_by: Optional[List[str]] = None, write_partition_columns: Optional[bool] = None, connection: DuckDBPyConnection = ...) -> None: ... -def aggregate(df: pandas.DataFrame, aggr_expr: str | List[Expression], group_expr: str = "", *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... +def write_csv( + df: pandas.DataFrame, + filename: str, + *, + sep: Optional[str] = None, + na_rep: Optional[str] = None, + header: Optional[bool] = None, + quotechar: Optional[str] = None, + escapechar: Optional[str] = None, + date_format: Optional[str] = None, + timestamp_format: Optional[str] = None, + quoting: Optional[str | int] = None, + encoding: Optional[str] = None, + compression: Optional[str] = None, + overwrite: Optional[bool] = None, + per_thread_output: Optional[bool] = None, + use_tmp_file: Optional[bool] = None, + partition_by: Optional[List[str]] = None, + write_partition_columns: Optional[bool] = None, + connection: DuckDBPyConnection = ..., +) -> None: ... +def aggregate( + df: pandas.DataFrame, + aggr_expr: str | List[Expression], + group_expr: str = "", + *, + connection: DuckDBPyConnection = ..., +) -> DuckDBPyRelation: ... def alias(df: pandas.DataFrame, alias: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... def filter(df: pandas.DataFrame, filter_expr: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def limit(df: pandas.DataFrame, n: int, offset: int = 0, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... +def limit( + df: pandas.DataFrame, n: int, offset: int = 0, *, connection: DuckDBPyConnection = ... +) -> DuckDBPyRelation: ... def order(df: pandas.DataFrame, order_expr: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... -def query_df(df: pandas.DataFrame, virtual_table_name: str, sql_query: str, *, connection: DuckDBPyConnection = ...) -> DuckDBPyRelation: ... +def query_df( + df: pandas.DataFrame, virtual_table_name: str, sql_query: str, *, connection: DuckDBPyConnection = ... +) -> DuckDBPyRelation: ... def description(*, connection: DuckDBPyConnection = ...) -> Optional[List[Any]]: ... def rowcount(*, connection: DuckDBPyConnection = ...) -> int: ... + # END OF CONNECTION WRAPPER diff --git a/duckdb/bytes_io_wrapper.py b/duckdb/bytes_io_wrapper.py index 0957652b..763fd8b7 100644 --- a/duckdb/bytes_io_wrapper.py +++ b/duckdb/bytes_io_wrapper.py @@ -63,4 +63,3 @@ def read(self, n: Union[int, None] = -1) -> bytes: to_return = combined_bytestring[:n] self.overflow = combined_bytestring[n:] return to_return - diff --git a/duckdb/experimental/__init__.py b/duckdb/experimental/__init__.py index 0ab3305b..a88a6170 100644 --- a/duckdb/experimental/__init__.py +++ b/duckdb/experimental/__init__.py @@ -1,2 +1,3 @@ from . import spark + __all__ = spark.__all__ diff --git a/duckdb/experimental/spark/_globals.py b/duckdb/experimental/spark/_globals.py index be16be41..d6a02326 100644 --- a/duckdb/experimental/spark/_globals.py +++ b/duckdb/experimental/spark/_globals.py @@ -56,7 +56,7 @@ class _NoValueType: __instance = None - def __new__(cls) -> '_NoValueType': + def __new__(cls) -> "_NoValueType": # ensure that only one instance exists if not cls.__instance: cls.__instance = super(_NoValueType, cls).__new__(cls) diff --git a/duckdb/experimental/spark/_typing.py b/duckdb/experimental/spark/_typing.py index 0c06fed5..251ef695 100644 --- a/duckdb/experimental/spark/_typing.py +++ b/duckdb/experimental/spark/_typing.py @@ -30,17 +30,14 @@ class SupportsIAdd(Protocol): - def __iadd__(self, other: "SupportsIAdd") -> "SupportsIAdd": - ... + def __iadd__(self, other: "SupportsIAdd") -> "SupportsIAdd": ... class SupportsOrdering(Protocol): - def __lt__(self, other: "SupportsOrdering") -> bool: - ... + def __lt__(self, other: "SupportsOrdering") -> bool: ... -class SizedIterable(Protocol, Sized, Iterable[T_co]): - ... +class SizedIterable(Protocol, Sized, Iterable[T_co]): ... S = TypeVar("S", bound=SupportsOrdering) diff --git a/duckdb/experimental/spark/context.py b/duckdb/experimental/spark/context.py index 95227add..dd4b016c 100644 --- a/duckdb/experimental/spark/context.py +++ b/duckdb/experimental/spark/context.py @@ -8,7 +8,7 @@ class SparkContext: def __init__(self, master: str) -> None: - self._connection = duckdb.connect(':memory:') + self._connection = duckdb.connect(":memory:") # This aligns the null ordering with Spark. self._connection.execute("set default_null_order='nulls_first_on_asc_last_on_desc'") diff --git a/duckdb/experimental/spark/errors/__init__.py b/duckdb/experimental/spark/errors/__init__.py index 5f2af443..6aac49d7 100644 --- a/duckdb/experimental/spark/errors/__init__.py +++ b/duckdb/experimental/spark/errors/__init__.py @@ -18,6 +18,7 @@ """ PySpark exceptions. """ + from .exceptions.base import ( # noqa: F401 PySparkException, AnalysisException, diff --git a/duckdb/experimental/spark/errors/exceptions/base.py b/duckdb/experimental/spark/errors/exceptions/base.py index fcdce827..48a3ea95 100644 --- a/duckdb/experimental/spark/errors/exceptions/base.py +++ b/duckdb/experimental/spark/errors/exceptions/base.py @@ -2,6 +2,7 @@ from ..utils import ErrorClassesReader + class PySparkException(Exception): """ Base Exception for handling errors generated from PySpark. @@ -79,6 +80,7 @@ def __str__(self) -> str: else: return self.message + class AnalysisException(PySparkException): """ Failed to analyze a SQL query plan. diff --git a/duckdb/experimental/spark/errors/utils.py b/duckdb/experimental/spark/errors/utils.py index 3ef418bd..f1b37f75 100644 --- a/duckdb/experimental/spark/errors/utils.py +++ b/duckdb/experimental/spark/errors/utils.py @@ -37,8 +37,7 @@ def get_error_message(self, error_class: str, message_parameters: dict[str, str] # Verify message parameters. message_parameters_from_template = re.findall("<([a-zA-Z0-9_-]+)>", message_template) assert set(message_parameters_from_template) == set(message_parameters), ( - f"Undefined error message parameter for error class: {error_class}. " - f"Parameters: {message_parameters}" + f"Undefined error message parameter for error class: {error_class}. Parameters: {message_parameters}" ) table = str.maketrans("<>", "{}") diff --git a/duckdb/experimental/spark/exception.py b/duckdb/experimental/spark/exception.py index 21668cf5..60495d88 100644 --- a/duckdb/experimental/spark/exception.py +++ b/duckdb/experimental/spark/exception.py @@ -8,7 +8,7 @@ class ContributionsAcceptedError(NotImplementedError): def __init__(self, message=None) -> None: doc = self.__class__.__doc__ if message: - doc = message + '\n' + doc + doc = message + "\n" + doc super().__init__(doc) diff --git a/duckdb/experimental/spark/sql/_typing.py b/duckdb/experimental/spark/sql/_typing.py index 645b60bb..b5a8b079 100644 --- a/duckdb/experimental/spark/sql/_typing.py +++ b/duckdb/experimental/spark/sql/_typing.py @@ -25,6 +25,7 @@ TypeVar, Union, ) + try: from typing import Literal, Protocol except ImportError: @@ -63,18 +64,15 @@ class SupportsOpen(Protocol): - def open(self, partition_id: int, epoch_id: int) -> bool: - ... + def open(self, partition_id: int, epoch_id: int) -> bool: ... class SupportsProcess(Protocol): - def process(self, row: types.Row) -> None: - ... + def process(self, row: types.Row) -> None: ... class SupportsClose(Protocol): - def close(self, error: Exception) -> None: - ... + def close(self, error: Exception) -> None: ... class UserDefinedFunctionLike(Protocol): @@ -83,11 +81,8 @@ class UserDefinedFunctionLike(Protocol): deterministic: bool @property - def returnType(self) -> types.DataType: - ... + def returnType(self) -> types.DataType: ... - def __call__(self, *args: ColumnOrName) -> Column: - ... + def __call__(self, *args: ColumnOrName) -> Column: ... - def asNondeterministic(self) -> "UserDefinedFunctionLike": - ... + def asNondeterministic(self) -> "UserDefinedFunctionLike": ... diff --git a/duckdb/experimental/spark/sql/catalog.py b/duckdb/experimental/spark/sql/catalog.py index 0cd790f7..3cc96f45 100644 --- a/duckdb/experimental/spark/sql/catalog.py +++ b/duckdb/experimental/spark/sql/catalog.py @@ -37,19 +37,19 @@ def __init__(self, session: SparkSession) -> None: self._session = session def listDatabases(self) -> list[Database]: - res = self._session.conn.sql('select database_name from duckdb_databases()').fetchall() + res = self._session.conn.sql("select database_name from duckdb_databases()").fetchall() def transform_to_database(x) -> Database: - return Database(name=x[0], description=None, locationUri='') + return Database(name=x[0], description=None, locationUri="") databases = [transform_to_database(x) for x in res] return databases def listTables(self) -> list[Table]: - res = self._session.conn.sql('select table_name, database_name, sql, temporary from duckdb_tables()').fetchall() + res = self._session.conn.sql("select table_name, database_name, sql, temporary from duckdb_tables()").fetchall() def transform_to_table(x) -> Table: - return Table(name=x[0], database=x[1], description=x[2], tableType='', isTemporary=x[3]) + return Table(name=x[0], database=x[1], description=x[2], tableType="", isTemporary=x[3]) tables = [transform_to_table(x) for x in res] return tables diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index 0dd86178..f78b31ae 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -99,7 +99,7 @@ def __init__(self, expr: Expression) -> None: self.expr = expr # arithmetic operators - def __neg__(self) -> 'Column': + def __neg__(self) -> "Column": return Column(-self.expr) # `and`, `or`, `not` cannot be overloaded in Python, @@ -161,8 +161,8 @@ def __getitem__(self, k: Any) -> "Column": Examples -------- - >>> df = spark.createDataFrame([('abcedfg', {"key": "value"})], ["l", "d"]) - >>> df.select(df.l[slice(1, 3)], df.d['key']).show() + >>> df = spark.createDataFrame([("abcedfg", {"key": "value"})], ["l", "d"]) + >>> df.select(df.l[slice(1, 3)], df.d["key"]).show() +------------------+------+ |substring(l, 1, 3)|d[key]| +------------------+------+ @@ -196,7 +196,7 @@ def __getattr__(self, item: Any) -> "Column": Examples -------- - >>> df = spark.createDataFrame([('abcedfg', {"key": "value"})], ["l", "d"]) + >>> df = spark.createDataFrame([("abcedfg", {"key": "value"})], ["l", "d"]) >>> df.select(df.d.key).show() +------+ |d[key]| @@ -347,7 +347,6 @@ def __ne__( # type: ignore[override] nulls_first = _unary_op("nulls_first") nulls_last = _unary_op("nulls_last") - def asc_nulls_first(self) -> "Column": return self.asc().nulls_first() @@ -365,4 +364,3 @@ def isNull(self) -> "Column": def isNotNull(self) -> "Column": return Column(self.expr.isnotnull()) - diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 42a5b8f0..19f5576b 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -170,7 +170,7 @@ def withColumns(self, *colsMap: dict[str, Column]) -> "DataFrame": Examples -------- >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) - >>> df.withColumns({'age2': df.age + 2, 'age3': df.age + 3}).show() + >>> df.withColumns({"age2": df.age + 2, "age3": df.age + 3}).show() +---+-----+----+----+ |age| name|age2|age3| +---+-----+----+----+ @@ -248,8 +248,8 @@ def withColumnsRenamed(self, colsMap: dict[str, str]) -> "DataFrame": Examples -------- >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) - >>> df = df.withColumns({'age2': df.age + 2, 'age3': df.age + 3}) - >>> df.withColumnsRenamed({'age2': 'age4', 'age3': 'age5'}).show() + >>> df = df.withColumns({"age2": df.age + 2, "age3": df.age + 3}) + >>> df.withColumnsRenamed({"age2": "age4", "age3": "age5"}).show() +---+-----+----+----+ |age| name|age4|age5| +---+-----+----+----+ @@ -265,9 +265,7 @@ def withColumnsRenamed(self, colsMap: dict[str, str]) -> "DataFrame": unknown_columns = set(colsMap.keys()) - set(self.relation.columns) if unknown_columns: - raise ValueError( - f"DataFrame does not contain column(s): {', '.join(unknown_columns)}" - ) + raise ValueError(f"DataFrame does not contain column(s): {', '.join(unknown_columns)}") # Compute this only once old_column_names = list(colsMap.keys()) @@ -289,11 +287,7 @@ def withColumnsRenamed(self, colsMap: dict[str, str]) -> "DataFrame": rel = self.relation.select(*cols) return DataFrame(rel, self.session) - - - def transform( - self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any - ) -> "DataFrame": + def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any) -> "DataFrame": """Returns a new :class:`DataFrame`. Concise syntax for chaining custom transformations. .. versionadded:: 3.0.0 @@ -325,10 +319,8 @@ def transform( >>> df = spark.createDataFrame([(1, 1.0), (2, 2.0)], ["int", "float"]) >>> def cast_all_to_int(input_df): ... return input_df.select([col(col_name).cast("int") for col_name in input_df.columns]) - ... >>> def sort_columns_asc(input_df): ... return input_df.select(*sorted(input_df.columns)) - ... >>> df.transform(cast_all_to_int).transform(sort_columns_asc).show() +-----+---+ |float|int| @@ -338,8 +330,9 @@ def transform( +-----+---+ >>> def add_n(input_df, n): - ... return input_df.select([(col(col_name) + n).alias(col_name) - ... for col_name in input_df.columns]) + ... return input_df.select( + ... [(col(col_name) + n).alias(col_name) for col_name in input_df.columns] + ... ) >>> df.transform(add_n, 1).transform(add_n, n=10).show() +---+-----+ |int|float| @@ -350,14 +343,11 @@ def transform( """ result = func(self, *args, **kwargs) assert isinstance(result, DataFrame), ( - "Func returned an instance of type [%s], " - "should have been DataFrame." % type(result) + "Func returned an instance of type [%s], should have been DataFrame." % type(result) ) return result - def sort( - self, *cols: Union[str, Column, list[Union[str, Column]]], **kwargs: Any - ) -> "DataFrame": + def sort(self, *cols: Union[str, Column, list[Union[str, Column]]], **kwargs: Any) -> "DataFrame": """Returns a new :class:`DataFrame` sorted by the specified column(s). Parameters @@ -380,8 +370,7 @@ def sort( Examples -------- >>> from pyspark.sql.functions import desc, asc - >>> df = spark.createDataFrame([ - ... (2, "Alice"), (5, "Bob")], schema=["age", "name"]) + >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) Sort the DataFrame in ascending order. @@ -419,8 +408,9 @@ def sort( Specify multiple columns - >>> df = spark.createDataFrame([ - ... (2, "Alice"), (2, "Bob"), (5, "Bob")], schema=["age", "name"]) + >>> df = spark.createDataFrame( + ... [(2, "Alice"), (2, "Bob"), (5, "Bob")], schema=["age", "name"] + ... ) >>> df.orderBy(desc("age"), "name").show() +---+-----+ |age| name| @@ -516,8 +506,7 @@ def filter(self, condition: "ColumnOrName") -> "DataFrame": Examples -------- - >>> df = spark.createDataFrame([ - ... (2, "Alice"), (5, "Bob")], schema=["age", "name"]) + >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) Filter by :class:`Column` instances. @@ -568,13 +557,9 @@ def select(self, *cols) -> "DataFrame": if len(cols) == 1: cols = cols[0] if isinstance(cols, list): - projections = [ - x.expr if isinstance(x, Column) else ColumnExpression(x) for x in cols - ] + projections = [x.expr if isinstance(x, Column) else ColumnExpression(x) for x in cols] else: - projections = [ - cols.expr if isinstance(cols, Column) else ColumnExpression(cols) - ] + projections = [cols.expr if isinstance(cols, Column) else ColumnExpression(cols)] rel = self.relation.select(*projections) return DataFrame(rel, self.session) @@ -636,22 +621,24 @@ def join( >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")]).toDF("age", "name") >>> df2 = spark.createDataFrame([Row(height=80, name="Tom"), Row(height=85, name="Bob")]) >>> df3 = spark.createDataFrame([Row(age=2, name="Alice"), Row(age=5, name="Bob")]) - >>> df4 = spark.createDataFrame([ - ... Row(age=10, height=80, name="Alice"), - ... Row(age=5, height=None, name="Bob"), - ... Row(age=None, height=None, name="Tom"), - ... Row(age=None, height=None, name=None), - ... ]) + >>> df4 = spark.createDataFrame( + ... [ + ... Row(age=10, height=80, name="Alice"), + ... Row(age=5, height=None, name="Bob"), + ... Row(age=None, height=None, name="Tom"), + ... Row(age=None, height=None, name=None), + ... ] + ... ) Inner join on columns (default) - >>> df.join(df2, 'name').select(df.name, df2.height).show() + >>> df.join(df2, "name").select(df.name, df2.height).show() +----+------+ |name|height| +----+------+ | Bob| 85| +----+------+ - >>> df.join(df4, ['name', 'age']).select(df.name, df.age).show() + >>> df.join(df4, ["name", "age"]).select(df.name, df.age).show() +----+---+ |name|age| +----+---+ @@ -660,8 +647,9 @@ def join( Outer join for both DataFrames on the 'name' column. - >>> df.join(df2, df.name == df2.name, 'outer').select( - ... df.name, df2.height).sort(desc("name")).show() + >>> df.join(df2, df.name == df2.name, "outer").select(df.name, df2.height).sort( + ... desc("name") + ... ).show() +-----+------+ | name|height| +-----+------+ @@ -669,7 +657,7 @@ def join( |Alice| NULL| | NULL| 80| +-----+------+ - >>> df.join(df2, 'name', 'outer').select('name', 'height').sort(desc("name")).show() + >>> df.join(df2, "name", "outer").select("name", "height").sort(desc("name")).show() +-----+------+ | name|height| +-----+------+ @@ -680,11 +668,9 @@ def join( Outer join for both DataFrams with multiple columns. - >>> df.join( - ... df3, - ... [df.name == df3.name, df.age == df3.age], - ... 'outer' - ... ).select(df.name, df3.age).show() + >>> df.join(df3, [df.name == df3.name, df.age == df3.age], "outer").select( + ... df.name, df3.age + ... ).show() +-----+---+ | name|age| +-----+---+ @@ -701,12 +687,9 @@ def join( on = [_to_column_expr(x) for x in on] # & all the Expressions together to form one Expression - assert isinstance( - on[0], Expression - ), "on should be Column or list of Column" + assert isinstance(on[0], Expression), "on should be Column or list of Column" on = reduce(lambda x, y: x.__and__(y), cast(list[Expression], on)) - if on is None and how is None: result = self.relation.join(other.relation) else: @@ -765,10 +748,8 @@ def crossJoin(self, other: "DataFrame") -> "DataFrame": Examples -------- >>> from pyspark.sql import Row - >>> df = spark.createDataFrame( - ... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) - >>> df2 = spark.createDataFrame( - ... [Row(height=80, name="Tom"), Row(height=85, name="Bob")]) + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> df2 = spark.createDataFrame([Row(height=80, name="Tom"), Row(height=85, name="Bob")]) >>> df.crossJoin(df2.select("height")).select("age", "name", "height").show() +---+-----+------+ |age| name|height| @@ -799,13 +780,13 @@ def alias(self, alias: str) -> "DataFrame": Examples -------- >>> from pyspark.sql.functions import col, desc - >>> df = spark.createDataFrame( - ... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) >>> df_as1 = df.alias("df_as1") >>> df_as2 = df.alias("df_as2") - >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') - >>> joined_df.select( - ... "df_as1.name", "df_as2.name", "df_as2.age").sort(desc("df_as1.name")).show() + >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), "inner") + >>> joined_df.select("df_as1.name", "df_as2.name", "df_as2.age").sort( + ... desc("df_as1.name") + ... ).show() +-----+-----+---+ | name| name|age| +-----+-----+---+ @@ -853,8 +834,7 @@ def limit(self, num: int) -> "DataFrame": Examples -------- - >>> df = spark.createDataFrame( - ... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) >>> df.limit(1).show() +---+----+ |age|name| @@ -889,25 +869,21 @@ def schema(self) -> StructType: return self._schema @overload - def __getitem__(self, item: Union[int, str]) -> Column: - ... + def __getitem__(self, item: Union[int, str]) -> Column: ... @overload - def __getitem__(self, item: Union[Column, list, tuple]) -> "DataFrame": - ... + def __getitem__(self, item: Union[Column, list, tuple]) -> "DataFrame": ... - def __getitem__( - self, item: Union[int, str, Column, list, tuple] - ) -> Union[Column, "DataFrame"]: + def __getitem__(self, item: Union[int, str, Column, list, tuple]) -> Union[Column, "DataFrame"]: """Returns the column as a :class:`Column`. Examples -------- - >>> df.select(df['age']).collect() + >>> df.select(df["age"]).collect() [Row(age=2), Row(age=5)] - >>> df[ ["name", "age"]].collect() + >>> df[["name", "age"]].collect() [Row(name='Alice', age=2), Row(name='Bob', age=5)] - >>> df[ df.age > 3 ].collect() + >>> df[df.age > 3].collect() [Row(age=5, name='Bob')] >>> df[df[0] > 3].collect() [Row(age=5, name='Bob')] @@ -932,18 +908,14 @@ def __getattr__(self, name: str) -> Column: [Row(age=2), Row(age=5)] """ if name not in self.relation.columns: - raise AttributeError( - "'%s' object has no attribute '%s'" % (self.__class__.__name__, name) - ) + raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, name)) return Column(duckdb.ColumnExpression(self.relation.alias, name)) @overload - def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": - ... + def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": ... @overload - def groupBy(self, __cols: Union[list[Column], list[str]]) -> "GroupedData": - ... + def groupBy(self, __cols: Union[list[Column], list[str]]) -> "GroupedData": ... def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] """Groups the :class:`DataFrame` using the specified columns, @@ -966,8 +938,9 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] Examples -------- - >>> df = spark.createDataFrame([ - ... (2, "Alice"), (2, "Bob"), (2, "Bob"), (5, "Bob")], schema=["age", "name"]) + >>> df = spark.createDataFrame( + ... [(2, "Alice"), (2, "Bob"), (2, "Bob"), (5, "Bob")], schema=["age", "name"] + ... ) Empty grouping columns triggers a global aggregation. @@ -1073,9 +1046,7 @@ def union(self, other: "DataFrame") -> "DataFrame": unionAll = union - def unionByName( - self, other: "DataFrame", allowMissingColumns: bool = False - ) -> "DataFrame": + def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) -> "DataFrame": """Returns a new :class:`DataFrame` containing union of rows in this and another :class:`DataFrame`. @@ -1244,7 +1215,8 @@ def exceptAll(self, other: "DataFrame") -> "DataFrame": Examples -------- >>> df1 = spark.createDataFrame( - ... [("a", 1), ("a", 1), ("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1", "C2"]) + ... [("a", 1), ("a", 1), ("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1", "C2"] + ... ) >>> df2 = spark.createDataFrame([("a", 1), ("b", 3)], ["C1", "C2"]) >>> df1.exceptAll(df2).show() +---+---+ @@ -1284,11 +1256,13 @@ def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": Examples -------- >>> from pyspark.sql import Row - >>> df = spark.createDataFrame([ - ... Row(name='Alice', age=5, height=80), - ... Row(name='Alice', age=5, height=80), - ... Row(name='Alice', age=10, height=80) - ... ]) + >>> df = spark.createDataFrame( + ... [ + ... Row(name="Alice", age=5, height=80), + ... Row(name="Alice", age=5, height=80), + ... Row(name="Alice", age=10, height=80), + ... ] + ... ) Deduplicate the same rows. @@ -1302,7 +1276,7 @@ def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": Deduplicate values on 'name' and 'height' columns. - >>> df.dropDuplicates(['name', 'height']).show() + >>> df.dropDuplicates(["name", "height"]).show() +-----+---+------+ | name|age|height| +-----+---+------+ @@ -1311,7 +1285,7 @@ def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": """ if subset: rn_col = f"tmp_col_{uuid.uuid1().hex}" - subset_str = ', '.join([f'"{c}"' for c in subset]) + subset_str = ", ".join([f'"{c}"' for c in subset]) window_spec = f"OVER(PARTITION BY {subset_str}) AS {rn_col}" df = DataFrame(self.relation.row_number(window_spec, "*"), self.session) return df.filter(f"{rn_col} = 1").drop(rn_col) @@ -1320,7 +1294,6 @@ def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": drop_duplicates = dropDuplicates - def distinct(self) -> "DataFrame": """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. @@ -1331,8 +1304,7 @@ def distinct(self) -> "DataFrame": Examples -------- - >>> df = spark.createDataFrame( - ... [(14, "Tom"), (23, "Alice"), (23, "Alice")], ["age", "name"]) + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (23, "Alice")], ["age", "name"]) Return the number of distinct rows in the :class:`DataFrame` @@ -1352,8 +1324,7 @@ def count(self) -> int: Examples -------- - >>> df = spark.createDataFrame( - ... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) + >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) Return the number of rows in the :class:`DataFrame`. @@ -1369,8 +1340,7 @@ def _cast_types(self, *types) -> "DataFrame": assert types_count == len(existing_columns) cast_expressions = [ - f"{existing}::{target_type} as {existing}" - for existing, target_type in zip(existing_columns, types) + f"{existing}::{target_type} as {existing}" for existing, target_type in zip(existing_columns, types) ] cast_expressions = ", ".join(cast_expressions) new_rel = self.relation.project(cast_expressions) @@ -1380,14 +1350,10 @@ def toDF(self, *cols) -> "DataFrame": existing_columns = self.relation.columns column_count = len(cols) if column_count != len(existing_columns): - raise PySparkValueError( - message="Provided column names and number of columns in the DataFrame don't match" - ) + raise PySparkValueError(message="Provided column names and number of columns in the DataFrame don't match") existing_columns = [ColumnExpression(x) for x in existing_columns] - projections = [ - existing.alias(new) for existing, new in zip(existing_columns, cols) - ] + projections = [existing.alias(new) for existing, new in zip(existing_columns, cols)] new_rel = self.relation.project(*projections) return DataFrame(new_rel, self.session) diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index 78b14de7..dfcf7e2e 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -11,6 +11,7 @@ LambdaExpression, SQLExpression, ) + if TYPE_CHECKING: from .dataframe import DataFrame @@ -105,14 +106,10 @@ def _inner_expr_or_val(val): def struct(*cols: Column) -> Column: - return Column( - FunctionExpression("struct_pack", *[_inner_expr_or_val(x) for x in cols]) - ) + return Column(FunctionExpression("struct_pack", *[_inner_expr_or_val(x) for x in cols])) -def array( - *cols: Union["ColumnOrName", Union[list["ColumnOrName"], tuple["ColumnOrName", ...]]] -) -> Column: +def array(*cols: Union["ColumnOrName", Union[list["ColumnOrName"], tuple["ColumnOrName", ...]]]) -> Column: """Creates a new array column. .. versionadded:: 1.4.0 @@ -134,11 +131,11 @@ def array( Examples -------- >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5)], ("name", "age")) - >>> df.select(array('age', 'age').alias("arr")).collect() + >>> df.select(array("age", "age").alias("arr")).collect() [Row(arr=[2, 2]), Row(arr=[5, 5])] >>> df.select(array([df.age, df.age]).alias("arr")).collect() [Row(arr=[2, 2]), Row(arr=[5, 5])] - >>> df.select(array('age', 'age').alias("col")).printSchema() + >>> df.select(array("age", "age").alias("col")).printSchema() root |-- col: array (nullable = false) | |-- element: long (containsNull = true) @@ -167,6 +164,7 @@ def _to_column_expr(col: ColumnOrName) -> Expression: message_parameters={"arg_name": "col", "arg_type": type(col).__name__}, ) + def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Column: r"""Replace all substrings of the specified string value that match regexp with rep. @@ -174,8 +172,8 @@ def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Colum Examples -------- - >>> df = spark.createDataFrame([('100-200',)], ['str']) - >>> df.select(regexp_replace('str', r'(\d+)', '--').alias('d')).collect() + >>> df = spark.createDataFrame([("100-200",)], ["str"]) + >>> df.select(regexp_replace("str", r"(\d+)", "--").alias("d")).collect() [Row(d='-----')] """ return _invoke_function( @@ -187,9 +185,7 @@ def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Colum ) -def slice( - x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int] -) -> Column: +def slice(x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int]) -> Column: """ Collection function: returns an array containing all the elements in `x` from index `start` (array indices start at 1, or from the end if `start` is negative) with the specified `length`. @@ -215,7 +211,7 @@ def slice( Examples -------- - >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x']) + >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ["x"]) >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect() [Row(sliced=[2, 3]), Row(sliced=[5])] """ @@ -301,9 +297,7 @@ def asc_nulls_first(col: "ColumnOrName") -> Column: Examples -------- - >>> df1 = spark.createDataFrame([(1, "Bob"), - ... (0, None), - ... (2, "Alice")], ["age", "name"]) + >>> df1 = spark.createDataFrame([(1, "Bob"), (0, None), (2, "Alice")], ["age", "name"]) >>> df1.sort(asc_nulls_first(df1.name)).show() +---+-----+ |age| name| @@ -339,9 +333,7 @@ def asc_nulls_last(col: "ColumnOrName") -> Column: Examples -------- - >>> df1 = spark.createDataFrame([(0, None), - ... (1, "Bob"), - ... (2, "Alice")], ["age", "name"]) + >>> df1 = spark.createDataFrame([(0, None), (1, "Bob"), (2, "Alice")], ["age", "name"]) >>> df1.sort(asc_nulls_last(df1.name)).show() +---+-----+ |age| name| @@ -414,9 +406,7 @@ def desc_nulls_first(col: "ColumnOrName") -> Column: Examples -------- - >>> df1 = spark.createDataFrame([(0, None), - ... (1, "Bob"), - ... (2, "Alice")], ["age", "name"]) + >>> df1 = spark.createDataFrame([(0, None), (1, "Bob"), (2, "Alice")], ["age", "name"]) >>> df1.sort(desc_nulls_first(df1.name)).show() +---+-----+ |age| name| @@ -452,9 +442,7 @@ def desc_nulls_last(col: "ColumnOrName") -> Column: Examples -------- - >>> df1 = spark.createDataFrame([(0, None), - ... (1, "Bob"), - ... (2, "Alice")], ["age", "name"]) + >>> df1 = spark.createDataFrame([(0, None), (1, "Bob"), (2, "Alice")], ["age", "name"]) >>> df1.sort(desc_nulls_last(df1.name)).show() +---+-----+ |age| name| @@ -484,16 +472,22 @@ def left(str: "ColumnOrName", len: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("Spark SQL", 3,)], ['a', 'b']) - >>> df.select(left(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "Spark SQL", + ... 3, + ... ) + ... ], + ... ["a", "b"], + ... ) + >>> df.select(left(df.a, df.b).alias("r")).collect() [Row(r='Spa')] """ len = _to_column_expr(len) return Column( CaseExpression(len <= ConstantExpression(0), ConstantExpression("")).otherwise( - FunctionExpression( - "array_slice", _to_column_expr(str), ConstantExpression(0), len - ) + FunctionExpression("array_slice", _to_column_expr(str), ConstantExpression(0), len) ) ) @@ -514,23 +508,27 @@ def right(str: "ColumnOrName", len: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("Spark SQL", 3,)], ['a', 'b']) - >>> df.select(right(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "Spark SQL", + ... 3, + ... ) + ... ], + ... ["a", "b"], + ... ) + >>> df.select(right(df.a, df.b).alias("r")).collect() [Row(r='SQL')] """ len = _to_column_expr(len) return Column( CaseExpression(len <= ConstantExpression(0), ConstantExpression("")).otherwise( - FunctionExpression( - "array_slice", _to_column_expr(str), -len, ConstantExpression(-1) - ) + FunctionExpression("array_slice", _to_column_expr(str), -len, ConstantExpression(-1)) ) ) -def levenshtein( - left: "ColumnOrName", right: "ColumnOrName", threshold: Optional[int] = None -) -> Column: +def levenshtein(left: "ColumnOrName", right: "ColumnOrName", threshold: Optional[int] = None) -> Column: """Computes the Levenshtein distance of the two given strings. .. versionadded:: 1.5.0 @@ -558,10 +556,18 @@ def levenshtein( Examples -------- - >>> df0 = spark.createDataFrame([('kitten', 'sitting',)], ['l', 'r']) - >>> df0.select(levenshtein('l', 'r').alias('d')).collect() + >>> df0 = spark.createDataFrame( + ... [ + ... ( + ... "kitten", + ... "sitting", + ... ) + ... ], + ... ["l", "r"], + ... ) + >>> df0.select(levenshtein("l", "r").alias("d")).collect() [Row(d=3)] - >>> df0.select(levenshtein('l', 'r', 2).alias('d')).collect() + >>> df0.select(levenshtein("l", "r", 2).alias("d")).collect() [Row(d=-1)] """ distance = _invoke_function_over_columns("levenshtein", left, right) @@ -569,7 +575,9 @@ def levenshtein( return distance else: distance = _to_column_expr(distance) - return Column(CaseExpression(distance <= ConstantExpression(threshold), distance).otherwise(ConstantExpression(-1))) + return Column( + CaseExpression(distance <= ConstantExpression(threshold), distance).otherwise(ConstantExpression(-1)) + ) def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: @@ -597,8 +605,13 @@ def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: Examples -------- - >>> df = spark.createDataFrame([('abcd',)], ['s',]) - >>> df.select(lpad(df.s, 6, '#').alias('s')).collect() + >>> df = spark.createDataFrame( + ... [("abcd",)], + ... [ + ... "s", + ... ], + ... ) + >>> df.select(lpad(df.s, 6, "#").alias("s")).collect() [Row(s='##abcd')] """ return _invoke_function("lpad", _to_column_expr(col), ConstantExpression(len), ConstantExpression(pad)) @@ -629,8 +642,13 @@ def rpad(col: "ColumnOrName", len: int, pad: str) -> Column: Examples -------- - >>> df = spark.createDataFrame([('abcd',)], ['s',]) - >>> df.select(rpad(df.s, 6, '#').alias('s')).collect() + >>> df = spark.createDataFrame( + ... [("abcd",)], + ... [ + ... "s", + ... ], + ... ) + >>> df.select(rpad(df.s, 6, "#").alias("s")).collect() [Row(s='abcd##')] """ return _invoke_function("rpad", _to_column_expr(col), ConstantExpression(len), ConstantExpression(pad)) @@ -702,12 +720,14 @@ def asin(col: "ColumnOrName") -> Column: """ col = _to_column_expr(col) # FIXME: ConstantExpression(float("nan")) gives NULL and not NaN - return Column(CaseExpression((col < -1.0) | (col > 1.0), ConstantExpression(float("nan"))).otherwise(FunctionExpression("asin", col))) + return Column( + CaseExpression((col < -1.0) | (col > 1.0), ConstantExpression(float("nan"))).otherwise( + FunctionExpression("asin", col) + ) + ) -def like( - str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None -) -> Column: +def like(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None) -> Column: """ Returns true if str matches `pattern` with `escape`, null if any arguments are null, false otherwise. @@ -728,15 +748,14 @@ def like( Examples -------- - >>> df = spark.createDataFrame([("Spark", "_park")], ['a', 'b']) - >>> df.select(like(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame([("Spark", "_park")], ["a", "b"]) + >>> df.select(like(df.a, df.b).alias("r")).collect() [Row(r=True)] >>> df = spark.createDataFrame( - ... [("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], - ... ['a', 'b'] + ... [("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], ["a", "b"] ... ) - >>> df.select(like(df.a, df.b, lit('/')).alias('r')).collect() + >>> df.select(like(df.a, df.b, lit("/")).alias("r")).collect() [Row(r=True)] """ if escapeChar is None: @@ -746,9 +765,7 @@ def like( return _invoke_function("like_escape", _to_column_expr(str), _to_column_expr(pattern), escapeChar) -def ilike( - str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None -) -> Column: +def ilike(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None) -> Column: """ Returns true if str matches `pattern` with `escape` case-insensitively, null if any arguments are null, false otherwise. @@ -769,15 +786,14 @@ def ilike( Examples -------- - >>> df = spark.createDataFrame([("Spark", "_park")], ['a', 'b']) - >>> df.select(ilike(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame([("Spark", "_park")], ["a", "b"]) + >>> df.select(ilike(df.a, df.b).alias("r")).collect() [Row(r=True)] >>> df = spark.createDataFrame( - ... [("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], - ... ['a', 'b'] + ... [("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], ["a", "b"] ... ) - >>> df.select(ilike(df.a, df.b, lit('/')).alias('r')).collect() + >>> df.select(ilike(df.a, df.b, lit("/")).alias("r")).collect() [Row(r=True)] """ if escapeChar is None: @@ -805,8 +821,8 @@ def array_agg(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([[1],[1],[2]], ["c"]) - >>> df.agg(array_agg('c').alias('r')).collect() + >>> df = spark.createDataFrame([[1], [1], [2]], ["c"]) + >>> df.agg(array_agg("c").alias("r")).collect() [Row(r=[1, 1, 2])] """ return _invoke_function_over_columns("list", col) @@ -838,8 +854,8 @@ def collect_list(col: "ColumnOrName") -> Column: Examples -------- - >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ('age',)) - >>> df2.agg(collect_list('age')).collect() + >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ("age",)) + >>> df2.agg(collect_list("age")).collect() [Row(collect_list(age)=[2, 5, 5])] """ return array_agg(col) @@ -874,15 +890,13 @@ def array_append(col: "ColumnOrName", value: Any) -> Column: >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2="c")]) >>> df.select(array_append(df.c1, df.c2)).collect() [Row(array_append(c1, c2)=['b', 'a', 'c', 'c'])] - >>> df.select(array_append(df.c1, 'x')).collect() + >>> df.select(array_append(df.c1, "x")).collect() [Row(array_append(c1, x)=['b', 'a', 'c', 'x'])] """ return _invoke_function("list_append", _to_column_expr(col), _get_expr(value)) -def array_insert( - arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: Any -) -> Column: +def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: Any) -> Column: """ Collection function: adds an item into a given array at a specified array index. Array indices start at 1, or start from the end if index is negative. @@ -913,12 +927,11 @@ def array_insert( Examples -------- >>> df = spark.createDataFrame( - ... [(['a', 'b', 'c'], 2, 'd'), (['c', 'b', 'a'], -2, 'd')], - ... ['data', 'pos', 'val'] + ... [(["a", "b", "c"], 2, "d"), (["c", "b", "a"], -2, "d")], ["data", "pos", "val"] ... ) - >>> df.select(array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect() + >>> df.select(array_insert(df.data, df.pos.cast("integer"), df.val).alias("data")).collect() [Row(data=['a', 'd', 'b', 'c']), Row(data=['c', 'b', 'd', 'a'])] - >>> df.select(array_insert(df.data, 5, 'hello').alias('data')).collect() + >>> df.select(array_insert(df.data, 5, "hello").alias("data")).collect() [Row(data=['a', 'b', 'c', None, 'hello']), Row(data=['c', 'b', 'a', None, 'hello'])] """ pos = _get_expr(pos) @@ -944,9 +957,7 @@ def array_insert( FunctionExpression( "list_resize", FunctionExpression("list_value", None), - FunctionExpression( - "subtract", FunctionExpression("abs", pos), list_length_plus_1 - ), + FunctionExpression("subtract", FunctionExpression("abs", pos), list_length_plus_1), ), arr, ), @@ -964,9 +975,7 @@ def array_insert( "list_slice", list_, 1, - CaseExpression( - pos_is_positive, FunctionExpression("subtract", pos, 1) - ).otherwise(pos), + CaseExpression(pos_is_positive, FunctionExpression("subtract", pos, 1)).otherwise(pos), ), # Here we insert the value at the specified position FunctionExpression("list_value", _get_expr(value)), @@ -975,9 +984,7 @@ def array_insert( FunctionExpression( "list_slice", list_, - CaseExpression(pos_is_positive, pos).otherwise( - FunctionExpression("add", pos, 1) - ), + CaseExpression(pos_is_positive, pos).otherwise(FunctionExpression("add", pos, 1)), -1, ), ) @@ -1002,7 +1009,7 @@ def array_contains(col: "ColumnOrName", value: Any) -> Column: Examples -------- - >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ["data"]) >>> df.select(array_contains(df.data, "a")).collect() [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)] >>> df.select(array_contains(df.data, lit("a"))).collect() @@ -1033,7 +1040,7 @@ def array_distinct(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data']) + >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ["data"]) >>> df.select(array_distinct(df.data)).collect() [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])] """ @@ -1125,11 +1132,13 @@ def array_max(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) - >>> df.select(array_max(df.data).alias('max')).collect() + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ["data"]) + >>> df.select(array_max(df.data).alias("max")).collect() [Row(max=3), Row(max=10)] """ - return _invoke_function("array_extract", _to_column_expr(_invoke_function_over_columns("array_sort", col)), _get_expr(-1)) + return _invoke_function( + "array_extract", _to_column_expr(_invoke_function_over_columns("array_sort", col)), _get_expr(-1) + ) def array_min(col: "ColumnOrName") -> Column: @@ -1153,11 +1162,13 @@ def array_min(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) - >>> df.select(array_min(df.data).alias('min')).collect() + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ["data"]) + >>> df.select(array_min(df.data).alias("min")).collect() [Row(min=1), Row(min=-1)] """ - return _invoke_function("array_extract", _to_column_expr(_invoke_function_over_columns("array_sort", col)), _get_expr(1)) + return _invoke_function( + "array_extract", _to_column_expr(_invoke_function_over_columns("array_sort", col)), _get_expr(1) + ) def avg(col: "ColumnOrName") -> Column: @@ -1311,11 +1322,17 @@ def median(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([ - ... ("Java", 2012, 20000), ("dotNET", 2012, 5000), - ... ("Java", 2012, 22000), ("dotNET", 2012, 10000), - ... ("dotNET", 2013, 48000), ("Java", 2013, 30000)], - ... schema=("course", "year", "earnings")) + >>> df = spark.createDataFrame( + ... [ + ... ("Java", 2012, 20000), + ... ("dotNET", 2012, 5000), + ... ("Java", 2012, 22000), + ... ("dotNET", 2012, 10000), + ... ("dotNET", 2013, 48000), + ... ("Java", 2013, 30000), + ... ], + ... schema=("course", "year", "earnings"), + ... ) >>> df.groupby("course").agg(median("earnings")).show() +------+----------------+ |course|median(earnings)| @@ -1349,11 +1366,17 @@ def mode(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([ - ... ("Java", 2012, 20000), ("dotNET", 2012, 5000), - ... ("Java", 2012, 20000), ("dotNET", 2012, 5000), - ... ("dotNET", 2013, 48000), ("Java", 2013, 30000)], - ... schema=("course", "year", "earnings")) + >>> df = spark.createDataFrame( + ... [ + ... ("Java", 2012, 20000), + ... ("dotNET", 2012, 5000), + ... ("Java", 2012, 20000), + ... ("dotNET", 2012, 5000), + ... ("dotNET", 2013, 48000), + ... ("Java", 2013, 30000), + ... ], + ... schema=("course", "year", "earnings"), + ... ) >>> df.groupby("course").agg(mode("year")).show() +------+----------+ |course|mode(year)| @@ -1416,14 +1439,12 @@ def any_value(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(None, 1), - ... ("a", 2), - ... ("a", 3), - ... ("b", 8), - ... ("b", 2)], ["c1", "c2"]) - >>> df.select(any_value('c1'), any_value('c2')).collect() + >>> df = spark.createDataFrame( + ... [(None, 1), ("a", 2), ("a", 3), ("b", 8), ("b", 2)], ["c1", "c2"] + ... ) + >>> df.select(any_value("c1"), any_value("c2")).collect() [Row(any_value(c1)=None, any_value(c2)=1)] - >>> df.select(any_value('c1', True), any_value('c2', True)).collect() + >>> df.select(any_value("c1", True), any_value("c2", True)).collect() [Row(any_value(c1)='a', any_value(c2)=1)] """ return _invoke_function_over_columns("any_value", col) @@ -1486,8 +1507,8 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C Examples -------- - >>> df = spark.createDataFrame([1,2,2,3], "INT") - >>> df.agg(approx_count_distinct("value").alias('distinct_values')).show() + >>> df = spark.createDataFrame([1, 2, 2, 3], "INT") + >>> df.agg(approx_count_distinct("value").alias("distinct_values")).show() +---------------+ |distinct_values| +---------------+ @@ -1567,7 +1588,6 @@ def transform( >>> def alternate(x, i): ... return when(i % 2 == 0, x).otherwise(-x) - ... >>> df.select(transform("values", alternate).alias("alternated")).show() +--------------+ | alternated| @@ -1602,8 +1622,8 @@ def concat_ws(sep: str, *cols: "ColumnOrName") -> "Column": Examples -------- - >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) - >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect() + >>> df = spark.createDataFrame([("abcd", "123")], ["s", "d"]) + >>> df.select(concat_ws("-", df.s, df.d).alias("s")).collect() [Row(s='abcd-123')] """ cols = [_to_column_expr(expr) for expr in cols] @@ -1788,7 +1808,7 @@ def isnan(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) + >>> df = spark.createDataFrame([(1.0, float("nan")), (float("nan"), 2.0)], ("a", "b")) >>> df.select("a", "b", isnan("a").alias("r1"), isnan(df.b).alias("r2")).show() +---+---+-----+-----+ | a| b| r1| r2| @@ -1845,7 +1865,7 @@ def isnotnull(col: "ColumnOrName") -> Column: Examples -------- >>> df = spark.createDataFrame([(None,), (1,)], ["e"]) - >>> df.select(isnotnull(df.e).alias('r')).collect() + >>> df.select(isnotnull(df.e).alias("r")).collect() [Row(r=False), Row(r=True)] """ return Column(_to_column_expr(col).isnotnull()) @@ -1862,8 +1882,20 @@ def equal_null(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col2 : :class:`~pyspark.sql.Column` or str Examples -------- - >>> df = spark.createDataFrame([(None, None,), (1, 9,)], ["a", "b"]) - >>> df.select(equal_null(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... None, + ... None, + ... ), + ... ( + ... 1, + ... 9, + ... ), + ... ], + ... ["a", "b"], + ... ) + >>> df.select(equal_null(df.a, df.b).alias("r")).collect() [Row(r=True), Row(r=False)] """ if isinstance(col1, str): @@ -1872,7 +1904,7 @@ def equal_null(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: if isinstance(col2, str): col2 = col(col2) - return nvl((col1 == col2) | ((col1.isNull() & col2.isNull())), lit(False)) + return nvl((col1 == col2) | (col1.isNull() & col2.isNull()), lit(False)) def flatten(col: "ColumnOrName") -> Column: @@ -1898,7 +1930,7 @@ def flatten(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data']) + >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ["data"]) >>> df.show(truncate=False) +------------------------+ |data | @@ -1906,7 +1938,7 @@ def flatten(col: "ColumnOrName") -> Column: |[[1, 2, 3], [4, 5], [6]]| |[NULL, [4, 5]] | +------------------------+ - >>> df.select(flatten(df.data).alias('r')).show() + >>> df.select(flatten(df.data).alias("r")).show() +------------------+ | r| +------------------+ @@ -1916,11 +1948,7 @@ def flatten(col: "ColumnOrName") -> Column: """ col = _to_column_expr(col) contains_null = _list_contains_null(col) - return Column( - CaseExpression(contains_null, None).otherwise( - FunctionExpression("flatten", col) - ) - ) + return Column(CaseExpression(contains_null, None).otherwise(FunctionExpression("flatten", col))) def array_compact(col: "ColumnOrName") -> Column: @@ -1945,7 +1973,7 @@ def array_compact(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ['data']) + >>> df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ["data"]) >>> df.select(array_compact(df.data)).collect() [Row(array_compact(data)=[1, 2, 3]), Row(array_compact(data)=[4, 5, 4])] """ @@ -1977,11 +2005,13 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: Examples -------- - >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data']) + >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ["data"]) >>> df.select(array_remove(df.data, 1)).collect() [Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])] """ - return _invoke_function("list_filter", _to_column_expr(col), LambdaExpression("x", ColumnExpression("x") != ConstantExpression(element))) + return _invoke_function( + "list_filter", _to_column_expr(col), LambdaExpression("x", ColumnExpression("x") != ConstantExpression(element)) + ) def last_day(date: "ColumnOrName") -> Column: @@ -2005,14 +2035,13 @@ def last_day(date: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('1997-02-10',)], ['d']) - >>> df.select(last_day(df.d).alias('date')).collect() + >>> df = spark.createDataFrame([("1997-02-10",)], ["d"]) + >>> df.select(last_day(df.d).alias("date")).collect() [Row(date=datetime.date(1997, 2, 28))] """ return _invoke_function("last_day", _to_column_expr(date)) - def sqrt(col: "ColumnOrName") -> Column: """ Computes the square root of the specified float value. @@ -2129,7 +2158,7 @@ def corr(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> a = range(20) >>> b = [2 * x for x in range(20)] >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) - >>> df.agg(corr("a", "b").alias('c')).collect() + >>> df.agg(corr("a", "b").alias("c")).collect() [Row(c=1.0)] """ return _invoke_function_over_columns("corr", col1, col2) @@ -2243,7 +2272,7 @@ def positive(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(-1,), (0,), (1,)], ['v']) + >>> df = spark.createDataFrame([(-1,), (0,), (1,)], ["v"]) >>> df.select(positive("v").alias("p")).show() +---+ | p| @@ -2303,7 +2332,14 @@ def printf(format: "ColumnOrName", *cols: "ColumnOrName") -> Column: -------- >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( - ... [("aa%d%s", 123, "cc",)], ["a", "b", "c"] + ... [ + ... ( + ... "aa%d%s", + ... 123, + ... "cc", + ... ) + ... ], + ... ["a", "b", "c"], ... ).select(sf.printf("a", "b", "c")).show() +---------------+ |printf(a, b, c)| @@ -2335,9 +2371,9 @@ def product(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(1, 10).toDF('x').withColumn('mod3', col('x') % 3) - >>> prods = df.groupBy('mod3').agg(product('x').alias('product')) - >>> prods.orderBy('mod3').show() + >>> df = spark.range(1, 10).toDF("x").withColumn("mod3", col("x") % 3) + >>> prods = df.groupBy("mod3").agg(product("x").alias("product")) + >>> prods.orderBy("mod3").show() +----+-------+ |mod3|product| +----+-------+ @@ -2375,7 +2411,7 @@ def rand(seed: Optional[int] = None) -> Column: Examples -------- >>> from pyspark.sql import functions as sf - >>> spark.range(0, 2, 1, 1).withColumn('rand', sf.rand(seed=42) * 3).show() + >>> spark.range(0, 2, 1, 1).withColumn("rand", sf.rand(seed=42) * 3).show() +---+------------------+ | id| rand| +---+------------------+ @@ -2409,9 +2445,9 @@ def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: Examples -------- >>> import pyspark.sql.functions as sf - >>> spark.createDataFrame( - ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] - ... ).select(sf.regexp('str', sf.lit(r'(\d+)'))).show() + >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( + ... sf.regexp("str", sf.lit(r"(\d+)")) + ... ).show() +------------------+ |REGEXP(str, (\d+))| +------------------+ @@ -2419,9 +2455,9 @@ def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: +------------------+ >>> import pyspark.sql.functions as sf - >>> spark.createDataFrame( - ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] - ... ).select(sf.regexp('str', sf.lit(r'\d{2}b'))).show() + >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( + ... sf.regexp("str", sf.lit(r"\d{2}b")) + ... ).show() +-------------------+ |REGEXP(str, \d{2}b)| +-------------------+ @@ -2429,9 +2465,9 @@ def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: +-------------------+ >>> import pyspark.sql.functions as sf - >>> spark.createDataFrame( - ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] - ... ).select(sf.regexp('str', sf.col("regexp"))).show() + >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( + ... sf.regexp("str", sf.col("regexp")) + ... ).show() +-------------------+ |REGEXP(str, regexp)| +-------------------+ @@ -2462,11 +2498,11 @@ def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: Examples -------- >>> df = spark.createDataFrame([("1a 2b 14m", r"\d+")], ["str", "regexp"]) - >>> df.select(regexp_count('str', lit(r'\d+')).alias('d')).collect() + >>> df.select(regexp_count("str", lit(r"\d+")).alias("d")).collect() [Row(d=3)] - >>> df.select(regexp_count('str', lit(r'mmm')).alias('d')).collect() + >>> df.select(regexp_count("str", lit(r"mmm")).alias("d")).collect() [Row(d=0)] - >>> df.select(regexp_count("str", col("regexp")).alias('d')).collect() + >>> df.select(regexp_count("str", col("regexp")).alias("d")).collect() [Row(d=3)] """ return _invoke_function_over_columns("len", _invoke_function_over_columns("regexp_extract_all", str, regexp)) @@ -2497,22 +2533,22 @@ def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: Examples -------- - >>> df = spark.createDataFrame([('100-200',)], ['str']) - >>> df.select(regexp_extract('str', r'(\d+)-(\d+)', 1).alias('d')).collect() + >>> df = spark.createDataFrame([("100-200",)], ["str"]) + >>> df.select(regexp_extract("str", r"(\d+)-(\d+)", 1).alias("d")).collect() [Row(d='100')] - >>> df = spark.createDataFrame([('foo',)], ['str']) - >>> df.select(regexp_extract('str', r'(\d+)', 1).alias('d')).collect() + >>> df = spark.createDataFrame([("foo",)], ["str"]) + >>> df.select(regexp_extract("str", r"(\d+)", 1).alias("d")).collect() [Row(d='')] - >>> df = spark.createDataFrame([('aaaac',)], ['str']) - >>> df.select(regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect() + >>> df = spark.createDataFrame([("aaaac",)], ["str"]) + >>> df.select(regexp_extract("str", "(a+)(b)?(c)", 2).alias("d")).collect() [Row(d='')] """ - return _invoke_function("regexp_extract", _to_column_expr(str), ConstantExpression(pattern), ConstantExpression(idx)) + return _invoke_function( + "regexp_extract", _to_column_expr(str), ConstantExpression(pattern), ConstantExpression(idx) + ) -def regexp_extract_all( - str: "ColumnOrName", regexp: "ColumnOrName", idx: Optional[Union[int, Column]] = None -) -> Column: +def regexp_extract_all(str: "ColumnOrName", regexp: "ColumnOrName", idx: Optional[Union[int, Column]] = None) -> Column: r"""Extract all strings in the `str` that match the Java regex `regexp` and corresponding to the regex group index. @@ -2535,18 +2571,20 @@ def regexp_extract_all( Examples -------- >>> df = spark.createDataFrame([("100-200, 300-400", r"(\d+)-(\d+)")], ["str", "regexp"]) - >>> df.select(regexp_extract_all('str', lit(r'(\d+)-(\d+)')).alias('d')).collect() + >>> df.select(regexp_extract_all("str", lit(r"(\d+)-(\d+)")).alias("d")).collect() [Row(d=['100', '300'])] - >>> df.select(regexp_extract_all('str', lit(r'(\d+)-(\d+)'), 1).alias('d')).collect() + >>> df.select(regexp_extract_all("str", lit(r"(\d+)-(\d+)"), 1).alias("d")).collect() [Row(d=['100', '300'])] - >>> df.select(regexp_extract_all('str', lit(r'(\d+)-(\d+)'), 2).alias('d')).collect() + >>> df.select(regexp_extract_all("str", lit(r"(\d+)-(\d+)"), 2).alias("d")).collect() [Row(d=['200', '400'])] - >>> df.select(regexp_extract_all('str', col("regexp")).alias('d')).collect() + >>> df.select(regexp_extract_all("str", col("regexp")).alias("d")).collect() [Row(d=['100', '300'])] """ if idx is None: idx = 1 - return _invoke_function("regexp_extract_all", _to_column_expr(str), _to_column_expr(regexp), ConstantExpression(idx)) + return _invoke_function( + "regexp_extract_all", _to_column_expr(str), _to_column_expr(regexp), ConstantExpression(idx) + ) def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: @@ -2569,9 +2607,9 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: Examples -------- >>> import pyspark.sql.functions as sf - >>> spark.createDataFrame( - ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] - ... ).select(sf.regexp_like('str', sf.lit(r'(\d+)'))).show() + >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( + ... sf.regexp_like("str", sf.lit(r"(\d+)")) + ... ).show() +-----------------------+ |REGEXP_LIKE(str, (\d+))| +-----------------------+ @@ -2579,9 +2617,9 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: +-----------------------+ >>> import pyspark.sql.functions as sf - >>> spark.createDataFrame( - ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] - ... ).select(sf.regexp_like('str', sf.lit(r'\d{2}b'))).show() + >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( + ... sf.regexp_like("str", sf.lit(r"\d{2}b")) + ... ).show() +------------------------+ |REGEXP_LIKE(str, \d{2}b)| +------------------------+ @@ -2589,9 +2627,9 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: +------------------------+ >>> import pyspark.sql.functions as sf - >>> spark.createDataFrame( - ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] - ... ).select(sf.regexp_like('str', sf.col("regexp"))).show() + >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( + ... sf.regexp_like("str", sf.col("regexp")) + ... ).show() +------------------------+ |REGEXP_LIKE(str, regexp)| +------------------------+ @@ -2622,14 +2660,20 @@ def regexp_substr(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: Examples -------- >>> df = spark.createDataFrame([("1a 2b 14m", r"\d+")], ["str", "regexp"]) - >>> df.select(regexp_substr('str', lit(r'\d+')).alias('d')).collect() + >>> df.select(regexp_substr("str", lit(r"\d+")).alias("d")).collect() [Row(d='1')] - >>> df.select(regexp_substr('str', lit(r'mmm')).alias('d')).collect() + >>> df.select(regexp_substr("str", lit(r"mmm")).alias("d")).collect() [Row(d=None)] - >>> df.select(regexp_substr("str", col("regexp")).alias('d')).collect() + >>> df.select(regexp_substr("str", col("regexp")).alias("d")).collect() [Row(d='1')] """ - return Column(FunctionExpression("nullif", FunctionExpression("regexp_extract", _to_column_expr(str), _to_column_expr(regexp)), ConstantExpression(""))) + return Column( + FunctionExpression( + "nullif", + FunctionExpression("regexp_extract", _to_column_expr(str), _to_column_expr(regexp)), + ConstantExpression(""), + ) + ) def repeat(col: "ColumnOrName", n: int) -> Column: @@ -2655,16 +2699,19 @@ def repeat(col: "ColumnOrName", n: int) -> Column: Examples -------- - >>> df = spark.createDataFrame([('ab',)], ['s',]) - >>> df.select(repeat(df.s, 3).alias('s')).collect() + >>> df = spark.createDataFrame( + ... [("ab",)], + ... [ + ... "s", + ... ], + ... ) + >>> df.select(repeat(df.s, 3).alias("s")).collect() [Row(s='ababab')] """ return _invoke_function("repeat", _to_column_expr(col), ConstantExpression(n)) -def sequence( - start: "ColumnOrName", stop: "ColumnOrName", step: Optional["ColumnOrName"] = None -) -> Column: +def sequence(start: "ColumnOrName", stop: "ColumnOrName", step: Optional["ColumnOrName"] = None) -> Column: """ Generate a sequence of integers from `start` to `stop`, incrementing by `step`. If `step` is not set, incrementing by 1 if `start` is less than or equal to `stop`, @@ -2691,11 +2738,11 @@ def sequence( Examples -------- - >>> df1 = spark.createDataFrame([(-2, 2)], ('C1', 'C2')) - >>> df1.select(sequence('C1', 'C2').alias('r')).collect() + >>> df1 = spark.createDataFrame([(-2, 2)], ("C1", "C2")) + >>> df1.select(sequence("C1", "C2").alias("r")).collect() [Row(r=[-2, -1, 0, 1, 2])] - >>> df2 = spark.createDataFrame([(4, -4, -2)], ('C1', 'C2', 'C3')) - >>> df2.select(sequence('C1', 'C2', 'C3').alias('r')).collect() + >>> df2 = spark.createDataFrame([(4, -4, -2)], ("C1", "C2", "C3")) + >>> df2.select(sequence("C1", "C2", "C3").alias("r")).collect() [Row(r=[4, 2, 0, -2, -4])] """ if step is None: @@ -2726,10 +2773,7 @@ def sign(col: "ColumnOrName") -> Column: Examples -------- >>> import pyspark.sql.functions as sf - >>> spark.range(1).select( - ... sf.sign(sf.lit(-5)), - ... sf.sign(sf.lit(6)) - ... ).show() + >>> spark.range(1).select(sf.sign(sf.lit(-5)), sf.sign(sf.lit(6))).show() +--------+-------+ |sign(-5)|sign(6)| +--------+-------+ @@ -2761,10 +2805,7 @@ def signum(col: "ColumnOrName") -> Column: Examples -------- >>> import pyspark.sql.functions as sf - >>> spark.range(1).select( - ... sf.signum(sf.lit(-5)), - ... sf.signum(sf.lit(6)) - ... ).show() + >>> spark.range(1).select(sf.signum(sf.lit(-5)), sf.signum(sf.lit(6))).show() +----------+---------+ |SIGNUM(-5)|SIGNUM(6)| +----------+---------+ @@ -2824,7 +2865,7 @@ def skewness(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([[1],[1],[2]], ["c"]) + >>> df = spark.createDataFrame([[1], [1], [2]], ["c"]) >>> df.select(skewness(df.c)).first() Row(skewness(c)=0.70710...) """ @@ -2855,7 +2896,7 @@ def encode(col: "ColumnOrName", charset: str) -> Column: Examples -------- - >>> df = spark.createDataFrame([('abcd',)], ['c']) + >>> df = spark.createDataFrame([("abcd",)], ["c"]) >>> df.select(encode("c", "UTF-8")).show() +----------------+ |encode(c, UTF-8)| @@ -2885,24 +2926,20 @@ def find_in_set(str: "ColumnOrName", str_array: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("ab", "abc,b,ab,c,def")], ['a', 'b']) - >>> df.select(find_in_set(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame([("ab", "abc,b,ab,c,def")], ["a", "b"]) + >>> df.select(find_in_set(df.a, df.b).alias("r")).collect() [Row(r=3)] """ str_array = _to_column_expr(str_array) str = _to_column_expr(str) return Column( - CaseExpression( - FunctionExpression("contains", str, ConstantExpression(",")), 0 - ).otherwise( + CaseExpression(FunctionExpression("contains", str, ConstantExpression(",")), 0).otherwise( CoalesceOperator( FunctionExpression( - "list_position", - FunctionExpression("string_split", str_array, ConstantExpression(",")), - str + "list_position", FunctionExpression("string_split", str_array, ConstantExpression(",")), str ), # If the element cannot be found, list_position returns null but we want to return 0 - ConstantExpression(0) + ConstantExpression(0), ) ) ) @@ -3018,7 +3055,6 @@ def last(col: "ColumnOrName", ignorenulls: bool = False) -> Column: return _invoke_function_over_columns("last", col) - def greatest(*cols: "ColumnOrName") -> Column: """ Returns the greatest value of the list of column names, skipping null values. @@ -3041,7 +3077,7 @@ def greatest(*cols: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) + >>> df = spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"]) >>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect() [Row(greatest=4)] """ @@ -3075,7 +3111,7 @@ def least(*cols: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) + >>> df = spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"]) >>> df.select(least(df.a, df.b, df.c).alias("least")).collect() [Row(least=1)] """ @@ -3203,12 +3239,20 @@ def btrim(str: "ColumnOrName", trim: Optional["ColumnOrName"] = None) -> Column: Examples -------- - >>> df = spark.createDataFrame([("SSparkSQLS", "SL", )], ['a', 'b']) - >>> df.select(btrim(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "SSparkSQLS", + ... "SL", + ... ) + ... ], + ... ["a", "b"], + ... ) + >>> df.select(btrim(df.a, df.b).alias("r")).collect() [Row(r='parkSQ')] - >>> df = spark.createDataFrame([(" SparkSQL ",)], ['a']) - >>> df.select(btrim(df.a).alias('r')).collect() + >>> df = spark.createDataFrame([(" SparkSQL ",)], ["a"]) + >>> df.select(btrim(df.a).alias("r")).collect() [Row(r='SparkSQL')] """ if trim is not None: @@ -3234,11 +3278,27 @@ def endswith(str: "ColumnOrName", suffix: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("Spark SQL", "Spark",)], ["a", "b"]) - >>> df.select(endswith(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "Spark SQL", + ... "Spark", + ... ) + ... ], + ... ["a", "b"], + ... ) + >>> df.select(endswith(df.a, df.b).alias("r")).collect() [Row(r=False)] - >>> df = spark.createDataFrame([("414243", "4243",)], ["e", "f"]) + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "414243", + ... "4243", + ... ) + ... ], + ... ["e", "f"], + ... ) >>> df = df.select(to_binary("e").alias("e"), to_binary("f").alias("f")) >>> df.printSchema() root @@ -3271,11 +3331,27 @@ def startswith(str: "ColumnOrName", prefix: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("Spark SQL", "Spark",)], ["a", "b"]) - >>> df.select(startswith(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "Spark SQL", + ... "Spark", + ... ) + ... ], + ... ["a", "b"], + ... ) + >>> df.select(startswith(df.a, df.b).alias("r")).collect() [Row(r=True)] - >>> df = spark.createDataFrame([("414243", "4142",)], ["e", "f"]) + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "414243", + ... "4142", + ... ) + ... ], + ... ["e", "f"], + ... ) >>> df = df.select(to_binary("e").alias("e"), to_binary("f").alias("f")) >>> df.printSchema() root @@ -3313,7 +3389,7 @@ def length(col: "ColumnOrName") -> Column: Examples -------- - >>> spark.createDataFrame([('ABC ',)], ['a']).select(length('a').alias('length')).collect() + >>> spark.createDataFrame([("ABC ",)], ["a"]).select(length("a").alias("length")).collect() [Row(length=4)] """ return _invoke_function_over_columns("length", col) @@ -3351,7 +3427,7 @@ def coalesce(*cols: "ColumnOrName") -> Column: | 1| | 2| +--------------+ - >>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show() + >>> cDf.select("*", coalesce(cDf["a"], lit(0.0))).show() +----+----+----------------+ | a| b|coalesce(a, 0.0)| +----+----+----------------+ @@ -3375,8 +3451,20 @@ def nvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col2 : :class:`~pyspark.sql.Column` or str Examples -------- - >>> df = spark.createDataFrame([(None, 8,), (1, 9,)], ["a", "b"]) - >>> df.select(nvl(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... None, + ... 8, + ... ), + ... ( + ... 1, + ... 9, + ... ), + ... ], + ... ["a", "b"], + ... ) + >>> df.select(nvl(df.a, df.b).alias("r")).collect() [Row(r=8), Row(r=1)] """ @@ -3397,8 +3485,22 @@ def nvl2(col1: "ColumnOrName", col2: "ColumnOrName", col3: "ColumnOrName") -> Co Examples -------- - >>> df = spark.createDataFrame([(None, 8, 6,), (1, 9, 9,)], ["a", "b", "c"]) - >>> df.select(nvl2(df.a, df.b, df.c).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... None, + ... 8, + ... 6, + ... ), + ... ( + ... 1, + ... 9, + ... 9, + ... ), + ... ], + ... ["a", "b", "c"], + ... ) + >>> df.select(nvl2(df.a, df.b, df.c).alias("r")).collect() [Row(r=6), Row(r=9)] """ col1 = _to_column_expr(col1) @@ -3443,8 +3545,20 @@ def nullif(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(None, None,), (1, 9,)], ["a", "b"]) - >>> df.select(nullif(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... None, + ... None, + ... ), + ... ( + ... 1, + ... 9, + ... ), + ... ], + ... ["a", "b"], + ... ) + >>> df.select(nullif(df.a, df.b).alias("r")).collect() [Row(r=None), Row(r=1)] """ return _invoke_function_over_columns("nullif", col1, col2) @@ -3470,7 +3584,7 @@ def md5(col: "ColumnOrName") -> Column: Examples -------- - >>> spark.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() + >>> spark.createDataFrame([("ABC",)], ["a"]).select(md5("a").alias("hash")).collect() [Row(hash='902fbdd2b1df0c4f70b4a5d23525e932')] """ return _invoke_function_over_columns("md5", col) @@ -3517,9 +3631,7 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column: if numBits == 256: return _invoke_function_over_columns("sha256", col) - raise ContributionsAcceptedError( - "SHA-224, SHA-384, and SHA-512 are not supported yet." - ) + raise ContributionsAcceptedError("SHA-224, SHA-384, and SHA-512 are not supported yet.") def curdate() -> Column: @@ -3537,7 +3649,7 @@ def curdate() -> Column: Examples -------- >>> import pyspark.sql.functions as sf - >>> spark.range(1).select(sf.curdate()).show() # doctest: +SKIP + >>> spark.range(1).select(sf.curdate()).show() # doctest: +SKIP +--------------+ |current_date()| +--------------+ @@ -3565,7 +3677,7 @@ def current_date() -> Column: Examples -------- >>> df = spark.range(1) - >>> df.select(current_date()).show() # doctest: +SKIP + >>> df.select(current_date()).show() # doctest: +SKIP +--------------+ |current_date()| +--------------+ @@ -3589,7 +3701,7 @@ def now() -> Column: Examples -------- >>> df = spark.range(1) - >>> df.select(now()).show(truncate=False) # doctest: +SKIP + >>> df.select(now()).show(truncate=False) # doctest: +SKIP +-----------------------+ |now() | +-----------------------+ @@ -3598,6 +3710,7 @@ def now() -> Column: """ return _invoke_function("now") + def desc(col: "ColumnOrName") -> Column: """ Returns a sort expression based on the descending order of the given column name. @@ -3634,6 +3747,7 @@ def desc(col: "ColumnOrName") -> Column: """ return Column(_to_column_expr(col).desc()) + def asc(col: "ColumnOrName") -> Column: """ Returns a sort expression based on the ascending order of the given column name. @@ -3685,6 +3799,7 @@ def asc(col: "ColumnOrName") -> Column: """ return Column(_to_column_expr(col).asc()) + def date_trunc(format: str, timestamp: "ColumnOrName") -> Column: """ Returns timestamp truncated to the unit specified by the format. @@ -3700,10 +3815,10 @@ def date_trunc(format: str, timestamp: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('1997-02-28 05:02:11',)], ['t']) - >>> df.select(date_trunc('year', df.t).alias('year')).collect() + >>> df = spark.createDataFrame([("1997-02-28 05:02:11",)], ["t"]) + >>> df.select(date_trunc("year", df.t).alias("year")).collect() [Row(year=datetime.datetime(1997, 1, 1, 0, 0))] - >>> df.select(date_trunc('mon', df.t).alias('month')).collect() + >>> df.select(date_trunc("mon", df.t).alias("month")).collect() [Row(month=datetime.datetime(1997, 2, 1, 0, 0))] """ format = format.lower() @@ -3740,14 +3855,14 @@ def date_part(field: "ColumnOrName", source: "ColumnOrName") -> Column: Examples -------- >>> import datetime - >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) >>> df.select( - ... date_part(lit('YEAR'), 'ts').alias('year'), - ... date_part(lit('month'), 'ts').alias('month'), - ... date_part(lit('WEEK'), 'ts').alias('week'), - ... date_part(lit('D'), 'ts').alias('day'), - ... date_part(lit('M'), 'ts').alias('minute'), - ... date_part(lit('S'), 'ts').alias('second') + ... date_part(lit("YEAR"), "ts").alias("year"), + ... date_part(lit("month"), "ts").alias("month"), + ... date_part(lit("WEEK"), "ts").alias("week"), + ... date_part(lit("D"), "ts").alias("day"), + ... date_part(lit("M"), "ts").alias("minute"), + ... date_part(lit("S"), "ts").alias("second"), ... ).collect() [Row(year=2015, month=4, week=15, day=8, minute=8, second=Decimal('15.000000'))] """ @@ -3775,14 +3890,14 @@ def extract(field: "ColumnOrName", source: "ColumnOrName") -> Column: Examples -------- >>> import datetime - >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) >>> df.select( - ... extract(lit('YEAR'), 'ts').alias('year'), - ... extract(lit('month'), 'ts').alias('month'), - ... extract(lit('WEEK'), 'ts').alias('week'), - ... extract(lit('D'), 'ts').alias('day'), - ... extract(lit('M'), 'ts').alias('minute'), - ... extract(lit('S'), 'ts').alias('second') + ... extract(lit("YEAR"), "ts").alias("year"), + ... extract(lit("month"), "ts").alias("month"), + ... extract(lit("WEEK"), "ts").alias("week"), + ... extract(lit("D"), "ts").alias("day"), + ... extract(lit("M"), "ts").alias("minute"), + ... extract(lit("S"), "ts").alias("second"), ... ).collect() [Row(year=2015, month=4, week=15, day=8, minute=8, second=Decimal('15.000000'))] """ @@ -3811,14 +3926,14 @@ def datepart(field: "ColumnOrName", source: "ColumnOrName") -> Column: Examples -------- >>> import datetime - >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) >>> df.select( - ... datepart(lit('YEAR'), 'ts').alias('year'), - ... datepart(lit('month'), 'ts').alias('month'), - ... datepart(lit('WEEK'), 'ts').alias('week'), - ... datepart(lit('D'), 'ts').alias('day'), - ... datepart(lit('M'), 'ts').alias('minute'), - ... datepart(lit('S'), 'ts').alias('second') + ... datepart(lit("YEAR"), "ts").alias("year"), + ... datepart(lit("month"), "ts").alias("month"), + ... datepart(lit("WEEK"), "ts").alias("week"), + ... datepart(lit("D"), "ts").alias("day"), + ... datepart(lit("M"), "ts").alias("minute"), + ... datepart(lit("S"), "ts").alias("second"), ... ).collect() [Row(year=2015, month=4, week=15, day=8, minute=8, second=Decimal('15.000000'))] """ @@ -3854,15 +3969,19 @@ def date_diff(end: "ColumnOrName", start: "ColumnOrName") -> Column: Examples -------- >>> import pyspark.sql.functions as sf - >>> df = spark.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2']) - >>> df.select('*', sf.date_diff(sf.col('d1').cast('DATE'), sf.col('d2').cast('DATE'))).show() + >>> df = spark.createDataFrame([("2015-04-08", "2015-05-10")], ["d1", "d2"]) + >>> df.select( + ... "*", sf.date_diff(sf.col("d1").cast("DATE"), sf.col("d2").cast("DATE")) + ... ).show() +----------+----------+-----------------+ | d1| d2|date_diff(d1, d2)| +----------+----------+-----------------+ |2015-04-08|2015-05-10| -32| +----------+----------+-----------------+ - >>> df.select('*', sf.date_diff(sf.col('d1').cast('DATE'), sf.col('d2').cast('DATE'))).show() + >>> df.select( + ... "*", sf.date_diff(sf.col("d1").cast("DATE"), sf.col("d2").cast("DATE")) + ... ).show() +----------+----------+-----------------+ | d1| d2|date_diff(d2, d1)| +----------+----------+-----------------+ @@ -3893,8 +4012,8 @@ def year(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(year('dt').alias('year')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(year("dt").alias("year")).collect() [Row(year=2015)] """ return _invoke_function_over_columns("year", col) @@ -3921,8 +4040,8 @@ def quarter(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(quarter('dt').alias('quarter')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(quarter("dt").alias("quarter")).collect() [Row(quarter=2)] """ return _invoke_function_over_columns("quarter", col) @@ -3949,8 +4068,8 @@ def month(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(month('dt').alias('month')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(month("dt").alias("month")).collect() [Row(month=4)] """ return _invoke_function_over_columns("month", col) @@ -3978,8 +4097,8 @@ def dayofweek(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(dayofweek('dt').alias('day')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(dayofweek("dt").alias("day")).collect() [Row(day=4)] """ return _invoke_function_over_columns("dayofweek", col) + lit(1) @@ -4003,8 +4122,8 @@ def day(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(day('dt').alias('day')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(day("dt").alias("day")).collect() [Row(day=8)] """ return _invoke_function_over_columns("day", col) @@ -4031,8 +4150,8 @@ def dayofmonth(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(dayofmonth('dt').alias('day')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(dayofmonth("dt").alias("day")).collect() [Row(day=8)] """ return day(col) @@ -4059,8 +4178,8 @@ def dayofyear(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(dayofyear('dt').alias('day')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(dayofyear("dt").alias("day")).collect() [Row(day=98)] """ return _invoke_function_over_columns("dayofyear", col) @@ -4088,8 +4207,8 @@ def hour(col: "ColumnOrName") -> Column: Examples -------- >>> import datetime - >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) - >>> df.select(hour('ts').alias('hour')).collect() + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) + >>> df.select(hour("ts").alias("hour")).collect() [Row(hour=13)] """ return _invoke_function_over_columns("hour", col) @@ -4117,8 +4236,8 @@ def minute(col: "ColumnOrName") -> Column: Examples -------- >>> import datetime - >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) - >>> df.select(minute('ts').alias('minute')).collect() + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) + >>> df.select(minute("ts").alias("minute")).collect() [Row(minute=8)] """ return _invoke_function_over_columns("minute", col) @@ -4146,8 +4265,8 @@ def second(col: "ColumnOrName") -> Column: Examples -------- >>> import datetime - >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) - >>> df.select(second('ts').alias('second')).collect() + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) + >>> df.select(second("ts").alias("second")).collect() [Row(second=15)] """ return _invoke_function_over_columns("second", col) @@ -4176,8 +4295,8 @@ def weekofyear(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(weekofyear(df.dt).alias('week')).collect() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(weekofyear(df.dt).alias("week")).collect() [Row(week=15)] """ return _invoke_function_over_columns("weekofyear", col) @@ -4267,7 +4386,7 @@ def call_function(funcName: str, *cols: "ColumnOrName") -> Column: -------- >>> from pyspark.sql.functions import call_udf, col >>> from pyspark.sql.types import IntegerType, StringType - >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "c")],["id", "name"]) + >>> df = spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")], ["id", "name"]) >>> _ = spark.udf.register("intX2", lambda i: i * 2, IntegerType()) >>> df.select(call_function("intX2", "id")).show() +---------+ @@ -4338,7 +4457,7 @@ def covar_pop(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> a = [1] * 10 >>> b = [1] * 10 >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) - >>> df.agg(covar_pop("a", "b").alias('c')).collect() + >>> df.agg(covar_pop("a", "b").alias("c")).collect() [Row(c=0.0)] """ return _invoke_function_over_columns("covar_pop", col1, col2) @@ -4370,7 +4489,7 @@ def covar_samp(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> a = [1] * 10 >>> b = [1] * 10 >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) - >>> df.agg(covar_samp("a", "b").alias('c')).collect() + >>> df.agg(covar_samp("a", "b").alias("c")).collect() [Row(c=0.0)] """ return _invoke_function_over_columns("covar_samp", col1, col2) @@ -4429,8 +4548,8 @@ def factorial(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(5,)], ['n']) - >>> df.select(factorial(df.n).alias('f')).collect() + >>> df = spark.createDataFrame([(5,)], ["n"]) + >>> df.select(factorial(df.n).alias("f")).collect() [Row(f=120)] """ return _invoke_function_over_columns("factorial", col) @@ -4456,8 +4575,8 @@ def log2(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(4,)], ['a']) - >>> df.select(log2('a').alias('log2')).show() + >>> df = spark.createDataFrame([(4,)], ["a"]) + >>> df.select(log2("a").alias("log2")).show() +----+ |log2| +----+ @@ -4484,8 +4603,8 @@ def ln(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(4,)], ['a']) - >>> df.select(ln('a')).show() + >>> df = spark.createDataFrame([(4,)], ["a"]) + >>> df.select(ln("a")).show() +------------------+ | ln(a)| +------------------+ @@ -4525,7 +4644,6 @@ def degrees(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("degrees", col) - def radians(col: "ColumnOrName") -> Column: """ Converts an angle measured in degrees to an approximately equivalent angle @@ -4616,10 +4734,12 @@ def atan2(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float] >>> df.select(atan2(lit(1), lit(2))).first() Row(ATAN2(1, 2)=0.46364...) """ + def lit_or_column(x: Union["ColumnOrName", float]) -> Column: if isinstance(x, (int, float)): return lit(x) return x + return _invoke_function_over_columns("atan2", lit_or_column(col1), lit_or_column(col2)) @@ -4676,7 +4796,7 @@ def round(col: "ColumnOrName", scale: int = 0) -> Column: Examples -------- - >>> spark.createDataFrame([(2.5,)], ['a']).select(round('a', 0).alias('r')).collect() + >>> spark.createDataFrame([(2.5,)], ["a"]).select(round("a", 0).alias("r")).collect() [Row(r=3.0)] """ return _invoke_function_over_columns("round", col, lit(scale)) @@ -4706,7 +4826,7 @@ def bround(col: "ColumnOrName", scale: int = 0) -> Column: Examples -------- - >>> spark.createDataFrame([(2.5,)], ['a']).select(bround('a', 0).alias('r')).collect() + >>> spark.createDataFrame([(2.5,)], ["a"]).select(bround("a", 0).alias("r")).collect() [Row(r=2.0)] """ return _invoke_function_over_columns("round_even", col, lit(scale)) @@ -4743,7 +4863,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: Examples -------- - >>> df = spark.createDataFrame([(["a", "b", "c"], 1)], ['data', 'index']) + >>> df = spark.createDataFrame([(["a", "b", "c"], 1)], ["data", "index"]) >>> df.select(get(df.data, 1)).show() +------------+ |get(data, 1)| @@ -4806,7 +4926,7 @@ def initcap(col: "ColumnOrName") -> Column: Examples -------- - >>> spark.createDataFrame([('ab cd',)], ['a']).select(initcap("a").alias('v')).collect() + >>> spark.createDataFrame([("ab cd",)], ["a"]).select(initcap("a").alias("v")).collect() [Row(v='Ab Cd')] """ return Column( @@ -4814,18 +4934,14 @@ def initcap(col: "ColumnOrName") -> Column: "array_to_string", FunctionExpression( "list_transform", - FunctionExpression( - "string_split", _to_column_expr(col), ConstantExpression(" ") - ), + FunctionExpression("string_split", _to_column_expr(col), ConstantExpression(" ")), LambdaExpression( "x", FunctionExpression( "concat", FunctionExpression( "upper", - FunctionExpression( - "array_extract", ColumnExpression("x"), 1 - ), + FunctionExpression("array_extract", ColumnExpression("x"), 1), ), FunctionExpression("array_slice", ColumnExpression("x"), 2, -1), ), @@ -4858,7 +4974,7 @@ def octet_length(col: "ColumnOrName") -> Column: Examples -------- >>> from pyspark.sql.functions import octet_length - >>> spark.createDataFrame([('cat',), ( '\U0001F408',)], ['cat']) \\ + >>> spark.createDataFrame([('cat',), ( '\U0001f408',)], ['cat']) \\ ... .select(octet_length('cat')).collect() [Row(octet_length(cat)=3), Row(octet_length(cat)=4)] """ @@ -4886,7 +5002,7 @@ def hex(col: "ColumnOrName") -> Column: Examples -------- - >>> spark.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() + >>> spark.createDataFrame([("ABC", 3)], ["a", "b"]).select(hex("a"), hex("b")).collect() [Row(hex(a)='414243', hex(b)='3')] """ return _invoke_function_over_columns("hex", col) @@ -4913,7 +5029,7 @@ def unhex(col: "ColumnOrName") -> Column: Examples -------- - >>> spark.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() + >>> spark.createDataFrame([("414243",)], ["a"]).select(unhex("a")).collect() [Row(unhex(a)=bytearray(b'ABC'))] """ return _invoke_function_over_columns("unhex", col) @@ -4950,7 +5066,7 @@ def base64(col: "ColumnOrName") -> Column: |UGFuZGFzIEFQSQ==| +----------------+ """ - if isinstance(col,str): + if isinstance(col, str): col = Column(ColumnExpression(col)) return _invoke_function_over_columns("base64", col.cast("BLOB")) @@ -4976,9 +5092,7 @@ def unbase64(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame(["U3Bhcms=", - ... "UHlTcGFyaw==", - ... "UGFuZGFzIEFQSQ=="], "STRING") + >>> df = spark.createDataFrame(["U3Bhcms=", "UHlTcGFyaw==", "UGFuZGFzIEFQSQ=="], "STRING") >>> df.select(unbase64("value")).show() +--------------------+ | unbase64(value)| @@ -5016,21 +5130,19 @@ def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Col Examples -------- - >>> df = spark.createDataFrame([('2015-04-08', 2)], ['dt', 'add']) - >>> df.select(add_months(df.dt, 1).alias('next_month')).collect() + >>> df = spark.createDataFrame([("2015-04-08", 2)], ["dt", "add"]) + >>> df.select(add_months(df.dt, 1).alias("next_month")).collect() [Row(next_month=datetime.date(2015, 5, 8))] - >>> df.select(add_months(df.dt, df.add.cast('integer')).alias('next_month')).collect() + >>> df.select(add_months(df.dt, df.add.cast("integer")).alias("next_month")).collect() [Row(next_month=datetime.date(2015, 6, 8))] - >>> df.select(add_months('dt', -2).alias('prev_month')).collect() + >>> df.select(add_months("dt", -2).alias("prev_month")).collect() [Row(prev_month=datetime.date(2015, 2, 8))] """ months = ConstantExpression(months) if isinstance(months, int) else _to_column_expr(months) return _invoke_function("date_add", _to_column_expr(start), FunctionExpression("to_months", months)).cast("date") -def array_join( - col: "ColumnOrName", delimiter: str, null_replacement: Optional[str] = None -) -> Column: +def array_join(col: "ColumnOrName", delimiter: str, null_replacement: Optional[str] = None) -> Column: """ Concatenates the elements of `column` using the `delimiter`. Null values are replaced with `null_replacement` if set, otherwise they are ignored. @@ -5056,7 +5168,7 @@ def array_join( Examples -------- - >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data']) + >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ["data"]) >>> df.select(array_join(df.data, ",").alias("joined")).collect() [Row(joined='a,b,c'), Row(joined='a')] >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect() @@ -5065,7 +5177,14 @@ def array_join( col = _to_column_expr(col) if null_replacement is not None: col = FunctionExpression( - "list_transform", col, LambdaExpression("x", CaseExpression(ColumnExpression("x").isnull(), ConstantExpression(null_replacement)).otherwise(ColumnExpression("x"))) + "list_transform", + col, + LambdaExpression( + "x", + CaseExpression(ColumnExpression("x").isnull(), ConstantExpression(null_replacement)).otherwise( + ColumnExpression("x") + ), + ), ) return _invoke_function("array_to_string", col, ConstantExpression(delimiter)) @@ -5099,11 +5218,15 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: Examples -------- - >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data']) + >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ["data"]) >>> df.select(array_position(df.data, "a")).collect() [Row(array_position(data, a)=3), Row(array_position(data, a)=0)] """ - return Column(CoalesceOperator(_to_column_expr(_invoke_function_over_columns("list_position", col, lit(value))), ConstantExpression(0))) + return Column( + CoalesceOperator( + _to_column_expr(_invoke_function_over_columns("list_position", col, lit(value))), ConstantExpression(0) + ) + ) def array_prepend(col: "ColumnOrName", value: Any) -> Column: @@ -5128,7 +5251,7 @@ def array_prepend(col: "ColumnOrName", value: Any) -> Column: Examples -------- - >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) + >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ["data"]) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] """ @@ -5158,8 +5281,8 @@ def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Colu Examples -------- - >>> df = spark.createDataFrame([('ab',)], ['data']) - >>> df.select(array_repeat(df.data, 3).alias('r')).collect() + >>> df = spark.createDataFrame([("ab",)], ["data"]) + >>> df.select(array_repeat(df.data, 3).alias("r")).collect() [Row(r=['ab', 'ab', 'ab'])] """ count = lit(count) if isinstance(count, int) else count @@ -5185,15 +5308,14 @@ def array_size(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([([2, 1, 3],), (None,)], ['data']) - >>> df.select(array_size(df.data).alias('r')).collect() + >>> df = spark.createDataFrame([([2, 1, 3],), (None,)], ["data"]) + >>> df.select(array_size(df.data).alias("r")).collect() [Row(r=3), Row(r=None)] """ return _invoke_function_over_columns("len", col) -def array_sort( - col: "ColumnOrName", comparator: Optional[Callable[[Column, Column], Column]] = None -) -> Column: + +def array_sort(col: "ColumnOrName", comparator: Optional[Callable[[Column, Column], Column]] = None) -> Column: """ Collection function: sorts the input array in ascending order. The elements of the input array must be orderable. Null elements will be placed at the end of the returned array. @@ -5224,14 +5346,20 @@ def array_sort( Examples -------- - >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) - >>> df.select(array_sort(df.data).alias('r')).collect() + >>> df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ["data"]) + >>> df.select(array_sort(df.data).alias("r")).collect() [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])] - >>> df = spark.createDataFrame([(["foo", "foobar", None, "bar"],),(["foo"],),([],)], ['data']) - >>> df.select(array_sort( - ... "data", - ... lambda x, y: when(x.isNull() | y.isNull(), lit(0)).otherwise(length(y) - length(x)) - ... ).alias("r")).collect() + >>> df = spark.createDataFrame( + ... [(["foo", "foobar", None, "bar"],), (["foo"],), ([],)], ["data"] + ... ) + >>> df.select( + ... array_sort( + ... "data", + ... lambda x, y: when(x.isNull() | y.isNull(), lit(0)).otherwise( + ... length(y) - length(x) + ... ), + ... ).alias("r") + ... ).collect() [Row(r=['foobar', 'foo', None, 'bar']), Row(r=['foo']), Row(r=[])] """ if comparator is not None: @@ -5267,10 +5395,10 @@ def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: Examples -------- - >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) - >>> df.select(sort_array(df.data).alias('r')).collect() + >>> df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ["data"]) + >>> df.select(sort_array(df.data).alias("r")).collect() [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])] - >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() + >>> df.select(sort_array(df.data, asc=False).alias("r")).collect() [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] """ if asc: @@ -5317,10 +5445,15 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: Examples -------- - >>> df = spark.createDataFrame([('oneAtwoBthreeC',)], ['s',]) - >>> df.select(split(df.s, '[ABC]', 2).alias('s')).collect() + >>> df = spark.createDataFrame( + ... [("oneAtwoBthreeC",)], + ... [ + ... "s", + ... ], + ... ) + >>> df.select(split(df.s, "[ABC]", 2).alias("s")).collect() [Row(s=['one', 'twoBthreeC'])] - >>> df.select(split(df.s, '[ABC]', -1).alias('s')).collect() + >>> df.select(split(df.s, "[ABC]", -1).alias("s")).collect() [Row(s=['one', 'two', 'three', ''])] """ if limit > 0: @@ -5351,8 +5484,17 @@ def split_part(src: "ColumnOrName", delimiter: "ColumnOrName", partNum: "ColumnO Examples -------- - >>> df = spark.createDataFrame([("11.12.13", ".", 3,)], ["a", "b", "c"]) - >>> df.select(split_part(df.a, df.b, df.c).alias('r')).collect() + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "11.12.13", + ... ".", + ... 3, + ... ) + ... ], + ... ["a", "b", "c"], + ... ) + >>> df.select(split_part(df.a, df.b, df.c).alias("r")).collect() [Row(r='13')] """ src = _to_column_expr(src) @@ -5360,7 +5502,11 @@ def split_part(src: "ColumnOrName", delimiter: "ColumnOrName", partNum: "ColumnO partNum = _to_column_expr(partNum) part = FunctionExpression("split_part", src, delimiter, partNum) - return Column(CaseExpression(src.isnull() | delimiter.isnull() | partNum.isnull(), ConstantExpression(None)).otherwise(CaseExpression(delimiter == ConstantExpression(""), ConstantExpression("")).otherwise(part))) + return Column( + CaseExpression(src.isnull() | delimiter.isnull() | partNum.isnull(), ConstantExpression(None)).otherwise( + CaseExpression(delimiter == ConstantExpression(""), ConstantExpression("")).otherwise(part) + ) + ) def stddev_samp(col: "ColumnOrName") -> Column: @@ -5427,6 +5573,7 @@ def stddev(col: "ColumnOrName") -> Column: """ return stddev_samp(col) + def std(col: "ColumnOrName") -> Column: """ Aggregate function: alias for stddev_samp. @@ -5600,8 +5747,8 @@ def weekday(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) - >>> df.select(weekday('dt').alias('day')).show() + >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + >>> df.select(weekday("dt").alias("day")).show() +---+ |day| +---+ @@ -5634,6 +5781,7 @@ def zeroifnull(col: "ColumnOrName") -> Column: """ return coalesce(col, lit(0)) + def _to_date_or_timestamp(col: "ColumnOrName", spark_datatype: _types.DataType, format: Optional[str] = None) -> Column: if format is not None: raise ContributionsAcceptedError( @@ -5670,12 +5818,12 @@ def to_date(col: "ColumnOrName", format: Optional[str] = None) -> Column: Examples -------- - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(to_date(df.t).alias('date')).collect() + >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + >>> df.select(to_date(df.t).alias("date")).collect() [Row(date=datetime.date(1997, 2, 28))] - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(to_date(df.t, 'yyyy-MM-dd HH:mm:ss').alias('date')).collect() + >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + >>> df.select(to_date(df.t, "yyyy-MM-dd HH:mm:ss").alias("date")).collect() [Row(date=datetime.date(1997, 2, 28))] """ return _to_date_or_timestamp(col, _types.DateType(), format) @@ -5708,12 +5856,12 @@ def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: Examples -------- - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(to_timestamp(df.t).alias('dt')).collect() + >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + >>> df.select(to_timestamp(df.t).alias("dt")).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(to_timestamp(df.t, 'yyyy-MM-dd HH:mm:ss').alias('dt')).collect() + >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + >>> df.select(to_timestamp(df.t, "yyyy-MM-dd HH:mm:ss").alias("dt")).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] """ return _to_date_or_timestamp(col, _types.TimestampNTZType(), format) @@ -5739,12 +5887,12 @@ def to_timestamp_ltz( Examples -------- >>> df = spark.createDataFrame([("2016-12-31",)], ["e"]) - >>> df.select(to_timestamp_ltz(df.e, lit("yyyy-MM-dd")).alias('r')).collect() + >>> df.select(to_timestamp_ltz(df.e, lit("yyyy-MM-dd")).alias("r")).collect() ... # doctest: +SKIP [Row(r=datetime.datetime(2016, 12, 31, 0, 0))] >>> df = spark.createDataFrame([("2016-12-31",)], ["e"]) - >>> df.select(to_timestamp_ltz(df.e).alias('r')).collect() + >>> df.select(to_timestamp_ltz(df.e).alias("r")).collect() ... # doctest: +SKIP [Row(r=datetime.datetime(2016, 12, 31, 0, 0))] """ @@ -5771,12 +5919,12 @@ def to_timestamp_ntz( Examples -------- >>> df = spark.createDataFrame([("2016-04-08",)], ["e"]) - >>> df.select(to_timestamp_ntz(df.e, lit("yyyy-MM-dd")).alias('r')).collect() + >>> df.select(to_timestamp_ntz(df.e, lit("yyyy-MM-dd")).alias("r")).collect() ... # doctest: +SKIP [Row(r=datetime.datetime(2016, 4, 8, 0, 0))] >>> df = spark.createDataFrame([("2016-04-08",)], ["e"]) - >>> df.select(to_timestamp_ntz(df.e).alias('r')).collect() + >>> df.select(to_timestamp_ntz(df.e).alias("r")).collect() ... # doctest: +SKIP [Row(r=datetime.datetime(2016, 4, 8, 0, 0))] """ @@ -5797,20 +5945,19 @@ def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = Non format to use to convert timestamp values. Examples -------- - >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(try_to_timestamp(df.t).alias('dt')).collect() + >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + >>> df.select(try_to_timestamp(df.t).alias("dt")).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] - >>> df.select(try_to_timestamp(df.t, lit('yyyy-MM-dd HH:mm:ss')).alias('dt')).collect() + >>> df.select(try_to_timestamp(df.t, lit("yyyy-MM-dd HH:mm:ss")).alias("dt")).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] """ if format is None: - format = lit(['%Y-%m-%d', '%Y-%m-%d %H:%M:%S']) + format = lit(["%Y-%m-%d", "%Y-%m-%d %H:%M:%S"]) return _invoke_function_over_columns("try_strptime", col, format) -def substr( - str: "ColumnOrName", pos: "ColumnOrName", len: Optional["ColumnOrName"] = None -) -> Column: + +def substr(str: "ColumnOrName", pos: "ColumnOrName", len: Optional["ColumnOrName"] = None) -> Column: """ Returns the substring of `str` that starts at `pos` and is of length `len`, or the slice of byte array that starts at `pos` and is of length `len`. @@ -5830,7 +5977,14 @@ def substr( -------- >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( - ... [("Spark SQL", 5, 1,)], ["a", "b", "c"] + ... [ + ... ( + ... "Spark SQL", + ... 5, + ... 1, + ... ) + ... ], + ... ["a", "b", "c"], ... ).select(sf.substr("a", "b", "c")).show() +---------------+ |substr(a, b, c)| @@ -5840,7 +5994,14 @@ def substr( >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( - ... [("Spark SQL", 5, 1,)], ["a", "b", "c"] + ... [ + ... ( + ... "Spark SQL", + ... 5, + ... 1, + ... ) + ... ], + ... ["a", "b", "c"], ... ).select(sf.substr("a", "b")).show() +------------------------+ |substr(a, b, 2147483647)| @@ -5855,7 +6016,10 @@ def substr( def _unix_diff(col: "ColumnOrName", part: str) -> Column: - return _invoke_function_over_columns("date_diff", lit(part), lit("1970-01-01 00:00:00+00:00").cast("timestamp"), col) + return _invoke_function_over_columns( + "date_diff", lit(part), lit("1970-01-01 00:00:00+00:00").cast("timestamp"), col + ) + def unix_date(col: "ColumnOrName") -> Column: """Returns the number of days since 1970-01-01. @@ -5865,8 +6029,8 @@ def unix_date(col: "ColumnOrName") -> Column: Examples -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") - >>> df = spark.createDataFrame([('1970-01-02',)], ['t']) - >>> df.select(unix_date(to_date(df.t)).alias('n')).collect() + >>> df = spark.createDataFrame([("1970-01-02",)], ["t"]) + >>> df.select(unix_date(to_date(df.t)).alias("n")).collect() [Row(n=1)] >>> spark.conf.unset("spark.sql.session.timeZone") """ @@ -5881,8 +6045,8 @@ def unix_micros(col: "ColumnOrName") -> Column: Examples -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") - >>> df = spark.createDataFrame([('2015-07-22 10:00:00',)], ['t']) - >>> df.select(unix_micros(to_timestamp(df.t)).alias('n')).collect() + >>> df = spark.createDataFrame([("2015-07-22 10:00:00",)], ["t"]) + >>> df.select(unix_micros(to_timestamp(df.t)).alias("n")).collect() [Row(n=1437584400000000)] >>> spark.conf.unset("spark.sql.session.timeZone") """ @@ -5898,8 +6062,8 @@ def unix_millis(col: "ColumnOrName") -> Column: Examples -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") - >>> df = spark.createDataFrame([('2015-07-22 10:00:00',)], ['t']) - >>> df.select(unix_millis(to_timestamp(df.t)).alias('n')).collect() + >>> df = spark.createDataFrame([("2015-07-22 10:00:00",)], ["t"]) + >>> df.select(unix_millis(to_timestamp(df.t)).alias("n")).collect() [Row(n=1437584400000)] >>> spark.conf.unset("spark.sql.session.timeZone") """ @@ -5915,8 +6079,8 @@ def unix_seconds(col: "ColumnOrName") -> Column: Examples -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") - >>> df = spark.createDataFrame([('2015-07-22 10:00:00',)], ['t']) - >>> df.select(unix_seconds(to_timestamp(df.t)).alias('n')).collect() + >>> df = spark.createDataFrame([("2015-07-22 10:00:00",)], ["t"]) + >>> df.select(unix_seconds(to_timestamp(df.t)).alias("n")).collect() [Row(n=1437584400)] >>> spark.conf.unset("spark.sql.session.timeZone") """ @@ -5941,7 +6105,7 @@ def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y']) + >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ["x", "y"]) >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() [Row(overlap=True), Row(overlap=False)] """ @@ -5952,21 +6116,19 @@ def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: a2_has_null = _list_contains_null(a2) return Column( - CaseExpression( - FunctionExpression("list_has_any", a1, a2), ConstantExpression(True) - ).otherwise( + CaseExpression(FunctionExpression("list_has_any", a1, a2), ConstantExpression(True)).otherwise( CaseExpression( - (FunctionExpression("len", a1) > 0) & (FunctionExpression("len", a2) > 0) & (a1_has_null | a2_has_null), ConstantExpression(None) - ).otherwise(ConstantExpression(False))) + (FunctionExpression("len", a1) > 0) & (FunctionExpression("len", a2) > 0) & (a1_has_null | a2_has_null), + ConstantExpression(None), + ).otherwise(ConstantExpression(False)) + ) ) def _list_contains_null(c: ColumnExpression) -> Expression: return FunctionExpression( "list_contains", - FunctionExpression( - "list_transform", c, LambdaExpression("x", ColumnExpression("x").isnull()) - ), + FunctionExpression("list_transform", c, LambdaExpression("x", ColumnExpression("x").isnull())), True, ) @@ -5995,8 +6157,10 @@ def arrays_zip(*cols: "ColumnOrName") -> Column: Examples -------- >>> from pyspark.sql.functions import arrays_zip - >>> df = spark.createDataFrame([([1, 2, 3], [2, 4, 6], [3, 6])], ['vals1', 'vals2', 'vals3']) - >>> df = df.select(arrays_zip(df.vals1, df.vals2, df.vals3).alias('zipped')) + >>> df = spark.createDataFrame( + ... [([1, 2, 3], [2, 4, 6], [3, 6])], ["vals1", "vals2", "vals3"] + ... ) + >>> df = df.select(arrays_zip(df.vals1, df.vals2, df.vals3).alias("zipped")) >>> df.show(truncate=False) +------------------------------------+ |zipped | @@ -6039,8 +6203,13 @@ def substring(str: "ColumnOrName", pos: int, len: int) -> Column: substring of given value. Examples -------- - >>> df = spark.createDataFrame([('abcd',)], ['s',]) - >>> df.select(substring(df.s, 1, 2).alias('s')).collect() + >>> df = spark.createDataFrame( + ... [("abcd",)], + ... [ + ... "s", + ... ], + ... ) + >>> df.select(substring(df.s, 1, 2).alias("s")).collect() [Row(s='ab')] """ return _invoke_function( @@ -6065,10 +6234,18 @@ def contains(left: "ColumnOrName", right: "ColumnOrName") -> Column: The input column or strings to find, may be NULL. Examples -------- - >>> df = spark.createDataFrame([("Spark SQL", "Spark")], ['a', 'b']) - >>> df.select(contains(df.a, df.b).alias('r')).collect() + >>> df = spark.createDataFrame([("Spark SQL", "Spark")], ["a", "b"]) + >>> df.select(contains(df.a, df.b).alias("r")).collect() [Row(r=True)] - >>> df = spark.createDataFrame([("414243", "4243",)], ["c", "d"]) + >>> df = spark.createDataFrame( + ... [ + ... ( + ... "414243", + ... "4243", + ... ) + ... ], + ... ["c", "d"], + ... ) >>> df = df.select(to_binary("c").alias("c"), to_binary("d").alias("d")) >>> df.printSchema() root @@ -6100,15 +6277,16 @@ def reverse(col: "ColumnOrName") -> Column: array of elements in reverse order. Examples -------- - >>> df = spark.createDataFrame([('Spark SQL',)], ['data']) - >>> df.select(reverse(df.data).alias('s')).collect() + >>> df = spark.createDataFrame([("Spark SQL",)], ["data"]) + >>> df.select(reverse(df.data).alias("s")).collect() [Row(s='LQS krapS')] - >>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data']) - >>> df.select(reverse(df.data).alias('r')).collect() + >>> df = spark.createDataFrame([([2, 1, 3],), ([1],), ([],)], ["data"]) + >>> df.select(reverse(df.data).alias("r")).collect() [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])] """ return _invoke_function("reverse", _to_column_expr(col)) + def concat(*cols: "ColumnOrName") -> Column: """ Concatenates multiple input columns together into a single column. @@ -6129,13 +6307,15 @@ def concat(*cols: "ColumnOrName") -> Column: :meth:`pyspark.sql.functions.array_join` : to concatenate string columns with delimiter Examples -------- - >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) - >>> df = df.select(concat(df.s, df.d).alias('s')) + >>> df = spark.createDataFrame([("abcd", "123")], ["s", "d"]) + >>> df = df.select(concat(df.s, df.d).alias("s")) >>> df.collect() [Row(s='abcd123')] >>> df DataFrame[s: string] - >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c']) + >>> df = spark.createDataFrame( + ... [([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ["a", "b", "c"] + ... ) >>> df = df.select(concat(df.a, df.b, df.c).alias("arr")) >>> df.collect() [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] @@ -6174,12 +6354,18 @@ def instr(str: "ColumnOrName", substr: str) -> Column: Examples -------- - >>> df = spark.createDataFrame([('abcd',)], ['s',]) - >>> df.select(instr(df.s, 'b').alias('s')).collect() + >>> df = spark.createDataFrame( + ... [("abcd",)], + ... [ + ... "s", + ... ], + ... ) + >>> df.select(instr(df.s, "b").alias("s")).collect() [Row(s=2)] """ return _invoke_function("instr", _to_column_expr(str), ConstantExpression(substr)) + def expr(str: str) -> Column: """Parses the expression string into the column that it represents @@ -6211,6 +6397,7 @@ def expr(str: str) -> Column: """ return Column(SQLExpression(str)) + def broadcast(df: "DataFrame") -> "DataFrame": """ The broadcast function in Spark is used to optimize joins by broadcasting a smaller diff --git a/duckdb/experimental/spark/sql/group.py b/duckdb/experimental/spark/sql/group.py index 4c4d5bb6..29210e29 100644 --- a/duckdb/experimental/spark/sql/group.py +++ b/duckdb/experimental/spark/sql/group.py @@ -30,6 +30,7 @@ __all__ = ["GroupedData", "Grouping"] + def _api_internal(self: "GroupedData", name: str, *cols: str) -> DataFrame: expressions = ",".join(list(cols)) group_by = str(self._grouping) if self._grouping else "" @@ -42,6 +43,7 @@ def _api_internal(self: "GroupedData", name: str, *cols: str) -> DataFrame: ) return DataFrame(jdf, self.session) + def df_varargs_api(f: Callable[..., DataFrame]) -> Callable[..., DataFrame]: def _api(self: "GroupedData", *cols: str) -> DataFrame: name = f.__name__ @@ -56,8 +58,8 @@ class Grouping: def __init__(self, *cols: "ColumnOrName", **kwargs) -> None: self._type = "" self._cols = [_to_column_expr(x) for x in cols] - if 'special' in kwargs: - special = kwargs['special'] + if "special" in kwargs: + special = kwargs["special"] accepted_special = ["cube", "rollup"] assert special in accepted_special self._type = special @@ -69,7 +71,7 @@ def get_columns(self) -> str: def __str__(self) -> str: columns = self.get_columns() if self._type: - return self._type + '(' + columns + ')' + return self._type + "(" + columns + ")" return columns @@ -94,7 +96,8 @@ def count(self) -> DataFrame: Examples -------- >>> df = spark.createDataFrame( - ... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"]) + ... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"] + ... ) >>> df.show() +---+-----+ |age| name| @@ -115,7 +118,7 @@ def count(self) -> DataFrame: | Bob| 2| +-----+-----+ """ - return _api_internal(self, "count").withColumnRenamed('count_star()', 'count') + return _api_internal(self, "count").withColumnRenamed("count_star()", "count") @df_varargs_api def mean(self, *cols: str) -> DataFrame: @@ -141,9 +144,10 @@ def avg(self, *cols: str) -> DataFrame: Examples -------- - >>> df = spark.createDataFrame([ - ... (2, "Alice", 80), (3, "Alice", 100), - ... (5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"]) + >>> df = spark.createDataFrame( + ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], + ... ["age", "name", "height"], + ... ) >>> df.show() +---+-----+------+ |age| name|height| @@ -156,7 +160,7 @@ def avg(self, *cols: str) -> DataFrame: Group-by name, and calculate the mean of the age in each group. - >>> df.groupBy("name").avg('age').sort("name").show() + >>> df.groupBy("name").avg("age").sort("name").show() +-----+--------+ | name|avg(age)| +-----+--------+ @@ -166,7 +170,7 @@ def avg(self, *cols: str) -> DataFrame: Calculate the mean of the age and height in all data. - >>> df.groupBy().avg('age', 'height').show() + >>> df.groupBy().avg("age", "height").show() +--------+-----------+ |avg(age)|avg(height)| +--------+-----------+ @@ -186,9 +190,10 @@ def max(self, *cols: str) -> DataFrame: Examples -------- - >>> df = spark.createDataFrame([ - ... (2, "Alice", 80), (3, "Alice", 100), - ... (5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"]) + >>> df = spark.createDataFrame( + ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], + ... ["age", "name", "height"], + ... ) >>> df.show() +---+-----+------+ |age| name|height| @@ -230,9 +235,10 @@ def min(self, *cols: str) -> DataFrame: Examples -------- - >>> df = spark.createDataFrame([ - ... (2, "Alice", 80), (3, "Alice", 100), - ... (5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"]) + >>> df = spark.createDataFrame( + ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], + ... ["age", "name", "height"], + ... ) >>> df.show() +---+-----+------+ |age| name|height| @@ -274,9 +280,10 @@ def sum(self, *cols: str) -> DataFrame: Examples -------- - >>> df = spark.createDataFrame([ - ... (2, "Alice", 80), (3, "Alice", 100), - ... (5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"]) + >>> df = spark.createDataFrame( + ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], + ... ["age", "name", "height"], + ... ) >>> df.show() +---+-----+------+ |age| name|height| @@ -308,12 +315,10 @@ def sum(self, *cols: str) -> DataFrame: """ @overload - def agg(self, *exprs: Column) -> DataFrame: - ... + def agg(self, *exprs: Column) -> DataFrame: ... @overload - def agg(self, __exprs: dict[str, str]) -> DataFrame: - ... + def agg(self, __exprs: dict[str, str]) -> DataFrame: ... def agg(self, *exprs: Union[Column, dict[str, str]]) -> DataFrame: """Compute aggregates and returns the result as a :class:`DataFrame`. @@ -357,7 +362,8 @@ def agg(self, *exprs: Union[Column, dict[str, str]]) -> DataFrame: >>> from pyspark.sql import functions as F >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df = spark.createDataFrame( - ... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"]) + ... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"] + ... ) >>> df.show() +---+-----+ |age| name| @@ -393,10 +399,9 @@ def agg(self, *exprs: Union[Column, dict[str, str]]) -> DataFrame: Same as above but uses pandas UDF. - >>> @pandas_udf('int', PandasUDFType.GROUPED_AGG) # doctest: +SKIP + >>> @pandas_udf("int", PandasUDFType.GROUPED_AGG) # doctest: +SKIP ... def min_udf(v): ... return v.min() - ... >>> df.groupBy(df.name).agg(min_udf(df.age)).sort("name").show() # doctest: +SKIP +-----+------------+ | name|min_udf(age)| diff --git a/duckdb/experimental/spark/sql/readwriter.py b/duckdb/experimental/spark/sql/readwriter.py index 6e8c72c6..18095ab6 100644 --- a/duckdb/experimental/spark/sql/readwriter.py +++ b/duckdb/experimental/spark/sql/readwriter.py @@ -328,9 +328,9 @@ def json( >>> import tempfile >>> with tempfile.TemporaryDirectory() as d: ... # Write a DataFrame into a JSON file - ... spark.createDataFrame( - ... [{"age": 100, "name": "Hyukjin Kwon"}] - ... ).write.mode("overwrite").format("json").save(d) + ... spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}]).write.mode( + ... "overwrite" + ... ).format("json").save(d) ... ... # Read the JSON file as a DataFrame. ... spark.read.json(d).show() @@ -344,98 +344,62 @@ def json( if schema is not None: raise ContributionsAcceptedError("The 'schema' option is not supported") if primitivesAsString is not None: - raise ContributionsAcceptedError( - "The 'primitivesAsString' option is not supported" - ) + raise ContributionsAcceptedError("The 'primitivesAsString' option is not supported") if prefersDecimal is not None: - raise ContributionsAcceptedError( - "The 'prefersDecimal' option is not supported" - ) + raise ContributionsAcceptedError("The 'prefersDecimal' option is not supported") if allowComments is not None: - raise ContributionsAcceptedError( - "The 'allowComments' option is not supported" - ) + raise ContributionsAcceptedError("The 'allowComments' option is not supported") if allowUnquotedFieldNames is not None: - raise ContributionsAcceptedError( - "The 'allowUnquotedFieldNames' option is not supported" - ) + raise ContributionsAcceptedError("The 'allowUnquotedFieldNames' option is not supported") if allowSingleQuotes is not None: - raise ContributionsAcceptedError( - "The 'allowSingleQuotes' option is not supported" - ) + raise ContributionsAcceptedError("The 'allowSingleQuotes' option is not supported") if allowNumericLeadingZero is not None: - raise ContributionsAcceptedError( - "The 'allowNumericLeadingZero' option is not supported" - ) + raise ContributionsAcceptedError("The 'allowNumericLeadingZero' option is not supported") if allowBackslashEscapingAnyCharacter is not None: - raise ContributionsAcceptedError( - "The 'allowBackslashEscapingAnyCharacter' option is not supported" - ) + raise ContributionsAcceptedError("The 'allowBackslashEscapingAnyCharacter' option is not supported") if mode is not None: raise ContributionsAcceptedError("The 'mode' option is not supported") if columnNameOfCorruptRecord is not None: - raise ContributionsAcceptedError( - "The 'columnNameOfCorruptRecord' option is not supported" - ) + raise ContributionsAcceptedError("The 'columnNameOfCorruptRecord' option is not supported") if dateFormat is not None: raise ContributionsAcceptedError("The 'dateFormat' option is not supported") if timestampFormat is not None: - raise ContributionsAcceptedError( - "The 'timestampFormat' option is not supported" - ) + raise ContributionsAcceptedError("The 'timestampFormat' option is not supported") if multiLine is not None: raise ContributionsAcceptedError("The 'multiLine' option is not supported") if allowUnquotedControlChars is not None: - raise ContributionsAcceptedError( - "The 'allowUnquotedControlChars' option is not supported" - ) + raise ContributionsAcceptedError("The 'allowUnquotedControlChars' option is not supported") if lineSep is not None: raise ContributionsAcceptedError("The 'lineSep' option is not supported") if samplingRatio is not None: - raise ContributionsAcceptedError( - "The 'samplingRatio' option is not supported" - ) + raise ContributionsAcceptedError("The 'samplingRatio' option is not supported") if dropFieldIfAllNull is not None: - raise ContributionsAcceptedError( - "The 'dropFieldIfAllNull' option is not supported" - ) + raise ContributionsAcceptedError("The 'dropFieldIfAllNull' option is not supported") if encoding is not None: raise ContributionsAcceptedError("The 'encoding' option is not supported") if locale is not None: raise ContributionsAcceptedError("The 'locale' option is not supported") if pathGlobFilter is not None: - raise ContributionsAcceptedError( - "The 'pathGlobFilter' option is not supported" - ) + raise ContributionsAcceptedError("The 'pathGlobFilter' option is not supported") if recursiveFileLookup is not None: - raise ContributionsAcceptedError( - "The 'recursiveFileLookup' option is not supported" - ) + raise ContributionsAcceptedError("The 'recursiveFileLookup' option is not supported") if modifiedBefore is not None: - raise ContributionsAcceptedError( - "The 'modifiedBefore' option is not supported" - ) + raise ContributionsAcceptedError("The 'modifiedBefore' option is not supported") if modifiedAfter is not None: - raise ContributionsAcceptedError( - "The 'modifiedAfter' option is not supported" - ) + raise ContributionsAcceptedError("The 'modifiedAfter' option is not supported") if allowNonNumericNumbers is not None: - raise ContributionsAcceptedError( - "The 'allowNonNumericNumbers' option is not supported" - ) + raise ContributionsAcceptedError("The 'allowNonNumericNumbers' option is not supported") if isinstance(path, str): path = [path] - if isinstance(path, list): + if isinstance(path, list): if len(path) == 1: rel = self.session.conn.read_json(path[0]) from .dataframe import DataFrame df = DataFrame(rel, self.session) return df - raise PySparkNotImplementedError( - message="Only a single path is supported for now" - ) + raise PySparkNotImplementedError(message="Only a single path is supported for now") else: raise PySparkTypeError( error_class="NOT_STR_OR_LIST_OF_RDD", diff --git a/duckdb/experimental/spark/sql/session.py b/duckdb/experimental/spark/sql/session.py index 744a77e8..c83c7e82 100644 --- a/duckdb/experimental/spark/sql/session.py +++ b/duckdb/experimental/spark/sql/session.py @@ -16,10 +16,7 @@ from .streaming import DataStreamReader import duckdb -from ..errors import ( - PySparkTypeError, - PySparkValueError -) +from ..errors import PySparkTypeError, PySparkValueError from ..errors.error_classes import * @@ -53,11 +50,12 @@ def __init__(self, context: SparkContext) -> None: def _create_dataframe(self, data: Union[Iterable[Any], "PandasDataFrame"]) -> DataFrame: try: import pandas + has_pandas = True except ImportError: has_pandas = False if has_pandas and isinstance(data, pandas.DataFrame): - unique_name = f'pyspark_pandas_df_{uuid.uuid1()}' + unique_name = f"pyspark_pandas_df_{uuid.uuid1()}" self.conn.register(unique_name, data) return DataFrame(self.conn.sql(f'select * from "{unique_name}"'), self) @@ -73,9 +71,9 @@ def verify_tuple_integrity(tuples): error_class="LENGTH_SHOULD_BE_THE_SAME", message_parameters={ "arg1": f"data{i}", - "arg2": f"data{i+1}", + "arg2": f"data{i + 1}", "arg1_length": str(expected_length), - "arg2_length": str(actual_length) + "arg2_length": str(actual_length), }, ) @@ -86,13 +84,13 @@ def verify_tuple_integrity(tuples): def construct_query(tuples) -> str: def construct_values_list(row, start_param_idx): parameter_count = len(row) - parameters = [f'${x+start_param_idx}' for x in range(parameter_count)] - parameters = '(' + ', '.join(parameters) + ')' + parameters = [f"${x + start_param_idx}" for x in range(parameter_count)] + parameters = "(" + ", ".join(parameters) + ")" return parameters row_size = len(tuples[0]) values_list = [construct_values_list(x, 1 + (i * row_size)) for i, x in enumerate(tuples)] - values_list = ', '.join(values_list) + values_list = ", ".join(values_list) query = f""" select * from (values {values_list}) @@ -175,7 +173,7 @@ def createDataFrame( if is_empty: rel = df.relation # Add impossible where clause - rel = rel.filter('1=0') + rel = rel.filter("1=0") df = DataFrame(rel, self) # Cast to types @@ -203,7 +201,7 @@ def range( end = start start = 0 - return DataFrame(self.conn.table_function("range", parameters=[start, end, step]),self) + return DataFrame(self.conn.table_function("range", parameters=[start, end, step]), self) def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame: if kwargs: @@ -255,7 +253,7 @@ def udf(self) -> UDFRegistration: @property def version(self) -> str: - return '1.0.0' + return "1.0.0" class Builder: def __init__(self) -> None: diff --git a/duckdb/experimental/spark/sql/streaming.py b/duckdb/experimental/spark/sql/streaming.py index cda80602..4dcba01f 100644 --- a/duckdb/experimental/spark/sql/streaming.py +++ b/duckdb/experimental/spark/sql/streaming.py @@ -27,7 +27,7 @@ def load( path: Optional[str] = None, format: Optional[str] = None, schema: Union[StructType, str, None] = None, - **options: OptionalPrimitiveType + **options: OptionalPrimitiveType, ) -> "DataFrame": from duckdb.experimental.spark.sql.dataframe import DataFrame diff --git a/duckdb/experimental/spark/sql/type_utils.py b/duckdb/experimental/spark/sql/type_utils.py index ecccc014..f8c8ce4f 100644 --- a/duckdb/experimental/spark/sql/type_utils.py +++ b/duckdb/experimental/spark/sql/type_utils.py @@ -36,62 +36,62 @@ ) _sqltype_to_spark_class = { - 'boolean': BooleanType, - 'utinyint': UnsignedByteType, - 'tinyint': ByteType, - 'usmallint': UnsignedShortType, - 'smallint': ShortType, - 'uinteger': UnsignedIntegerType, - 'integer': IntegerType, - 'ubigint': UnsignedLongType, - 'bigint': LongType, - 'hugeint': HugeIntegerType, - 'uhugeint': UnsignedHugeIntegerType, - 'varchar': StringType, - 'blob': BinaryType, - 'bit': BitstringType, - 'uuid': UUIDType, - 'date': DateType, - 'time': TimeNTZType, - 'time with time zone': TimeType, - 'timestamp': TimestampNTZType, - 'timestamp with time zone': TimestampType, - 'timestamp_ms': TimestampNanosecondNTZType, - 'timestamp_ns': TimestampMilisecondNTZType, - 'timestamp_s': TimestampSecondNTZType, - 'interval': DayTimeIntervalType, - 'list': ArrayType, - 'struct': StructType, - 'map': MapType, + "boolean": BooleanType, + "utinyint": UnsignedByteType, + "tinyint": ByteType, + "usmallint": UnsignedShortType, + "smallint": ShortType, + "uinteger": UnsignedIntegerType, + "integer": IntegerType, + "ubigint": UnsignedLongType, + "bigint": LongType, + "hugeint": HugeIntegerType, + "uhugeint": UnsignedHugeIntegerType, + "varchar": StringType, + "blob": BinaryType, + "bit": BitstringType, + "uuid": UUIDType, + "date": DateType, + "time": TimeNTZType, + "time with time zone": TimeType, + "timestamp": TimestampNTZType, + "timestamp with time zone": TimestampType, + "timestamp_ms": TimestampNanosecondNTZType, + "timestamp_ns": TimestampMilisecondNTZType, + "timestamp_s": TimestampSecondNTZType, + "interval": DayTimeIntervalType, + "list": ArrayType, + "struct": StructType, + "map": MapType, # union # enum # null (???) - 'float': FloatType, - 'double': DoubleType, - 'decimal': DecimalType, + "float": FloatType, + "double": DoubleType, + "decimal": DecimalType, } def convert_nested_type(dtype: DuckDBPyType) -> DataType: id = dtype.id - if id == 'list' or id == 'array': + if id == "list" or id == "array": children = dtype.children return ArrayType(convert_type(children[0][1])) # TODO: add support for 'union' - if id == 'struct': + if id == "struct": children: list[tuple[str, DuckDBPyType]] = dtype.children fields = [StructField(x[0], convert_type(x[1])) for x in children] return StructType(fields) - if id == 'map': + if id == "map": return MapType(convert_type(dtype.key), convert_type(dtype.value)) raise NotImplementedError def convert_type(dtype: DuckDBPyType) -> DataType: id = dtype.id - if id in ['list', 'struct', 'map', 'array']: + if id in ["list", "struct", "map", "array"]: return convert_nested_type(dtype) - if id == 'decimal': + if id == "decimal": children: list[tuple[str, DuckDBPyType]] = dtype.children precision = cast(int, children[0][1]) scale = cast(int, children[1][1]) diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 4b3a4132..81293caf 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -632,11 +632,9 @@ class MapType(DataType): Examples -------- - >>> (MapType(StringType(), IntegerType()) - ... == MapType(StringType(), IntegerType(), True)) + >>> (MapType(StringType(), IntegerType()) == MapType(StringType(), IntegerType(), True)) True - >>> (MapType(StringType(), IntegerType(), False) - ... == MapType(StringType(), FloatType())) + >>> (MapType(StringType(), IntegerType(), False) == MapType(StringType(), FloatType())) False """ @@ -697,11 +695,9 @@ class StructField(DataType): Examples -------- - >>> (StructField("f1", StringType(), True) - ... == StructField("f1", StringType(), True)) + >>> (StructField("f1", StringType(), True) == StructField("f1", StringType(), True)) True - >>> (StructField("f1", StringType(), True) - ... == StructField("f2", StringType(), True)) + >>> (StructField("f1", StringType(), True) == StructField("f2", StringType(), True)) False """ @@ -743,7 +739,7 @@ def fromInternal(self, obj: T) -> T: return self.dataType.fromInternal(obj) def typeName(self) -> str: # type: ignore[override] - raise TypeError("StructField does not have typeName. " "Use typeName on its type explicitly instead.") + raise TypeError("StructField does not have typeName. Use typeName on its type explicitly instead.") class StructType(DataType): @@ -767,8 +763,9 @@ class StructType(DataType): >>> struct1 == struct2 True >>> struct1 = StructType([StructField("f1", StringType(), True)]) - >>> struct2 = StructType([StructField("f1", StringType(), True), - ... StructField("f2", IntegerType(), False)]) + >>> struct2 = StructType( + ... [StructField("f1", StringType(), True), StructField("f2", IntegerType(), False)] + ... ) >>> struct1 == struct2 False """ @@ -796,12 +793,10 @@ def add( data_type: Union[str, DataType], nullable: bool = True, metadata: Optional[dict[str, Any]] = None, - ) -> "StructType": - ... + ) -> "StructType": ... @overload - def add(self, field: StructField) -> "StructType": - ... + def add(self, field: StructField) -> "StructType": ... def add( self, @@ -1091,7 +1086,6 @@ def _create_row(fields: Union["Row", list[str]], values: Union[tuple[Any, ...], class Row(tuple): - """ A row in :class:`DataFrame`. The fields in it can be accessed: @@ -1115,13 +1109,13 @@ class Row(tuple): >>> row = Row(name="Alice", age=11) >>> row Row(name='Alice', age=11) - >>> row['name'], row['age'] + >>> row["name"], row["age"] ('Alice', 11) >>> row.name, row.age ('Alice', 11) - >>> 'name' in row + >>> "name" in row True - >>> 'wrong_key' in row + >>> "wrong_key" in row False Row also can be used to create another Row like class, then it @@ -1130,9 +1124,9 @@ class Row(tuple): >>> Person = Row("name", "age") >>> Person - >>> 'name' in Person + >>> "name" in Person True - >>> 'wrong_key' in Person + >>> "wrong_key" in Person False >>> Person("Alice", 11) Row(name='Alice', age=11) @@ -1147,16 +1141,14 @@ class Row(tuple): """ @overload - def __new__(cls, *args: str) -> "Row": - ... + def __new__(cls, *args: str) -> "Row": ... @overload - def __new__(cls, **kwargs: Any) -> "Row": - ... + def __new__(cls, **kwargs: Any) -> "Row": ... def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": if args and kwargs: - raise ValueError("Can not use both args " "and kwargs to create Row") + raise ValueError("Can not use both args and kwargs to create Row") if kwargs: # create row objects row = tuple.__new__(cls, list(kwargs.values())) @@ -1185,12 +1177,12 @@ def asDict(self, recursive: bool = False) -> dict[str, Any]: Examples -------- - >>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11} + >>> Row(name="Alice", age=11).asDict() == {"name": "Alice", "age": 11} True - >>> row = Row(key=1, value=Row(name='a', age=2)) - >>> row.asDict() == {'key': 1, 'value': Row(name='a', age=2)} + >>> row = Row(key=1, value=Row(name="a", age=2)) + >>> row.asDict() == {"key": 1, "value": Row(name="a", age=2)} True - >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}} + >>> row.asDict(True) == {"key": 1, "value": {"name": "a", "age": 2}} True """ if not hasattr(self, "__fields__"): @@ -1223,7 +1215,7 @@ def __call__(self, *args: Any) -> "Row": """create new Row object""" if len(args) > len(self): raise ValueError( - "Can not create Row with fields %s, expected %d values " "but got %s" % (self, len(self), args) + "Can not create Row with fields %s, expected %d values but got %s" % (self, len(self), args) ) return _create_row(self, args) diff --git a/duckdb/filesystem.py b/duckdb/filesystem.py index fbef757d..ea4ba540 100644 --- a/duckdb/filesystem.py +++ b/duckdb/filesystem.py @@ -3,13 +3,14 @@ from .bytes_io_wrapper import BytesIOWrapper from io import TextIOBase + def is_file_like(obj): # We only care that we can read from the file return hasattr(obj, "read") and hasattr(obj, "seek") class ModifiedMemoryFileSystem(MemoryFileSystem): - protocol = ('DUCKDB_INTERNAL_OBJECTSTORE',) + protocol = ("DUCKDB_INTERNAL_OBJECTSTORE",) # defer to the original implementation that doesn't hardcode the protocol _strip_protocol = classmethod(AbstractFileSystem._strip_protocol.__func__) diff --git a/duckdb/functional/__init__.py b/duckdb/functional/__init__.py index ac4a6495..90c2a561 100644 --- a/duckdb/functional/__init__.py +++ b/duckdb/functional/__init__.py @@ -1,17 +1,3 @@ -from _duckdb.functional import ( - FunctionNullHandling, - PythonUDFType, - SPECIAL, - DEFAULT, - NATIVE, - ARROW -) +from _duckdb.functional import FunctionNullHandling, PythonUDFType, SPECIAL, DEFAULT, NATIVE, ARROW -__all__ = [ - "FunctionNullHandling", - "PythonUDFType", - "SPECIAL", - "DEFAULT", - "NATIVE", - "ARROW" -] +__all__ = ["FunctionNullHandling", "PythonUDFType", "SPECIAL", "DEFAULT", "NATIVE", "ARROW"] diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index d8d4cfe9..ef87f03a 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -8,13 +8,14 @@ from decimal import Decimal import datetime + def _predicate_to_expression(predicate: pl.Expr) -> Optional[SQLExpression]: """ Convert a Polars predicate expression to a DuckDB-compatible SQL expression. - + Parameters: predicate (pl.Expr): A Polars expression (e.g., col("foo") > 5) - + Returns: SQLExpression: A DuckDB SQL expression string equivalent. None: If conversion fails. @@ -25,7 +26,7 @@ def _predicate_to_expression(predicate: pl.Expr) -> Optional[SQLExpression]: """ # Serialize the Polars expression tree to JSON tree = json.loads(predicate.meta.serialize(format="json")) - + try: # Convert the tree to SQL sql_filter = _pl_tree_to_sql(tree) @@ -38,7 +39,7 @@ def _predicate_to_expression(predicate: pl.Expr) -> Optional[SQLExpression]: def _pl_operation_to_sql(op: str) -> str: """ Map Polars binary operation strings to SQL equivalents. - + Example: >>> _pl_operation_to_sql("Eq") '=' @@ -73,13 +74,13 @@ def _escape_sql_identifier(identifier: str) -> str: def _pl_tree_to_sql(tree: dict) -> str: """ Recursively convert a Polars expression tree (as JSON) to a SQL string. - + Parameters: tree (dict): JSON-deserialized expression tree from Polars - + Returns: str: SQL expression string - + Example: Input tree: { @@ -97,13 +98,15 @@ def _pl_tree_to_sql(tree: dict) -> str: if node_type == "BinaryExpr": # Binary expressions: left OP right return ( - "(" + - " ".join(( - _pl_tree_to_sql(subtree['left']), - _pl_operation_to_sql(subtree['op']), - _pl_tree_to_sql(subtree['right']) - )) + - ")" + "(" + + " ".join( + ( + _pl_tree_to_sql(subtree["left"]), + _pl_operation_to_sql(subtree["op"]), + _pl_tree_to_sql(subtree["right"]), + ) + ) + + ")" ) if node_type == "Column": # A reference to a column name @@ -147,20 +150,30 @@ def _pl_tree_to_sql(tree: dict) -> str: # Decimal support if dtype.startswith("{'Decimal'") or dtype == "Decimal": - decimal_value = value['Decimal'] + decimal_value = value["Decimal"] decimal_value = Decimal(decimal_value[0]) / Decimal(10 ** decimal_value[1]) return str(decimal_value) # Datetime with microseconds since epoch if dtype.startswith("{'Datetime'") or dtype == "Datetime": - micros = value['Datetime'][0] + micros = value["Datetime"][0] dt_timestamp = datetime.datetime.fromtimestamp(micros / 1_000_000, tz=datetime.UTC) return f"'{str(dt_timestamp)}'::TIMESTAMP" # Match simple numeric/boolean types - if dtype in ("Int8", "Int16", "Int32", "Int64", - "UInt8", "UInt16", "UInt32", "UInt64", - "Float32", "Float64", "Boolean"): + if dtype in ( + "Int8", + "Int16", + "Int32", + "Int64", + "UInt8", + "UInt16", + "UInt32", + "UInt64", + "Float32", + "Float64", + "Boolean", + ): return str(value[dtype]) # Time type @@ -168,9 +181,7 @@ def _pl_tree_to_sql(tree: dict) -> str: nanoseconds = value["Time"] seconds = nanoseconds // 1_000_000_000 microseconds = (nanoseconds % 1_000_000_000) // 1_000 - dt_time = (datetime.datetime.min + datetime.timedelta( - seconds=seconds, microseconds=microseconds - )).time() + dt_time = (datetime.datetime.min + datetime.timedelta(seconds=seconds, microseconds=microseconds)).time() return f"'{dt_time}'::TIME" # Date type @@ -182,7 +193,7 @@ def _pl_tree_to_sql(tree: dict) -> str: # Binary type if dtype == "Binary": binary_data = bytes(value["Binary"]) - escaped = ''.join(f'\\x{b:02x}' for b in binary_data) + escaped = "".join(f"\\x{b:02x}" for b in binary_data) return f"'{escaped}'::BLOB" # String type @@ -191,15 +202,16 @@ def _pl_tree_to_sql(tree: dict) -> str: string_val = value.get("StringOwned", value.get("String", None)) return f"'{string_val}'" - raise NotImplementedError(f"Unsupported scalar type {str(dtype)}, with value {value}") raise NotImplementedError(f"Node type: {node_type} is not implemented. {subtree}") + def duckdb_source(relation: duckdb.DuckDBPyRelation, schema: pl.schema.Schema) -> pl.LazyFrame: """ A polars IO plugin for DuckDB. """ + def source_generator( with_columns: Optional[list[str]], predicate: Optional[pl.Expr], diff --git a/duckdb/query_graph/__main__.py b/duckdb/query_graph/__main__.py index eab68179..aa67b42f 100644 --- a/duckdb/query_graph/__main__.py +++ b/duckdb/query_graph/__main__.py @@ -77,7 +77,6 @@ class NodeTiming: - def __init__(self, phase: str, time: float) -> object: self.phase = phase self.time = time @@ -94,7 +93,6 @@ def combine_timing(l: object, r: object) -> object: class AllTimings: - def __init__(self) -> None: self.phase_to_timings = {} @@ -128,37 +126,38 @@ def open_utf8(fpath: str, flags: str) -> object: def get_child_timings(top_node: object, query_timings: object) -> str: - node_timing = NodeTiming(top_node['operator_type'], float(top_node['operator_timing'])) + node_timing = NodeTiming(top_node["operator_type"], float(top_node["operator_timing"])) query_timings.add_node_timing(node_timing) - for child in top_node['children']: + for child in top_node["children"]: get_child_timings(child, query_timings) def get_pink_shade_hex(fraction: float): fraction = max(0, min(1, fraction)) - + # Define the RGB values for very light pink (almost white) and dark pink light_pink = (255, 250, 250) # Very light pink - dark_pink = (255, 20, 147) # Dark pink - + dark_pink = (255, 20, 147) # Dark pink + # Calculate the RGB values for the given fraction r = int(light_pink[0] + (dark_pink[0] - light_pink[0]) * fraction) g = int(light_pink[1] + (dark_pink[1] - light_pink[1]) * fraction) b = int(light_pink[2] + (dark_pink[2] - light_pink[2]) * fraction) - + # Return as hexadecimal color code return f"#{r:02x}{g:02x}{b:02x}" + def get_node_body(name: str, result: str, cpu_time: float, card: int, est: int, width: int, extra_info: str) -> str: - node_style = f"background-color: {get_pink_shade_hex(float(result)/cpu_time)};" + node_style = f"background-color: {get_pink_shade_hex(float(result) / cpu_time)};" - body = f"" - body += "
" + body = f'' + body += '
' new_name = "BRIDGE" if (name == "INVALID") else name.replace("_", " ") formatted_num = f"{float(result):.4f}" body += f"

{new_name}

time: {formatted_num} seconds

" - body += f" {extra_info} " - if (width > 0): + body += f' {extra_info} ' + if width > 0: body += f"

cardinality: {card}

" body += f"

estimate: {est}

" body += f"

width: {width} bytes

" @@ -174,26 +173,31 @@ def generate_tree_recursive(json_graph: object, cpu_time: float) -> str: extra_info = "" estimate = 0 - for key in json_graph['extra_info']: - value = json_graph['extra_info'][key] - if (key == "Estimated Cardinality"): + for key in json_graph["extra_info"]: + value = json_graph["extra_info"][key] + if key == "Estimated Cardinality": estimate = int(value) else: extra_info += f"{key}: {value}
" cardinality = json_graph["operator_cardinality"] - width = int(json_graph["result_set_size"]/max(1,cardinality)) + width = int(json_graph["result_set_size"] / max(1, cardinality)) # get rid of some typically long names extra_info = re.sub(r"__internal_\s*", "__", extra_info) extra_info = re.sub(r"compress_integral\s*", "compress", extra_info) - node_body = get_node_body(json_graph["operator_type"], - json_graph["operator_timing"], - cpu_time, cardinality, estimate, width, - re.sub(r",\s*", ", ", extra_info)) + node_body = get_node_body( + json_graph["operator_type"], + json_graph["operator_timing"], + cpu_time, + cardinality, + estimate, + width, + re.sub(r",\s*", ", ", extra_info), + ) children_html = "" - if len(json_graph['children']) >= 1: + if len(json_graph["children"]) >= 1: children_html += "
    " for child in json_graph["children"]: children_html += generate_tree_recursive(child, cpu_time) @@ -205,7 +209,7 @@ def generate_tree_recursive(json_graph: object, cpu_time: float) -> str: def generate_timing_html(graph_json: object, query_timings: object) -> object: json_graph = json.loads(graph_json) gather_timing_information(json_graph, query_timings) - total_time = float(json_graph.get('operator_timing') or json_graph.get('latency')) + total_time = float(json_graph.get("operator_timing") or json_graph.get("latency")) table_head = """ @@ -242,12 +246,12 @@ def generate_timing_html(graph_json: object, query_timings: object) -> object: def generate_tree_html(graph_json: object) -> str: json_graph = json.loads(graph_json) - cpu_time = float(json_graph['cpu_time']) - tree_prefix = "
    \n
      " + cpu_time = float(json_graph["cpu_time"]) + tree_prefix = '
      \n
        ' tree_suffix = "
      " # first level of json is general overview # FIXME: make sure json output first level always has only 1 level - tree_body = generate_tree_recursive(json_graph['children'][0], cpu_time) + tree_body = generate_tree_recursive(json_graph["children"][0], cpu_time) return tree_prefix + tree_body + tree_suffix @@ -256,39 +260,32 @@ def generate_ipython(json_input: str) -> str: html_output = generate_html(json_input, False) - return HTML(("\n" - " ${CSS}\n" - " ${LIBRARIES}\n" - "
      \n" - " ${CHART_SCRIPT}\n" - " ").replace("${CSS}", html_output['css']).replace('${CHART_SCRIPT}', - html_output['chart_script']).replace( - '${LIBRARIES}', html_output['libraries'])) + return HTML( + ('\n ${CSS}\n ${LIBRARIES}\n
      \n ${CHART_SCRIPT}\n ') + .replace("${CSS}", html_output["css"]) + .replace("${CHART_SCRIPT}", html_output["chart_script"]) + .replace("${LIBRARIES}", html_output["libraries"]) + ) def generate_style_html(graph_json: str, include_meta_info: bool) -> None: - treeflex_css = "\n" + treeflex_css = '\n' css = "\n" - return { - 'treeflex_css': treeflex_css, - 'duckdb_css': css, - 'libraries': '', - 'chart_script': '' - } + return {"treeflex_css": treeflex_css, "duckdb_css": css, "libraries": "", "chart_script": ""} def gather_timing_information(json: str, query_timings: object) -> None: # add up all of the times # measure each time as a percentage of the total time. # then you can return a list of [phase, time, percentage] - get_child_timings(json['children'][0], query_timings) + get_child_timings(json["children"][0], query_timings) def translate_json_to_html(input_file: str, output_file: str) -> None: query_timings = AllTimings() - with open_utf8(input_file, 'r') as f: + with open_utf8(input_file, "r") as f: text = f.read() html_output = generate_style_html(text, True) @@ -317,10 +314,10 @@ def translate_json_to_html(input_file: str, output_file: str) -> None: """ - html = html.replace("${TREEFLEX_CSS}", html_output['treeflex_css']) - html = html.replace("${DUCKDB_CSS}", html_output['duckdb_css']) + html = html.replace("${TREEFLEX_CSS}", html_output["treeflex_css"]) + html = html.replace("${DUCKDB_CSS}", html_output["duckdb_css"]) html = html.replace("${TIMING_TABLE}", timing_table) - html = html.replace('${TREE}', tree_output) + html = html.replace("${TREE}", tree_output) f.write(html) @@ -329,11 +326,12 @@ def main() -> None: print("Please use python3") exit(1) parser = argparse.ArgumentParser( - prog='Query Graph Generator', - description='Given a json profile output, generate a html file showing the query graph and timings of operators') - parser.add_argument('profile_input', help='profile input in json') - parser.add_argument('--out', required=False, default=False) - parser.add_argument('--open', required=False, action='store_true', default=True) + prog="Query Graph Generator", + description="Given a json profile output, generate a html file showing the query graph and timings of operators", + ) + parser.add_argument("profile_input", help="profile input in json") + parser.add_argument("--out", required=False, default=False) + parser.add_argument("--open", required=False, action="store_true", default=True) args = parser.parse_args() input = args.profile_input @@ -356,8 +354,8 @@ def main() -> None: translate_json_to_html(input, output) if open_output: - webbrowser.open('file://' + os.path.abspath(output), new=2) + webbrowser.open("file://" + os.path.abspath(output), new=2) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/duckdb/typing/__init__.py b/duckdb/typing/__init__.py index d0e95b50..33cf4cd7 100644 --- a/duckdb/typing/__init__.py +++ b/duckdb/typing/__init__.py @@ -26,7 +26,7 @@ USMALLINT, UTINYINT, UUID, - VARCHAR + VARCHAR, ) __all__ = [ @@ -57,5 +57,5 @@ "USMALLINT", "UTINYINT", "UUID", - "VARCHAR" + "VARCHAR", ] diff --git a/duckdb/typing/__init__.pyi b/duckdb/typing/__init__.pyi index 69435c05..8a3cef79 100644 --- a/duckdb/typing/__init__.pyi +++ b/duckdb/typing/__init__.pyi @@ -32,5 +32,7 @@ class DuckDBPyType: def __init__(self, type_str: str, connection: DuckDBPyConnection = ...) -> None: ... def __repr__(self) -> str: ... def __eq__(self, other) -> bool: ... - def __getattr__(self, name: str): DuckDBPyType - def __getitem__(self, name: str): DuckDBPyType \ No newline at end of file + def __getattr__(self, name: str): + DuckDBPyType + def __getitem__(self, name: str): + DuckDBPyType diff --git a/duckdb/value/constant/__init__.pyi b/duckdb/value/constant/__init__.pyi index 8cea58cf..f5190345 100644 --- a/duckdb/value/constant/__init__.pyi +++ b/duckdb/value/constant/__init__.pyi @@ -54,9 +54,9 @@ class DoubleValue(Value): def __repr__(self) -> str: ... class DecimalValue(Value): - def __init__(self, object: Any, width: int, scale: int) -> None: ... - def __repr__(self) -> str: ... - + def __init__(self, object: Any, width: int, scale: int) -> None: ... + def __repr__(self) -> str: ... + class StringValue(Value): def __init__(self, object: Any) -> None: ... def __repr__(self) -> str: ... @@ -109,7 +109,6 @@ class TimeTimeZoneValue(Value): def __init__(self, object: Any) -> None: ... def __repr__(self) -> str: ... - class Value: def __init__(self, object: Any, type: DuckDBPyType) -> None: ... def __repr__(self) -> str: ... diff --git a/duckdb_packaging/_versioning.py b/duckdb_packaging/_versioning.py index ca8e7716..3709dac0 100644 --- a/duckdb_packaging/_versioning.py +++ b/duckdb_packaging/_versioning.py @@ -5,13 +5,16 @@ - Git tag creation and management - Version parsing and validation """ + import pathlib import subprocess from typing import Optional import re -VERSION_RE = re.compile(r"^(?P[0-9]+)\.(?P[0-9]+)\.(?P[0-9]+)(?:rc(?P[0-9]+)|\.post(?P[0-9]+))?$") +VERSION_RE = re.compile( + r"^(?P[0-9]+)\.(?P[0-9]+)\.(?P[0-9]+)(?:rc(?P[0-9]+)|\.post(?P[0-9]+))?$" +) def parse_version(version: str) -> tuple[int, int, int, int, int]: @@ -67,12 +70,12 @@ def git_tag_to_pep440(git_tag: str) -> str: PEP440 version string (e.g., "1.3.1", "1.3.1.post1") """ # Remove 'v' prefix if present - version = git_tag[1:] if git_tag.startswith('v') else git_tag + version = git_tag[1:] if git_tag.startswith("v") else git_tag if "-post" in version: - assert 'rc' not in version + assert "rc" not in version version = version.replace("-post", ".post") - elif '-rc' in version: + elif "-rc" in version: version = version.replace("-rc", "rc") return version @@ -87,10 +90,10 @@ def pep440_to_git_tag(version: str) -> str: Returns: Git tag format (e.g., "v1.3.1-post1") """ - if '.post' in version: - assert 'rc' not in version + if ".post" in version: + assert "rc" not in version version = version.replace(".post", "-post") - elif 'rc' in version: + elif "rc" in version: version = version.replace("rc", "-rc") return f"v{version}" @@ -104,12 +107,7 @@ def get_current_version() -> Optional[str]: """ try: # Get the latest tag - result = subprocess.run( - ["git", "describe", "--tags", "--abbrev=0"], - capture_output=True, - text=True, - check=True - ) + result = subprocess.run(["git", "describe", "--tags", "--abbrev=0"], capture_output=True, text=True, check=True) tag = result.stdout.strip() return git_tag_to_pep440(tag) except subprocess.CalledProcessError: @@ -156,18 +154,18 @@ def get_git_describe(repo_path: Optional[pathlib.Path] = None, since_major=False Git describe output or None if no tags exist """ cwd = repo_path if repo_path is not None else None - pattern="v*.*.*" + pattern = "v*.*.*" if since_major: - pattern="v*.0.0" + pattern = "v*.0.0" elif since_minor: - pattern="v*.*.0" + pattern = "v*.*.0" try: result = subprocess.run( ["git", "describe", "--tags", "--long", "--match", pattern], capture_output=True, text=True, check=True, - cwd=cwd + cwd=cwd, ) result.check_returncode() return result.stdout.strip() diff --git a/duckdb_packaging/build_backend.py b/duckdb_packaging/build_backend.py index de1a9535..b9a005db 100644 --- a/duckdb_packaging/build_backend.py +++ b/duckdb_packaging/build_backend.py @@ -12,6 +12,7 @@ Also see https://peps.python.org/pep-0517/#in-tree-build-backends. """ + import sys import os import subprocess @@ -39,7 +40,7 @@ _FORCED_PEP440_VERSION = forced_version_from_env() -def _log(msg: str, is_error: bool=False) -> None: +def _log(msg: str, is_error: bool = False) -> None: """Log a message with build backend prefix. Args: @@ -84,9 +85,9 @@ def _duckdb_submodule_path() -> Path: cur_module_reponame = None cur_module_path = None elif line.strip().startswith("path"): - cur_module_path = line.split('=')[-1].strip() + cur_module_path = line.split("=")[-1].strip() elif line.strip().startswith("url"): - basename = os.path.basename(line.split('=')[-1].strip()) + basename = os.path.basename(line.split("=")[-1].strip()) cur_module_reponame = basename[:-4] if basename.endswith(".git") else basename if cur_module_reponame is not None and cur_module_path is not None: modules[cur_module_reponame] = cur_module_path @@ -115,7 +116,7 @@ def _version_file_path() -> Path: return package_dir / _DUCKDB_VERSION_FILENAME -def _write_duckdb_long_version(long_version: str)-> None: +def _write_duckdb_long_version(long_version: str) -> None: """Write the given version string to a file in the same directory as this module.""" _version_file_path().write_text(long_version, encoding="utf-8") @@ -126,7 +127,7 @@ def _read_duckdb_long_version() -> str: def _skbuild_config_add( - key: str, value: Union[list, str], config_settings: dict[str, Union[list[str],str]], fail_if_exists: bool=False + key: str, value: Union[list, str], config_settings: dict[str, Union[list[str], str]], fail_if_exists: bool = False ): """Add or modify a configuration setting for scikit-build-core. @@ -178,7 +179,7 @@ def _skbuild_config_add( ) -def build_sdist(sdist_directory: str, config_settings: Optional[dict[str, Union[list[str],str]]] = None) -> str: +def build_sdist(sdist_directory: str, config_settings: Optional[dict[str, Union[list[str], str]]] = None) -> str: """Build a source distribution using the DuckDB submodule. This function extracts the DuckDB version from either the git submodule and saves it @@ -207,9 +208,9 @@ def build_sdist(sdist_directory: str, config_settings: Optional[dict[str, Union[ def build_wheel( - wheel_directory: str, - config_settings: Optional[dict[str, Union[list[str],str]]] = None, - metadata_directory: Optional[str] = None, + wheel_directory: str, + config_settings: Optional[dict[str, Union[list[str], str]]] = None, + metadata_directory: Optional[str] = None, ) -> str: """Build a wheel from either git submodule or extracted sdist sources. @@ -246,7 +247,6 @@ def build_wheel( else: _log("No explicit DuckDB submodule version provided. Letting CMake figure it out.") - return skbuild_build_wheel(wheel_directory, config_settings=config_settings, metadata_directory=metadata_directory) diff --git a/duckdb_packaging/pypi_cleanup.py b/duckdb_packaging/pypi_cleanup.py index 031adf94..80073c0e 100644 --- a/duckdb_packaging/pypi_cleanup.py +++ b/duckdb_packaging/pypi_cleanup.py @@ -28,8 +28,8 @@ from requests.exceptions import RequestException from urllib3 import Retry -_PYPI_URL_PROD = 'https://pypi.org/' -_PYPI_URL_TEST = 'https://test.pypi.org/' +_PYPI_URL_PROD = "https://pypi.org/" +_PYPI_URL_TEST = "https://test.pypi.org/" _DEFAULT_MAX_NIGHTLIES = 2 _LOGIN_RETRY_ATTEMPTS = 3 _LOGIN_RETRY_DELAY = 5 @@ -50,88 +50,70 @@ def create_argument_parser() -> argparse.ArgumentParser: * Keep the configured amount of dev releases per version, and remove older dev releases """, epilog="Environment variables required (unless --dry-run): PYPI_CLEANUP_PASSWORD, PYPI_CLEANUP_OTP", - formatter_class=argparse.RawDescriptionHelpFormatter + formatter_class=argparse.RawDescriptionHelpFormatter, ) - parser.add_argument( - "--dry-run", - action="store_true", - help="Show what would be deleted but don't actually do it" - ) + parser.add_argument("--dry-run", action="store_true", help="Show what would be deleted but don't actually do it") host_group = parser.add_mutually_exclusive_group(required=True) - host_group.add_argument( - "--prod", - action="store_true", - help="Use production PyPI (pypi.org)" - ) - host_group.add_argument( - "--test", - action="store_true", - help="Use test PyPI (test.pypi.org)" - ) + host_group.add_argument("--prod", action="store_true", help="Use production PyPI (pypi.org)") + host_group.add_argument("--test", action="store_true", help="Use test PyPI (test.pypi.org)") parser.add_argument( - "-m", "--max-nightlies", + "-m", + "--max-nightlies", type=int, default=_DEFAULT_MAX_NIGHTLIES, - help=f"Max number of nightlies of unreleased versions (default={_DEFAULT_MAX_NIGHTLIES})" + help=f"Max number of nightlies of unreleased versions (default={_DEFAULT_MAX_NIGHTLIES})", ) - parser.add_argument( - "-u", "--username", - type=validate_username, - help="PyPI username (required unless --dry-run)" - ) + parser.add_argument("-u", "--username", type=validate_username, help="PyPI username (required unless --dry-run)") - parser.add_argument( - "-v", "--verbose", - action="store_true", - help="Enable verbose debug logging" - ) + parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose debug logging") return parser + class PyPICleanupError(Exception): """Base exception for PyPI cleanup operations.""" + pass class AuthenticationError(PyPICleanupError): """Raised when authentication fails.""" + pass class ValidationError(PyPICleanupError): """Raised when input validation fails.""" + pass def setup_logging(verbose: bool = False) -> None: """Configure logging with appropriate level and format.""" level = logging.DEBUG if verbose else logging.INFO - logging.basicConfig( - level=level, - format='%(asctime)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) + logging.basicConfig(level=level, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S") def validate_username(value: str) -> str: """Validate and sanitize username input.""" if not value or not value.strip(): raise argparse.ArgumentTypeError("Username cannot be empty") - + username = value.strip() if len(username) > 100: # Reasonable limit raise argparse.ArgumentTypeError("Username too long (max 100 characters)") - + # Basic validation - PyPI usernames are alphanumeric with limited special chars - if not re.match(r'^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$|^[a-zA-Z0-9]$', username): + if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$|^[a-zA-Z0-9]$", username): raise argparse.ArgumentTypeError("Invalid username format") - + return username + @contextlib.contextmanager def session_with_retries() -> Generator[Session, None, None]: """Create a requests session with retry strategy for ephemeral errors.""" @@ -154,19 +136,20 @@ def session_with_retries() -> Generator[Session, None, None]: session.mount("https://", adapter) yield session + def load_credentials(dry_run: bool) -> tuple[Optional[str], Optional[str]]: """Load credentials from environment variables.""" if dry_run: return None, None - - password = os.getenv('PYPI_CLEANUP_PASSWORD') - otp = os.getenv('PYPI_CLEANUP_OTP') - + + password = os.getenv("PYPI_CLEANUP_PASSWORD") + otp = os.getenv("PYPI_CLEANUP_OTP") + if not password: raise ValidationError("PYPI_CLEANUP_PASSWORD environment variable is required when not in dry-run mode") if not otp: raise ValidationError("PYPI_CLEANUP_OTP environment variable is required when not in dry-run mode") - + return password, otp @@ -174,15 +157,17 @@ def validate_arguments(args: argparse.Namespace) -> None: """Validate parsed arguments.""" if not args.dry_run and not args.username: raise ValidationError("--username is required when not in dry-run mode") - + if args.max_nightlies < 0: raise ValidationError("--max-nightlies must be non-negative") + class CsrfParser(HTMLParser): """HTML parser to extract CSRF tokens from PyPI forms. - + Based on pypi-cleanup package (https://github.com/arcivanov/pypi-cleanup/tree/master) """ + def __init__(self, target, contains_input=None) -> None: super().__init__() self._target = target @@ -222,24 +207,31 @@ def handle_endtag(self, tag): class PyPICleanup: """Main class for performing PyPI package cleanup operations.""" - def __init__(self, index_url: str, do_delete: bool, max_dev_releases: int=_DEFAULT_MAX_NIGHTLIES, - username: Optional[str]=None, password: Optional[str]=None, otp: Optional[str]=None) -> None: + def __init__( + self, + index_url: str, + do_delete: bool, + max_dev_releases: int = _DEFAULT_MAX_NIGHTLIES, + username: Optional[str] = None, + password: Optional[str] = None, + otp: Optional[str] = None, + ) -> None: parsed_url = urlparse(index_url) - self._index_url = parsed_url.geturl().rstrip('/') + self._index_url = parsed_url.geturl().rstrip("/") self._index_host = parsed_url.hostname self._do_delete = do_delete self._max_dev_releases = max_dev_releases self._username = username self._password = password self._otp = otp - self._package = 'duckdb' + self._package = "duckdb" self._dev_version_pattern = re.compile(r"^(?P\d+\.\d+\.\d+)\.dev(?P\d+)$") self._rc_version_pattern = re.compile(r"^(?P\d+\.\d+\.\d+)\.rc\d+$") self._stable_version_pattern = re.compile(r"^\d+\.\d+\.\d+(\.post\d+)?$") def run(self) -> int: """Execute the cleanup process. - + Returns: int: Exit code (0 for success, non-zero for failure) """ @@ -268,17 +260,17 @@ def _execute_cleanup(self, http_session: Session) -> int: if not versions: logging.info(f"No releases found for {self._package}") return 0 - + # Determine versions to delete versions_to_delete = self._determine_versions_to_delete(versions) if not versions_to_delete: logging.info("No versions to delete (no stale rc's or dev releases)") return 0 - + logging.warning(f"Found {len(versions_to_delete)} versions to clean up:") for version in sorted(versions_to_delete): logging.warning(version) - + if not self._do_delete: logging.info("Dry run complete - no packages were deleted") return 0 @@ -286,14 +278,14 @@ def _execute_cleanup(self, http_session: Session) -> int: # Perform authentication and deletion self._authenticate(http_session) self._delete_versions(http_session, versions_to_delete) - + logging.info(f"Successfully cleaned up {len(versions_to_delete)} development versions") return 0 - + def _fetch_released_versions(self, http_session: Session) -> set[str]: """Fetch package release information from PyPI API.""" logging.debug(f"Fetching package information for '{self._package}'") - + try: req = http_session.get(f"{self._index_url}/pypi/{self._package}/json") req.raise_for_status() @@ -392,12 +384,12 @@ def _determine_versions_to_delete(self, versions: set[str]) -> set[str]: logging.warning(f"Found version string(s) in an unsupported format: {unknown_versions}") return versions_to_delete - + def _authenticate(self, http_session: Session) -> None: """Authenticate with PyPI.""" if not self._username or not self._password: raise AuthenticationError("Username and password are required for authentication") - + logging.info(f"Authenticating user '{self._username}' with PyPI") try: @@ -408,12 +400,12 @@ def _authenticate(self, http_session: Session) -> None: if login_response.url.startswith(f"{self._index_url}/account/two-factor/"): logging.debug("Two-factor authentication required") self._handle_two_factor_auth(http_session, login_response) - + logging.info("Authentication successful") except RequestException as e: raise AuthenticationError(f"Network error during authentication: {e}") from e - + def _get_csrf_token(self, http_session: Session, form_action: str) -> str: """Extract CSRF token from a form page.""" resp = http_session.get(f"{self._index_url}{form_action}") @@ -423,23 +415,19 @@ def _get_csrf_token(self, http_session: Session, form_action: str) -> str: if not parser.csrf: raise AuthenticationError(f"No CSRF token found in {form_action}") return parser.csrf - + def _perform_login(self, http_session: Session) -> requests.Response: """Perform the initial login with username/password.""" # Get login form and CSRF token csrf_token = self._get_csrf_token(http_session, "/account/login/") - login_data = { - "csrf_token": csrf_token, - "username": self._username, - "password": self._password - } + login_data = {"csrf_token": csrf_token, "username": self._username, "password": self._password} response = http_session.post( f"{self._index_url}/account/login/", data=login_data, - headers={"referer": f"{self._index_url}/account/login/"} + headers={"referer": f"{self._index_url}/account/login/"}, ) response.raise_for_status() @@ -448,16 +436,16 @@ def _perform_login(self, http_session: Session) -> requests.Response: raise AuthenticationError(f"Login failed for user '{self._username}' - check credentials") return response - + def _handle_two_factor_auth(self, http_session: Session, response: requests.Response) -> None: """Handle two-factor authentication.""" if not self._otp: raise AuthenticationError("Two-factor authentication required but no OTP secret provided") - + two_factor_url = response.url - form_action = two_factor_url[len(self._index_url):] + form_action = two_factor_url[len(self._index_url) :] csrf_token = self._get_csrf_token(http_session, form_action) - + # Try authentication with retries for attempt in range(_LOGIN_RETRY_ATTEMPTS): try: @@ -467,7 +455,7 @@ def _handle_two_factor_auth(self, http_session: Session, response: requests.Resp auth_response = http_session.post( two_factor_url, data={"csrf_token": csrf_token, "method": "totp", "totp_value": auth_code}, - headers={"referer": two_factor_url} + headers={"referer": two_factor_url}, ) auth_response.raise_for_status() @@ -479,19 +467,19 @@ def _handle_two_factor_auth(self, http_session: Session, response: requests.Resp if attempt < _LOGIN_RETRY_ATTEMPTS - 1: logging.debug(f"2FA code rejected, retrying in {_LOGIN_RETRY_DELAY} seconds...") time.sleep(_LOGIN_RETRY_DELAY) - + except RequestException as e: if attempt == _LOGIN_RETRY_ATTEMPTS - 1: raise AuthenticationError(f"Network error during 2FA: {e}") from e logging.debug(f"Network error during 2FA attempt {attempt + 1}, retrying...") time.sleep(_LOGIN_RETRY_DELAY) - + raise AuthenticationError("Two-factor authentication failed after all attempts") - + def _delete_versions(self, http_session: Session, versions_to_delete: set[str]) -> None: """Delete the specified package versions.""" logging.info(f"Starting deletion of {len(versions_to_delete)} development versions") - + failed_deletions = list() for version in sorted(versions_to_delete): try: @@ -501,24 +489,24 @@ def _delete_versions(self, http_session: Session, versions_to_delete: set[str]) # Continue with other versions rather than failing completely logging.error(f"Failed to delete version {version}: {e}") failed_deletions.append(version) - + if failed_deletions: raise PyPICleanupError( f"Failed to delete {len(failed_deletions)}/{len(versions_to_delete)} versions: {failed_deletions}" ) - + def _delete_single_version(self, http_session: Session, version: str) -> None: """Delete a single package version.""" # Safety check if not self._is_dev_version(version) or self._is_rc_version(version): raise PyPICleanupError(f"Refusing to delete non-[dev|rc] version: {version}") - + logging.debug(f"Deleting {self._package} version {version}") - + # Get deletion form and CSRF token form_action = f"/manage/project/{self._package}/release/{version}/" form_url = f"{self._index_url}{form_action}" - + csrf_token = self._get_csrf_token(http_session, form_action) # Submit deletion request @@ -528,7 +516,7 @@ def _delete_single_version(self, http_session: Session, version: str) -> None: "csrf_token": csrf_token, "confirm_delete_version": version, }, - headers={"referer": form_url} + headers={"referer": form_url}, ) delete_response.raise_for_status() @@ -537,26 +525,27 @@ def main() -> int: """Main entry point for the script.""" parser = create_argument_parser() args = parser.parse_args() - + # Setup logging setup_logging(args.verbose) - + try: # Validate arguments validate_arguments(args) - + # Load credentials password, otp = load_credentials(args.dry_run) - + # Determine PyPI URL pypi_url = _PYPI_URL_PROD if args.prod else _PYPI_URL_TEST - + # Create and run cleanup - cleanup = PyPICleanup(pypi_url, not args.dry_run, args.max_nightlies, username=args.username, - password=password, otp=otp) - + cleanup = PyPICleanup( + pypi_url, not args.dry_run, args.max_nightlies, username=args.username, password=password, otp=otp + ) + return cleanup.run() - + except ValidationError as e: logging.error(f"Configuration error: {e}") return 2 diff --git a/duckdb_packaging/setuptools_scm_version.py b/duckdb_packaging/setuptools_scm_version.py index 27bedd24..217b2ffe 100644 --- a/duckdb_packaging/setuptools_scm_version.py +++ b/duckdb_packaging/setuptools_scm_version.py @@ -21,9 +21,10 @@ def _main_branch_versioning(): - from_env = os.getenv('MAIN_BRANCH_VERSIONING') + from_env = os.getenv("MAIN_BRANCH_VERSIONING") return from_env == "1" if from_env is not None else MAIN_BRANCH_VERSIONING + def version_scheme(version: Any) -> str: """ setuptools_scm version scheme that matches DuckDB's original behavior. @@ -65,13 +66,13 @@ def _bump_version(base_version: str, distance: int, dirty: bool = False) -> str: # Otherwise we're at a distance and / or dirty, and need to bump if post != 0: # We're developing on top of a post-release - return f"{format_version(major, minor, patch, post=post+1)}.dev{distance}" + return f"{format_version(major, minor, patch, post=post + 1)}.dev{distance}" elif rc != 0: # We're developing on top of an rc - return f"{format_version(major, minor, patch, rc=rc+1)}.dev{distance}" + return f"{format_version(major, minor, patch, rc=rc + 1)}.dev{distance}" elif _main_branch_versioning(): - return f"{format_version(major, minor+1, 0)}.dev{distance}" - return f"{format_version(major, minor, patch+1)}.dev{distance}" + return f"{format_version(major, minor + 1, 0)}.dev{distance}" + return f"{format_version(major, minor, patch + 1)}.dev{distance}" def forced_version_from_env(): @@ -117,9 +118,9 @@ def _git_describe_override_to_pep_440(override_value: str) -> str: version, distance, commit_hash = match.groups() # Convert version format to PEP440 format (v1.3.1-post1 -> 1.3.1.post1) - if '-post' in version: + if "-post" in version: version = version.replace("-post", ".post") - elif '-rc' in version: + elif "-rc" in version: version = version.replace("-rc", "rc") # Bump version and format according to PEP440 diff --git a/scripts/generate_connection_code.py b/scripts/generate_connection_code.py index 3737f83a..8e2bace9 100644 --- a/scripts/generate_connection_code.py +++ b/scripts/generate_connection_code.py @@ -3,7 +3,7 @@ import generate_connection_wrapper_methods import generate_connection_wrapper_stubs -if __name__ == '__main__': +if __name__ == "__main__": generate_connection_methods.generate() generate_connection_stubs.generate() generate_connection_wrapper_methods.generate() diff --git a/scripts/generate_connection_methods.py b/scripts/generate_connection_methods.py index c1f01e54..a48b6142 100644 --- a/scripts/generate_connection_methods.py +++ b/scripts/generate_connection_methods.py @@ -13,23 +13,23 @@ def is_py_kwargs(method): - return 'kwargs_as_dict' in method and method['kwargs_as_dict'] == True + return "kwargs_as_dict" in method and method["kwargs_as_dict"] == True def is_py_args(method): - if 'args' not in method: + if "args" not in method: return False - args = method['args'] + args = method["args"] if len(args) == 0: return False - if args[0]['name'] != '*args': + if args[0]["name"] != "*args": return False return True def generate(): # Read the PYCONNECTION_SOURCE file - with open(PYCONNECTION_SOURCE, 'r') as source_file: + with open(PYCONNECTION_SOURCE, "r") as source_file: source_code = source_file.readlines() start_index = -1 @@ -52,16 +52,16 @@ def generate(): # ---- Generate the definition code from the json ---- # Read the JSON file - with open(JSON_PATH, 'r') as json_file: + with open(JSON_PATH, "r") as json_file: connection_methods = json.load(json_file) DEFAULT_ARGUMENT_MAP = { - 'True': 'true', - 'False': 'false', - 'None': 'py::none()', - 'PythonUDFType.NATIVE': 'PythonUDFType::NATIVE', - 'PythonExceptionHandling.DEFAULT': 'PythonExceptionHandling::FORWARD_ERROR', - 'FunctionNullHandling.DEFAULT': 'FunctionNullHandling::DEFAULT_NULL_HANDLING', + "True": "true", + "False": "false", + "None": "py::none()", + "PythonUDFType.NATIVE": "PythonUDFType::NATIVE", + "PythonExceptionHandling.DEFAULT": "PythonExceptionHandling::FORWARD_ERROR", + "FunctionNullHandling.DEFAULT": "FunctionNullHandling::DEFAULT_NULL_HANDLING", } def map_default(val): @@ -72,61 +72,61 @@ def map_default(val): def create_arguments(arguments) -> list: result = [] for arg in arguments: - if arg['name'] == '*args': + if arg["name"] == "*args": break - argument = f"py::arg(\"{arg['name']}\")" - if 'allow_none' in arg: - value = str(arg['allow_none']).lower() + argument = f'py::arg("{arg["name"]}")' + if "allow_none" in arg: + value = str(arg["allow_none"]).lower() argument += f".none({value})" # Add the default argument if present - if 'default' in arg: - default = map_default(arg['default']) + if "default" in arg: + default = map_default(arg["default"]) argument += f" = {default}" result.append(argument) return result def create_definition(name, method) -> str: - definition = f"m.def(\"{name}\"" + definition = f'm.def("{name}"' definition += ", " - definition += f"""&DuckDBPyConnection::{method['function']}""" + definition += f"""&DuckDBPyConnection::{method["function"]}""" definition += ", " - definition += f"\"{method['docs']}\"" - if 'args' in method and not is_py_args(method): + definition += f'"{method["docs"]}"' + if "args" in method and not is_py_args(method): definition += ", " - arguments = create_arguments(method['args']) - definition += ', '.join(arguments) - if 'kwargs' in method: + arguments = create_arguments(method["args"]) + definition += ", ".join(arguments) + if "kwargs" in method: definition += ", " if is_py_kwargs(method): definition += "py::kw_only()" else: definition += "py::kw_only(), " - arguments = create_arguments(method['kwargs']) - definition += ', '.join(arguments) + arguments = create_arguments(method["kwargs"]) + definition += ", ".join(arguments) definition += ");" return definition body = [] for method in connection_methods: - if isinstance(method['name'], list): - names = method['name'] + if isinstance(method["name"], list): + names = method["name"] else: - names = [method['name']] + names = [method["name"]] for name in names: body.append(create_definition(name, method)) # ---- End of generation code ---- - with_newlines = ['\t' + x + '\n' for x in body] + with_newlines = ["\t" + x + "\n" for x in body] # Recreate the file content by concatenating all the pieces together new_content = start_section + with_newlines + end_section # Write out the modified PYCONNECTION_SOURCE file - with open(PYCONNECTION_SOURCE, 'w') as source_file: + with open(PYCONNECTION_SOURCE, "w") as source_file: source_file.write("".join(new_content)) -if __name__ == '__main__': +if __name__ == "__main__": raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") # generate() diff --git a/scripts/generate_connection_stubs.py b/scripts/generate_connection_stubs.py index fbb66c21..e3831173 100644 --- a/scripts/generate_connection_stubs.py +++ b/scripts/generate_connection_stubs.py @@ -12,7 +12,7 @@ def generate(): # Read the DUCKDB_STUBS_FILE file - with open(DUCKDB_STUBS_FILE, 'r') as source_file: + with open(DUCKDB_STUBS_FILE, "r") as source_file: source_code = source_file.readlines() start_index = -1 @@ -35,7 +35,7 @@ def generate(): # ---- Generate the definition code from the json ---- # Read the JSON file - with open(JSON_PATH, 'r') as json_file: + with open(JSON_PATH, "r") as json_file: connection_methods = json.load(json_file) body = [] @@ -45,8 +45,8 @@ def create_arguments(arguments) -> list: for arg in arguments: argument = f"{arg['name']}: {arg['type']}" # Add the default argument if present - if 'default' in arg: - default = arg['default'] + if "default" in arg: + default = arg["default"] argument += f" = {default}" result.append(argument) return result @@ -57,13 +57,13 @@ def create_definition(name, method, overloaded: bool) -> str: else: definition: str = "" definition += f"def {name}(" - arguments = ['self'] - if 'args' in method: - arguments.extend(create_arguments(method['args'])) - if 'kwargs' in method: - if not any(x.startswith('*') for x in arguments): + arguments = ["self"] + if "args" in method: + arguments.extend(create_arguments(method["args"])) + if "kwargs" in method: + if not any(x.startswith("*") for x in arguments): arguments.append("*") - arguments.extend(create_arguments(method['kwargs'])) + arguments.extend(create_arguments(method["kwargs"])) definition += ", ".join(arguments) definition += ")" definition += f" -> {method['return']}: ..." @@ -71,28 +71,28 @@ def create_definition(name, method, overloaded: bool) -> str: # We have "duplicate" methods, which are overloaded. # We keep note of them to add the @overload decorator. - overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m['name'], list)} + overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m["name"], list)} for method in connection_methods: - if isinstance(method['name'], list): - names = method['name'] + if isinstance(method["name"], list): + names = method["name"] else: - names = [method['name']] + names = [method["name"]] for name in names: body.append(create_definition(name, method, name in overloaded_methods)) # ---- End of generation code ---- - with_newlines = [' ' + x + '\n' for x in body] + with_newlines = [" " + x + "\n" for x in body] # Recreate the file content by concatenating all the pieces together new_content = start_section + with_newlines + end_section # Write out the modified DUCKDB_STUBS_FILE file - with open(DUCKDB_STUBS_FILE, 'w') as source_file: + with open(DUCKDB_STUBS_FILE, "w") as source_file: source_file.write("".join(new_content)) -if __name__ == '__main__': +if __name__ == "__main__": raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") # generate() diff --git a/scripts/generate_connection_wrapper_methods.py b/scripts/generate_connection_wrapper_methods.py index af5ad4ac..45ac45cc 100644 --- a/scripts/generate_connection_wrapper_methods.py +++ b/scripts/generate_connection_wrapper_methods.py @@ -40,16 +40,16 @@ INIT_PY_END = "# END OF CONNECTION WRAPPER" # Read the JSON file -with open(WRAPPER_JSON_PATH, 'r') as json_file: +with open(WRAPPER_JSON_PATH, "r") as json_file: wrapper_methods = json.load(json_file) # On DuckDBPyConnection these are read_only_properties, they're basically functions without requiring () to invoke # that's not possible on 'duckdb' so it becomes a function call with no arguments (i.e duckdb.description()) -READONLY_PROPERTY_NAMES = ['description', 'rowcount'] +READONLY_PROPERTY_NAMES = ["description", "rowcount"] # These methods are not directly DuckDBPyConnection methods, # they first call 'FromDF' and then call a method on the created DuckDBPyRelation -SPECIAL_METHOD_NAMES = [x['name'] for x in wrapper_methods if x['name'] not in READONLY_PROPERTY_NAMES] +SPECIAL_METHOD_NAMES = [x["name"] for x in wrapper_methods if x["name"] not in READONLY_PROPERTY_NAMES] RETRIEVE_CONN_FROM_DICT = """auto connection_arg = kwargs.contains("conn") ? kwargs["conn"] : py::none(); auto conn = py::cast>(connection_arg); @@ -57,18 +57,18 @@ def is_py_args(method): - if 'args' not in method: + if "args" not in method: return False - args = method['args'] + args = method["args"] if len(args) == 0: return False - if args[0]['name'] != '*args': + if args[0]["name"] != "*args": return False return True def is_py_kwargs(method): - return 'kwargs_as_dict' in method and method['kwargs_as_dict'] == True + return "kwargs_as_dict" in method and method["kwargs_as_dict"] == True def remove_section(content, start_marker, end_marker) -> tuple[list[str], list[str]]: @@ -94,33 +94,33 @@ def remove_section(content, start_marker, end_marker) -> tuple[list[str], list[s def generate(): # Read the DUCKDB_PYTHON_SOURCE file - with open(DUCKDB_PYTHON_SOURCE, 'r') as source_file: + with open(DUCKDB_PYTHON_SOURCE, "r") as source_file: source_code = source_file.readlines() start_section, end_section = remove_section(source_code, START_MARKER, END_MARKER) # Read the DUCKDB_INIT_FILE file - with open(DUCKDB_INIT_FILE, 'r') as source_file: + with open(DUCKDB_INIT_FILE, "r") as source_file: source_code = source_file.readlines() py_start, py_end = remove_section(source_code, INIT_PY_START, INIT_PY_END) # ---- Generate the definition code from the json ---- # Read the JSON file - with open(JSON_PATH, 'r') as json_file: + with open(JSON_PATH, "r") as json_file: connection_methods = json.load(json_file) # Collect the definitions from the pyconnection.hpp header - cpp_connection_defs = get_methods('DuckDBPyConnection') - cpp_relation_defs = get_methods('DuckDBPyRelation') + cpp_connection_defs = get_methods("DuckDBPyConnection") + cpp_relation_defs = get_methods("DuckDBPyRelation") DEFAULT_ARGUMENT_MAP = { - 'True': 'true', - 'False': 'false', - 'None': 'py::none()', - 'PythonUDFType.NATIVE': 'PythonUDFType::NATIVE', - 'PythonExceptionHandling.DEFAULT': 'PythonExceptionHandling::FORWARD_ERROR', - 'FunctionNullHandling.DEFAULT': 'FunctionNullHandling::DEFAULT_NULL_HANDLING', + "True": "true", + "False": "false", + "None": "py::none()", + "PythonUDFType.NATIVE": "PythonUDFType::NATIVE", + "PythonExceptionHandling.DEFAULT": "PythonExceptionHandling::FORWARD_ERROR", + "FunctionNullHandling.DEFAULT": "FunctionNullHandling::DEFAULT_NULL_HANDLING", } def map_default(val): @@ -131,16 +131,16 @@ def map_default(val): def create_arguments(arguments) -> list: result = [] for arg in arguments: - if arg['name'] == '*args': + if arg["name"] == "*args": # py::args() should not have a corresponding py::arg() continue - argument = f"py::arg(\"{arg['name']}\")" - if 'allow_none' in arg: - value = str(arg['allow_none']).lower() + argument = f'py::arg("{arg["name"]}")' + if "allow_none" in arg: + value = str(arg["allow_none"]).lower() argument += f".none({value})" # Add the default argument if present - if 'default' in arg: - default = map_default(arg['default']) + if "default" in arg: + default = map_default(arg["default"]) argument += f" = {default}" result.append(argument) return result @@ -148,11 +148,11 @@ def create_arguments(arguments) -> list: def get_lambda_definition(name, method, definition: ConnectionMethod) -> str: param_definitions = [] if name in SPECIAL_METHOD_NAMES: - param_definitions.append('const PandasDataFrame &df') + param_definitions.append("const PandasDataFrame &df") param_definitions.extend([x.proto for x in definition.params]) if not is_py_kwargs(method): - param_definitions.append('shared_ptr conn = nullptr') + param_definitions.append("shared_ptr conn = nullptr") param_definitions = ", ".join(param_definitions) param_names = [x.name for x in definition.params] @@ -160,73 +160,73 @@ def get_lambda_definition(name, method, definition: ConnectionMethod) -> str: function_name = definition.name if name in SPECIAL_METHOD_NAMES: - function_name = 'FromDF(df)->' + function_name + function_name = "FromDF(df)->" + function_name format_dict = { - 'param_definitions': param_definitions, - 'opt_retrieval': '', - 'opt_return': '' if definition.is_void else 'return ', - 'function_name': function_name, - 'parameter_names': param_names, + "param_definitions": param_definitions, + "opt_retrieval": "", + "opt_return": "" if definition.is_void else "return ", + "function_name": function_name, + "parameter_names": param_names, } if is_py_kwargs(method): - format_dict['opt_retrieval'] += RETRIEVE_CONN_FROM_DICT + format_dict["opt_retrieval"] += RETRIEVE_CONN_FROM_DICT return LAMBDA_FORMAT.format_map(format_dict) def create_definition(name, method, lambda_def) -> str: - definition = f"m.def(\"{name}\"" + definition = f'm.def("{name}"' definition += ", " definition += lambda_def definition += ", " - definition += f"\"{method['docs']}\"" - if 'args' in method and not is_py_args(method): + definition += f'"{method["docs"]}"' + if "args" in method and not is_py_args(method): definition += ", " - arguments = create_arguments(method['args']) - definition += ', '.join(arguments) - if 'kwargs' in method: + arguments = create_arguments(method["args"]) + definition += ", ".join(arguments) + if "kwargs" in method: definition += ", " if is_py_kwargs(method): definition += "py::kw_only()" else: definition += "py::kw_only(), " - arguments = create_arguments(method['kwargs']) - definition += ', '.join(arguments) + arguments = create_arguments(method["kwargs"]) + definition += ", ".join(arguments) definition += ");" return definition body = [] all_names = [] for method in connection_methods: - if isinstance(method['name'], list): - names = method['name'] + if isinstance(method["name"], list): + names = method["name"] else: - names = [method['name']] - if 'kwargs' not in method: - method['kwargs'] = [] - method['kwargs'].append({'name': 'connection', 'type': 'Optional[DuckDBPyConnection]', 'default': 'None'}) + names = [method["name"]] + if "kwargs" not in method: + method["kwargs"] = [] + method["kwargs"].append({"name": "connection", "type": "Optional[DuckDBPyConnection]", "default": "None"}) for name in names: - function_name = method['function'] + function_name = method["function"] cpp_definition = cpp_connection_defs[function_name] lambda_def = get_lambda_definition(name, method, cpp_definition) body.append(create_definition(name, method, lambda_def)) all_names.append(name) for method in wrapper_methods: - if isinstance(method['name'], list): - names = method['name'] + if isinstance(method["name"], list): + names = method["name"] else: - names = [method['name']] - if 'kwargs' not in method: - method['kwargs'] = [] - method['kwargs'].append({'name': 'connection', 'type': 'Optional[DuckDBPyConnection]', 'default': 'None'}) + names = [method["name"]] + if "kwargs" not in method: + method["kwargs"] = [] + method["kwargs"].append({"name": "connection", "type": "Optional[DuckDBPyConnection]", "default": "None"}) for name in names: - function_name = method['function'] + function_name = method["function"] if name in SPECIAL_METHOD_NAMES: cpp_definition = cpp_relation_defs[function_name] - if 'args' not in method: - method['args'] = [] - method['args'].insert(0, {'name': 'df', 'type': 'DataFrame'}) + if "args" not in method: + method["args"] = [] + method["args"].insert(0, {"name": "df", "type": "DataFrame"}) else: cpp_definition = cpp_connection_defs[function_name] lambda_def = get_lambda_definition(name, method, cpp_definition) @@ -235,24 +235,24 @@ def create_definition(name, method, lambda_def) -> str: # ---- End of generation code ---- - with_newlines = ['\t' + x + '\n' for x in body] + with_newlines = ["\t" + x + "\n" for x in body] # Recreate the file content by concatenating all the pieces together new_content = start_section + with_newlines + end_section # Write out the modified DUCKDB_PYTHON_SOURCE file - with open(DUCKDB_PYTHON_SOURCE, 'w') as source_file: + with open(DUCKDB_PYTHON_SOURCE, "w") as source_file: source_file.write("".join(new_content)) - item_list = '\n'.join([f'\t{name},' for name in all_names]) - str_item_list = '\n'.join([f"\t'{name}'," for name in all_names]) - imports = PY_INIT_FORMAT.format(item_list=item_list, str_item_list=str_item_list).split('\n') - imports = [x + '\n' for x in imports] + item_list = "\n".join([f"\t{name}," for name in all_names]) + str_item_list = "\n".join([f"\t'{name}'," for name in all_names]) + imports = PY_INIT_FORMAT.format(item_list=item_list, str_item_list=str_item_list).split("\n") + imports = [x + "\n" for x in imports] init_py_content = py_start + imports + py_end # Write out the modified DUCKDB_INIT_FILE file - with open(DUCKDB_INIT_FILE, 'w') as source_file: + with open(DUCKDB_INIT_FILE, "w") as source_file: source_file.write("".join(init_py_content)) -if __name__ == '__main__': +if __name__ == "__main__": # raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") generate() diff --git a/scripts/generate_connection_wrapper_stubs.py b/scripts/generate_connection_wrapper_stubs.py index 62c60a84..02e36c4e 100644 --- a/scripts/generate_connection_wrapper_stubs.py +++ b/scripts/generate_connection_wrapper_stubs.py @@ -13,7 +13,7 @@ def generate(): # Read the DUCKDB_STUBS_FILE file - with open(DUCKDB_STUBS_FILE, 'r') as source_file: + with open(DUCKDB_STUBS_FILE, "r") as source_file: source_code = source_file.readlines() start_index = -1 @@ -38,10 +38,10 @@ def generate(): methods = [] # Read the JSON file - with open(JSON_PATH, 'r') as json_file: + with open(JSON_PATH, "r") as json_file: connection_methods = json.load(json_file) - with open(WRAPPER_JSON_PATH, 'r') as json_file: + with open(WRAPPER_JSON_PATH, "r") as json_file: wrapper_methods = json.load(json_file) methods.extend(connection_methods) @@ -49,19 +49,19 @@ def generate(): # On DuckDBPyConnection these are read_only_properties, they're basically functions without requiring () to invoke # that's not possible on 'duckdb' so it becomes a function call with no arguments (i.e duckdb.description()) - READONLY_PROPERTY_NAMES = ['description', 'rowcount'] + READONLY_PROPERTY_NAMES = ["description", "rowcount"] # These methods are not directly DuckDBPyConnection methods, # they first call 'from_df' and then call a method on the created DuckDBPyRelation - SPECIAL_METHOD_NAMES = [x['name'] for x in wrapper_methods if x['name'] not in READONLY_PROPERTY_NAMES] + SPECIAL_METHOD_NAMES = [x["name"] for x in wrapper_methods if x["name"] not in READONLY_PROPERTY_NAMES] def create_arguments(arguments) -> list: result = [] for arg in arguments: argument = f"{arg['name']}: {arg['type']}" # Add the default argument if present - if 'default' in arg: - default = arg['default'] + if "default" in arg: + default = arg["default"] argument += f" = {default}" result.append(argument) return result @@ -74,49 +74,49 @@ def create_definition(name, method, overloaded: bool) -> str: definition += f"def {name}(" arguments = [] if name in SPECIAL_METHOD_NAMES: - arguments.append('df: pandas.DataFrame') - if 'args' in method: - arguments.extend(create_arguments(method['args'])) - if 'kwargs' in method: - if not any(x.startswith('*') for x in arguments): + arguments.append("df: pandas.DataFrame") + if "args" in method: + arguments.extend(create_arguments(method["args"])) + if "kwargs" in method: + if not any(x.startswith("*") for x in arguments): arguments.append("*") - arguments.extend(create_arguments(method['kwargs'])) - definition += ', '.join(arguments) + arguments.extend(create_arguments(method["kwargs"])) + definition += ", ".join(arguments) definition += ")" definition += f" -> {method['return']}: ..." return definition # We have "duplicate" methods, which are overloaded. # We keep note of them to add the @overload decorator. - overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m['name'], list)} + overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m["name"], list)} body = [] for method in methods: - if isinstance(method['name'], list): - names = method['name'] + if isinstance(method["name"], list): + names = method["name"] else: - names = [method['name']] + names = [method["name"]] # Artificially add 'connection' keyword argument - if 'kwargs' not in method: - method['kwargs'] = [] - method['kwargs'].append({'name': 'connection', 'type': 'DuckDBPyConnection', 'default': '...'}) + if "kwargs" not in method: + method["kwargs"] = [] + method["kwargs"].append({"name": "connection", "type": "DuckDBPyConnection", "default": "..."}) for name in names: body.append(create_definition(name, method, name in overloaded_methods)) # ---- End of generation code ---- - with_newlines = [x + '\n' for x in body] + with_newlines = [x + "\n" for x in body] # Recreate the file content by concatenating all the pieces together new_content = start_section + with_newlines + end_section # Write out the modified DUCKDB_STUBS_FILE file - with open(DUCKDB_STUBS_FILE, 'w') as source_file: + with open(DUCKDB_STUBS_FILE, "w") as source_file: source_file.write("".join(new_content)) -if __name__ == '__main__': +if __name__ == "__main__": raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") # generate() diff --git a/scripts/generate_import_cache_cpp.py b/scripts/generate_import_cache_cpp.py index f03d8d89..8a4b0c36 100644 --- a/scripts/generate_import_cache_cpp.py +++ b/scripts/generate_import_cache_cpp.py @@ -16,97 +16,97 @@ # deal with leaf nodes?? Those are just PythonImportCacheItem def get_class_name(path: str) -> str: - parts: list[str] = path.replace('_', '').split('.') + parts: list[str] = path.replace("_", "").split(".") parts = [x.title() for x in parts] - return ''.join(parts) + 'CacheItem' + return "".join(parts) + "CacheItem" def get_filename(name: str) -> str: - return name.replace('_', '').lower() + '_module.hpp' + return name.replace("_", "").lower() + "_module.hpp" def get_variable_name(name: str) -> str: - if name in ['short', 'ushort']: - return name + '_' + if name in ["short", "ushort"]: + return name + "_" return name def collect_items_of_module(module: dict, collection: dict): global json_data - children = module['children'] - collection[module['full_path']] = module + children = module["children"] + collection[module["full_path"]] = module for child in children: collect_items_of_module(json_data[child], collection) class CacheItem: def __init__(self, module: dict, items) -> None: - self.name = module['name'] + self.name = module["name"] self.module = module self.items = items - self.class_name = get_class_name(module['full_path']) + self.class_name = get_class_name(module["full_path"]) def get_full_module_path(self): - if self.module['type'] != 'module': - return '' - full_path = self.module['full_path'] + if self.module["type"] != "module": + return "" + full_path = self.module["full_path"] return f""" public: \tstatic constexpr const char *Name = "{full_path}"; """ def get_optionally_required(self): - if 'required' not in self.module: - return '' + if "required" not in self.module: + return "" string = f""" protected: \tbool IsRequired() const override final {{ -\t\treturn {str(self.module['required']).lower()}; +\t\treturn {str(self.module["required"]).lower()}; \t}} """ return string def get_variables(self): variables = [] - for key in self.module['children']: + for key in self.module["children"]: item = self.items[key] - name = item['name'] + name = item["name"] var_name = get_variable_name(name) - if item['children'] == []: - class_name = 'PythonImportCacheItem' + if item["children"] == []: + class_name = "PythonImportCacheItem" else: - class_name = get_class_name(item['full_path']) - variables.append(f'\t{class_name} {var_name};') - return '\n'.join(variables) + class_name = get_class_name(item["full_path"]) + variables.append(f"\t{class_name} {var_name};") + return "\n".join(variables) def get_initializer(self): variables = [] - for key in self.module['children']: + for key in self.module["children"]: item = self.items[key] - name = item['name'] + name = item["name"] var_name = get_variable_name(name) - if item['children'] == []: + if item["children"] == []: initialization = f'{var_name}("{name}", this)' variables.append(initialization) else: - if item['type'] == 'module': - arguments = '' + if item["type"] == "module": + arguments = "" else: - arguments = 'this' - initialization = f'{var_name}({arguments})' + arguments = "this" + initialization = f"{var_name}({arguments})" variables.append(initialization) - if self.module['type'] != 'module': + if self.module["type"] != "module": constructor_params = f'"{self.name}"' - constructor_params += ', parent' + constructor_params += ", parent" else: - full_path = self.module['full_path'] + full_path = self.module["full_path"] constructor_params = f'"{full_path}"' - return f'PythonImportCacheItem({constructor_params}), ' + ', '.join(variables) + '{}' + return f"PythonImportCacheItem({constructor_params}), " + ", ".join(variables) + "{}" def get_constructor(self): - if self.module['type'] == 'module': - return f'{self.class_name}()' - return f'{self.class_name}(optional_ptr parent)' + if self.module["type"] == "module": + return f"{self.class_name}()" + return f"{self.class_name}(optional_ptr parent)" def to_string(self): return f""" @@ -125,7 +125,7 @@ def to_string(self): def collect_classes(items: dict) -> list: output: list = [] for item in items.values(): - if item['children'] == []: + if item["children"] == []: continue output.append(CacheItem(item, items)) return output @@ -134,7 +134,7 @@ def collect_classes(items: dict) -> list: class ModuleFile: def __init__(self, module: dict) -> None: self.module = module - self.file_name = get_filename(module['name']) + self.file_name = get_filename(module["name"]) self.items = {} collect_items_of_module(module, self.items) self.classes = collect_classes(self.items) @@ -144,7 +144,7 @@ def get_classes(self): classes = [] for item in self.classes: classes.append(item.to_string()) - return ''.join(classes) + return "".join(classes) def to_string(self): string = f""" @@ -176,13 +176,13 @@ def to_string(self): files: list[ModuleFile] = [] for name, value in json_data.items(): - if value['full_path'] != value['name']: + if value["full_path"] != value["name"]: continue files.append(ModuleFile(value)) for file in files: content = file.to_string() - path = f'src/duckdb_py/include/duckdb_python/import_cache/modules/{file.file_name}' + path = f"src/duckdb_py/include/duckdb_python/import_cache/modules/{file.file_name}" import_cache_path = os.path.join(script_dir, '..', path) with open(import_cache_path, "w") as f: f.write(content) @@ -191,10 +191,10 @@ def to_string(self): def get_root_modules(files: list[ModuleFile]): modules = [] for file in files: - name = file.module['name'] + name = file.module["name"] class_name = get_class_name(name) - modules.append(f'\t{class_name} {name};') - return '\n'.join(modules) + modules.append(f"\t{class_name} {name};") + return "\n".join(modules) # Generate the python_import_cache.hpp file @@ -237,9 +237,7 @@ def get_root_modules(files: list[ModuleFile]): """ -import_cache_path = os.path.join( - script_dir, '..', 'src/duckdb_py/include/duckdb_python/import_cache/python_import_cache.hpp' -) +import_cache_path = os.path.join(script_dir, "..", "src/duckdb_py/include/duckdb_python/import_cache/python_import_cache.hpp") with open(import_cache_path, "w") as f: f.write(import_cache_file) @@ -248,13 +246,13 @@ def get_module_file_path_includes(files: list[ModuleFile]): includes = [] for file in files: includes.append(f'#include "duckdb_python/import_cache/modules/{file.file_name}"') - return '\n'.join(includes) + return "\n".join(includes) module_includes = get_module_file_path_includes(files) modules_header = os.path.join( - script_dir, '..', 'src/duckdb_py/include/duckdb_python/import_cache/python_import_cache_modules.hpp' + script_dir, "..", "src/duckdb_py/include/duckdb_python/import_cache/python_import_cache_modules.hpp" ) with open(modules_header, "w") as f: f.write(module_includes) diff --git a/scripts/generate_import_cache_json.py b/scripts/generate_import_cache_json.py index 2df33b24..099db841 100644 --- a/scripts/generate_import_cache_json.py +++ b/scripts/generate_import_cache_json.py @@ -4,12 +4,12 @@ from typing import List, Dict, Union import json -lines: list[str] = [file for file in open(f'{script_dir}/imports.py').read().split('\n') if file != ''] +lines: list[str] = [file for file in open(f"{script_dir}/imports.py").read().split("\n") if file != ""] class ImportCacheAttribute: def __init__(self, full_path: str) -> None: - parts = full_path.split('.') + parts = full_path.split(".") self.type = "attribute" self.name = parts[-1] self.full_path = full_path @@ -42,7 +42,7 @@ def populate_json(self, json_data: dict): class ImportCacheModule: def __init__(self, full_path) -> None: - parts = full_path.split('.') + parts = full_path.split(".") self.type = "module" self.name = parts[-1] self.full_path = full_path @@ -82,27 +82,27 @@ def __init__(self) -> None: self.modules: dict[str, ImportCacheModule] = {} def add_module(self, path: str): - assert path.startswith('import') + assert path.startswith("import") path = path[7:] module = ImportCacheModule(path) self.modules[module.full_path] = module # Add it to the parent module if present - parts = path.split('.') + parts = path.split(".") if len(parts) == 1: return # This works back from the furthest child module to the top level module child_module = module for i in range(1, len(parts)): - parent_path = '.'.join(parts[: len(parts) - i]) + parent_path = ".".join(parts[: len(parts) - i]) parent_module = self.add_or_get_module(parent_path) parent_module.add_item(child_module) child_module = parent_module def add_or_get_module(self, module_name: str) -> ImportCacheModule: if module_name not in self.modules: - self.add_module(f'import {module_name}') + self.add_module(f"import {module_name}") return self.get_module(module_name) def get_module(self, module_name: str) -> ImportCacheModule: @@ -111,13 +111,13 @@ def get_module(self, module_name: str) -> ImportCacheModule: return self.modules[module_name] def get_item(self, item_name: str) -> Union[ImportCacheModule, ImportCacheAttribute]: - parts = item_name.split('.') + parts = item_name.split(".") if len(parts) == 1: return self.get_module(item_name) parent = self.get_module(parts[0]) for i in range(1, len(parts)): - child_path = '.'.join(parts[: i + 1]) + child_path = ".".join(parts[: i + 1]) if parent.has_item(child_path): parent = parent.get_item(child_path) else: @@ -127,8 +127,8 @@ def get_item(self, item_name: str) -> Union[ImportCacheModule, ImportCacheAttrib return parent def add_attribute(self, path: str): - assert not path.startswith('import') - parts = path.split('.') + assert not path.startswith("import") + parts = path.split(".") assert len(parts) >= 2 self.get_item(path) @@ -145,9 +145,9 @@ def to_json(self): generator = ImportCacheGenerator() for line in lines: - if line.startswith('#'): + if line.startswith("#"): continue - if line.startswith('import'): + if line.startswith("import"): generator.add_module(line) else: generator.add_attribute(line) diff --git a/sqllogic/conftest.py b/sqllogic/conftest.py index 64ad8edc..b8d913ea 100644 --- a/sqllogic/conftest.py +++ b/sqllogic/conftest.py @@ -10,7 +10,7 @@ SQLLOGIC_TEST_CASE_NAME = "test_sqllogic" SQLLOGIC_TEST_PARAMETER = "test_script_path" -DUCKDB_ROOT_DIR = (pathlib.Path(__file__).parent.parent / 'external' / 'duckdb').resolve() +DUCKDB_ROOT_DIR = (pathlib.Path(__file__).parent.parent / "external" / "duckdb").resolve() def pytest_addoption(parser: pytest.Parser): @@ -65,8 +65,8 @@ def pytest_keyboard_interrupt(excinfo: pytest.ExceptionInfo): # Ensure all tests are properly cleaned up on keyboard interrupt from .test_sqllogic import test_sqllogic - if hasattr(test_sqllogic, 'executor') and test_sqllogic.executor: - if test_sqllogic.executor.database and hasattr(test_sqllogic.executor.database, 'connection'): + if hasattr(test_sqllogic, "executor") and test_sqllogic.executor: + if test_sqllogic.executor.database and hasattr(test_sqllogic.executor.database, "connection"): test_sqllogic.executor.database.connection.interrupt() test_sqllogic.executor.cleanup() test_sqllogic.executor = None diff --git a/sqllogic/skipped_tests.py b/sqllogic/skipped_tests.py index 39269c42..485ed9b9 100644 --- a/sqllogic/skipped_tests.py +++ b/sqllogic/skipped_tests.py @@ -1,42 +1,42 @@ SKIPPED_TESTS = set( [ - 'test/sql/timezone/disable_timestamptz_casts.test', # <-- ICU extension is always loaded - 'test/sql/copy/return_stats_truncate.test', # <-- handling was changed - 'test/sql/copy/return_stats.test', # <-- handling was changed - 'test/sql/copy/parquet/writer/skip_empty_write.test', # <-- handling was changed - 'test/sql/types/map/map_empty.test', - 'test/extension/wrong_function_type.test', # <-- JSON is always loaded - 'test/sql/insert/test_insert_invalid.test', # <-- doesn't parse properly - 'test/sql/cast/cast_error_location.test', # <-- python exception doesn't contain error location yet - 'test/sql/pragma/test_query_log.test', # <-- query_log gets filled with NULL when con.query(...) is used - 'test/sql/json/table/read_json_objects.test', # <-- Python client is always loaded with JSON available - 'test/sql/copy/csv/zstd_crash.test', # <-- Python client is always loaded with Parquet available - 'test/sql/error/extension_function_error.test', # <-- Python client is always loaded with TPCH available - 'test/optimizer/joins/tpcds_nofail.test', # <-- Python client is always loaded with TPCDS available - 'test/sql/settings/errors_as_json.test', # <-- errors_as_json not currently supported in Python - 'test/sql/parallelism/intraquery/depth_first_evaluation_union_and_join.test', # <-- Python client is always loaded with TPCDS available - 'test/sql/types/timestamp/test_timestamp_tz.test', # <-- Python client is always loaded wih ICU available - making the TIMESTAMPTZ::DATE cast pass - 'test/sql/parser/invisible_spaces.test', # <-- Parser is getting tripped up on the invisible spaces - 'test/sql/copy/csv/code_cov/csv_state_machine_invalid_utf.test', # <-- ConversionException is empty, see Python Mega Issue (duckdb-internal #1488) - 'test/sql/copy/csv/test_csv_timestamp_tz.test', # <-- ICU is always loaded - 'test/fuzzer/duckfuzz/duck_fuzz_column_binding_tests.test', # <-- ICU is always loaded - 'test/sql/pragma/test_custom_optimizer_profiling.test', # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement - 'test/sql/pragma/test_custom_profiling_settings.test', # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement - 'test/sql/copy/csv/test_copy.test', # JSON is always loaded - 'test/sql/copy/csv/test_timestamptz_12926.test', # ICU is always loaded - 'test/fuzzer/pedro/in_clause_optimization_error.test', # error message differs due to a different execution path - 'test/sql/order/test_limit_parameter.test', # error message differs due to a different execution path - 'test/sql/catalog/test_set_search_path.test', # current_query() is not the same - 'test/sql/catalog/table/create_table_parameters.test', # prepared statement error quirks - 'test/sql/pragma/profiling/test_custom_profiling_rows_scanned.test', # we perform additional queries that mess with the expected metrics - 'test/sql/pragma/profiling/test_custom_profiling_disable_metrics.test', # we perform additional queries that mess with the expected metrics - 'test/sql/pragma/profiling/test_custom_profiling_result_set_size.test', # we perform additional queries that mess with the expected metrics - 'test/sql/pragma/profiling/test_custom_profiling_result_set_size.test', # we perform additional queries that mess with the expected metrics - 'test/sql/cte/materialized/materialized_cte_modifiers.test', # problems connected to auto installing tpcds from remote - 'test/sql/tpcds/dsdgen_readonly.test', # problems connected to auto installing tpcds from remote - 'test/sql/tpcds/tpcds_sf0.test', # problems connected to auto installing tpcds from remote - 'test/sql/optimizer/plan/test_filter_pushdown_materialized_cte.test', # problems connected to auto installing tpcds from remote - 'test/sql/explain/test_explain_analyze.test', # unknown problem with changes in API - 'test/sql/pragma/profiling/test_profiling_all.test', # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement + "test/sql/timezone/disable_timestamptz_casts.test", # <-- ICU extension is always loaded + "test/sql/copy/return_stats_truncate.test", # <-- handling was changed + "test/sql/copy/return_stats.test", # <-- handling was changed + "test/sql/copy/parquet/writer/skip_empty_write.test", # <-- handling was changed + "test/sql/types/map/map_empty.test", + "test/extension/wrong_function_type.test", # <-- JSON is always loaded + "test/sql/insert/test_insert_invalid.test", # <-- doesn't parse properly + "test/sql/cast/cast_error_location.test", # <-- python exception doesn't contain error location yet + "test/sql/pragma/test_query_log.test", # <-- query_log gets filled with NULL when con.query(...) is used + "test/sql/json/table/read_json_objects.test", # <-- Python client is always loaded with JSON available + "test/sql/copy/csv/zstd_crash.test", # <-- Python client is always loaded with Parquet available + "test/sql/error/extension_function_error.test", # <-- Python client is always loaded with TPCH available + "test/optimizer/joins/tpcds_nofail.test", # <-- Python client is always loaded with TPCDS available + "test/sql/settings/errors_as_json.test", # <-- errors_as_json not currently supported in Python + "test/sql/parallelism/intraquery/depth_first_evaluation_union_and_join.test", # <-- Python client is always loaded with TPCDS available + "test/sql/types/timestamp/test_timestamp_tz.test", # <-- Python client is always loaded wih ICU available - making the TIMESTAMPTZ::DATE cast pass + "test/sql/parser/invisible_spaces.test", # <-- Parser is getting tripped up on the invisible spaces + "test/sql/copy/csv/code_cov/csv_state_machine_invalid_utf.test", # <-- ConversionException is empty, see Python Mega Issue (duckdb-internal #1488) + "test/sql/copy/csv/test_csv_timestamp_tz.test", # <-- ICU is always loaded + "test/fuzzer/duckfuzz/duck_fuzz_column_binding_tests.test", # <-- ICU is always loaded + "test/sql/pragma/test_custom_optimizer_profiling.test", # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement + "test/sql/pragma/test_custom_profiling_settings.test", # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement + "test/sql/copy/csv/test_copy.test", # JSON is always loaded + "test/sql/copy/csv/test_timestamptz_12926.test", # ICU is always loaded + "test/fuzzer/pedro/in_clause_optimization_error.test", # error message differs due to a different execution path + "test/sql/order/test_limit_parameter.test", # error message differs due to a different execution path + "test/sql/catalog/test_set_search_path.test", # current_query() is not the same + "test/sql/catalog/table/create_table_parameters.test", # prepared statement error quirks + "test/sql/pragma/profiling/test_custom_profiling_rows_scanned.test", # we perform additional queries that mess with the expected metrics + "test/sql/pragma/profiling/test_custom_profiling_disable_metrics.test", # we perform additional queries that mess with the expected metrics + "test/sql/pragma/profiling/test_custom_profiling_result_set_size.test", # we perform additional queries that mess with the expected metrics + "test/sql/pragma/profiling/test_custom_profiling_result_set_size.test", # we perform additional queries that mess with the expected metrics + "test/sql/cte/materialized/materialized_cte_modifiers.test", # problems connected to auto installing tpcds from remote + "test/sql/tpcds/dsdgen_readonly.test", # problems connected to auto installing tpcds from remote + "test/sql/tpcds/tpcds_sf0.test", # problems connected to auto installing tpcds from remote + "test/sql/optimizer/plan/test_filter_pushdown_materialized_cte.test", # problems connected to auto installing tpcds from remote + "test/sql/explain/test_explain_analyze.test", # unknown problem with changes in API + "test/sql/pragma/profiling/test_profiling_all.test", # Because of logic related to enabling 'restart' statement capabilities, this will not measure the right statement ] ) diff --git a/sqllogic/test_sqllogic.py b/sqllogic/test_sqllogic.py index 4e7cead0..6f55e931 100644 --- a/sqllogic/test_sqllogic.py +++ b/sqllogic/test_sqllogic.py @@ -6,7 +6,7 @@ import sys from typing import Any, Generator, Optional -sys.path.append(str(pathlib.Path(__file__).parent.parent / 'external' / 'duckdb' / 'scripts')) +sys.path.append(str(pathlib.Path(__file__).parent.parent / "external" / "duckdb" / "scripts")) from sqllogictest import ( SQLParserException, SQLLogicParser, @@ -24,8 +24,8 @@ def sigquit_handler(signum, frame): # Access the executor from the test_sqllogic function - if hasattr(test_sqllogic, 'executor') and test_sqllogic.executor: - if test_sqllogic.executor.database and hasattr(test_sqllogic.executor.database, 'connection'): + if hasattr(test_sqllogic, "executor") and test_sqllogic.executor: + if test_sqllogic.executor.database and hasattr(test_sqllogic.executor.database, "connection"): test_sqllogic.executor.database.connection.interrupt() test_sqllogic.executor.cleanup() test_sqllogic.executor = None @@ -85,13 +85,13 @@ def execute_test(self, test: SQLLogicTest) -> ExecuteResult: self.original_sqlite_test = self.test.is_sqlite_test() # Top level keywords - keywords = {'__TEST_DIR__': self.get_test_directory(), '__WORKING_DIRECTORY__': os.getcwd()} + keywords = {"__TEST_DIR__": self.get_test_directory(), "__WORKING_DIRECTORY__": os.getcwd()} def update_value(_: SQLLogicContext) -> Generator[Any, Any, Any]: # Yield once to represent one iteration, do not touch the keywords yield None - self.database = SQLLogicDatabase(':memory:', None) + self.database = SQLLogicDatabase(":memory:", None) pool = self.database.connect() context = SQLLogicContext(pool, self, test.statements, keywords, update_value) pool.initialize_connection(context, pool.get_connection()) @@ -126,7 +126,7 @@ def update_value(_: SQLLogicContext) -> Generator[Any, Any, Any]: def cleanup(self): if self.database: - if hasattr(self.database, 'connection'): + if hasattr(self.database, "connection"): self.database.connection.interrupt() self.database.reset() self.database = None @@ -160,6 +160,6 @@ def test_sqllogic(test_script_path: pathlib.Path, pytestconfig: pytest.Config, t test_sqllogic.executor = None -if __name__ == '__main__': +if __name__ == "__main__": # Pass all arguments including the script name to pytest sys.exit(pytest.main(sys.argv)) diff --git a/tests/conftest.py b/tests/conftest.py index b9950ee7..d69cdfce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,11 +11,11 @@ try: # need to ignore warnings that might be thrown deep inside pandas's import tree (from dateutil in this case) - warnings.simplefilter(action='ignore', category=DeprecationWarning) - pandas = import_module('pandas') + warnings.simplefilter(action="ignore", category=DeprecationWarning) + pandas = import_module("pandas") warnings.resetwarnings() - pyarrow_dtype = getattr(pandas, 'ArrowDtype', None) + pyarrow_dtype = getattr(pandas, "ArrowDtype", None) except ImportError: pandas = None pyarrow_dtype = None @@ -65,7 +65,7 @@ def pytest_collection_modifyitems(config, items): @pytest.fixture(scope="function") def duckdb_empty_cursor(request): - connection = duckdb.connect('') + connection = duckdb.connect("") cursor = connection.cursor() return cursor @@ -99,7 +99,7 @@ def makeTimeSeries(nper=None, freq: Frequency = "B", name=None) -> Series: def pandas_2_or_higher(): from packaging.version import Version - return Version(import_pandas().__version__) >= Version('2.0.0') + return Version(import_pandas().__version__) >= Version("2.0.0") def pandas_supports_arrow_backend(): @@ -124,7 +124,7 @@ def arrow_pandas_df(*args, **kwargs): class NumpyPandas: def __init__(self) -> None: - self.backend = 'numpy_nullable' + self.backend = "numpy_nullable" self.DataFrame = numpy_pandas_df self.pandas = import_pandas() @@ -173,11 +173,11 @@ class ArrowPandas: def __init__(self) -> None: self.pandas = import_pandas() if pandas_2_or_higher() and pyarrow_dtypes_enabled: - self.backend = 'pyarrow' + self.backend = "pyarrow" self.DataFrame = arrow_pandas_df else: # For backwards compatible reasons, just mock regular pandas - self.backend = 'numpy_nullable' + self.backend = "numpy_nullable" self.DataFrame = self.pandas.DataFrame self.testing = ArrowMockTesting() @@ -187,7 +187,7 @@ def __getattr__(self, name: str) -> Any: @pytest.fixture(scope="function") def require(): - def _require(extension_name, db_name=''): + def _require(extension_name, db_name=""): # Paths to search for extensions build = normpath(join(dirname(__file__), "../../../build/")) @@ -199,11 +199,11 @@ def _require(extension_name, db_name=''): ] # DUCKDB_PYTHON_TEST_EXTENSION_PATH can be used to add a path for the extension test to search for extensions - if 'DUCKDB_PYTHON_TEST_EXTENSION_PATH' in os.environ: - env_extension_path = os.getenv('DUCKDB_PYTHON_TEST_EXTENSION_PATH') - env_extension_path = env_extension_path.rstrip('/') - extension_search_patterns.append(env_extension_path + '/*/*.duckdb_extension') - extension_search_patterns.append(env_extension_path + '/*.duckdb_extension') + if "DUCKDB_PYTHON_TEST_EXTENSION_PATH" in os.environ: + env_extension_path = os.getenv("DUCKDB_PYTHON_TEST_EXTENSION_PATH") + env_extension_path = env_extension_path.rstrip("/") + extension_search_patterns.append(env_extension_path + "/*/*.duckdb_extension") + extension_search_patterns.append(env_extension_path + "/*.duckdb_extension") extension_paths_found = [] for pattern in extension_search_patterns: @@ -215,39 +215,39 @@ def _require(extension_name, db_name=''): for path in extension_paths_found: print(path) if path.endswith(extension_name + ".duckdb_extension"): - conn = duckdb.connect(db_name, config={'allow_unsigned_extensions': 'true'}) + conn = duckdb.connect(db_name, config={"allow_unsigned_extensions": "true"}) conn.execute(f"LOAD '{path}'") return conn - pytest.skip(f'could not load {extension_name}') + pytest.skip(f"could not load {extension_name}") return _require # By making the scope 'function' we ensure that a new connection gets created for every function that uses the fixture -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def spark(): from spark_namespace import USE_ACTUAL_SPARK - if not hasattr(spark, 'session'): + if not hasattr(spark, "session"): # Cache the import from spark_namespace.sql import SparkSession as session spark.session = session - return spark.session.builder.appName('pyspark').getOrCreate() + return spark.session.builder.appName("pyspark").getOrCreate() -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def duckdb_cursor(): - connection = duckdb.connect('') + connection = duckdb.connect("") yield connection connection.close() -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def integers(duckdb_cursor): cursor = duckdb_cursor - cursor.execute('CREATE TABLE integers (i integer)') + cursor.execute("CREATE TABLE integers (i integer)") cursor.execute( """ INSERT INTO integers VALUES @@ -268,10 +268,10 @@ def integers(duckdb_cursor): cursor.execute("drop table integers") -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def timestamps(duckdb_cursor): cursor = duckdb_cursor - cursor.execute('CREATE TABLE timestamps (t timestamp)') + cursor.execute("CREATE TABLE timestamps (t timestamp)") cursor.execute("INSERT INTO timestamps VALUES ('1992-10-03 18:34:45'), ('2010-01-01 00:00:01'), (NULL)") yield cursor.execute("drop table timestamps") diff --git a/tests/coverage/test_pandas_categorical_coverage.py b/tests/coverage/test_pandas_categorical_coverage.py index e20afa72..15eee10a 100644 --- a/tests/coverage/test_pandas_categorical_coverage.py +++ b/tests/coverage/test_pandas_categorical_coverage.py @@ -15,17 +15,17 @@ def check_create_table(category, pandas): conn.execute("PRAGMA enable_verification") df_in = pandas.DataFrame( { - 'x': pandas.Categorical(category, ordered=True), - 'y': pandas.Categorical(category, ordered=True), - 'z': category, + "x": pandas.Categorical(category, ordered=True), + "y": pandas.Categorical(category, ordered=True), + "z": category, } ) - category.append('bla') + category.append("bla") df_in_diff = pandas.DataFrame( { - 'k': pandas.Categorical(category, ordered=True), + "k": pandas.Categorical(category, ordered=True), } ) @@ -44,7 +44,7 @@ def check_create_table(category, pandas): conn.execute("INSERT INTO t1 VALUES ('2','2','2')") res = conn.execute("SELECT x FROM t1 where x = '1'").fetchall() - assert res == [('1',)] + assert res == [("1",)] res = conn.execute("SELECT t1.x FROM t1 inner join t2 on (t1.x = t2.x) order by t1.x").fetchall() assert res == conn.execute("SELECT x FROM t1 order by t1.x").fetchall() @@ -70,14 +70,14 @@ def check_create_table(category, pandas): # TODO: extend tests with ArrowPandas class TestCategory(object): - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_category_string_uint16(self, duckdb_cursor, pandas): category = [] for i in range(300): category.append(str(i)) check_create_table(category, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_category_string_uint32(self, duckdb_cursor, pandas): category = [] for i in range(70000): diff --git a/tests/extensions/json/test_read_json.py b/tests/extensions/json/test_read_json.py index 48590175..f0fd809f 100644 --- a/tests/extensions/json/test_read_json.py +++ b/tests/extensions/json/test_read_json.py @@ -10,50 +10,50 @@ def TestFile(name): import os - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', name) + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", name) return filename class TestReadJSON(object): def test_read_json_columns(self): - rel = duckdb.read_json(TestFile('example.json'), columns={'id': 'integer', 'name': 'varchar'}) + rel = duckdb.read_json(TestFile("example.json"), columns={"id": "integer", "name": "varchar"}) res = rel.fetchone() print(res) - assert res == (1, 'O Brother, Where Art Thou?') + assert res == (1, "O Brother, Where Art Thou?") def test_read_json_auto(self): - rel = duckdb.read_json(TestFile('example.json')) + rel = duckdb.read_json(TestFile("example.json")) res = rel.fetchone() print(res) - assert res == (1, 'O Brother, Where Art Thou?') + assert res == (1, "O Brother, Where Art Thou?") def test_read_json_maximum_depth(self): - rel = duckdb.read_json(TestFile('example.json'), maximum_depth=4) + rel = duckdb.read_json(TestFile("example.json"), maximum_depth=4) res = rel.fetchone() print(res) - assert res == (1, 'O Brother, Where Art Thou?') + assert res == (1, "O Brother, Where Art Thou?") def test_read_json_sample_size(self): - rel = duckdb.read_json(TestFile('example.json'), sample_size=2) + rel = duckdb.read_json(TestFile("example.json"), sample_size=2) res = rel.fetchone() print(res) - assert res == (1, 'O Brother, Where Art Thou?') + assert res == (1, "O Brother, Where Art Thou?") def test_read_json_format(self): # Wrong option with pytest.raises(duckdb.BinderException, match="format must be one of .* not 'test'"): - rel = duckdb.read_json(TestFile('example.json'), format='test') + rel = duckdb.read_json(TestFile("example.json"), format="test") - rel = duckdb.read_json(TestFile('example.json'), format='unstructured') + rel = duckdb.read_json(TestFile("example.json"), format="unstructured") res = rel.fetchone() print(res) assert res == ( [ - {'id': 1, 'name': 'O Brother, Where Art Thou?'}, - {'id': 2, 'name': 'Home for the Holidays'}, - {'id': 3, 'name': 'The Firm'}, - {'id': 4, 'name': 'Broadcast News'}, - {'id': 5, 'name': 'Raising Arizona'}, + {"id": 1, "name": "O Brother, Where Art Thou?"}, + {"id": 2, "name": "Home for the Holidays"}, + {"id": 3, "name": "The Firm"}, + {"id": 4, "name": "Broadcast News"}, + {"id": 5, "name": "Raising Arizona"}, ], ) @@ -63,13 +63,13 @@ def test_read_filelike(self, duckdb_cursor): duckdb_cursor.execute("set threads=1") string = StringIO("""{"id":1,"name":"O Brother, Where Art Thou?"}\n{"id":2,"name":"Home for the Holidays"}""") res = duckdb_cursor.read_json(string).fetchall() - assert res == [(1, 'O Brother, Where Art Thou?'), (2, 'Home for the Holidays')] + assert res == [(1, "O Brother, Where Art Thou?"), (2, "Home for the Holidays")] string1 = StringIO("""{"id":1,"name":"O Brother, Where Art Thou?"}""") string2 = StringIO("""{"id":2,"name":"Home for the Holidays"}""") res = duckdb_cursor.read_json([string1, string2], filename=True).fetchall() - assert res[0][1] == 'O Brother, Where Art Thou?' - assert res[1][1] == 'Home for the Holidays' + assert res[0][1] == "O Brother, Where Art Thou?" + assert res[1][1] == "Home for the Holidays" # filenames are different assert res[0][2] != res[1][2] @@ -77,51 +77,51 @@ def test_read_filelike(self, duckdb_cursor): def test_read_json_records(self): # Wrong option with pytest.raises(duckdb.BinderException, match="""read_json requires "records" to be one of"""): - rel = duckdb.read_json(TestFile('example.json'), records='none') + rel = duckdb.read_json(TestFile("example.json"), records="none") - rel = duckdb.read_json(TestFile('example.json'), records='true') + rel = duckdb.read_json(TestFile("example.json"), records="true") res = rel.fetchone() print(res) - assert res == (1, 'O Brother, Where Art Thou?') + assert res == (1, "O Brother, Where Art Thou?") @pytest.mark.parametrize( - 'option', + "option", [ - ('filename', True), - ('filename', 'test'), - ('date_format', '%m-%d-%Y'), - ('date_format', '%m-%d-%y'), - ('date_format', '%d-%m-%Y'), - ('date_format', '%d-%m-%y'), - ('date_format', '%Y-%m-%d'), - ('date_format', '%y-%m-%d'), - ('timestamp_format', '%H:%M:%S%y-%m-%d'), - ('compression', 'AUTO_DETECT'), - ('compression', 'UNCOMPRESSED'), - ('maximum_object_size', 5), - ('ignore_errors', False), - ('ignore_errors', True), - ('convert_strings_to_integers', False), - ('convert_strings_to_integers', True), - ('field_appearance_threshold', 0.534), - ('map_inference_threshold', 34234), - ('maximum_sample_files', 5), - ('hive_partitioning', True), - ('hive_partitioning', False), - ('union_by_name', True), - ('union_by_name', False), - ('hive_types_autocast', False), - ('hive_types_autocast', True), - ('hive_types', {'id': 'INTEGER', 'name': 'VARCHAR'}), + ("filename", True), + ("filename", "test"), + ("date_format", "%m-%d-%Y"), + ("date_format", "%m-%d-%y"), + ("date_format", "%d-%m-%Y"), + ("date_format", "%d-%m-%y"), + ("date_format", "%Y-%m-%d"), + ("date_format", "%y-%m-%d"), + ("timestamp_format", "%H:%M:%S%y-%m-%d"), + ("compression", "AUTO_DETECT"), + ("compression", "UNCOMPRESSED"), + ("maximum_object_size", 5), + ("ignore_errors", False), + ("ignore_errors", True), + ("convert_strings_to_integers", False), + ("convert_strings_to_integers", True), + ("field_appearance_threshold", 0.534), + ("map_inference_threshold", 34234), + ("maximum_sample_files", 5), + ("hive_partitioning", True), + ("hive_partitioning", False), + ("union_by_name", True), + ("union_by_name", False), + ("hive_types_autocast", False), + ("hive_types_autocast", True), + ("hive_types", {"id": "INTEGER", "name": "VARCHAR"}), ], ) def test_read_json_options(self, duckdb_cursor, option): keyword_arguments = dict() option_name, option_value = option keyword_arguments[option_name] = option_value - if option_name == 'hive_types': - with pytest.raises(duckdb.InvalidInputException, match=r'Unknown hive_type:'): - rel = duckdb_cursor.read_json(TestFile('example.json'), **keyword_arguments) + if option_name == "hive_types": + with pytest.raises(duckdb.InvalidInputException, match=r"Unknown hive_type:"): + rel = duckdb_cursor.read_json(TestFile("example.json"), **keyword_arguments) else: - rel = duckdb_cursor.read_json(TestFile('example.json'), **keyword_arguments) + rel = duckdb_cursor.read_json(TestFile("example.json"), **keyword_arguments) res = rel.fetchall() diff --git a/tests/extensions/test_extensions_loading.py b/tests/extensions/test_extensions_loading.py index 2b4eab0c..f35366ba 100644 --- a/tests/extensions/test_extensions_loading.py +++ b/tests/extensions/test_extensions_loading.py @@ -13,9 +13,9 @@ def test_extension_loading(require): - if not os.getenv('DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED', False): + if not os.getenv("DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED", False): return - extensions_list = ['json', 'excel', 'httpfs', 'tpch', 'tpcds', 'icu', 'fts'] + extensions_list = ["json", "excel", "httpfs", "tpch", "tpcds", "icu", "fts"] for extension in extensions_list: connection = require(extension) assert connection is not None @@ -26,16 +26,16 @@ def test_install_non_existent_extension(): conn.execute("set custom_extension_repository = 'http://example.com'") with raises(duckdb.IOException) as exc: - conn.install_extension('non-existent') + conn.install_extension("non-existent") if not isinstance(exc, duckdb.HTTPException): - pytest.skip(reason='This test does not throw an HTTPException, only an IOException') + pytest.skip(reason="This test does not throw an HTTPException, only an IOException") value = exc.value assert value.status_code == 404 - assert value.reason == 'Not Found' - assert 'Example Domain' in value.body - assert 'Content-Length' in value.headers + assert value.reason == "Not Found" + assert "Example Domain" in value.body + assert "Content-Length" in value.headers def test_install_misuse_errors(duckdb_cursor): @@ -43,17 +43,17 @@ def test_install_misuse_errors(duckdb_cursor): duckdb.InvalidInputException, match="Both 'repository' and 'repository_url' are set which is not allowed, please pick one or the other", ): - duckdb_cursor.install_extension('name', repository='hello', repository_url='hello.com') + duckdb_cursor.install_extension("name", repository="hello", repository_url="hello.com") with pytest.raises( duckdb.InvalidInputException, match="The provided 'repository' or 'repository_url' can not be empty!" ): - duckdb_cursor.install_extension('name', repository_url='') + duckdb_cursor.install_extension("name", repository_url="") with pytest.raises( duckdb.InvalidInputException, match="The provided 'repository' or 'repository_url' can not be empty!" ): - duckdb_cursor.install_extension('name', repository='') + duckdb_cursor.install_extension("name", repository="") with pytest.raises(duckdb.InvalidInputException, match="The provided 'version' can not be empty!"): - duckdb_cursor.install_extension('name', version='') + duckdb_cursor.install_extension("name", version="") diff --git a/tests/extensions/test_httpfs.py b/tests/extensions/test_httpfs.py index 6366e07f..866491f0 100644 --- a/tests/extensions/test_httpfs.py +++ b/tests/extensions/test_httpfs.py @@ -9,33 +9,33 @@ # FIXME: we can add a custom command line argument to pytest to provide an extension directory # We can use that instead of checking this environment variable inside of conftest.py's 'require' method pytestmark = mark.skipif( - not os.getenv('DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED', False), - reason='DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED is not set', + not os.getenv("DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED", False), + reason="DUCKDB_PYTHON_TEST_EXTENSION_REQUIRED is not set", ) class TestHTTPFS(object): def test_read_json_httpfs(self, require): - connection = require('httpfs') + connection = require("httpfs") try: - res = connection.read_json('https://jsonplaceholder.typicode.com/todos') + res = connection.read_json("https://jsonplaceholder.typicode.com/todos") assert len(res.types) == 4 except duckdb.Error as e: - if '403' in e: + if "403" in e: pytest.skip(reason="Test is flaky, sometimes returns 403") else: pytest.fail(str(e)) def test_s3fs(self, require): - connection = require('httpfs') + connection = require("httpfs") rel = connection.read_csv(f"s3://duckdb-blobs/data/Star_Trek-Season_1.csv", header=True) res = rel.fetchone() assert res == (1, 0, datetime.date(1965, 2, 28), 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 6, 0, 0, 0, 0) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_httpfs(self, require, pandas): - connection = require('httpfs') + connection = require("httpfs") try: connection.execute( "SELECT id, first_name, last_name FROM PARQUET_SCAN('https://raw.githubusercontent.com/duckdb/duckdb/main/data/parquet-testing/userdata1.parquet') LIMIT 3;" @@ -52,15 +52,15 @@ def test_httpfs(self, require, pandas): result_df = connection.fetchdf() exp_result = pandas.DataFrame( { - 'id': pandas.Series([1, 2, 3], dtype="int32"), - 'first_name': ['Amanda', 'Albert', 'Evelyn'], - 'last_name': ['Jordan', 'Freeman', 'Morgan'], + "id": pandas.Series([1, 2, 3], dtype="int32"), + "first_name": ["Amanda", "Albert", "Evelyn"], + "last_name": ["Jordan", "Freeman", "Morgan"], } ) pandas.testing.assert_frame_equal(result_df, exp_result) def test_http_exception(self, require): - connection = require('httpfs') + connection = require("httpfs") # Read from a bogus HTTPS url, assert that it errors with a non-successful status code with raises(duckdb.HTTPException) as exc: @@ -68,15 +68,15 @@ def test_http_exception(self, require): value = exc.value assert value.status_code != 200 - assert value.body == '' - assert 'Content-Length' in value.headers + assert value.body == "" + assert "Content-Length" in value.headers def test_fsspec_priority(self, require): pytest.importorskip("fsspec") pytest.importorskip("gscfs") import fsspec - connection = require('httpfs') + connection = require("httpfs") gcs = fsspec.filesystem("gcs") connection.register_filesystem(gcs) diff --git a/tests/fast/adbc/test_adbc.py b/tests/fast/adbc/test_adbc.py index 663563cf..80b6b385 100644 --- a/tests/fast/adbc/test_adbc.py +++ b/tests/fast/adbc/test_adbc.py @@ -47,7 +47,7 @@ def test_connection_get_table_types(duck_conn): with duck_conn.cursor() as cursor: # Test Default Schema cursor.execute("CREATE TABLE tableschema (ints BIGINT)") - assert duck_conn.adbc_get_table_types() == ['BASE TABLE'] + assert duck_conn.adbc_get_table_types() == ["BASE TABLE"] def test_connection_get_objects(duck_conn): @@ -124,7 +124,7 @@ def test_commit(tmp_path): # This errors because the table does not exist with pytest.raises( adbc_driver_manager_lib.InternalError, - match=r'Table with name ingest does not exist!', + match=r"Table with name ingest does not exist!", ): cur.execute("SELECT count(*) from ingest") @@ -138,7 +138,7 @@ def test_commit(tmp_path): ) as conn: with conn.cursor() as cur: cur.execute("SELECT count(*) from ingest") - assert cur.fetch_arrow_table().to_pydict() == {'count_star()': [4]} + assert cur.fetch_arrow_table().to_pydict() == {"count_star()": [4]} def test_connection_get_table_schema(duck_conn): @@ -310,17 +310,17 @@ def test_large_chunk(tmp_path): with conn.cursor() as cur: cur.adbc_ingest("ingest", table, "create") cur.execute("SELECT count(*) from ingest") - assert cur.fetch_arrow_table().to_pydict() == {'count_star()': [30_000]} + assert cur.fetch_arrow_table().to_pydict() == {"count_star()": [30_000]} def test_dictionary_data(tmp_path): - data = ['apple', 'banana', 'apple', 'orange', 'banana', 'banana'] + data = ["apple", "banana", "apple", "orange", "banana", "banana"] dict_type = pyarrow.dictionary(index_type=pyarrow.int32(), value_type=pyarrow.string()) dict_array = pyarrow.array(data, type=dict_type) # Wrap in a table - table = pyarrow.table({'fruits': dict_array}) + table = pyarrow.table({"fruits": dict_array}) db = os.path.join(tmp_path, "tmp.db") if os.path.exists(db): os.remove(db) @@ -335,7 +335,7 @@ def test_dictionary_data(tmp_path): cur.adbc_ingest("ingest", table, "create") cur.execute("from ingest") assert cur.fetch_arrow_table().to_pydict() == { - 'fruits': ['apple', 'banana', 'apple', 'orange', 'banana', 'banana'] + "fruits": ["apple", "banana", "apple", "orange", "banana", "banana"] } @@ -361,36 +361,36 @@ def test_ree_data(tmp_path): cur.adbc_ingest("ingest", table, "create") cur.execute("from ingest") assert cur.fetch_arrow_table().to_pydict() == { - 'fruits': ['apple', 'apple', 'apple', 'banana', 'banana', 'orange'] + "fruits": ["apple", "apple", "apple", "banana", "banana", "orange"] } def sorted_get_objects(catalogs): res = [] - for catalog in sorted(catalogs, key=lambda cat: cat['catalog_name']): + for catalog in sorted(catalogs, key=lambda cat: cat["catalog_name"]): new_catalog = { - "catalog_name": catalog['catalog_name'], + "catalog_name": catalog["catalog_name"], "catalog_db_schemas": [], } - for db_schema in sorted(catalog['catalog_db_schemas'] or [], key=lambda sch: sch['db_schema_name']): + for db_schema in sorted(catalog["catalog_db_schemas"] or [], key=lambda sch: sch["db_schema_name"]): new_db_schema = { - "db_schema_name": db_schema['db_schema_name'], + "db_schema_name": db_schema["db_schema_name"], "db_schema_tables": [], } - for table in sorted(db_schema['db_schema_tables'] or [], key=lambda tab: tab['table_name']): + for table in sorted(db_schema["db_schema_tables"] or [], key=lambda tab: tab["table_name"]): new_table = { - "table_name": table['table_name'], - "table_type": table['table_type'], + "table_name": table["table_name"], + "table_type": table["table_type"], "table_columns": [], "table_constraints": [], } - for column in sorted(table['table_columns'] or [], key=lambda col: col['ordinal_position']): + for column in sorted(table["table_columns"] or [], key=lambda col: col["ordinal_position"]): new_table["table_columns"].append(column) - for constraint in sorted(table['table_constraints'] or [], key=lambda con: con['constraint_name']): + for constraint in sorted(table["table_constraints"] or [], key=lambda con: con["constraint_name"]): new_table["table_constraints"].append(constraint) new_db_schema["db_schema_tables"].append(new_table) diff --git a/tests/fast/adbc/test_statement_bind.py b/tests/fast/adbc/test_statement_bind.py index 5e9d7d45..d1919cb1 100644 --- a/tests/fast/adbc/test_statement_bind.py +++ b/tests/fast/adbc/test_statement_bind.py @@ -70,30 +70,30 @@ def test_bind_single_row(self): raw_schema = statement.get_parameter_schema() schema = _import(raw_schema) - assert schema.names == ['0'] + assert schema.names == ["0"] _bind(statement, data) res, _ = statement.execute_query() table = _import(res).read_all() - result = table['i'] + result = table["i"] assert result.num_chunks == 1 result_values = result.chunk(0) assert result_values == expected_result def test_multiple_parameters(self): int_data = pa.array([5]) - varchar_data = pa.array(['not a short string']) + varchar_data = pa.array(["not a short string"]) bool_data = pa.array([True]) # Create the schema - schema = pa.schema([('a', pa.int64()), ('b', pa.string()), ('c', pa.bool_())]) + schema = pa.schema([("a", pa.int64()), ("b", pa.string()), ("c", pa.bool_())]) # Create the PyArrow table expected_res = pa.Table.from_arrays([int_data, varchar_data, bool_data], schema=schema) data = pa.record_batch( - [[5], ['not a short string'], [True]], + [[5], ["not a short string"], [True]], names=["ints", "strings", "bools"], ) @@ -105,7 +105,7 @@ def test_multiple_parameters(self): raw_schema = statement.get_parameter_schema() schema = _import(raw_schema) - assert schema.names == ['0', '1', '2'] + assert schema.names == ["0", "1", "2"] _bind(statement, data) res, _ = statement.execute_query() @@ -115,14 +115,14 @@ def test_multiple_parameters(self): def test_bind_composite_type(self): data_dict = { - 'field1': pa.array([10], type=pa.int64()), - 'field2': pa.array([3.14], type=pa.float64()), - 'field3': pa.array(['example with long string'], type=pa.string()), + "field1": pa.array([10], type=pa.int64()), + "field2": pa.array([3.14], type=pa.float64()), + "field3": pa.array(["example with long string"], type=pa.string()), } # Create the StructArray struct_array = pa.StructArray.from_arrays(arrays=data_dict.values(), names=data_dict.keys()) - schema = pa.schema([(name, array.type) for name, array in zip(['a'], [struct_array])]) + schema = pa.schema([(name, array.type) for name, array in zip(["a"], [struct_array])]) # Create the RecordBatch record_batch = pa.RecordBatch.from_arrays([struct_array], schema=schema) @@ -135,18 +135,18 @@ def test_bind_composite_type(self): raw_schema = statement.get_parameter_schema() schema = _import(raw_schema) - assert schema.names == ['0'] + assert schema.names == ["0"] _bind(statement, record_batch) res, _ = statement.execute_query() table = _import(res).read_all() - result = table['a'] + result = table["a"] result = result.chunk(0) assert result == struct_array def test_too_many_parameters(self): data = pa.record_batch( - [[12423], ['not a short string']], + [[12423], ["not a short string"]], names=["ints", "strings"], ) @@ -158,7 +158,7 @@ def test_too_many_parameters(self): raw_schema = statement.get_parameter_schema() schema = _import(raw_schema) - assert schema.names == ['0'] + assert schema.names == ["0"] array = adbc_driver_manager.ArrowArrayHandle() schema = adbc_driver_manager.ArrowSchemaHandle() @@ -174,7 +174,7 @@ def test_too_many_parameters(self): def test_not_enough_parameters(self): data = pa.record_batch( - [['not a short string']], + [["not a short string"]], names=["strings"], ) @@ -186,7 +186,7 @@ def test_not_enough_parameters(self): raw_schema = statement.get_parameter_schema() schema = _import(raw_schema) - assert schema.names == ['0', '1'] + assert schema.names == ["0", "1"] array = adbc_driver_manager.ArrowArrayHandle() schema = adbc_driver_manager.ArrowSchemaHandle() diff --git a/tests/fast/api/test_3324.py b/tests/fast/api/test_3324.py index e8f6085f..f3cd235b 100644 --- a/tests/fast/api/test_3324.py +++ b/tests/fast/api/test_3324.py @@ -27,4 +27,4 @@ def test_3324(self, duckdb_cursor): ).fetch_df() with pytest.raises(duckdb.BinderException, match="Unexpected prepared parameter"): - duckdb_cursor.execute("""execute v1(?)""", ('test1',)).fetch_df() + duckdb_cursor.execute("""execute v1(?)""", ("test1",)).fetch_df() diff --git a/tests/fast/api/test_3654.py b/tests/fast/api/test_3654.py index e63f0cd1..8fad47e6 100644 --- a/tests/fast/api/test_3654.py +++ b/tests/fast/api/test_3654.py @@ -11,11 +11,11 @@ class Test3654(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_3654_pandas(self, duckdb_cursor, pandas): df1 = pandas.DataFrame( { - 'id': [1, 1, 2], + "id": [1, 1, 2], } ) con = duckdb.connect() @@ -24,14 +24,14 @@ def test_3654_pandas(self, duckdb_cursor, pandas): print(rel.execute().fetchall()) assert rel.execute().fetchall() == [(1,), (1,), (2,)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_3654_arrow(self, duckdb_cursor, pandas): if not can_run: return df1 = pandas.DataFrame( { - 'id': [1, 1, 2], + "id": [1, 1, 2], } ) table = pa.Table.from_pandas(df1) diff --git a/tests/fast/api/test_3728.py b/tests/fast/api/test_3728.py index 2df3c156..37b50ee6 100644 --- a/tests/fast/api/test_3728.py +++ b/tests/fast/api/test_3728.py @@ -14,6 +14,6 @@ def test_3728_describe_enum(self, duckdb_cursor): # This fails with "RuntimeError: Not implemented Error: unsupported type: mood" assert cursor.table("person").execute().description == [ - ('name', 'VARCHAR', None, None, None, None, None), - ('current_mood', "ENUM('sad', 'ok', 'happy')", None, None, None, None, None), + ("name", "VARCHAR", None, None, None, None, None), + ("current_mood", "ENUM('sad', 'ok', 'happy')", None, None, None, None, None), ] diff --git a/tests/fast/api/test_6315.py b/tests/fast/api/test_6315.py index e8eaff59..b9e7c0cf 100644 --- a/tests/fast/api/test_6315.py +++ b/tests/fast/api/test_6315.py @@ -9,15 +9,15 @@ def test_6315(self, duckdb_cursor): rv.fetchall() desc = rv.description names = [x[0] for x in desc] - assert names == ['type', 'name', 'tbl_name', 'rootpage', 'sql'] + assert names == ["type", "name", "tbl_name", "rootpage", "sql"] # description of relation rel = c.sql("select * from sqlite_master where type = 'table'") desc = rel.description names = [x[0] for x in desc] - assert names == ['type', 'name', 'tbl_name', 'rootpage', 'sql'] + assert names == ["type", "name", "tbl_name", "rootpage", "sql"] rel.fetchall() desc = rel.description names = [x[0] for x in desc] - assert names == ['type', 'name', 'tbl_name', 'rootpage', 'sql'] + assert names == ["type", "name", "tbl_name", "rootpage", "sql"] diff --git a/tests/fast/api/test_attribute_getter.py b/tests/fast/api/test_attribute_getter.py index 958e8892..eda6845a 100644 --- a/tests/fast/api/test_attribute_getter.py +++ b/tests/fast/api/test_attribute_getter.py @@ -11,43 +11,43 @@ class TestGetAttribute(object): def test_basic_getattr(self, duckdb_cursor): - rel = duckdb_cursor.sql('select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)') + rel = duckdb_cursor.sql("select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)") assert rel.a.fetchmany(5) == [(0,), (1,), (2,), (3,), (4,)] assert rel.b.fetchmany(5) == [(5,), (6,), (7,), (8,), (9,)] assert rel.c.fetchmany(5) == [(2,), (0,), (1,), (2,), (0,)] def test_basic_getitem(self, duckdb_cursor): - rel = duckdb_cursor.sql('select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)') - assert rel['a'].fetchmany(5) == [(0,), (1,), (2,), (3,), (4,)] - assert rel['b'].fetchmany(5) == [(5,), (6,), (7,), (8,), (9,)] - assert rel['c'].fetchmany(5) == [(2,), (0,), (1,), (2,), (0,)] + rel = duckdb_cursor.sql("select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)") + assert rel["a"].fetchmany(5) == [(0,), (1,), (2,), (3,), (4,)] + assert rel["b"].fetchmany(5) == [(5,), (6,), (7,), (8,), (9,)] + assert rel["c"].fetchmany(5) == [(2,), (0,), (1,), (2,), (0,)] def test_getitem_nonexistant(self, duckdb_cursor): - rel = duckdb_cursor.sql('select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)') + rel = duckdb_cursor.sql("select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)") with pytest.raises(AttributeError): - rel['d'] + rel["d"] def test_getattr_nonexistant(self, duckdb_cursor): - rel = duckdb_cursor.sql('select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)') + rel = duckdb_cursor.sql("select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)") with pytest.raises(AttributeError): rel.d def test_getattr_collision(self, duckdb_cursor): - rel = duckdb_cursor.sql('select i as df from range(100) tbl(i)') + rel = duckdb_cursor.sql("select i as df from range(100) tbl(i)") # 'df' also exists as a method on DuckDBPyRelation assert rel.df.__class__ != duckdb.DuckDBPyRelation def test_getitem_collision(self, duckdb_cursor): - rel = duckdb_cursor.sql('select i as df from range(100) tbl(i)') + rel = duckdb_cursor.sql("select i as df from range(100) tbl(i)") # this case is not an issue on __getitem__ - assert rel['df'].__class__ == duckdb.DuckDBPyRelation + assert rel["df"].__class__ == duckdb.DuckDBPyRelation def test_getitem_struct(self, duckdb_cursor): rel = duckdb_cursor.sql("select {'a':5, 'b':6} as a, 5 as b") - assert rel['a']['a'].fetchall()[0][0] == 5 - assert rel['a']['b'].fetchall()[0][0] == 6 + assert rel["a"]["a"].fetchall()[0][0] == 5 + assert rel["a"]["b"].fetchall()[0][0] == 6 def test_getattr_struct(self, duckdb_cursor): rel = duckdb_cursor.sql("select {'a':5, 'b':6} as a, 5 as b") @@ -56,7 +56,7 @@ def test_getattr_struct(self, duckdb_cursor): def test_getattr_spaces(self, duckdb_cursor): rel = duckdb_cursor.sql('select 42 as "hello world"') - assert rel['hello world'].fetchall()[0][0] == 42 + assert rel["hello world"].fetchall()[0][0] == 42 def test_getattr_doublequotes(self, duckdb_cursor): rel = duckdb_cursor.sql('select 1 as "tricky"", ""quotes", 2 as tricky, 3 as quotes') diff --git a/tests/fast/api/test_config.py b/tests/fast/api/test_config.py index 5db5f77b..4a0a0445 100644 --- a/tests/fast/api/test_config.py +++ b/tests/fast/api/test_config.py @@ -9,54 +9,54 @@ class TestDBConfig(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_default_order(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'a': [1, 2, 3]}) - con = duckdb.connect(':memory:', config={'default_order': 'desc'}) - result = con.execute('select * from df order by a').fetchall() + df = pandas.DataFrame({"a": [1, 2, 3]}) + con = duckdb.connect(":memory:", config={"default_order": "desc"}) + result = con.execute("select * from df order by a").fetchall() assert result == [(3,), (2,), (1,)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_null_order(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'a': [1, 2, 3, None]}) - con = duckdb.connect(':memory:', config={'default_null_order': 'nulls_last'}) - result = con.execute('select * from df order by a').fetchall() + df = pandas.DataFrame({"a": [1, 2, 3, None]}) + con = duckdb.connect(":memory:", config={"default_null_order": "nulls_last"}) + result = con.execute("select * from df order by a").fetchall() assert result == [(1,), (2,), (3,), (None,)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_multiple_options(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'a': [1, 2, 3, None]}) - con = duckdb.connect(':memory:', config={'default_null_order': 'nulls_last', 'default_order': 'desc'}) - result = con.execute('select * from df order by a').fetchall() + df = pandas.DataFrame({"a": [1, 2, 3, None]}) + con = duckdb.connect(":memory:", config={"default_null_order": "nulls_last", "default_order": "desc"}) + result = con.execute("select * from df order by a").fetchall() assert result == [(3,), (2,), (1,), (None,)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_external_access(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'a': [1, 2, 3]}) + df = pandas.DataFrame({"a": [1, 2, 3]}) # this works (replacement scan) - con_regular = duckdb.connect(':memory:', config={}) - con_regular.execute('select * from df') + con_regular = duckdb.connect(":memory:", config={}) + con_regular.execute("select * from df") # disable external access: this also disables pandas replacement scans - con = duckdb.connect(':memory:', config={'enable_external_access': False}) + con = duckdb.connect(":memory:", config={"enable_external_access": False}) # this should fail query_failed = False try: - con.execute('select * from df').fetchall() + con.execute("select * from df").fetchall() except: query_failed = True assert query_failed == True def test_extension_setting(self): - repository = os.environ.get('LOCAL_EXTENSION_REPO') + repository = os.environ.get("LOCAL_EXTENSION_REPO") if not repository: return - con = duckdb.connect(config={"TimeZone": "UTC", 'autoinstall_extension_repository': repository}) - assert 'UTC' == con.sql("select current_setting('TimeZone')").fetchone()[0] + con = duckdb.connect(config={"TimeZone": "UTC", "autoinstall_extension_repository": repository}) + assert "UTC" == con.sql("select current_setting('TimeZone')").fetchone()[0] def test_unrecognized_option(self, duckdb_cursor): success = True try: - con_regular = duckdb.connect(':memory:', config={'thisoptionisprobablynotthere': '42'}) + con_regular = duckdb.connect(":memory:", config={"thisoptionisprobablynotthere": "42"}) except: success = False assert success == False @@ -64,27 +64,27 @@ def test_unrecognized_option(self, duckdb_cursor): def test_incorrect_parameter(self, duckdb_cursor): success = True try: - con_regular = duckdb.connect(':memory:', config={'default_null_order': '42'}) + con_regular = duckdb.connect(":memory:", config={"default_null_order": "42"}) except: success = False assert success == False def test_user_agent_default(self, duckdb_cursor): - con_regular = duckdb.connect(':memory:') + con_regular = duckdb.connect(":memory:") regex = re.compile("duckdb/.* python/.*") # Expands to: SELECT * FROM pragma_user_agent() assert regex.match(con_regular.sql("PRAGMA user_agent").fetchone()[0]) is not None custom_user_agent = con_regular.sql("SELECT current_setting('custom_user_agent')").fetchone() - assert custom_user_agent[0] == '' + assert custom_user_agent[0] == "" def test_user_agent_custom(self, duckdb_cursor): - con_regular = duckdb.connect(':memory:', config={'custom_user_agent': 'CUSTOM_STRING'}) + con_regular = duckdb.connect(":memory:", config={"custom_user_agent": "CUSTOM_STRING"}) regex = re.compile("duckdb/.* python/.* CUSTOM_STRING") assert regex.match(con_regular.sql("PRAGMA user_agent").fetchone()[0]) is not None custom_user_agent = con_regular.sql("SELECT current_setting('custom_user_agent')").fetchone() - assert custom_user_agent[0] == 'CUSTOM_STRING' + assert custom_user_agent[0] == "CUSTOM_STRING" def test_secret_manager_option(self, duckdb_cursor): - con = duckdb.connect(':memory:', config={'allow_persistent_secrets': False}) - result = con.execute('select count(*) from duckdb_secrets()').fetchall() + con = duckdb.connect(":memory:", config={"allow_persistent_secrets": False}) + result = con.execute("select count(*) from duckdb_secrets()").fetchall() assert result == [(0,)] diff --git a/tests/fast/api/test_connection_close.py b/tests/fast/api/test_connection_close.py index e7a47404..f71a02bb 100644 --- a/tests/fast/api/test_connection_close.py +++ b/tests/fast/api/test_connection_close.py @@ -54,7 +54,7 @@ def test_get_closed_default_conn(self, duckdb_cursor): duckdb.close() # 'duckdb.close()' closes this connection, because we explicitly set it as the default - with pytest.raises(duckdb.ConnectionException, match='Connection Error: Connection already closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection Error: Connection already closed"): con.sql("select 42").fetchall() default_con = duckdb.default_connection() @@ -65,11 +65,11 @@ def test_get_closed_default_conn(self, duckdb_cursor): duckdb.sql("select 42").fetchall() # Show that the 'default_con' is still closed - with pytest.raises(duckdb.ConnectionException, match='Connection Error: Connection already closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection Error: Connection already closed"): default_con.sql("select 42").fetchall() duckdb.close() # This also does not error because we silently receive a new connection - con2 = duckdb.connect(':default:') + con2 = duckdb.connect(":default:") con2.sql("select 42").fetchall() diff --git a/tests/fast/api/test_cursor.py b/tests/fast/api/test_cursor.py index 9510fbd9..69c3fe79 100644 --- a/tests/fast/api/test_cursor.py +++ b/tests/fast/api/test_cursor.py @@ -7,7 +7,7 @@ class TestDBAPICursor(object): def test_cursor_basic(self): # Create a connection - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") # Then create a cursor on the connection cursor = con.cursor() # Use the cursor for queries @@ -15,14 +15,14 @@ def test_cursor_basic(self): assert res == [([1, 2, 3, None, 4],)] def test_cursor_preexisting(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") con.execute("create table tbl as select i a, i+1 b, i+2 c from range(5) tbl(i)") cursor = con.cursor() res = cursor.execute("select * from tbl").fetchall() assert res == [(0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)] def test_cursor_after_creation(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") # First create the cursor cursor = con.cursor() # Then create table on the source connection @@ -31,7 +31,7 @@ def test_cursor_after_creation(self): assert res == [(0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)] def test_cursor_mixed(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") # First create the cursor cursor = con.cursor() # Then create table on the cursor @@ -43,7 +43,7 @@ def test_cursor_mixed(self): assert res == [(0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6)] def test_cursor_temp_schema_closed(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") cursor = con.cursor() cursor.execute("create temp table tbl as select * from range(100)") other_cursor = con.cursor() @@ -54,7 +54,7 @@ def test_cursor_temp_schema_closed(self): res = other_cursor.execute("select * from tbl").fetchall() def test_cursor_temp_schema_open(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") cursor = con.cursor() cursor.execute("create temp table tbl as select * from range(100)") other_cursor = con.cursor() @@ -65,7 +65,7 @@ def test_cursor_temp_schema_open(self): res = other_cursor.execute("select * from tbl").fetchall() def test_cursor_temp_schema_both(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") cursor1 = con.cursor() cursor2 = con.cursor() cursor3 = con.cursor() @@ -92,23 +92,23 @@ def test_cursor_timezone(self): # Because the 'timezone' setting was not explicitly set for the connection # the setting of the DBConfig is used instead res = con1.execute("SELECT make_timestamptz(2000,01,20,03,30,59)").fetchone() - assert str(res) == '(datetime.datetime(2000, 1, 20, 3, 30, 59, tzinfo=),)' + assert str(res) == "(datetime.datetime(2000, 1, 20, 3, 30, 59, tzinfo=),)" def test_cursor_closed(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") con.close() with pytest.raises(duckdb.ConnectionException): cursor = con.cursor() def test_cursor_used_after_connection_closed(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") cursor = con.cursor() con.close() with pytest.raises(duckdb.ConnectionException): cursor.execute("select [1,2,3,4]") def test_cursor_used_after_close(self): - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") cursor = con.cursor() cursor.close() with pytest.raises(duckdb.ConnectionException): diff --git a/tests/fast/api/test_dbapi00.py b/tests/fast/api/test_dbapi00.py index 815a81b9..38d87887 100644 --- a/tests/fast/api/test_dbapi00.py +++ b/tests/fast/api/test_dbapi00.py @@ -12,7 +12,7 @@ def assert_result_equal(result): class TestSimpleDBAPI(object): def test_regular_selection(self, duckdb_cursor, integers): - duckdb_cursor.execute('SELECT * FROM integers') + duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchall() assert_result_equal(result) @@ -20,7 +20,7 @@ def test_fetchmany_default(self, duckdb_cursor, integers): # Get truth-value truth_value = len(duckdb_cursor.execute("select * from integers").fetchall()) - duckdb_cursor.execute('Select * from integers') + duckdb_cursor.execute("Select * from integers") # by default 'size' is 1 arraysize = 1 list_of_results = [] @@ -40,7 +40,7 @@ def test_fetchmany_default(self, duckdb_cursor, integers): def test_fetchmany(self, duckdb_cursor, integers): # Get truth value truth_value = len(duckdb_cursor.execute("select * from integers").fetchall()) - duckdb_cursor.execute('select * from integers') + duckdb_cursor.execute("select * from integers") list_of_results = [] arraysize = 3 expected_iteration_count = 1 + (int)(truth_value / arraysize) + (1 if truth_value % arraysize else 0) @@ -63,8 +63,8 @@ def test_fetchmany(self, duckdb_cursor, integers): assert len(res) == 0 def test_fetchmany_too_many(self, duckdb_cursor, integers): - truth_value = len(duckdb_cursor.execute('select * from integers').fetchall()) - duckdb_cursor.execute('select * from integers') + truth_value = len(duckdb_cursor.execute("select * from integers").fetchall()) + duckdb_cursor.execute("select * from integers") res = duckdb_cursor.fetchmany(truth_value * 5) assert len(res) == truth_value assert_result_equal(res) @@ -74,48 +74,48 @@ def test_fetchmany_too_many(self, duckdb_cursor, integers): assert len(res) == 0 def test_numpy_selection(self, duckdb_cursor, integers, timestamps): - duckdb_cursor.execute('SELECT * FROM integers') + duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchnumpy() arr = numpy.ma.masked_array(numpy.arange(11)) arr.mask = [False] * 10 + [True] - numpy.testing.assert_array_equal(result['i'], arr, "Incorrect result returned") - duckdb_cursor.execute('SELECT * FROM timestamps') + numpy.testing.assert_array_equal(result["i"], arr, "Incorrect result returned") + duckdb_cursor.execute("SELECT * FROM timestamps") result = duckdb_cursor.fetchnumpy() - arr = numpy.array(['1992-10-03 18:34:45', '2010-01-01 00:00:01', None], dtype="datetime64[ms]") + arr = numpy.array(["1992-10-03 18:34:45", "2010-01-01 00:00:01", None], dtype="datetime64[ms]") arr = numpy.ma.masked_array(arr) arr.mask = [False, False, True] - numpy.testing.assert_array_equal(result['t'], arr, "Incorrect result returned") + numpy.testing.assert_array_equal(result["t"], arr, "Incorrect result returned") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_selection(self, duckdb_cursor, pandas, integers, timestamps): import datetime from packaging.version import Version # I don't know when this exactly changed, but 2.0.3 does not support this, recent versions do - if Version(pandas.__version__) <= Version('2.0.3'): + if Version(pandas.__version__) <= Version("2.0.3"): pytest.skip("The resulting dtype is 'object' when given a Series with dtype Int32DType") - duckdb_cursor.execute('SELECT * FROM integers') + duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchdf() array = numpy.ma.masked_array(numpy.arange(11)) array.mask = [False] * 10 + [True] - arr = {'i': pandas.Series(array.data, dtype=pandas.Int32Dtype)} - arr['i'][array.mask] = pandas.NA + arr = {"i": pandas.Series(array.data, dtype=pandas.Int32Dtype)} + arr["i"][array.mask] = pandas.NA arr = pandas.DataFrame(arr) pandas.testing.assert_frame_equal(result, arr) - duckdb_cursor.execute('SELECT * FROM timestamps') + duckdb_cursor.execute("SELECT * FROM timestamps") result = duckdb_cursor.fetchdf() df = pandas.DataFrame( { - 't': pandas.Series( + "t": pandas.Series( data=[ datetime.datetime(year=1992, month=10, day=3, hour=18, minute=34, second=45), datetime.datetime(year=2010, month=1, day=1, hour=0, minute=0, second=1), None, ], - dtype='datetime64[us]', + dtype="datetime64[us]", ) } ) diff --git a/tests/fast/api/test_dbapi01.py b/tests/fast/api/test_dbapi01.py index dd0d2b4e..f7f00a10 100644 --- a/tests/fast/api/test_dbapi01.py +++ b/tests/fast/api/test_dbapi01.py @@ -6,8 +6,8 @@ class TestMultipleResultSets(object): def test_regular_selection(self, duckdb_cursor, integers): - duckdb_cursor.execute('SELECT * FROM integers') - duckdb_cursor.execute('SELECT * FROM integers') + duckdb_cursor.execute("SELECT * FROM integers") + duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchall() assert result == [ (0,), @@ -24,18 +24,18 @@ def test_regular_selection(self, duckdb_cursor, integers): ], "Incorrect result returned" def test_numpy_selection(self, duckdb_cursor, integers): - duckdb_cursor.execute('SELECT * FROM integers') - duckdb_cursor.execute('SELECT * FROM integers') + duckdb_cursor.execute("SELECT * FROM integers") + duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchnumpy() expected = numpy.ma.masked_array(numpy.arange(11), mask=([False] * 10 + [True])) - numpy.testing.assert_array_equal(result['i'], expected) + numpy.testing.assert_array_equal(result["i"], expected) def test_numpy_materialized(self, duckdb_cursor, integers): - connection = duckdb.connect('') + connection = duckdb.connect("") cursor = connection.cursor() - cursor.execute('CREATE TABLE integers (i integer)') - cursor.execute('INSERT INTO integers VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9),(NULL)') + cursor.execute("CREATE TABLE integers (i integer)") + cursor.execute("INSERT INTO integers VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9),(NULL)") rel = connection.table("integers") res = rel.aggregate("sum(i)").execute().fetchnumpy() - assert res['sum(i)'][0] == 45 + assert res["sum(i)"][0] == 45 diff --git a/tests/fast/api/test_dbapi04.py b/tests/fast/api/test_dbapi04.py index b2c9173a..1125f819 100644 --- a/tests/fast/api/test_dbapi04.py +++ b/tests/fast/api/test_dbapi04.py @@ -3,7 +3,7 @@ class TestSimpleDBAPI(object): def test_regular_selection(self, duckdb_cursor, integers): - duckdb_cursor.execute('SELECT * FROM integers') + duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchall() assert result == [ (0,), diff --git a/tests/fast/api/test_dbapi05.py b/tests/fast/api/test_dbapi05.py index 0de217f2..234fb2ec 100644 --- a/tests/fast/api/test_dbapi05.py +++ b/tests/fast/api/test_dbapi05.py @@ -3,7 +3,7 @@ class TestSimpleDBAPI(object): def test_prepare(self, duckdb_cursor): - result = duckdb_cursor.execute('SELECT CAST(? AS INTEGER), CAST(? AS INTEGER)', ['42', '84']).fetchall() + result = duckdb_cursor.execute("SELECT CAST(? AS INTEGER), CAST(? AS INTEGER)", ["42", "84"]).fetchall() assert result == [ ( 42, @@ -15,26 +15,26 @@ def test_prepare(self, duckdb_cursor): # from python docs c.execute( - '''CREATE TABLE stocks - (date text, trans text, symbol text, qty real, price real)''' + """CREATE TABLE stocks + (date text, trans text, symbol text, qty real, price real)""" ) c.execute("INSERT INTO stocks VALUES ('2006-01-05','BUY','RHAT',100,35.14)") - t = ('RHAT',) - result = c.execute('SELECT COUNT(*) FROM stocks WHERE symbol=?', t).fetchone() + t = ("RHAT",) + result = c.execute("SELECT COUNT(*) FROM stocks WHERE symbol=?", t).fetchone() assert result == (1,) - t = ['RHAT'] - result = c.execute('SELECT COUNT(*) FROM stocks WHERE symbol=?', t).fetchone() + t = ["RHAT"] + result = c.execute("SELECT COUNT(*) FROM stocks WHERE symbol=?", t).fetchone() assert result == (1,) # Larger example that inserts many records at a time purchases = [ - ('2006-03-28', 'BUY', 'IBM', 1000, 45.00), - ('2006-04-05', 'BUY', 'MSFT', 1000, 72.00), - ('2006-04-06', 'SELL', 'IBM', 500, 53.00), + ("2006-03-28", "BUY", "IBM", 1000, 45.00), + ("2006-04-05", "BUY", "MSFT", 1000, 72.00), + ("2006-04-06", "SELL", "IBM", 500, 53.00), ] - c.executemany('INSERT INTO stocks VALUES (?,?,?,?,?)', purchases) + c.executemany("INSERT INTO stocks VALUES (?,?,?,?,?)", purchases) - result = c.execute('SELECT count(*) FROM stocks').fetchone() + result = c.execute("SELECT count(*) FROM stocks").fetchone() assert result == (4,) diff --git a/tests/fast/api/test_dbapi07.py b/tests/fast/api/test_dbapi07.py index 7792b8de..238f30fc 100644 --- a/tests/fast/api/test_dbapi07.py +++ b/tests/fast/api/test_dbapi07.py @@ -7,10 +7,10 @@ class TestNumpyTimestampMilliseconds(object): def test_numpy_timestamp(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIMESTAMP '2019-11-26 21:11:42.501' as test_time").fetchnumpy() - assert res['test_time'] == numpy.datetime64('2019-11-26 21:11:42.501') + assert res["test_time"] == numpy.datetime64("2019-11-26 21:11:42.501") class TestTimestampMilliseconds(object): def test_numpy_timestamp(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIMESTAMP '2019-11-26 21:11:42.501' as test_time").fetchone()[0] - assert res == datetime.strptime('2019-11-26 21:11:42.501', '%Y-%m-%d %H:%M:%S.%f') + assert res == datetime.strptime("2019-11-26 21:11:42.501", "%Y-%m-%d %H:%M:%S.%f") diff --git a/tests/fast/api/test_dbapi08.py b/tests/fast/api/test_dbapi08.py index a81acfd1..457a9e78 100644 --- a/tests/fast/api/test_dbapi08.py +++ b/tests/fast/api/test_dbapi08.py @@ -6,7 +6,7 @@ class TestType(object): - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_fetchdf(self, pandas): con = duckdb.connect() con.execute("CREATE TABLE items(item VARCHAR)") @@ -14,7 +14,7 @@ def test_fetchdf(self, pandas): res = con.execute("SELECT item FROM items").fetchdf() assert isinstance(res, pandas.core.frame.DataFrame) - df = pandas.DataFrame({'item': ['jeans', '', None]}) + df = pandas.DataFrame({"item": ["jeans", "", None]}) print(res) print(df) diff --git a/tests/fast/api/test_dbapi09.py b/tests/fast/api/test_dbapi09.py index dde8ebff..538e7fc3 100644 --- a/tests/fast/api/test_dbapi09.py +++ b/tests/fast/api/test_dbapi09.py @@ -12,11 +12,11 @@ def test_fetchall_date(self, duckdb_cursor): def test_fetchnumpy_date(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT DATE '2020-01-10' as test_date").fetchnumpy() - arr = numpy.array(['2020-01-10'], dtype="datetime64[s]") + arr = numpy.array(["2020-01-10"], dtype="datetime64[s]") arr = numpy.ma.masked_array(arr) - numpy.testing.assert_array_equal(res['test_date'], arr) + numpy.testing.assert_array_equal(res["test_date"], arr) def test_fetchdf_date(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT DATE '2020-01-10' as test_date").fetchdf() - ser = pandas.Series(numpy.array(['2020-01-10'], dtype="datetime64[us]"), name="test_date") - pandas.testing.assert_series_equal(res['test_date'], ser) + ser = pandas.Series(numpy.array(["2020-01-10"], dtype="datetime64[us]"), name="test_date") + pandas.testing.assert_series_equal(res["test_date"], ser) diff --git a/tests/fast/api/test_dbapi12.py b/tests/fast/api/test_dbapi12.py index 78881f5e..833d231c 100644 --- a/tests/fast/api/test_dbapi12.py +++ b/tests/fast/api/test_dbapi12.py @@ -10,45 +10,45 @@ def test_readonly(self, duckdb_cursor): def test_rel(rel, duckdb_cursor): res = ( - rel.filter('i < 3') - .order('j') - .project('i') - .union(rel.filter('i > 2').project('i')) - .join(rel.set_alias('a1'), 'i') - .project('CAST(i as BIGINT) i, j') - .order('i') + rel.filter("i < 3") + .order("j") + .project("i") + .union(rel.filter("i > 2").project("i")) + .join(rel.set_alias("a1"), "i") + .project("CAST(i as BIGINT) i, j") + .order("i") ) pd.testing.assert_frame_equal(res.to_df(), test_df) res3 = duckdb_cursor.from_df(res.to_df()).to_df() pd.testing.assert_frame_equal(res3, test_df) - df_sql = res.query('x', 'select CAST(i as BIGINT) i, j from x') + df_sql = res.query("x", "select CAST(i as BIGINT) i, j from x") pd.testing.assert_frame_equal(df_sql.df(), test_df) - res2 = res.aggregate('i, count(j) as cj', 'i').order('i') + res2 = res.aggregate("i, count(j) as cj", "i").order("i") cmp_df = pd.DataFrame.from_dict({"i": [1, 2, 3], "cj": [1, 1, 1]}) pd.testing.assert_frame_equal(res2.to_df(), cmp_df) - duckdb_cursor.execute('DROP TABLE IF EXISTS a2') - rel.create('a2') - rel_a2 = duckdb_cursor.table('a2').project('CAST(i as BIGINT) i, j').to_df() + duckdb_cursor.execute("DROP TABLE IF EXISTS a2") + rel.create("a2") + rel_a2 = duckdb_cursor.table("a2").project("CAST(i as BIGINT) i, j").to_df() pd.testing.assert_frame_equal(rel_a2, test_df) - duckdb_cursor.execute('DROP TABLE IF EXISTS a3') - duckdb_cursor.execute('CREATE TABLE a3 (i INTEGER, j STRING)') - rel.insert_into('a3') - rel_a3 = duckdb_cursor.table('a3').project('CAST(i as BIGINT) i, j').to_df() + duckdb_cursor.execute("DROP TABLE IF EXISTS a3") + duckdb_cursor.execute("CREATE TABLE a3 (i INTEGER, j STRING)") + rel.insert_into("a3") + rel_a3 = duckdb_cursor.table("a3").project("CAST(i as BIGINT) i, j").to_df() pd.testing.assert_frame_equal(rel_a3, test_df) - duckdb_cursor.execute('CREATE TABLE a (i INTEGER, j STRING)') + duckdb_cursor.execute("CREATE TABLE a (i INTEGER, j STRING)") duckdb_cursor.execute("INSERT INTO a VALUES (1, 'one'), (2, 'two'), (3, 'three')") - duckdb_cursor.execute('CREATE VIEW v AS SELECT * FROM a') + duckdb_cursor.execute("CREATE VIEW v AS SELECT * FROM a") - duckdb_cursor.execute('CREATE TEMPORARY TABLE at_ (i INTEGER)') - duckdb_cursor.execute('CREATE TEMPORARY VIEW vt AS SELECT * FROM at_') + duckdb_cursor.execute("CREATE TEMPORARY TABLE at_ (i INTEGER)") + duckdb_cursor.execute("CREATE TEMPORARY VIEW vt AS SELECT * FROM at_") - rel_a = duckdb_cursor.table('a') - rel_v = duckdb_cursor.view('v') + rel_a = duckdb_cursor.table("a") + rel_v = duckdb_cursor.view("v") # rel_at = duckdb_cursor.table('at') # rel_vt = duckdb_cursor.view('vt') @@ -59,8 +59,8 @@ def test_rel(rel, duckdb_cursor): test_rel(rel_df, duckdb_cursor) def test_fromquery(self, duckdb_cursor): - assert duckdb.from_query('select 42').fetchone()[0] == 42 - assert duckdb_cursor.query('select 43').fetchone()[0] == 43 + assert duckdb.from_query("select 42").fetchone()[0] == 42 + assert duckdb_cursor.query("select 43").fetchone()[0] == 43 # assert duckdb_cursor.from_query('select 44').execute().fetchone()[0] == 44 # assert duckdb_cursor.from_query('select 45').execute().fetchone()[0] == 45 diff --git a/tests/fast/api/test_dbapi13.py b/tests/fast/api/test_dbapi13.py index fb7fbaa8..ffdb4884 100644 --- a/tests/fast/api/test_dbapi13.py +++ b/tests/fast/api/test_dbapi13.py @@ -14,9 +14,9 @@ def test_fetchnumpy_time(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIME '13:06:40' as test_time").fetchnumpy() arr = numpy.array([datetime.time(13, 6, 40)], dtype="object") arr = numpy.ma.masked_array(arr) - numpy.testing.assert_array_equal(res['test_time'], arr) + numpy.testing.assert_array_equal(res["test_time"], arr) def test_fetchdf_time(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIME '13:06:40' as test_time").fetchdf() ser = pandas.Series(numpy.array([datetime.time(13, 6, 40)], dtype="object"), name="test_time") - pandas.testing.assert_series_equal(res['test_time'], ser) + pandas.testing.assert_series_equal(res["test_time"], ser) diff --git a/tests/fast/api/test_dbapi_fetch.py b/tests/fast/api/test_dbapi_fetch.py index 6eda4b9d..9c47c54c 100644 --- a/tests/fast/api/test_dbapi_fetch.py +++ b/tests/fast/api/test_dbapi_fetch.py @@ -8,21 +8,21 @@ class TestDBApiFetch(object): def test_multiple_fetch_one(self, duckdb_cursor): con = duckdb.connect() - c = con.execute('SELECT 42') + c = con.execute("SELECT 42") assert c.fetchone() == (42,) assert c.fetchone() is None assert c.fetchone() is None def test_multiple_fetch_all(self, duckdb_cursor): con = duckdb.connect() - c = con.execute('SELECT 42') + c = con.execute("SELECT 42") assert c.fetchall() == [(42,)] assert c.fetchall() == [] assert c.fetchall() == [] def test_multiple_fetch_many(self, duckdb_cursor): con = duckdb.connect() - c = con.execute('SELECT 42') + c = con.execute("SELECT 42") assert c.fetchmany(1000) == [(42,)] assert c.fetchmany(1000) == [] assert c.fetchmany(1000) == [] @@ -30,8 +30,8 @@ def test_multiple_fetch_many(self, duckdb_cursor): def test_multiple_fetch_df(self, duckdb_cursor): pd = pytest.importorskip("pandas") con = duckdb.connect() - c = con.execute('SELECT 42::BIGINT AS a') - pd.testing.assert_frame_equal(c.df(), pd.DataFrame.from_dict({'a': [42]})) + c = con.execute("SELECT 42::BIGINT AS a") + pd.testing.assert_frame_equal(c.df(), pd.DataFrame.from_dict({"a": [42]})) assert c.df() is None assert c.df() is None @@ -39,36 +39,36 @@ def test_multiple_fetch_arrow(self, duckdb_cursor): pd = pytest.importorskip("pandas") arrow = pytest.importorskip("pyarrow") con = duckdb.connect() - c = con.execute('SELECT 42::BIGINT AS a') + c = con.execute("SELECT 42::BIGINT AS a") table = c.fetch_arrow_table() df = table.to_pandas() - pd.testing.assert_frame_equal(df, pd.DataFrame.from_dict({'a': [42]})) + pd.testing.assert_frame_equal(df, pd.DataFrame.from_dict({"a": [42]})) assert c.fetch_arrow_table() is None assert c.fetch_arrow_table() is None def test_multiple_close(self, duckdb_cursor): con = duckdb.connect() - c = con.execute('SELECT 42') + c = con.execute("SELECT 42") c.close() c.close() c.close() - with pytest.raises(duckdb.InvalidInputException, match='No open result set'): + with pytest.raises(duckdb.InvalidInputException, match="No open result set"): c.fetchall() def test_multiple_fetch_all_relation(self, duckdb_cursor): - res = duckdb_cursor.query('SELECT 42') + res = duckdb_cursor.query("SELECT 42") assert res.fetchall() == [(42,)] assert res.fetchall() == [(42,)] assert res.fetchall() == [(42,)] def test_multiple_fetch_many_relation(self, duckdb_cursor): - res = duckdb_cursor.query('SELECT 42') + res = duckdb_cursor.query("SELECT 42") assert res.fetchmany(10000) == [(42,)] assert res.fetchmany(10000) == [] assert res.fetchmany(10000) == [] def test_fetch_one_relation(self, duckdb_cursor): - res = duckdb_cursor.query('SELECT * FROM range(3)') + res = duckdb_cursor.query("SELECT * FROM range(3)") assert res.fetchone() == (0,) assert res.fetchone() == (1,) assert res.fetchone() == (2,) @@ -86,40 +86,40 @@ def test_fetch_one_relation(self, duckdb_cursor): assert res.fetchone() is None @pytest.mark.parametrize( - 'test_case', + "test_case", [ - (False, 'BOOLEAN', False), - (-128, 'TINYINT', -128), - (-32768, 'SMALLINT', -32768), - (-2147483648, 'INTEGER', -2147483648), - (-9223372036854775808, 'BIGINT', -9223372036854775808), - (-170141183460469231731687303715884105728, 'HUGEINT', -170141183460469231731687303715884105728), - (0, 'UTINYINT', 0), - (0, 'USMALLINT', 0), - (0, 'UINTEGER', 0), - (0, 'UBIGINT', 0), - (0, 'UHUGEINT', 0), - (1.3423423767089844, 'FLOAT', 1.3423424), - (1.3423424, 'DOUBLE', 1.3423424), - (Decimal('1.342342'), 'DECIMAL(10, 6)', 1.342342), - ('hello', "ENUM('world', 'hello')", 'hello'), - ('🦆🦆🦆🦆🦆🦆', 'VARCHAR', '🦆🦆🦆🦆🦆🦆'), - (b'thisisalongblob\x00withnullbytes', 'BLOB', 'thisisalongblob\\x00withnullbytes'), - ('0010001001011100010101011010111', 'BITSTRING', '0010001001011100010101011010111'), - ('290309-12-22 (BC) 00:00:00', 'TIMESTAMP', '290309-12-22 (BC) 00:00:00'), - ('290309-12-22 (BC) 00:00:00', 'TIMESTAMP_MS', '290309-12-22 (BC) 00:00:00'), - (datetime.datetime(1677, 9, 22, 0, 0), 'TIMESTAMP_NS', '1677-09-22 00:00:00'), - ('290309-12-22 (BC) 00:00:00', 'TIMESTAMP_S', '290309-12-22 (BC) 00:00:00'), - ('290309-12-22 (BC) 00:00:30+00', 'TIMESTAMPTZ', '290309-12-22 (BC) 00:17:30+00:17'), + (False, "BOOLEAN", False), + (-128, "TINYINT", -128), + (-32768, "SMALLINT", -32768), + (-2147483648, "INTEGER", -2147483648), + (-9223372036854775808, "BIGINT", -9223372036854775808), + (-170141183460469231731687303715884105728, "HUGEINT", -170141183460469231731687303715884105728), + (0, "UTINYINT", 0), + (0, "USMALLINT", 0), + (0, "UINTEGER", 0), + (0, "UBIGINT", 0), + (0, "UHUGEINT", 0), + (1.3423423767089844, "FLOAT", 1.3423424), + (1.3423424, "DOUBLE", 1.3423424), + (Decimal("1.342342"), "DECIMAL(10, 6)", 1.342342), + ("hello", "ENUM('world', 'hello')", "hello"), + ("🦆🦆🦆🦆🦆🦆", "VARCHAR", "🦆🦆🦆🦆🦆🦆"), + (b"thisisalongblob\x00withnullbytes", "BLOB", "thisisalongblob\\x00withnullbytes"), + ("0010001001011100010101011010111", "BITSTRING", "0010001001011100010101011010111"), + ("290309-12-22 (BC) 00:00:00", "TIMESTAMP", "290309-12-22 (BC) 00:00:00"), + ("290309-12-22 (BC) 00:00:00", "TIMESTAMP_MS", "290309-12-22 (BC) 00:00:00"), + (datetime.datetime(1677, 9, 22, 0, 0), "TIMESTAMP_NS", "1677-09-22 00:00:00"), + ("290309-12-22 (BC) 00:00:00", "TIMESTAMP_S", "290309-12-22 (BC) 00:00:00"), + ("290309-12-22 (BC) 00:00:30+00", "TIMESTAMPTZ", "290309-12-22 (BC) 00:17:30+00:17"), ( datetime.time(0, 0, tzinfo=datetime.timezone(datetime.timedelta(seconds=57599))), - 'TIMETZ', - '00:00:00+15:59:59', + "TIMETZ", + "00:00:00+15:59:59", ), - ('5877642-06-25 (BC)', 'DATE', '5877642-06-25 (BC)'), - (UUID('cd57dfbd-d65f-4e15-991e-2a92e74b9f79'), 'UUID', 'cd57dfbd-d65f-4e15-991e-2a92e74b9f79'), - (datetime.timedelta(days=90), 'INTERVAL', '3 months'), - ('🦆🦆🦆🦆🦆🦆', 'UNION(a int, b bool, c varchar)', '🦆🦆🦆🦆🦆🦆'), + ("5877642-06-25 (BC)", "DATE", "5877642-06-25 (BC)"), + (UUID("cd57dfbd-d65f-4e15-991e-2a92e74b9f79"), "UUID", "cd57dfbd-d65f-4e15-991e-2a92e74b9f79"), + (datetime.timedelta(days=90), "INTERVAL", "3 months"), + ("🦆🦆🦆🦆🦆🦆", "UNION(a int, b bool, c varchar)", "🦆🦆🦆🦆🦆🦆"), ], ) def test_fetch_dict_coverage(self, duckdb_cursor, test_case): @@ -138,7 +138,7 @@ def test_fetch_dict_coverage(self, duckdb_cursor, test_case): print(res[0].keys()) assert res[0][python_key] == -2147483648 - @pytest.mark.parametrize('test_case', ['VARCHAR[]']) + @pytest.mark.parametrize("test_case", ["VARCHAR[]"]) def test_fetch_dict_key_not_hashable(self, duckdb_cursor, test_case): key_type = test_case query = f""" @@ -153,4 +153,4 @@ def test_fetch_dict_key_not_hashable(self, duckdb_cursor, test_case): select a from map_cte; """ res = duckdb_cursor.sql(query).fetchone() - assert 'key' in res[0].keys() + assert "key" in res[0].keys() diff --git a/tests/fast/api/test_duckdb_connection.py b/tests/fast/api/test_duckdb_connection.py index 4cb565c1..4b0dc4d6 100644 --- a/tests/fast/api/test_duckdb_connection.py +++ b/tests/fast/api/test_duckdb_connection.py @@ -9,9 +9,9 @@ def is_dunder_method(method_name: str) -> bool: if len(method_name) < 4: return False - if method_name.startswith('_pybind11'): + if method_name.startswith("_pybind11"): return True - return method_name[:2] == '__' and method_name[:-3:-1] == '__' + return method_name[:2] == "__" and method_name[:-3:-1] == "__" @pytest.fixture(scope="session") @@ -23,32 +23,32 @@ def tmp_database(tmp_path_factory): # This file contains tests for DuckDBPyConnection methods, # wrapped by the 'duckdb' module, to execute with the 'default_connection' class TestDuckDBConnection(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_append(self, pandas): duckdb.execute("Create table integers (i integer)") df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) - duckdb.append('integers', df_in) - assert duckdb.execute('select count(*) from integers').fetchone()[0] == 5 + duckdb.append("integers", df_in) + assert duckdb.execute("select count(*) from integers").fetchone()[0] == 5 # cleanup duckdb.execute("drop table integers") def test_default_connection_from_connect(self): - duckdb.sql('create or replace table connect_default_connect (i integer)') - con = duckdb.connect(':default:') - con.sql('select i from connect_default_connect') - duckdb.sql('drop table connect_default_connect') + duckdb.sql("create or replace table connect_default_connect (i integer)") + con = duckdb.connect(":default:") + con.sql("select i from connect_default_connect") + duckdb.sql("drop table connect_default_connect") with pytest.raises(duckdb.Error): - con.sql('select i from connect_default_connect') + con.sql("select i from connect_default_connect") # not allowed with additional options with pytest.raises( - duckdb.InvalidInputException, match='Default connection fetching is only allowed without additional options' + duckdb.InvalidInputException, match="Default connection fetching is only allowed without additional options" ): - con = duckdb.connect(':default:', read_only=True) + con = duckdb.connect(":default:", read_only=True) def test_arrow(self): pyarrow = pytest.importorskip("pyarrow") @@ -114,7 +114,7 @@ def test_readonly_properties(self): duckdb.execute("select 42") description = duckdb.description() rowcount = duckdb.rowcount() - assert description == [('42', 'INTEGER', None, None, None, None, None)] + assert description == [("42", "INTEGER", None, None, None, None, None)] assert rowcount == -1 def test_execute(self): @@ -124,29 +124,29 @@ def test_executemany(self): # executemany does not keep an open result set # TODO: shouldn't we also have a version that executes a query multiple times with different parameters, returning all of the results? duckdb.execute("create table tbl (i integer, j varchar)") - duckdb.executemany("insert into tbl VALUES (?, ?)", [(5, 'test'), (2, 'duck'), (42, 'quack')]) + duckdb.executemany("insert into tbl VALUES (?, ?)", [(5, "test"), (2, "duck"), (42, "quack")]) res = duckdb.table("tbl").fetchall() - assert res == [(5, 'test'), (2, 'duck'), (42, 'quack')] + assert res == [(5, "test"), (2, "duck"), (42, "quack")] duckdb.execute("drop table tbl") def test_pystatement(self): - with pytest.raises(duckdb.ParserException, match='seledct'): - statements = duckdb.extract_statements('seledct 42; select 21') + with pytest.raises(duckdb.ParserException, match="seledct"): + statements = duckdb.extract_statements("seledct 42; select 21") - statements = duckdb.extract_statements('select $1; select 21') + statements = duckdb.extract_statements("select $1; select 21") assert len(statements) == 2 - assert statements[0].query == 'select $1' + assert statements[0].query == "select $1" assert statements[0].type == duckdb.StatementType.SELECT - assert statements[0].named_parameters == set('1') + assert statements[0].named_parameters == set("1") assert statements[0].expected_result_type == [duckdb.ExpectedResultType.QUERY_RESULT] - assert statements[1].query == ' select 21' + assert statements[1].query == " select 21" assert statements[1].type == duckdb.StatementType.SELECT assert statements[1].named_parameters == set() with pytest.raises( duckdb.InvalidInputException, - match='Please provide either a DuckDBPyStatement or a string representing the query', + match="Please provide either a DuckDBPyStatement or a string representing the query", ): rel = duckdb.query(statements) @@ -158,23 +158,23 @@ def test_pystatement(self): with pytest.raises( duckdb.InvalidInputException, - match='Values were not provided for the following prepared statement parameters: 1', + match="Values were not provided for the following prepared statement parameters: 1", ): duckdb.execute(statements[0]) - assert duckdb.execute(statements[0], {'1': 42}).fetchall() == [(42,)] + assert duckdb.execute(statements[0], {"1": 42}).fetchall() == [(42,)] duckdb.execute("create table tbl(a integer)") - statements = duckdb.extract_statements('insert into tbl select $1') + statements = duckdb.extract_statements("insert into tbl select $1") assert statements[0].expected_result_type == [ duckdb.ExpectedResultType.CHANGED_ROWS, duckdb.ExpectedResultType.QUERY_RESULT, ] with pytest.raises( - duckdb.InvalidInputException, match='executemany requires a non-empty list of parameter sets to be provided' + duckdb.InvalidInputException, match="executemany requires a non-empty list of parameter sets to be provided" ): duckdb.executemany(statements[0]) duckdb.executemany(statements[0], [(21,), (22,), (23,)]) - assert duckdb.table('tbl').fetchall() == [(21,), (22,), (23,)] + assert duckdb.table("tbl").fetchall() == [(21,), (22,), (23,)] duckdb.execute("drop table tbl") def test_fetch_arrow_table(self): @@ -188,18 +188,18 @@ def test_fetch_arrow_table(self): duckdb.execute("Insert Into test values ('" + str(i) + "')") duckdb.execute("Insert Into test values ('5000')") duckdb.execute("Insert Into test values ('6000')") - sql = ''' + sql = """ SELECT a, COUNT(*) AS repetitions FROM test GROUP BY a - ''' + """ result_df = duckdb.execute(sql).df() arrow_table = duckdb.execute(sql).fetch_arrow_table() arrow_df = arrow_table.to_pandas() - assert result_df['repetitions'].sum() == arrow_df['repetitions'].sum() + assert result_df["repetitions"].sum() == arrow_df["repetitions"].sum() duckdb.execute("drop table test") def test_fetch_df(self): @@ -213,10 +213,10 @@ def test_fetch_df_chunk(self): duckdb.execute("CREATE table t as select range a from range(3000);") query = duckdb.execute("SELECT a FROM t") cur_chunk = query.fetch_df_chunk() - assert cur_chunk['a'][0] == 0 + assert cur_chunk["a"][0] == 0 assert len(cur_chunk) == 2048 cur_chunk = query.fetch_df_chunk() - assert cur_chunk['a'][0] == 2048 + assert cur_chunk["a"][0] == 2048 assert len(cur_chunk) == 952 duckdb.execute("DROP TABLE t") @@ -247,11 +247,11 @@ def test_fetchnumpy(self): numpy = pytest.importorskip("numpy") duckdb.execute("SELECT BLOB 'hello'") results = duckdb.fetchall() - assert results[0][0] == b'hello' + assert results[0][0] == b"hello" duckdb.execute("SELECT BLOB 'hello' AS a") results = duckdb.fetchnumpy() - assert results['a'] == numpy.array([b'hello'], dtype=object) + assert results["a"] == numpy.array([b"hello"], dtype=object) def test_fetchone(self): assert (0,) == duckdb.execute("select * from range(5)").fetchone() @@ -288,11 +288,11 @@ def test_register(self): def test_register_relation(self): con = duckdb.connect() - rel = con.sql('select [5,4,3]') + rel = con.sql("select [5,4,3]") con.register("relation", rel) con.sql("create table tbl as select * from relation") - assert con.table('tbl').fetchall() == [([5, 4, 3],)] + assert con.table("tbl").fetchall() == [([5, 4, 3],)] def test_unregister_problematic_behavior(self, duckdb_cursor): # We have a VIEW called 'vw' in the Catalog @@ -302,33 +302,33 @@ def test_unregister_problematic_behavior(self, duckdb_cursor): # Create a registered object called 'vw' arrow_result = duckdb_cursor.execute("select 42").fetch_arrow_table() with pytest.raises(duckdb.CatalogException, match='View with name "vw" already exists'): - duckdb_cursor.register('vw', arrow_result) + duckdb_cursor.register("vw", arrow_result) # Temporary views take precedence over registered objects assert duckdb_cursor.execute("select * from vw").fetchone() == (0,) # Decide that we're done with this registered object.. - duckdb_cursor.unregister('vw') + duckdb_cursor.unregister("vw") # This should not have affected the existing view: assert duckdb_cursor.execute("select * from vw").fetchone() == (0,) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_relation_out_of_scope(self, pandas): def temporary_scope(): # Create a connection, we will return this con = duckdb.connect() # Create a dataframe - df = pandas.DataFrame({'a': [1, 2, 3]}) + df = pandas.DataFrame({"a": [1, 2, 3]}) # The dataframe has to be registered as well # making sure it does not go out of scope con.register("df", df) - rel = con.sql('select * from df') + rel = con.sql("select * from df") con.register("relation", rel) return con con = temporary_scope() - res = con.sql('select * from relation').fetchall() + res = con.sql("select * from relation").fetchall() print(res) def test_table(self): diff --git a/tests/fast/api/test_duckdb_execute.py b/tests/fast/api/test_duckdb_execute.py index fba01a0c..a025fc42 100644 --- a/tests/fast/api/test_duckdb_execute.py +++ b/tests/fast/api/test_duckdb_execute.py @@ -4,8 +4,8 @@ class TestDuckDBExecute(object): def test_execute_basic(self, duckdb_cursor): - duckdb_cursor.execute('create table t as select 5') - res = duckdb_cursor.table('t').fetchall() + duckdb_cursor.execute("create table t as select 5") + res = duckdb_cursor.table("t").fetchall() assert res == [(5,)] def test_execute_many_basic(self, duckdb_cursor): @@ -19,11 +19,11 @@ def test_execute_many_basic(self, duckdb_cursor): """, (99,), ) - res = duckdb_cursor.table('t').fetchall() + res = duckdb_cursor.table("t").fetchall() assert res == [(99,)] @pytest.mark.parametrize( - 'rowcount', + "rowcount", [ 50, 2048, @@ -53,7 +53,7 @@ def test_execute_many_error(self, duckdb_cursor): # Prepared parameter used in a statement that is not the last with pytest.raises( - duckdb.NotImplementedException, match='Prepared parameters are only supported for the last statement' + duckdb.NotImplementedException, match="Prepared parameters are only supported for the last statement" ): duckdb_cursor.execute( """ @@ -73,11 +73,11 @@ def to_insert_from_generator(what): gen = to_insert_from_generator(to_insert) duckdb_cursor.execute("CREATE TABLE unittest_generator (a INTEGER);") duckdb_cursor.executemany("INSERT into unittest_generator (a) VALUES (?)", gen) - assert duckdb_cursor.table('unittest_generator').fetchall() == [(1,), (2,), (3,)] + assert duckdb_cursor.table("unittest_generator").fetchall() == [(1,), (2,), (3,)] def test_execute_multiple_statements(self, duckdb_cursor): pd = pytest.importorskip("pandas") - df = pd.DataFrame({'a': [5, 6, 7, 8]}) + df = pd.DataFrame({"a": [5, 6, 7, 8]}) sql = """ select * from df; select * from VALUES (1),(2),(3),(4) t(a); diff --git a/tests/fast/api/test_duckdb_query.py b/tests/fast/api/test_duckdb_query.py index 43f36603..2ecfd8f3 100644 --- a/tests/fast/api/test_duckdb_query.py +++ b/tests/fast/api/test_duckdb_query.py @@ -7,38 +7,38 @@ class TestDuckDBQuery(object): def test_duckdb_query(self, duckdb_cursor): # we can use duckdb_cursor.sql to run both DDL statements and select statements - duckdb_cursor.sql('create view v1 as select 42 i') - rel = duckdb_cursor.sql('select * from v1') + duckdb_cursor.sql("create view v1 as select 42 i") + rel = duckdb_cursor.sql("select * from v1") assert rel.fetchall()[0][0] == 42 # also multiple statements - duckdb_cursor.sql('create view v2 as select i*2 j from v1; create view v3 as select j * 2 from v2;') - rel = duckdb_cursor.sql('select * from v3') + duckdb_cursor.sql("create view v2 as select i*2 j from v1; create view v3 as select j * 2 from v2;") + rel = duckdb_cursor.sql("select * from v3") assert rel.fetchall()[0][0] == 168 # we can run multiple select statements - we get only the last result - res = duckdb_cursor.sql('select 42; select 84;').fetchall() + res = duckdb_cursor.sql("select 42; select 84;").fetchall() assert res == [(84,)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_duckdb_from_query_multiple_statements(self, pandas): - tst_df = pandas.DataFrame({'a': [1, 23, 3, 5]}) + tst_df = pandas.DataFrame({"a": [1, 23, 3, 5]}) res = duckdb.sql( - ''' + """ select 42; select * from tst_df union all select * from tst_df; - ''' + """ ).fetchall() assert res == [(1,), (23,), (3,), (5,), (1,), (23,), (3,), (5,)] def test_duckdb_query_empty_result(self): con = duckdb.connect() # show tables on empty connection does not produce any tuples - res = con.query('show tables').fetchall() + res = con.query("show tables").fetchall() assert res == [] def test_parametrized_explain(self, duckdb_cursor): @@ -57,7 +57,7 @@ def test_parametrized_explain(self, duckdb_cursor): duckdb_cursor.execute(query, params) results = duckdb_cursor.fetchall() - assert 'EXPLAIN_ANALYZE' in results[0][1] + assert "EXPLAIN_ANALYZE" in results[0][1] def test_named_param(self): con = duckdb.connect() @@ -83,7 +83,7 @@ def test_named_param(self): from range(100) tbl(i) """, - {'param': 5, 'other_param': 10}, + {"param": 5, "other_param": 10}, ).fetchall() assert res == original_res @@ -95,14 +95,14 @@ def test_named_param_not_dict(self): duckdb.InvalidInputException, match="Values were not provided for the following prepared statement parameters: name1, name2, name3", ): - con.execute("select $name1, $name2, $name3", ['name1', 'name2', 'name3']) + con.execute("select $name1, $name2, $name3", ["name1", "name2", "name3"]) def test_named_param_basic(self): con = duckdb.connect() - res = con.execute("select $name1, $name2, $name3", {'name1': 5, 'name2': 3, 'name3': 'a'}).fetchall() + res = con.execute("select $name1, $name2, $name3", {"name1": 5, "name2": 3, "name3": "a"}).fetchall() assert res == [ - (5, 3, 'a'), + (5, 3, "a"), ] def test_named_param_not_exhaustive(self): @@ -112,7 +112,7 @@ def test_named_param_not_exhaustive(self): duckdb.InvalidInputException, match="Invalid Input Error: Values were not provided for the following prepared statement parameters: name3", ): - con.execute("select $name1, $name2, $name3", {'name1': 5, 'name2': 3}) + con.execute("select $name1, $name2, $name3", {"name1": 5, "name2": 3}) def test_named_param_excessive(self): con = duckdb.connect() @@ -121,7 +121,7 @@ def test_named_param_excessive(self): duckdb.InvalidInputException, match="Values were not provided for the following prepared statement parameters: name3", ): - con.execute("select $name1, $name2, $name3", {'name1': 5, 'name2': 3, 'not_a_named_param': 5}) + con.execute("select $name1, $name2, $name3", {"name1": 5, "name2": 3, "not_a_named_param": 5}) def test_named_param_not_named(self): con = duckdb.connect() @@ -130,7 +130,7 @@ def test_named_param_not_named(self): duckdb.InvalidInputException, match="Values were not provided for the following prepared statement parameters: 1, 2", ): - con.execute("select $1, $1, $2", {'name1': 5, 'name2': 3}) + con.execute("select $1, $1, $2", {"name1": 5, "name2": 3}) def test_named_param_mixed(self): con = duckdb.connect() @@ -138,13 +138,13 @@ def test_named_param_mixed(self): with pytest.raises( duckdb.NotImplementedException, match="Mixing named and positional parameters is not supported yet" ): - con.execute("select $name1, $1, $2", {'name1': 5, 'name2': 3}) + con.execute("select $name1, $1, $2", {"name1": 5, "name2": 3}) def test_named_param_strings_with_dollarsign(self): con = duckdb.connect() - res = con.execute("select '$name1', $name1, $name1, '$name1'", {'name1': 5}).fetchall() - assert res == [('$name1', 5, 5, '$name1')] + res = con.execute("select '$name1', $name1, $name1, '$name1'", {"name1": 5}).fetchall() + assert res == [("$name1", 5, 5, "$name1")] def test_named_param_case_insensivity(self): con = duckdb.connect() @@ -153,10 +153,10 @@ def test_named_param_case_insensivity(self): """ select $NaMe1, $NAME2, $name3 """, - {'name1': 5, 'nAmE2': 3, 'NAME3': 'a'}, + {"name1": 5, "nAmE2": 3, "NAME3": "a"}, ).fetchall() assert res == [ - (5, 3, 'a'), + (5, 3, "a"), ] def test_named_param_keyword(self): @@ -176,16 +176,16 @@ def test_conversion_from_tuple(self): assert result == [([21, 22, 42],)] # If wrapped in a Value, it can convert to a struct - result = con.execute("select $1", [Value(('a', 21, True), {'a': str, 'b': int, 'c': bool})]).fetchall() - assert result == [({'a': 'a', 'b': 21, 'c': True},)] + result = con.execute("select $1", [Value(("a", 21, True), {"a": str, "b": int, "c": bool})]).fetchall() + assert result == [({"a": "a", "b": 21, "c": True},)] # If the amount of items in the tuple and the children of the struct don't match # we throw an error with pytest.raises( duckdb.InvalidInputException, - match='Tried to create a STRUCT value from a tuple containing 3 elements, but the STRUCT consists of 2 children', + match="Tried to create a STRUCT value from a tuple containing 3 elements, but the STRUCT consists of 2 children", ): - result = con.execute("select $1", [Value(('a', 21, True), {'a': str, 'b': int})]).fetchall() + result = con.execute("select $1", [Value(("a", 21, True), {"a": str, "b": int})]).fetchall() # If we try to create anything other than a STRUCT or a LIST out of the tuple, we throw an error with pytest.raises(duckdb.InvalidInputException, match="Can't convert tuple to a Value of type VARCHAR"): @@ -194,12 +194,12 @@ def test_conversion_from_tuple(self): def test_column_name_behavior(self, duckdb_cursor): _ = pytest.importorskip("pandas") - expected_names = ['one', 'ONE_1'] + expected_names = ["one", "ONE_1"] df = duckdb_cursor.execute('select 1 as one, 2 as "ONE"').fetchdf() assert expected_names == list(df.columns) - duckdb_cursor.register('tbl', df) + duckdb_cursor.register("tbl", df) df = duckdb_cursor.execute("select * from tbl").fetchdf() assert expected_names == list(df.columns) diff --git a/tests/fast/api/test_explain.py b/tests/fast/api/test_explain.py index 73c198b9..feedc134 100644 --- a/tests/fast/api/test_explain.py +++ b/tests/fast/api/test_explain.py @@ -4,40 +4,40 @@ class TestExplain(object): def test_explain_basic(self, duckdb_cursor): - res = duckdb_cursor.sql('select 42').explain() + res = duckdb_cursor.sql("select 42").explain() assert isinstance(res, str) def test_explain_standard(self, duckdb_cursor): - res = duckdb_cursor.sql('select 42').explain('standard') + res = duckdb_cursor.sql("select 42").explain("standard") assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain('STANDARD') + res = duckdb_cursor.sql("select 42").explain("STANDARD") assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain(duckdb.STANDARD) + res = duckdb_cursor.sql("select 42").explain(duckdb.STANDARD) assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain(duckdb.ExplainType.STANDARD) + res = duckdb_cursor.sql("select 42").explain(duckdb.ExplainType.STANDARD) assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain(0) + res = duckdb_cursor.sql("select 42").explain(0) assert isinstance(res, str) def test_explain_analyze(self, duckdb_cursor): - res = duckdb_cursor.sql('select 42').explain('analyze') + res = duckdb_cursor.sql("select 42").explain("analyze") assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain('ANALYZE') + res = duckdb_cursor.sql("select 42").explain("ANALYZE") assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain(duckdb.ExplainType.ANALYZE) + res = duckdb_cursor.sql("select 42").explain(duckdb.ExplainType.ANALYZE) assert isinstance(res, str) - res = duckdb_cursor.sql('select 42').explain(1) + res = duckdb_cursor.sql("select 42").explain(1) assert isinstance(res, str) def test_explain_df(self, duckdb_cursor): pd = pytest.importorskip("pandas") - df = pd.DataFrame({'a': [42]}) - res = duckdb_cursor.sql('select * from df').explain('ANALYZE') + df = pd.DataFrame({"a": [42]}) + res = duckdb_cursor.sql("select * from df").explain("ANALYZE") assert isinstance(res, str) diff --git a/tests/fast/api/test_fsspec.py b/tests/fast/api/test_fsspec.py index a878fda5..7b797598 100644 --- a/tests/fast/api/test_fsspec.py +++ b/tests/fast/api/test_fsspec.py @@ -49,7 +49,7 @@ def __init__(self) -> None: self._data = {"a": parquet_data, "b": parquet_data} fsspec.register_implementation("deadlock", TestFileSystem, clobber=True) - fs = fsspec.filesystem('deadlock') + fs = fsspec.filesystem("deadlock") duckdb_cursor.register_filesystem(fs) result = duckdb_cursor.read_parquet(file_globs=["deadlock://a", "deadlock://b"], union_by_name=True) diff --git a/tests/fast/api/test_insert_into.py b/tests/fast/api/test_insert_into.py index e6d4c6ba..2537c182 100644 --- a/tests/fast/api/test_insert_into.py +++ b/tests/fast/api/test_insert_into.py @@ -7,22 +7,22 @@ class TestInsertInto(object): def test_insert_into_schema(self, duckdb_cursor): # open connection con = duckdb.connect() - con.execute('CREATE SCHEMA s') - con.execute('CREATE TABLE s.t (id INTEGER PRIMARY KEY)') + con.execute("CREATE SCHEMA s") + con.execute("CREATE TABLE s.t (id INTEGER PRIMARY KEY)") # make relation - df = DataFrame([1], columns=['id']) + df = DataFrame([1], columns=["id"]) rel = con.from_df(df) - rel.insert_into('s.t') + rel.insert_into("s.t") assert con.execute("select * from s.t").fetchall() == [(1,)] # This should fail since this will go to default schema with pytest.raises(duckdb.CatalogException): - rel.insert_into('t') + rel.insert_into("t") # If we add t in the default schema it should work. - con.execute('CREATE TABLE t (id INTEGER PRIMARY KEY)') - rel.insert_into('t') + con.execute("CREATE TABLE t (id INTEGER PRIMARY KEY)") + rel.insert_into("t") assert con.execute("select * from t").fetchall() == [(1,)] diff --git a/tests/fast/api/test_join.py b/tests/fast/api/test_join.py index 7d7f45c2..5e2a148f 100644 --- a/tests/fast/api/test_join.py +++ b/tests/fast/api/test_join.py @@ -8,7 +8,7 @@ def test_alias_from_sql(self): rel1 = con.sql("SELECT 1 AS col1, 2 AS col2") rel2 = con.sql("SELECT 1 AS col1, 3 AS col3") - rel = con.sql('select * from rel1 JOIN rel2 USING (col1)') + rel = con.sql("select * from rel1 JOIN rel2 USING (col1)") rel.show() res = rel.fetchall() assert res == [(1, 2, 3)] @@ -19,27 +19,27 @@ def test_relational_join(self): rel1 = con.sql("SELECT 1 AS col1, 2 AS col2") rel2 = con.sql("SELECT 1 AS col1, 3 AS col3") - rel = rel1.join(rel2, 'col1') + rel = rel1.join(rel2, "col1") res = rel.fetchall() assert res == [(1, 2, 3)] def test_relational_join_alias_collision(self): con = duckdb.connect() - rel1 = con.sql("SELECT 1 AS col1, 2 AS col2").set_alias('a') - rel2 = con.sql("SELECT 1 AS col1, 3 AS col3").set_alias('a') + rel1 = con.sql("SELECT 1 AS col1, 2 AS col2").set_alias("a") + rel2 = con.sql("SELECT 1 AS col1, 3 AS col3").set_alias("a") - with pytest.raises(duckdb.InvalidInputException, match='Both relations have the same alias'): - rel = rel1.join(rel2, 'col1') + with pytest.raises(duckdb.InvalidInputException, match="Both relations have the same alias"): + rel = rel1.join(rel2, "col1") def test_relational_join_with_condition(self): con = duckdb.connect() - rel1 = con.sql("SELECT 1 AS col1, 2 AS col2", alias='rel1') - rel2 = con.sql("SELECT 1 AS col1, 3 AS col3", alias='rel2') + rel1 = con.sql("SELECT 1 AS col1, 2 AS col2", alias="rel1") + rel2 = con.sql("SELECT 1 AS col1, 3 AS col3", alias="rel2") # This makes a USING clause, which is kind of unexpected behavior - rel = rel1.join(rel2, 'rel1.col1 = rel2.col1') + rel = rel1.join(rel2, "rel1.col1 = rel2.col1") rel.show() res = rel.fetchall() assert res == [(1, 2, 1, 3)] @@ -49,8 +49,8 @@ def test_deduplicated_bindings(self, duckdb_cursor): duckdb_cursor.execute("create table old as select * from (values ('42', 1), ('21', 2)) t(a, b)") duckdb_cursor.execute("create table old_1 as select * from (values ('42', 3), ('21', 4)) t(a, b)") - old = duckdb_cursor.table('old') - old_1 = duckdb_cursor.table('old_1') + old = duckdb_cursor.table("old") + old_1 = duckdb_cursor.table("old_1") join_one = old.join(old_1, "old.a == old_1.a") join_two = old.join(old_1, "old.a == old_1.a") diff --git a/tests/fast/api/test_native_tz.py b/tests/fast/api/test_native_tz.py index 6098ca08..f4a9d716 100644 --- a/tests/fast/api/test_native_tz.py +++ b/tests/fast/api/test_native_tz.py @@ -8,7 +8,7 @@ pa = pytest.importorskip("pyarrow") from packaging.version import Version -filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data', 'tz.parquet') +filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "data", "tz.parquet") class TestNativeTimeZone(object): @@ -16,20 +16,20 @@ def test_native_python_timestamp_timezone(self, duckdb_cursor): duckdb_cursor.execute("SET timezone='America/Los_Angeles';") res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetchone() assert res[0].hour == 14 and res[0].minute == 52 - assert res[0].tzinfo.zone == 'America/Los_Angeles' + assert res[0].tzinfo.zone == "America/Los_Angeles" res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetchall()[0] assert res[0].hour == 14 and res[0].minute == 52 - assert res[0].tzinfo.zone == 'America/Los_Angeles' + assert res[0].tzinfo.zone == "America/Los_Angeles" res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetchmany(1)[0] assert res[0].hour == 14 and res[0].minute == 52 - assert res[0].tzinfo.zone == 'America/Los_Angeles' + assert res[0].tzinfo.zone == "America/Los_Angeles" duckdb_cursor.execute("SET timezone='UTC';") res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetchone() assert res[0].hour == 21 and res[0].minute == 52 - assert res[0].tzinfo.zone == 'UTC' + assert res[0].tzinfo.zone == "UTC" def test_native_python_time_timezone(self, duckdb_cursor): res = duckdb_cursor.execute(f"select TimeRecStart::TIMETZ as tz from '{filename}'").fetchone() @@ -41,33 +41,33 @@ def test_native_python_time_timezone(self, duckdb_cursor): def test_pandas_timestamp_timezone(self, duckdb_cursor): res = duckdb_cursor.execute("SET timezone='America/Los_Angeles';") res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").df() - assert res.dtypes["tz"].tz.zone == 'America/Los_Angeles' - assert res['tz'][0].hour == 14 and res['tz'][0].minute == 52 + assert res.dtypes["tz"].tz.zone == "America/Los_Angeles" + assert res["tz"][0].hour == 14 and res["tz"][0].minute == 52 duckdb_cursor.execute("SET timezone='UTC';") res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").df() - assert res['tz'][0].hour == 21 and res['tz'][0].minute == 52 + assert res["tz"][0].hour == 21 and res["tz"][0].minute == 52 def test_pandas_timestamp_time(self, duckdb_cursor): with pytest.raises( - duckdb.NotImplementedException, match="Not implemented Error: Unsupported type \"TIME WITH TIME ZONE\"" + duckdb.NotImplementedException, match='Not implemented Error: Unsupported type "TIME WITH TIME ZONE"' ): duckdb_cursor.execute(f"select TimeRecStart::TIMETZ as tz from '{filename}'").df() @pytest.mark.skipif( - Version(pa.__version__) < Version('15.0.0'), reason="pyarrow 14.0.2 'to_pandas' causes a DeprecationWarning" + Version(pa.__version__) < Version("15.0.0"), reason="pyarrow 14.0.2 'to_pandas' causes a DeprecationWarning" ) def test_arrow_timestamp_timezone(self, duckdb_cursor): res = duckdb_cursor.execute("SET timezone='America/Los_Angeles';") table = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetch_arrow_table() res = table.to_pandas() - assert res.dtypes["tz"].tz.zone == 'America/Los_Angeles' - assert res['tz'][0].hour == 14 and res['tz'][0].minute == 52 + assert res.dtypes["tz"].tz.zone == "America/Los_Angeles" + assert res["tz"][0].hour == 14 and res["tz"][0].minute == 52 duckdb_cursor.execute("SET timezone='UTC';") res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetch_arrow_table().to_pandas() - assert res.dtypes["tz"].tz.zone == 'UTC' - assert res['tz'][0].hour == 21 and res['tz'][0].minute == 52 + assert res.dtypes["tz"].tz.zone == "UTC" + assert res["tz"][0].hour == 21 and res["tz"][0].minute == 52 def test_arrow_timestamp_time(self, duckdb_cursor): duckdb_cursor.execute("SET timezone='America/Los_Angeles';") @@ -81,8 +81,8 @@ def test_arrow_timestamp_time(self, duckdb_cursor): .fetch_arrow_table() .to_pandas() ) - assert res1['tz'][0].hour == 14 and res1['tz'][0].minute == 52 - assert res2['tz'][0].hour == res2['tz'][0].hour and res2['tz'][0].minute == res1['tz'][0].minute + assert res1["tz"][0].hour == 14 and res1["tz"][0].minute == 52 + assert res2["tz"][0].hour == res2["tz"][0].hour and res2["tz"][0].minute == res1["tz"][0].minute duckdb_cursor.execute("SET timezone='UTC';") res1 = ( @@ -95,5 +95,5 @@ def test_arrow_timestamp_time(self, duckdb_cursor): .fetch_arrow_table() .to_pandas() ) - assert res1['tz'][0].hour == 21 and res1['tz'][0].minute == 52 - assert res2['tz'][0].hour == res2['tz'][0].hour and res2['tz'][0].minute == res1['tz'][0].minute + assert res1["tz"][0].hour == 21 and res1["tz"][0].minute == 52 + assert res2["tz"][0].hour == res2["tz"][0].hour and res2["tz"][0].minute == res1["tz"][0].minute diff --git a/tests/fast/api/test_query_interrupt.py b/tests/fast/api/test_query_interrupt.py index 6334e475..e6d2b998 100644 --- a/tests/fast/api/test_query_interrupt.py +++ b/tests/fast/api/test_query_interrupt.py @@ -25,7 +25,7 @@ def test_query_interruption(self): # Start the thread thread.start() try: - res = con.execute('select count(*) from range(100000000000)').fetchall() + res = con.execute("select count(*) from range(100000000000)").fetchall() except RuntimeError: # If this is not reached, we could not cancel the query before it completed # indicating that the query interruption functionality is broken diff --git a/tests/fast/api/test_read_csv.py b/tests/fast/api/test_read_csv.py index 7337515d..dff90869 100644 --- a/tests/fast/api/test_read_csv.py +++ b/tests/fast/api/test_read_csv.py @@ -11,7 +11,7 @@ def TestFile(name): import os - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data', name) + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "data", name) return filename @@ -35,262 +35,262 @@ def create_temp_csv(tmp_path): class TestReadCSV(object): def test_using_connection_wrapper(self): - rel = duckdb.read_csv(TestFile('category.csv')) + rel = duckdb.read_csv(TestFile("category.csv")) res = rel.fetchone() print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == (1, "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_using_connection_wrapper_with_keyword(self): - rel = duckdb.read_csv(TestFile('category.csv'), dtype={'category_id': 'string'}) + rel = duckdb.read_csv(TestFile("category.csv"), dtype={"category_id": "string"}) res = rel.fetchone() print(res) - assert res == ('1', 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == ("1", "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_no_options(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv')) + rel = duckdb_cursor.read_csv(TestFile("category.csv")) res = rel.fetchone() print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == (1, "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_dtype(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), dtype={'category_id': 'string'}) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), dtype={"category_id": "string"}) res = rel.fetchone() print(res) - assert res == ('1', 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == ("1", "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_dtype_as_list(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), dtype=['string']) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), dtype=["string"]) res = rel.fetchone() print(res) - assert res == ('1', 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == ("1", "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) - rel = duckdb_cursor.read_csv(TestFile('category.csv'), dtype=['double']) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), dtype=["double"]) res = rel.fetchone() print(res) - assert res == (1.0, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == (1.0, "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_sep(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), sep=" ") + rel = duckdb_cursor.read_csv(TestFile("category.csv"), sep=" ") res = rel.fetchone() print(res) - assert res == ('1|Action|2006-02-15', datetime.time(4, 46, 27)) + assert res == ("1|Action|2006-02-15", datetime.time(4, 46, 27)) def test_delimiter(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), delimiter=" ") + rel = duckdb_cursor.read_csv(TestFile("category.csv"), delimiter=" ") res = rel.fetchone() print(res) - assert res == ('1|Action|2006-02-15', datetime.time(4, 46, 27)) + assert res == ("1|Action|2006-02-15", datetime.time(4, 46, 27)) def test_delimiter_and_sep(self, duckdb_cursor): with pytest.raises(duckdb.InvalidInputException, match="read_csv takes either 'delimiter' or 'sep', not both"): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), delimiter=" ", sep=" ") + rel = duckdb_cursor.read_csv(TestFile("category.csv"), delimiter=" ", sep=" ") def test_header_true(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv')) + rel = duckdb_cursor.read_csv(TestFile("category.csv")) res = rel.fetchone() print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == (1, "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) @pytest.mark.skip(reason="Issue #6011 needs to be fixed first, header=False doesn't work correctly") def test_header_false(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), header=False) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), header=False) def test_na_values(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), na_values='Action') + rel = duckdb_cursor.read_csv(TestFile("category.csv"), na_values="Action") res = rel.fetchone() print(res) assert res == (1, None, datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_na_values_list(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), na_values=['Action', 'Animation']) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), na_values=["Action", "Animation"]) res = rel.fetchone() assert res == (1, None, datetime.datetime(2006, 2, 15, 4, 46, 27)) res = rel.fetchone() assert res == (2, None, datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_skiprows(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), skiprows=1) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), skiprows=1) res = rel.fetchone() print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == (1, "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) # We want to detect this at bind time def test_compression_wrong(self, duckdb_cursor): with pytest.raises(duckdb.Error, match="Input is not a GZIP stream"): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), compression='gzip') + rel = duckdb_cursor.read_csv(TestFile("category.csv"), compression="gzip") def test_quotechar(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('unquote_without_delimiter.csv'), quotechar="", header=False) + rel = duckdb_cursor.read_csv(TestFile("unquote_without_delimiter.csv"), quotechar="", header=False) res = rel.fetchone() print(res) assert res == ('"AAA"BB',) def test_quote(self, duckdb_cursor): with pytest.raises( - duckdb.Error, match="The methods read_csv and read_csv_auto do not have the \"quote\" argument." + duckdb.Error, match='The methods read_csv and read_csv_auto do not have the "quote" argument.' ): - rel = duckdb_cursor.read_csv(TestFile('unquote_without_delimiter.csv'), quote="", header=False) + rel = duckdb_cursor.read_csv(TestFile("unquote_without_delimiter.csv"), quote="", header=False) def test_escapechar(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('quote_escape.csv'), escapechar=";", header=False) + rel = duckdb_cursor.read_csv(TestFile("quote_escape.csv"), escapechar=";", header=False) res = rel.limit(1, 1).fetchone() print(res) - assert res == ('345', 'TEST6', '"text""2""text"') + assert res == ("345", "TEST6", '"text""2""text"') def test_encoding_wrong(self, duckdb_cursor): with pytest.raises( duckdb.BinderException, match="Copy is only supported for UTF-8 encoded files, ENCODING 'UTF-8'" ): - rel = duckdb_cursor.read_csv(TestFile('quote_escape.csv'), encoding=";") + rel = duckdb_cursor.read_csv(TestFile("quote_escape.csv"), encoding=";") def test_encoding_correct(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('quote_escape.csv'), encoding="UTF-8") + rel = duckdb_cursor.read_csv(TestFile("quote_escape.csv"), encoding="UTF-8") res = rel.limit(1, 1).fetchone() print(res) - assert res == (345, 'TEST6', 'text"2"text') + assert res == (345, "TEST6", 'text"2"text') def test_date_format_as_datetime(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('datetime.csv')) + rel = duckdb_cursor.read_csv(TestFile("datetime.csv")) res = rel.fetchone() print(res) assert res == ( 123, - 'TEST2', + "TEST2", datetime.time(12, 12, 12), datetime.date(2000, 1, 1), datetime.datetime(2000, 1, 1, 12, 12), ) def test_date_format_as_date(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('datetime.csv'), date_format='%Y-%m-%d') + rel = duckdb_cursor.read_csv(TestFile("datetime.csv"), date_format="%Y-%m-%d") res = rel.fetchone() print(res) assert res == ( 123, - 'TEST2', + "TEST2", datetime.time(12, 12, 12), datetime.date(2000, 1, 1), datetime.datetime(2000, 1, 1, 12, 12), ) def test_timestamp_format(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('datetime.csv'), timestamp_format='%Y-%m-%d %H:%M:%S') + rel = duckdb_cursor.read_csv(TestFile("datetime.csv"), timestamp_format="%Y-%m-%d %H:%M:%S") res = rel.fetchone() assert res == ( 123, - 'TEST2', + "TEST2", datetime.time(12, 12, 12), datetime.date(2000, 1, 1), datetime.datetime(2000, 1, 1, 12, 12), ) def test_sample_size_correct(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('problematic.csv'), sample_size=-1) + rel = duckdb_cursor.read_csv(TestFile("problematic.csv"), sample_size=-1) res = rel.fetchone() print(res) - assert res == ('1', '1', '1') + assert res == ("1", "1", "1") def test_all_varchar(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), all_varchar=True) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), all_varchar=True) res = rel.fetchone() print(res) - assert res == ('1', 'Action', '2006-02-15 04:46:27') + assert res == ("1", "Action", "2006-02-15 04:46:27") def test_null_padding(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('nullpadding.csv'), null_padding=False, header=False) + rel = duckdb_cursor.read_csv(TestFile("nullpadding.csv"), null_padding=False, header=False) res = rel.fetchall() assert res == [ - ('# this file has a bunch of gunk at the top',), - ('one,two,three,four',), - ('1,a,alice',), - ('2,b,bob',), + ("# this file has a bunch of gunk at the top",), + ("one,two,three,four",), + ("1,a,alice",), + ("2,b,bob",), ] - rel = duckdb_cursor.read_csv(TestFile('nullpadding.csv'), null_padding=True, header=False) + rel = duckdb_cursor.read_csv(TestFile("nullpadding.csv"), null_padding=True, header=False) res = rel.fetchall() assert res == [ - ('# this file has a bunch of gunk at the top', None, None, None), - ('one', 'two', 'three', 'four'), - ('1', 'a', 'alice', None), - ('2', 'b', 'bob', None), + ("# this file has a bunch of gunk at the top", None, None, None), + ("one", "two", "three", "four"), + ("1", "a", "alice", None), + ("2", "b", "bob", None), ] - rel = duckdb.read_csv(TestFile('nullpadding.csv'), null_padding=False, header=False) + rel = duckdb.read_csv(TestFile("nullpadding.csv"), null_padding=False, header=False) res = rel.fetchall() assert res == [ - ('# this file has a bunch of gunk at the top',), - ('one,two,three,four',), - ('1,a,alice',), - ('2,b,bob',), + ("# this file has a bunch of gunk at the top",), + ("one,two,three,four",), + ("1,a,alice",), + ("2,b,bob",), ] - rel = duckdb.read_csv(TestFile('nullpadding.csv'), null_padding=True, header=False) + rel = duckdb.read_csv(TestFile("nullpadding.csv"), null_padding=True, header=False) res = rel.fetchall() assert res == [ - ('# this file has a bunch of gunk at the top', None, None, None), - ('one', 'two', 'three', 'four'), - ('1', 'a', 'alice', None), - ('2', 'b', 'bob', None), + ("# this file has a bunch of gunk at the top", None, None, None), + ("one", "two", "three", "four"), + ("1", "a", "alice", None), + ("2", "b", "bob", None), ] - rel = duckdb_cursor.from_csv_auto(TestFile('nullpadding.csv'), null_padding=False, header=False) + rel = duckdb_cursor.from_csv_auto(TestFile("nullpadding.csv"), null_padding=False, header=False) res = rel.fetchall() assert res == [ - ('# this file has a bunch of gunk at the top',), - ('one,two,three,four',), - ('1,a,alice',), - ('2,b,bob',), + ("# this file has a bunch of gunk at the top",), + ("one,two,three,four",), + ("1,a,alice",), + ("2,b,bob",), ] - rel = duckdb_cursor.from_csv_auto(TestFile('nullpadding.csv'), null_padding=True, header=False) + rel = duckdb_cursor.from_csv_auto(TestFile("nullpadding.csv"), null_padding=True, header=False) res = rel.fetchall() assert res == [ - ('# this file has a bunch of gunk at the top', None, None, None), - ('one', 'two', 'three', 'four'), - ('1', 'a', 'alice', None), - ('2', 'b', 'bob', None), + ("# this file has a bunch of gunk at the top", None, None, None), + ("one", "two", "three", "four"), + ("1", "a", "alice", None), + ("2", "b", "bob", None), ] def test_normalize_names(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), normalize_names=False) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), normalize_names=False) df = rel.df() column_names = list(df.columns.values) # The names are not normalized, so they are capitalized - assert 'CATEGORY_ID' in column_names + assert "CATEGORY_ID" in column_names - rel = duckdb_cursor.read_csv(TestFile('category.csv'), normalize_names=True) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), normalize_names=True) df = rel.df() column_names = list(df.columns.values) # The capitalized names are normalized to lowercase instead - assert 'CATEGORY_ID' not in column_names + assert "CATEGORY_ID" not in column_names def test_filename(self, duckdb_cursor): - rel = duckdb_cursor.read_csv(TestFile('category.csv'), filename=False) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), filename=False) df = rel.df() column_names = list(df.columns.values) # The filename is not included in the returned columns - assert 'filename' not in column_names + assert "filename" not in column_names - rel = duckdb_cursor.read_csv(TestFile('category.csv'), filename=True) + rel = duckdb_cursor.read_csv(TestFile("category.csv"), filename=True) df = rel.df() column_names = list(df.columns.values) # The filename is included in the returned columns - assert 'filename' in column_names + assert "filename" in column_names def test_read_pathlib_path(self, duckdb_cursor): pathlib = pytest.importorskip("pathlib") - path = pathlib.Path(TestFile('category.csv')) + path = pathlib.Path(TestFile("category.csv")) rel = duckdb_cursor.read_csv(path) res = rel.fetchone() print(res) - assert res == (1, 'Action', datetime.datetime(2006, 2, 15, 4, 46, 27)) + assert res == (1, "Action", datetime.datetime(2006, 2, 15, 4, 46, 27)) def test_read_filelike(self, duckdb_cursor): pytest.importorskip("fsspec") string = StringIO("c1,c2,c3\na,b,c") res = duckdb_cursor.read_csv(string).fetchall() - assert res == [('a', 'b', 'c')] + assert res == [("a", "b", "c")] def test_read_filelike_rel_out_of_scope(self, duckdb_cursor): _ = pytest.importorskip("fsspec") @@ -321,7 +321,7 @@ def test_filelike_bytesio(self, duckdb_cursor): _ = pytest.importorskip("fsspec") string = BytesIO(b"c1,c2,c3\na,b,c") res = duckdb_cursor.read_csv(string).fetchall() - assert res == [('a', 'b', 'c')] + assert res == [("a", "b", "c")] def test_filelike_exception(self, duckdb_cursor): _ = pytest.importorskip("fsspec") @@ -341,7 +341,7 @@ def __init__(self) -> None: pass def read(self, amount=-1): - return b'test' + return b"test" def seek(self, loc): raise ValueError(loc) @@ -377,7 +377,7 @@ def read(self, amount=-1): obj = CustomIO() res = duckdb_cursor.read_csv(obj).fetchall() - assert res == [('a', 'b', 'c')] + assert res == [("a", "b", "c")] def test_filelike_non_readable(self, duckdb_cursor): _ = pytest.importorskip("fsspec") @@ -410,9 +410,9 @@ def scoped_objects(duckdb_cursor): rel1 = duckdb_cursor.read_csv(obj) assert rel1.fetchall() == [ ( - 'a', - 'b', - 'c', + "a", + "b", + "c", ) ] assert CountedObject.instance_count == 1 @@ -421,9 +421,9 @@ def scoped_objects(duckdb_cursor): rel2 = duckdb_cursor.read_csv(obj) assert rel2.fetchall() == [ ( - 'a', - 'b', - 'c', + "a", + "b", + "c", ) ] assert CountedObject.instance_count == 2 @@ -432,9 +432,9 @@ def scoped_objects(duckdb_cursor): rel3 = duckdb_cursor.read_csv(obj) assert rel3.fetchall() == [ ( - 'a', - 'b', - 'c', + "a", + "b", + "c", ) ] assert CountedObject.instance_count == 3 @@ -448,24 +448,24 @@ def test_read_csv_glob(self, tmp_path, create_temp_csv): # Use the temporary file paths to read CSV files con = duckdb.connect() - rel = con.read_csv(f'{tmp_path}/file*.csv') + rel = con.read_csv(f"{tmp_path}/file*.csv") res = con.sql("select * from rel order by all").fetchall() assert res == [(1,), (2,), (3,), (4,), (5,), (6,)] @pytest.mark.xfail(condition=platform.system() == "Emscripten", reason="time zones not working") def test_read_csv_combined(self, duckdb_cursor): - CSV_FILE = TestFile('stress_test.csv') + CSV_FILE = TestFile("stress_test.csv") COLUMNS = { - 'result': 'VARCHAR', - 'table': 'BIGINT', - '_time': 'TIMESTAMPTZ', - '_measurement': 'VARCHAR', - 'bench_test': 'VARCHAR', - 'flight_id': 'VARCHAR', - 'flight_status': 'VARCHAR', - 'log_level': 'VARCHAR', - 'sys_uuid': 'VARCHAR', - 'message': 'VARCHAR', + "result": "VARCHAR", + "table": "BIGINT", + "_time": "TIMESTAMPTZ", + "_measurement": "VARCHAR", + "bench_test": "VARCHAR", + "flight_id": "VARCHAR", + "flight_status": "VARCHAR", + "log_level": "VARCHAR", + "sys_uuid": "VARCHAR", + "message": "VARCHAR", } rel = duckdb.read_csv(CSV_FILE, skiprows=1, delimiter=",", quotechar='"', escapechar="\\", dtype=COLUMNS) @@ -483,39 +483,39 @@ def test_read_csv_combined(self, duckdb_cursor): def test_read_csv_names(self, tmp_path): file = tmp_path / "file.csv" - file.write_text('one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4') + file.write_text("one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4") con = duckdb.connect() - rel = con.read_csv(str(file), names=['a', 'b', 'c']) - assert rel.columns == ['a', 'b', 'c', 'four'] + rel = con.read_csv(str(file), names=["a", "b", "c"]) + assert rel.columns == ["a", "b", "c", "four"] with pytest.raises(duckdb.InvalidInputException, match="read_csv only accepts 'names' as a list of strings"): rel = con.read_csv(file, names=True) with pytest.raises(duckdb.InvalidInputException, match="not possible to detect the CSV Header"): - rel = con.read_csv(file, names=['a', 'b', 'c', 'd', 'e']) + rel = con.read_csv(file, names=["a", "b", "c", "d", "e"]) # Duplicates are not okay with pytest.raises(duckdb.BinderException, match="names must have unique values"): - rel = con.read_csv(file, names=['a', 'b', 'a', 'b']) - assert rel.columns == ['a', 'b', 'a', 'b'] + rel = con.read_csv(file, names=["a", "b", "a", "b"]) + assert rel.columns == ["a", "b", "a", "b"] def test_read_csv_names_mixed_with_dtypes(self, tmp_path): file = tmp_path / "file.csv" - file.write_text('one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4') + file.write_text("one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4") con = duckdb.connect() rel = con.read_csv( file, - names=['a', 'b', 'c'], + names=["a", "b", "c"], dtype={ - 'a': int, - 'b': bool, - 'c': str, + "a": int, + "b": bool, + "c": str, }, ) - assert rel.columns == ['a', 'b', 'c', 'four'] - assert rel.types == ['BIGINT', 'BOOLEAN', 'VARCHAR', 'BIGINT'] + assert rel.columns == ["a", "b", "c", "four"] + assert rel.types == ["BIGINT", "BOOLEAN", "VARCHAR", "BIGINT"] # dtypes and names dont match # FIXME: seems the order columns are named in this error is non-deterministic @@ -524,23 +524,23 @@ def test_read_csv_names_mixed_with_dtypes(self, tmp_path): with pytest.raises(duckdb.BinderException, match=expected_error): rel = con.read_csv( file, - names=['a', 'b', 'c'], + names=["a", "b", "c"], dtype={ - 'd': int, - 'e': bool, - 'f': str, + "d": int, + "e": bool, + "f": str, }, ) def test_read_csv_multi_file(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4') + file1.write_text("one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4") file2 = tmp_path / "file2.csv" - file2.write_text('one,two,three,four\n5,6,7,8\n5,6,7,8\n5,6,7,8') + file2.write_text("one,two,three,four\n5,6,7,8\n5,6,7,8\n5,6,7,8") file3 = tmp_path / "file3.csv" - file3.write_text('one,two,three,four\n9,10,11,12\n9,10,11,12\n9,10,11,12') + file3.write_text("one,two,three,four\n9,10,11,12\n9,10,11,12\n9,10,11,12") con = duckdb.connect() files = [str(file1), str(file2), str(file3)] @@ -562,72 +562,72 @@ def test_read_csv_empty_list(self): con = duckdb.connect() files = [] with pytest.raises( - duckdb.InvalidInputException, match='Please provide a non-empty list of paths or file-like objects' + duckdb.InvalidInputException, match="Please provide a non-empty list of paths or file-like objects" ): rel = con.read_csv(files) res = rel.fetchall() def test_read_auto_detect(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('one|two|three|four\n1|2|3|4') + file1.write_text("one|two|three|four\n1|2|3|4") con = duckdb.connect() - rel = con.read_csv(str(file1), columns={'a': 'VARCHAR'}, auto_detect=False, header=False) - assert rel.fetchall() == [('one|two|three|four',), ('1|2|3|4',)] + rel = con.read_csv(str(file1), columns={"a": "VARCHAR"}, auto_detect=False, header=False) + assert rel.fetchall() == [("one|two|three|four",), ("1|2|3|4",)] def test_read_csv_list_invalid_path(self, tmp_path): con = duckdb.connect() file1 = tmp_path / "file1.csv" - file1.write_text('one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4') + file1.write_text("one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4") file3 = tmp_path / "file3.csv" - file3.write_text('one,two,three,four\n9,10,11,12\n9,10,11,12\n9,10,11,12') + file3.write_text("one,two,three,four\n9,10,11,12\n9,10,11,12\n9,10,11,12") - files = [str(file1), 'not_valid_path', str(file3)] + files = [str(file1), "not_valid_path", str(file3)] with pytest.raises(duckdb.IOException, match='No files found that match the pattern "not_valid_path"'): rel = con.read_csv(files) res = rel.fetchall() @pytest.mark.parametrize( - 'options', + "options", [ - {'lineterminator': '\\n'}, - {'lineterminator': 'LINE_FEED'}, - {'lineterminator': CSVLineTerminator.LINE_FEED}, - {'columns': {'id': 'INTEGER', 'name': 'INTEGER', 'c': 'integer', 'd': 'INTEGER'}}, - {'auto_type_candidates': ['INTEGER', 'INTEGER']}, - {'max_line_size': 10000}, - {'ignore_errors': True}, - {'ignore_errors': False}, - {'store_rejects': True}, - {'store_rejects': False}, - {'rejects_table': 'my_rejects_table'}, - {'rejects_scan': 'my_rejects_scan'}, - {'rejects_table': 'my_rejects_table', 'rejects_limit': 50}, - {'force_not_null': ['one', 'two']}, - {'buffer_size': 2097153}, - {'decimal': '.'}, - {'allow_quoted_nulls': True}, - {'allow_quoted_nulls': False}, - {'filename': True}, - {'filename': 'test'}, - {'hive_partitioning': True}, - {'hive_partitioning': False}, - {'union_by_name': True}, - {'union_by_name': False}, - {'hive_types_autocast': False}, - {'hive_types_autocast': True}, - {'hive_types': {'one': 'INTEGER', 'two': 'VARCHAR'}}, + {"lineterminator": "\\n"}, + {"lineterminator": "LINE_FEED"}, + {"lineterminator": CSVLineTerminator.LINE_FEED}, + {"columns": {"id": "INTEGER", "name": "INTEGER", "c": "integer", "d": "INTEGER"}}, + {"auto_type_candidates": ["INTEGER", "INTEGER"]}, + {"max_line_size": 10000}, + {"ignore_errors": True}, + {"ignore_errors": False}, + {"store_rejects": True}, + {"store_rejects": False}, + {"rejects_table": "my_rejects_table"}, + {"rejects_scan": "my_rejects_scan"}, + {"rejects_table": "my_rejects_table", "rejects_limit": 50}, + {"force_not_null": ["one", "two"]}, + {"buffer_size": 2097153}, + {"decimal": "."}, + {"allow_quoted_nulls": True}, + {"allow_quoted_nulls": False}, + {"filename": True}, + {"filename": "test"}, + {"hive_partitioning": True}, + {"hive_partitioning": False}, + {"union_by_name": True}, + {"union_by_name": False}, + {"hive_types_autocast": False}, + {"hive_types_autocast": True}, + {"hive_types": {"one": "INTEGER", "two": "VARCHAR"}}, ], ) @pytest.mark.skipif(sys.platform.startswith("win"), reason="Skipping on Windows because of lineterminator option") def test_read_csv_options(self, duckdb_cursor, options, tmp_path): file = tmp_path / "file.csv" - file.write_text('one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4') + file.write_text("one,two,three,four\n1,2,3,4\n1,2,3,4\n1,2,3,4") print(options) - if 'hive_types' in options: - with pytest.raises(duckdb.InvalidInputException, match=r'Unknown hive_type:'): + if "hive_types" in options: + with pytest.raises(duckdb.InvalidInputException, match=r"Unknown hive_type:"): rel = duckdb_cursor.read_csv(file, **options) else: rel = duckdb_cursor.read_csv(file, **options) @@ -635,73 +635,73 @@ def test_read_csv_options(self, duckdb_cursor, options, tmp_path): def test_read_comment(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('one|two|three|four\n1|2|3|4#|5|6\n#bla\n1|2|3|4\n') + file1.write_text("one|two|three|four\n1|2|3|4#|5|6\n#bla\n1|2|3|4\n") con = duckdb.connect() - rel = con.read_csv(str(file1), columns={'a': 'VARCHAR'}, auto_detect=False, header=False, comment='#') - assert rel.fetchall() == [('one|two|three|four',), ('1|2|3|4',), ('1|2|3|4',)] + rel = con.read_csv(str(file1), columns={"a": "VARCHAR"}, auto_detect=False, header=False, comment="#") + assert rel.fetchall() == [("one|two|three|four",), ("1|2|3|4",), ("1|2|3|4",)] def test_read_enum(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('feelings\nhappy\nsad\nangry\nhappy\n') + file1.write_text("feelings\nhappy\nsad\nangry\nhappy\n") con = duckdb.connect() con.execute("CREATE TYPE mood AS ENUM ('happy', 'sad', 'angry')") - rel = con.read_csv(str(file1), dtype=['mood']) - assert rel.fetchall() == [('happy',), ('sad',), ('angry',), ('happy',)] + rel = con.read_csv(str(file1), dtype=["mood"]) + assert rel.fetchall() == [("happy",), ("sad",), ("angry",), ("happy",)] - rel = con.read_csv(str(file1), dtype={'feelings': 'mood'}) - assert rel.fetchall() == [('happy',), ('sad',), ('angry',), ('happy',)] + rel = con.read_csv(str(file1), dtype={"feelings": "mood"}) + assert rel.fetchall() == [("happy",), ("sad",), ("angry",), ("happy",)] - rel = con.read_csv(str(file1), columns={'feelings': 'mood'}) - assert rel.fetchall() == [('happy',), ('sad',), ('angry',), ('happy',)] + rel = con.read_csv(str(file1), columns={"feelings": "mood"}) + assert rel.fetchall() == [("happy",), ("sad",), ("angry",), ("happy",)] with pytest.raises(duckdb.CatalogException, match="Type with name mood_2 does not exist!"): - rel = con.read_csv(str(file1), columns={'feelings': 'mood_2'}) + rel = con.read_csv(str(file1), columns={"feelings": "mood_2"}) with pytest.raises(duckdb.CatalogException, match="Type with name mood_2 does not exist!"): - rel = con.read_csv(str(file1), dtype={'feelings': 'mood_2'}) + rel = con.read_csv(str(file1), dtype={"feelings": "mood_2"}) with pytest.raises(duckdb.CatalogException, match="Type with name mood_2 does not exist!"): - rel = con.read_csv(str(file1), dtype=['mood_2']) + rel = con.read_csv(str(file1), dtype=["mood_2"]) def test_strict_mode(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('one|two|three|four\n1|2|3|4\n1|2|3|4|5\n1|2|3|4\n') + file1.write_text("one|two|three|four\n1|2|3|4\n1|2|3|4|5\n1|2|3|4\n") con = duckdb.connect() with pytest.raises(duckdb.InvalidInputException, match="CSV Error on Line"): rel = con.read_csv( str(file1), header=True, - delimiter='|', - columns={'a': 'INTEGER', 'b': 'INTEGER', 'c': 'INTEGER', 'd': 'INTEGER'}, + delimiter="|", + columns={"a": "INTEGER", "b": "INTEGER", "c": "INTEGER", "d": "INTEGER"}, auto_detect=False, ) rel.fetchall() rel = con.read_csv( str(file1), header=True, - delimiter='|', + delimiter="|", strict_mode=False, - columns={'a': 'INTEGER', 'b': 'INTEGER', 'c': 'INTEGER', 'd': 'INTEGER'}, + columns={"a": "INTEGER", "b": "INTEGER", "c": "INTEGER", "d": "INTEGER"}, auto_detect=False, ) assert rel.fetchall() == [(1, 2, 3, 4), (1, 2, 3, 4), (1, 2, 3, 4)] def test_union_by_name(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('one|two|three|four\n1|2|3|4') + file1.write_text("one|two|three|four\n1|2|3|4") file1 = tmp_path / "file2.csv" - file1.write_text('two|three|four|five\n2|3|4|5') + file1.write_text("two|three|four|five\n2|3|4|5") con = duckdb.connect() file_path = tmp_path / "file*.csv" rel = con.read_csv(file_path, union_by_name=True) - assert rel.columns == ['one', 'two', 'three', 'four', 'five'] + assert rel.columns == ["one", "two", "three", "four", "five"] assert rel.fetchall() == [(1, 2, 3, 4, None), (None, 2, 3, 4, 5)] def test_thousands_separator(self, tmp_path): @@ -709,27 +709,27 @@ def test_thousands_separator(self, tmp_path): file.write_text('money\n"10,000.23"\n"1,000,000,000.01"') con = duckdb.connect() - rel = con.read_csv(file, thousands=',') + rel = con.read_csv(file, thousands=",") assert rel.fetchall() == [(10000.23,), (1000000000.01,)] with pytest.raises( duckdb.BinderException, match="Unsupported parameter for THOUSANDS: should be max one character" ): - con.read_csv(file, thousands=',,,') + con.read_csv(file, thousands=",,,") def test_skip_comment_option(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('skip this line\n# comment\nx,y,z\n1,2,3\n4,5,6') + file1.write_text("skip this line\n# comment\nx,y,z\n1,2,3\n4,5,6") con = duckdb.connect() - rel = con.read_csv(file1, comment='#', skiprows=1, all_varchar=True) - assert rel.columns == ['x', 'y', 'z'] - assert rel.fetchall() == [('1', '2', '3'), ('4', '5', '6')] + rel = con.read_csv(file1, comment="#", skiprows=1, all_varchar=True) + assert rel.columns == ["x", "y", "z"] + assert rel.fetchall() == [("1", "2", "3"), ("4", "5", "6")] def test_files_to_sniff_option(self, tmp_path): file1 = tmp_path / "file1.csv" - file1.write_text('bar,baz\n2025-05-12,baz') + file1.write_text("bar,baz\n2025-05-12,baz") file2 = tmp_path / "file2.csv" - file2.write_text('bar,baz\nbar,baz') + file2.write_text("bar,baz\nbar,baz") file_path = tmp_path / "file*.csv" con = duckdb.connect() @@ -737,4 +737,4 @@ def test_files_to_sniff_option(self, tmp_path): rel = con.read_csv(file_path, files_to_sniff=1) rel.fetchall() rel = con.read_csv(file_path, files_to_sniff=-1) - assert rel.fetchall() == [('2025-05-12', 'baz'), ('bar', 'baz')] + assert rel.fetchall() == [("2025-05-12", "baz"), ("bar", "baz")] diff --git a/tests/fast/api/test_relation_to_view.py b/tests/fast/api/test_relation_to_view.py index f4a43d54..31a19d54 100644 --- a/tests/fast/api/test_relation_to_view.py +++ b/tests/fast/api/test_relation_to_view.py @@ -4,27 +4,27 @@ class TestRelationToView(object): def test_values_to_view(self, duckdb_cursor): - rel = duckdb_cursor.values(['test', 'this is a long string']) + rel = duckdb_cursor.values(["test", "this is a long string"]) res = rel.fetchall() - assert res == [('test', 'this is a long string')] + assert res == [("test", "this is a long string")] - rel.to_view('vw1') + rel.to_view("vw1") - view = duckdb_cursor.table('vw1') + view = duckdb_cursor.table("vw1") res = view.fetchall() - assert res == [('test', 'this is a long string')] + assert res == [("test", "this is a long string")] def test_relation_to_view(self, duckdb_cursor): rel = duckdb_cursor.sql("select 'test', 'this is a long string'") res = rel.fetchall() - assert res == [('test', 'this is a long string')] + assert res == [("test", "this is a long string")] - rel.to_view('vw1') + rel.to_view("vw1") - view = duckdb_cursor.table('vw1') + view = duckdb_cursor.table("vw1") res = view.fetchall() - assert res == [('test', 'this is a long string')] + assert res == [("test", "this is a long string")] def test_registered_relation(self, duckdb_cursor): rel = duckdb_cursor.sql("select 'test', 'this is a long string'") @@ -33,12 +33,12 @@ def test_registered_relation(self, duckdb_cursor): # Register on a different connection is not allowed with pytest.raises( duckdb.InvalidInputException, - match='was created by another Connection and can therefore not be used by this Connection', + match="was created by another Connection and can therefore not be used by this Connection", ): - con.register('cross_connection', rel) + con.register("cross_connection", rel) # Register on the same connection just creates a view - duckdb_cursor.register('same_connection', rel) - view = duckdb_cursor.table('same_connection') + duckdb_cursor.register("same_connection", rel) + view = duckdb_cursor.table("same_connection") res = view.fetchall() - assert res == [('test', 'this is a long string')] + assert res == [("test", "this is a long string")] diff --git a/tests/fast/api/test_streaming_result.py b/tests/fast/api/test_streaming_result.py index e51f62e4..739fd17a 100644 --- a/tests/fast/api/test_streaming_result.py +++ b/tests/fast/api/test_streaming_result.py @@ -5,7 +5,7 @@ class TestStreamingResult(object): def test_fetch_one(self, duckdb_cursor): # fetch one - res = duckdb_cursor.sql('SELECT * FROM range(100000)') + res = duckdb_cursor.sql("SELECT * FROM range(100000)") result = [] while len(result) < 5000: tpl = res.fetchone() @@ -24,7 +24,7 @@ def test_fetch_one(self, duckdb_cursor): def test_fetch_many(self, duckdb_cursor): # fetch many - res = duckdb_cursor.sql('SELECT * FROM range(100000)') + res = duckdb_cursor.sql("SELECT * FROM range(100000)") result = [] while len(result) < 5000: tpl = res.fetchmany(10) @@ -45,11 +45,11 @@ def test_record_batch_reader(self, duckdb_cursor): pytest.importorskip("pyarrow") pytest.importorskip("pyarrow.dataset") # record batch reader - res = duckdb_cursor.sql('SELECT * FROM range(100000) t(i)') + res = duckdb_cursor.sql("SELECT * FROM range(100000) t(i)") reader = res.fetch_arrow_reader(batch_size=16_384) result = [] for batch in reader: - result += batch.to_pydict()['i'] + result += batch.to_pydict()["i"] assert result == list(range(100000)) # record batch reader with error @@ -60,9 +60,9 @@ def test_record_batch_reader(self, duckdb_cursor): reader = res.fetch_arrow_reader(batch_size=16_384) def test_9801(self, duckdb_cursor): - duckdb_cursor.execute('CREATE TABLE test(id INTEGER , name VARCHAR NOT NULL);') + duckdb_cursor.execute("CREATE TABLE test(id INTEGER , name VARCHAR NOT NULL);") - words = ['aaaaaaaaaaaaaaaaaaaaaaa', 'bbbb', 'ccccccccc', 'ííííííííí'] + words = ["aaaaaaaaaaaaaaaaaaaaaaa", "bbbb", "ccccccccc", "ííííííííí"] lines = [(i, words[i % 4]) for i in range(1000)] duckdb_cursor.executemany("INSERT INTO TEST (id, name) VALUES (?, ?)", lines) diff --git a/tests/fast/api/test_to_csv.py b/tests/fast/api/test_to_csv.py index e48ae1b8..5f8000a9 100644 --- a/tests/fast/api/test_to_csv.py +++ b/tests/fast/api/test_to_csv.py @@ -9,10 +9,10 @@ class TestToCSV(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_basic_to_csv(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]}) + df = pandas.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name) @@ -20,21 +20,21 @@ def test_basic_to_csv(self, pandas): csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_sep(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]}) + df = pandas.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) - rel.to_csv(temp_file_name, sep=',') + rel.to_csv(temp_file_name, sep=",") - csv_rel = duckdb.read_csv(temp_file_name, sep=',') + csv_rel = duckdb.read_csv(temp_file_name, sep=",") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_na_rep(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': [5, None, 23, 2], 'b': [45, 234, 234, 2]}) + df = pandas.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, na_rep="test") @@ -42,10 +42,10 @@ def test_to_csv_na_rep(self, pandas): csv_rel = duckdb.read_csv(temp_file_name, na_values="test") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_header(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': [5, None, 23, 2], 'b': [45, 234, 234, 2]}) + df = pandas.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name) @@ -53,18 +53,18 @@ def test_to_csv_header(self, pandas): csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_quotechar(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ["\'a,b,c\'", None, "hello", "bye"], 'b': [45, 234, 234, 2]}) + df = pandas.DataFrame({"a": ["'a,b,c'", None, "hello", "bye"], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) - rel.to_csv(temp_file_name, quotechar='\'', sep=',') + rel.to_csv(temp_file_name, quotechar="'", sep=",") - csv_rel = duckdb.read_csv(temp_file_name, sep=',', quotechar='\'') + csv_rel = duckdb.read_csv(temp_file_name, sep=",", quotechar="'") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_escapechar(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pandas.DataFrame( @@ -76,11 +76,11 @@ def test_to_csv_escapechar(self, pandas): } ) rel = duckdb.from_df(df) - rel.to_csv(temp_file_name, quotechar='"', escapechar='!') - csv_rel = duckdb.read_csv(temp_file_name, quotechar='"', escapechar='!') + rel.to_csv(temp_file_name, quotechar='"', escapechar="!") + csv_rel = duckdb.read_csv(temp_file_name, quotechar='"', escapechar="!") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_date_format(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pandas.DataFrame(getTimeSeriesData()) @@ -93,82 +93,82 @@ def test_to_csv_date_format(self, pandas): assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_timestamp_format(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) data = [datetime.time(hour=23, minute=1, second=34, microsecond=234345)] - df = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + df = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) rel = duckdb.from_df(df) - rel.to_csv(temp_file_name, timestamp_format='%m/%d/%Y') + rel.to_csv(temp_file_name, timestamp_format="%m/%d/%Y") - csv_rel = duckdb.read_csv(temp_file_name, timestamp_format='%m/%d/%Y') + csv_rel = duckdb.read_csv(temp_file_name, timestamp_format="%m/%d/%Y") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_quoting_off(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, quoting=None) csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_quoting_on(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, quoting="force") csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_quoting_quote_all(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, quoting=csv.QUOTE_ALL) csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_encoding_incorrect(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) with pytest.raises( duckdb.InvalidInputException, match="Invalid Input Error: The only supported encoding option is 'UTF8" ): rel.to_csv(temp_file_name, encoding="nope") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_encoding_correct(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, encoding="UTF-8") csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_compression_gzip(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, compression="gzip") csv_rel = duckdb.read_csv(temp_file_name, compression="gzip") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_partition(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pandas.DataFrame( { - "c_category": ['a', 'a', 'b', 'b'], + "c_category": ["a", "a", "b", "b"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], @@ -178,23 +178,23 @@ def test_to_csv_partition(self, pandas): rel = duckdb.from_df(df) rel.to_csv(temp_file_name, header=True, partition_by=["c_category"]) csv_rel = duckdb.sql( - f'''FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE);''' + f"""FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE);""" ) expected = [ - (True, 1.0, 42.0, 'a', 'a'), - (False, 3.2, None, 'b,c', 'a'), - (True, 3.0, 123.0, 'e', 'b'), - (True, 4.0, 321.0, 'f', 'b'), + (True, 1.0, 42.0, "a", "a"), + (False, 3.2, None, "b,c", "a"), + (True, 3.0, 123.0, "e", "b"), + (True, 4.0, 321.0, "f", "b"), ] assert csv_rel.execute().fetchall() == expected - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_partition_with_columns_written(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pandas.DataFrame( { - "c_category": ['a', 'a', 'b', 'b'], + "c_category": ["a", "a", "b", "b"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], @@ -205,17 +205,17 @@ def test_to_csv_partition_with_columns_written(self, pandas): res = duckdb.sql("FROM rel order by all") rel.to_csv(temp_file_name, header=True, partition_by=["c_category"], write_partition_columns=True) csv_rel = duckdb.sql( - f'''FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE) order by all;''' + f"""FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE) order by all;""" ) assert res.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_overwrite(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pandas.DataFrame( { - "c_category_1": ['a', 'a', 'b', 'b'], - "c_category_2": ['c', 'c', 'd', 'd'], + "c_category_1": ["a", "a", "b", "b"], + "c_category_2": ["c", "c", "d", "d"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], @@ -226,24 +226,24 @@ def test_to_csv_overwrite(self, pandas): rel.to_csv(temp_file_name, header=True, partition_by=["c_category_1"]) # csv to be overwritten rel.to_csv(temp_file_name, header=True, partition_by=["c_category_1"], overwrite=True) csv_rel = duckdb.sql( - f'''FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE);''' + f"""FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE);""" ) # When partition columns are read from directory names, column order become different from original expected = [ - ('c', True, 1.0, 42.0, 'a', 'a'), - ('c', False, 3.2, None, 'b,c', 'a'), - ('d', True, 3.0, 123.0, 'e', 'b'), - ('d', True, 4.0, 321.0, 'f', 'b'), + ("c", True, 1.0, 42.0, "a", "a"), + ("c", False, 3.2, None, "b,c", "a"), + ("d", True, 3.0, 123.0, "e", "b"), + ("d", True, 4.0, 321.0, "f", "b"), ] assert csv_rel.execute().fetchall() == expected - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_overwrite_with_columns_written(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pandas.DataFrame( { - "c_category_1": ['a', 'a', 'b', 'b'], - "c_category_2": ['c', 'c', 'd', 'd'], + "c_category_1": ["a", "a", "b", "b"], + "c_category_2": ["c", "c", "d", "d"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], @@ -258,18 +258,18 @@ def test_to_csv_overwrite_with_columns_written(self, pandas): temp_file_name, header=True, partition_by=["c_category_1"], overwrite=True, write_partition_columns=True ) csv_rel = duckdb.sql( - f'''FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE) order by all;''' + f"""FROM read_csv_auto('{temp_file_name}/*/*.csv', hive_partitioning=TRUE, header=TRUE) order by all;""" ) res = duckdb.sql("FROM rel order by all") assert res.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_overwrite_not_enabled(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pandas.DataFrame( { - "c_category_1": ['a', 'a', 'b', 'b'], - "c_category_2": ['c', 'c', 'd', 'd'], + "c_category_1": ["a", "a", "b", "b"], + "c_category_2": ["c", "c", "d", "d"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], @@ -281,14 +281,14 @@ def test_to_csv_overwrite_not_enabled(self, pandas): with pytest.raises(duckdb.IOException, match="OVERWRITE"): rel.to_csv(temp_file_name, header=True, partition_by=["c_category_1"]) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_per_thread_output(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) num_threads = duckdb.sql("select current_setting('threads')").fetchone()[0] - print('num_threads:', num_threads) + print("num_threads:", num_threads) df = pandas.DataFrame( { - "c_category": ['a', 'a', 'b', 'b'], + "c_category": ["a", "a", "b", "b"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], @@ -297,16 +297,16 @@ def test_to_csv_per_thread_output(self, pandas): ) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, header=True, per_thread_output=True) - csv_rel = duckdb.read_csv(f'{temp_file_name}/*.csv', header=True) + csv_rel = duckdb.read_csv(f"{temp_file_name}/*.csv", header=True) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_use_tmp_file(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pandas.DataFrame( { - "c_category_1": ['a', 'a', 'b', 'b'], - "c_category_2": ['c', 'c', 'd', 'd'], + "c_category_1": ["a", "a", "b", "b"], + "c_category_2": ["c", "c", "d", "d"], "c_bool": [True, False, True, True], "c_float": [1.0, 3.2, 3.0, 4.0], "c_int": [42, None, 123, 321], diff --git a/tests/fast/api/test_to_parquet.py b/tests/fast/api/test_to_parquet.py index d778aba3..c13ac011 100644 --- a/tests/fast/api/test_to_parquet.py +++ b/tests/fast/api/test_to_parquet.py @@ -13,7 +13,7 @@ class TestToParquet(object): @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_basic_to_parquet(self, pd): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pd.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]}) + df = pd.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_parquet(temp_file_name) @@ -24,7 +24,7 @@ def test_basic_to_parquet(self, pd): @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_compression_gzip(self, pd): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pd.DataFrame({'a': ['string1', 'string2', 'string3']}) + df = pd.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_parquet(temp_file_name, compression="gzip") csv_rel = duckdb.read_parquet(temp_file_name, compression="gzip") @@ -32,37 +32,32 @@ def test_compression_gzip(self, pd): def test_field_ids_auto(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - rel = duckdb.sql('''SELECT {i: 128} AS my_struct''') - rel.to_parquet(temp_file_name, field_ids='auto') + rel = duckdb.sql("""SELECT {i: 128} AS my_struct""") + rel.to_parquet(temp_file_name, field_ids="auto") parquet_rel = duckdb.read_parquet(temp_file_name) assert rel.execute().fetchall() == parquet_rel.execute().fetchall() def test_field_ids(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - rel = duckdb.sql('''SELECT 1 as i, {j: 128} AS my_struct''') - rel.to_parquet(temp_file_name, field_ids=dict(i=42, my_struct={'__duckdb_field_id': 43, 'j': 44})) + rel = duckdb.sql("""SELECT 1 as i, {j: 128} AS my_struct""") + rel.to_parquet(temp_file_name, field_ids=dict(i=42, my_struct={"__duckdb_field_id": 43, "j": 44})) parquet_rel = duckdb.read_parquet(temp_file_name) assert rel.execute().fetchall() == parquet_rel.execute().fetchall() - assert ( - [('duckdb_schema', None), ('i', 42), ('my_struct', 43), ('j', 44)] - == duckdb.sql( - f''' + assert [("duckdb_schema", None), ("i", 42), ("my_struct", 43), ("j", 44)] == duckdb.sql( + f""" select name,field_id from parquet_schema('{temp_file_name}') - ''' - ) - .execute() - .fetchall() - ) + """ + ).execute().fetchall() @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) - @pytest.mark.parametrize('row_group_size_bytes', [122880 * 1024, '2MB']) + @pytest.mark.parametrize("row_group_size_bytes", [122880 * 1024, "2MB"]) def test_row_group_size_bytes(self, pd, row_group_size_bytes): con = duckdb.connect() con.execute("SET preserve_insertion_order=false;") temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pd.DataFrame({'a': ['string1', 'string2', 'string3']}) + df = pd.DataFrame({"a": ["string1", "string2", "string3"]}) rel = con.from_df(df) rel.to_parquet(temp_file_name, row_group_size_bytes=row_group_size_bytes) parquet_rel = con.read_parquet(temp_file_name) @@ -71,21 +66,21 @@ def test_row_group_size_bytes(self, pd, row_group_size_bytes): @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_row_group_size(self, pd): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) - df = pd.DataFrame({'a': ['string1', 'string2', 'string3']}) + df = pd.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_parquet(temp_file_name, row_group_size=122880) parquet_rel = duckdb.read_parquet(temp_file_name) assert rel.execute().fetchall() == parquet_rel.execute().fetchall() @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) - @pytest.mark.parametrize('write_columns', [None, True, False]) + @pytest.mark.parametrize("write_columns", [None, True, False]) def test_partition(self, pd, write_columns): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pd.DataFrame( { "name": ["rei", "shinji", "asuka", "kaworu"], "float": [321.0, 123.0, 23.0, 340.0], - "category": ['a', 'a', 'b', 'c'], + "category": ["a", "a", "b", "c"], } ) rel = duckdb.from_df(df) @@ -95,14 +90,14 @@ def test_partition(self, pd, write_columns): assert result.execute().fetchall() == expected @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) - @pytest.mark.parametrize('write_columns', [None, True, False]) + @pytest.mark.parametrize("write_columns", [None, True, False]) def test_overwrite(self, pd, write_columns): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) df = pd.DataFrame( { "name": ["rei", "shinji", "asuka", "kaworu"], "float": [321.0, 123.0, 23.0, 340.0], - "category": ['a', 'a', 'b', 'c'], + "category": ["a", "a", "b", "c"], } ) rel = duckdb.from_df(df) @@ -120,7 +115,7 @@ def test_use_tmp_file(self, pd): { "name": ["rei", "shinji", "asuka", "kaworu"], "float": [321.0, 123.0, 23.0, 340.0], - "category": ['a', 'a', 'b', 'c'], + "category": ["a", "a", "b", "c"], } ) rel = duckdb.from_df(df) @@ -133,17 +128,17 @@ def test_use_tmp_file(self, pd): def test_per_thread_output(self, pd): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) num_threads = duckdb.sql("select current_setting('threads')").fetchone()[0] - print('threads:', num_threads) + print("threads:", num_threads) df = pd.DataFrame( { "name": ["rei", "shinji", "asuka", "kaworu"], "float": [321.0, 123.0, 23.0, 340.0], - "category": ['a', 'a', 'b', 'c'], + "category": ["a", "a", "b", "c"], } ) rel = duckdb.from_df(df) rel.to_parquet(temp_file_name, per_thread_output=True) - result = duckdb.read_parquet(f'{temp_file_name}/*.parquet') + result = duckdb.read_parquet(f"{temp_file_name}/*.parquet") assert rel.execute().fetchall() == result.execute().fetchall() @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) @@ -153,27 +148,27 @@ def test_append(self, pd): { "name": ["rei", "shinji", "asuka", "kaworu"], "float": [321.0, 123.0, 23.0, 340.0], - "category": ['a', 'a', 'b', 'c'], + "category": ["a", "a", "b", "c"], } ) rel = duckdb.from_df(df) - rel.to_parquet(temp_file_name, partition_by=['category']) + rel.to_parquet(temp_file_name, partition_by=["category"]) df_to_append = pd.DataFrame( { "name": ["random"], "float": [420], - "category": ['a'], + "category": ["a"], } ) rel_to_append = duckdb.from_df(df_to_append) - rel_to_append.to_parquet(temp_file_name, partition_by=['category'], append=True) + rel_to_append.to_parquet(temp_file_name, partition_by=["category"], append=True) result = duckdb.sql(f"FROM read_parquet('{temp_file_name}/*/*.parquet', hive_partitioning=TRUE) ORDER BY name") result.show() expected = [ - ('asuka', 23.0, 'b'), - ('kaworu', 340.0, 'c'), - ('random', 420.0, 'a'), - ('rei', 321.0, 'a'), - ('shinji', 123.0, 'a'), + ("asuka", 23.0, "b"), + ("kaworu", 340.0, "c"), + ("random", 420.0, "a"), + ("rei", 321.0, "a"), + ("shinji", 123.0, "a"), ] assert result.execute().fetchall() == expected diff --git a/tests/fast/api/test_with_propagating_exceptions.py b/tests/fast/api/test_with_propagating_exceptions.py index e9cfb3c0..8613d6f4 100644 --- a/tests/fast/api/test_with_propagating_exceptions.py +++ b/tests/fast/api/test_with_propagating_exceptions.py @@ -7,12 +7,12 @@ def test_with(self): # Should propagate exception raised in the 'with duckdb.connect() ..' with pytest.raises(duckdb.ParserException, match="syntax error at or near *"): with duckdb.connect() as con: - print('before') - con.execute('invalid') - print('after') + print("before") + con.execute("invalid") + print("after") # Does not raise an exception with duckdb.connect() as con: - print('before') - con.execute('select 1') - print('after') + print("before") + con.execute("select 1") + print("after") diff --git a/tests/fast/arrow/parquet_write_roundtrip.py b/tests/fast/arrow/parquet_write_roundtrip.py index 093040c0..5dbf3949 100644 --- a/tests/fast/arrow/parquet_write_roundtrip.py +++ b/tests/fast/arrow/parquet_write_roundtrip.py @@ -17,13 +17,13 @@ def parquet_types_test(type_list): sql_type = type_pair[2] add_cast = len(type_pair) > 3 and type_pair[3] add_sql_cast = len(type_pair) > 4 and type_pair[4] - df = pandas.DataFrame.from_dict({'val': numpy.array(value_list, dtype=numpy_type)}) + df = pandas.DataFrame.from_dict({"val": numpy.array(value_list, dtype=numpy_type)}) duckdb_cursor = duckdb.connect() duckdb_cursor.execute(f"CREATE TABLE tmp AS SELECT val::{sql_type} val FROM df") duckdb_cursor.execute(f"COPY tmp TO '{temp_name}' (FORMAT PARQUET)") read_df = pandas.read_parquet(temp_name) if add_cast: - read_df['val'] = read_df['val'].astype(numpy_type) + read_df["val"] = read_df["val"].astype(numpy_type) assert df.equals(read_df) read_from_duckdb = duckdb_cursor.execute(f"SELECT * FROM parquet_scan('{temp_name}')").df() @@ -40,16 +40,16 @@ def parquet_types_test(type_list): class TestParquetRoundtrip(object): def test_roundtrip_numeric(self, duckdb_cursor): type_list = [ - ([-(2**7), 0, 2**7 - 1], numpy.int8, 'TINYINT'), - ([-(2**15), 0, 2**15 - 1], numpy.int16, 'SMALLINT'), - ([-(2**31), 0, 2**31 - 1], numpy.int32, 'INTEGER'), - ([-(2**63), 0, 2**63 - 1], numpy.int64, 'BIGINT'), - ([0, 42, 2**8 - 1], numpy.uint8, 'UTINYINT'), - ([0, 42, 2**16 - 1], numpy.uint16, 'USMALLINT'), - ([0, 42, 2**32 - 1], numpy.uint32, 'UINTEGER', False, True), - ([0, 42, 2**64 - 1], numpy.uint64, 'UBIGINT'), - ([0, 0.5, -0.5], numpy.float32, 'REAL'), - ([0, 0.5, -0.5], numpy.float64, 'DOUBLE'), + ([-(2**7), 0, 2**7 - 1], numpy.int8, "TINYINT"), + ([-(2**15), 0, 2**15 - 1], numpy.int16, "SMALLINT"), + ([-(2**31), 0, 2**31 - 1], numpy.int32, "INTEGER"), + ([-(2**63), 0, 2**63 - 1], numpy.int64, "BIGINT"), + ([0, 42, 2**8 - 1], numpy.uint8, "UTINYINT"), + ([0, 42, 2**16 - 1], numpy.uint16, "USMALLINT"), + ([0, 42, 2**32 - 1], numpy.uint32, "UINTEGER", False, True), + ([0, 42, 2**64 - 1], numpy.uint64, "UBIGINT"), + ([0, 0.5, -0.5], numpy.float32, "REAL"), + ([0, 0.5, -0.5], numpy.float64, "DOUBLE"), ] parquet_types_test(type_list) @@ -61,15 +61,15 @@ def test_roundtrip_timestamp(self, duckdb_cursor): datetime.datetime(1992, 7, 9, 7, 5, 33), ] type_list = [ - (date_time_list, 'datetime64[ns]', 'TIMESTAMP_NS'), - (date_time_list, 'datetime64[us]', 'TIMESTAMP'), - (date_time_list, 'datetime64[ms]', 'TIMESTAMP_MS'), - (date_time_list, 'datetime64[s]', 'TIMESTAMP_S'), - (date_time_list, 'datetime64[D]', 'DATE', True), + (date_time_list, "datetime64[ns]", "TIMESTAMP_NS"), + (date_time_list, "datetime64[us]", "TIMESTAMP"), + (date_time_list, "datetime64[ms]", "TIMESTAMP_MS"), + (date_time_list, "datetime64[s]", "TIMESTAMP_S"), + (date_time_list, "datetime64[D]", "DATE", True), ] parquet_types_test(type_list) def test_roundtrip_varchar(self, duckdb_cursor): - varchar_list = ['hello', 'this is a very long string', 'hello', None] - type_list = [(varchar_list, object, 'VARCHAR')] + varchar_list = ["hello", "this is a very long string", "hello", None] + type_list = [(varchar_list, object, "VARCHAR")] parquet_types_test(type_list) diff --git a/tests/fast/arrow/test_10795.py b/tests/fast/arrow/test_10795.py index 043bf4ff..5503e529 100644 --- a/tests/fast/arrow/test_10795.py +++ b/tests/fast/arrow/test_10795.py @@ -1,12 +1,12 @@ import duckdb import pytest -pyarrow = pytest.importorskip('pyarrow') +pyarrow = pytest.importorskip("pyarrow") -@pytest.mark.parametrize('arrow_large_buffer_size', [True, False]) +@pytest.mark.parametrize("arrow_large_buffer_size", [True, False]) def test_10795(arrow_large_buffer_size): conn = duckdb.connect() conn.sql(f"set arrow_large_buffer_size={arrow_large_buffer_size}") arrow = conn.sql("select map(['non-inlined string', 'test', 'duckdb'], [42, 1337, 123]) as map").to_arrow_table() - assert arrow.to_pydict() == {'map': [[('non-inlined string', 42), ('test', 1337), ('duckdb', 123)]]} + assert arrow.to_pydict() == {"map": [[("non-inlined string", 42), ("test", 1337), ("duckdb", 123)]]} diff --git a/tests/fast/arrow/test_12384.py b/tests/fast/arrow/test_12384.py index af9c8ed2..d2d4a7fc 100644 --- a/tests/fast/arrow/test_12384.py +++ b/tests/fast/arrow/test_12384.py @@ -2,17 +2,17 @@ import pytest import os -pa = pytest.importorskip('pyarrow') +pa = pytest.importorskip("pyarrow") def test_10795(): - arrow_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'arrow_table') - with pa.memory_map(arrow_filename, 'r') as source: + arrow_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "arrow_table") + with pa.memory_map(arrow_filename, "r") as source: reader = pa.ipc.RecordBatchFileReader(source) taxi_fhvhv_arrow = reader.read_all() - con = duckdb.connect(database=':memory:') + con = duckdb.connect(database=":memory:") con.execute("SET TimeZone='UTC';") - con.register('taxi_fhvhv', taxi_fhvhv_arrow) + con.register("taxi_fhvhv", taxi_fhvhv_arrow) res = con.execute( "SELECT PULocationID, pickup_datetime FROM taxi_fhvhv WHERE pickup_datetime >= '2023-01-01T00:00:00-05:00' AND PULocationID = 244" ).fetchall() diff --git a/tests/fast/arrow/test_14344.py b/tests/fast/arrow/test_14344.py index 522228c0..86f8728b 100644 --- a/tests/fast/arrow/test_14344.py +++ b/tests/fast/arrow/test_14344.py @@ -22,4 +22,4 @@ def test_14344(duckdb_cursor): USING (foo) """ ).fetchall() - assert res == [('123',)] + assert res == [("123",)] diff --git a/tests/fast/arrow/test_2426.py b/tests/fast/arrow/test_2426.py index cdef8da7..6d760500 100644 --- a/tests/fast/arrow/test_2426.py +++ b/tests/fast/arrow/test_2426.py @@ -22,15 +22,15 @@ def test_2426(self, duckdb_cursor): con.execute("Insert Into test values ('" + str(i) + "')") con.execute("Insert Into test values ('5000')") con.execute("Insert Into test values ('6000')") - sql = ''' + sql = """ SELECT a, COUNT(*) AS repetitions FROM test GROUP BY a - ''' + """ result_df = con.execute(sql).df() arrow_table = con.execute(sql).fetch_arrow_table() arrow_df = arrow_table.to_pandas() - assert result_df['repetitions'].sum() == arrow_df['repetitions'].sum() + assert result_df["repetitions"].sum() == arrow_df["repetitions"].sum() diff --git a/tests/fast/arrow/test_5547.py b/tests/fast/arrow/test_5547.py index b27b29b2..eb77ab83 100644 --- a/tests/fast/arrow/test_5547.py +++ b/tests/fast/arrow/test_5547.py @@ -3,7 +3,7 @@ from pandas.testing import assert_frame_equal import pytest -pa = pytest.importorskip('pyarrow') +pa = pytest.importorskip("pyarrow") def test_5547(): diff --git a/tests/fast/arrow/test_6584.py b/tests/fast/arrow/test_6584.py index 9a6241f9..6f96bf2d 100644 --- a/tests/fast/arrow/test_6584.py +++ b/tests/fast/arrow/test_6584.py @@ -2,7 +2,7 @@ import duckdb import pytest -pyarrow = pytest.importorskip('pyarrow') +pyarrow = pytest.importorskip("pyarrow") def f(cur, i, data): diff --git a/tests/fast/arrow/test_6796.py b/tests/fast/arrow/test_6796.py index 6690f22c..ef464f49 100644 --- a/tests/fast/arrow/test_6796.py +++ b/tests/fast/arrow/test_6796.py @@ -2,10 +2,10 @@ import pytest from conftest import NumpyPandas, ArrowPandas -pyarrow = pytest.importorskip('pyarrow') +pyarrow = pytest.importorskip("pyarrow") -@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +@pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_6796(pandas): conn = duckdb.connect() input_df = pandas.DataFrame({"foo": ["bar"]}) diff --git a/tests/fast/arrow/test_7652.py b/tests/fast/arrow/test_7652.py index afe3b738..857d871d 100644 --- a/tests/fast/arrow/test_7652.py +++ b/tests/fast/arrow/test_7652.py @@ -9,7 +9,7 @@ class Test7652(object): def test_7652(self, duckdb_cursor): - temp_file_name = tempfile.NamedTemporaryFile(suffix='.parquet').name + temp_file_name = tempfile.NamedTemporaryFile(suffix=".parquet").name # Generate a list of values that aren't uniform in changes. generated_list = [1, 0, 2] @@ -17,7 +17,7 @@ def test_7652(self, duckdb_cursor): print(f"Min value: {min(generated_list)} max value: {max(generated_list)}") # Convert list of values to a PyArrow table with a single column. - fake_table = pa.Table.from_arrays([pa.array(generated_list, pa.int64())], names=['n0']) + fake_table = pa.Table.from_arrays([pa.array(generated_list, pa.int64())], names=["n0"]) # Write that column with DELTA_BINARY_PACKED encoding with pq.ParquetWriter( diff --git a/tests/fast/arrow/test_7699.py b/tests/fast/arrow/test_7699.py index c8c234ef..a4de66b9 100644 --- a/tests/fast/arrow/test_7699.py +++ b/tests/fast/arrow/test_7699.py @@ -22,4 +22,4 @@ def test_7699(self, duckdb_cursor): rel = duckdb_cursor.sql("select * from df1234") res = rel.fetchall() - assert res == [('K',), ('L',), ('K',), ('L',), ('M',)] + assert res == [("K",), ("L",), ("K",), ("L",), ("M",)] diff --git a/tests/fast/arrow/test_arrow_batch_index.py b/tests/fast/arrow/test_arrow_batch_index.py index dadf6f89..a8dc2c7f 100644 --- a/tests/fast/arrow/test_arrow_batch_index.py +++ b/tests/fast/arrow/test_arrow_batch_index.py @@ -9,13 +9,13 @@ class TestArrowBatchIndex(object): def test_arrow_batch_index(self, duckdb_cursor): con = duckdb.connect() - df = con.execute('SELECT * FROM range(10000000) t(i)').df() + df = con.execute("SELECT * FROM range(10000000) t(i)").df() arrow_tbl = pa.Table.from_pandas(df) - con.execute('CREATE TABLE tbl AS SELECT * FROM arrow_tbl') + con.execute("CREATE TABLE tbl AS SELECT * FROM arrow_tbl") - result = con.execute('SELECT * FROM tbl LIMIT 5').fetchall() + result = con.execute("SELECT * FROM tbl LIMIT 5").fetchall() assert [x[0] for x in result] == [0, 1, 2, 3, 4] - result = con.execute('SELECT * FROM tbl LIMIT 5 OFFSET 777778').fetchall() + result = con.execute("SELECT * FROM tbl LIMIT 5 OFFSET 777778").fetchall() assert [x[0] for x in result] == [777778, 777779, 777780, 777781, 777782] diff --git a/tests/fast/arrow/test_arrow_binary_view.py b/tests/fast/arrow/test_arrow_binary_view.py index 7d9d0afc..31107f67 100644 --- a/tests/fast/arrow/test_arrow_binary_view.py +++ b/tests/fast/arrow/test_arrow_binary_view.py @@ -8,7 +8,7 @@ class TestArrowBinaryView(object): def test_arrow_binary_view(self, duckdb_cursor): con = duckdb.connect() tab = pa.table({"x": pa.array([b"abc", b"thisisaverybigbinaryyaymorethanfifteen", None], pa.binary_view())}) - assert con.execute("FROM tab").fetchall() == [(b'abc',), (b'thisisaverybigbinaryyaymorethanfifteen',), (None,)] + assert con.execute("FROM tab").fetchall() == [(b"abc",), (b"thisisaverybigbinaryyaymorethanfifteen",), (None,)] # By default we won't export a view assert not con.execute("FROM tab").fetch_arrow_table().equals(tab) # We do the binary view from 1.4 onwards @@ -16,5 +16,5 @@ def test_arrow_binary_view(self, duckdb_cursor): assert con.execute("FROM tab").fetch_arrow_table().equals(tab) assert con.execute("FROM tab where x = 'thisisaverybigbinaryyaymorethanfifteen'").fetchall() == [ - (b'thisisaverybigbinaryyaymorethanfifteen',) + (b"thisisaverybigbinaryyaymorethanfifteen",) ] diff --git a/tests/fast/arrow/test_arrow_case_sensitive.py b/tests/fast/arrow/test_arrow_case_sensitive.py index 6106cc75..ef60046a 100644 --- a/tests/fast/arrow/test_arrow_case_sensitive.py +++ b/tests/fast/arrow/test_arrow_case_sensitive.py @@ -7,18 +7,18 @@ class TestArrowCaseSensitive(object): def test_arrow_case_sensitive(self, duckdb_cursor): data = (pa.array([1], type=pa.int32()), pa.array([1000], type=pa.int32())) - arrow_table = pa.Table.from_arrays([data[0], data[1]], ['A1', 'a1']) + arrow_table = pa.Table.from_arrays([data[0], data[1]], ["A1", "a1"]) - duckdb_cursor.register('arrow_tbl', arrow_table) - assert duckdb_cursor.table("arrow_tbl").columns == ['A1', 'a1_1'] + duckdb_cursor.register("arrow_tbl", arrow_table) + assert duckdb_cursor.table("arrow_tbl").columns == ["A1", "a1_1"] assert duckdb_cursor.execute("select A1 from arrow_tbl;").fetchall() == [(1,)] assert duckdb_cursor.execute("select a1_1 from arrow_tbl;").fetchall() == [(1000,)] - assert arrow_table.column_names == ['A1', 'a1'] + assert arrow_table.column_names == ["A1", "a1"] def test_arrow_case_sensitive_repeated(self, duckdb_cursor): data = (pa.array([1], type=pa.int32()), pa.array([1000], type=pa.int32())) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[1]], ['A1', 'a1_1', 'a1']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[1]], ["A1", "a1_1", "a1"]) - duckdb_cursor.register('arrow_tbl', arrow_table) - assert duckdb_cursor.table("arrow_tbl").columns == ['A1', 'a1_1', 'a1_2'] - assert arrow_table.column_names == ['A1', 'a1_1', 'a1'] + duckdb_cursor.register("arrow_tbl", arrow_table) + assert duckdb_cursor.table("arrow_tbl").columns == ["A1", "a1_1", "a1_2"] + assert arrow_table.column_names == ["A1", "a1_1", "a1"] diff --git a/tests/fast/arrow/test_arrow_decimal_32_64.py b/tests/fast/arrow/test_arrow_decimal_32_64.py index 4a960454..39b6e43a 100644 --- a/tests/fast/arrow/test_arrow_decimal_32_64.py +++ b/tests/fast/arrow/test_arrow_decimal_32_64.py @@ -8,7 +8,7 @@ class TestArrowDecimalTypes(object): def test_decimal_32(self, duckdb_cursor): duckdb_cursor = duckdb.connect() - duckdb_cursor.execute('SET arrow_output_version = 1.5') + duckdb_cursor.execute("SET arrow_output_version = 1.5") decimal_32 = pa.Table.from_pylist( [ {"data": Decimal("100.20")}, @@ -20,10 +20,10 @@ def test_decimal_32(self, duckdb_cursor): ) # Test scan assert duckdb_cursor.execute("FROM decimal_32").fetchall() == [ - (Decimal('100.20'),), - (Decimal('110.21'),), - (Decimal('31.20'),), - (Decimal('500.20'),), + (Decimal("100.20"),), + (Decimal("110.21"),), + (Decimal("31.20"),), + (Decimal("500.20"),), ] # Test filter pushdown assert duckdb_cursor.execute("SELECT COUNT(*) FROM decimal_32 where data > 100 and data < 200 ").fetchall() == [ @@ -37,7 +37,7 @@ def test_decimal_32(self, duckdb_cursor): def test_decimal_64(self, duckdb_cursor): duckdb_cursor = duckdb.connect() - duckdb_cursor.execute('SET arrow_output_version = 1.5') + duckdb_cursor.execute("SET arrow_output_version = 1.5") decimal_64 = pa.Table.from_pylist( [ {"data": Decimal("1000.231")}, @@ -50,10 +50,10 @@ def test_decimal_64(self, duckdb_cursor): # Test scan assert duckdb_cursor.execute("FROM decimal_64").fetchall() == [ - (Decimal('1000.231'),), - (Decimal('1100.231'),), - (Decimal('999999999999.231'),), - (Decimal('500.200'),), + (Decimal("1000.231"),), + (Decimal("1100.231"),), + (Decimal("999999999999.231"),), + (Decimal("500.200"),), ] # Test Filter pushdown diff --git a/tests/fast/arrow/test_arrow_extensions.py b/tests/fast/arrow/test_arrow_extensions.py index 95a2108a..43c995bb 100644 --- a/tests/fast/arrow/test_arrow_extensions.py +++ b/tests/fast/arrow/test_arrow_extensions.py @@ -5,11 +5,10 @@ from uuid import UUID import datetime -pa = pytest.importorskip('pyarrow', '18.0.0') +pa = pytest.importorskip("pyarrow", "18.0.0") class TestCanonicalExtensionTypes(object): - def test_uuid(self): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("SET arrow_lossless_conversion = true") @@ -17,9 +16,9 @@ def test_uuid(self): storage_array = pa.array([uuid.uuid4().bytes for _ in range(4)], pa.binary(16)) storage_array = pa.uuid().wrap_array(storage_array) - arrow_table = pa.Table.from_arrays([storage_array], names=['uuid_col']) + arrow_table = pa.Table.from_arrays([storage_array], names=["uuid_col"]) - duck_arrow = duckdb_cursor.execute('FROM arrow_table').fetch_arrow_table() + duck_arrow = duckdb_cursor.execute("FROM arrow_table").fetch_arrow_table() assert duck_arrow.equals(arrow_table) @@ -30,14 +29,14 @@ def test_uuid_from_duck(self): arrow_table = duckdb_cursor.execute("select uuid from test_all_types()").fetch_arrow_table() assert arrow_table.to_pylist() == [ - {'uuid': UUID('00000000-0000-0000-0000-000000000000')}, - {'uuid': UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')}, - {'uuid': None}, + {"uuid": UUID("00000000-0000-0000-0000-000000000000")}, + {"uuid": UUID("ffffffff-ffff-ffff-ffff-ffffffffffff")}, + {"uuid": None}, ] assert duckdb_cursor.execute("FROM arrow_table").fetchall() == [ - (UUID('00000000-0000-0000-0000-000000000000'),), - (UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'),), + (UUID("00000000-0000-0000-0000-000000000000"),), + (UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"),), (None,), ] @@ -45,8 +44,8 @@ def test_uuid_from_duck(self): "select '00000000-0000-0000-0000-000000000100'::UUID as uuid" ).fetch_arrow_table() - assert arrow_table.to_pylist() == [{'uuid': UUID('00000000-0000-0000-0000-000000000100')}] - assert duckdb_cursor.execute("FROM arrow_table").fetchall() == [(UUID('00000000-0000-0000-0000-000000000100'),)] + assert arrow_table.to_pylist() == [{"uuid": UUID("00000000-0000-0000-0000-000000000100")}] + assert duckdb_cursor.execute("FROM arrow_table").fetchall() == [(UUID("00000000-0000-0000-0000-000000000100"),)] def test_json(self, duckdb_cursor): data = {"name": "Pedro", "age": 28, "car": "VW Fox"} @@ -56,10 +55,10 @@ def test_json(self, duckdb_cursor): storage_array = pa.array([json_string], pa.string()) - arrow_table = pa.Table.from_arrays([storage_array], names=['json_col']) + arrow_table = pa.Table.from_arrays([storage_array], names=["json_col"]) duckdb_cursor.execute("SET arrow_lossless_conversion = true") - duck_arrow = duckdb_cursor.execute('FROM arrow_table').fetch_arrow_table() + duck_arrow = duckdb_cursor.execute("FROM arrow_table").fetch_arrow_table() assert duck_arrow.equals(arrow_table) @@ -70,8 +69,8 @@ def test_uuid_no_def(self): res_arrow = duckdb_cursor.execute("select uuid from test_all_types()").fetch_arrow_table() res_duck = duckdb_cursor.execute("from res_arrow").fetchall() assert res_duck == [ - (UUID('00000000-0000-0000-0000-000000000000'),), - (UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'),), + (UUID("00000000-0000-0000-0000-000000000000"),), + (UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"),), (None,), ] @@ -79,15 +78,15 @@ def test_uuid_no_def_lossless(self): duckdb_cursor = duckdb.connect() res_arrow = duckdb_cursor.execute("select uuid from test_all_types()").fetch_arrow_table() assert res_arrow.to_pylist() == [ - {'uuid': '00000000-0000-0000-0000-000000000000'}, - {'uuid': 'ffffffff-ffff-ffff-ffff-ffffffffffff'}, - {'uuid': None}, + {"uuid": "00000000-0000-0000-0000-000000000000"}, + {"uuid": "ffffffff-ffff-ffff-ffff-ffffffffffff"}, + {"uuid": None}, ] res_duck = duckdb_cursor.execute("from res_arrow").fetchall() assert res_duck == [ - ('00000000-0000-0000-0000-000000000000',), - ('ffffffff-ffff-ffff-ffff-ffffffffffff',), + ("00000000-0000-0000-0000-000000000000",), + ("ffffffff-ffff-ffff-ffff-ffffffffffff",), (None,), ] @@ -98,8 +97,8 @@ def test_uuid_no_def_stream(self): res_arrow = duckdb_cursor.execute("select uuid from test_all_types()").fetch_record_batch() res_duck = duckdb.execute("from res_arrow").fetchall() assert res_duck == [ - (UUID('00000000-0000-0000-0000-000000000000'),), - (UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'),), + (UUID("00000000-0000-0000-0000-000000000000"),), + (UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"),), (None,), ] @@ -109,9 +108,9 @@ def test_function(x): return x con = duckdb.connect() - con.create_function('test', test_function, ['UUID'], 'UUID', type='arrow') + con.create_function("test", test_function, ["UUID"], "UUID", type="arrow") - rel = con.sql("select ? as x", params=[uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')]) + rel = con.sql("select ? as x", params=[uuid.UUID("ffffffff-ffff-ffff-ffff-ffffffffffff")]) rel.project("test(x) from t").fetchall() def test_unimplemented_extension(self, duckdb_cursor): @@ -120,51 +119,51 @@ def __init__(self) -> None: pa.ExtensionType.__init__(self, pa.binary(5), "pedro.binary") def __arrow_ext_serialize__(self) -> bytes: - return b'' + return b"" @classmethod def __arrow_ext_deserialize__(cls, storage_type, serialized): return UuidTypeWrong() - storage_array = pa.array(['pedro'], pa.binary(5)) + storage_array = pa.array(["pedro"], pa.binary(5)) my_type = MyType() storage_array = my_type.wrap_array(storage_array) age_array = pa.array([29], pa.int32()) - arrow_table = pa.Table.from_arrays([storage_array, age_array], names=['pedro_pedro_pedro', 'age']) + arrow_table = pa.Table.from_arrays([storage_array, age_array], names=["pedro_pedro_pedro", "age"]) - duck_arrow = duckdb_cursor.execute('FROM arrow_table').fetch_arrow_table() - assert duckdb_cursor.execute('FROM duck_arrow').fetchall() == [(b'pedro', 29)] + duck_arrow = duckdb_cursor.execute("FROM arrow_table").fetch_arrow_table() + assert duckdb_cursor.execute("FROM duck_arrow").fetchall() == [(b"pedro", 29)] def test_hugeint(self): con = duckdb.connect() con.execute("SET arrow_lossless_conversion = true") - storage_array = pa.array([b'\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff'], pa.binary(16)) + storage_array = pa.array([b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"], pa.binary(16)) hugeint_type = pa.opaque(pa.binary(16), "hugeint", "DuckDB") storage_array = hugeint_type.wrap_array(storage_array) - arrow_table = pa.Table.from_arrays([storage_array], names=['numbers']) + arrow_table = pa.Table.from_arrays([storage_array], names=["numbers"]) - assert con.execute('FROM arrow_table').fetchall() == [(-1,)] + assert con.execute("FROM arrow_table").fetchall() == [(-1,)] - assert con.execute('FROM arrow_table').fetch_arrow_table().equals(arrow_table) + assert con.execute("FROM arrow_table").fetch_arrow_table().equals(arrow_table) con.execute("SET arrow_lossless_conversion = false") - assert not con.execute('FROM arrow_table').fetch_arrow_table().equals(arrow_table) + assert not con.execute("FROM arrow_table").fetch_arrow_table().equals(arrow_table) def test_uhugeint(self, duckdb_cursor): - storage_array = pa.array([b'\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff'], pa.binary(16)) + storage_array = pa.array([b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"], pa.binary(16)) uhugeint_type = pa.opaque(pa.binary(16), "uhugeint", "DuckDB") storage_array = uhugeint_type.wrap_array(storage_array) - arrow_table = pa.Table.from_arrays([storage_array], names=['numbers']) + arrow_table = pa.Table.from_arrays([storage_array], names=["numbers"]) - assert duckdb_cursor.execute('FROM arrow_table').fetchall() == [(340282366920938463463374607431768211455,)] + assert duckdb_cursor.execute("FROM arrow_table").fetchall() == [(340282366920938463463374607431768211455,)] def test_bit(self): con = duckdb.connect() @@ -176,18 +175,18 @@ def test_bit(self): res_bit = con.execute("SELECT '0101011'::BIT str FROM range(5) tbl(i)").fetch_arrow_table() assert con.execute("FROM res_blob").fetchall() == [ - (b'\x01\xab',), - (b'\x01\xab',), - (b'\x01\xab',), - (b'\x01\xab',), - (b'\x01\xab',), + (b"\x01\xab",), + (b"\x01\xab",), + (b"\x01\xab",), + (b"\x01\xab",), + (b"\x01\xab",), ] assert con.execute("FROM res_bit").fetchall() == [ - ('0101011',), - ('0101011',), - ('0101011',), - ('0101011',), - ('0101011',), + ("0101011",), + ("0101011",), + ("0101011",), + ("0101011",), + ("0101011",), ] def test_timetz(self): @@ -209,12 +208,12 @@ def test_bignum(self): res_bignum = con.execute( "SELECT '179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368'::bignum a FROM range(1) tbl(i)" ).fetch_arrow_table() - assert res_bignum.column("a").type.type_name == 'bignum' + assert res_bignum.column("a").type.type_name == "bignum" assert res_bignum.column("a").type.vendor_name == "DuckDB" assert con.execute("FROM res_bignum").fetchall() == [ ( - '179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368', + "179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368", ) ] @@ -235,9 +234,9 @@ def test_extension_dictionary(self, duckdb_cursor): indices = pa.array([0, 1, 0, 1, 2, 1, 0, 2]) dictionary = pa.array( [ - b'\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff', - b'\x01\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff', - b'\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff', + b"\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff", + b"\x01\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff", + b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff", ], pa.binary(16), ) @@ -245,7 +244,7 @@ def test_extension_dictionary(self, duckdb_cursor): dictionary = uhugeint_type.wrap_array(dictionary) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) assert rel.execute().fetchall() == [ (340282366920938463463374607431768211200,), @@ -263,13 +262,13 @@ def test_boolean(self): con.execute("SET arrow_lossless_conversion = true") storage_array = pa.array([-1, 0, 1, 2, None], pa.int8()) bool8_array = pa.ExtensionArray.from_storage(pa.bool8(), storage_array) - arrow_table = pa.Table.from_arrays([bool8_array], names=['bool8']) - assert con.execute('FROM arrow_table').fetchall() == [(True,), (False,), (True,), (True,), (None,)] - result_table = con.execute('FROM arrow_table').fetch_arrow_table() + arrow_table = pa.Table.from_arrays([bool8_array], names=["bool8"]) + assert con.execute("FROM arrow_table").fetchall() == [(True,), (False,), (True,), (True,), (None,)] + result_table = con.execute("FROM arrow_table").fetch_arrow_table() res_storage_array = pa.array([1, 0, 1, 1, None], pa.int8()) res_bool8_array = pa.ExtensionArray.from_storage(pa.bool8(), res_storage_array) - res_arrow_table = pa.Table.from_arrays([res_bool8_array], names=['bool8']) + res_arrow_table = pa.Table.from_arrays([res_bool8_array], names=["bool8"]) assert result_table.equals(res_arrow_table) @@ -279,7 +278,7 @@ def test_accept_malformed_complex_json(self, duckdb_cursor): pa.binary(), metadata={ "ARROW:extension:name": "foofyfoo", - "ARROW:extension:metadata": 'this is not valid json', + "ARROW:extension:metadata": "this is not valid json", }, ) schema = pa.schema([field]) @@ -296,7 +295,7 @@ def test_accept_malformed_complex_json(self, duckdb_cursor): pa.binary(), metadata={ "ARROW:extension:name": "arrow.opaque", - "ARROW:extension:metadata": 'this is not valid json', + "ARROW:extension:metadata": "this is not valid json", }, ) schema = pa.schema([field]) @@ -337,9 +336,9 @@ def test_accept_malformed_complex_json(self, duckdb_cursor): schema=schema, ) assert duckdb_cursor.sql("""DESCRIBE FROM bignum_table;""").fetchone() == ( - 'bignum_value', - 'BIGNUM', - 'YES', + "bignum_value", + "BIGNUM", + "YES", None, None, None, diff --git a/tests/fast/arrow/test_arrow_fetch.py b/tests/fast/arrow/test_arrow_fetch.py index 04a34595..a969da21 100644 --- a/tests/fast/arrow/test_arrow_fetch.py +++ b/tests/fast/arrow/test_arrow_fetch.py @@ -83,8 +83,8 @@ def test_to_arrow_chunk_size(self, duckdb_cursor): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("CREATE table t as select range a from range(3000);") - relation = duckdb_cursor.table('t') + relation = duckdb_cursor.table("t") arrow_tbl = relation.fetch_arrow_table() - assert arrow_tbl['a'].num_chunks == 1 + assert arrow_tbl["a"].num_chunks == 1 arrow_tbl = relation.fetch_arrow_table(2048) - assert arrow_tbl['a'].num_chunks == 2 + assert arrow_tbl["a"].num_chunks == 2 diff --git a/tests/fast/arrow/test_arrow_fetch_recordbatch.py b/tests/fast/arrow/test_arrow_fetch_recordbatch.py index 24d7c2c7..8915d886 100644 --- a/tests/fast/arrow/test_arrow_fetch_recordbatch.py +++ b/tests/fast/arrow/test_arrow_fetch_recordbatch.py @@ -1,7 +1,7 @@ import duckdb import pytest -pa = pytest.importorskip('pyarrow') +pa = pytest.importorskip("pyarrow") class TestArrowFetchRecordBatch(object): @@ -12,7 +12,7 @@ def test_record_batch_next_batch_numeric(self, duckdb_cursor): duckdb_cursor.execute("CREATE table t as select range a from range(3000);") query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -38,7 +38,7 @@ def test_record_batch_next_batch_bool(self, duckdb_cursor): ) query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -63,7 +63,7 @@ def test_record_batch_next_batch_varchar(self, duckdb_cursor): duckdb_cursor.execute("CREATE table t as select range::varchar a from range(3000);") query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -90,7 +90,7 @@ def test_record_batch_next_batch_struct(self, duckdb_cursor): ) query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -115,7 +115,7 @@ def test_record_batch_next_batch_list(self, duckdb_cursor): duckdb_cursor.execute("CREATE table t as select [i,i+1] as a from range(3000) as tbl(i);") query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -141,7 +141,7 @@ def test_record_batch_next_batch_map(self, duckdb_cursor): duckdb_cursor.execute("CREATE table t as select map([i], [i+1]) as a from range(3000) as tbl(i);") query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -169,7 +169,7 @@ def test_record_batch_next_batch_with_null(self, duckdb_cursor): ) query = duckdb_cursor.execute("SELECT a FROM t") record_batch_reader = query.fetch_record_batch(1024) - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 chunk = record_batch_reader.read_next_batch() @@ -224,15 +224,15 @@ def test_record_batch_next_batch_multiple_vectors_per_chunk_error(self, duckdb_c duckdb_cursor = duckdb.connect() duckdb_cursor.execute("CREATE table t as select range a from range(5000);") query = duckdb_cursor.execute("SELECT a FROM t") - with pytest.raises(RuntimeError, match='Approximate Batch Size of Record Batch MUST be higher than 0'): + with pytest.raises(RuntimeError, match="Approximate Batch Size of Record Batch MUST be higher than 0"): record_batch_reader = query.fetch_record_batch(0) - with pytest.raises(TypeError, match='incompatible function arguments'): + with pytest.raises(TypeError, match="incompatible function arguments"): record_batch_reader = query.fetch_record_batch(-1) def test_record_batch_reader_from_relation(self, duckdb_cursor): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("CREATE table t as select range a from range(3000);") - relation = duckdb_cursor.table('t') + relation = duckdb_cursor.table("t") record_batch_reader = relation.record_batch() chunk = record_batch_reader.read_next_batch() assert len(chunk) == 3000 @@ -249,7 +249,7 @@ def test_record_coverage(self, duckdb_cursor): def test_record_batch_query_error(self): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("CREATE table t as select 'foo' as a;") - with pytest.raises(duckdb.ConversionException, match='Conversion Error'): + with pytest.raises(duckdb.ConversionException, match="Conversion Error"): # 'execute' materializes the result, causing the error directly query = duckdb_cursor.execute("SELECT cast(a as double) FROM t") record_batch_reader = query.fetch_record_batch(1024) @@ -282,7 +282,7 @@ def test_many_chunk_sizes(self): record_batch_reader = query.fetch_record_batch(i) num_loops = int(object_size / i) for j in range(num_loops): - assert record_batch_reader.schema.names == ['a'] + assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == i remainder = object_size % i diff --git a/tests/fast/arrow/test_arrow_fixed_binary.py b/tests/fast/arrow/test_arrow_fixed_binary.py index aa0047a8..cec8d520 100644 --- a/tests/fast/arrow/test_arrow_fixed_binary.py +++ b/tests/fast/arrow/test_arrow_fixed_binary.py @@ -7,8 +7,8 @@ class TestArrowFixedBinary(object): def test_arrow_fixed_binary(self, duckdb_cursor): ids = [ None, - b'\x66\x4d\xf4\xae\xb1\x5c\xb0\x4a\xdd\x5d\x1d\x54', - b'\x66\x4d\xf4\xf0\xa3\xfc\xec\x5b\x26\x81\x4e\x1d', + b"\x66\x4d\xf4\xae\xb1\x5c\xb0\x4a\xdd\x5d\x1d\x54", + b"\x66\x4d\xf4\xf0\xa3\xfc\xec\x5b\x26\x81\x4e\x1d", ] id_array = pa.array(ids, type=pa.binary(12)) @@ -18,4 +18,4 @@ def test_arrow_fixed_binary(self, duckdb_cursor): SELECT lower(hex(id)) as id FROM arrow_table """ ).fetchall() - assert res == [(None,), ('664df4aeb15cb04add5d1d54',), ('664df4f0a3fcec5b26814e1d',)] + assert res == [(None,), ("664df4aeb15cb04add5d1d54",), ("664df4f0a3fcec5b26814e1d",)] diff --git a/tests/fast/arrow/test_arrow_ipc.py b/tests/fast/arrow/test_arrow_ipc.py index 1d71eaa4..24718bbc 100644 --- a/tests/fast/arrow/test_arrow_ipc.py +++ b/tests/fast/arrow/test_arrow_ipc.py @@ -1,14 +1,14 @@ import pytest import duckdb -pa = pytest.importorskip('pyarrow') +pa = pytest.importorskip("pyarrow") -ipc = pytest.importorskip('pyarrow.ipc') +ipc = pytest.importorskip("pyarrow.ipc") def get_record_batch(): - data = [pa.array([1, 2, 3, 4]), pa.array(['foo', 'bar', 'baz', None]), pa.array([True, None, False, True])] - return pa.record_batch(data, names=['f0', 'f1', 'f2']) + data = [pa.array([1, 2, 3, 4]), pa.array(["foo", "bar", "baz", None]), pa.array([True, None, False, True])] + return pa.record_batch(data, names=["f0", "f1", "f2"]) class TestArrowIPCExtension(object): diff --git a/tests/fast/arrow/test_arrow_list.py b/tests/fast/arrow/test_arrow_list.py index 556f614a..47b8cb2a 100644 --- a/tests/fast/arrow/test_arrow_list.py +++ b/tests/fast/arrow/test_arrow_list.py @@ -21,7 +21,7 @@ def create_and_register_arrow_table(column_list, duckdb_cursor): def create_and_register_comparison_result(column_list, duckdb_cursor): - columns = ",".join([f'{name} {dtype}' for (name, dtype, _) in column_list]) + columns = ",".join([f"{name} {dtype}" for (name, dtype, _) in column_list]) column_amount = len(column_list) assert column_amount row_amount = len(column_list[0][2]) @@ -31,7 +31,7 @@ def create_and_register_comparison_result(column_list, duckdb_cursor): inserted_values.append(column_list[col][2][row]) inserted_values = tuple(inserted_values) - column_format = ",".join(['?' for _ in range(column_amount)]) + column_format = ",".join(["?" for _ in range(column_amount)]) row_format = ",".join([f"({column_format})" for _ in range(row_amount)]) query = f"""CREATE TABLE test ({columns}); INSERT INTO test VALUES {row_format}; @@ -73,7 +73,7 @@ def generate_list(child_size) -> ListGenerationResult: # Create a regular ListArray list_arr = pa.ListArray.from_arrays(offsets=offsets, values=input, mask=pa.array(mask, type=pa.bool_())) - if not hasattr(pa, 'ListViewArray'): + if not hasattr(pa, "ListViewArray"): return ListGenerationResult(list_arr, None) lists = list(reversed(lists)) @@ -102,13 +102,13 @@ def test_regular_list(self, duckdb_cursor): create_and_register_arrow_table( [ - ('a', list_type, data), + ("a", list_type, data), ], duckdb_cursor, ) create_and_register_comparison_result( [ - ('a', 'FLOAT[]', data), + ("a", "FLOAT[]", data), ], duckdb_cursor, ) @@ -125,26 +125,26 @@ def test_fixedsize_list(self, duckdb_cursor): create_and_register_arrow_table( [ - ('a', list_type, data), + ("a", list_type, data), ], duckdb_cursor, ) create_and_register_comparison_result( [ - ('a', f'FLOAT[{list_size}]', data), + ("a", f"FLOAT[{list_size}]", data), ], duckdb_cursor, ) check_equal(duckdb_cursor) - @pytest.mark.skipif(not hasattr(pa, 'ListViewArray'), reason='The pyarrow version does not support ListViewArrays') - @pytest.mark.parametrize('child_size', [100000]) + @pytest.mark.skipif(not hasattr(pa, "ListViewArray"), reason="The pyarrow version does not support ListViewArrays") + @pytest.mark.parametrize("child_size", [100000]) def test_list_view(self, duckdb_cursor, child_size): res = generate_list(child_size) - list_tbl = pa.Table.from_arrays([res.list], ['x']) - list_view_tbl = pa.Table.from_arrays([res.list_view], ['x']) + list_tbl = pa.Table.from_arrays([res.list], ["x"]) + list_view_tbl = pa.Table.from_arrays([res.list_view], ["x"]) assert res.list_view.to_pylist() == res.list.to_pylist() original = duckdb_cursor.query("select * from list_tbl").fetchall() diff --git a/tests/fast/arrow/test_arrow_offsets.py b/tests/fast/arrow/test_arrow_offsets.py index 6bc94530..0ddc0f7d 100644 --- a/tests/fast/arrow/test_arrow_offsets.py +++ b/tests/fast/arrow/test_arrow_offsets.py @@ -62,7 +62,7 @@ def decimal_value(value, precision, scale): val = str(value) actual_width = precision - scale if len(val) > actual_width: - return decimal.Decimal('9' * actual_width) + return decimal.Decimal("9" * actual_width) return decimal.Decimal(val) @@ -76,7 +76,7 @@ def expected_result(col1_null, col2_null, expected): null_test_parameters = lambda: mark.parametrize( - ['col1_null', 'col2_null'], [(False, True), (True, False), (True, True), (False, False)] + ["col1_null", "col2_null"], [(False, True), (True, False), (True, True), (False, False)] ) @@ -100,10 +100,10 @@ def test_struct_of_strings(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - assert res == expected_result(col1_null, col2_null, '131072') + assert res == expected_result(col1_null, col2_null, "131072") @null_test_parameters() def test_struct_of_bools(self, duckdb_cursor, col1_null, col2_null): @@ -126,7 +126,7 @@ def test_struct_of_bools(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, True) @@ -158,7 +158,7 @@ def test_struct_of_dates(self, duckdb_cursor, constructor, expected, col1_null, SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, expected) @@ -167,8 +167,8 @@ def test_struct_of_dates(self, duckdb_cursor, constructor, expected, col1_null, def test_struct_of_enum(self, duckdb_cursor, col1_null, col2_null): enum_type = pa.dictionary(pa.int64(), pa.utf8()) - tuples = ['red' for i in range(MAGIC_ARRAY_SIZE)] - tuples[-1] = 'green' + tuples = ["red" for i in range(MAGIC_ARRAY_SIZE)] + tuples[-1] = "green" if col1_null: tuples[-1] = None @@ -177,7 +177,7 @@ def test_struct_of_enum(self, duckdb_cursor, col1_null, col2_null): struct_tuples[-1] = None arrow_table = pa.Table.from_pydict( - {'col1': pa.array(tuples, enum_type), 'col2': pa.array(struct_tuples, pa.struct({"a": enum_type}))}, + {"col1": pa.array(tuples, enum_type), "col2": pa.array(struct_tuples, pa.struct({"a": enum_type}))}, schema=pa.schema([("col1", enum_type), ("col2", pa.struct({"a": enum_type}))]), ) res = duckdb_cursor.sql( @@ -185,10 +185,10 @@ def test_struct_of_enum(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - assert res == expected_result(col1_null, col2_null, 'green') + assert res == expected_result(col1_null, col2_null, "green") @null_test_parameters() def test_struct_of_blobs(self, duckdb_cursor, col1_null, col2_null): @@ -209,24 +209,24 @@ def test_struct_of_blobs(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - assert res == expected_result(col1_null, col2_null, b'131072') + assert res == expected_result(col1_null, col2_null, b"131072") @null_test_parameters() @pytest.mark.parametrize( ["constructor", "unit", "expected"], [ - (pa_time32(), 'ms', datetime.time(0, 2, 11, 72000)), - (pa_time32(), 's', datetime.time(23, 59, 59)), - (pa_time64(), 'ns', datetime.time(0, 0, 0, 131)), - (pa_time64(), 'us', datetime.time(0, 0, 0, 131072)), + (pa_time32(), "ms", datetime.time(0, 2, 11, 72000)), + (pa_time32(), "s", datetime.time(23, 59, 59)), + (pa_time64(), "ns", datetime.time(0, 0, 0, 131)), + (pa_time64(), "us", datetime.time(0, 0, 0, 131072)), ], ) def test_struct_of_time(self, duckdb_cursor, constructor, unit, expected, col1_null, col2_null): size = MAGIC_ARRAY_SIZE - if unit == 's': + if unit == "s": # FIXME: We limit the size because we don't support time values > 24 hours size = 86400 # The amount of seconds in a day @@ -247,7 +247,7 @@ def test_struct_of_time(self, duckdb_cursor, constructor, unit, expected, col1_n SELECT col1, col2.a - FROM arrow_table offset {size-1} + FROM arrow_table offset {size - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, expected) @@ -282,7 +282,7 @@ def test_struct_of_interval(self, duckdb_cursor, constructor, expected, converte SELECT col1, col2.a - FROM arrow_table offset {size-1} + FROM arrow_table offset {size - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, expected) @@ -291,10 +291,10 @@ def test_struct_of_interval(self, duckdb_cursor, constructor, expected, converte @pytest.mark.parametrize( ["constructor", "unit", "expected"], [ - (pa_duration(), 'ms', datetime.timedelta(seconds=131, microseconds=72000)), - (pa_duration(), 's', datetime.timedelta(days=1, seconds=44672)), - (pa_duration(), 'ns', datetime.timedelta(microseconds=131)), - (pa_duration(), 'us', datetime.timedelta(microseconds=131072)), + (pa_duration(), "ms", datetime.timedelta(seconds=131, microseconds=72000)), + (pa_duration(), "s", datetime.timedelta(days=1, seconds=44672)), + (pa_duration(), "ns", datetime.timedelta(microseconds=131)), + (pa_duration(), "us", datetime.timedelta(microseconds=131072)), ], ) def test_struct_of_duration(self, duckdb_cursor, constructor, unit, expected, col1_null, col2_null): @@ -317,7 +317,7 @@ def test_struct_of_duration(self, duckdb_cursor, constructor, unit, expected, co SELECT col1, col2.a - FROM arrow_table offset {size-1} + FROM arrow_table offset {size - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, expected) @@ -326,10 +326,10 @@ def test_struct_of_duration(self, duckdb_cursor, constructor, unit, expected, co @pytest.mark.parametrize( ["constructor", "unit", "expected"], [ - (pa_timestamp(), 'ms', datetime.datetime(1970, 1, 1, 0, 2, 11, 72000, tzinfo=pytz.utc)), - (pa_timestamp(), 's', datetime.datetime(1970, 1, 2, 12, 24, 32, 0, tzinfo=pytz.utc)), - (pa_timestamp(), 'ns', datetime.datetime(1970, 1, 1, 0, 0, 0, 131, tzinfo=pytz.utc)), - (pa_timestamp(), 'us', datetime.datetime(1970, 1, 1, 0, 0, 0, 131072, tzinfo=pytz.utc)), + (pa_timestamp(), "ms", datetime.datetime(1970, 1, 1, 0, 2, 11, 72000, tzinfo=pytz.utc)), + (pa_timestamp(), "s", datetime.datetime(1970, 1, 2, 12, 24, 32, 0, tzinfo=pytz.utc)), + (pa_timestamp(), "ns", datetime.datetime(1970, 1, 1, 0, 0, 0, 131, tzinfo=pytz.utc)), + (pa_timestamp(), "us", datetime.datetime(1970, 1, 1, 0, 0, 0, 131072, tzinfo=pytz.utc)), ], ) def test_struct_of_timestamp_tz(self, duckdb_cursor, constructor, unit, expected, col1_null, col2_null): @@ -346,7 +346,7 @@ def test_struct_of_timestamp_tz(self, duckdb_cursor, constructor, unit, expected arrow_table = pa.Table.from_pydict( {"col1": col1, "col2": col2}, schema=pa.schema( - [("col1", constructor(unit, 'UTC')), ("col2", pa.struct({"a": constructor(unit, 'UTC')}))] + [("col1", constructor(unit, "UTC")), ("col2", pa.struct({"a": constructor(unit, "UTC")}))] ), ) @@ -355,7 +355,7 @@ def test_struct_of_timestamp_tz(self, duckdb_cursor, constructor, unit, expected SELECT col1, col2.a - FROM arrow_table offset {size-1} + FROM arrow_table offset {size - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, expected) @@ -379,23 +379,23 @@ def test_struct_of_large_blobs(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - assert res == expected_result(col1_null, col2_null, b'131072') + assert res == expected_result(col1_null, col2_null, b"131072") @null_test_parameters() @pytest.mark.parametrize( ["precision_scale", "expected"], [ - ((38, 37), decimal.Decimal('9.0000000000000000000000000000000000000')), - ((38, 24), decimal.Decimal('131072.000000000000000000000000')), - ((18, 14), decimal.Decimal('9999.00000000000000')), - ((18, 5), decimal.Decimal('131072.00000')), - ((9, 7), decimal.Decimal('99.0000000')), - ((9, 3), decimal.Decimal('131072.000')), - ((4, 2), decimal.Decimal('99.00')), - ((4, 0), decimal.Decimal('9999')), + ((38, 37), decimal.Decimal("9.0000000000000000000000000000000000000")), + ((38, 24), decimal.Decimal("131072.000000000000000000000000")), + ((18, 14), decimal.Decimal("9999.00000000000000")), + ((18, 5), decimal.Decimal("131072.00000")), + ((9, 7), decimal.Decimal("99.0000000")), + ((9, 3), decimal.Decimal("131072.000")), + ((4, 2), decimal.Decimal("99.00")), + ((4, 0), decimal.Decimal("9999")), ], ) def test_struct_of_decimal(self, duckdb_cursor, precision_scale, expected, col1_null, col2_null): @@ -420,7 +420,7 @@ def test_struct_of_decimal(self, duckdb_cursor, precision_scale, expected, col1_ SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() assert res == expected_result(col1_null, col2_null, expected) @@ -443,16 +443,16 @@ def test_struct_of_small_list(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - res1 = None if col1_null else '131072' + res1 = None if col1_null else "131072" if col2_null: res2 = None elif col1_null: res2 = [None, None, None] else: - res2 = ['131072', '131072', '131072'] + res2 = ["131072", "131072", "131072"] assert res == [(res1, res2)] @null_test_parameters() @@ -473,16 +473,16 @@ def test_struct_of_fixed_size_list(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - res1 = None if col1_null else '131072' + res1 = None if col1_null else "131072" if col2_null: res2 = None elif col1_null: res2 = (None, None, None) else: - res2 = ('131072', '131072', '131072') + res2 = ("131072", "131072", "131072") assert res == [(res1, res2)] @null_test_parameters() @@ -504,16 +504,16 @@ def test_struct_of_fixed_size_blob(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - res1 = None if col1_null else b'131072' + res1 = None if col1_null else b"131072" if col2_null: res2 = None elif col1_null: res2 = (None, None, None) else: - res2 = (b'131072', b'131073', b'131074') + res2 = (b"131072", b"131073", b"131074") assert res == [(res1, res2)] @null_test_parameters() @@ -535,16 +535,16 @@ def test_struct_of_list_of_blobs(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() - res1 = None if col1_null else b'131072' + res1 = None if col1_null else b"131072" if col2_null: res2 = None elif col1_null: res2 = [None, None, None] else: - res2 = [b'131072', b'131073', b'131074'] + res2 = [b"131072", b"131073", b"131074"] assert res == [(res1, res2)] @null_test_parameters() @@ -566,7 +566,7 @@ def test_struct_of_list_of_list(self, duckdb_cursor, col1_null, col2_null): SELECT col1, col2.a - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() res1 = None if col1_null else 131072 @@ -578,7 +578,7 @@ def test_struct_of_list_of_list(self, duckdb_cursor, col1_null, col2_null): res2 = [[131072, 131072, 131072], [], None, [131072]] assert res == [(res1, res2)] - @pytest.mark.parametrize('col1_null', [True, False]) + @pytest.mark.parametrize("col1_null", [True, False]) def test_list_of_struct(self, duckdb_cursor, col1_null): # One single tuple containing a very big list tuples = [{"a": i} for i in range(0, MAGIC_ARRAY_SIZE)] @@ -599,19 +599,19 @@ def test_list_of_struct(self, duckdb_cursor, col1_null): res = res[0][0] for i, x in enumerate(res[:-1]): assert x.__class__ == dict - assert x['a'] == i + assert x["a"] == i if col1_null: assert res[-1] == None else: - assert res[-1]['a'] == len(res) - 1 + assert res[-1]["a"] == len(res) - 1 - @pytest.mark.parametrize(['outer_null', 'inner_null'], [(True, False), (False, True)]) + @pytest.mark.parametrize(["outer_null", "inner_null"], [(True, False), (False, True)]) def test_list_of_list_of_struct(self, duckdb_cursor, outer_null, inner_null): tuples = [[[{"a": str(i), "b": None, "c": [i]}]] for i in range(MAGIC_ARRAY_SIZE)] if outer_null: tuples[-1] = None else: - inner = [[{"a": 'aaaaaaaaaaaaaaa', "b": 'test', "c": [1, 2, 3]}] for _ in range(MAGIC_ARRAY_SIZE)] + inner = [[{"a": "aaaaaaaaaaaaaaa", "b": "test", "c": [1, 2, 3]}] for _ in range(MAGIC_ARRAY_SIZE)] if inner_null: inner[-1] = None tuples[-1] = inner @@ -635,7 +635,7 @@ def test_list_of_list_of_struct(self, duckdb_cursor, outer_null, inner_null): f""" SELECT col1 - FROM arrow_table OFFSET {MAGIC_ARRAY_SIZE-1} + FROM arrow_table OFFSET {MAGIC_ARRAY_SIZE - 1} """ ).fetchall() if outer_null: @@ -646,7 +646,7 @@ def test_list_of_list_of_struct(self, duckdb_cursor, outer_null, inner_null): else: assert res[-1][-1][-1] == 131072 - @pytest.mark.parametrize('col1_null', [True, False]) + @pytest.mark.parametrize("col1_null", [True, False]) def test_struct_of_list(self, duckdb_cursor, col1_null): # All elements are of size 1 tuples = [{"a": [str(i)]} for i in range(MAGIC_ARRAY_SIZE)] @@ -664,13 +664,13 @@ def test_struct_of_list(self, duckdb_cursor, col1_null): f""" SELECT col1 - FROM arrow_table offset {MAGIC_ARRAY_SIZE-1} + FROM arrow_table offset {MAGIC_ARRAY_SIZE - 1} """ ).fetchone() if col1_null: assert res[0] == None else: - assert res[0]['a'][-1] == '131072' + assert res[0]["a"][-1] == "131072" def test_bools_with_offset(self, duckdb_cursor): bools = [False, False, False, False, True, False, False, False, False, False] diff --git a/tests/fast/arrow/test_arrow_pycapsule.py b/tests/fast/arrow/test_arrow_pycapsule.py index 8310c58b..6df5053f 100644 --- a/tests/fast/arrow/test_arrow_pycapsule.py +++ b/tests/fast/arrow/test_arrow_pycapsule.py @@ -8,11 +8,11 @@ def polars_supports_capsule(): from packaging.version import Version - return Version(pl.__version__) >= Version('1.4.1') + return Version(pl.__version__) >= Version("1.4.1") @pytest.mark.skipif( - not polars_supports_capsule(), reason='Polars version does not support the Arrow PyCapsule interface' + not polars_supports_capsule(), reason="Polars version does not support the Arrow PyCapsule interface" ) class TestArrowPyCapsule(object): def test_polars_pycapsule_scan(self, duckdb_cursor): @@ -25,7 +25,7 @@ def __arrow_c_stream__(self, requested_schema=None) -> object: self.count += 1 return self.obj.__arrow_c_stream__(requested_schema=requested_schema) - df = pl.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]}) + df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}) obj = MyObject(df) # Call the __arrow_c_stream__ from within DuckDB diff --git a/tests/fast/arrow/test_arrow_recordbatchreader.py b/tests/fast/arrow/test_arrow_recordbatchreader.py index 0f8a701d..a9523d43 100644 --- a/tests/fast/arrow/test_arrow_recordbatchreader.py +++ b/tests/fast/arrow/test_arrow_recordbatchreader.py @@ -10,11 +10,10 @@ class TestArrowRecordBatchReader(object): def test_parallel_reader(self, duckdb_cursor): - duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -31,19 +30,16 @@ def test_parallel_reader(self, duckdb_cursor): rel = duckdb_conn.from_arrow(reader) assert ( - rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 12 + rel.filter("first_name='Jose' and salary > 134708.82").aggregate("count(*)").execute().fetchone()[0] == 12 ) # The reader is already consumed so this should be 0 - assert ( - rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 0 - ) + assert rel.filter("first_name='Jose' and salary > 134708.82").aggregate("count(*)").execute().fetchone()[0] == 0 def test_parallel_reader_replacement_scans(self, duckdb_cursor): - duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -59,23 +55,22 @@ def test_parallel_reader_replacement_scans(self, duckdb_cursor): assert ( duckdb_conn.execute( - "select count(*) r1 from reader where first_name=\'Jose\' and salary > 134708.82" + "select count(*) r1 from reader where first_name='Jose' and salary > 134708.82" ).fetchone()[0] == 12 ) assert ( duckdb_conn.execute( - "select count(*) r2 from reader where first_name=\'Jose\' and salary > 134708.82" + "select count(*) r2 from reader where first_name='Jose' and salary > 134708.82" ).fetchone()[0] == 0 ) def test_parallel_reader_register(self, duckdb_cursor): - duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -92,21 +87,16 @@ def test_parallel_reader_register(self, duckdb_cursor): duckdb_conn.register("bla", reader) assert ( - duckdb_conn.execute("select count(*) from bla where first_name=\'Jose\' and salary > 134708.82").fetchone()[ - 0 - ] + duckdb_conn.execute("select count(*) from bla where first_name='Jose' and salary > 134708.82").fetchone()[0] == 12 ) assert ( - duckdb_conn.execute("select count(*) from bla where first_name=\'Jose\' and salary > 134708.82").fetchone()[ - 0 - ] + duckdb_conn.execute("select count(*) from bla where first_name='Jose' and salary > 134708.82").fetchone()[0] == 0 ) def test_parallel_reader_default_conn(self, duckdb_cursor): - - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -123,9 +113,7 @@ def test_parallel_reader_default_conn(self, duckdb_cursor): rel = duckdb.from_arrow(reader) assert ( - rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 12 + rel.filter("first_name='Jose' and salary > 134708.82").aggregate("count(*)").execute().fetchone()[0] == 12 ) # The reader is already consumed so this should be 0 - assert ( - rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 0 - ) + assert rel.filter("first_name='Jose' and salary > 134708.82").aggregate("count(*)").execute().fetchone()[0] == 0 diff --git a/tests/fast/arrow/test_arrow_replacement_scan.py b/tests/fast/arrow/test_arrow_replacement_scan.py index a02bac10..f2a9c13b 100644 --- a/tests/fast/arrow/test_arrow_replacement_scan.py +++ b/tests/fast/arrow/test_arrow_replacement_scan.py @@ -10,8 +10,7 @@ class TestArrowReplacementScan(object): def test_arrow_table_replacement_scan(self, duckdb_cursor): - - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_table = pq.read_table(parquet_filename) df = userdata_parquet_table.to_pandas() @@ -22,11 +21,11 @@ def test_arrow_table_replacement_scan(self, duckdb_cursor): assert con.execute("select count(*) from df").fetchone() == (1000,) @pytest.mark.skipif( - not hasattr(pa.Table, '__arrow_c_stream__'), - reason='This version of pyarrow does not support the Arrow Capsule Interface', + not hasattr(pa.Table, "__arrow_c_stream__"), + reason="This version of pyarrow does not support the Arrow Capsule Interface", ) def test_arrow_pycapsule_replacement_scan(self, duckdb_cursor): - tbl = pa.Table.from_pydict({'a': [1, 2, 3, 4, 5, 6, 7, 8, 9]}) + tbl = pa.Table.from_pydict({"a": [1, 2, 3, 4, 5, 6, 7, 8, 9]}) capsule = tbl.__arrow_c_stream__() rel = duckdb_cursor.sql("select * from capsule") @@ -36,13 +35,13 @@ def test_arrow_pycapsule_replacement_scan(self, duckdb_cursor): rel = duckdb_cursor.sql("select * from capsule where a > 3 and a < 5") assert rel.fetchall() == [(4,)] - tbl = pa.Table.from_pydict({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9], 'd': [10, 11, 12]}) + tbl = pa.Table.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9], "d": [10, 11, 12]}) capsule = tbl.__arrow_c_stream__() rel = duckdb_cursor.sql("select b, d from capsule") assert rel.fetchall() == [(i, i + 6) for i in range(4, 7)] - with pytest.raises(duckdb.InvalidInputException, match='The ArrowArrayStream was already released'): + with pytest.raises(duckdb.InvalidInputException, match="The ArrowArrayStream was already released"): rel = duckdb_cursor.sql("select b, d from capsule") schema_obj = tbl.schema @@ -53,19 +52,18 @@ def test_arrow_pycapsule_replacement_scan(self, duckdb_cursor): rel = duckdb_cursor.sql("select b, d from schema_capsule") def test_arrow_table_replacement_scan_view(self, duckdb_cursor): - - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_table = pq.read_table(parquet_filename) con = duckdb.connect() con.execute("create view x as select * from userdata_parquet_table") del userdata_parquet_table - with pytest.raises(duckdb.CatalogException, match='Table with name userdata_parquet_table does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name userdata_parquet_table does not exist"): assert con.execute("select count(*) from x").fetchone() def test_arrow_dataset_replacement_scan(self, duckdb_cursor): - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_table = pq.read_table(parquet_filename) userdata_parquet_dataset = ds.dataset(parquet_filename) diff --git a/tests/fast/arrow/test_arrow_run_end_encoding.py b/tests/fast/arrow/test_arrow_run_end_encoding.py index 6315d1b7..c6f9fad5 100644 --- a/tests/fast/arrow/test_arrow_run_end_encoding.py +++ b/tests/fast/arrow/test_arrow_run_end_encoding.py @@ -3,7 +3,7 @@ import pandas as pd import duckdb -pa = pytest.importorskip("pyarrow", '21.0.0', reason="Needs pyarrow >= 21") +pa = pytest.importorskip("pyarrow", "21.0.0", reason="Needs pyarrow >= 21") pc = pytest.importorskip("pyarrow.compute") @@ -25,14 +25,14 @@ def create_list(offsets, values): def list_constructors(): result = [] result.append(create_list) - if hasattr(pa, 'ListViewArray'): + if hasattr(pa, "ListViewArray"): result.append(create_list_view) return result class TestArrowREE(object): @pytest.mark.parametrize( - 'query', + "query", [ """ select @@ -46,22 +46,22 @@ class TestArrowREE(object): """, ], ) - @pytest.mark.parametrize('run_length', [4, 1, 10, 1000, 2048, 3000]) - @pytest.mark.parametrize('size', [100, 10000]) + @pytest.mark.parametrize("run_length", [4, 1, 10, 1000, 2048, 3000]) + @pytest.mark.parametrize("size", [100, 10000]) @pytest.mark.parametrize( - 'value_type', - ['UTINYINT', 'USMALLINT', 'UINTEGER', 'UBIGINT', 'TINYINT', 'SMALLINT', 'INTEGER', 'BIGINT', 'HUGEINT'], + "value_type", + ["UTINYINT", "USMALLINT", "UINTEGER", "UBIGINT", "TINYINT", "SMALLINT", "INTEGER", "BIGINT", "HUGEINT"], ) def test_arrow_run_end_encoding_numerics(self, duckdb_cursor, query, run_length, size, value_type): - if value_type == 'UTINYINT': + if value_type == "UTINYINT": if size > 255: size = 255 - if value_type == 'TINYINT': + if value_type == "TINYINT": if size > 127: size = 127 query = query.format(run_length, value_type, size) rel = duckdb_cursor.sql(query) - array = rel.fetch_arrow_table()['ree'] + array = rel.fetch_arrow_table()["ree"] expected = rel.fetchall() encoded_array = pc.run_end_encode(array) @@ -72,31 +72,31 @@ def test_arrow_run_end_encoding_numerics(self, duckdb_cursor, query, run_length, assert res == expected @pytest.mark.parametrize( - ['dbtype', 'val1', 'val2'], + ["dbtype", "val1", "val2"], [ - ('TINYINT', '(-128)', '127'), - ('SMALLINT', '(-32768)', '32767'), - ('INTEGER', '(-2147483648)', '2147483647'), - ('BIGINT', '(-9223372036854775808)', '9223372036854775807'), - ('UTINYINT', '0', '255'), - ('USMALLINT', '0', '65535'), - ('UINTEGER', '0', '4294967295'), - ('UBIGINT', '0', '18446744073709551615'), - ('BOOL', 'true', 'false'), - ('VARCHAR', "'test'", "'this is a long string'"), - ('BLOB', "'\\xE0\\x9F\\x98\\x84'", "'\\xF0\\x9F\\xA6\\x86'"), - ('DATE', "'1992-03-27'", "'2204-11-01'"), - ('TIME', "'01:02:03'", "'23:41:35'"), - ('TIMESTAMP_S', "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), - ('TIMESTAMP', "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), - ('TIMESTAMP_MS', "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), - ('TIMESTAMP_NS', "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), - ('DECIMAL(4,2)', "'12.23'", "'99.99'"), - ('DECIMAL(7,6)', "'1.234234'", "'0.000001'"), - ('DECIMAL(14,7)', "'134523.234234'", "'999999.000001'"), - ('DECIMAL(28,1)', "'12345678910111234123456789.1'", "'999999999999999999999999999.9'"), - ('UUID', "'10acd298-15d7-417c-8b59-eabb5a2bacab'", "'eeccb8c5-9943-b2bb-bb5e-222f4e14b687'"), - ('BIT', "'01010101010000'", "'01010100010101010101010101111111111'"), + ("TINYINT", "(-128)", "127"), + ("SMALLINT", "(-32768)", "32767"), + ("INTEGER", "(-2147483648)", "2147483647"), + ("BIGINT", "(-9223372036854775808)", "9223372036854775807"), + ("UTINYINT", "0", "255"), + ("USMALLINT", "0", "65535"), + ("UINTEGER", "0", "4294967295"), + ("UBIGINT", "0", "18446744073709551615"), + ("BOOL", "true", "false"), + ("VARCHAR", "'test'", "'this is a long string'"), + ("BLOB", "'\\xE0\\x9F\\x98\\x84'", "'\\xF0\\x9F\\xA6\\x86'"), + ("DATE", "'1992-03-27'", "'2204-11-01'"), + ("TIME", "'01:02:03'", "'23:41:35'"), + ("TIMESTAMP_S", "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), + ("TIMESTAMP", "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), + ("TIMESTAMP_MS", "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), + ("TIMESTAMP_NS", "'1992-03-22 01:02:03'", "'2022-11-07 08:43:04.123456'"), + ("DECIMAL(4,2)", "'12.23'", "'99.99'"), + ("DECIMAL(7,6)", "'1.234234'", "'0.000001'"), + ("DECIMAL(14,7)", "'134523.234234'", "'999999.000001'"), + ("DECIMAL(28,1)", "'12345678910111234123456789.1'", "'999999999999999999999999999.9'"), + ("UUID", "'10acd298-15d7-417c-8b59-eabb5a2bacab'", "'eeccb8c5-9943-b2bb-bb5e-222f4e14b687'"), + ("BIT", "'01010101010000'", "'01010100010101010101010101111111111'"), ], ) @pytest.mark.parametrize( @@ -107,7 +107,7 @@ def test_arrow_run_end_encoding_numerics(self, duckdb_cursor, query, run_length, ], ) def test_arrow_run_end_encoding(self, duckdb_cursor, dbtype, val1, val2, filter): - if dbtype in ['BIT', 'UUID']: + if dbtype in ["BIT", "UUID"]: pytest.skip("BIT and UUID are currently broken (FIXME)") projection = "a, b, ree" query = """ @@ -135,25 +135,25 @@ def test_arrow_run_end_encoding(self, duckdb_cursor, dbtype, val1, val2, filter) # Create an Arrow Table from the table arrow_conversion = rel.fetch_arrow_table() arrays = { - 'ree': arrow_conversion['ree'], - 'a': arrow_conversion['a'], - 'b': arrow_conversion['b'], + "ree": arrow_conversion["ree"], + "a": arrow_conversion["a"], + "b": arrow_conversion["b"], } encoded_arrays = { - 'ree': pc.run_end_encode(arrays['ree']), - 'a': pc.run_end_encode(arrays['a']), - 'b': pc.run_end_encode(arrays['b']), + "ree": pc.run_end_encode(arrays["ree"]), + "a": pc.run_end_encode(arrays["a"]), + "b": pc.run_end_encode(arrays["b"]), } schema = pa.schema( [ - ("ree", encoded_arrays['ree'].type), - ("a", encoded_arrays['a'].type), - ("b", encoded_arrays['b'].type), + ("ree", encoded_arrays["ree"].type), + ("a", encoded_arrays["a"].type), + ("b", encoded_arrays["b"].type), ] ) - tbl = pa.Table.from_arrays([encoded_arrays['ree'], encoded_arrays['a'], encoded_arrays['b']], schema=schema) + tbl = pa.Table.from_arrays([encoded_arrays["ree"], encoded_arrays["a"], encoded_arrays["b"]], schema=schema) # Scan the Arrow Table and verify that the results are the same res = duckdb_cursor.sql("select {} from tbl where {}".format(projection, filter)).fetchall() @@ -161,8 +161,8 @@ def test_arrow_run_end_encoding(self, duckdb_cursor, dbtype, val1, val2, filter) def test_arrow_ree_empty_table(self, duckdb_cursor): duckdb_cursor.query("create table tbl (ree integer)") - rel = duckdb_cursor.table('tbl') - array = rel.fetch_arrow_table()['ree'] + rel = duckdb_cursor.table("tbl") + array = rel.fetch_arrow_table()["ree"] expected = rel.fetchall() encoded_array = pc.run_end_encode(array) @@ -172,7 +172,7 @@ def test_arrow_ree_empty_table(self, duckdb_cursor): res = duckdb_cursor.sql("select * from pa_res").fetchall() assert res == expected - @pytest.mark.parametrize('projection', ['*', 'a, c, b', 'ree, a, b, c', 'c, b, a, ree', 'c', 'b, ree, c, a']) + @pytest.mark.parametrize("projection", ["*", "a, c, b", "ree, a, b, c", "c, b, a, ree", "c", "b, ree, c, a"]) def test_arrow_ree_projections(self, duckdb_cursor, projection): # Create the schema duckdb_cursor.query( @@ -199,28 +199,28 @@ def test_arrow_ree_projections(self, duckdb_cursor, projection): ) # Fetch the result as an Arrow Table - result = duckdb_cursor.table('tbl').fetch_arrow_table() + result = duckdb_cursor.table("tbl").fetch_arrow_table() # Turn 'ree' into a run-end-encoded array and reconstruct a table from it arrays = { - 'ree': pc.run_end_encode(result['ree']), - 'a': result['a'], - 'b': result['b'], - 'c': result['c'], + "ree": pc.run_end_encode(result["ree"]), + "a": result["a"], + "b": result["b"], + "c": result["c"], } schema = pa.schema( [ - ("ree", arrays['ree'].type), - ("a", arrays['a'].type), - ("b", arrays['b'].type), - ("c", arrays['c'].type), + ("ree", arrays["ree"].type), + ("a", arrays["a"].type), + ("b", arrays["b"].type), + ("c", arrays["c"].type), ] ) - arrow_tbl = pa.Table.from_arrays([arrays['ree'], arrays['a'], arrays['b'], arrays['c']], schema=schema) + arrow_tbl = pa.Table.from_arrays([arrays["ree"], arrays["a"], arrays["b"], arrays["c"]], schema=schema) # Verify that the array is run end encoded - ar_type = arrow_tbl['ree'].type + ar_type = arrow_tbl["ree"].type assert pa.types.is_run_end_encoded(ar_type) == True # Scan the arrow table, making projections that don't cover the entire table @@ -229,9 +229,7 @@ def test_arrow_ree_projections(self, duckdb_cursor, projection): res = duckdb_cursor.query( """ select {} from arrow_tbl - """.format( - projection - ) + """.format(projection) ).fetch_arrow_table() # Verify correctness by fetching from the original table and the constructed result @@ -239,7 +237,7 @@ def test_arrow_ree_projections(self, duckdb_cursor, projection): actual = duckdb_cursor.query("select {} from res".format(projection)).fetchall() assert expected == actual - @pytest.mark.parametrize('create_list', list_constructors()) + @pytest.mark.parametrize("create_list", list_constructors()) def test_arrow_ree_list(self, duckdb_cursor, create_list): size = 1000 duckdb_cursor.query( @@ -248,9 +246,7 @@ def test_arrow_ree_list(self, duckdb_cursor, create_list): as select i // 4 as ree, FROM range({}) t(i) - """.format( - size - ) + """.format(size) ) # Populate the table with data @@ -281,7 +277,7 @@ def test_arrow_ree_list(self, duckdb_cursor, create_list): structured = pa.chunked_array(structured_chunks) - arrow_tbl = pa.Table.from_arrays([structured], names=['ree']) + arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() assert arrow_tbl.to_pylist() == result.to_pylist() @@ -317,7 +313,7 @@ def test_arrow_ree_struct(self, duckdb_cursor): structured_chunks = [pa.StructArray.from_arrays([y for y in x], names=names) for x in zipped] structured = pa.chunked_array(structured_chunks) - arrow_tbl = pa.Table.from_arrays([structured], names=['ree']) + arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() expected = duckdb_cursor.query("select {'ree': ree, 'a': a, 'b': b, 'c': c} as s from tbl").fetchall() @@ -337,9 +333,7 @@ def test_arrow_ree_union(self, duckdb_cursor): i % 2 == 0 as b, i::VARCHAR as c FROM range({}) t(i) - """.format( - size - ) + """.format(size) ) # Populate the table with data @@ -368,7 +362,7 @@ def test_arrow_ree_union(self, duckdb_cursor): structured_chunks.append(new_array) structured = pa.chunked_array(structured_chunks) - arrow_tbl = pa.Table.from_arrays([structured], names=['ree']) + arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() # Recreate the same result set @@ -395,9 +389,7 @@ def test_arrow_ree_map(self, duckdb_cursor): i // 4 as ree, i as a, FROM range({}) t(i) - """.format( - size - ) + """.format(size) ) # Populate the table with data @@ -431,7 +423,7 @@ def test_arrow_ree_map(self, duckdb_cursor): structured_chunks.append(new_array) structured = pa.chunked_array(structured_chunks) - arrow_tbl = pa.Table.from_arrays([structured], names=['ree']) + arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() # Verify that the resulting scan is the same as the input @@ -446,9 +438,7 @@ def test_arrow_ree_dictionary(self, duckdb_cursor): as select i // 4 as ree, FROM range({}) t(i) - """.format( - size - ) + """.format(size) ) # Populate the table with data @@ -473,7 +463,7 @@ def test_arrow_ree_dictionary(self, duckdb_cursor): structured_chunks.append(new_array) structured = pa.chunked_array(structured_chunks) - arrow_tbl = pa.Table.from_arrays([structured], names=['ree']) + arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() # Verify that the resulting scan is the same as the input diff --git a/tests/fast/arrow/test_arrow_scanner.py b/tests/fast/arrow/test_arrow_scanner.py index 6d74ddb5..2e8b1296 100644 --- a/tests/fast/arrow/test_arrow_scanner.py +++ b/tests/fast/arrow/test_arrow_scanner.py @@ -22,7 +22,7 @@ def test_parallel_scanner(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") arrow_dataset = pyarrow.dataset.dataset( [ @@ -33,13 +33,13 @@ def test_parallel_scanner(self, duckdb_cursor): format="parquet", ) - scanner_filter = (pc.field("first_name") == pc.scalar('Jose')) & (pc.field("salary") > pc.scalar(134708.82)) + scanner_filter = (pc.field("first_name") == pc.scalar("Jose")) & (pc.field("salary") > pc.scalar(134708.82)) arrow_scanner = Scanner.from_dataset(arrow_dataset, filter=scanner_filter) rel = duckdb_conn.from_arrow(arrow_scanner) - assert rel.aggregate('count(*)').execute().fetchone()[0] == 12 + assert rel.aggregate("count(*)").execute().fetchone()[0] == 12 def test_parallel_scanner_replacement_scans(self, duckdb_cursor): if not can_run: @@ -48,7 +48,7 @@ def test_parallel_scanner_replacement_scans(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") arrow_dataset = pyarrow.dataset.dataset( [ @@ -59,7 +59,7 @@ def test_parallel_scanner_replacement_scans(self, duckdb_cursor): format="parquet", ) - scanner_filter = (pc.field("first_name") == pc.scalar('Jose')) & (pc.field("salary") > pc.scalar(134708.82)) + scanner_filter = (pc.field("first_name") == pc.scalar("Jose")) & (pc.field("salary") > pc.scalar(134708.82)) arrow_scanner = Scanner.from_dataset(arrow_dataset, filter=scanner_filter) @@ -72,7 +72,7 @@ def test_parallel_scanner_register(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") arrow_dataset = pyarrow.dataset.dataset( [ @@ -83,7 +83,7 @@ def test_parallel_scanner_register(self, duckdb_cursor): format="parquet", ) - scanner_filter = (pc.field("first_name") == pc.scalar('Jose')) & (pc.field("salary") > pc.scalar(134708.82)) + scanner_filter = (pc.field("first_name") == pc.scalar("Jose")) & (pc.field("salary") > pc.scalar(134708.82)) arrow_scanner = Scanner.from_dataset(arrow_dataset, filter=scanner_filter) @@ -95,7 +95,7 @@ def test_parallel_scanner_default_conn(self, duckdb_cursor): if not can_run: return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") arrow_dataset = pyarrow.dataset.dataset( [ @@ -106,10 +106,10 @@ def test_parallel_scanner_default_conn(self, duckdb_cursor): format="parquet", ) - scanner_filter = (pc.field("first_name") == pc.scalar('Jose')) & (pc.field("salary") > pc.scalar(134708.82)) + scanner_filter = (pc.field("first_name") == pc.scalar("Jose")) & (pc.field("salary") > pc.scalar(134708.82)) arrow_scanner = Scanner.from_dataset(arrow_dataset, filter=scanner_filter) rel = duckdb.from_arrow(arrow_scanner) - assert rel.aggregate('count(*)').execute().fetchone()[0] == 12 + assert rel.aggregate("count(*)").execute().fetchone()[0] == 12 diff --git a/tests/fast/arrow/test_arrow_string_view.py b/tests/fast/arrow/test_arrow_string_view.py index fc4bbd40..a1b46e5b 100644 --- a/tests/fast/arrow/test_arrow_string_view.py +++ b/tests/fast/arrow/test_arrow_string_view.py @@ -2,10 +2,10 @@ import pytest from packaging import version -pa = pytest.importorskip('pyarrow') +pa = pytest.importorskip("pyarrow") pytestmark = pytest.mark.skipif( - not hasattr(pa, 'string_view'), reason="This version of PyArrow does not support StringViews" + not hasattr(pa, "string_view"), reason="This version of PyArrow does not support StringViews" ) @@ -20,7 +20,7 @@ def RoundTripStringView(query, array): # Generate an arrow table # Create a field for the array with a specific data type - field = pa.field('str_val', pa.string_view()) + field = pa.field("str_val", pa.string_view()) # Create a schema for the table using the field schema = pa.schema([field]) @@ -103,26 +103,26 @@ def test_not_inlined_string_view(self): # Test Over-Vector Size def test_large_string_view_inlined(self): - RoundTripDuckDBInternal('''select * from (SELECT i::varchar str FROM range(10000) tbl(i)) order by str''') + RoundTripDuckDBInternal("""select * from (SELECT i::varchar str FROM range(10000) tbl(i)) order by str""") def test_large_string_view_inlined_with_null(self): RoundTripDuckDBInternal( - '''select * from (SELECT i::varchar str FROM range(10000) tbl(i) UNION select null) order by str''' + """select * from (SELECT i::varchar str FROM range(10000) tbl(i) UNION select null) order by str""" ) def test_large_string_view_not_inlined(self): RoundTripDuckDBInternal( - '''select * from (SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str''' + """select * from (SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str""" ) def test_large_string_view_not_inlined_with_null(self): RoundTripDuckDBInternal( - '''select * from (SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str''' + """select * from (SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str""" ) def test_large_string_view_mixed_with_null(self): RoundTripDuckDBInternal( - '''select * from (SELECT i::varchar str FROM range(10000) tbl(i) UNION SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str''' + """select * from (SELECT i::varchar str FROM range(10000) tbl(i) UNION SELECT 'Imaverybigstringmuchbiggerthanfourbytes'||i::varchar str FROM range(10000) tbl(i) UNION select null) order by str""" ) def test_multiple_data_buffers(self): diff --git a/tests/fast/arrow/test_arrow_types.py b/tests/fast/arrow/test_arrow_types.py index 97f747ef..f2bf71c7 100644 --- a/tests/fast/arrow/test_arrow_types.py +++ b/tests/fast/arrow/test_arrow_types.py @@ -17,7 +17,7 @@ def test_null_type(self, duckdb_cursor): inputs = [pa.array([None, None, None], type=pa.null())] arrow_table = pa.Table.from_arrays(inputs, schema=schema) - assert rel['data'] == arrow_table['data'] + assert rel["data"] == arrow_table["data"] def test_invalid_struct(self, duckdb_cursor): empty_struct_type = pa.struct([]) @@ -27,7 +27,7 @@ def test_invalid_struct(self, duckdb_cursor): arrow_table = pa.Table.from_arrays([empty_array], schema=pa.schema([("data", empty_struct_type)])) with pytest.raises( duckdb.InvalidInputException, - match='Attempted to convert a STRUCT with no fields to DuckDB which is not supported', + match="Attempted to convert a STRUCT with no fields to DuckDB which is not supported", ): duckdb_cursor.sql("select * from arrow_table").fetchall() @@ -39,9 +39,9 @@ def test_invalid_union(self, duckdb_cursor): arrow_table = pa.Table.from_arrays([sparse_union_array], schema=pa.schema([("data", sparse_union_array.type)])) with pytest.raises( duckdb.InvalidInputException, - match='Attempted to convert a UNION with no fields to DuckDB which is not supported', + match="Attempted to convert a UNION with no fields to DuckDB which is not supported", ): - duckdb_cursor.register('invalid_union', arrow_table) + duckdb_cursor.register("invalid_union", arrow_table) res = duckdb_cursor.sql("select * from invalid_union").fetchall() print(res) diff --git a/tests/fast/arrow/test_arrow_union.py b/tests/fast/arrow/test_arrow_union.py index 1d853a1b..c0a5d568 100644 --- a/tests/fast/arrow/test_arrow_union.py +++ b/tests/fast/arrow/test_arrow_union.py @@ -1,13 +1,13 @@ from pytest import importorskip -importorskip('pyarrow') +importorskip("pyarrow") import duckdb from pyarrow import scalar, string, large_string, list_, int32, types def test_nested(duckdb_cursor): - res = run(duckdb_cursor, 'select 42::UNION(name VARCHAR, attr UNION(age INT, veteran BOOL)) as res') + res = run(duckdb_cursor, "select 42::UNION(name VARCHAR, attr UNION(age INT, veteran BOOL)) as res") assert types.is_union(res.type) assert res.value.value == scalar(42, type=int32()) @@ -16,14 +16,14 @@ def test_union_contains_nested_data(duckdb_cursor): _ = importorskip("pyarrow", minversion="11") res = run(duckdb_cursor, "select ['hello']::UNION(first_name VARCHAR, middle_names VARCHAR[]) as res") assert types.is_union(res.type) - assert res.value == scalar(['hello'], type=list_(string())) + assert res.value == scalar(["hello"], type=list_(string())) def test_unions_inside_lists_structs_maps(duckdb_cursor): res = run(duckdb_cursor, "select [union_value(name := 'Frank')] as res") assert types.is_list(res.type) assert types.is_union(res.type.value_type) - assert res[0].value == scalar('Frank', type=string()) + assert res[0].value == scalar("Frank", type=string()) def test_unions_with_struct(duckdb_cursor): @@ -38,13 +38,13 @@ def test_unions_with_struct(duckdb_cursor): """ ) - rel = duckdb_cursor.table('tbl') + rel = duckdb_cursor.table("tbl") arrow = rel.fetch_arrow_table() duckdb_cursor.execute("create table other as select * from arrow") - rel2 = duckdb_cursor.table('other') + rel2 = duckdb_cursor.table("other") res = rel2.fetchall() - assert res == [({'a': 42, 'b': True},)] + assert res == [({"a": 42, "b": True},)] def run(conn, query): diff --git a/tests/fast/arrow/test_arrow_version_format.py b/tests/fast/arrow/test_arrow_version_format.py index ff8699eb..fd169ce0 100644 --- a/tests/fast/arrow/test_arrow_version_format.py +++ b/tests/fast/arrow/test_arrow_version_format.py @@ -32,20 +32,20 @@ def test_decimal_v1_5(self, duckdb_cursor): ) col_type = duckdb_cursor.execute("FROM decimal_64").fetch_arrow_table().schema.field("data").type assert col_type.bit_width == 64 and pa.types.is_decimal(col_type) - for version in ['1.0', '1.1', '1.2', '1.3', '1.4']: + for version in ["1.0", "1.1", "1.2", "1.3", "1.4"]: duckdb_cursor.execute(f"SET arrow_output_version = {version}") result = duckdb_cursor.execute("FROM decimal_32").fetch_arrow_table() col_type = result.schema.field("data").type assert col_type.bit_width == 128 and pa.types.is_decimal(col_type) assert result.to_pydict() == { - 'data': [Decimal('100.20'), Decimal('110.21'), Decimal('31.20'), Decimal('500.20')] + "data": [Decimal("100.20"), Decimal("110.21"), Decimal("31.20"), Decimal("500.20")] } result = duckdb_cursor.execute("FROM decimal_64").fetch_arrow_table() col_type = result.schema.field("data").type assert col_type.bit_width == 128 and pa.types.is_decimal(col_type) assert result.to_pydict() == { - 'data': [Decimal('1000.231'), Decimal('1100.231'), Decimal('999999999999.231'), Decimal('500.200')] + "data": [Decimal("1000.231"), Decimal("1100.231"), Decimal("999999999999.231"), Decimal("500.200")] } def test_invalide_opt(self, duckdb_cursor): @@ -63,14 +63,14 @@ def test_view_v1_4(self, duckdb_cursor): col_type = duckdb_cursor.execute("SELECT ['string'] as data ").fetch_arrow_table().schema.field("data").type assert pa.types.is_list_view(col_type) - for version in ['1.0', '1.1', '1.2', '1.3']: + for version in ["1.0", "1.1", "1.2", "1.3"]: duckdb_cursor.execute(f"SET arrow_output_version = {version}") col_type = duckdb_cursor.execute("SELECT 'string' as data ").fetch_arrow_table().schema.field("data").type assert not pa.types.is_string_view(col_type) col_type = duckdb_cursor.execute("SELECT ['string'] as data ").fetch_arrow_table().schema.field("data").type assert not pa.types.is_list_view(col_type) - for version in ['1.4', '1.5']: + for version in ["1.4", "1.5"]: duckdb_cursor.execute(f"SET arrow_output_version = {version}") col_type = duckdb_cursor.execute("SELECT 'string' as data ").fetch_arrow_table().schema.field("data").type assert pa.types.is_string_view(col_type) @@ -80,7 +80,7 @@ def test_view_v1_4(self, duckdb_cursor): duckdb_cursor.execute("SET produce_arrow_string_view=False") duckdb_cursor.execute("SET arrow_output_list_view=False") - for version in ['1.4', '1.5']: + for version in ["1.4", "1.5"]: duckdb_cursor.execute(f"SET arrow_output_version = {version}") col_type = duckdb_cursor.execute("SELECT 'string' as data ").fetch_arrow_table().schema.field("data").type assert not pa.types.is_string_view(col_type) diff --git a/tests/fast/arrow/test_buffer_size_option.py b/tests/fast/arrow/test_buffer_size_option.py index 46047e21..7d5131e5 100644 --- a/tests/fast/arrow/test_buffer_size_option.py +++ b/tests/fast/arrow/test_buffer_size_option.py @@ -34,7 +34,7 @@ def just_return(x): return x con = duckdb.connect() - con.create_function('just_return', just_return, [VARCHAR], VARCHAR, type='arrow') + con.create_function("just_return", just_return, [VARCHAR], VARCHAR, type="arrow") res = con.query("select just_return('bla')").fetch_arrow_table() diff --git a/tests/fast/arrow/test_dataset.py b/tests/fast/arrow/test_dataset.py index 521ec8f7..8ec0094e 100644 --- a/tests/fast/arrow/test_dataset.py +++ b/tests/fast/arrow/test_dataset.py @@ -14,7 +14,7 @@ def test_parallel_dataset(self, duckdb_cursor): duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -28,7 +28,7 @@ def test_parallel_dataset(self, duckdb_cursor): rel = duckdb_conn.from_arrow(userdata_parquet_dataset) assert ( - rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)').execute().fetchone()[0] == 12 + rel.filter("first_name='Jose' and salary > 134708.82").aggregate("count(*)").execute().fetchone()[0] == 12 ) def test_parallel_dataset_register(self, duckdb_cursor): @@ -36,7 +36,7 @@ def test_parallel_dataset_register(self, duckdb_cursor): duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -61,7 +61,7 @@ def test_parallel_dataset_roundtrip(self, duckdb_cursor): duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_dataset = pyarrow.dataset.dataset( [ @@ -79,7 +79,7 @@ def test_parallel_dataset_roundtrip(self, duckdb_cursor): arrow_table = record_batch_reader.read_all() # reorder since order of rows isn't deterministic - df = userdata_parquet_dataset.to_table().to_pandas().sort_values('id').reset_index(drop=True) + df = userdata_parquet_dataset.to_table().to_pandas().sort_values("id").reset_index(drop=True) # turn it into an arrow table arrow_table_2 = pyarrow.Table.from_pandas(df) result_1 = duckdb_conn.execute("select * from arrow_table order by all").fetchall() @@ -94,7 +94,7 @@ def test_ducktyping(self, duckdb_cursor): query = duckdb_conn.execute("SELECT b FROM dataset WHERE a < 5") record_batch_reader = query.fetch_record_batch(2048) arrow_table = record_batch_reader.read_all() - assert arrow_table.equals(CustomDataset.DATA[:5].select(['b'])) + assert arrow_table.equals(CustomDataset.DATA[:5].select(["b"])) class CustomDataset(pyarrow.dataset.Dataset): diff --git a/tests/fast/arrow/test_date.py b/tests/fast/arrow/test_date.py index 316fc689..9649ffa6 100644 --- a/tests/fast/arrow/test_date.py +++ b/tests/fast/arrow/test_date.py @@ -18,30 +18,30 @@ def test_date_types(self, duckdb_cursor): return data = (pa.array([1000 * 60 * 60 * 24], type=pa.date64()), pa.array([1], type=pa.date32())) - arrow_table = pa.Table.from_arrays([data[0], data[1]], ['a', 'b']) + arrow_table = pa.Table.from_arrays([data[0], data[1]], ["a", "b"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == arrow_table['b'] - assert rel['b'] == arrow_table['b'] + assert rel["a"] == arrow_table["b"] + assert rel["b"] == arrow_table["b"] def test_date_null(self, duckdb_cursor): if not can_run: return data = (pa.array([None], type=pa.date64()), pa.array([None], type=pa.date32())) - arrow_table = pa.Table.from_arrays([data[0], data[1]], ['a', 'b']) + arrow_table = pa.Table.from_arrays([data[0], data[1]], ["a", "b"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == arrow_table['b'] - assert rel['b'] == arrow_table['b'] + assert rel["a"] == arrow_table["b"] + assert rel["b"] == arrow_table["b"] def test_max_date(self, duckdb_cursor): if not can_run: return data = (pa.array([2147483647], type=pa.date32()), pa.array([2147483647], type=pa.date32())) - result = pa.Table.from_arrays([data[0], data[1]], ['a', 'b']) + result = pa.Table.from_arrays([data[0], data[1]], ["a", "b"]) data = ( pa.array([2147483647 * (1000 * 60 * 60 * 24)], type=pa.date64()), pa.array([2147483647], type=pa.date32()), ) - arrow_table = pa.Table.from_arrays([data[0], data[1]], ['a', 'b']) + arrow_table = pa.Table.from_arrays([data[0], data[1]], ["a", "b"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == result['a'] - assert rel['b'] == result['b'] + assert rel["a"] == result["a"] + assert rel["b"] == result["b"] diff --git a/tests/fast/arrow/test_dictionary_arrow.py b/tests/fast/arrow/test_dictionary_arrow.py index 823d6b05..e4319f7c 100644 --- a/tests/fast/arrow/test_dictionary_arrow.py +++ b/tests/fast/arrow/test_dictionary_arrow.py @@ -17,7 +17,7 @@ def test_dictionary(self, duckdb_cursor): indices = pa.array([0, 1, 0, 1, 2, 1, 0, 2]) dictionary = pa.array([10, 100, None]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) assert rel.execute().fetchall() == [(10,), (100,), (10,), (100,), (None,), (100,), (10,), (None,)] @@ -27,14 +27,14 @@ def test_dictionary(self, duckdb_cursor): indices = pa.array(indices_list) dictionary = pa.array([10, 100, None, 999999]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) result = [(10,), (100,), (10,), (100,), (None,), (100,), (10,), (None,), (999999,)] * 10000 assert rel.execute().fetchall() == result # Table with dictionary and normal array - arrow_table = pa.Table.from_arrays([dict_array, pa.array(indices_list)], ['a', 'b']) + arrow_table = pa.Table.from_arrays([dict_array, pa.array(indices_list)], ["a", "b"]) rel = duckdb_cursor.from_arrow(arrow_table) result = [(10, 0), (100, 1), (10, 0), (100, 1), (None, 2), (100, 1), (10, 0), (None, 2), (999999, 3)] * 10000 assert rel.execute().fetchall() == result @@ -43,7 +43,7 @@ def test_dictionary_null_index(self, duckdb_cursor): indices = pa.array([None, 1, 0, 1, 2, 1, 0, 2]) dictionary = pa.array([10, 100, None]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) assert rel.execute().fetchall() == [(None,), (100,), (10,), (100,), (None,), (100,), (10,), (None,)] @@ -51,7 +51,7 @@ def test_dictionary_null_index(self, duckdb_cursor): indices = pa.array([None, 1, None, 1, 2, 1, 0]) dictionary = pa.array([10, 100, 100]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) print(rel.execute().fetchall()) assert rel.execute().fetchall() == [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] @@ -61,19 +61,19 @@ def test_dictionary_null_index(self, duckdb_cursor): indices = pa.array(indices_list * 1000) dictionary = pa.array([10, 100, 100]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) result = [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] * 1000 assert rel.execute().fetchall() == result # Table with dictionary and normal array - arrow_table = pa.Table.from_arrays([dict_array, indices], ['a', 'b']) + arrow_table = pa.Table.from_arrays([dict_array, indices], ["a", "b"]) rel = duckdb_cursor.from_arrow(arrow_table) result = [(None, None), (100, 1), (None, None), (100, 1), (100, 2), (100, 1), (10, 0)] * 1000 assert rel.execute().fetchall() == result @pytest.mark.parametrize( - 'element', + "element", [ # list """ @@ -110,7 +110,7 @@ def test_dictionary_null_index(self, duckdb_cursor): ], ) @pytest.mark.parametrize( - 'count', + "count", [ 1, 10, @@ -123,14 +123,14 @@ def test_dictionary_null_index(self, duckdb_cursor): 5000, ], ) - @pytest.mark.parametrize('query', ["select {} as a from range({})", "select [{} for x in range({})] as a"]) + @pytest.mark.parametrize("query", ["select {} as a from range({})", "select [{} for x in range({})] as a"]) def test_dictionary_roundtrip(self, query, element, duckdb_cursor, count): query = query.format(element, count) original_rel = duckdb_cursor.sql(query) expected = original_rel.fetchall() arrow_res = original_rel.fetch_arrow_table() - roundtrip_rel = duckdb_cursor.sql('select * from arrow_res') + roundtrip_rel = duckdb_cursor.sql("select * from arrow_res") actual = roundtrip_rel.fetchall() assert expected == actual assert original_rel.columns == roundtrip_rel.columns @@ -142,14 +142,14 @@ def test_dictionary_batches(self, duckdb_cursor): indices = pa.array(indices_list * 10000) dictionary = pa.array([10, 100, 100]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) batch_arrow_table = pa.Table.from_batches(arrow_table.to_batches(10)) rel = duckdb_cursor.from_arrow(batch_arrow_table) result = [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] * 10000 assert rel.execute().fetchall() == result # Table with dictionary and normal array - arrow_table = pa.Table.from_arrays([dict_array, indices], ['a', 'b']) + arrow_table = pa.Table.from_arrays([dict_array, indices], ["a", "b"]) batch_arrow_table = pa.Table.from_batches(arrow_table.to_batches(10)) rel = duckdb_cursor.from_arrow(batch_arrow_table) result = [(None, None), (100, 1), (None, None), (100, 1), (100, 2), (100, 1), (10, 0)] * 10000 @@ -157,14 +157,14 @@ def test_dictionary_batches(self, duckdb_cursor): def test_dictionary_lifetime(self, duckdb_cursor): tables = [] - expected = '' + expected = "" for i in range(100): if i % 3 == 0: - input = 'ABCD' * 17000 + input = "ABCD" * 17000 elif i % 3 == 1: - input = 'FOOO' * 17000 + input = "FOOO" * 17000 else: - input = 'BARR' * 17000 + input = "BARR" * 17000 expected += input array = pa.array( input, @@ -186,14 +186,14 @@ def test_dictionary_batches_parallel(self, duckdb_cursor): indices = pa.array(indices_list * 10000) dictionary = pa.array([10, 100, 100]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) batch_arrow_table = pa.Table.from_batches(arrow_table.to_batches(10)) rel = duckdb_cursor.from_arrow(batch_arrow_table) result = [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] * 10000 assert rel.execute().fetchall() == result # Table with dictionary and normal array - arrow_table = pa.Table.from_arrays([dict_array, indices], ['a', 'b']) + arrow_table = pa.Table.from_arrays([dict_array, indices], ["a", "b"]) batch_arrow_table = pa.Table.from_batches(arrow_table.to_batches(10)) rel = duckdb_cursor.from_arrow(batch_arrow_table) result = [(None, None), (100, 1), (None, None), (100, 1), (100, 2), (100, 1), (10, 0)] * 10000 @@ -214,7 +214,7 @@ def test_dictionary_index_types(self, duckdb_cursor): for index_type in index_types: dict_array = pa.DictionaryArray.from_arrays(index_type, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) result = [(None,), (100,), (None,), (100,), (100,), (100,), (10,)] * 10000 assert rel.execute().fetchall() == result @@ -222,17 +222,17 @@ def test_dictionary_index_types(self, duckdb_cursor): def test_dictionary_strings(self, duckdb_cursor): indices_list = [None, 0, 1, 2, 3, 4, None] indices = pa.array(indices_list * 1000) - dictionary = pa.array(['Matt Daaaaaaaaamon', 'Alec Baldwin', 'Sean Penn', 'Tim Robbins', 'Samuel L. Jackson']) + dictionary = pa.array(["Matt Daaaaaaaaamon", "Alec Baldwin", "Sean Penn", "Tim Robbins", "Samuel L. Jackson"]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) result = [ (None,), - ('Matt Daaaaaaaaamon',), - ('Alec Baldwin',), - ('Sean Penn',), - ('Tim Robbins',), - ('Samuel L. Jackson',), + ("Matt Daaaaaaaaamon",), + ("Alec Baldwin",), + ("Sean Penn",), + ("Tim Robbins",), + ("Samuel L. Jackson",), (None,), ] * 1000 assert rel.execute().fetchall() == result @@ -249,7 +249,7 @@ def test_dictionary_timestamps(self, duckdb_cursor): ] ) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) print(rel.execute().fetchall()) expected = [ diff --git a/tests/fast/arrow/test_filter_pushdown.py b/tests/fast/arrow/test_filter_pushdown.py index dffa9631..026b52f4 100644 --- a/tests/fast/arrow/test_filter_pushdown.py +++ b/tests/fast/arrow/test_filter_pushdown.py @@ -17,7 +17,7 @@ def create_pyarrow_pandas(rel): if not pandas_supports_arrow_backend(): pytest.skip(reason="Pandas version doesn't support 'pyarrow' backend") - return rel.df().convert_dtypes(dtype_backend='pyarrow') + return rel.df().convert_dtypes(dtype_backend="pyarrow") def create_pyarrow_table(rel): @@ -34,7 +34,7 @@ def test_decimal_filter_pushdown(duckdb_cursor): np = pytest.importorskip("numpy") np.random.seed(10) - df = pl.DataFrame({'x': pl.Series(np.random.uniform(-10, 10, 1000)).cast(pl.Decimal(precision=18, scale=4))}) + df = pl.DataFrame({"x": pl.Series(np.random.uniform(-10, 10, 1000)).cast(pl.Decimal(precision=18, scale=4))}) query = """ SELECT @@ -179,34 +179,33 @@ def string_check_or_pushdown(connection, tbl_name, create_table): class TestArrowFilterPushdown(object): - @pytest.mark.parametrize( - 'data_type', + "data_type", [ - 'TINYINT', - 'SMALLINT', - 'INTEGER', - 'BIGINT', - 'UTINYINT', - 'USMALLINT', - 'UINTEGER', - 'UBIGINT', - 'FLOAT', - 'DOUBLE', - 'HUGEINT', - 'DECIMAL(4,1)', - 'DECIMAL(9,1)', - 'DECIMAL(18,4)', - 'DECIMAL(30,12)', + "TINYINT", + "SMALLINT", + "INTEGER", + "BIGINT", + "UTINYINT", + "USMALLINT", + "UINTEGER", + "UBIGINT", + "FLOAT", + "DOUBLE", + "HUGEINT", + "DECIMAL(4,1)", + "DECIMAL(9,1)", + "DECIMAL(18,4)", + "DECIMAL(30,12)", ], ) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_numeric(self, data_type, duckdb_cursor, create_table): tbl_name = "tbl" numeric_operators(duckdb_cursor, data_type, tbl_name, create_table) numeric_check_or_pushdown(duckdb_cursor, tbl_name, create_table) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_varchar(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -259,7 +258,7 @@ def test_filter_pushdown_varchar(self, duckdb_cursor, create_table): # More complex tests for OR pushed down on string string_check_or_pushdown(duckdb_cursor, "test_varchar", create_table) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_bool(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -294,7 +293,7 @@ def test_filter_pushdown_bool(self, duckdb_cursor, create_table): # Try Or assert duckdb_cursor.execute("SELECT count(*) from arrow_table where a = True or b = True").fetchone()[0] == 3 - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_time(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -352,7 +351,7 @@ def test_filter_pushdown_time(self, duckdb_cursor, create_table): == 2 ) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_timestamp(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -422,7 +421,7 @@ def test_filter_pushdown_timestamp(self, duckdb_cursor, create_table): == 2 ) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_timestamp_TZ(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -494,18 +493,18 @@ def test_filter_pushdown_timestamp_TZ(self, duckdb_cursor, create_table): == 2 ) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) @pytest.mark.parametrize( - ['data_type', 'value'], + ["data_type", "value"], [ - ['TINYINT', 127], - ['SMALLINT', 32767], - ['INTEGER', 2147483647], - ['BIGINT', 9223372036854775807], - ['UTINYINT', 255], - ['USMALLINT', 65535], - ['UINTEGER', 4294967295], - ['UBIGINT', 18446744073709551615], + ["TINYINT", 127], + ["SMALLINT", 32767], + ["INTEGER", 2147483647], + ["BIGINT", 9223372036854775807], + ["UTINYINT", 255], + ["USMALLINT", 65535], + ["UINTEGER", 4294967295], + ["UBIGINT", 18446744073709551615], ], ) def test_filter_pushdown_integers(self, duckdb_cursor, data_type, value, create_table): @@ -514,9 +513,9 @@ def test_filter_pushdown_integers(self, duckdb_cursor, data_type, value, create_ CREATE TABLE tbl as select {value}::{data_type} as i """ ) - expected = duckdb_cursor.table('tbl').fetchall() + expected = duckdb_cursor.table("tbl").fetchall() filter = "i > 0" - rel = duckdb_cursor.table('tbl') + rel = duckdb_cursor.table("tbl") arrow_table = create_table(rel) actual = duckdb_cursor.sql(f"select * from arrow_table where {filter}").fetchall() assert expected == actual @@ -529,7 +528,7 @@ def test_filter_pushdown_integers(self, duckdb_cursor, data_type, value, create_ assert expected == actual @pytest.mark.skipif( - Version(pa.__version__) < Version('15.0.0'), reason="pyarrow 14.0.2 'to_pandas' causes a DeprecationWarning" + Version(pa.__version__) < Version("15.0.0"), reason="pyarrow 14.0.2 'to_pandas' causes a DeprecationWarning" ) def test_9371(self, duckdb_cursor, tmp_path): import datetime @@ -546,7 +545,7 @@ def test_9371(self, duckdb_cursor, tmp_path): # Example data dt = datetime.datetime(2023, 8, 29, 1, tzinfo=datetime.timezone.utc) - my_arrow_table = pa.Table.from_pydict({'ts': [dt, dt, dt], 'value': [1, 2, 3]}) + my_arrow_table = pa.Table.from_pydict({"ts": [dt, dt, dt], "value": [1, 2, 3]}) df = my_arrow_table.to_pandas() df = df.set_index("ts") # SET INDEX! (It all works correctly when the index is not set) df.to_parquet(str(file_path)) @@ -557,7 +556,7 @@ def test_9371(self, duckdb_cursor, tmp_path): expected = [(1, dt), (2, dt), (3, dt)] assert output == expected - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_date(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -617,15 +616,15 @@ def test_filter_pushdown_date(self, duckdb_cursor, create_table): == 2 ) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_blob(self, duckdb_cursor, create_table): import pandas df = pandas.DataFrame( { - 'a': [bytes([1]), bytes([2]), bytes([3]), None], - 'b': [bytes([1]), bytes([2]), bytes([3]), None], - 'c': [bytes([1]), bytes([2]), bytes([3]), None], + "a": [bytes([1]), bytes([2]), bytes([3]), None], + "b": [bytes([1]), bytes([2]), bytes([3]), None], + "c": [bytes([1]), bytes([2]), bytes([3]), None], } ) rel = duckdb.from_df(df) @@ -660,7 +659,7 @@ def test_filter_pushdown_blob(self, duckdb_cursor, create_table): duckdb_cursor.execute("SELECT count(*) from arrow_table where a = '\x01' or b = '\x02'").fetchone()[0] == 2 ) - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table, create_pyarrow_dataset]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table, create_pyarrow_dataset]) def test_filter_pushdown_no_projection(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -685,7 +684,7 @@ def test_filter_pushdown_no_projection(self, duckdb_cursor, create_table): assert duckdb_cursor.execute("SELECT * FROM arrow_table VALUES where a = 1").fetchall() == [(1, 1, 1)] - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_pushdown_2145(self, duckdb_cursor, tmp_path, create_table): import pandas @@ -697,12 +696,12 @@ def test_filter_pushdown_2145(self, duckdb_cursor, tmp_path, create_table): df2 = pandas.DataFrame(np.random.randn(date2.shape[0], 5), columns=list("ABCDE")) df2["date"] = date2 - data1 = tmp_path / 'data1.parquet' - data2 = tmp_path / 'data2.parquet' + data1 = tmp_path / "data1.parquet" + data2 = tmp_path / "data2.parquet" duckdb_cursor.execute(f"copy (select * from df1) to '{data1.as_posix()}'") duckdb_cursor.execute(f"copy (select * from df2) to '{data2.as_posix()}'") - glob_pattern = tmp_path / 'data*.parquet' + glob_pattern = tmp_path / "data*.parquet" table = duckdb_cursor.read_parquet(glob_pattern.as_posix()).fetch_arrow_table() output_df = duckdb.arrow(table).filter("date > '2019-01-01'").df() @@ -710,7 +709,7 @@ def test_filter_pushdown_2145(self, duckdb_cursor, tmp_path, create_table): pandas.testing.assert_frame_equal(expected_df, output_df) # https://github.com/duckdb/duckdb/pull/4817/files#r1339973721 - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_filter_column_removal(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -738,7 +737,7 @@ def test_filter_column_removal(self, duckdb_cursor, create_table): assert not match @pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9") - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_struct_filter_pushdown(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -768,7 +767,7 @@ def test_struct_filter_pushdown(self, duckdb_cursor, create_table): ).fetchall() input = query_res[0][1] - if 'PANDAS_SCAN' in input: + if "PANDAS_SCAN" in input: pytest.skip(reason="This version of pandas does not produce an Arrow object") match = re.search(r".*ARROW_SCAN.*Filters:.*s\.a<2.*", input, flags=re.DOTALL) assert match @@ -809,7 +808,7 @@ def test_struct_filter_pushdown(self, duckdb_cursor, create_table): assert not match @pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9") - @pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table]) + @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_nested_struct_filter_pushdown(self, duckdb_cursor, create_table): duckdb_cursor.execute( """ @@ -838,15 +837,15 @@ def test_nested_struct_filter_pushdown(self, duckdb_cursor, create_table): ).fetchall() input = query_res[0][1] - if 'PANDAS_SCAN' in input: + if "PANDAS_SCAN" in input: pytest.skip(reason="This version of pandas does not produce an Arrow object") match = re.search(r".*ARROW_SCAN.*Filters:.*s\.a\.b<2.*", input, flags=re.DOTALL) assert match # Check that the filter is applied correctly assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.a.b < 2").fetchone()[0] == { - 'a': {'b': 1, 'c': False}, - 'd': {'e': 2, 'f': 'foo'}, + "a": {"b": 1, "c": False}, + "d": {"e": 2, "f": "foo"}, } query_res = duckdb_cursor.execute( @@ -866,8 +865,8 @@ def test_nested_struct_filter_pushdown(self, duckdb_cursor, create_table): # Check that the filter is applied correctly assert duckdb_cursor.execute("SELECT COUNT(*) FROM arrow_table WHERE s.a.c=true AND s.d.e=5").fetchone()[0] == 1 assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.a.c=true AND s.d.e=5").fetchone()[0] == { - 'a': {'b': None, 'c': True}, - 'd': {'e': 5, 'f': 'qux'}, + "a": {"b": None, "c": True}, + "d": {"e": 5, "f": "qux"}, } query_res = duckdb_cursor.execute( @@ -887,8 +886,8 @@ def test_nested_struct_filter_pushdown(self, duckdb_cursor, create_table): # Check that the filter is applied correctly assert duckdb_cursor.execute("SELECT * FROM arrow_table WHERE s.d.f = 'bar'").fetchone()[0] == { - 'a': {'b': 3, 'c': True}, - 'd': {'e': 4, 'f': 'bar'}, + "a": {"b": 3, "c": True}, + "d": {"e": 4, "f": "bar"}, } def test_filter_pushdown_not_supported(self): @@ -899,21 +898,21 @@ def test_filter_pushdown_not_supported(self): arrow_tbl = con.execute("FROM T").fetch_arrow_table() # No projection just unsupported filter - assert con.execute("from arrow_tbl where c == 3").fetchall() == [(3, '3', 3, 3)] + assert con.execute("from arrow_tbl where c == 3").fetchall() == [(3, "3", 3, 3)] # No projection unsupported + supported filter - assert con.execute("from arrow_tbl where c < 4 and a > 2").fetchall() == [(3, '3', 3, 3)] + assert con.execute("from arrow_tbl where c < 4 and a > 2").fetchall() == [(3, "3", 3, 3)] # No projection supported + unsupported + supported filter - assert con.execute("from arrow_tbl where a > 2 and c < 4 and b == '3' ").fetchall() == [(3, '3', 3, 3)] + assert con.execute("from arrow_tbl where a > 2 and c < 4 and b == '3' ").fetchall() == [(3, "3", 3, 3)] assert con.execute("from arrow_tbl where a > 2 and c < 4 and b == '0' ").fetchall() == [] # Projection with unsupported filter column + unsupported + supported filter - assert con.execute("select c, b from arrow_tbl where c < 4 and b == '3' and a > 2 ").fetchall() == [(3, '3')] - assert con.execute("select c, b from arrow_tbl where a > 2 and c < 4 and b == '3'").fetchall() == [(3, '3')] + assert con.execute("select c, b from arrow_tbl where c < 4 and b == '3' and a > 2 ").fetchall() == [(3, "3")] + assert con.execute("select c, b from arrow_tbl where a > 2 and c < 4 and b == '3'").fetchall() == [(3, "3")] # Projection without unsupported filter column + unsupported + supported filter - assert con.execute("select a, b from arrow_tbl where a > 2 and c < 4 and b == '3' ").fetchall() == [(3, '3')] + assert con.execute("select a, b from arrow_tbl where a > 2 and c < 4 and b == '3' ").fetchall() == [(3, "3")] # Lets also experiment with multiple unpush-able filters con.execute( @@ -924,7 +923,7 @@ def test_filter_pushdown_not_supported(self): assert con.execute( "select a, b from arrow_tbl where a > 2 and c < 40 and b == '28' and g > 15 and e < 30" - ).fetchall() == [(28, '28')] + ).fetchall() == [(28, "28")] def test_join_filter_pushdown(self, duckdb_cursor): duckdb_conn = duckdb.connect() @@ -951,18 +950,18 @@ def test_in_filter_pushdown(self, duckdb_cursor): def test_pushdown_of_optional_filter(self, duckdb_cursor): cardinality_table = pa.Table.from_pydict( { - 'column_name': [ - 'id', - 'product_code', - 'price', - 'quantity', - 'category', - 'is_available', - 'rating', - 'discount', - 'color', + "column_name": [ + "id", + "product_code", + "price", + "quantity", + "category", + "is_available", + "rating", + "discount", + "color", ], - 'cardinality': [100, 100, 100, 45, 5, 3, 6, 39, 5], + "cardinality": [100, 100, 100, 45, 5, 3, 6, 39, 5], } ) @@ -976,15 +975,15 @@ def test_pushdown_of_optional_filter(self, duckdb_cursor): ) res = result.fetchall() assert res == [ - ('is_available', 3), - ('category', 5), - ('color', 5), - ('rating', 6), - ('discount', 39), - ('quantity', 45), - ('id', 100), - ('product_code', 100), - ('price', 100), + ("is_available", 3), + ("category", 5), + ("color", 5), + ("rating", 6), + ("discount", 39), + ("quantity", 45), + ("id", 100), + ("product_code", 100), + ("price", 100), ] # DuckDB intentionally violates IEEE-754 when it comes to NaNs, ensuring a total ordering where NaN is the greatest value @@ -1002,11 +1001,11 @@ def test_nan_filter_pushdown(self, duckdb_cursor): ) def assert_equal_results(con, arrow_table, query): - duckdb_res = con.sql(query.format(table='test')).fetchall() - arrow_res = con.sql(query.format(table='arrow_table')).fetchall() + duckdb_res = con.sql(query.format(table="test")).fetchall() + arrow_res = con.sql(query.format(table="arrow_table")).fetchall() assert len(duckdb_res) == len(arrow_res) - arrow_table = duckdb_cursor.table('test').fetch_arrow_table() + arrow_table = duckdb_cursor.table("test").fetch_arrow_table() assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a > 'NaN'::FLOAT") assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a >= 'NaN'::FLOAT") assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a < 'NaN'::FLOAT") diff --git a/tests/fast/arrow/test_integration.py b/tests/fast/arrow/test_integration.py index d9006758..6ab7350d 100644 --- a/tests/fast/arrow/test_integration.py +++ b/tests/fast/arrow/test_integration.py @@ -10,8 +10,8 @@ class TestArrowIntegration(object): def test_parquet_roundtrip(self, duckdb_cursor): - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") + cols = "id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments" # TODO timestamp @@ -35,8 +35,8 @@ def test_parquet_roundtrip(self, duckdb_cursor): assert rel_from_arrow.equals(rel_from_duckdb, check_metadata=True) def test_unsigned_roundtrip(self, duckdb_cursor): - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'unsigned.parquet') - cols = 'a, b, c, d' + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "unsigned.parquet") + cols = "a, b, c, d" unsigned_parquet_table = pq.read_table(parquet_filename) unsigned_parquet_table.validate(full=True) @@ -82,16 +82,16 @@ def test_decimals_roundtrip(self, duckdb_cursor): "SELECT typeof(a), typeof(b), typeof(c),typeof(d) from testarrow" ).fetchone() - assert arrow_result[0] == 'DECIMAL(4,2)' - assert arrow_result[1] == 'DECIMAL(9,2)' - assert arrow_result[2] == 'DECIMAL(18,2)' - assert arrow_result[3] == 'DECIMAL(30,2)' + assert arrow_result[0] == "DECIMAL(4,2)" + assert arrow_result[1] == "DECIMAL(9,2)" + assert arrow_result[2] == "DECIMAL(18,2)" + assert arrow_result[3] == "DECIMAL(30,2)" # Lets also test big number comming from arrow land data = pa.array(np.array([9999999999999999999999999999999999]), type=pa.decimal128(38, 0)) - arrow_tbl = pa.Table.from_arrays([data], ['a']) + arrow_tbl = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_tbl).create("bigdecimal") - result = duckdb_cursor.execute('select * from bigdecimal') + result = duckdb_cursor.execute("select * from bigdecimal") assert result.fetchone()[0] == 9999999999999999999999999999999999 def test_intervals_roundtrip(self, duckdb_cursor): @@ -110,9 +110,9 @@ def test_intervals_roundtrip(self, duckdb_cursor): arr = [expected_value] data = pa.array(arr, pa.month_day_nano_interval()) - arrow_tbl = pa.Table.from_arrays([data], ['a']) + arrow_tbl = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_tbl).create("intervaltbl") - duck_arrow_tbl = duckdb_cursor.table("intervaltbl").fetch_arrow_table()['a'] + duck_arrow_tbl = duckdb_cursor.table("intervaltbl").fetch_arrow_table()["a"] assert duck_arrow_tbl[0].value == expected_value @@ -120,7 +120,7 @@ def test_intervals_roundtrip(self, duckdb_cursor): duckdb_cursor.execute("CREATE TABLE test (a INTERVAL)") duckdb_cursor.execute("INSERT INTO test VALUES (INTERVAL 1 YEAR + INTERVAL 1 DAY + INTERVAL 1 SECOND)") expected_value = pa.MonthDayNano([12, 1, 1000000000]) - duck_tbl_arrow = duckdb_cursor.table("test").fetch_arrow_table()['a'] + duck_tbl_arrow = duckdb_cursor.table("test").fetch_arrow_table()["a"] assert duck_tbl_arrow[0].value.months == expected_value.months assert duck_tbl_arrow[0].value.days == expected_value.days assert duck_tbl_arrow[0].value.nanoseconds == expected_value.nanoseconds @@ -140,9 +140,9 @@ def test_null_intervals_roundtrip(self, duckdb_cursor): ) arr = [None, expected_value] data = pa.array(arr, pa.month_day_nano_interval()) - arrow_tbl = pa.Table.from_arrays([data], ['a']) + arrow_tbl = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_tbl).create("intervalnulltbl") - duckdb_tbl_arrow = duckdb_cursor.table("intervalnulltbl").fetch_arrow_table()['a'] + duckdb_tbl_arrow = duckdb_cursor.table("intervalnulltbl").fetch_arrow_table()["a"] assert duckdb_tbl_arrow[0].value == None assert duckdb_tbl_arrow[1].value == expected_value @@ -154,9 +154,9 @@ def test_nested_interval_roundtrip(self, duckdb_cursor): second_value = pa.MonthDayNano([90, 12, 0]) dictionary = pa.array([first_value, second_value, None]) dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) - arrow_table = pa.Table.from_arrays([dict_array], ['a']) + arrow_table = pa.Table.from_arrays([dict_array], ["a"]) duckdb_cursor.from_arrow(arrow_table).create("dictionarytbl") - duckdb_tbl_arrow = duckdb_cursor.table("dictionarytbl").fetch_arrow_table()['a'] + duckdb_tbl_arrow = duckdb_cursor.table("dictionarytbl").fetch_arrow_table()["a"] assert duckdb_tbl_arrow[0].value == first_value assert duckdb_tbl_arrow[1].value == second_value @@ -170,7 +170,7 @@ def test_nested_interval_roundtrip(self, duckdb_cursor): # List query = duckdb_cursor.sql( "SELECT a from (select list_value(INTERVAL 3 MONTHS, INTERVAL 5 DAYS, INTERVAL 10 SECONDS, NULL) as a) as t" - ).fetch_arrow_table()['a'] + ).fetch_arrow_table()["a"] assert query[0][0].value == pa.MonthDayNano([3, 0, 0]) assert query[0][1].value == pa.MonthDayNano([0, 5, 0]) assert query[0][2].value == pa.MonthDayNano([0, 0, 10000000000]) @@ -180,25 +180,25 @@ def test_nested_interval_roundtrip(self, duckdb_cursor): query = "SELECT a from (SELECT STRUCT_PACK(a := INTERVAL 1 MONTHS, b := INTERVAL 10 DAYS, c:= INTERVAL 20 SECONDS) as a) as t" true_answer = duckdb_cursor.sql(query).fetchall() from_arrow = duckdb_cursor.from_arrow(duckdb_cursor.sql(query).fetch_arrow_table()).fetchall() - assert true_answer[0][0]['a'] == from_arrow[0][0]['a'] - assert true_answer[0][0]['b'] == from_arrow[0][0]['b'] - assert true_answer[0][0]['c'] == from_arrow[0][0]['c'] + assert true_answer[0][0]["a"] == from_arrow[0][0]["a"] + assert true_answer[0][0]["b"] == from_arrow[0][0]["b"] + assert true_answer[0][0]["c"] == from_arrow[0][0]["c"] def test_min_max_interval_roundtrip(self, duckdb_cursor): interval_min_value = pa.MonthDayNano([0, 0, 0]) interval_max_value = pa.MonthDayNano([2147483647, 2147483647, 9223372036854775000]) data = pa.array([interval_min_value, interval_max_value], pa.month_day_nano_interval()) - arrow_tbl = pa.Table.from_arrays([data], ['a']) + arrow_tbl = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_tbl).create("intervalminmaxtbl") - duck_arrow_tbl = duckdb_cursor.table("intervalminmaxtbl").fetch_arrow_table()['a'] + duck_arrow_tbl = duckdb_cursor.table("intervalminmaxtbl").fetch_arrow_table()["a"] assert duck_arrow_tbl[0].value == pa.MonthDayNano([0, 0, 0]) assert duck_arrow_tbl[1].value == pa.MonthDayNano([2147483647, 2147483647, 9223372036854775000]) def test_duplicate_column_names(self, duckdb_cursor): pd = pytest.importorskip("pandas") - df_a = pd.DataFrame({'join_key': [1, 2, 3], 'col_a': ['a', 'b', 'c']}) - df_b = pd.DataFrame({'join_key': [1, 3, 4], 'col_a': ['x', 'y', 'z']}) + df_a = pd.DataFrame({"join_key": [1, 2, 3], "col_a": ["a", "b", "c"]}) + df_b = pd.DataFrame({"join_key": [1, 3, 4], "col_a": ["x", "y", "z"]}) res = duckdb_cursor.execute( """ @@ -210,7 +210,7 @@ def test_duplicate_column_names(self, duckdb_cursor): table1.join_key = table2.join_key """ ).fetch_arrow_table() - assert res.schema.names == ['join_key', 'col_a', 'join_key', 'col_a'] + assert res.schema.names == ["join_key", "col_a", "join_key", "col_a"] def test_strings_roundtrip(self, duckdb_cursor): duckdb_cursor.execute("CREATE TABLE test (a varchar)") diff --git a/tests/fast/arrow/test_interval.py b/tests/fast/arrow/test_interval.py index a548818f..32b7fa64 100644 --- a/tests/fast/arrow/test_interval.py +++ b/tests/fast/arrow/test_interval.py @@ -17,45 +17,45 @@ def test_duration_types(self, duckdb_cursor): if not can_run: return expected_arrow = pa.Table.from_arrays( - [pa.array([pa.MonthDayNano([0, 0, 1000000000])], type=pa.month_day_nano_interval())], ['a'] + [pa.array([pa.MonthDayNano([0, 0, 1000000000])], type=pa.month_day_nano_interval())], ["a"] ) data = ( - pa.array([1000000000], type=pa.duration('ns')), - pa.array([1000000], type=pa.duration('us')), - pa.array([1000], pa.duration('ms')), - pa.array([1], pa.duration('s')), + pa.array([1000000000], type=pa.duration("ns")), + pa.array([1000000], type=pa.duration("us")), + pa.array([1000], pa.duration("ms")), + pa.array([1], pa.duration("s")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == expected_arrow['a'] - assert rel['b'] == expected_arrow['a'] - assert rel['c'] == expected_arrow['a'] - assert rel['d'] == expected_arrow['a'] + assert rel["a"] == expected_arrow["a"] + assert rel["b"] == expected_arrow["a"] + assert rel["c"] == expected_arrow["a"] + assert rel["d"] == expected_arrow["a"] def test_duration_null(self, duckdb_cursor): if not can_run: return - expected_arrow = pa.Table.from_arrays([pa.array([None], type=pa.month_day_nano_interval())], ['a']) + expected_arrow = pa.Table.from_arrays([pa.array([None], type=pa.month_day_nano_interval())], ["a"]) data = ( - pa.array([None], type=pa.duration('ns')), - pa.array([None], type=pa.duration('us')), - pa.array([None], pa.duration('ms')), - pa.array([None], pa.duration('s')), + pa.array([None], type=pa.duration("ns")), + pa.array([None], type=pa.duration("us")), + pa.array([None], pa.duration("ms")), + pa.array([None], pa.duration("s")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == expected_arrow['a'] - assert rel['b'] == expected_arrow['a'] - assert rel['c'] == expected_arrow['a'] - assert rel['d'] == expected_arrow['a'] + assert rel["a"] == expected_arrow["a"] + assert rel["b"] == expected_arrow["a"] + assert rel["c"] == expected_arrow["a"] + assert rel["d"] == expected_arrow["a"] def test_duration_overflow(self, duckdb_cursor): if not can_run: return # Only seconds can overflow - data = pa.array([9223372036854775807], pa.duration('s')) - arrow_table = pa.Table.from_arrays([data], ['a']) + data = pa.array([9223372036854775807], pa.duration("s")) + arrow_table = pa.Table.from_arrays([data], ["a"]) - with pytest.raises(duckdb.ConversionException, match='Could not convert Interval to Microsecond'): + with pytest.raises(duckdb.ConversionException, match="Could not convert Interval to Microsecond"): arrow_from_duck = duckdb.from_arrow(arrow_table).fetch_arrow_table() diff --git a/tests/fast/arrow/test_large_offsets.py b/tests/fast/arrow/test_large_offsets.py index 1bcdd1b7..dccfa101 100644 --- a/tests/fast/arrow/test_large_offsets.py +++ b/tests/fast/arrow/test_large_offsets.py @@ -18,7 +18,7 @@ def test_large_lists(self, duckdb_cursor): tbl = pa.Table.from_pydict(dict(col=ary)) with pytest.raises( duckdb.InvalidInputException, - match='Arrow Appender: The maximum combined list offset for regular list buffers is 2147483647 but the offset of 2147481000 exceeds this.', + match="Arrow Appender: The maximum combined list offset for regular list buffers is 2147483647 but the offset of 2147481000 exceeds this.", ): res = duckdb_cursor.sql("SELECT col FROM tbl").fetch_arrow_table() @@ -34,7 +34,7 @@ def test_large_maps(self, duckdb_cursor): with pytest.raises( duckdb.InvalidInputException, - match='Arrow Appender: The maximum combined list offset for regular list buffers is 2147483647 but the offset of 2147481000 exceeds this.', + match="Arrow Appender: The maximum combined list offset for regular list buffers is 2147483647 but the offset of 2147481000 exceeds this.", ): arrow_map = duckdb_cursor.sql("select map(col, col) from tbl").fetch_arrow_table() diff --git a/tests/fast/arrow/test_large_string.py b/tests/fast/arrow/test_large_string.py index 4836048d..308785af 100644 --- a/tests/fast/arrow/test_large_string.py +++ b/tests/fast/arrow/test_large_string.py @@ -22,4 +22,4 @@ def test_large_string_type(self, duckdb_cursor): rel = duckdb.from_arrow(arrow_table) res = rel.execute().fetchall() - assert res == [('foo',), ('baaaar',), ('b',)] + assert res == [("foo",), ("baaaar",), ("b",)] diff --git a/tests/fast/arrow/test_multiple_reads.py b/tests/fast/arrow/test_multiple_reads.py index 935a8a9c..36fb8f59 100644 --- a/tests/fast/arrow/test_multiple_reads.py +++ b/tests/fast/arrow/test_multiple_reads.py @@ -14,8 +14,8 @@ class TestArrowReads(object): def test_multiple_queries_same_relation(self, duckdb_cursor): if not can_run: return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") + cols = "id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments" userdata_parquet_table = pyarrow.parquet.read_table(parquet_filename) userdata_parquet_table.validate(full=True) diff --git a/tests/fast/arrow/test_nested_arrow.py b/tests/fast/arrow/test_nested_arrow.py index 693a5155..a906324f 100644 --- a/tests/fast/arrow/test_nested_arrow.py +++ b/tests/fast/arrow/test_nested_arrow.py @@ -16,13 +16,13 @@ def compare_results(duckdb_cursor, query): def arrow_to_pandas(duckdb_cursor, query): - return duckdb_cursor.query(query).fetch_arrow_table().to_pandas()['a'].values.tolist() + return duckdb_cursor.query(query).fetch_arrow_table().to_pandas()["a"].values.tolist() def get_use_list_view_options(): result = [] result.append(False) - if hasattr(pa, 'ListViewArray'): + if hasattr(pa, "ListViewArray"): result.append(True) return result @@ -32,7 +32,7 @@ def test_lists_basic(self, duckdb_cursor): # Test Constant List query = ( duckdb_cursor.query("SELECT a from (select list_value(3,5,10) as a) as t") - .fetch_arrow_table()['a'] + .fetch_arrow_table()["a"] .to_numpy() ) assert query[0][0] == 3 @@ -40,32 +40,32 @@ def test_lists_basic(self, duckdb_cursor): assert query[0][2] == 10 # Empty List - query = duckdb_cursor.query("SELECT a from (select list_value() as a) as t").fetch_arrow_table()['a'].to_numpy() + query = duckdb_cursor.query("SELECT a from (select list_value() as a) as t").fetch_arrow_table()["a"].to_numpy() assert len(query[0]) == 0 # Test Constant List With Null query = ( duckdb_cursor.query("SELECT a from (select list_value(3,NULL) as a) as t") - .fetch_arrow_table()['a'] + .fetch_arrow_table()["a"] .to_numpy() ) assert query[0][0] == 3 assert np.isnan(query[0][1]) - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) def test_list_types(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") # Large Lists data = pa.array([[1], None, [2]], type=pa.large_list(pa.int64())) - arrow_table = pa.Table.from_arrays([data], ['a']) + arrow_table = pa.Table.from_arrays([data], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) res = rel.execute().fetchall() assert res == [([1],), (None,), ([2],)] # Fixed Size Lists data = pa.array([[1], None, [2]], type=pa.list_(pa.int64(), 1)) - arrow_table = pa.Table.from_arrays([data], ['a']) + arrow_table = pa.Table.from_arrays([data], ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) res = rel.execute().fetchall() assert res == [((1,),), (None,), ((2,),)] @@ -76,27 +76,27 @@ def test_list_types(self, duckdb_cursor, use_list_view): pa.array([[1], None, [2]], type=pa.large_list(pa.int64())), pa.array([[1, 2, 3], None, [2, 1]], type=pa.list_(pa.int64())), ] - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2]], ['a', 'b', 'c']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2]], ["a", "b", "c"]) rel = duckdb_cursor.from_arrow(arrow_table) - res = rel.project('a').execute().fetchall() + res = rel.project("a").execute().fetchall() assert res == [((1,),), (None,), ((2,),)] - res = rel.project('b').execute().fetchall() + res = rel.project("b").execute().fetchall() assert res == [([1],), (None,), ([2],)] - res = rel.project('c').execute().fetchall() + res = rel.project("c").execute().fetchall() assert res == [([1, 2, 3],), (None,), ([2, 1],)] # Struct Holding different List Types - struct = [pa.StructArray.from_arrays(data, ['fixed', 'large', 'normal'])] - arrow_table = pa.Table.from_arrays(struct, ['a']) + struct = [pa.StructArray.from_arrays(data, ["fixed", "large", "normal"])] + arrow_table = pa.Table.from_arrays(struct, ["a"]) rel = duckdb_cursor.from_arrow(arrow_table) res = rel.execute().fetchall() assert res == [ - ({'fixed': (1,), 'large': [1], 'normal': [1, 2, 3]},), - ({'fixed': None, 'large': None, 'normal': None},), - ({'fixed': (2,), 'large': [2], 'normal': [2, 1]},), + ({"fixed": (1,), "large": [1], "normal": [1, 2, 3]},), + ({"fixed": None, "large": None, "normal": None},), + ({"fixed": (2,), "large": [2], "normal": [2, 1]},), ] - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) @pytest.mark.skip(reason="FIXME: this fails on CI") def test_lists_roundtrip(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") @@ -132,8 +132,8 @@ def test_lists_roundtrip(self, duckdb_cursor, use_list_view): compare_results( duckdb_cursor, - '''SELECT grp,lst,cs FROM (select grp, lst, case when grp>1 then lst else list_value(null) end as cs - from (SELECT a%4 as grp, list(a order by a) as lst FROM range(7) tbl(a) group by grp) as lst_tbl) as T order by all;''', + """SELECT grp,lst,cs FROM (select grp, lst, case when grp>1 then lst else list_value(null) end as cs + from (SELECT a%4 as grp, list(a order by a) as lst FROM range(7) tbl(a) group by grp) as lst_tbl) as T order by all;""", ) # Tests for converting multiple lists to/from Arrow with NULL values and/or strings compare_results( @@ -141,7 +141,7 @@ def test_lists_roundtrip(self, duckdb_cursor, use_list_view): "SELECT list(st order by st) from (select i, case when i%10 then NULL else i::VARCHAR end as st from range(1000) tbl(i)) as t group by i%5 order by all", ) - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) def test_struct_roundtrip(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") @@ -156,7 +156,7 @@ def test_struct_roundtrip(self, duckdb_cursor, use_list_view): "SELECT a from (SELECT STRUCT_PACK(a := LIST_VALUE(1,2,3), b := i) as a FROM range(10000) tbl(i)) as t", ) - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) def test_map_roundtrip(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") @@ -185,13 +185,13 @@ def test_map_roundtrip(self, duckdb_cursor, use_list_view): "SELECT m from (select MAP(lsta,lstb) as m from (SELECT list(i) as lsta, list(i) as lstb from range(10000) tbl(i) group by i%5 order by all) as lst_tbl) as T", ) - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) def test_map_arrow_to_duckdb(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") map_type = pa.map_(pa.int32(), pa.int32()) values = [[(3, 12), (3, 21)], [(5, 42)]] - arrow_table = pa.table({'detail': pa.array(values, map_type)}) + arrow_table = pa.table({"detail": pa.array(values, map_type)}) with pytest.raises( duckdb.InvalidInputException, match="Arrow map contains duplicate key, which isn't supported by DuckDB map type", @@ -201,11 +201,11 @@ def test_map_arrow_to_duckdb(self, duckdb_cursor, use_list_view): def test_null_map_arrow_to_duckdb(self, duckdb_cursor): map_type = pa.map_(pa.int32(), pa.int32()) values = [None, [(5, 42)]] - arrow_table = pa.table({'detail': pa.array(values, map_type)}) + arrow_table = pa.table({"detail": pa.array(values, map_type)}) res = duckdb_cursor.sql("select * from arrow_table").fetchall() assert res == [(None,), ({5: 42},)] - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) def test_map_arrow_to_pandas(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") assert arrow_to_pandas( @@ -215,16 +215,16 @@ def test_map_arrow_to_pandas(self, duckdb_cursor, use_list_view): assert arrow_to_pandas( duckdb_cursor, "SELECT a from (select MAP(LIST_VALUE('Jon Lajoie', 'Backstreet Boys', 'Tenacious D'),LIST_VALUE(10,9,10)) as a) as t", - ) == [[('Jon Lajoie', 10), ('Backstreet Boys', 9), ('Tenacious D', 10)]] + ) == [[("Jon Lajoie", 10), ("Backstreet Boys", 9), ("Tenacious D", 10)]] assert arrow_to_pandas( duckdb_cursor, "SELECT a from (select MAP(list_value(1), list_value(2)) from range(5) tbl(i)) tbl(a)" ) == [[(1, 2)], [(1, 2)], [(1, 2)], [(1, 2)], [(1, 2)]] assert arrow_to_pandas( duckdb_cursor, "SELECT MAP(LIST_VALUE({'i':1,'j':2},{'i':3,'j':4}),LIST_VALUE({'i':1,'j':2},{'i':3,'j':4})) as a", - ) == [[({'i': 1, 'j': 2}, {'i': 1, 'j': 2}), ({'i': 3, 'j': 4}, {'i': 3, 'j': 4})]] + ) == [[({"i": 1, "j": 2}, {"i": 1, "j": 2}), ({"i": 3, "j": 4}, {"i": 3, "j": 4})]] - @pytest.mark.parametrize('use_list_view', get_use_list_view_options()) + @pytest.mark.parametrize("use_list_view", get_use_list_view_options()) def test_frankstein_nested(self, duckdb_cursor, use_list_view): duckdb_cursor.execute(f"pragma arrow_output_list_view={use_list_view};") diff --git a/tests/fast/arrow/test_parallel.py b/tests/fast/arrow/test_parallel.py index 2609d1ae..c768a1dd 100644 --- a/tests/fast/arrow/test_parallel.py +++ b/tests/fast/arrow/test_parallel.py @@ -19,7 +19,7 @@ def test_parallel_run(self, duckdb_cursor): duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") data = pyarrow.array(np.random.randint(800, size=1000000), type=pyarrow.int32()) - tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ['a']).to_batches(10000)) + tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ["a"]).to_batches(10000)) rel = duckdb_conn.from_arrow(tbl) # Also test multiple reads assert rel.aggregate("(count(a))::INT").execute().fetchone()[0] == 1000000 @@ -32,17 +32,17 @@ def test_parallel_types_and_different_batches(self, duckdb_cursor): duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") + cols = "id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments" userdata_parquet_table = pyarrow.parquet.read_table(parquet_filename) for i in [7, 51, 99, 100, 101, 500, 1000, 2000]: data = pyarrow.array(np.arange(3, 7), type=pyarrow.int32()) - tbl = pyarrow.Table.from_arrays([data], ['a']) + tbl = pyarrow.Table.from_arrays([data], ["a"]) rel_id = duckdb_conn.from_arrow(tbl) userdata_parquet_table2 = pyarrow.Table.from_batches(userdata_parquet_table.to_batches(i)) rel = duckdb_conn.from_arrow(userdata_parquet_table2) - result = rel.filter("first_name=\'Jose\' and salary > 134708.82").aggregate('count(*)') + result = rel.filter("first_name='Jose' and salary > 134708.82").aggregate("count(*)") assert result.execute().fetchone()[0] == 4 def test_parallel_fewer_batches_than_threads(self, duckdb_cursor): @@ -53,7 +53,7 @@ def test_parallel_fewer_batches_than_threads(self, duckdb_cursor): duckdb_conn.execute("PRAGMA verify_parallelism") data = pyarrow.array(np.random.randint(800, size=1000), type=pyarrow.int32()) - tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ['a']).to_batches(2)) + tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ["a"]).to_batches(2)) rel = duckdb_conn.from_arrow(tbl) # Also test multiple reads assert rel.aggregate("(count(a))::INT").execute().fetchone()[0] == 1000 diff --git a/tests/fast/arrow/test_polars.py b/tests/fast/arrow/test_polars.py index 87e2f726..a4e94d18 100644 --- a/tests/fast/arrow/test_polars.py +++ b/tests/fast/arrow/test_polars.py @@ -31,21 +31,21 @@ def test_polars(self, duckdb_cursor): } ) # scan plus return a polars dataframe - polars_result = duckdb_cursor.sql('SELECT * FROM df').pl() + polars_result = duckdb_cursor.sql("SELECT * FROM df").pl() pl_testing.assert_frame_equal(df, polars_result) # now do the same for a lazy dataframe lazy_df = df.lazy() - lazy_result = duckdb_cursor.sql('SELECT * FROM lazy_df').pl() + lazy_result = duckdb_cursor.sql("SELECT * FROM lazy_df").pl() pl_testing.assert_frame_equal(df, lazy_result) con = duckdb.connect() - con_result = con.execute('SELECT * FROM df').pl() + con_result = con.execute("SELECT * FROM df").pl() pl_testing.assert_frame_equal(df, con_result) def test_execute_polars(self, duckdb_cursor): res1 = duckdb_cursor.execute("SELECT 1 AS a, 2 AS a").pl() - assert res1.columns == ['a', 'a_1'] + assert res1.columns == ["a", "a_1"] def test_register_polars(self, duckdb_cursor): con = duckdb.connect() @@ -58,21 +58,21 @@ def test_register_polars(self, duckdb_cursor): } ) # scan plus return a polars dataframe - con.register('polars_df', df) - polars_result = con.execute('select * from polars_df').pl() + con.register("polars_df", df) + polars_result = con.execute("select * from polars_df").pl() pl_testing.assert_frame_equal(df, polars_result) - con.unregister('polars_df') - with pytest.raises(duckdb.CatalogException, match='Table with name polars_df does not exist'): + con.unregister("polars_df") + with pytest.raises(duckdb.CatalogException, match="Table with name polars_df does not exist"): con.execute("SELECT * FROM polars_df;").pl() - con.register('polars_df', df.lazy()) - polars_result = con.execute('select * from polars_df').pl() + con.register("polars_df", df.lazy()) + polars_result = con.execute("select * from polars_df").pl() pl_testing.assert_frame_equal(df, polars_result) def test_empty_polars_dataframe(self, duckdb_cursor): polars_empty_df = pl.DataFrame() with pytest.raises( - duckdb.InvalidInputException, match='Provided table/dataframe must have at least one column' + duckdb.InvalidInputException, match="Provided table/dataframe must have at least one column" ): duckdb_cursor.sql("from polars_empty_df") @@ -82,7 +82,7 @@ def test_polars_from_json(self, duckdb_cursor): duckdb_cursor.sql("set arrow_lossless_conversion=false") string = StringIO("""{"entry":[{"content":{"ManagedSystem":{"test":null}}}]}""") res = duckdb_cursor.read_json(string).pl() - assert str(res['entry'][0][0]) == "{'content': {'ManagedSystem': {'test': None}}}" + assert str(res["entry"][0][0]) == "{'content': {'ManagedSystem': {'test': None}}}" @pytest.mark.skipif( not hasattr(pl.exceptions, "PanicException"), reason="Polars has no PanicException in this version" @@ -93,13 +93,13 @@ def test_polars_from_json_error(self, duckdb_cursor): duckdb_cursor.sql("set arrow_lossless_conversion=true") string = StringIO("""{"entry":[{"content":{"ManagedSystem":{"test":null}}}]}""") res = duckdb_cursor.read_json(string).pl() - assert duckdb_cursor.execute("FROM res").fetchall() == [([{'content': {'ManagedSystem': {'test': None}}}],)] + assert duckdb_cursor.execute("FROM res").fetchall() == [([{"content": {"ManagedSystem": {"test": None}}}],)] def test_polars_from_json_error(self, duckdb_cursor): conn = duckdb.connect() my_table = conn.query("select 'x' my_str").pl() my_res = duckdb.query("select my_str from my_table where my_str != 'y'") - assert my_res.fetchall() == [('x',)] + assert my_res.fetchall() == [("x",)] def test_polars_lazy_from_conn(self, duckdb_cursor): duckdb_conn = duckdb.connect() @@ -107,7 +107,7 @@ def test_polars_lazy_from_conn(self, duckdb_cursor): result = duckdb_conn.execute("SELECT 42 as bla") lazy_df = result.pl(lazy=True) - assert lazy_df.collect().to_dicts() == [{'bla': 42}] + assert lazy_df.collect().to_dicts() == [{"bla": 42}] def test_polars_lazy(self, duckdb_cursor): con = duckdb.connect() @@ -118,18 +118,18 @@ def test_polars_lazy(self, duckdb_cursor): assert isinstance(lazy_df, pl.LazyFrame) assert lazy_df.collect().to_dicts() == [ - {'a': 'Pedro', 'b': 32}, - {'a': 'Mark', 'b': 31}, - {'a': 'Thijs', 'b': 29}, + {"a": "Pedro", "b": 32}, + {"a": "Mark", "b": 31}, + {"a": "Thijs", "b": 29}, ] - assert lazy_df.select('a').collect().to_dicts() == [{'a': 'Pedro'}, {'a': 'Mark'}, {'a': 'Thijs'}] - assert lazy_df.limit(1).collect().to_dicts() == [{'a': 'Pedro', 'b': 32}] + assert lazy_df.select("a").collect().to_dicts() == [{"a": "Pedro"}, {"a": "Mark"}, {"a": "Thijs"}] + assert lazy_df.limit(1).collect().to_dicts() == [{"a": "Pedro", "b": 32}] assert lazy_df.filter(pl.col("b") < 32).collect().to_dicts() == [ - {'a': 'Mark', 'b': 31}, - {'a': 'Thijs', 'b': 29}, + {"a": "Mark", "b": 31}, + {"a": "Thijs", "b": 29}, ] - assert lazy_df.filter(pl.col("b") < 32).select('a').collect().to_dicts() == [{'a': 'Mark'}, {'a': 'Thijs'}] + assert lazy_df.filter(pl.col("b") < 32).select("a").collect().to_dicts() == [{"a": "Mark"}, {"a": "Thijs"}] def test_polars_column_with_tricky_name(self, duckdb_cursor): # Test that a polars DataFrame with a column name that is non standard still works @@ -162,23 +162,23 @@ def test_polars_column_with_tricky_name(self, duckdb_cursor): assert result.to_dicts() == [{'"xy"': 1}] @pytest.mark.parametrize( - 'data_type', + "data_type", [ - 'TINYINT', - 'SMALLINT', - 'INTEGER', - 'BIGINT', - 'UTINYINT', - 'USMALLINT', - 'UINTEGER', - 'UBIGINT', - 'FLOAT', - 'DOUBLE', - 'HUGEINT', - 'DECIMAL(4,1)', - 'DECIMAL(9,1)', - 'DECIMAL(18,4)', - 'DECIMAL(30,12)', + "TINYINT", + "SMALLINT", + "INTEGER", + "BIGINT", + "UTINYINT", + "USMALLINT", + "UINTEGER", + "UBIGINT", + "FLOAT", + "DOUBLE", + "HUGEINT", + "DECIMAL(4,1)", + "DECIMAL(9,1)", + "DECIMAL(18,4)", + "DECIMAL(30,12)", ], ) def test_polars_lazy_pushdown_numeric(self, data_type, duckdb_cursor): @@ -524,9 +524,9 @@ def test_polars_lazy_pushdown_blob(self, duckdb_cursor): df = pandas.DataFrame( { - 'a': [bytes([1]), bytes([2]), bytes([3]), None], - 'b': [bytes([1]), bytes([2]), bytes([3]), None], - 'c': [bytes([1]), bytes([2]), bytes([3]), None], + "a": [bytes([1]), bytes([2]), bytes([3]), None], + "b": [bytes([1]), bytes([2]), bytes([3]), None], + "c": [bytes([1]), bytes([2]), bytes([3]), None], } ) duck_tbl = duckdb.from_df(df) diff --git a/tests/fast/arrow/test_progress.py b/tests/fast/arrow/test_progress.py index c20ebe51..6f056937 100644 --- a/tests/fast/arrow/test_progress.py +++ b/tests/fast/arrow/test_progress.py @@ -8,7 +8,7 @@ class TestProgressBarArrow(object): def test_progress_arrow(self): - if os.name == 'nt': + if os.name == "nt": return np = pytest.importorskip("numpy") pyarrow = pytest.importorskip("pyarrow") @@ -18,9 +18,9 @@ def test_progress_arrow(self): duckdb_conn.execute("PRAGMA progress_bar_time=1") duckdb_conn.execute("PRAGMA disable_print_progress_bar") - tbl = pyarrow.Table.from_arrays([data], ['a']) + tbl = pyarrow.Table.from_arrays([data], ["a"]) rel = duckdb_conn.from_arrow(tbl) - result = rel.aggregate('sum(a)') + result = rel.aggregate("sum(a)") assert result.execute().fetchone()[0] == 49999995000000 # Multiple Threads duckdb_conn.execute("PRAGMA threads=4") @@ -28,9 +28,9 @@ def test_progress_arrow(self): assert result.execute().fetchone()[0] == 49999995000000 # More than one batch - tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ['a']).to_batches(100)) + tbl = pyarrow.Table.from_batches(pyarrow.Table.from_arrays([data], ["a"]).to_batches(100)) rel = duckdb_conn.from_arrow(tbl) - result = rel.aggregate('sum(a)') + result = rel.aggregate("sum(a)") assert result.execute().fetchone()[0] == 49999995000000 # Single Thread @@ -40,7 +40,7 @@ def test_progress_arrow(self): assert py_res == 49999995000000 def test_progress_arrow_empty(self): - if os.name == 'nt': + if os.name == "nt": return np = pytest.importorskip("numpy") pyarrow = pytest.importorskip("pyarrow") @@ -50,7 +50,7 @@ def test_progress_arrow_empty(self): duckdb_conn.execute("PRAGMA progress_bar_time=1") duckdb_conn.execute("PRAGMA disable_print_progress_bar") - tbl = pyarrow.Table.from_arrays([data], ['a']) + tbl = pyarrow.Table.from_arrays([data], ["a"]) rel = duckdb_conn.from_arrow(tbl) - result = rel.aggregate('sum(a)') + result = rel.aggregate("sum(a)") assert result.execute().fetchone()[0] == None diff --git a/tests/fast/arrow/test_time.py b/tests/fast/arrow/test_time.py index 726b0f6a..e7c4404e 100644 --- a/tests/fast/arrow/test_time.py +++ b/tests/fast/arrow/test_time.py @@ -18,60 +18,60 @@ def test_time_types(self, duckdb_cursor): return data = ( - pa.array([1], type=pa.time32('s')), - pa.array([1000], type=pa.time32('ms')), - pa.array([1000000], pa.time64('us')), - pa.array([1000000000], pa.time64('ns')), + pa.array([1], type=pa.time32("s")), + pa.array([1000], type=pa.time32("ms")), + pa.array([1000000], pa.time64("us")), + pa.array([1000000000], pa.time64("ns")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == arrow_table['c'] - assert rel['b'] == arrow_table['c'] - assert rel['c'] == arrow_table['c'] - assert rel['d'] == arrow_table['c'] + assert rel["a"] == arrow_table["c"] + assert rel["b"] == arrow_table["c"] + assert rel["c"] == arrow_table["c"] + assert rel["d"] == arrow_table["c"] def test_time_null(self, duckdb_cursor): if not can_run: return data = ( - pa.array([None], type=pa.time32('s')), - pa.array([None], type=pa.time32('ms')), - pa.array([None], pa.time64('us')), - pa.array([None], pa.time64('ns')), + pa.array([None], type=pa.time32("s")), + pa.array([None], type=pa.time32("ms")), + pa.array([None], pa.time64("us")), + pa.array([None], pa.time64("ns")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == arrow_table['c'] - assert rel['b'] == arrow_table['c'] - assert rel['c'] == arrow_table['c'] - assert rel['d'] == arrow_table['c'] + assert rel["a"] == arrow_table["c"] + assert rel["b"] == arrow_table["c"] + assert rel["c"] == arrow_table["c"] + assert rel["d"] == arrow_table["c"] def test_max_times(self, duckdb_cursor): if not can_run: return - data = pa.array([2147483647000000], type=pa.time64('us')) - result = pa.Table.from_arrays([data], ['a']) + data = pa.array([2147483647000000], type=pa.time64("us")) + result = pa.Table.from_arrays([data], ["a"]) # Max Sec - data = pa.array([2147483647], type=pa.time32('s')) - arrow_table = pa.Table.from_arrays([data], ['a']) + data = pa.array([2147483647], type=pa.time32("s")) + arrow_table = pa.Table.from_arrays([data], ["a"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == result['a'] + assert rel["a"] == result["a"] # Max MSec - data = pa.array([2147483647000], type=pa.time64('us')) - result = pa.Table.from_arrays([data], ['a']) - data = pa.array([2147483647], type=pa.time32('ms')) - arrow_table = pa.Table.from_arrays([data], ['a']) + data = pa.array([2147483647000], type=pa.time64("us")) + result = pa.Table.from_arrays([data], ["a"]) + data = pa.array([2147483647], type=pa.time32("ms")) + arrow_table = pa.Table.from_arrays([data], ["a"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == result['a'] + assert rel["a"] == result["a"] # Max NSec - data = pa.array([9223372036854774], type=pa.time64('us')) - result = pa.Table.from_arrays([data], ['a']) - data = pa.array([9223372036854774000], type=pa.time64('ns')) - arrow_table = pa.Table.from_arrays([data], ['a']) + data = pa.array([9223372036854774], type=pa.time64("us")) + result = pa.Table.from_arrays([data], ["a"]) + data = pa.array([9223372036854774000], type=pa.time64("ns")) + arrow_table = pa.Table.from_arrays([data], ["a"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - print(rel['a']) - print(result['a']) - assert rel['a'] == result['a'] + print(rel["a"]) + print(result["a"]) + assert rel["a"] == result["a"] diff --git a/tests/fast/arrow/test_timestamp_timezone.py b/tests/fast/arrow/test_timestamp_timezone.py index 4fdadf49..08816be1 100644 --- a/tests/fast/arrow/test_timestamp_timezone.py +++ b/tests/fast/arrow/test_timestamp_timezone.py @@ -3,7 +3,7 @@ import datetime import pytz -pa = pytest.importorskip('pyarrow') +pa = pytest.importorskip("pyarrow") def generate_table(current_time, precision, timezone): @@ -13,30 +13,30 @@ def generate_table(current_time, precision, timezone): return pa.Table.from_arrays(inputs, schema=schema) -timezones = ['UTC', 'BET', 'CET', 'Asia/Kathmandu'] +timezones = ["UTC", "BET", "CET", "Asia/Kathmandu"] class TestArrowTimestampsTimezone(object): def test_timestamp_timezone(self, duckdb_cursor): - precisions = ['us', 's', 'ns', 'ms'] + precisions = ["us", "s", "ns", "ms"] current_time = datetime.datetime(2017, 11, 28, 23, 55, 59, tzinfo=pytz.UTC) con = duckdb.connect() con.execute("SET TimeZone = 'UTC'") for precision in precisions: - arrow_table = generate_table(current_time, precision, 'UTC') + arrow_table = generate_table(current_time, precision, "UTC") res_utc = con.from_arrow(arrow_table).execute().fetchall() assert res_utc[0][0] == current_time def test_timestamp_timezone_overflow(self, duckdb_cursor): - precisions = ['s', 'ms'] + precisions = ["s", "ms"] current_time = 9223372036854775807 for precision in precisions: - with pytest.raises(duckdb.ConversionException, match='Could not convert'): - arrow_table = generate_table(current_time, precision, 'UTC') + with pytest.raises(duckdb.ConversionException, match="Could not convert"): + arrow_table = generate_table(current_time, precision, "UTC") res_utc = duckdb.from_arrow(arrow_table).execute().fetchall() def test_timestamp_tz_to_arrow(self, duckdb_cursor): - precisions = ['us', 's', 'ns', 'ms'] + precisions = ["us", "s", "ns", "ms"] current_time = datetime.datetime(2017, 11, 28, 23, 55, 59) con = duckdb.connect() for precision in precisions: @@ -44,16 +44,16 @@ def test_timestamp_tz_to_arrow(self, duckdb_cursor): con.execute("SET TimeZone = '" + timezone + "'") arrow_table = generate_table(current_time, precision, timezone) res = con.from_arrow(arrow_table).fetch_arrow_table() - assert res[0].type == pa.timestamp('us', tz=timezone) - assert res == generate_table(current_time, 'us', timezone) + assert res[0].type == pa.timestamp("us", tz=timezone) + assert res == generate_table(current_time, "us", timezone) def test_timestamp_tz_with_null(self, duckdb_cursor): con = duckdb.connect() con.execute("create table t (i timestamptz)") con.execute("insert into t values (NULL),('2021-11-15 02:30:00'::timestamptz)") - rel = con.table('t') + rel = con.table("t") arrow_tbl = rel.fetch_arrow_table() - con.register('t2', arrow_tbl) + con.register("t2", arrow_tbl) assert con.execute("select * from t").fetchall() == con.execute("select * from t2").fetchall() @@ -61,8 +61,8 @@ def test_timestamp_stream(self, duckdb_cursor): con = duckdb.connect() con.execute("create table t (i timestamptz)") con.execute("insert into t values (NULL),('2021-11-15 02:30:00'::timestamptz)") - rel = con.table('t') + rel = con.table("t") arrow_tbl = rel.record_batch().read_all() - con.register('t2', arrow_tbl) + con.register("t2", arrow_tbl) assert con.execute("select * from t").fetchall() == con.execute("select * from t2").fetchall() diff --git a/tests/fast/arrow/test_timestamps.py b/tests/fast/arrow/test_timestamps.py index c2529c83..684a333c 100644 --- a/tests/fast/arrow/test_timestamps.py +++ b/tests/fast/arrow/test_timestamps.py @@ -17,61 +17,61 @@ def test_timestamp_types(self, duckdb_cursor): if not can_run: return data = ( - pa.array([datetime.datetime.now()], type=pa.timestamp('ns')), - pa.array([datetime.datetime.now()], type=pa.timestamp('us')), - pa.array([datetime.datetime.now()], pa.timestamp('ms')), - pa.array([datetime.datetime.now()], pa.timestamp('s')), + pa.array([datetime.datetime.now()], type=pa.timestamp("ns")), + pa.array([datetime.datetime.now()], type=pa.timestamp("us")), + pa.array([datetime.datetime.now()], pa.timestamp("ms")), + pa.array([datetime.datetime.now()], pa.timestamp("s")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == arrow_table['a'] - assert rel['b'] == arrow_table['b'] - assert rel['c'] == arrow_table['c'] - assert rel['d'] == arrow_table['d'] + assert rel["a"] == arrow_table["a"] + assert rel["b"] == arrow_table["b"] + assert rel["c"] == arrow_table["c"] + assert rel["d"] == arrow_table["d"] def test_timestamp_nulls(self, duckdb_cursor): if not can_run: return data = ( - pa.array([None], type=pa.timestamp('ns')), - pa.array([None], type=pa.timestamp('us')), - pa.array([None], pa.timestamp('ms')), - pa.array([None], pa.timestamp('s')), + pa.array([None], type=pa.timestamp("ns")), + pa.array([None], type=pa.timestamp("us")), + pa.array([None], pa.timestamp("ms")), + pa.array([None], pa.timestamp("s")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ['a', 'b', 'c', 'd']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert rel['a'] == arrow_table['a'] - assert rel['b'] == arrow_table['b'] - assert rel['c'] == arrow_table['c'] - assert rel['d'] == arrow_table['d'] + assert rel["a"] == arrow_table["a"] + assert rel["b"] == arrow_table["b"] + assert rel["c"] == arrow_table["c"] + assert rel["d"] == arrow_table["d"] def test_timestamp_overflow(self, duckdb_cursor): if not can_run: return data = ( - pa.array([9223372036854775807], pa.timestamp('s')), - pa.array([9223372036854775807], pa.timestamp('ms')), - pa.array([9223372036854775807], pa.timestamp('us')), + pa.array([9223372036854775807], pa.timestamp("s")), + pa.array([9223372036854775807], pa.timestamp("ms")), + pa.array([9223372036854775807], pa.timestamp("us")), ) - arrow_table = pa.Table.from_arrays([data[0], data[1], data[2]], ['a', 'b', 'c']) + arrow_table = pa.Table.from_arrays([data[0], data[1], data[2]], ["a", "b", "c"]) arrow_from_duck = duckdb.from_arrow(arrow_table).fetch_arrow_table() - assert arrow_from_duck['a'] == arrow_table['a'] - assert arrow_from_duck['b'] == arrow_table['b'] - assert arrow_from_duck['c'] == arrow_table['c'] + assert arrow_from_duck["a"] == arrow_table["a"] + assert arrow_from_duck["b"] == arrow_table["b"] + assert arrow_from_duck["c"] == arrow_table["c"] expected = (datetime.datetime(9999, 12, 31, 23, 59, 59, 999999),) duck_rel = duckdb.from_arrow(arrow_table) - res = duck_rel.project('a::TIMESTAMP_US') + res = duck_rel.project("a::TIMESTAMP_US") result = res.fetchone() assert result == expected duck_rel = duckdb.from_arrow(arrow_table) - res = duck_rel.project('b::TIMESTAMP_US') + res = duck_rel.project("b::TIMESTAMP_US") result = res.fetchone() assert result == expected duck_rel = duckdb.from_arrow(arrow_table) - res = duck_rel.project('c::TIMESTAMP_NS') + res = duck_rel.project("c::TIMESTAMP_NS") result = res.fetchone() assert result == expected diff --git a/tests/fast/arrow/test_tpch.py b/tests/fast/arrow/test_tpch.py index ff4a0445..d5d13b20 100644 --- a/tests/fast/arrow/test_tpch.py +++ b/tests/fast/arrow/test_tpch.py @@ -24,7 +24,7 @@ def check_result(result, answers): db_result = result.fetchone() cq_results = q_res.split("|") # The end of the rows, continue - if cq_results == [''] and str(db_result) == 'None' or str(db_result[0]) == 'None': + if cq_results == [""] and str(db_result) == "None" or str(db_result[0]) == "None": continue ans_result = [munge(cell) for cell in cq_results] db_result = [munge(cell) for cell in db_result] @@ -39,7 +39,7 @@ def test_tpch_arrow(self, duckdb_cursor): if not can_run: return - tpch_tables = ['part', 'partsupp', 'supplier', 'customer', 'lineitem', 'orders', 'nation', 'region'] + tpch_tables = ["part", "partsupp", "supplier", "customer", "lineitem", "orders", "nation", "region"] arrow_tables = [] duckdb_conn = duckdb.connect() @@ -69,7 +69,7 @@ def test_tpch_arrow_01(self, duckdb_cursor): if not can_run: return - tpch_tables = ['part', 'partsupp', 'supplier', 'customer', 'lineitem', 'orders', 'nation', 'region'] + tpch_tables = ["part", "partsupp", "supplier", "customer", "lineitem", "orders", "nation", "region"] arrow_tables = [] duckdb_conn = duckdb.connect() @@ -97,7 +97,7 @@ def test_tpch_arrow_batch(self, duckdb_cursor): if not can_run: return - tpch_tables = ['part', 'partsupp', 'supplier', 'customer', 'lineitem', 'orders', 'nation', 'region'] + tpch_tables = ["part", "partsupp", "supplier", "customer", "lineitem", "orders", "nation", "region"] arrow_tables = [] duckdb_conn = duckdb.connect() diff --git a/tests/fast/arrow/test_unregister.py b/tests/fast/arrow/test_unregister.py index c63ef0d6..8ff37b5a 100644 --- a/tests/fast/arrow/test_unregister.py +++ b/tests/fast/arrow/test_unregister.py @@ -17,8 +17,8 @@ class TestArrowUnregister(object): def test_arrow_unregister1(self, duckdb_cursor): if not can_run: return - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") + cols = "id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments" arrow_table_obj = pyarrow.parquet.read_table(parquet_filename) connection = duckdb.connect(":memory:") @@ -26,9 +26,9 @@ def test_arrow_unregister1(self, duckdb_cursor): arrow_table_2 = connection.execute("SELECT * FROM arrow_table;").fetch_arrow_table() connection.unregister("arrow_table") - with pytest.raises(duckdb.CatalogException, match='Table with name arrow_table does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name arrow_table does not exist"): connection.execute("SELECT * FROM arrow_table;").fetch_arrow_table() - with pytest.raises(duckdb.CatalogException, match='View with name arrow_table does not exist'): + with pytest.raises(duckdb.CatalogException, match="View with name arrow_table does not exist"): connection.execute("DROP VIEW arrow_table;") connection.execute("DROP VIEW IF EXISTS arrow_table;") @@ -40,8 +40,8 @@ def test_arrow_unregister2(self, duckdb_cursor): os.remove(db) connection = duckdb.connect(db) - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') - cols = 'id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments' + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") + cols = "id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments" arrow_table_obj = pyarrow.parquet.read_table(parquet_filename) connection.register("arrow_table", arrow_table_obj) connection.unregister("arrow_table") # Attempting to unregister. @@ -49,7 +49,7 @@ def test_arrow_unregister2(self, duckdb_cursor): # Reconnecting while Arrow Table still in mem. connection = duckdb.connect(db) assert len(connection.execute("PRAGMA show_tables;").fetchall()) == 0 - with pytest.raises(duckdb.CatalogException, match='Table with name arrow_table does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name arrow_table does not exist"): connection.execute("SELECT * FROM arrow_table;").fetch_arrow_table() connection.close() del arrow_table_obj @@ -57,6 +57,6 @@ def test_arrow_unregister2(self, duckdb_cursor): # Reconnecting after Arrow Table is freed. connection = duckdb.connect(db) assert len(connection.execute("PRAGMA show_tables;").fetchall()) == 0 - with pytest.raises(duckdb.CatalogException, match='Table with name arrow_table does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name arrow_table does not exist"): connection.execute("SELECT * FROM arrow_table;").fetch_arrow_table() connection.close() diff --git a/tests/fast/arrow/test_view.py b/tests/fast/arrow/test_view.py index 54acb336..7f1410aa 100644 --- a/tests/fast/arrow/test_view.py +++ b/tests/fast/arrow/test_view.py @@ -8,9 +8,9 @@ class TestArrowView(object): def test_arrow_view(self, duckdb_cursor): - parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'userdata1.parquet') + parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_table = pa.parquet.read_table(parquet_filename) userdata_parquet_table.validate(full=True) - duckdb_cursor.from_arrow(userdata_parquet_table).create_view('arrow_view') - assert duckdb_cursor.execute("PRAGMA show_tables").fetchone() == ('arrow_view',) + duckdb_cursor.from_arrow(userdata_parquet_table).create_view("arrow_view") + assert duckdb_cursor.execute("PRAGMA show_tables").fetchone() == ("arrow_view",) assert duckdb_cursor.execute("select avg(salary)::INT from arrow_view").fetchone()[0] == 149005 diff --git a/tests/fast/numpy/test_numpy_new_path.py b/tests/fast/numpy/test_numpy_new_path.py index 4267085c..b872d4d9 100644 --- a/tests/fast/numpy/test_numpy_new_path.py +++ b/tests/fast/numpy/test_numpy_new_path.py @@ -28,11 +28,11 @@ def test_scan_numpy(self, duckdb_cursor): z = np.array(["zzz", "xxx"]) res = duckdb_cursor.sql("select * from z").fetchall() - assert res == [('zzz',), ('xxx',)] + assert res == [("zzz",), ("xxx",)] z = [np.array(["zzz", "xxx"]), np.array([1, 2])] res = duckdb_cursor.sql("select * from z").fetchall() - assert res == [('zzz', 1), ('xxx', 2)] + assert res == [("zzz", 1), ("xxx", 2)] # test ndarray with dtype = object (python dict) z = [] @@ -41,9 +41,9 @@ def test_scan_numpy(self, duckdb_cursor): z = np.array(z) res = duckdb_cursor.sql("select * from z").fetchall() assert res == [ - ({'3': 0},), - ({'2': 1},), - ({'1': 2},), + ({"3": 0},), + ({"2": 1},), + ({"1": 2},), ] # test timedelta @@ -74,12 +74,12 @@ def test_scan_numpy(self, duckdb_cursor): # dict of mixed types z = {"z": np.array([1, 2, 3]), "x": np.array(["z", "x", "c"])} res = duckdb_cursor.sql("select * from z").fetchall() - assert res == [(1, 'z'), (2, 'x'), (3, 'c')] + assert res == [(1, "z"), (2, "x"), (3, "c")] # list of mixed types z = [np.array([1, 2, 3]), np.array(["z", "x", "c"])] res = duckdb_cursor.sql("select * from z").fetchall() - assert res == [(1, 'z'), (2, 'x'), (3, 'c')] + assert res == [(1, "z"), (2, "x"), (3, "c")] # currently unsupported formats, will throw duckdb.InvalidInputException diff --git a/tests/fast/pandas/test_2304.py b/tests/fast/pandas/test_2304.py index 6fc355e5..11344df8 100644 --- a/tests/fast/pandas/test_2304.py +++ b/tests/fast/pandas/test_2304.py @@ -5,37 +5,37 @@ class TestPandasMergeSameName(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_2304(self, duckdb_cursor, pandas): df1 = pandas.DataFrame( { - 'id_1': [1, 1, 1, 2, 2], - 'agedate': np.array(['2010-01-01', '2010-02-01', '2010-03-01', '2020-02-01', '2020-03-01']).astype( - 'datetime64[D]' + "id_1": [1, 1, 1, 2, 2], + "agedate": np.array(["2010-01-01", "2010-02-01", "2010-03-01", "2020-02-01", "2020-03-01"]).astype( + "datetime64[D]" ), - 'age': [1, 2, 3, 1, 2], - 'v': [1.1, 1.2, 1.3, 2.1, 2.2], + "age": [1, 2, 3, 1, 2], + "v": [1.1, 1.2, 1.3, 2.1, 2.2], } ) df2 = pandas.DataFrame( { - 'id_1': [1, 1, 2], - 'agedate': np.array(['2010-01-01', '2010-02-01', '2020-03-01']).astype('datetime64[D]'), - 'v2': [11.1, 11.2, 21.2], + "id_1": [1, 1, 2], + "agedate": np.array(["2010-01-01", "2010-02-01", "2020-03-01"]).astype("datetime64[D]"), + "v2": [11.1, 11.2, 21.2], } ) con = duckdb.connect() - con.register('df1', df1) - con.register('df2', df2) + con.register("df1", df1) + con.register("df2", df2) query = """SELECT * from df1 LEFT OUTER JOIN df2 ON (df1.id_1=df2.id_1 and df1.agedate=df2.agedate) order by df1.id_1, df1.agedate, df1.age, df1.v, df2.id_1,df2.agedate,df2.v2""" result_df = con.execute(query).fetchdf() expected_result = con.execute(query).fetchall() - con.register('result_df', result_df) + con.register("result_df", result_df) rel = con.sql( """ select * from result_df order by @@ -52,32 +52,32 @@ def test_2304(self, duckdb_cursor, pandas): assert result == expected_result - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pd_names(self, duckdb_cursor, pandas): df1 = pandas.DataFrame( { - 'id': [1, 1, 2], - 'id_1': [1, 1, 2], - 'id_3': [1, 1, 2], + "id": [1, 1, 2], + "id_1": [1, 1, 2], + "id_3": [1, 1, 2], } ) - df2 = pandas.DataFrame({'id': [1, 1, 2], 'id_1': [1, 1, 2], 'id_2': [1, 1, 1]}) + df2 = pandas.DataFrame({"id": [1, 1, 2], "id_1": [1, 1, 2], "id_2": [1, 1, 1]}) exp_result = pandas.DataFrame( { - 'id': [1, 1, 2, 1, 1], - 'id_1': [1, 1, 2, 1, 1], - 'id_3': [1, 1, 2, 1, 1], - 'id_2': [1, 1, 2, 1, 1], - 'id_1_1': [1, 1, 2, 1, 1], - 'id_2_1': [1, 1, 1, 1, 1], + "id": [1, 1, 2, 1, 1], + "id_1": [1, 1, 2, 1, 1], + "id_3": [1, 1, 2, 1, 1], + "id_2": [1, 1, 2, 1, 1], + "id_1_1": [1, 1, 2, 1, 1], + "id_2_1": [1, 1, 1, 1, 1], } ) con = duckdb.connect() - con.register('df1', df1) - con.register('df2', df2) + con.register("df1", df1) + con.register("df2", df2) query = """SELECT * from df1 LEFT OUTER JOIN df2 ON (df1.id_1=df2.id_1)""" @@ -85,30 +85,30 @@ def test_pd_names(self, duckdb_cursor, pandas): result_df = con.execute(query).fetchdf() pandas.testing.assert_frame_equal(exp_result, result_df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_repeat_name(self, duckdb_cursor, pandas): df1 = pandas.DataFrame( { - 'id': [1], - 'id_1': [1], - 'id_2': [1], + "id": [1], + "id_1": [1], + "id_2": [1], } ) - df2 = pandas.DataFrame({'id': [1]}) + df2 = pandas.DataFrame({"id": [1]}) exp_result = pandas.DataFrame( { - 'id': [1], - 'id_1': [1], - 'id_2': [1], - 'id_3': [1], + "id": [1], + "id_1": [1], + "id_2": [1], + "id_3": [1], } ) con = duckdb.connect() - con.register('df1', df1) - con.register('df2', df2) + con.register("df1", df1) + con.register("df2", df2) result_df = con.execute( """ diff --git a/tests/fast/pandas/test_append_df.py b/tests/fast/pandas/test_append_df.py index 18805a5a..e6d64776 100644 --- a/tests/fast/pandas/test_append_df.py +++ b/tests/fast/pandas/test_append_df.py @@ -4,35 +4,35 @@ class TestAppendDF(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_df_to_table_append(self, duckdb_cursor, pandas): conn = duckdb.connect() conn.execute("Create table integers (i integer)") df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) - conn.append('integers', df_in) - assert conn.execute('select count(*) from integers').fetchone()[0] == 5 + conn.append("integers", df_in) + assert conn.execute("select count(*) from integers").fetchone()[0] == 5 - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_append_by_name(self, pandas): con = duckdb.connect() con.execute("create table tbl (a integer, b bool, c varchar)") - df_in = pandas.DataFrame({'c': ['duck', 'db'], 'b': [False, True], 'a': [4, 2]}) + df_in = pandas.DataFrame({"c": ["duck", "db"], "b": [False, True], "a": [4, 2]}) # By default we append by position, causing the following exception: with pytest.raises( duckdb.ConversionException, match="Conversion Error: Could not convert string 'duck' to INT32" ): - con.append('tbl', df_in) + con.append("tbl", df_in) # When we use 'by_name' we instead append by name - con.append('tbl', df_in, by_name=True) - res = con.table('tbl').fetchall() - assert res == [(4, False, 'duck'), (2, True, 'db')] + con.append("tbl", df_in, by_name=True) + res = con.table("tbl").fetchall() + assert res == [(4, False, "duck"), (2, True, "db")] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_append_by_name_quoted(self, pandas): con = duckdb.connect() con.execute( @@ -41,32 +41,32 @@ def test_append_by_name_quoted(self, pandas): """ ) df_in = pandas.DataFrame({"needs to be quoted": [1, 2, 3]}) - con.append('tbl', df_in, by_name=True) - res = con.table('tbl').fetchall() + con.append("tbl", df_in, by_name=True) + res = con.table("tbl").fetchall() assert res == [(1, None), (2, None), (3, None)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_append_by_name_no_exact_match(self, pandas): con = duckdb.connect() con.execute("create table tbl (a integer, b bool)") - df_in = pandas.DataFrame({'c': ['a', 'b'], 'b': [True, False], 'a': [42, 1337]}) + df_in = pandas.DataFrame({"c": ["a", "b"], "b": [True, False], "a": [42, 1337]}) # Too many columns raises an error, because the columns cant be found in the targeted table with pytest.raises(duckdb.BinderException, match='Table "tbl" does not have a column with name "c"'): - con.append('tbl', df_in, by_name=True) + con.append("tbl", df_in, by_name=True) - df_in = pandas.DataFrame({'b': [False, False, False]}) + df_in = pandas.DataFrame({"b": [False, False, False]}) # Not matching all columns is not a problem, as they will be filled with NULL instead - con.append('tbl', df_in, by_name=True) - res = con.table('tbl').fetchall() + con.append("tbl", df_in, by_name=True) + res = con.table("tbl").fetchall() # 'a' got filled by NULL automatically because it wasn't inserted into assert res == [(None, False), (None, False), (None, False)] # Empty the table con.execute("create or replace table tbl (a integer, b bool)") - df_in = pandas.DataFrame({'a': [1, 2, 3]}) - con.append('tbl', df_in, by_name=True) - res = con.table('tbl').fetchall() + df_in = pandas.DataFrame({"a": [1, 2, 3]}) + con.append("tbl", df_in, by_name=True) + res = con.table("tbl").fetchall() # Also works for missing columns *after* the supplied ones assert res == [(1, None), (2, None), (3, None)] diff --git a/tests/fast/pandas/test_bug2281.py b/tests/fast/pandas/test_bug2281.py index 703baf4b..98a90937 100644 --- a/tests/fast/pandas/test_bug2281.py +++ b/tests/fast/pandas/test_bug2281.py @@ -8,11 +8,11 @@ class TestPandasStringNull(object): def test_pandas_string_null(self, duckdb_cursor): - csv = u'''what,is_control,is_test + csv = """what,is_control,is_test ,0,0 -foo,1,0''' +foo,1,0""" df = pd.read_csv(io.StringIO(csv)) duckdb_cursor.register("c", df) - duckdb_cursor.execute('select what, count(*) from c group by what') + duckdb_cursor.execute("select what, count(*) from c group by what") df_result = duckdb_cursor.fetchdf() assert True # Should not crash ^^ diff --git a/tests/fast/pandas/test_bug5922.py b/tests/fast/pandas/test_bug5922.py index af9be167..28daabe9 100644 --- a/tests/fast/pandas/test_bug5922.py +++ b/tests/fast/pandas/test_bug5922.py @@ -4,13 +4,13 @@ class TestPandasAcceptFloat16(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_accept_float16(self, duckdb_cursor, pandas): - df = pandas.DataFrame({'col': [1, 2, 3]}) - df16 = df.astype({'col': 'float16'}) + df = pandas.DataFrame({"col": [1, 2, 3]}) + df16 = df.astype({"col": "float16"}) con = duckdb.connect() - con.execute('CREATE TABLE tbl AS SELECT * FROM df16') - con.execute('select * from tbl') + con.execute("CREATE TABLE tbl AS SELECT * FROM df16") + con.execute("select * from tbl") df_result = con.fetchdf() - df32 = df.astype({'col': 'float32'}) - assert (df32['col'] == df_result['col']).all() + df32 = df.astype({"col": "float32"}) + assert (df32["col"] == df_result["col"]).all() diff --git a/tests/fast/pandas/test_copy_on_write.py b/tests/fast/pandas/test_copy_on_write.py index dc484f1b..ec1b8786 100644 --- a/tests/fast/pandas/test_copy_on_write.py +++ b/tests/fast/pandas/test_copy_on_write.py @@ -2,7 +2,7 @@ import pytest # https://pandas.pydata.org/docs/dev/user_guide/copy_on_write.html -pandas = pytest.importorskip('pandas', '1.5', reason='copy_on_write does not exist in earlier versions') +pandas = pytest.importorskip("pandas", "1.5", reason="copy_on_write does not exist in earlier versions") import datetime @@ -23,9 +23,9 @@ def convert_to_result(col): class TestCopyOnWrite(object): @pytest.mark.parametrize( - 'col', + "col", [ - ['a', 'b', 'this is a long string'], + ["a", "b", "this is a long string"], [1.2334, None, 234.12], [123234, -213123, 2324234], [datetime.date(1990, 12, 7), None, datetime.date(1940, 1, 13)], @@ -37,10 +37,10 @@ def test_copy_on_write(self, col): con = duckdb.connect() df_in = pandas.DataFrame( { - 'numbers': col, + "numbers": col, } ) - rel = con.sql('select * from df_in') + rel = con.sql("select * from df_in") res = rel.fetchall() print(res) expected = convert_to_result(col) diff --git a/tests/fast/pandas/test_create_table_from_pandas.py b/tests/fast/pandas/test_create_table_from_pandas.py index 69234dc7..2194d964 100644 --- a/tests/fast/pandas/test_create_table_from_pandas.py +++ b/tests/fast/pandas/test_create_table_from_pandas.py @@ -26,12 +26,12 @@ def assert_create_register(internal_data, expected_result, data_type, pandas): class TestCreateTableFromPandas(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_integer_create_table(self, duckdb_cursor, pandas): if sys.version_info.major < 3: return # FIXME: This should work with other data types e.g., int8... - data_types = ['Int8', 'Int16', 'Int32', 'Int64'] + data_types = ["Int8", "Int16", "Int32", "Int64"] internal_data = [1, 2, 3, 4] expected_result = [(1,), (2,), (3,), (4,)] for data_type in data_types: diff --git a/tests/fast/pandas/test_date_as_datetime.py b/tests/fast/pandas/test_date_as_datetime.py index 038f24a8..b738b2e1 100644 --- a/tests/fast/pandas/test_date_as_datetime.py +++ b/tests/fast/pandas/test_date_as_datetime.py @@ -5,9 +5,9 @@ def run_checks(df): - assert type(df['d'][0]) is datetime.date - assert df['d'][0] == datetime.date(1992, 7, 30) - assert pd.isnull(df['d'][1]) + assert type(df["d"][0]) is datetime.date + assert df["d"][0] == datetime.date(1992, 7, 30) + assert pd.isnull(df["d"][1]) def test_date_as_datetime(): @@ -22,7 +22,7 @@ def test_date_as_datetime(): run_checks(con.execute("Select * from t").fetch_df(date_as_object=True)) # Relation Methods - rel = con.table('t') + rel = con.table("t") run_checks(rel.df(date_as_object=True)) run_checks(rel.to_df(date_as_object=True)) diff --git a/tests/fast/pandas/test_datetime_time.py b/tests/fast/pandas/test_datetime_time.py index cda96e6b..1a5a3f7a 100644 --- a/tests/fast/pandas/test_datetime_time.py +++ b/tests/fast/pandas/test_datetime_time.py @@ -8,24 +8,24 @@ class TestDateTimeTime(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_time_high(self, duckdb_cursor, pandas): duckdb_time = duckdb_cursor.sql("SELECT make_time(23, 1, 34.234345) AS '0'").df() data = [time(hour=23, minute=1, second=34, microsecond=234345)] - df_in = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + df_in = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) df_out = duckdb.query_df(df_in, "df", "select * from df").df() pandas.testing.assert_frame_equal(df_out, duckdb_time) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_time_low(self, duckdb_cursor, pandas): duckdb_time = duckdb_cursor.sql("SELECT make_time(00, 01, 1.000) AS '0'").df() data = [time(hour=0, minute=1, second=1)] - df_in = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + df_in = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) df_out = duckdb.query_df(df_in, "df", "select * from df").df() pandas.testing.assert_frame_equal(df_out, duckdb_time) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - @pytest.mark.parametrize('input', ['2263-02-28', '9999-01-01']) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("input", ["2263-02-28", "9999-01-01"]) def test_pandas_datetime_big(self, pandas, input): duckdb_con = duckdb.connect() @@ -33,8 +33,8 @@ def test_pandas_datetime_big(self, pandas, input): duckdb_con.execute(f"INSERT INTO TEST VALUES ('{input}')") res = duckdb_con.execute("select * from test").df() - date_value = np.array([f'{input}'], dtype='datetime64[us]') - df = pandas.DataFrame({'date': date_value}) + date_value = np.array([f"{input}"], dtype="datetime64[us]") + df = pandas.DataFrame({"date": date_value}) pandas.testing.assert_frame_equal(res, df) def test_timezone_datetime(self): @@ -45,6 +45,6 @@ def test_timezone_datetime(self): original = dt stringified = str(dt) - original_res = con.execute('select ?::TIMESTAMPTZ', [original]).fetchone() - stringified_res = con.execute('select ?::TIMESTAMPTZ', [stringified]).fetchone() + original_res = con.execute("select ?::TIMESTAMPTZ", [original]).fetchone() + stringified_res = con.execute("select ?::TIMESTAMPTZ", [stringified]).fetchone() assert original_res == stringified_res diff --git a/tests/fast/pandas/test_datetime_timestamp.py b/tests/fast/pandas/test_datetime_timestamp.py index e3b26501..ffc1b7d8 100644 --- a/tests/fast/pandas/test_datetime_timestamp.py +++ b/tests/fast/pandas/test_datetime_timestamp.py @@ -9,21 +9,21 @@ class TestDateTimeTimeStamp(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_timestamp_high(self, pandas, duckdb_cursor): duckdb_time = duckdb_cursor.sql("SELECT '2260-01-01 23:59:00'::TIMESTAMP AS '0'").df() df_in = pandas.DataFrame( { 0: pandas.Series( data=[datetime.datetime(year=2260, month=1, day=1, hour=23, minute=59)], - dtype='datetime64[us]', + dtype="datetime64[us]", ) } ) df_out = duckdb_cursor.sql("select * from df_in").df() pandas.testing.assert_frame_equal(df_out, duckdb_time) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_timestamp_low(self, pandas, duckdb_cursor): duckdb_time = duckdb_cursor.sql( """ @@ -32,27 +32,27 @@ def test_timestamp_low(self, pandas, duckdb_cursor): ).df() df_in = pandas.DataFrame( { - '0': pandas.Series( + "0": pandas.Series( data=[ pandas.Timestamp( datetime.datetime(year=1680, month=1, day=1, hour=23, minute=59, microsecond=234243), - unit='us', + unit="us", ) ], - dtype='datetime64[us]', + dtype="datetime64[us]", ) } ) - print('original:', duckdb_time['0'].dtype) - print('df_in:', df_in['0'].dtype) + print("original:", duckdb_time["0"].dtype) + print("df_in:", df_in["0"].dtype) df_out = duckdb_cursor.sql("select * from df_in").df() - print('df_out:', df_out['0'].dtype) + print("df_out:", df_out["0"].dtype) pandas.testing.assert_frame_equal(df_out, duckdb_time) @pytest.mark.skipif( - Version(pd.__version__) < Version('2.0.2'), reason="pandas < 2.0.2 does not properly convert timezones" + Version(pd.__version__) < Version("2.0.2"), reason="pandas < 2.0.2 does not properly convert timezones" ) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_timestamp_timezone_regular(self, pandas, duckdb_cursor): duckdb_time = duckdb_cursor.sql( """ @@ -65,7 +65,7 @@ def test_timestamp_timezone_regular(self, pandas, duckdb_cursor): df_in = pandas.DataFrame( { 0: pandas.Series( - data=[datetime.datetime(year=2022, month=1, day=1, hour=15, tzinfo=timezone)], dtype='object' + data=[datetime.datetime(year=2022, month=1, day=1, hour=15, tzinfo=timezone)], dtype="object" ) } ) @@ -75,9 +75,9 @@ def test_timestamp_timezone_regular(self, pandas, duckdb_cursor): pandas.testing.assert_frame_equal(df_out, duckdb_time) @pytest.mark.skipif( - Version(pd.__version__) < Version('2.0.2'), reason="pandas < 2.0.2 does not properly convert timezones" + Version(pd.__version__) < Version("2.0.2"), reason="pandas < 2.0.2 does not properly convert timezones" ) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_timestamp_timezone_negative_extreme(self, pandas, duckdb_cursor): duckdb_time = duckdb_cursor.sql( """ @@ -91,7 +91,7 @@ def test_timestamp_timezone_negative_extreme(self, pandas, duckdb_cursor): df_in = pandas.DataFrame( { 0: pandas.Series( - data=[datetime.datetime(year=2021, month=12, day=31, hour=22, tzinfo=timezone)], dtype='object' + data=[datetime.datetime(year=2021, month=12, day=31, hour=22, tzinfo=timezone)], dtype="object" ) } ) @@ -99,9 +99,9 @@ def test_timestamp_timezone_negative_extreme(self, pandas, duckdb_cursor): pandas.testing.assert_frame_equal(df_out, duckdb_time) @pytest.mark.skipif( - Version(pd.__version__) < Version('2.0.2'), reason="pandas < 2.0.2 does not properly convert timezones" + Version(pd.__version__) < Version("2.0.2"), reason="pandas < 2.0.2 does not properly convert timezones" ) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_timestamp_timezone_positive_extreme(self, pandas, duckdb_cursor): duckdb_time = duckdb_cursor.sql( """ @@ -115,7 +115,7 @@ def test_timestamp_timezone_positive_extreme(self, pandas, duckdb_cursor): df_in = pandas.DataFrame( { 0: pandas.Series( - data=[datetime.datetime(year=2021, month=12, day=31, hour=23, tzinfo=timezone)], dtype='object' + data=[datetime.datetime(year=2021, month=12, day=31, hour=23, tzinfo=timezone)], dtype="object" ) } ) @@ -123,16 +123,16 @@ def test_timestamp_timezone_positive_extreme(self, pandas, duckdb_cursor): pandas.testing.assert_frame_equal(df_out, duckdb_time) @pytest.mark.skipif( - Version(pd.__version__) < Version('2.0.2'), reason="pandas < 2.0.2 does not properly convert timezones" + Version(pd.__version__) < Version("2.0.2"), reason="pandas < 2.0.2 does not properly convert timezones" ) - @pytest.mark.parametrize('unit', ['ms', 'ns', 's']) + @pytest.mark.parametrize("unit", ["ms", "ns", "s"]) def test_timestamp_timezone_coverage(self, unit, duckdb_cursor): pd = pytest.importorskip("pandas") ts_df = pd.DataFrame( - {'ts': pd.Series(data=[pd.Timestamp(datetime.datetime(1990, 12, 21))], dtype=f'datetime64[{unit}]')} + {"ts": pd.Series(data=[pd.Timestamp(datetime.datetime(1990, 12, 21))], dtype=f"datetime64[{unit}]")} ) usecond_df = pd.DataFrame( - {'ts': pd.Series(data=[pd.Timestamp(datetime.datetime(1990, 12, 21))], dtype='datetime64[us]')} + {"ts": pd.Series(data=[pd.Timestamp(datetime.datetime(1990, 12, 21))], dtype="datetime64[us]")} ) query = """ @@ -142,12 +142,12 @@ def test_timestamp_timezone_coverage(self, unit, duckdb_cursor): """ duckdb_cursor.sql("set TimeZone = 'UTC'") - utc_usecond = duckdb_cursor.sql(query.format('usecond_df')).df() - utc_other = duckdb_cursor.sql(query.format('ts_df')).df() + utc_usecond = duckdb_cursor.sql(query.format("usecond_df")).df() + utc_other = duckdb_cursor.sql(query.format("ts_df")).df() duckdb_cursor.sql("set TimeZone = 'America/Los_Angeles'") - us_usecond = duckdb_cursor.sql(query.format('usecond_df')).df() - us_other = duckdb_cursor.sql(query.format('ts_df')).df() + us_usecond = duckdb_cursor.sql(query.format("usecond_df")).df() + us_other = duckdb_cursor.sql(query.format("ts_df")).df() pd.testing.assert_frame_equal(utc_usecond, utc_other) pd.testing.assert_frame_equal(us_usecond, us_other) diff --git a/tests/fast/pandas/test_df_analyze.py b/tests/fast/pandas/test_df_analyze.py index 114f8e3f..8e67da4a 100644 --- a/tests/fast/pandas/test_df_analyze.py +++ b/tests/fast/pandas/test_df_analyze.py @@ -6,11 +6,11 @@ def create_generic_dataframe(data, pandas): - return pandas.DataFrame({'col0': pandas.Series(data=data, dtype='object')}) + return pandas.DataFrame({"col0": pandas.Series(data=data, dtype="object")}) class TestResolveObjectColumns(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_sample_low_correct(self, duckdb_cursor, pandas): print(pandas.backend) duckdb_conn = duckdb.connect() @@ -21,7 +21,7 @@ def test_sample_low_correct(self, duckdb_cursor, pandas): duckdb_df = duckdb_conn.query("select * FROM (VALUES (1000008), (6), (9), (4), (1), (6)) as '0'").df() pandas.testing.assert_frame_equal(duckdb_df, roundtripped_df, check_dtype=False) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_sample_low_incorrect_detected(self, duckdb_cursor, pandas): duckdb_conn = duckdb.connect() duckdb_conn.execute("SET pandas_analyze_sample=2") @@ -31,9 +31,9 @@ def test_sample_low_incorrect_detected(self, duckdb_cursor, pandas): df = create_generic_dataframe(data, pandas) roundtripped_df = duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() # Sample high enough to detect mismatch in types, fallback to VARCHAR - assert roundtripped_df['col0'].dtype == np.dtype('object') + assert roundtripped_df["col0"].dtype == np.dtype("object") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_sample_zero(self, duckdb_cursor, pandas): duckdb_conn = duckdb.connect() # Disable dataframe analyze @@ -42,12 +42,12 @@ def test_sample_zero(self, duckdb_cursor, pandas): df = create_generic_dataframe(data, pandas) roundtripped_df = duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() # Always converts to VARCHAR - if pandas.backend == 'pyarrow': - assert roundtripped_df['col0'].dtype == np.dtype('int64') + if pandas.backend == "pyarrow": + assert roundtripped_df["col0"].dtype == np.dtype("int64") else: - assert roundtripped_df['col0'].dtype == np.dtype('object') + assert roundtripped_df["col0"].dtype == np.dtype("object") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_sample_low_incorrect_undetected(self, duckdb_cursor, pandas): duckdb_conn = duckdb.connect() duckdb_conn.execute("SET pandas_analyze_sample=1") @@ -65,10 +65,10 @@ def test_reset_analyze_sample_setting(self, duckdb_cursor): res = duckdb_cursor.execute("select current_setting('pandas_analyze_sample')").fetchall() assert res == [(1000,)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_10750(self, duckdb_cursor, pandas): max_row_number = 2000 - data = {'id': [i for i in range(max_row_number + 1)], 'content': [None for _ in range(max_row_number + 1)]} + data = {"id": [i for i in range(max_row_number + 1)], "content": [None for _ in range(max_row_number + 1)]} pdf = pandas.DataFrame(data=data) duckdb_cursor.register("content", pdf) diff --git a/tests/fast/pandas/test_df_object_resolution.py b/tests/fast/pandas/test_df_object_resolution.py index d54db072..73470818 100644 --- a/tests/fast/pandas/test_df_object_resolution.py +++ b/tests/fast/pandas/test_df_object_resolution.py @@ -13,7 +13,7 @@ def create_generic_dataframe(data, pandas): - return pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + return pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) def create_repeated_nulls(size): @@ -25,7 +25,7 @@ def create_repeated_nulls(size): def create_trailing_non_null(size): data = [None for _ in range(size - 1)] - data.append('this is a long string') + data.append("this is a long string") return data @@ -43,7 +43,7 @@ def ConvertStringToDecimal(data: list, pandas): for i in range(len(data)): if isinstance(data[i], str): data[i] = decimal.Decimal(data[i]) - data = pandas.Series(data=data, dtype='object') + data = pandas.Series(data=data, dtype="object") return data @@ -61,13 +61,13 @@ def construct_list(pair): def construct_struct(pair): - return [{'v1': pair.first}, {'v1': pair.second}] + return [{"v1": pair.first}, {"v1": pair.second}] def construct_map(pair): return [ - {'key': ['v1', 'v2'], "value": [pair.first, pair.first]}, - {'key': ['v1', 'v2'], "value": [pair.second, pair.second]}, + {"key": ["v1", "v2"], "value": [pair.first, pair.first]}, + {"key": ["v1", "v2"], "value": [pair.second, pair.second]}, ] @@ -83,157 +83,157 @@ def check_struct_upgrade(expected_type: str, creation_method, pair: ObjectPair, class TestResolveObjectColumns(object): # TODO: add support for ArrowPandas - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_integers(self, pandas, duckdb_cursor): data = [5, 0, 3] df_in = create_generic_dataframe(data, pandas) # These are float64 because pandas would force these to be float64 even if we set them to int8, int16, int32, int64 respectively - df_expected_res = pandas.DataFrame({'0': pandas.Series(data=data, dtype='int32')}) + df_expected_res = pandas.DataFrame({"0": pandas.Series(data=data, dtype="int32")}) df_out = duckdb_cursor.sql("SELECT * FROM df_in").df() print(df_out) pandas.testing.assert_frame_equal(df_expected_res, df_out) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_struct_correct(self, pandas, duckdb_cursor): - data = [{'a': 1, 'b': 3, 'c': 3, 'd': 7}] - df = pandas.DataFrame({'0': pandas.Series(data=data)}) + data = [{"a": 1, "b": 3, "c": 3, "d": 7}] + df = pandas.DataFrame({"0": pandas.Series(data=data)}) duckdb_col = duckdb_cursor.sql("SELECT {a: 1, b: 3, c: 3, d: 7} as '0'").df() converted_col = duckdb_cursor.sql("SELECT * FROM df").df() pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_fallback_different_keys(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'e': 7}], #'e' instead of 'd' as key - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "e": 7}], #'e' instead of 'd' as key + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() y = pandas.DataFrame( [ - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'e'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "e"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], ] ) equal_df = duckdb_cursor.sql("SELECT * FROM y").df() pandas.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_fallback_incorrect_amount_of_keys(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3}], # incorrect amount of keys - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3}], # incorrect amount of keys + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() y = pandas.DataFrame( [ - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c'], 'value': [1, 3, 3]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c"], "value": [1, 3, 3]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], ] ) equal_df = duckdb_cursor.sql("SELECT * FROM y").df() pandas.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_struct_value_upgrade(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'a': 1, 'b': 3, 'c': 3, 'd': 'string'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], + [{"a": 1, "b": 3, "c": 3, "d": "string"}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) y = pandas.DataFrame( [ - [{'a': 1, 'b': 3, 'c': 3, 'd': 'string'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': '7'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': '7'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': '7'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': '7'}], + [{"a": 1, "b": 3, "c": 3, "d": "string"}], + [{"a": 1, "b": 3, "c": 3, "d": "7"}], + [{"a": 1, "b": 3, "c": 3, "d": "7"}], + [{"a": 1, "b": 3, "c": 3, "d": "7"}], + [{"a": 1, "b": 3, "c": 3, "d": "7"}], ] ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() equal_df = duckdb_cursor.sql("SELECT * FROM y").df() pandas.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_struct_null(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ [None], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) y = pandas.DataFrame( [ [None], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() equal_df = duckdb_cursor.sql("SELECT * FROM y").df() pandas.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_fallback_value_upgrade(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'a': 1, 'b': 3, 'c': 3, 'd': 'test'}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], - [{'a': 1, 'b': 3, 'c': 3, 'd': 7}], + [{"a": 1, "b": 3, "c": 3, "d": "test"}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], + [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) y = pandas.DataFrame( [ - [{'a': '1', 'b': '3', 'c': '3', 'd': 'test'}], - [{'a': '1', 'b': '3', 'c': '3', 'd': '7'}], - [{'a': '1', 'b': '3', 'c': '3'}], - [{'a': '1', 'b': '3', 'c': '3', 'd': '7'}], - [{'a': '1', 'b': '3', 'c': '3', 'd': '7'}], + [{"a": "1", "b": "3", "c": "3", "d": "test"}], + [{"a": "1", "b": "3", "c": "3", "d": "7"}], + [{"a": "1", "b": "3", "c": "3"}], + [{"a": "1", "b": "3", "c": "3", "d": "7"}], + [{"a": "1", "b": "3", "c": "3", "d": "7"}], ] ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() equal_df = duckdb_cursor.sql("SELECT * FROM y").df() pandas.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_correct(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], ] ) - x.rename(columns={0: 'a'}, inplace=True) + x.rename(columns={0: "a"}, inplace=True) converted_col = duckdb_cursor.sql("select * from x as 'a'").df() duckdb_cursor.sql( """ @@ -253,10 +253,10 @@ def test_map_correct(self, pandas, duckdb_cursor): print(converted_col.columns) pandas.testing.assert_frame_equal(converted_col, duckdb_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) - @pytest.mark.parametrize('sample_size', [1, 10]) - @pytest.mark.parametrize('fill', [1000, 10000]) - @pytest.mark.parametrize('get_data', [create_repeated_nulls, create_trailing_non_null]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("sample_size", [1, 10]) + @pytest.mark.parametrize("fill", [1000, 10000]) + @pytest.mark.parametrize("get_data", [create_repeated_nulls, create_trailing_non_null]) def test_analyzing_nulls(self, pandas, duckdb_cursor, fill, sample_size, get_data): data = get_data(fill) df1 = pandas.DataFrame(data={"col1": data}) @@ -265,9 +265,9 @@ def test_analyzing_nulls(self, pandas, duckdb_cursor, fill, sample_size, get_dat pandas.testing.assert_frame_equal(df1, df) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_nested_map(self, pandas, duckdb_cursor): - df = pandas.DataFrame(data={'col1': [{'a': {'b': {'x': 'A', 'y': 'B'}}}, {'c': {'b': {'x': 'A'}}}]}) + df = pandas.DataFrame(data={"col1": [{"a": {"b": {"x": "A", "y": "B"}}}, {"c": {"b": {"x": "A"}}}]}) rel = duckdb_cursor.sql("select * from df") expected_rel = duckdb_cursor.sql( @@ -283,18 +283,18 @@ def test_nested_map(self, pandas, duckdb_cursor): expected_res = str(expected_rel) assert res == expected_res - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_value_upgrade(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 'test']}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], - [{'key': ['a', 'b', 'c', 'd'], 'value': [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, "test"]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], + [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], ] ) - x.rename(columns={0: 'a'}, inplace=True) + x.rename(columns={0: "a"}, inplace=True) converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql( """ @@ -319,69 +319,69 @@ def test_map_value_upgrade(self, pandas, duckdb_cursor): print(converted_col.columns) pandas.testing.assert_frame_equal(converted_col, duckdb_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_duplicate(self, pandas, duckdb_cursor): - x = pandas.DataFrame([[{'key': ['a', 'a', 'b'], 'value': [4, 0, 4]}]]) + x = pandas.DataFrame([[{"key": ["a", "a", "b"], "value": [4, 0, 4]}]]) with pytest.raises(duckdb.InvalidInputException, match="Map keys must be unique."): duckdb_cursor.sql("select * from x").show() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_nullkey(self, pandas, duckdb_cursor): - x = pandas.DataFrame([[{'key': [None, 'a', 'b'], 'value': [4, 0, 4]}]]) + x = pandas.DataFrame([[{"key": [None, "a", "b"], "value": [4, 0, 4]}]]) with pytest.raises(duckdb.InvalidInputException, match="Map keys can not be NULL."): converted_col = duckdb_cursor.sql("select * from x").df() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_nullkeylist(self, pandas, duckdb_cursor): - x = pandas.DataFrame([[{'key': None, 'value': None}]]) + x = pandas.DataFrame([[{"key": None, "value": None}]]) converted_col = duckdb_cursor.sql("select * from x").df() duckdb_col = duckdb_cursor.sql("SELECT MAP(NULL, NULL) as '0'").df() pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_fallback_nullkey(self, pandas, duckdb_cursor): - x = pandas.DataFrame([[{'a': 4, None: 0, 'c': 4}], [{'a': 4, None: 0, 'd': 4}]]) + x = pandas.DataFrame([[{"a": 4, None: 0, "c": 4}], [{"a": 4, None: 0, "d": 4}]]) with pytest.raises(duckdb.InvalidInputException, match="Map keys can not be NULL."): converted_col = duckdb_cursor.sql("select * from x").df() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_map_fallback_nullkey_coverage(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ - [{'key': None, 'value': None}], - [{'key': None, None: 5}], + [{"key": None, "value": None}], + [{"key": None, None: 5}], ] ) with pytest.raises(duckdb.InvalidInputException, match="Map keys can not be NULL."): converted_col = duckdb_cursor.sql("select * from x").df() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_structs_in_nested_types(self, pandas, duckdb_cursor): # This test is testing a bug that occurred when type upgrades occurred inside nested types # STRUCT(key1 varchar) + STRUCT(key1 varchar, key2 varchar) turns into MAP # But when inside a nested structure, this upgrade did not happen properly pairs = { - 'v1': ObjectPair({'key1': 21}, {'key1': 21, 'key2': 42}), - 'v2': ObjectPair({'key1': 21}, {'key2': 21}), - 'v3': ObjectPair({'key1': 21, 'key2': 42}, {'key1': 21}), - 'v4': ObjectPair({}, {'key1': 21}), + "v1": ObjectPair({"key1": 21}, {"key1": 21, "key2": 42}), + "v2": ObjectPair({"key1": 21}, {"key2": 21}), + "v3": ObjectPair({"key1": 21, "key2": 42}, {"key1": 21}), + "v4": ObjectPair({}, {"key1": 21}), } for _, pair in pairs.items(): - check_struct_upgrade('MAP(VARCHAR, INTEGER)[]', construct_list, pair, pandas, duckdb_cursor) + check_struct_upgrade("MAP(VARCHAR, INTEGER)[]", construct_list, pair, pandas, duckdb_cursor) for key, pair in pairs.items(): - if key == 'v4': - expected_type = 'MAP(VARCHAR, MAP(VARCHAR, INTEGER))' + if key == "v4": + expected_type = "MAP(VARCHAR, MAP(VARCHAR, INTEGER))" else: - expected_type = 'STRUCT(v1 MAP(VARCHAR, INTEGER))' + expected_type = "STRUCT(v1 MAP(VARCHAR, INTEGER))" check_struct_upgrade(expected_type, construct_struct, pair, pandas, duckdb_cursor) for key, pair in pairs.items(): - check_struct_upgrade('MAP(VARCHAR, MAP(VARCHAR, INTEGER))', construct_map, pair, pandas, duckdb_cursor) + check_struct_upgrade("MAP(VARCHAR, MAP(VARCHAR, INTEGER))", construct_map, pair, pandas, duckdb_cursor) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_structs_of_different_sizes(self, pandas, duckdb_cursor): # This list has both a STRUCT(v1) and a STRUCT(v1, v2) member # Those can't be combined @@ -404,9 +404,9 @@ def test_structs_of_different_sizes(self, pandas, duckdb_cursor): ) res = duckdb_cursor.query("select typeof(col) from df").fetchall() # So we fall back to converting them as VARCHAR instead - assert res == [('MAP(VARCHAR, VARCHAR)[]',), ('MAP(VARCHAR, VARCHAR)[]',)] + assert res == [("MAP(VARCHAR, VARCHAR)[]",), ("MAP(VARCHAR, VARCHAR)[]",)] - malformed_struct = duckdb.Value({"v1": 1, "v2": 2}, duckdb.struct_type({'v1': int})) + malformed_struct = duckdb.Value({"v1": 1, "v2": 2}, duckdb.struct_type({"v1": int})) with pytest.raises( duckdb.InvalidInputException, match=re.escape( @@ -416,7 +416,7 @@ def test_structs_of_different_sizes(self, pandas, duckdb_cursor): res = duckdb_cursor.execute("select $1", [malformed_struct]) print(res) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_struct_key_conversion(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ @@ -428,48 +428,48 @@ def test_struct_key_conversion(self, pandas, duckdb_cursor): duckdb_cursor.sql("drop view if exists tbl") pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_correct(self, pandas, duckdb_cursor): - x = pandas.DataFrame([{'0': [[5], [34], [-245]]}]) + x = pandas.DataFrame([{"0": [[5], [34], [-245]]}]) duckdb_col = duckdb_cursor.sql("select [[5], [34], [-245]] as '0'").df() converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql("drop view if exists tbl") pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_contains_null(self, pandas, duckdb_cursor): - x = pandas.DataFrame([{'0': [[5], None, [-245]]}]) + x = pandas.DataFrame([{"0": [[5], None, [-245]]}]) duckdb_col = duckdb_cursor.sql("select [[5], NULL, [-245]] as '0'").df() converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql("drop view if exists tbl") pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_starts_with_null(self, pandas, duckdb_cursor): - x = pandas.DataFrame([{'0': [None, [5], [-245]]}]) + x = pandas.DataFrame([{"0": [None, [5], [-245]]}]) duckdb_col = duckdb_cursor.sql("select [NULL, [5], [-245]] as '0'").df() converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql("drop view if exists tbl") pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_value_upgrade(self, pandas, duckdb_cursor): - x = pandas.DataFrame([{'0': [['5'], [34], [-245]]}]) + x = pandas.DataFrame([{"0": [["5"], [34], [-245]]}]) duckdb_rel = duckdb_cursor.sql("select [['5'], ['34'], ['-245']] as '0'") duckdb_col = duckdb_rel.df() converted_col = duckdb_cursor.sql("select * from x").df() pandas.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_column_value_upgrade(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ [[1, 25, 300]], [[500, 345, 30]], - [[50, 'a', 67]], + [[50, "a", 67]], ] ) - x.rename(columns={0: 'a'}, inplace=True) + x.rename(columns={0: "a"}, inplace=True) converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql( """ @@ -498,29 +498,29 @@ def test_list_column_value_upgrade(self, pandas, duckdb_cursor): print(converted_col.columns) pandas.testing.assert_frame_equal(converted_col, duckdb_col) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_ubigint_object_conversion(self, pandas, duckdb_cursor): # UBIGINT + TINYINT would result in HUGEINT, but conversion to HUGEINT is not supported yet from pandas->duckdb # So this instead becomes a DOUBLE data = [18446744073709551615, 0] - x = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + x = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) converted_col = duckdb_cursor.sql("select * from x").df() - if pandas.backend == 'numpy_nullable': - float64 = np.dtype('float64') - assert isinstance(converted_col['0'].dtype, float64.__class__) == True + if pandas.backend == "numpy_nullable": + float64 = np.dtype("float64") + assert isinstance(converted_col["0"].dtype, float64.__class__) == True else: - uint64 = np.dtype('uint64') - assert isinstance(converted_col['0'].dtype, uint64.__class__) == True + uint64 = np.dtype("uint64") + assert isinstance(converted_col["0"].dtype, uint64.__class__) == True - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_double_object_conversion(self, pandas, duckdb_cursor): data = [18446744073709551616, 0] - x = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + x = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) converted_col = duckdb_cursor.sql("select * from x").df() - double_dtype = np.dtype('float64') - assert isinstance(converted_col['0'].dtype, double_dtype.__class__) == True + double_dtype = np.dtype("float64") + assert isinstance(converted_col["0"].dtype, double_dtype.__class__) == True - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) @pytest.mark.xfail( condition=platform.system() == "Emscripten", reason="older numpy raises a warning when running with Pyodide", @@ -551,51 +551,51 @@ def test_numpy_object_with_stride(self, pandas, duckdb_cursor): (9, 18, 0), ] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numpy_stringliterals(self, pandas, duckdb_cursor): df = pandas.DataFrame({"x": list(map(np.str_, range(3)))}) res = duckdb_cursor.execute("select * from df").fetchall() - assert res == [('0',), ('1',), ('2',)] + assert res == [("0",), ("1",), ("2",)] - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_integer_conversion_fail(self, pandas, duckdb_cursor): data = [2**10000, 0] - x = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + x = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) converted_col = duckdb_cursor.sql("select * from x").df() - print(converted_col['0']) - double_dtype = np.dtype('object') - assert isinstance(converted_col['0'].dtype, double_dtype.__class__) == True + print(converted_col["0"]) + double_dtype = np.dtype("object") + assert isinstance(converted_col["0"].dtype, double_dtype.__class__) == True # Most of the time numpy.datetime64 is just a wrapper around a datetime.datetime object # But to support arbitrary precision, it can fall back to using an `int` internally - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) # Which we don't support yet + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) # Which we don't support yet def test_numpy_datetime(self, pandas, duckdb_cursor): numpy = pytest.importorskip("numpy") data = [] - data += [numpy.datetime64('2022-12-10T21:38:24.578696')] * standard_vector_size - data += [numpy.datetime64('2022-02-21T06:59:23.324812')] * standard_vector_size - data += [numpy.datetime64('1974-06-05T13:12:01.000000')] * standard_vector_size - data += [numpy.datetime64('2049-01-13T00:24:31.999999')] * standard_vector_size - x = pandas.DataFrame({'dates': pandas.Series(data=data, dtype='object')}) + data += [numpy.datetime64("2022-12-10T21:38:24.578696")] * standard_vector_size + data += [numpy.datetime64("2022-02-21T06:59:23.324812")] * standard_vector_size + data += [numpy.datetime64("1974-06-05T13:12:01.000000")] * standard_vector_size + data += [numpy.datetime64("2049-01-13T00:24:31.999999")] * standard_vector_size + x = pandas.DataFrame({"dates": pandas.Series(data=data, dtype="object")}) res = duckdb_cursor.sql("select distinct * from x").df() - assert len(res['dates'].__array__()) == 4 + assert len(res["dates"].__array__()) == 4 - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_numpy_datetime_int_internally(self, pandas, duckdb_cursor): numpy = pytest.importorskip("numpy") - data = [numpy.datetime64('2022-12-10T21:38:24.0000000000001')] - x = pandas.DataFrame({'dates': pandas.Series(data=data, dtype='object')}) + data = [numpy.datetime64("2022-12-10T21:38:24.0000000000001")] + x = pandas.DataFrame({"dates": pandas.Series(data=data, dtype="object")}) with pytest.raises( duckdb.ConversionException, match=re.escape("Conversion Error: Unimplemented type for cast (BIGINT -> TIMESTAMP)"), ): rel = duckdb.query_df(x, "x", "create table dates as select dates::TIMESTAMP WITHOUT TIME ZONE from x") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fallthrough_object_conversion(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ @@ -605,10 +605,10 @@ def test_fallthrough_object_conversion(self, pandas, duckdb_cursor): ] ) duckdb_col = duckdb_cursor.sql("select * from x").df() - df_expected_res = pandas.DataFrame({'0': pandas.Series(['4', '2', '0'])}) + df_expected_res = pandas.DataFrame({"0": pandas.Series(["4", "2", "0"])}) pandas.testing.assert_frame_equal(duckdb_col, df_expected_res) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal(self, pandas, duckdb_cursor): # DuckDB uses DECIMAL where possible, so all the 'float' types here are actually DECIMAL reference_query = """ @@ -626,12 +626,12 @@ def test_numeric_decimal(self, pandas, duckdb_cursor): # Because of this we need to wrap these native floats as DECIMAL for this test, to avoid these decimals being "upgraded" to DOUBLE x = pandas.DataFrame( { - '0': ConvertStringToDecimal([5, '12.0', '-123.0', '-234234.0', None, '1.234'], pandas), - '1': ConvertStringToDecimal( - [5002340, 13, '-12.0000000005', '7453324234.0', None, '-324234234'], pandas + "0": ConvertStringToDecimal([5, "12.0", "-123.0", "-234234.0", None, "1.234"], pandas), + "1": ConvertStringToDecimal( + [5002340, 13, "-12.0000000005", "7453324234.0", None, "-324234234"], pandas ), - '2': ConvertStringToDecimal( - ['-234234234234.0', '324234234.00000005', -128, 345345, '1E5', '1324234359'], pandas + "2": ConvertStringToDecimal( + ["-234234234234.0", "324234234.00000005", -128, 345345, "1E5", "1324234359"], pandas ), } ) @@ -640,10 +640,10 @@ def test_numeric_decimal(self, pandas, duckdb_cursor): assert conversion == reference - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_coverage(self, pandas, duckdb_cursor): x = pandas.DataFrame( - {'0': [Decimal("nan"), Decimal("+nan"), Decimal("-nan"), Decimal("inf"), Decimal("+inf"), Decimal("-inf")]} + {"0": [Decimal("nan"), Decimal("+nan"), Decimal("-nan"), Decimal("inf"), Decimal("+inf"), Decimal("-inf")]} ) conversion = duckdb_cursor.sql("select * from x").fetchall() print(conversion[0][0].__class__) @@ -655,12 +655,12 @@ def test_numeric_decimal_coverage(self, pandas, duckdb_cursor): assert math.isinf(conversion[3][0]) assert math.isinf(conversion[4][0]) assert math.isinf(conversion[5][0]) - assert str(conversion) == '[(nan,), (nan,), (nan,), (inf,), (inf,), (inf,)]' + assert str(conversion) == "[(nan,), (nan,), (nan,), (inf,), (inf,), (inf,)]" # Test that the column 'offset' is actually used when converting, @pytest.mark.parametrize( - 'pandas', [NumpyPandas(), ArrowPandas()] + "pandas", [NumpyPandas(), ArrowPandas()] ) # and that the same 2048 (STANDARD_VECTOR_SIZE) values are not being scanned over and over again def test_multiple_chunks(self, pandas, duckdb_cursor): data = [] @@ -668,11 +668,11 @@ def test_multiple_chunks(self, pandas, duckdb_cursor): data += [datetime.date(2022, 9, 14) for x in range(standard_vector_size)] data += [datetime.date(2022, 9, 15) for x in range(standard_vector_size)] data += [datetime.date(2022, 9, 16) for x in range(standard_vector_size)] - x = pandas.DataFrame({'dates': pandas.Series(data=data, dtype='object')}) + x = pandas.DataFrame({"dates": pandas.Series(data=data, dtype="object")}) res = duckdb_cursor.sql("select distinct * from x").df() - assert len(res['dates'].__array__()) == 4 + assert len(res["dates"].__array__()) == 4 - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): duckdb_cursor.execute(f"SET GLOBAL pandas_analyze_sample=4096") duckdb_cursor.execute( @@ -683,8 +683,8 @@ def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): date_df = res.copy() # Convert the dataframe to datetime - date_df['i'] = pandas.to_datetime(res['i']).dt.date - assert str(date_df['i'].dtype) == 'object' + date_df["i"] = pandas.to_datetime(res["i"]).dt.date + assert str(date_df["i"].dtype) == "object" expected_res = [ ( @@ -707,10 +707,10 @@ def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): assert expected_res == actual_res # Now interleave nulls into the dataframe - duckdb_cursor.execute('drop table dates') - for i in range(0, len(res['i']), 2): - res.loc[i, 'i'] = None - duckdb_cursor.execute('create table dates as select * from res') + duckdb_cursor.execute("drop table dates") + for i in range(0, len(res["i"]), 2): + res.loc[i, "i"] = None + duckdb_cursor.execute("create table dates as select * from res") expected_res = [ ( @@ -721,8 +721,8 @@ def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): ] # Convert the dataframe to datetime date_df = res.copy() - date_df['i'] = pandas.to_datetime(res['i']).dt.date - assert str(date_df['i'].dtype) == 'object' + date_df["i"] = pandas.to_datetime(res["i"]).dt.date + assert str(date_df["i"].dtype) == "object" actual_res = duckdb_cursor.sql( """ @@ -736,47 +736,47 @@ def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): assert expected_res == actual_res - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_mixed_object_types(self, pandas, duckdb_cursor): x = pandas.DataFrame( { - 'nested': pandas.Series( - data=[{'a': 1, 'b': 2}, [5, 4, 3], {'key': [1, 2, 3], 'value': ['a', 'b', 'c']}], dtype='object' + "nested": pandas.Series( + data=[{"a": 1, "b": 2}, [5, 4, 3], {"key": [1, 2, 3], "value": ["a", "b", "c"]}], dtype="object" ), } ) res = duckdb_cursor.sql("select * from x").df() - assert res['nested'].dtype == np.dtype('object') + assert res["nested"].dtype == np.dtype("object") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_struct_deeply_nested_in_struct(self, pandas, duckdb_cursor): x = pandas.DataFrame( [ { # STRUCT(b STRUCT(x VARCHAR, y VARCHAR)) - 'a': {'b': {'x': 'A', 'y': 'B'}} + "a": {"b": {"x": "A", "y": "B"}} }, { # STRUCT(b STRUCT(x VARCHAR)) - 'a': {'b': {'x': 'A'}} + "a": {"b": {"x": "A"}} }, ] ) # The dataframe has incompatible struct schemas in the nested child # This gets upgraded to STRUCT(b MAP(VARCHAR, VARCHAR)) res = duckdb_cursor.sql("select * from x").fetchall() - assert res == [({'b': {'x': 'A', 'y': 'B'}},), ({'b': {'x': 'A'}},)] + assert res == [({"b": {"x": "A", "y": "B"}},), ({"b": {"x": "A"}},)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_struct_deeply_nested_in_list(self, pandas, duckdb_cursor): x = pandas.DataFrame( { - 'a': [ + "a": [ [ # STRUCT(x VARCHAR, y VARCHAR)[] - {'x': 'A', 'y': 'B'}, + {"x": "A", "y": "B"}, # STRUCT(x VARCHAR)[] - {'x': 'A'}, + {"x": "A"}, ] ] } @@ -784,16 +784,16 @@ def test_struct_deeply_nested_in_list(self, pandas, duckdb_cursor): # The dataframe has incompatible struct schemas in the nested child # This gets upgraded to STRUCT(b MAP(VARCHAR, VARCHAR)) res = duckdb_cursor.sql("select * from x").fetchall() - assert res == [([{'x': 'A', 'y': 'B'}, {'x': 'A'}],)] + assert res == [([{"x": "A", "y": "B"}, {"x": "A"}],)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_analyze_sample_too_small(self, pandas, duckdb_cursor): data = [1 for _ in range(9)] + [[1, 2, 3]] + [1 for _ in range(9991)] - x = pandas.DataFrame({'a': pandas.Series(data=data)}) + x = pandas.DataFrame({"a": pandas.Series(data=data)}) with pytest.raises(duckdb.InvalidInputException, match="Failed to cast value: Unimplemented type for cast"): res = duckdb_cursor.sql("select * from x").df() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_zero_fractional(self, pandas, duckdb_cursor): decimals = pandas.DataFrame( data={ @@ -826,7 +826,7 @@ def test_numeric_decimal_zero_fractional(self, pandas, duckdb_cursor): assert conversion == reference - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_incompatible(self, pandas, duckdb_cursor): reference_query = """ CREATE TABLE tbl AS SELECT * FROM ( @@ -842,10 +842,10 @@ def test_numeric_decimal_incompatible(self, pandas, duckdb_cursor): duckdb_cursor.execute(reference_query) x = pandas.DataFrame( { - '0': ConvertStringToDecimal(['5', '12.0', '-123.0', '-234234.0', None, '1.234'], pandas), - '1': ConvertStringToDecimal([5002340, 13, '-12.0000000005', 7453324234, None, '-324234234'], pandas), - '2': ConvertStringToDecimal( - [-234234234234, '324234234.00000005', -128, 345345, 0, '1324234359'], pandas + "0": ConvertStringToDecimal(["5", "12.0", "-123.0", "-234234.0", None, "1.234"], pandas), + "1": ConvertStringToDecimal([5002340, 13, "-12.0000000005", 7453324234, None, "-324234234"], pandas), + "2": ConvertStringToDecimal( + [-234234234234, "324234234.00000005", -128, 345345, 0, "1324234359"], pandas ), } ) @@ -857,7 +857,7 @@ def test_numeric_decimal_incompatible(self, pandas, duckdb_cursor): print(conversion) @pytest.mark.parametrize( - 'pandas', [NumpyPandas(), ArrowPandas()] + "pandas", [NumpyPandas(), ArrowPandas()] ) # result: [('1E-28',), ('10000000000000000000000000.0',)] def test_numeric_decimal_combined(self, pandas, duckdb_cursor): decimals = pandas.DataFrame( @@ -878,7 +878,7 @@ def test_numeric_decimal_combined(self, pandas, duckdb_cursor): print(conversion) # result: [('1234.0',), ('123456789.0',), ('1234567890123456789.0',), ('0.1234567890123456789',)] - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_varying_sizes(self, pandas, duckdb_cursor): decimals = pandas.DataFrame( data={ @@ -906,7 +906,7 @@ def test_numeric_decimal_varying_sizes(self, pandas, duckdb_cursor): print(reference) print(conversion) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_fallback_to_double(self, pandas, duckdb_cursor): # The widths of these decimal values are bigger than the max supported width for DECIMAL data = [ @@ -927,7 +927,7 @@ def test_numeric_decimal_fallback_to_double(self, pandas, duckdb_cursor): assert conversion == reference assert isinstance(conversion[0][0], float) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_double_mixed(self, pandas, duckdb_cursor): data = [ Decimal("1.234"), @@ -959,7 +959,7 @@ def test_numeric_decimal_double_mixed(self, pandas, duckdb_cursor): assert conversion == reference assert isinstance(conversion[0][0], float) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_numeric_decimal_out_of_range(self, pandas, duckdb_cursor): data = [Decimal("1.234567890123456789012345678901234567"), Decimal("123456789012345678901234567890123456.0")] decimals = pandas.DataFrame(data={"0": data}) diff --git a/tests/fast/pandas/test_df_recursive_nested.py b/tests/fast/pandas/test_df_recursive_nested.py index b8de512a..fb7d2ad0 100644 --- a/tests/fast/pandas/test_df_recursive_nested.py +++ b/tests/fast/pandas/test_df_recursive_nested.py @@ -12,8 +12,8 @@ def check_equal(conn, df, reference_query, data): duckdb_conn = duckdb.connect() duckdb_conn.execute(reference_query, parameters=[data]) - res = duckdb_conn.query('SELECT * FROM tbl').fetchall() - df_res = duckdb_conn.query('SELECT * FROM tbl').df() + res = duckdb_conn.query("SELECT * FROM tbl").fetchall() + df_res = duckdb_conn.query("SELECT * FROM tbl").df() out = conn.sql("SELECT * FROM df").fetchall() assert res == out @@ -24,39 +24,39 @@ def create_reference_query(): class TestDFRecursiveNested(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_of_structs(self, duckdb_cursor, pandas): - data = [[{'a': 5}, NULL, {'a': NULL}], NULL, [{'a': 5}, NULL, {'a': NULL}]] + data = [[{"a": 5}, NULL, {"a": NULL}], NULL, [{"a": 5}, NULL, {"a": NULL}]] reference_query = create_reference_query() - df = pandas.DataFrame([{'a': data}]) - check_equal(duckdb_cursor, df, reference_query, Value(data, 'STRUCT(a INTEGER)[]')) + df = pandas.DataFrame([{"a": data}]) + check_equal(duckdb_cursor, df, reference_query, Value(data, "STRUCT(a INTEGER)[]")) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_of_map(self, duckdb_cursor, pandas): # LIST(MAP(VARCHAR, VARCHAR)) - data = [[{5: NULL}, NULL, {}], NULL, [NULL, {3: NULL, 2: 'a', 4: NULL}, {'a': 1, 'b': 2, 'c': 3}]] + data = [[{5: NULL}, NULL, {}], NULL, [NULL, {3: NULL, 2: "a", 4: NULL}, {"a": 1, "b": 2, "c": 3}]] reference_query = create_reference_query() print(reference_query) - df = pandas.DataFrame([{'a': data}]) - check_equal(duckdb_cursor, df, reference_query, Value(data, 'MAP(VARCHAR, VARCHAR)[][]')) + df = pandas.DataFrame([{"a": data}]) + check_equal(duckdb_cursor, df, reference_query, Value(data, "MAP(VARCHAR, VARCHAR)[][]")) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_recursive_list(self, duckdb_cursor, pandas): # LIST(LIST(LIST(LIST(INTEGER)))) data = [[[[3, NULL, 5], NULL], NULL, [[5, -20, NULL]]], NULL, [[[NULL]], [[]], NULL]] reference_query = create_reference_query() - df = pandas.DataFrame([{'a': data}]) - check_equal(duckdb_cursor, df, reference_query, Value(data, 'INTEGER[][][][]')) + df = pandas.DataFrame([{"a": data}]) + check_equal(duckdb_cursor, df, reference_query, Value(data, "INTEGER[][][][]")) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_recursive_struct(self, duckdb_cursor, pandas): # STRUCT(STRUCT(STRUCT(LIST))) data = { - 'A': {'a': {'1': [1, 2, 3]}, 'b': NULL, 'c': {'1': NULL}}, - 'B': {'a': {'1': [1, NULL, 3]}, 'b': NULL, 'c': {'1': NULL}}, + "A": {"a": {"1": [1, 2, 3]}, "b": NULL, "c": {"1": NULL}}, + "B": {"a": {"1": [1, NULL, 3]}, "b": NULL, "c": {"1": NULL}}, } reference_query = create_reference_query() - df = pandas.DataFrame([{'a': data}]) + df = pandas.DataFrame([{"a": data}]) check_equal( duckdb_cursor, df, @@ -92,7 +92,7 @@ def test_recursive_struct(self, duckdb_cursor, pandas): ), ) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_recursive_map(self, duckdb_cursor, pandas): # MAP( # MAP( @@ -102,42 +102,42 @@ def test_recursive_map(self, duckdb_cursor, pandas): # INTEGER # ) data = { - 'key': [ - {'key': [5, 6, 7], 'value': [{'key': [8], 'value': [NULL]}, NULL, {'key': [9], 'value': ['a']}]}, - {'key': [], 'value': []}, + "key": [ + {"key": [5, 6, 7], "value": [{"key": [8], "value": [NULL]}, NULL, {"key": [9], "value": ["a"]}]}, + {"key": [], "value": []}, ], - 'value': [1, 2], + "value": [1, 2], } reference_query = create_reference_query() - df = pandas.DataFrame([{'a': data}]) + df = pandas.DataFrame([{"a": data}]) check_equal( - duckdb_cursor, df, reference_query, Value(data, 'MAP(MAP(INTEGER, MAP(INTEGER, VARCHAR)), INTEGER)') + duckdb_cursor, df, reference_query, Value(data, "MAP(MAP(INTEGER, MAP(INTEGER, VARCHAR)), INTEGER)") ) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_recursive_stresstest(self, duckdb_cursor, pandas): data = [ { - 'a': { - 'key': [ + "a": { + "key": [ # key 1 - {'1': [5, 4, 3], '2': [8, 7, 6], '3': [1, 2, 3]}, + {"1": [5, 4, 3], "2": [8, 7, 6], "3": [1, 2, 3]}, # key 2 - {'1': [], '2': NULL, '3': [NULL, 0, NULL]}, + {"1": [], "2": NULL, "3": [NULL, 0, NULL]}, ], - 'value': [ + "value": [ # value 1 - [{'A': 'abc', 'B': 'def', 'C': NULL}], + [{"A": "abc", "B": "def", "C": NULL}], # value 2 [NULL], ], }, - 'b': NULL, - 'c': {'key': [], 'value': []}, + "b": NULL, + "c": {"key": [], "value": []}, } ] reference_query = create_reference_query() - df = pandas.DataFrame([{'a': data}]) + df = pandas.DataFrame([{"a": data}]) duckdb_type = """ STRUCT( a MAP( diff --git a/tests/fast/pandas/test_fetch_df_chunk.py b/tests/fast/pandas/test_fetch_df_chunk.py index 1973a729..1f2d4b1b 100644 --- a/tests/fast/pandas/test_fetch_df_chunk.py +++ b/tests/fast/pandas/test_fetch_df_chunk.py @@ -13,16 +13,16 @@ def test_fetch_df_chunk(self): # Fetch the first chunk cur_chunk = query.fetch_df_chunk() - assert cur_chunk['a'][0] == 0 + assert cur_chunk["a"][0] == 0 assert len(cur_chunk) == VECTOR_SIZE # Fetch the second chunk, can't be entirely filled cur_chunk = query.fetch_df_chunk() - assert cur_chunk['a'][0] == VECTOR_SIZE + assert cur_chunk["a"][0] == VECTOR_SIZE expected = size - VECTOR_SIZE assert len(cur_chunk) == expected - @pytest.mark.parametrize('size', [3000, 10000, 100000, VECTOR_SIZE - 1, VECTOR_SIZE + 1, VECTOR_SIZE]) + @pytest.mark.parametrize("size", [3000, 10000, 100000, VECTOR_SIZE - 1, VECTOR_SIZE + 1, VECTOR_SIZE]) def test_monahan(self, size): con = duckdb.connect() con.execute(f"CREATE table t as select range a from range({size});") @@ -52,12 +52,12 @@ def test_fetch_df_chunk_parameter(self): # Return 2 vectors cur_chunk = query.fetch_df_chunk(2) - assert cur_chunk['a'][0] == 0 + assert cur_chunk["a"][0] == 0 assert len(cur_chunk) == VECTOR_SIZE * 2 # Return Default 1 vector cur_chunk = query.fetch_df_chunk() - assert cur_chunk['a'][0] == VECTOR_SIZE * 2 + assert cur_chunk["a"][0] == VECTOR_SIZE * 2 assert len(cur_chunk) == VECTOR_SIZE # Return 0 vectors @@ -69,7 +69,7 @@ def test_fetch_df_chunk_parameter(self): # Return more vectors than we have remaining cur_chunk = query.fetch_df_chunk(3) - assert cur_chunk['a'][0] == fetched + assert cur_chunk["a"][0] == fetched assert len(cur_chunk) == expected # These shouldn't throw errors (Just emmit empty chunks) @@ -88,5 +88,5 @@ def test_fetch_df_chunk_negative_parameter(self): query = con.execute("SELECT a FROM t") # Return -1 vector should not work - with pytest.raises(TypeError, match='incompatible function arguments'): + with pytest.raises(TypeError, match="incompatible function arguments"): cur_chunk = query.fetch_df_chunk(-1) diff --git a/tests/fast/pandas/test_fetch_nested.py b/tests/fast/pandas/test_fetch_nested.py index 5727429f..e25a44ba 100644 --- a/tests/fast/pandas/test_fetch_nested.py +++ b/tests/fast/pandas/test_fetch_nested.py @@ -10,10 +10,10 @@ def compare_results(con, query, expected): expected = pd.DataFrame.from_dict(expected) unsorted_res = con.query(query).df() - print(unsorted_res, unsorted_res['a'][0].__class__) + print(unsorted_res, unsorted_res["a"][0].__class__) df_duck = con.query("select * from unsorted_res order by all").df() - print(df_duck, df_duck['a'][0].__class__) - print(expected, expected['a'][0].__class__) + print(df_duck, df_duck["a"][0].__class__) + print(expected, expected["a"][0].__class__) pd.testing.assert_frame_equal(df_duck, expected) @@ -147,7 +147,7 @@ def list_test_cases(): class TestFetchNested(object): - @pytest.mark.parametrize('query, expected', list_test_cases()) + @pytest.mark.parametrize("query, expected", list_test_cases()) def test_fetch_df_list(self, duckdb_cursor, query, expected): compare_results(duckdb_cursor, query, expected) diff --git a/tests/fast/pandas/test_implicit_pandas_scan.py b/tests/fast/pandas/test_implicit_pandas_scan.py index e6f0b9f4..2d4610ff 100644 --- a/tests/fast/pandas/test_implicit_pandas_scan.py +++ b/tests/fast/pandas/test_implicit_pandas_scan.py @@ -15,7 +15,7 @@ except: pyarrow_dtypes_enabled = False -if Version(pd.__version__) >= Version('2.0.0') and pyarrow_dtypes_enabled: +if Version(pd.__version__) >= Version("2.0.0") and pyarrow_dtypes_enabled: pyarrow_df = numpy_nullable_df.convert_dtypes(dtype_backend="pyarrow") else: # dtype_backend is not supported in pandas < 2.0.0 @@ -23,20 +23,20 @@ class TestImplicitPandasScan(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_local_pandas_scan(self, duckdb_cursor, pandas): con = duckdb.connect() df = pandas.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val3", "CoL2": 17}]) - r1 = con.execute('select * from df').fetchdf() + r1 = con.execute("select * from df").fetchdf() assert r1["COL1"][0] == "val1" assert r1["COL1"][1] == "val3" assert r1["CoL2"][0] == 1.05 assert r1["CoL2"][1] == 17 - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_global_pandas_scan(self, duckdb_cursor, pandas): con = duckdb.connect() - r1 = con.execute(f'select * from {pandas.backend}_df').fetchdf() + r1 = con.execute(f"select * from {pandas.backend}_df").fetchdf() assert r1["COL1"][0] == "val1" assert r1["COL1"][1] == "val4" assert r1["CoL2"][0] == 1.05 diff --git a/tests/fast/pandas/test_import_cache.py b/tests/fast/pandas/test_import_cache.py index 32eab7b0..6ed601c5 100644 --- a/tests/fast/pandas/test_import_cache.py +++ b/tests/fast/pandas/test_import_cache.py @@ -3,26 +3,26 @@ import pytest -@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +@pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_import_cache_explicit_dtype(pandas): df = pandas.DataFrame( { - 'id': [1, 2, 3], - 'value': pandas.Series(['123.123', pandas.NaT, pandas.NA], dtype=pandas.StringDtype(storage='python')), + "id": [1, 2, 3], + "value": pandas.Series(["123.123", pandas.NaT, pandas.NA], dtype=pandas.StringDtype(storage="python")), } ) con = duckdb.connect() result_df = con.query("select id, value from df").df() - assert result_df['value'][1] is None - assert result_df['value'][2] is None + assert result_df["value"][1] is None + assert result_df["value"][2] is None -@pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) +@pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_import_cache_implicit_dtype(pandas): - df = pandas.DataFrame({'id': [1, 2, 3], 'value': pandas.Series(['123.123', pandas.NaT, pandas.NA])}) + df = pandas.DataFrame({"id": [1, 2, 3], "value": pandas.Series(["123.123", pandas.NaT, pandas.NA])}) con = duckdb.connect() result_df = con.query("select id, value from df").df() - assert result_df['value'][1] is None - assert result_df['value'][2] is None + assert result_df["value"][1] is None + assert result_df["value"][2] is None diff --git a/tests/fast/pandas/test_issue_1767.py b/tests/fast/pandas/test_issue_1767.py index e37f19e1..27f0c2ff 100644 --- a/tests/fast/pandas/test_issue_1767.py +++ b/tests/fast/pandas/test_issue_1767.py @@ -9,7 +9,7 @@ # Join from pandas not matching identical strings #1767 class TestIssue1767(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_unicode_join_pandas(self, duckdb_cursor, pandas): A = pandas.DataFrame({"key": ["a", "п"]}) B = pandas.DataFrame({"key": ["a", "п"]}) @@ -18,6 +18,6 @@ def test_unicode_join_pandas(self, duckdb_cursor, pandas): q = arrow.query("""SELECT key FROM "A" FULL JOIN "B" USING ("key") ORDER BY key""") result = q.df() - d = {'key': ["a", "п"]} + d = {"key": ["a", "п"]} df = pandas.DataFrame(data=d) pandas.testing.assert_frame_equal(result, df) diff --git a/tests/fast/pandas/test_limit.py b/tests/fast/pandas/test_limit.py index 4a03c24f..460716cd 100644 --- a/tests/fast/pandas/test_limit.py +++ b/tests/fast/pandas/test_limit.py @@ -4,22 +4,22 @@ class TestLimitPandas(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_limit_df(self, duckdb_cursor, pandas): df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) limit_df = duckdb.limit(df_in, 2) assert len(limit_df.execute().fetchall()) == 2 - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_aggregate_df(self, duckdb_cursor, pandas): df_in = pandas.DataFrame( { - 'numbers': [1, 2, 2, 2], + "numbers": [1, 2, 2, 2], } ) - aggregate_df = duckdb.aggregate(df_in, 'count(numbers)', 'numbers').order('all') + aggregate_df = duckdb.aggregate(df_in, "count(numbers)", "numbers").order("all") assert aggregate_df.execute().fetchall() == [(1,), (3,)] diff --git a/tests/fast/pandas/test_pandas_arrow.py b/tests/fast/pandas/test_pandas_arrow.py index 8729362d..e1661041 100644 --- a/tests/fast/pandas/test_pandas_arrow.py +++ b/tests/fast/pandas/test_pandas_arrow.py @@ -4,7 +4,7 @@ from conftest import pandas_supports_arrow_backend -pd = pytest.importorskip("pandas", '2.0.0') +pd = pytest.importorskip("pandas", "2.0.0") import numpy as np from pandas.api.types import is_integer_dtype @@ -13,7 +13,7 @@ class TestPandasArrow(object): def test_pandas_arrow(self, duckdb_cursor): pd = pytest.importorskip("pandas") - df = pd.DataFrame({'a': pd.Series([5, 4, 3])}).convert_dtypes() + df = pd.DataFrame({"a": pd.Series([5, 4, 3])}).convert_dtypes() con = duckdb.connect() res = con.sql("select * from df").fetchall() assert res == [(5,), (4,), (3,)] @@ -21,8 +21,8 @@ def test_pandas_arrow(self, duckdb_cursor): def test_mixed_columns(self): df = pd.DataFrame( { - 'strings': pd.Series(['abc', 'DuckDB', 'quack', 'quack']), - 'timestamps': pd.Series( + "strings": pd.Series(["abc", "DuckDB", "quack", "quack"]), + "timestamps": pd.Series( [ datetime.datetime(1990, 10, 21), datetime.datetime(2023, 1, 11), @@ -30,23 +30,23 @@ def test_mixed_columns(self): datetime.datetime(1990, 10, 21), ] ), - 'objects': pd.Series([[5, 4, 3], 'test', None, {'a': 42}]), - 'integers': np.ndarray((4,), buffer=np.array([1, 2, 3, 4, 5]), offset=np.int_().itemsize, dtype=int), + "objects": pd.Series([[5, 4, 3], "test", None, {"a": 42}]), + "integers": np.ndarray((4,), buffer=np.array([1, 2, 3, 4, 5]), offset=np.int_().itemsize, dtype=int), } ) - pyarrow_df = df.convert_dtypes(dtype_backend='pyarrow') + pyarrow_df = df.convert_dtypes(dtype_backend="pyarrow") con = duckdb.connect() with pytest.raises( - duckdb.InvalidInputException, match='The dataframe could not be converted to a pyarrow.lib.Table' + duckdb.InvalidInputException, match="The dataframe could not be converted to a pyarrow.lib.Table" ): - res = con.sql('select * from pyarrow_df').fetchall() + res = con.sql("select * from pyarrow_df").fetchall() numpy_df = pd.DataFrame( - {'a': np.ndarray((2,), buffer=np.array([1, 2, 3]), offset=np.int_().itemsize, dtype=int)} - ).convert_dtypes(dtype_backend='numpy_nullable') + {"a": np.ndarray((2,), buffer=np.array([1, 2, 3]), offset=np.int_().itemsize, dtype=int)} + ).convert_dtypes(dtype_backend="numpy_nullable") arrow_df = pd.DataFrame( { - 'a': pd.Series( + "a": pd.Series( [ datetime.datetime(1990, 10, 21), datetime.datetime(2023, 1, 11), @@ -55,45 +55,45 @@ def test_mixed_columns(self): ] ) } - ).convert_dtypes(dtype_backend='pyarrow') - python_df = pd.DataFrame({'a': pd.Series(['test', [5, 4, 3], {'a': 42}])}).convert_dtypes() + ).convert_dtypes(dtype_backend="pyarrow") + python_df = pd.DataFrame({"a": pd.Series(["test", [5, 4, 3], {"a": 42}])}).convert_dtypes() - df = pd.concat([numpy_df['a'], arrow_df['a'], python_df['a']], axis=1, keys=['numpy', 'arrow', 'python']) - assert is_integer_dtype(df.dtypes['numpy']) - assert isinstance(df.dtypes['arrow'], pd.ArrowDtype) - assert isinstance(df.dtypes['python'], np.dtype('O').__class__) + df = pd.concat([numpy_df["a"], arrow_df["a"], python_df["a"]], axis=1, keys=["numpy", "arrow", "python"]) + assert is_integer_dtype(df.dtypes["numpy"]) + assert isinstance(df.dtypes["arrow"], pd.ArrowDtype) + assert isinstance(df.dtypes["python"], np.dtype("O").__class__) with pytest.raises( - duckdb.InvalidInputException, match='The dataframe could not be converted to a pyarrow.lib.Table' + duckdb.InvalidInputException, match="The dataframe could not be converted to a pyarrow.lib.Table" ): - res = con.sql('select * from df').fetchall() + res = con.sql("select * from df").fetchall() def test_empty_df(self): df = pd.DataFrame( { - 'string': pd.Series(data=[], dtype='string'), - 'object': pd.Series(data=[], dtype='object'), - 'Int64': pd.Series(data=[], dtype='Int64'), - 'Float64': pd.Series(data=[], dtype='Float64'), - 'bool': pd.Series(data=[], dtype='bool'), - 'datetime64[ns]': pd.Series(data=[], dtype='datetime64[ns]'), - 'datetime64[ms]': pd.Series(data=[], dtype='datetime64[ms]'), - 'datetime64[us]': pd.Series(data=[], dtype='datetime64[us]'), - 'datetime64[s]': pd.Series(data=[], dtype='datetime64[s]'), - 'category': pd.Series(data=[], dtype='category'), - 'timedelta64[ns]': pd.Series(data=[], dtype='timedelta64[ns]'), + "string": pd.Series(data=[], dtype="string"), + "object": pd.Series(data=[], dtype="object"), + "Int64": pd.Series(data=[], dtype="Int64"), + "Float64": pd.Series(data=[], dtype="Float64"), + "bool": pd.Series(data=[], dtype="bool"), + "datetime64[ns]": pd.Series(data=[], dtype="datetime64[ns]"), + "datetime64[ms]": pd.Series(data=[], dtype="datetime64[ms]"), + "datetime64[us]": pd.Series(data=[], dtype="datetime64[us]"), + "datetime64[s]": pd.Series(data=[], dtype="datetime64[s]"), + "category": pd.Series(data=[], dtype="category"), + "timedelta64[ns]": pd.Series(data=[], dtype="timedelta64[ns]"), } ) - pyarrow_df = df.convert_dtypes(dtype_backend='pyarrow') + pyarrow_df = df.convert_dtypes(dtype_backend="pyarrow") con = duckdb.connect() - res = con.sql('select * from pyarrow_df').fetchall() + res = con.sql("select * from pyarrow_df").fetchall() assert res == [] def test_completely_null_df(self): df = pd.DataFrame( { - 'a': pd.Series( + "a": pd.Series( data=[ None, np.nan, @@ -102,35 +102,35 @@ def test_completely_null_df(self): ) } ) - pyarrow_df = df.convert_dtypes(dtype_backend='pyarrow') + pyarrow_df = df.convert_dtypes(dtype_backend="pyarrow") con = duckdb.connect() - res = con.sql('select * from pyarrow_df').fetchall() + res = con.sql("select * from pyarrow_df").fetchall() assert res == [(None,), (None,), (None,)] def test_mixed_nulls(self): df = pd.DataFrame( { - 'float': pd.Series(data=[4.123123, None, 7.23456], dtype='Float64'), - 'int64': pd.Series(data=[-234234124, 709329413, pd.NA], dtype='Int64'), - 'bool': pd.Series(data=[np.nan, True, False], dtype='boolean'), - 'string': pd.Series(data=['NULL', None, 'quack']), - 'list[str]': pd.Series(data=[['Huey', 'Dewey', 'Louie'], [None, pd.NA, np.nan, 'DuckDB'], None]), - 'datetime64': pd.Series( + "float": pd.Series(data=[4.123123, None, 7.23456], dtype="Float64"), + "int64": pd.Series(data=[-234234124, 709329413, pd.NA], dtype="Int64"), + "bool": pd.Series(data=[np.nan, True, False], dtype="boolean"), + "string": pd.Series(data=["NULL", None, "quack"]), + "list[str]": pd.Series(data=[["Huey", "Dewey", "Louie"], [None, pd.NA, np.nan, "DuckDB"], None]), + "datetime64": pd.Series( data=[datetime.datetime(2011, 8, 16, 22, 7, 8), None, datetime.datetime(2010, 4, 26, 18, 14, 14)] ), - 'date': pd.Series(data=[datetime.date(2008, 5, 28), datetime.date(2013, 7, 14), None]), + "date": pd.Series(data=[datetime.date(2008, 5, 28), datetime.date(2013, 7, 14), None]), } ) - pyarrow_df = df.convert_dtypes(dtype_backend='pyarrow') + pyarrow_df = df.convert_dtypes(dtype_backend="pyarrow") con = duckdb.connect() - res = con.sql('select * from pyarrow_df').fetchone() + res = con.sql("select * from pyarrow_df").fetchone() assert res == ( 4.123123, -234234124, None, - 'NULL', - ['Huey', 'Dewey', 'Louie'], + "NULL", + ["Huey", "Dewey", "Louie"], datetime.datetime(2011, 8, 16, 22, 7, 8), datetime.date(2008, 5, 28), ) diff --git a/tests/fast/pandas/test_pandas_category.py b/tests/fast/pandas/test_pandas_category.py index e86a97d9..4b29b3fb 100644 --- a/tests/fast/pandas/test_pandas_category.py +++ b/tests/fast/pandas/test_pandas_category.py @@ -7,7 +7,7 @@ def check_category_equal(category): df_in = pd.DataFrame( { - 'x': pd.Categorical(category, ordered=True), + "x": pd.Categorical(category, ordered=True), } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() @@ -23,7 +23,7 @@ def check_create_table(category): conn = duckdb.connect() conn.execute("PRAGMA enable_verification") - df_in = pd.DataFrame({'x': pd.Categorical(category, ordered=True), 'y': pd.Categorical(category, ordered=True)}) + df_in = pd.DataFrame({"x": pd.Categorical(category, ordered=True), "y": pd.Categorical(category, ordered=True)}) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() assert df_in.equals(df_out) @@ -39,7 +39,7 @@ def check_create_table(category): conn.execute("INSERT INTO t1 VALUES ('2','2')") res = conn.execute("SELECT x FROM t1 where x = '1'").fetchall() - assert res == [('1',)] + assert res == [("1",)] res = conn.execute("SELECT t1.x FROM t1 inner join t2 on (t1.x = t2.x)").fetchall() assert res == conn.execute("SELECT x FROM t1").fetchall() @@ -56,27 +56,27 @@ def check_create_table(category): class TestCategory(object): def test_category_simple(self, duckdb_cursor): - df_in = pd.DataFrame({'float': [1.0, 2.0, 1.0], 'int': pd.Series([1, 2, 1], dtype="category")}) + df_in = pd.DataFrame({"float": [1.0, 2.0, 1.0], "int": pd.Series([1, 2, 1], dtype="category")}) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() print(duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchall()) - print(df_out['int']) - assert numpy.all(df_out['float'] == numpy.array([1.0, 2.0, 1.0])) - assert numpy.all(df_out['int'] == numpy.array([1, 2, 1])) + print(df_out["int"]) + assert numpy.all(df_out["float"] == numpy.array([1.0, 2.0, 1.0])) + assert numpy.all(df_out["int"] == numpy.array([1, 2, 1])) def test_category_nulls(self, duckdb_cursor): - df_in = pd.DataFrame({'int': pd.Series([1, 2, None], dtype="category")}) + df_in = pd.DataFrame({"int": pd.Series([1, 2, None], dtype="category")}) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() print(duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchall()) - assert df_out['int'][0] == 1 - assert df_out['int'][1] == 2 - assert pd.isna(df_out['int'][2]) + assert df_out["int"][0] == 1 + assert df_out["int"][1] == 2 + assert pd.isna(df_out["int"][2]) def test_category_string(self, duckdb_cursor): - check_category_equal(['foo', 'bla', 'zoo', 'foo', 'foo', 'bla']) + check_category_equal(["foo", "bla", "zoo", "foo", "foo", "bla"]) def test_category_string_null(self, duckdb_cursor): - check_category_equal(['foo', 'bla', None, 'zoo', 'foo', 'foo', None, 'bla']) + check_category_equal(["foo", "bla", None, "zoo", "foo", "foo", None, "bla"]) def test_category_string_null_bug_4747(self, duckdb_cursor): check_category_equal([str(i) for i in range(160)] + [None]) @@ -84,18 +84,18 @@ def test_category_string_null_bug_4747(self, duckdb_cursor): def test_categorical_fetchall(self, duckdb_cursor): df_in = pd.DataFrame( { - 'x': pd.Categorical(['foo', 'bla', None, 'zoo', 'foo', 'foo', None, 'bla'], ordered=True), + "x": pd.Categorical(["foo", "bla", None, "zoo", "foo", "foo", None, "bla"], ordered=True), } ) assert duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchall() == [ - ('foo',), - ('bla',), + ("foo",), + ("bla",), (None,), - ('zoo',), - ('foo',), - ('foo',), + ("zoo",), + ("foo",), + ("foo",), (None,), - ('bla',), + ("bla",), ] def test_category_string_uint8(self, duckdb_cursor): @@ -105,30 +105,30 @@ def test_category_string_uint8(self, duckdb_cursor): check_create_table(category) def test_empty_categorical(self, duckdb_cursor): - empty_categoric_df = pd.DataFrame({'category': pd.Series(dtype='category')}) + empty_categoric_df = pd.DataFrame({"category": pd.Series(dtype="category")}) duckdb_cursor.execute("CREATE TABLE test AS SELECT * FROM empty_categoric_df") - res = duckdb_cursor.table('test').fetchall() + res = duckdb_cursor.table("test").fetchall() assert res == [] with pytest.raises(duckdb.ConversionException, match="Could not convert string 'test' to UINT8"): duckdb_cursor.execute("insert into test VALUES('test')") duckdb_cursor.execute("insert into test VALUES(NULL)") - res = duckdb_cursor.table('test').fetchall() + res = duckdb_cursor.table("test").fetchall() assert res == [(None,)] def test_category_fetch_df_chunk(self, duckdb_cursor): con = duckdb.connect() - categories = ['foo', 'bla', None, 'zoo', 'foo', 'foo', None, 'bla'] + categories = ["foo", "bla", None, "zoo", "foo", "foo", None, "bla"] result = categories * 256 categories = result * 2 df_result = pd.DataFrame( { - 'x': pd.Categorical(result, ordered=True), + "x": pd.Categorical(result, ordered=True), } ) df_in = pd.DataFrame( { - 'x': pd.Categorical(categories, ordered=True), + "x": pd.Categorical(categories, ordered=True), } ) con.register("data", df_in) @@ -146,8 +146,8 @@ def test_category_fetch_df_chunk(self, duckdb_cursor): def test_category_mix(self, duckdb_cursor): df_in = pd.DataFrame( { - 'float': [1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 0.0], - 'x': pd.Categorical(['foo', 'bla', None, 'zoo', 'foo', 'foo', None, 'bla'], ordered=True), + "float": [1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 0.0], + "x": pd.Categorical(["foo", "bla", None, "zoo", "foo", "foo", None, "bla"], ordered=True), } ) diff --git a/tests/fast/pandas/test_pandas_enum.py b/tests/fast/pandas/test_pandas_enum.py index 9dc13a64..b1eb2c7f 100644 --- a/tests/fast/pandas/test_pandas_enum.py +++ b/tests/fast/pandas/test_pandas_enum.py @@ -15,7 +15,7 @@ def test_3480(self, duckdb_cursor): """ ) df = duckdb_cursor.query(f"SELECT * FROM tab LIMIT 0;").to_df() - assert df["cat"].cat.categories.equals(pd.Index(['marie', 'duchess', 'toulouse'])) + assert df["cat"].cat.categories.equals(pd.Index(["marie", "duchess", "toulouse"])) duckdb_cursor.execute("DROP TABLE tab") duckdb_cursor.execute("DROP TYPE cat") @@ -32,14 +32,14 @@ def test_3479(self, duckdb_cursor): df = pd.DataFrame( { - "cat2": pd.Series(['duchess', 'toulouse', 'marie', None, "berlioz", "o_malley"], dtype="category"), + "cat2": pd.Series(["duchess", "toulouse", "marie", None, "berlioz", "o_malley"], dtype="category"), "amt": [1, 2, 3, 4, 5, 6], } ) - duckdb_cursor.register('df', df) + duckdb_cursor.register("df", df) with pytest.raises( duckdb.ConversionException, - match='Type UINT8 with value 0 can\'t be cast because the value is out of range for the destination type UINT8', + match="Type UINT8 with value 0 can't be cast because the value is out of range for the destination type UINT8", ): duckdb_cursor.execute(f"INSERT INTO tab SELECT * FROM df;") diff --git a/tests/fast/pandas/test_pandas_limit.py b/tests/fast/pandas/test_pandas_limit.py index 506d5dd5..d551a6e4 100644 --- a/tests/fast/pandas/test_pandas_limit.py +++ b/tests/fast/pandas/test_pandas_limit.py @@ -6,9 +6,9 @@ class TestPandasLimit(object): def test_pandas_limit(self, duckdb_cursor): con = duckdb.connect() - df = con.execute('select * from range(10000000) tbl(i)').df() + df = con.execute("select * from range(10000000) tbl(i)").df() - con.execute('SET threads=8') + con.execute("SET threads=8") - limit_df = con.execute('SELECT * FROM df WHERE i=334 OR i>9967864 LIMIT 5').df() - assert list(limit_df['i']) == [334, 9967865, 9967866, 9967867, 9967868] + limit_df = con.execute("SELECT * FROM df WHERE i=334 OR i>9967864 LIMIT 5").df() + assert list(limit_df["i"]) == [334, 9967865, 9967866, 9967867, 9967868] diff --git a/tests/fast/pandas/test_pandas_na.py b/tests/fast/pandas/test_pandas_na.py index f165d180..7bc01003 100644 --- a/tests/fast/pandas/test_pandas_na.py +++ b/tests/fast/pandas/test_pandas_na.py @@ -16,20 +16,20 @@ def assert_nullness(items, null_indices): @pytest.mark.skipif(platform.system() == "Emscripten", reason="Pandas interaction is broken in Pyodide 3.11") class TestPandasNA(object): - @pytest.mark.parametrize('rows', [100, duckdb.__standard_vector_size__, 5000, 1000000]) - @pytest.mark.parametrize('pd', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("rows", [100, duckdb.__standard_vector_size__, 5000, 1000000]) + @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_pandas_string_null(self, duckdb_cursor, rows, pd): df: pd.DataFrame = pd.DataFrame(index=np.arange(rows)) df["string_column"] = pd.Series(dtype="string") e_df_rel = duckdb_cursor.from_df(df) - assert e_df_rel.types == ['VARCHAR'] + assert e_df_rel.types == ["VARCHAR"] roundtrip = e_df_rel.df() - assert roundtrip['string_column'].dtype == 'object' - expected = pd.DataFrame({'string_column': [None for _ in range(rows)]}) + assert roundtrip["string_column"].dtype == "object" + expected = pd.DataFrame({"string_column": [None for _ in range(rows)]}) pd.testing.assert_frame_equal(expected, roundtrip) def test_pandas_na(self, duckdb_cursor): - pd = pytest.importorskip('pandas', minversion='1.0.0', reason='Support for pandas.NA has not been added yet') + pd = pytest.importorskip("pandas", minversion="1.0.0", reason="Support for pandas.NA has not been added yet") # DataFrame containing a single pd.NA df = pd.DataFrame(pd.Series([pd.NA])) @@ -46,7 +46,7 @@ def test_pandas_na(self, duckdb_cursor): # Test if pd.NA behaves the same as np.nan once converted nan_df = pd.DataFrame( { - 'a': [ + "a": [ 1.123, 5.23234, np.nan, @@ -60,7 +60,7 @@ def test_pandas_na(self, duckdb_cursor): ) na_df = pd.DataFrame( { - 'a': [ + "a": [ 1.123, 5.23234, pd.NA, @@ -72,15 +72,15 @@ def test_pandas_na(self, duckdb_cursor): ] } ) - assert str(nan_df['a'].dtype) == 'float64' - assert str(na_df['a'].dtype) == 'object' # pd.NA values turn the column into 'object' + assert str(nan_df["a"].dtype) == "float64" + assert str(na_df["a"].dtype) == "object" # pd.NA values turn the column into 'object' nan_result = duckdb_cursor.execute("select * from nan_df").df() na_result = duckdb_cursor.execute("select * from na_df").df() pd.testing.assert_frame_equal(nan_result, na_result) # Mixed with stringified pd.NA values - na_string_df = pd.DataFrame({'a': [str(pd.NA), str(pd.NA), pd.NA, str(pd.NA), pd.NA, pd.NA, pd.NA, str(pd.NA)]}) + na_string_df = pd.DataFrame({"a": [str(pd.NA), str(pd.NA), pd.NA, str(pd.NA), pd.NA, pd.NA, pd.NA, str(pd.NA)]}) null_indices = [2, 4, 5, 6] res = duckdb_cursor.execute("select * from na_string_df").fetchall() items = [x[0] for x in [y for y in res]] diff --git a/tests/fast/pandas/test_pandas_object.py b/tests/fast/pandas/test_pandas_object.py index c00fcbc2..9e10681c 100644 --- a/tests/fast/pandas/test_pandas_object.py +++ b/tests/fast/pandas/test_pandas_object.py @@ -9,22 +9,22 @@ class TestPandasObject(object): def test_object_lotof_nulls(self): # Test mostly null column data = [None] + [1] + [None] * 10000 # Last element is 1, others are None - pandas_df = pd.DataFrame(data, columns=['c'], dtype=object) + pandas_df = pd.DataFrame(data, columns=["c"], dtype=object) con = duckdb.connect() - assert con.execute('FROM pandas_df where c is not null').fetchall() == [(1.0,)] + assert con.execute("FROM pandas_df where c is not null").fetchall() == [(1.0,)] # Test all nulls, should return varchar data = [None] * 10000 # Last element is 1, others are None - pandas_df_2 = pd.DataFrame(data, columns=['c'], dtype=object) - assert con.execute('FROM pandas_df_2 limit 1').fetchall() == [(None,)] - assert con.execute('select typeof(c) FROM pandas_df_2 limit 1').fetchall() == [('"NULL"',)] + pandas_df_2 = pd.DataFrame(data, columns=["c"], dtype=object) + assert con.execute("FROM pandas_df_2 limit 1").fetchall() == [(None,)] + assert con.execute("select typeof(c) FROM pandas_df_2 limit 1").fetchall() == [('"NULL"',)] def test_object_to_string(self, duckdb_cursor): - con = duckdb.connect(database=':memory:', read_only=False) - x = pd.DataFrame([[1, 'a', 2], [1, None, 2], [1, 1.1, 2], [1, 1.1, 2], [1, 1.1, 2]]) + con = duckdb.connect(database=":memory:", read_only=False) + x = pd.DataFrame([[1, "a", 2], [1, None, 2], [1, 1.1, 2], [1, 1.1, 2], [1, 1.1, 2]]) x = x.iloc[1:].copy() # middle col now entirely native float items - con.register('view2', x) - df = con.execute('select * from view2').fetchall() + con.register("view2", x) + df = con.execute("select * from view2").fetchall() assert df == [(1, None, 2), (1, 1.1, 2), (1, 1.1, 2), (1, 1.1, 2)] def test_tuple_to_list(self, duckdb_cursor): @@ -45,7 +45,7 @@ def test_tuple_to_list(self, duckdb_cursor): ) ) duckdb_cursor.execute("CREATE TABLE test as SELECT * FROM tuple_df") - res = duckdb_cursor.table('test').fetchall() + res = duckdb_cursor.table("test").fetchall() assert res == [([1, 2, 3],), ([4, 5, 6],)] def test_2273(self, duckdb_cursor): @@ -56,8 +56,8 @@ def test_object_to_string_with_stride(self, duckdb_cursor): data = np.array([["a", "b", "c"], [1, 2, 3], [1, 2, 3], [11, 22, 33]]) df = pd.DataFrame(data=data[1:,], columns=data[0]) duckdb_cursor.register("object_with_strides", df) - res = duckdb_cursor.sql('select * from object_with_strides').fetchall() - assert res == [('1', '2', '3'), ('1', '2', '3'), ('11', '22', '33')] + res = duckdb_cursor.sql("select * from object_with_strides").fetchall() + assert res == [("1", "2", "3"), ("1", "2", "3"), ("11", "22", "33")] def test_2499(self, duckdb_cursor): df = pd.DataFrame( @@ -65,11 +65,11 @@ def test_2499(self, duckdb_cursor): [ np.array( [ - {'a': 0.881040697801939}, - {'a': 0.9922600577751953}, - {'a': 0.1589674833259317}, - {'a': 0.8928451262745073}, - {'a': 0.07022897889168278}, + {"a": 0.881040697801939}, + {"a": 0.9922600577751953}, + {"a": 0.1589674833259317}, + {"a": 0.8928451262745073}, + {"a": 0.07022897889168278}, ], dtype=object, ) @@ -77,11 +77,11 @@ def test_2499(self, duckdb_cursor): [ np.array( [ - {'a': 0.8759413504156746}, - {'a': 0.055784331256246156}, - {'a': 0.8605151517439655}, - {'a': 0.40807139339337695}, - {'a': 0.8429048322459952}, + {"a": 0.8759413504156746}, + {"a": 0.055784331256246156}, + {"a": 0.8605151517439655}, + {"a": 0.40807139339337695}, + {"a": 0.8429048322459952}, ], dtype=object, ) @@ -89,19 +89,19 @@ def test_2499(self, duckdb_cursor): [ np.array( [ - {'a': 0.9697093934032401}, - {'a': 0.9529257667149468}, - {'a': 0.21398182248591713}, - {'a': 0.6328512122275955}, - {'a': 0.5146953214092728}, + {"a": 0.9697093934032401}, + {"a": 0.9529257667149468}, + {"a": 0.21398182248591713}, + {"a": 0.6328512122275955}, + {"a": 0.5146953214092728}, ], dtype=object, ) ], ], - columns=['col'], + columns=["col"], ) - con = duckdb.connect(database=':memory:', read_only=False) - con.register('df', df) - assert con.execute('select count(*) from df').fetchone() == (3,) + con = duckdb.connect(database=":memory:", read_only=False) + con.register("df", df) + assert con.execute("select count(*) from df").fetchone() == (3,) diff --git a/tests/fast/pandas/test_pandas_string.py b/tests/fast/pandas/test_pandas_string.py index 494823ad..4bd5996d 100644 --- a/tests/fast/pandas/test_pandas_string.py +++ b/tests/fast/pandas/test_pandas_string.py @@ -5,23 +5,23 @@ class TestPandasString(object): def test_pandas_string(self, duckdb_cursor): - strings = numpy.array(['foo', 'bar', 'baz']) + strings = numpy.array(["foo", "bar", "baz"]) # https://pandas.pydata.org/pandas-docs/stable/user_guide/text.html df_in = pd.DataFrame( { - 'object': pd.Series(strings, dtype='object'), + "object": pd.Series(strings, dtype="object"), } ) # Only available in pandas 1.0.0 - if hasattr(pd, 'StringDtype'): - df_in['string'] = pd.Series(strings, dtype=pd.StringDtype()) + if hasattr(pd, "StringDtype"): + df_in["string"] = pd.Series(strings, dtype=pd.StringDtype()) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() - assert numpy.all(df_out['object'] == strings) - if hasattr(pd, 'StringDtype'): - assert numpy.all(df_out['string'] == strings) + assert numpy.all(df_out["object"] == strings) + if hasattr(pd, "StringDtype"): + assert numpy.all(df_out["string"] == strings) def test_bug_2467(self, duckdb_cursor): N = 1_000_000 @@ -35,11 +35,8 @@ def test_bug_2467(self, duckdb_cursor): CREATE TABLE t1 AS SELECT * FROM df """ ) - assert ( - con.execute( - f""" + assert con.execute( + f""" SELECT count(*) from t1 """ - ).fetchall() - == [(3000000,)] - ) + ).fetchall() == [(3000000,)] diff --git a/tests/fast/pandas/test_pandas_timestamp.py b/tests/fast/pandas/test_pandas_timestamp.py index 8e17db21..835ff3af 100644 --- a/tests/fast/pandas/test_pandas_timestamp.py +++ b/tests/fast/pandas/test_pandas_timestamp.py @@ -7,30 +7,30 @@ from conftest import pandas_2_or_higher -@pytest.mark.parametrize('timezone', ['UTC', 'CET', 'Asia/Kathmandu']) +@pytest.mark.parametrize("timezone", ["UTC", "CET", "Asia/Kathmandu"]) @pytest.mark.skipif(not pandas_2_or_higher(), reason="Pandas <2.0.0 does not support timezones in the metadata string") def test_run_pandas_with_tz(timezone): con = duckdb.connect() con.execute(f"SET TimeZone = '{timezone}'") df = pandas.DataFrame( { - 'timestamp': pandas.Series( - data=[pandas.Timestamp(year=2022, month=1, day=1, hour=10, minute=15, tz=timezone, unit='us')], - dtype=f'datetime64[us, {timezone}]', + "timestamp": pandas.Series( + data=[pandas.Timestamp(year=2022, month=1, day=1, hour=10, minute=15, tz=timezone, unit="us")], + dtype=f"datetime64[us, {timezone}]", ) } ) duck_df = con.from_df(df).df() - assert duck_df['timestamp'][0] == df['timestamp'][0] + assert duck_df["timestamp"][0] == df["timestamp"][0] def test_timestamp_conversion(duckdb_cursor): - tzinfo = pandas.Timestamp('2024-01-01 00:00:00+0100', tz='Europe/Copenhagen').tzinfo + tzinfo = pandas.Timestamp("2024-01-01 00:00:00+0100", tz="Europe/Copenhagen").tzinfo ts_df = pandas.DataFrame( { "ts": [ - pandas.Timestamp('2024-01-01 00:00:00+0100', tz=tzinfo), - pandas.Timestamp('2024-01-02 00:00:00+0100', tz=tzinfo), + pandas.Timestamp("2024-01-01 00:00:00+0100", tz=tzinfo), + pandas.Timestamp("2024-01-02 00:00:00+0100", tz=tzinfo), ] } ) diff --git a/tests/fast/pandas/test_pandas_types.py b/tests/fast/pandas/test_pandas_types.py index b21c7f14..fcc63b82 100644 --- a/tests/fast/pandas/test_pandas_types.py +++ b/tests/fast/pandas/test_pandas_types.py @@ -11,7 +11,7 @@ def round_trip(data, pandas_type): df_in = pd.DataFrame( { - 'object': pd.Series(data, dtype=pandas_type), + "object": pd.Series(data, dtype=pandas_type), } ) @@ -23,7 +23,7 @@ def round_trip(data, pandas_type): class TestNumpyNullableTypes(object): def test_pandas_numeric(self): - base_df = pd.DataFrame({'a': range(10)}) + base_df = pd.DataFrame({"a": range(10)}) data_types = [ "uint8", @@ -46,7 +46,7 @@ def test_pandas_numeric(self): "float64", ] - if version.parse(pd.__version__) >= version.parse('1.2.0'): + if version.parse(pd.__version__) >= version.parse("1.2.0"): # These DTypes where added in 1.2.0 data_types.extend(["Float32", "Float64"]) # Generate a dataframe with all the types, in the form of: @@ -59,7 +59,7 @@ def test_pandas_numeric(self): df = pd.DataFrame.from_dict(data) conn = duckdb.connect() - out_df = conn.execute('select * from df').df() + out_df = conn.execute("select * from df").df() # Verify that the types in the out_df are correct # FIXME: we don't support outputting pandas specific types (i.e UInt64) @@ -68,14 +68,14 @@ def test_pandas_numeric(self): assert str(out_df[column_name].dtype) == item.lower() def test_pandas_unsigned(self, duckdb_cursor): - unsigned_types = ['uint8', 'uint16', 'uint32', 'uint64'] + unsigned_types = ["uint8", "uint16", "uint32", "uint64"] data = numpy.array([0, 1, 2, 3]) for u_type in unsigned_types: round_trip(data, u_type) def test_pandas_bool(self, duckdb_cursor): data = numpy.array([True, False, False, True]) - round_trip(data, 'bool') + round_trip(data, "bool") def test_pandas_masked_float64(self, duckdb_cursor, tmp_path): pa = pytest.importorskip("pyarrow") @@ -102,85 +102,85 @@ def test_pandas_boolean(self, duckdb_cursor): data = numpy.array([True, None, pd.NA, numpy.nan, True]) df_in = pd.DataFrame( { - 'object': pd.Series(data, dtype='boolean'), + "object": pd.Series(data, dtype="boolean"), } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() - assert df_out['object'][0] == df_in['object'][0] - assert pd.isna(df_out['object'][1]) - assert pd.isna(df_out['object'][2]) - assert pd.isna(df_out['object'][3]) - assert df_out['object'][4] == df_in['object'][4] + assert df_out["object"][0] == df_in["object"][0] + assert pd.isna(df_out["object"][1]) + assert pd.isna(df_out["object"][2]) + assert pd.isna(df_out["object"][3]) + assert df_out["object"][4] == df_in["object"][4] def test_pandas_float32(self, duckdb_cursor): data = numpy.array([0.1, 0.32, 0.78, numpy.nan]) df_in = pd.DataFrame( { - 'object': pd.Series(data, dtype='float32'), + "object": pd.Series(data, dtype="float32"), } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() - assert df_out['object'][0] == df_in['object'][0] - assert df_out['object'][1] == df_in['object'][1] - assert df_out['object'][2] == df_in['object'][2] - assert pd.isna(df_out['object'][3]) + assert df_out["object"][0] == df_in["object"][0] + assert df_out["object"][1] == df_in["object"][1] + assert df_out["object"][2] == df_in["object"][2] + assert pd.isna(df_out["object"][3]) def test_pandas_float64(self): - data = numpy.array([0.233, numpy.nan, 3456.2341231, float('-inf'), -23424.45345, float('+inf'), 0.0000000001]) + data = numpy.array([0.233, numpy.nan, 3456.2341231, float("-inf"), -23424.45345, float("+inf"), 0.0000000001]) df_in = pd.DataFrame( { - 'object': pd.Series(data, dtype='float64'), + "object": pd.Series(data, dtype="float64"), } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() for i in range(len(data)): - if pd.isna(df_out['object'][i]): + if pd.isna(df_out["object"][i]): assert i == 1 continue - assert df_out['object'][i] == df_in['object'][i] + assert df_out["object"][i] == df_in["object"][i] def test_pandas_interval(self, duckdb_cursor): - if pd.__version__ != '1.2.4': + if pd.__version__ != "1.2.4": return data = numpy.array([2069211000000000, numpy.datetime64("NaT")]) df_in = pd.DataFrame( { - 'object': pd.Series(data, dtype='timedelta64[ns]'), + "object": pd.Series(data, dtype="timedelta64[ns]"), } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() - assert df_out['object'][0] == df_in['object'][0] - assert pd.isnull(df_out['object'][1]) + assert df_out["object"][0] == df_in["object"][0] + assert pd.isnull(df_out["object"][1]) def test_pandas_encoded_utf8(self, duckdb_cursor): - data = u'\u00c3' # Unicode data - data = [data.encode('utf8')] + data = "\u00c3" # Unicode data + data = [data.encode("utf8")] expected_result = data[0] - df_in = pd.DataFrame({'object': pd.Series(data, dtype='object')}) + df_in = pd.DataFrame({"object": pd.Series(data, dtype="object")}) result = duckdb.query_df(df_in, "data", "SELECT * FROM data").fetchone()[0] assert result == expected_result @pytest.mark.parametrize( - 'dtype', + "dtype", [ - 'bool', - 'utinyint', - 'usmallint', - 'uinteger', - 'ubigint', - 'tinyint', - 'smallint', - 'integer', - 'bigint', - 'float', - 'double', + "bool", + "utinyint", + "usmallint", + "uinteger", + "ubigint", + "tinyint", + "smallint", + "integer", + "bigint", + "float", + "double", ], ) def test_producing_nullable_dtypes(self, duckdb_cursor, dtype): @@ -190,19 +190,19 @@ def __init__(self, value, expected_dtype) -> None: self.expected_dtype = expected_dtype inputs = { - 'bool': Input('true', 'BooleanDtype'), - 'utinyint': Input('255', 'UInt8Dtype'), - 'usmallint': Input('65535', 'UInt16Dtype'), - 'uinteger': Input('4294967295', 'UInt32Dtype'), - 'ubigint': Input('18446744073709551615', 'UInt64Dtype'), - 'tinyint': Input('-128', 'Int8Dtype'), - 'smallint': Input('-32768', 'Int16Dtype'), - 'integer': Input('-2147483648', 'Int32Dtype'), - 'bigint': Input('-9223372036854775808', 'Int64Dtype'), - 'float': Input('268043421344044473239570760152672894976.0000000000', 'float32'), - 'double': Input( - '14303088389124869511075243108389716684037132417196499782261853698893384831666205572097390431189931733040903060865714975797777061496396865611606109149583360363636503436181348332896211726552694379264498632046075093077887837955077425420408952536212326792778411457460885268567735875437456412217418386401944141824.0000000000', - 'float64', + "bool": Input("true", "BooleanDtype"), + "utinyint": Input("255", "UInt8Dtype"), + "usmallint": Input("65535", "UInt16Dtype"), + "uinteger": Input("4294967295", "UInt32Dtype"), + "ubigint": Input("18446744073709551615", "UInt64Dtype"), + "tinyint": Input("-128", "Int8Dtype"), + "smallint": Input("-32768", "Int16Dtype"), + "integer": Input("-2147483648", "Int32Dtype"), + "bigint": Input("-9223372036854775808", "Int64Dtype"), + "float": Input("268043421344044473239570760152672894976.0000000000", "float32"), + "double": Input( + "14303088389124869511075243108389716684037132417196499782261853698893384831666205572097390431189931733040903060865714975797777061496396865611606109149583360363636503436181348332896211726552694379264498632046075093077887837955077425420408952536212326792778411457460885268567735875437456412217418386401944141824.0000000000", + "float64", ), } @@ -222,7 +222,7 @@ def __init__(self, value, expected_dtype) -> None: rel = duckdb_cursor.sql(query) # Pandas <= 2.2.3 does not convert without throwing a warning - warnings.simplefilter(action='ignore', category=RuntimeWarning) + warnings.simplefilter(action="ignore", category=RuntimeWarning) with suppress(TypeError): df = rel.df() warnings.resetwarnings() @@ -231,4 +231,4 @@ def __init__(self, value, expected_dtype) -> None: expected_dtype = getattr(pd, input.expected_dtype) else: expected_dtype = numpy.dtype(input.expected_dtype) - assert isinstance(df['a'].dtype, expected_dtype) + assert isinstance(df["a"].dtype, expected_dtype) diff --git a/tests/fast/pandas/test_pandas_unregister.py b/tests/fast/pandas/test_pandas_unregister.py index 794e5910..fce8f42a 100644 --- a/tests/fast/pandas/test_pandas_unregister.py +++ b/tests/fast/pandas/test_pandas_unregister.py @@ -8,7 +8,7 @@ class TestPandasUnregister(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_unregister1(self, duckdb_cursor, pandas): df = pandas.DataFrame([[1, 2, 3], [4, 5, 6]]) connection = duckdb.connect(":memory:") @@ -16,13 +16,13 @@ def test_pandas_unregister1(self, duckdb_cursor, pandas): df2 = connection.execute("SELECT * FROM dataframe;").fetchdf() connection.unregister("dataframe") - with pytest.raises(duckdb.CatalogException, match='Table with name dataframe does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name dataframe does not exist"): connection.execute("SELECT * FROM dataframe;").fetchdf() - with pytest.raises(duckdb.CatalogException, match='View with name dataframe does not exist'): + with pytest.raises(duckdb.CatalogException, match="View with name dataframe does not exist"): connection.execute("DROP VIEW dataframe;") connection.execute("DROP VIEW IF EXISTS dataframe;") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_unregister2(self, duckdb_cursor, pandas): fd, db = tempfile.mkstemp() os.close(fd) @@ -39,7 +39,7 @@ def test_pandas_unregister2(self, duckdb_cursor, pandas): connection = duckdb.connect(db) assert len(connection.execute("PRAGMA show_tables;").fetchall()) == 0 - with pytest.raises(duckdb.CatalogException, match='Table with name dataframe does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name dataframe does not exist"): connection.execute("SELECT * FROM dataframe;").fetchdf() connection.close() @@ -50,6 +50,6 @@ def test_pandas_unregister2(self, duckdb_cursor, pandas): # Reconnecting after DataFrame freed. connection = duckdb.connect(db) assert len(connection.execute("PRAGMA show_tables;").fetchall()) == 0 - with pytest.raises(duckdb.CatalogException, match='Table with name dataframe does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name dataframe does not exist"): connection.execute("SELECT * FROM dataframe;").fetchdf() connection.close() diff --git a/tests/fast/pandas/test_pandas_update.py b/tests/fast/pandas/test_pandas_update.py index 663d6da2..86d17154 100644 --- a/tests/fast/pandas/test_pandas_update.py +++ b/tests/fast/pandas/test_pandas_update.py @@ -4,10 +4,10 @@ class TestPandasUpdateList(object): def test_pandas_update_list(self, duckdb_cursor): - duckdb_cursor = duckdb.connect(':memory:') - duckdb_cursor.execute('create table t (l int[])') - duckdb_cursor.execute('insert into t values ([1, 2]), ([3,4])') - duckdb_cursor.execute('update t set l = [5, 6]') - expected = pd.DataFrame({'l': [[5, 6], [5, 6]]}) - res = duckdb_cursor.execute('select * from t').fetchdf() + duckdb_cursor = duckdb.connect(":memory:") + duckdb_cursor.execute("create table t (l int[])") + duckdb_cursor.execute("insert into t values ([1, 2]), ([3,4])") + duckdb_cursor.execute("update t set l = [5, 6]") + expected = pd.DataFrame({"l": [[5, 6], [5, 6]]}) + res = duckdb_cursor.execute("select * from t").fetchdf() pd.testing.assert_frame_equal(expected, res) diff --git a/tests/fast/pandas/test_parallel_pandas_scan.py b/tests/fast/pandas/test_parallel_pandas_scan.py index a9fd99b9..d113bbca 100644 --- a/tests/fast/pandas/test_parallel_pandas_scan.py +++ b/tests/fast/pandas/test_parallel_pandas_scan.py @@ -24,8 +24,8 @@ def run_parallel_queries(main_table, left_join_table, expected_df, pandas, itera try: duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") - duckdb_conn.register('main_table', main_table) - duckdb_conn.register('left_join_table', left_join_table) + duckdb_conn.register("main_table", main_table) + duckdb_conn.register("left_join_table", left_join_table) output_df = duckdb_conn.execute(sql).fetchdf() pandas.testing.assert_frame_equal(expected_df, output_df) print(output_df) @@ -36,69 +36,69 @@ def run_parallel_queries(main_table, left_join_table, expected_df, pandas, itera class TestParallelPandasScan(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_numeric_scan(self, duckdb_cursor, pandas): main_table = pandas.DataFrame([{"join_column": 3}]) left_join_table = pandas.DataFrame([{"join_column": 3, "other_column": 4}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_ascii_text(self, duckdb_cursor, pandas): main_table = pandas.DataFrame([{"join_column": "text"}]) left_join_table = pandas.DataFrame([{"join_column": "text", "other_column": "more text"}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_unicode_text(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame([{"join_column": u"mühleisen"}]) - left_join_table = pandas.DataFrame([{"join_column": u"mühleisen", "other_column": u"höhöhö"}]) + main_table = pandas.DataFrame([{"join_column": "mühleisen"}]) + left_join_table = pandas.DataFrame([{"join_column": "mühleisen", "other_column": "höhöhö"}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_complex_unicode_text(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame([{"join_column": u"鴨"}]) - left_join_table = pandas.DataFrame([{"join_column": u"鴨", "other_column": u"數據庫"}]) + main_table = pandas.DataFrame([{"join_column": "鴨"}]) + left_join_table = pandas.DataFrame([{"join_column": "鴨", "other_column": "數據庫"}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_emojis(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame([{"join_column": u"🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️"}]) - left_join_table = pandas.DataFrame([{"join_column": u"🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️", "other_column": u"🦆🍞🦆"}]) + main_table = pandas.DataFrame([{"join_column": "🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️"}]) + left_join_table = pandas.DataFrame([{"join_column": "🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️", "other_column": "🦆🍞🦆"}]) run_parallel_queries(main_table, left_join_table, left_join_table, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_numeric_object(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame({'join_column': pandas.Series([3], dtype="Int8")}) + main_table = pandas.DataFrame({"join_column": pandas.Series([3], dtype="Int8")}) left_join_table = pandas.DataFrame( - {'join_column': pandas.Series([3], dtype="Int8"), 'other_column': pandas.Series([4], dtype="Int8")} + {"join_column": pandas.Series([3], dtype="Int8"), "other_column": pandas.Series([4], dtype="Int8")} ) expected_df = pandas.DataFrame( {"join_column": numpy.array([3], dtype=numpy.int8), "other_column": numpy.array([4], dtype=numpy.int8)} ) run_parallel_queries(main_table, left_join_table, expected_df, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_timestamp(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame({'join_column': [pandas.Timestamp('20180310T11:17:54Z')]}) + main_table = pandas.DataFrame({"join_column": [pandas.Timestamp("20180310T11:17:54Z")]}) left_join_table = pandas.DataFrame( { - 'join_column': [pandas.Timestamp('20180310T11:17:54Z')], - 'other_column': [pandas.Timestamp('20190310T11:17:54Z')], + "join_column": [pandas.Timestamp("20180310T11:17:54Z")], + "other_column": [pandas.Timestamp("20190310T11:17:54Z")], } ) expected_df = pandas.DataFrame( { - "join_column": numpy.array([datetime.datetime(2018, 3, 10, 11, 17, 54)], dtype='datetime64[ns]'), - "other_column": numpy.array([datetime.datetime(2019, 3, 10, 11, 17, 54)], dtype='datetime64[ns]'), + "join_column": numpy.array([datetime.datetime(2018, 3, 10, 11, 17, 54)], dtype="datetime64[ns]"), + "other_column": numpy.array([datetime.datetime(2019, 3, 10, 11, 17, 54)], dtype="datetime64[ns]"), } ) run_parallel_queries(main_table, left_join_table, expected_df, pandas) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_empty(self, duckdb_cursor, pandas): - df_empty = pandas.DataFrame({'A': []}) + df_empty = pandas.DataFrame({"A": []}) duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") - duckdb_conn.register('main_table', df_empty) - assert duckdb_conn.execute('select * from main_table').fetchall() == [] + duckdb_conn.register("main_table", df_empty) + assert duckdb_conn.execute("select * from main_table").fetchall() == [] diff --git a/tests/fast/pandas/test_partitioned_pandas_scan.py b/tests/fast/pandas/test_partitioned_pandas_scan.py index 32c5352f..d2447ef8 100644 --- a/tests/fast/pandas/test_partitioned_pandas_scan.py +++ b/tests/fast/pandas/test_partitioned_pandas_scan.py @@ -8,9 +8,9 @@ class TestPartitionedPandasScan(object): def test_parallel_pandas(self, duckdb_cursor): con = duckdb.connect() - df = pd.DataFrame({'i': numpy.arange(10000000)}) + df = pd.DataFrame({"i": numpy.arange(10000000)}) - con.register('df', df) + con.register("df", df) seq_results = con.execute("SELECT SUM(i) FROM df").fetchall() diff --git a/tests/fast/pandas/test_progress_bar.py b/tests/fast/pandas/test_progress_bar.py index 241cedd6..7c1c21e1 100644 --- a/tests/fast/pandas/test_progress_bar.py +++ b/tests/fast/pandas/test_progress_bar.py @@ -8,10 +8,10 @@ class TestProgressBarPandas(object): def test_progress_pandas_single(self, duckdb_cursor): con = duckdb.connect() - df = pd.DataFrame({'i': numpy.arange(10000000)}) + df = pd.DataFrame({"i": numpy.arange(10000000)}) - con.register('df', df) - con.register('df_2', df) + con.register("df", df) + con.register("df_2", df) con.execute("PRAGMA progress_bar_time=1") con.execute("PRAGMA disable_print_progress_bar") result = con.execute("SELECT SUM(df.i) FROM df inner join df_2 on (df.i = df_2.i)").fetchall() @@ -19,10 +19,10 @@ def test_progress_pandas_single(self, duckdb_cursor): def test_progress_pandas_parallel(self, duckdb_cursor): con = duckdb.connect() - df = pd.DataFrame({'i': numpy.arange(10000000)}) + df = pd.DataFrame({"i": numpy.arange(10000000)}) - con.register('df', df) - con.register('df_2', df) + con.register("df", df) + con.register("df_2", df) con.execute("PRAGMA progress_bar_time=1") con.execute("PRAGMA disable_print_progress_bar") con.execute("PRAGMA threads=4") @@ -31,8 +31,8 @@ def test_progress_pandas_parallel(self, duckdb_cursor): def test_progress_pandas_empty(self, duckdb_cursor): con = duckdb.connect() - df = pd.DataFrame({'i': []}) - con.register('df', df) + df = pd.DataFrame({"i": []}) + con.register("df", df) con.execute("PRAGMA progress_bar_time=1") con.execute("PRAGMA disable_print_progress_bar") result = con.execute("SELECT SUM(df.i) from df").fetchall() diff --git a/tests/fast/pandas/test_pyarrow_projection_pushdown.py b/tests/fast/pandas/test_pyarrow_projection_pushdown.py index e693e75c..b04f713a 100644 --- a/tests/fast/pandas/test_pyarrow_projection_pushdown.py +++ b/tests/fast/pandas/test_pyarrow_projection_pushdown.py @@ -6,7 +6,7 @@ pa = pytest.importorskip("pyarrow") ds = pytest.importorskip("pyarrow.dataset") -_ = pytest.importorskip("pandas", '2.0.0') +_ = pytest.importorskip("pandas", "2.0.0") @pytest.mark.skipif(not pandas_supports_arrow_backend(), reason="pandas does not support the 'pyarrow' backend") @@ -16,6 +16,6 @@ def test_projection_pushdown_no_filter(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE test (a INTEGER, b INTEGER, c INTEGER)") duckdb_conn.execute("INSERT INTO test VALUES (1,1,1),(10,10,10),(100,10,100),(NULL,NULL,NULL)") duck_tbl = duckdb_conn.table("test") - arrow_table = duck_tbl.df().convert_dtypes(dtype_backend='pyarrow') + arrow_table = duck_tbl.df().convert_dtypes(dtype_backend="pyarrow") duckdb_conn.register("testarrowtable", arrow_table) assert duckdb_conn.execute("SELECT sum(a) FROM testarrowtable").fetchall() == [(111,)] diff --git a/tests/fast/pandas/test_same_name.py b/tests/fast/pandas/test_same_name.py index f48eb7eb..ac4f407a 100644 --- a/tests/fast/pandas/test_same_name.py +++ b/tests/fast/pandas/test_same_name.py @@ -5,76 +5,76 @@ class TestMultipleColumnsSameName(object): def test_multiple_columns_with_same_name(self, duckdb_cursor): - df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'd': [9, 10, 11, 12]}) + df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "d": [9, 10, 11, 12]}) df = df.rename(columns={df.columns[1]: "a"}) - duckdb_cursor.register('df_view', df) + duckdb_cursor.register("df_view", df) - assert duckdb_cursor.table("df_view").columns == ['a', 'a_1', 'd'] + assert duckdb_cursor.table("df_view").columns == ["a", "a_1", "d"] assert duckdb_cursor.execute("select a_1 from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] assert duckdb_cursor.execute("select a from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] # Verify we are not changing original dataframe - assert all(df.columns == ['a', 'a', 'd']), df.columns + assert all(df.columns == ["a", "a", "d"]), df.columns def test_multiple_columns_with_same_name_relation(self, duckdb_cursor): - df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'd': [9, 10, 11, 12]}) + df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "d": [9, 10, 11, 12]}) df = df.rename(columns={df.columns[1]: "a"}) rel = duckdb_cursor.from_df(df) assert rel.query("df_view", "DESCRIBE df_view;").fetchall() == [ - ('a', 'BIGINT', 'YES', None, None, None), - ('a_1', 'BIGINT', 'YES', None, None, None), - ('d', 'BIGINT', 'YES', None, None, None), + ("a", "BIGINT", "YES", None, None, None), + ("a_1", "BIGINT", "YES", None, None, None), + ("d", "BIGINT", "YES", None, None, None), ] assert rel.query("df_view", "select a_1 from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] assert rel.query("df_view", "select a from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] # Verify we are not changing original dataframe - assert all(df.columns == ['a', 'a', 'd']), df.columns + assert all(df.columns == ["a", "a", "d"]), df.columns def test_multiple_columns_with_same_name_replacement_scans(self, duckdb_cursor): - df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'd': [9, 10, 11, 12]}) + df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "d": [9, 10, 11, 12]}) df = df.rename(columns={df.columns[1]: "a"}) assert duckdb_cursor.execute("select a_1 from df;").fetchall() == [(5,), (6,), (7,), (8,)] assert duckdb_cursor.execute("select a from df;").fetchall() == [(1,), (2,), (3,), (4,)] # Verify we are not changing original dataframe - assert all(df.columns == ['a', 'a', 'd']), df.columns + assert all(df.columns == ["a", "a", "d"]), df.columns def test_3669(self, duckdb_cursor): - df = pd.DataFrame([(1, 5, 9), (2, 6, 10), (3, 7, 11), (4, 8, 12)], columns=['a_1', 'a', 'a']) - duckdb_cursor.register('df_view', df) - assert duckdb_cursor.table("df_view").columns == ['a_1', 'a', 'a_2'] + df = pd.DataFrame([(1, 5, 9), (2, 6, 10), (3, 7, 11), (4, 8, 12)], columns=["a_1", "a", "a"]) + duckdb_cursor.register("df_view", df) + assert duckdb_cursor.table("df_view").columns == ["a_1", "a", "a_2"] assert duckdb_cursor.execute("select a_1 from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] assert duckdb_cursor.execute("select a from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] # Verify we are not changing original dataframe - assert all(df.columns == ['a_1', 'a', 'a']), df.columns + assert all(df.columns == ["a_1", "a", "a"]), df.columns def test_minimally_rename(self, duckdb_cursor): df = pd.DataFrame( - [(1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15), (4, 8, 12, 16)], columns=['a_1', 'a', 'a', 'a_2'] + [(1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15), (4, 8, 12, 16)], columns=["a_1", "a", "a", "a_2"] ) - duckdb_cursor.register('df_view', df) + duckdb_cursor.register("df_view", df) rel = duckdb_cursor.table("df_view") res = rel.columns - assert res == ['a_1', 'a', 'a_2', 'a_2_1'] + assert res == ["a_1", "a", "a_2", "a_2_1"] assert duckdb_cursor.execute("select a_1 from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] assert duckdb_cursor.execute("select a from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] assert duckdb_cursor.execute("select a_2 from df_view;").fetchall() == [(9,), (10,), (11,), (12,)] assert duckdb_cursor.execute("select a_2_1 from df_view;").fetchall() == [(13,), (14,), (15,), (16,)] # Verify we are not changing original dataframe - assert all(df.columns == ['a_1', 'a', 'a', 'a_2']), df.columns + assert all(df.columns == ["a_1", "a", "a", "a_2"]), df.columns def test_multiple_columns_with_same_name_2(self, duckdb_cursor): - df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8], 'a_1': [9, 10, 11, 12]}) + df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "a_1": [9, 10, 11, 12]}) df = df.rename(columns={df.columns[1]: "a_1"}) - duckdb_cursor.register('df_view', df) - assert duckdb_cursor.table("df_view").columns == ['a', 'a_1', 'a_1_1'] + duckdb_cursor.register("df_view", df) + assert duckdb_cursor.table("df_view").columns == ["a", "a_1", "a_1_1"] assert duckdb_cursor.execute("select a_1 from df_view;").fetchall() == [(5,), (6,), (7,), (8,)] assert duckdb_cursor.execute("select a from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] assert duckdb_cursor.execute("select a_1_1 from df_view;").fetchall() == [(9,), (10,), (11,), (12,)] def test_case_insensitive(self, duckdb_cursor): - df = pd.DataFrame({'A_1': [1, 2, 3, 4], 'a_1': [9, 10, 11, 12]}) - duckdb_cursor.register('df_view', df) - assert duckdb_cursor.table("df_view").columns == ['A_1', 'a_1_1'] + df = pd.DataFrame({"A_1": [1, 2, 3, 4], "a_1": [9, 10, 11, 12]}) + duckdb_cursor.register("df_view", df) + assert duckdb_cursor.table("df_view").columns == ["A_1", "a_1_1"] assert duckdb_cursor.execute("select a_1 from df_view;").fetchall() == [(1,), (2,), (3,), (4,)] assert duckdb_cursor.execute("select a_1_1 from df_view;").fetchall() == [(9,), (10,), (11,), (12,)] diff --git a/tests/fast/pandas/test_stride.py b/tests/fast/pandas/test_stride.py index 5efe8d56..1b2f5052 100644 --- a/tests/fast/pandas/test_stride.py +++ b/tests/fast/pandas/test_stride.py @@ -8,27 +8,27 @@ class TestPandasStride(object): def test_stride(self, duckdb_cursor): expected_df = pd.DataFrame(np.arange(20).reshape(5, 4), columns=["a", "b", "c", "d"]) con = duckdb.connect() - con.register('df_view', expected_df) + con.register("df_view", expected_df) output_df = con.execute("SELECT * FROM df_view;").fetchdf() pd.testing.assert_frame_equal(expected_df, output_df) def test_stride_fp32(self, duckdb_cursor): - expected_df = pd.DataFrame(np.arange(20, dtype='float32').reshape(5, 4), columns=["a", "b", "c", "d"]) + expected_df = pd.DataFrame(np.arange(20, dtype="float32").reshape(5, 4), columns=["a", "b", "c", "d"]) con = duckdb.connect() - con.register('df_view', expected_df) + con.register("df_view", expected_df) output_df = con.execute("SELECT * FROM df_view;").fetchdf() for col in output_df.columns: - assert str(output_df[col].dtype) == 'float32' + assert str(output_df[col].dtype) == "float32" pd.testing.assert_frame_equal(expected_df, output_df) def test_stride_datetime(self, duckdb_cursor): - df = pd.DataFrame({'date': pd.Series(pd.date_range("2024-01-01", freq="D", periods=100))}) + df = pd.DataFrame({"date": pd.Series(pd.date_range("2024-01-01", freq="D", periods=100))}) df = df.loc[::23,] roundtrip = duckdb_cursor.sql("select * from df").df() expected = pd.DataFrame( { - 'date': [ + "date": [ datetime.datetime(2024, 1, 1), datetime.datetime(2024, 1, 24), datetime.datetime(2024, 2, 16), @@ -40,13 +40,13 @@ def test_stride_datetime(self, duckdb_cursor): pd.testing.assert_frame_equal(roundtrip, expected) def test_stride_timedelta(self, duckdb_cursor): - df = pd.DataFrame({'date': [datetime.timedelta(days=i) for i in range(100)]}) + df = pd.DataFrame({"date": [datetime.timedelta(days=i) for i in range(100)]}) df = df.loc[::23,] roundtrip = duckdb_cursor.sql("select * from df").df() expected = pd.DataFrame( { - 'date': [ + "date": [ datetime.timedelta(days=0), datetime.timedelta(days=23), datetime.timedelta(days=46), @@ -58,10 +58,10 @@ def test_stride_timedelta(self, duckdb_cursor): pd.testing.assert_frame_equal(roundtrip, expected) def test_stride_fp64(self, duckdb_cursor): - expected_df = pd.DataFrame(np.arange(20, dtype='float64').reshape(5, 4), columns=["a", "b", "c", "d"]) + expected_df = pd.DataFrame(np.arange(20, dtype="float64").reshape(5, 4), columns=["a", "b", "c", "d"]) con = duckdb.connect() - con.register('df_view', expected_df) + con.register("df_view", expected_df) output_df = con.execute("SELECT * FROM df_view;").fetchdf() for col in output_df.columns: - assert str(output_df[col].dtype) == 'float64' + assert str(output_df[col].dtype) == "float64" pd.testing.assert_frame_equal(expected_df, output_df) diff --git a/tests/fast/pandas/test_timedelta.py b/tests/fast/pandas/test_timedelta.py index 5c6aa4b9..c0afeb74 100644 --- a/tests/fast/pandas/test_timedelta.py +++ b/tests/fast/pandas/test_timedelta.py @@ -11,7 +11,7 @@ def test_timedelta_positive(self, duckdb_cursor): "SELECT '2290-01-01 23:59:00'::TIMESTAMP - '2000-01-01 23:59:00'::TIMESTAMP AS '0'" ).df() data = [datetime.timedelta(microseconds=9151574400000000)] - df_in = pd.DataFrame({0: pd.Series(data=data, dtype='object')}) + df_in = pd.DataFrame({0: pd.Series(data=data, dtype="object")}) df_out = duckdb.query_df(df_in, "df", "select * from df", connection=duckdb_cursor).df() pd.testing.assert_frame_equal(df_out, duckdb_interval) @@ -20,7 +20,7 @@ def test_timedelta_basic(self, duckdb_cursor): "SELECT '2290-08-30 23:53:40'::TIMESTAMP - '2000-02-01 01:56:00'::TIMESTAMP AS '0'" ).df() data = [datetime.timedelta(microseconds=9169797460000000)] - df_in = pd.DataFrame({0: pd.Series(data=data, dtype='object')}) + df_in = pd.DataFrame({0: pd.Series(data=data, dtype="object")}) df_out = duckdb.query_df(df_in, "df", "select * from df", connection=duckdb_cursor).df() pd.testing.assert_frame_equal(df_out, duckdb_interval) @@ -29,24 +29,24 @@ def test_timedelta_negative(self, duckdb_cursor): "SELECT '2000-01-01 23:59:00'::TIMESTAMP - '2290-01-01 23:59:00'::TIMESTAMP AS '0'" ).df() data = [datetime.timedelta(microseconds=-9151574400000000)] - df_in = pd.DataFrame({0: pd.Series(data=data, dtype='object')}) + df_in = pd.DataFrame({0: pd.Series(data=data, dtype="object")}) df_out = duckdb.query_df(df_in, "df", "select * from df", connection=duckdb_cursor).df() pd.testing.assert_frame_equal(df_out, duckdb_interval) - @pytest.mark.parametrize('days', [1, 9999]) - @pytest.mark.parametrize('seconds', [0, 60]) + @pytest.mark.parametrize("days", [1, 9999]) + @pytest.mark.parametrize("seconds", [0, 60]) @pytest.mark.parametrize( - 'microseconds', + "microseconds", [ 0, 232493, 999_999, ], ) - @pytest.mark.parametrize('milliseconds', [0, 999]) - @pytest.mark.parametrize('minutes', [0, 60]) - @pytest.mark.parametrize('hours', [0, 24]) - @pytest.mark.parametrize('weeks', [0, 51]) + @pytest.mark.parametrize("milliseconds", [0, 999]) + @pytest.mark.parametrize("minutes", [0, 60]) + @pytest.mark.parametrize("hours", [0, 24]) + @pytest.mark.parametrize("weeks", [0, 51]) @pytest.mark.skipif(platform.system() == "Emscripten", reason="Bind parameters are broken when running on Pyodide") def test_timedelta_coverage(self, duckdb_cursor, days, seconds, microseconds, milliseconds, minutes, hours, weeks): def create_duck_interval(days, seconds, microseconds, milliseconds, minutes, hours, weeks) -> str: diff --git a/tests/fast/pandas/test_timestamp.py b/tests/fast/pandas/test_timestamp.py index 0a580025..dbb7273d 100644 --- a/tests/fast/pandas/test_timestamp.py +++ b/tests/fast/pandas/test_timestamp.py @@ -8,33 +8,33 @@ class TestPandasTimestamps(object): - @pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns']) + @pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) def test_timestamp_types_roundtrip(self, unit): d = { - 'time': pd.Series( + "time": pd.Series( [pd.Timestamp(datetime.datetime(2020, 6, 12, 14, 43, 24, 394587), unit=unit)], - dtype=f'datetime64[{unit}]', + dtype=f"datetime64[{unit}]", ) } df = pd.DataFrame(data=d) df_from_duck = duckdb.from_df(df).df() assert df_from_duck.equals(df) - @pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns']) + @pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) def test_timestamp_timezone_roundtrip(self, unit): if pandas_2_or_higher(): - dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit=unit, tz='UTC') - expected_dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit='us', tz='UTC') + dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit=unit, tz="UTC") + expected_dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit="us", tz="UTC") else: # Older versions of pandas only support 'ns' as timezone unit - expected_dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit='ns', tz='UTC') - dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit='ns', tz='UTC') + expected_dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit="ns", tz="UTC") + dtype = pd.core.dtypes.dtypes.DatetimeTZDtype(unit="ns", tz="UTC") conn = duckdb.connect() conn.execute("SET TimeZone =UTC") d = { - 'time': pd.Series( - [pd.Timestamp(datetime.datetime(2020, 6, 12, 14, 43, 24, 394587), unit=unit, tz='UTC')], + "time": pd.Series( + [pd.Timestamp(datetime.datetime(2020, 6, 12, 14, 43, 24, 394587), unit=unit, tz="UTC")], dtype=dtype, ) } @@ -46,9 +46,9 @@ def test_timestamp_timezone_roundtrip(self, unit): df_from_duck = conn.from_df(df).df() assert df_from_duck.equals(expected) - @pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns']) + @pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) def test_timestamp_nulls(self, unit): - d = {'time': pd.Series([pd.Timestamp(None, unit=unit)], dtype=f'datetime64[{unit}]')} + d = {"time": pd.Series([pd.Timestamp(None, unit=unit)], dtype=f"datetime64[{unit}]")} df = pd.DataFrame(data=d) df_from_duck = duckdb.from_df(df).df() assert df_from_duck.equals(df) @@ -56,10 +56,10 @@ def test_timestamp_nulls(self, unit): def test_timestamp_timedelta(self): df = pd.DataFrame( { - 'a': [pd.Timedelta(1, unit='s')], - 'b': [pd.Timedelta(None, unit='s')], - 'c': [pd.Timedelta(1, unit='us')], - 'd': [pd.Timedelta(1, unit='ms')], + "a": [pd.Timedelta(1, unit="s")], + "b": [pd.Timedelta(None, unit="s")], + "c": [pd.Timedelta(1, unit="us")], + "d": [pd.Timedelta(1, unit="ms")], } ) df_from_duck = duckdb.from_df(df).df() @@ -78,4 +78,4 @@ def test_timestamp_timezone(self, duckdb_cursor): """ ) res = rel.df() - assert res['dateTime'][0] == res['dateTime_1'][0] + assert res["dateTime"][0] == res["dateTime_1"][0] diff --git a/tests/fast/relational_api/test_groupings.py b/tests/fast/relational_api/test_groupings.py index fc81deba..b0a95410 100644 --- a/tests/fast/relational_api/test_groupings.py +++ b/tests/fast/relational_api/test_groupings.py @@ -22,7 +22,7 @@ def con(): class TestGroupings(object): def test_basic_grouping(self, con): - rel = con.table('tbl').sum("a", "b") + rel = con.table("tbl").sum("a", "b") res = rel.fetchall() assert res == [(7,), (2,), (5,)] @@ -31,7 +31,7 @@ def test_basic_grouping(self, con): assert res == res2 def test_cubed(self, con): - rel = con.table('tbl').sum("a", "CUBE (b)").order("ALL") + rel = con.table("tbl").sum("a", "CUBE (b)").order("ALL") res = rel.fetchall() assert res == [(2,), (5,), (7,), (14,)] @@ -40,7 +40,7 @@ def test_cubed(self, con): assert res == res2 def test_rollup(self, con): - rel = con.table('tbl').sum("a", "ROLLUP (b, c)").order("ALL") + rel = con.table("tbl").sum("a", "ROLLUP (b, c)").order("ALL") res = rel.fetchall() assert res == [(1,), (1,), (2,), (2,), (2,), (3,), (5,), (5,), (7,), (14,)] diff --git a/tests/fast/relational_api/test_joins.py b/tests/fast/relational_api/test_joins.py index 8eb365d5..cf3d3cf2 100644 --- a/tests/fast/relational_api/test_joins.py +++ b/tests/fast/relational_api/test_joins.py @@ -31,57 +31,57 @@ def con(): class TestRAPIJoins(object): def test_outer_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') - expr = ColumnExpression('tbl_a.b') == ColumnExpression('tbl_b.a') - rel = a.join(b, expr, 'outer') + a = con.table("tbl_a") + b = con.table("tbl_b") + expr = ColumnExpression("tbl_a.b") == ColumnExpression("tbl_b.a") + rel = a.join(b, expr, "outer") res = rel.fetchall() assert res == [(1, 1, 1, 4), (2, 1, 1, 4), (3, 2, None, None), (None, None, 3, 5)] def test_inner_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') - expr = ColumnExpression('tbl_a.b') == ColumnExpression('tbl_b.a') - rel = a.join(b, expr, 'inner') + a = con.table("tbl_a") + b = con.table("tbl_b") + expr = ColumnExpression("tbl_a.b") == ColumnExpression("tbl_b.a") + rel = a.join(b, expr, "inner") res = rel.fetchall() assert res == [(1, 1, 1, 4), (2, 1, 1, 4)] def test_anti_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') - expr = ColumnExpression('tbl_a.b') == ColumnExpression('tbl_b.a') - rel = a.join(b, expr, 'anti') + a = con.table("tbl_a") + b = con.table("tbl_b") + expr = ColumnExpression("tbl_a.b") == ColumnExpression("tbl_b.a") + rel = a.join(b, expr, "anti") res = rel.fetchall() # Only output the row(s) from A where the condition is false assert res == [(3, 2)] def test_left_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') - expr = ColumnExpression('tbl_a.b') == ColumnExpression('tbl_b.a') - rel = a.join(b, expr, 'left') + a = con.table("tbl_a") + b = con.table("tbl_b") + expr = ColumnExpression("tbl_a.b") == ColumnExpression("tbl_b.a") + rel = a.join(b, expr, "left") res = rel.fetchall() assert res == [(1, 1, 1, 4), (2, 1, 1, 4), (3, 2, None, None)] def test_right_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') - expr = ColumnExpression('tbl_a.b') == ColumnExpression('tbl_b.a') - rel = a.join(b, expr, 'right') + a = con.table("tbl_a") + b = con.table("tbl_b") + expr = ColumnExpression("tbl_a.b") == ColumnExpression("tbl_b.a") + rel = a.join(b, expr, "right") res = rel.fetchall() assert res == [(1, 1, 1, 4), (2, 1, 1, 4), (None, None, 3, 5)] def test_semi_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') - expr = ColumnExpression('tbl_a.b') == ColumnExpression('tbl_b.a') - rel = a.join(b, expr, 'semi') + a = con.table("tbl_a") + b = con.table("tbl_b") + expr = ColumnExpression("tbl_a.b") == ColumnExpression("tbl_b.a") + rel = a.join(b, expr, "semi") res = rel.fetchall() assert res == [(1, 1), (2, 1)] def test_cross_join(self, con): - a = con.table('tbl_a') - b = con.table('tbl_b') + a = con.table("tbl_a") + b = con.table("tbl_b") rel = a.cross(b) res = rel.fetchall() assert res == [(1, 1, 1, 4), (2, 1, 1, 4), (3, 2, 1, 4), (1, 1, 3, 5), (2, 1, 3, 5), (3, 2, 3, 5)] diff --git a/tests/fast/relational_api/test_pivot.py b/tests/fast/relational_api/test_pivot.py index d78df656..9cf91e56 100644 --- a/tests/fast/relational_api/test_pivot.py +++ b/tests/fast/relational_api/test_pivot.py @@ -27,4 +27,4 @@ def test_pivot_issue_14601(self, duckdb_cursor): export_dir = tempfile.mkdtemp() duckdb_cursor.query(f"EXPORT DATABASE '{export_dir}'") with open(os.path.join(export_dir, "schema.sql"), "r") as f: - assert 'CREATE TYPE' not in f.read() + assert "CREATE TYPE" not in f.read() diff --git a/tests/fast/relational_api/test_rapi_aggregations.py b/tests/fast/relational_api/test_rapi_aggregations.py index 29202759..3466a77a 100644 --- a/tests/fast/relational_api/test_rapi_aggregations.py +++ b/tests/fast/relational_api/test_rapi_aggregations.py @@ -269,7 +269,7 @@ def test_product(self, table): def test_string_agg(self, table): result = table.string_agg("s", sep="/").execute().fetchall() - expected = [('h/e/l/l/o/,/wor/ld',)] + expected = [("h/e/l/l/o/,/wor/ld",)] assert len(result) == len(expected) assert all([r == e for r, e in zip(result, expected)]) result = ( @@ -278,7 +278,7 @@ def test_string_agg(self, table): .execute() .fetchall() ) - expected = [(1, 'h/e/l'), (2, 'l/o'), (3, ',/wor/ld')] + expected = [(1, "h/e/l"), (2, "l/o"), (3, ",/wor/ld")] assert len(result) == len(expected) assert all([r == e for r, e in zip(result, expected)]) diff --git a/tests/fast/relational_api/test_rapi_close.py b/tests/fast/relational_api/test_rapi_close.py index 270c58f5..b6355167 100644 --- a/tests/fast/relational_api/test_rapi_close.py +++ b/tests/fast/relational_api/test_rapi_close.py @@ -11,153 +11,153 @@ def test_close_conn_rel(self, duckdb_cursor): rel = con.table("items") con.close() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): print(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): len(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.aggregate("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.any_value("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.apply("", "") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.arg_max("", "") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.arg_min("", "") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.fetch_arrow_table() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.avg("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.bit_and("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.bit_or("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.bit_xor("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.bitstring_agg("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.bool_and("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.bool_or("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.count("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.create("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.create_view("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.cume_dist("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.dense_rank("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.describe() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.df() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.distinct() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.execute() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.favg("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.fetchall() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.fetchnumpy() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.fetchone() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.filter("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.first("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.first_value("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.fsum("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.geomean("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.histogram("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.insert("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.insert_into("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.lag("", "") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.last("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.last_value("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.lead("", "") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): print(rel.limit(1)) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.list("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - rel.map(lambda df: df['col0'].add(42).to_frame()) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): + rel.map(lambda df: df["col0"].add(42).to_frame()) + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.max("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.mean("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.median("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.min("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.mode("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.n_tile("", 1) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.nth_value("", "", 1) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.order("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.percent_rank("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.product("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.project("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.quantile("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.quantile_cont("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.quantile_disc("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.query("", "") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.rank("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.rank_dense("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.row_number("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.std("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.stddev("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.stddev_pop("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.stddev_samp("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.string_agg("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.sum("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.to_arrow_table() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.to_df() - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.var("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.var_pop("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.var_samp("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.variance("") - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.write_csv("") con = duckdb.connect() @@ -166,14 +166,14 @@ def test_close_conn_rel(self, duckdb_cursor): valid_rel = con.table("items") # Test these bad boys when left relation is valid - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): valid_rel.union(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): valid_rel.except_(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): valid_rel.intersect(rel) - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): - valid_rel.join(rel.set_alias('rel'), "rel.items = valid_rel.items") + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): + valid_rel.join(rel.set_alias("rel"), "rel.items = valid_rel.items") def test_del_conn(self, duckdb_cursor): con = duckdb.connect() @@ -181,5 +181,5 @@ def test_del_conn(self, duckdb_cursor): con.execute("INSERT INTO items VALUES ('jeans', 20.0, 1), ('hammer', 42.2, 2)") rel = con.table("items") del con - with pytest.raises(duckdb.ConnectionException, match='Connection has already been closed'): + with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): print(rel) diff --git a/tests/fast/relational_api/test_rapi_description.py b/tests/fast/relational_api/test_rapi_description.py index 41813d94..80616132 100644 --- a/tests/fast/relational_api/test_rapi_description.py +++ b/tests/fast/relational_api/test_rapi_description.py @@ -4,31 +4,31 @@ class TestRAPIDescription(object): def test_rapi_description(self, duckdb_cursor): - res = duckdb_cursor.query('select 42::INT AS a, 84::BIGINT AS b') + res = duckdb_cursor.query("select 42::INT AS a, 84::BIGINT AS b") desc = res.description names = [x[0] for x in desc] types = [x[1] for x in desc] - assert names == ['a', 'b'] - assert types == ['INTEGER', 'BIGINT'] + assert names == ["a", "b"] + assert types == ["INTEGER", "BIGINT"] assert all([x == duckdb.NUMBER for x in types]) def test_rapi_describe(self, duckdb_cursor): np = pytest.importorskip("numpy") pd = pytest.importorskip("pandas") - res = duckdb_cursor.query('select 42::INT AS a, 84::BIGINT AS b') + res = duckdb_cursor.query("select 42::INT AS a, 84::BIGINT AS b") duck_describe = res.describe().df() - np.testing.assert_array_equal(duck_describe['aggr'], ['count', 'mean', 'stddev', 'min', 'max', 'median']) - np.testing.assert_array_equal(duck_describe['a'], [1, 42, float('nan'), 42, 42, 42]) - np.testing.assert_array_equal(duck_describe['b'], [1, 84, float('nan'), 84, 84, 84]) + np.testing.assert_array_equal(duck_describe["aggr"], ["count", "mean", "stddev", "min", "max", "median"]) + np.testing.assert_array_equal(duck_describe["a"], [1, 42, float("nan"), 42, 42, 42]) + np.testing.assert_array_equal(duck_describe["b"], [1, 84, float("nan"), 84, 84, 84]) # now with more values res = duckdb_cursor.query( - 'select CASE WHEN i%2=0 THEN i ELSE NULL END AS i, i * 10 AS j, (i * 23 // 27)::DOUBLE AS k FROM range(10000) t(i)' + "select CASE WHEN i%2=0 THEN i ELSE NULL END AS i, i * 10 AS j, (i * 23 // 27)::DOUBLE AS k FROM range(10000) t(i)" ) duck_describe = res.describe().df() - np.testing.assert_allclose(duck_describe['i'], [5000.0, 4999.0, 2887.0400066504103, 0.0, 9998.0, 4999.0]) - np.testing.assert_allclose(duck_describe['j'], [10000.0, 49995.0, 28868.956799071675, 0.0, 99990.0, 49995.0]) - np.testing.assert_allclose(duck_describe['k'], [10000.0, 4258.3518, 2459.207430770227, 0.0, 8517.0, 4258.5]) + np.testing.assert_allclose(duck_describe["i"], [5000.0, 4999.0, 2887.0400066504103, 0.0, 9998.0, 4999.0]) + np.testing.assert_allclose(duck_describe["j"], [10000.0, 49995.0, 28868.956799071675, 0.0, 99990.0, 49995.0]) + np.testing.assert_allclose(duck_describe["k"], [10000.0, 4258.3518, 2459.207430770227, 0.0, 8517.0, 4258.5]) # describe data with other (non-numeric) types res = duckdb_cursor.query("select 'hello world' AS a, [1, 2, 3] AS b") @@ -38,8 +38,8 @@ def test_rapi_describe(self, duckdb_cursor): # describe mixed table res = duckdb_cursor.query("select 42::INT AS a, 84::BIGINT AS b, 'hello world' AS c") duck_describe = res.describe().df() - np.testing.assert_array_equal(duck_describe['a'], [1, 42, float('nan'), 42, 42, 42]) - np.testing.assert_array_equal(duck_describe['b'], [1, 84, float('nan'), 84, 84, 84]) + np.testing.assert_array_equal(duck_describe["a"], [1, 42, float("nan"), 42, 42, 42]) + np.testing.assert_array_equal(duck_describe["b"], [1, 84, float("nan"), 84, 84, 84]) # timestamps res = duckdb_cursor.query("select timestamp '1992-01-01', date '2000-01-01'") diff --git a/tests/fast/relational_api/test_rapi_functions.py b/tests/fast/relational_api/test_rapi_functions.py index 92de4c2c..c6b1f1fa 100644 --- a/tests/fast/relational_api/test_rapi_functions.py +++ b/tests/fast/relational_api/test_rapi_functions.py @@ -3,10 +3,10 @@ class TestRAPIFunctions(object): def test_rapi_str_print(self, duckdb_cursor): - res = duckdb_cursor.query('select 42::INT AS a, 84::BIGINT AS b') + res = duckdb_cursor.query("select 42::INT AS a, 84::BIGINT AS b") assert str(res) is not None res.show() def test_rapi_relation_sql_query(self): - res = duckdb.table_function('range', [10]) + res = duckdb.table_function("range", [10]) assert res.sql_query() == 'SELECT * FROM "range"(10)' diff --git a/tests/fast/relational_api/test_rapi_query.py b/tests/fast/relational_api/test_rapi_query.py index 92f87776..16ed326c 100644 --- a/tests/fast/relational_api/test_rapi_query.py +++ b/tests/fast/relational_api/test_rapi_query.py @@ -10,12 +10,12 @@ def tbl_table(): con.execute("drop table if exists tbl CASCADE") con.execute("create table tbl (i integer)") yield - con.execute('drop table tbl CASCADE') + con.execute("drop table tbl CASCADE") @pytest.fixture() def scoped_default(duckdb_cursor): - default = duckdb.connect(':default:') + default = duckdb.connect(":default:") duckdb.set_default_connection(duckdb_cursor) # Overwrite the default connection yield @@ -24,7 +24,7 @@ def scoped_default(duckdb_cursor): class TestRAPIQuery(object): - @pytest.mark.parametrize('steps', [1, 2, 3, 4]) + @pytest.mark.parametrize("steps", [1, 2, 3, 4]) def test_query_chain(self, steps): con = duckdb.default_connection() amount = int(1000000) @@ -36,7 +36,7 @@ def test_query_chain(self, steps): result = rel.execute() assert len(result.fetchall()) == amount - @pytest.mark.parametrize('input', [[5, 4, 3], [], [1000]]) + @pytest.mark.parametrize("input", [[5, 4, 3], [], [1000]]) def test_query_table(self, tbl_table, input): con = duckdb.default_connection() rel = con.table("tbl") @@ -98,80 +98,80 @@ def test_query_table_unrelated(self, tbl_table): def test_query_non_select_result(self, duckdb_cursor): with pytest.raises(duckdb.ParserException, match="syntax error"): - duckdb_cursor.query('selec 42') + duckdb_cursor.query("selec 42") - res = duckdb_cursor.query('explain select 42').fetchall() + res = duckdb_cursor.query("explain select 42").fetchall() assert len(res) > 0 - res = duckdb_cursor.query('describe select 42::INT AS column_name').fetchall() - assert res[0][0] == 'column_name' + res = duckdb_cursor.query("describe select 42::INT AS column_name").fetchall() + assert res[0][0] == "column_name" - res = duckdb_cursor.query('create or replace table tbl_non_select_result(i integer)') + res = duckdb_cursor.query("create or replace table tbl_non_select_result(i integer)") assert res is None - res = duckdb_cursor.query('insert into tbl_non_select_result values (42)') + res = duckdb_cursor.query("insert into tbl_non_select_result values (42)") assert res is None - res = duckdb_cursor.query('insert into tbl_non_select_result values (84) returning *').fetchall() + res = duckdb_cursor.query("insert into tbl_non_select_result values (84) returning *").fetchall() assert res == [(84,)] - res = duckdb_cursor.query('select * from tbl_non_select_result').fetchall() + res = duckdb_cursor.query("select * from tbl_non_select_result").fetchall() assert res == [(42,), (84,)] - res = duckdb_cursor.query('insert into tbl_non_select_result select * from range(10000) returning *').fetchall() + res = duckdb_cursor.query("insert into tbl_non_select_result select * from range(10000) returning *").fetchall() assert len(res) == 10000 - res = duckdb_cursor.query('show tables').fetchall() + res = duckdb_cursor.query("show tables").fetchall() assert len(res) > 0 - res = duckdb_cursor.query('drop table tbl_non_select_result') + res = duckdb_cursor.query("drop table tbl_non_select_result") assert res is None def test_replacement_scan_recursion(self, duckdb_cursor): depth_limit = 1000 - if sys.platform.startswith('win') or platform.system() == "Emscripten": + if sys.platform.startswith("win") or platform.system() == "Emscripten": # With the default we reach a stack overflow in the CI for windows # and also outside of it for Pyodide depth_limit = 250 duckdb_cursor.execute(f"SET max_expression_depth TO {depth_limit}") - rel = duckdb_cursor.sql('select 42 a, 21 b') - rel = duckdb_cursor.sql('select a+a a, b+b b from rel') - other_rel = duckdb_cursor.sql('select a from rel') + rel = duckdb_cursor.sql("select 42 a, 21 b") + rel = duckdb_cursor.sql("select a+a a, b+b b from rel") + other_rel = duckdb_cursor.sql("select a from rel") res = other_rel.fetchall() assert res == [(84,)] def test_set_default_connection(self, scoped_default): duckdb.sql("create table t as select 42") - assert duckdb.table('t').fetchall() == [(42,)] - con = duckdb.connect(':default:') + assert duckdb.table("t").fetchall() == [(42,)] + con = duckdb.connect(":default:") # Uses the same db as the module - assert con.table('t').fetchall() == [(42,)] + assert con.table("t").fetchall() == [(42,)] con2 = duckdb.connect() con2.sql("create table t as select 21") - assert con2.table('t').fetchall() == [(21,)] + assert con2.table("t").fetchall() == [(21,)] # Change the db used by the module duckdb.set_default_connection(con2) - with pytest.raises(duckdb.CatalogException, match='Table with name d does not exist'): - con2.table('d').fetchall() + with pytest.raises(duckdb.CatalogException, match="Table with name d does not exist"): + con2.table("d").fetchall() - assert duckdb.table('t').fetchall() == [(21,)] + assert duckdb.table("t").fetchall() == [(21,)] duckdb.sql("create table d as select [1,2,3]") - assert duckdb.table('d').fetchall() == [([1, 2, 3],)] - assert con2.table('d').fetchall() == [([1, 2, 3],)] + assert duckdb.table("d").fetchall() == [([1, 2, 3],)] + assert con2.table("d").fetchall() == [([1, 2, 3],)] def test_set_default_connection_error(self, scoped_default): - with pytest.raises(TypeError, match='Invoked with: None'): + with pytest.raises(TypeError, match="Invoked with: None"): # set_default_connection does not allow None duckdb.set_default_connection(None) - with pytest.raises(TypeError, match='Invoked with: 5'): + with pytest.raises(TypeError, match="Invoked with: 5"): duckdb.set_default_connection(5) assert duckdb.sql("select 42").fetchall() == [(42,)] diff --git a/tests/fast/relational_api/test_rapi_windows.py b/tests/fast/relational_api/test_rapi_windows.py index 7c13debc..cc58b8f1 100644 --- a/tests/fast/relational_api/test_rapi_windows.py +++ b/tests/fast/relational_api/test_rapi_windows.py @@ -429,14 +429,14 @@ def test_bitstring_agg(self, table): .fetchall() ) expected = [ - (1, '0010000000000'), - (1, '0010000000000'), - (1, '0011000000000'), - (2, '0000000000001'), - (2, '0000000000011'), - (3, '0000001000000'), - (3, '1000001000000'), - (3, '1000001000000'), + (1, "0010000000000"), + (1, "0010000000000"), + (1, "0011000000000"), + (2, "0000000000001"), + (2, "0000000000011"), + (3, "0000001000000"), + (3, "1000001000000"), + (3, "1000001000000"), ] assert len(result) == len(expected) assert all([r == e for r, e in zip(result, expected)]) @@ -619,7 +619,7 @@ def test_string_agg(self, table): .execute() .fetchall() ) - expected = [(1, 'e'), (1, 'e/h'), (1, 'e/h/l'), (2, 'o'), (2, 'o/l'), (3, 'wor'), (3, 'wor/,'), (3, 'wor/,/ld')] + expected = [(1, "e"), (1, "e/h"), (1, "e/h/l"), (2, "o"), (2, "o/l"), (3, "wor"), (3, "wor/,"), (3, "wor/,/ld")] assert len(result) == len(expected) assert all([r == e for r, e in zip(result, expected)]) diff --git a/tests/fast/relational_api/test_table_function.py b/tests/fast/relational_api/test_table_function.py index 4f4a1016..5748f762 100644 --- a/tests/fast/relational_api/test_table_function.py +++ b/tests/fast/relational_api/test_table_function.py @@ -7,11 +7,11 @@ class TestTableFunction(object): def test_table_function(self, duckdb_cursor): - path = os.path.join(script_path, '..', 'data/integers.csv') - rel = duckdb_cursor.table_function('read_csv', [path]) + path = os.path.join(script_path, "..", "data/integers.csv") + rel = duckdb_cursor.table_function("read_csv", [path]) res = rel.fetchall() assert res == [(1, 10, 0), (2, 50, 30)] # Provide only a string as argument, should error, needs a list with pytest.raises(duckdb.InvalidInputException, match=r"'params' has to be a list of parameters"): - rel = duckdb_cursor.table_function('read_csv', path) + rel = duckdb_cursor.table_function("read_csv", path) diff --git a/tests/fast/spark/test_replace_column_value.py b/tests/fast/spark/test_replace_column_value.py index 33940616..65ab85f1 100644 --- a/tests/fast/spark/test_replace_column_value.py +++ b/tests/fast/spark/test_replace_column_value.py @@ -13,7 +13,7 @@ def test_replace_value(self, spark): # Replace part of string with another string from spark_namespace.sql.functions import regexp_replace - df2 = df.withColumn('address', regexp_replace('address', 'Rd', 'Road')) + df2 = df.withColumn("address", regexp_replace("address", "Rd", "Road")) # Replace string column value conditionally from spark_namespace.sql.functions import when @@ -21,24 +21,24 @@ def test_replace_value(self, spark): res = df2.collect() print(res) df2 = df.withColumn( - 'address', - when(df.address.endswith('Rd'), regexp_replace(df.address, 'Rd', 'Road')) - .when(df.address.endswith('St'), regexp_replace(df.address, 'St', 'Street')) - .when(df.address.endswith('Ave'), regexp_replace(df.address, 'Ave', 'Avenue')) + "address", + when(df.address.endswith("Rd"), regexp_replace(df.address, "Rd", "Road")) + .when(df.address.endswith("St"), regexp_replace(df.address, "St", "Street")) + .when(df.address.endswith("Ave"), regexp_replace(df.address, "Ave", "Avenue")) .otherwise(df.address), ) res = df2.collect() print(res) expected = [ - Row(id=1, address='14851 Jeffrey Road', state='DE'), - Row(id=2, address='43421 Margarita Street', state='NY'), - Row(id=3, address='13111 Siemon Avenue', state='CA'), + Row(id=1, address="14851 Jeffrey Road", state="DE"), + Row(id=2, address="43421 Margarita Street", state="NY"), + Row(id=3, address="13111 Siemon Avenue", state="CA"), ] print(expected) assert res == expected # Replace all substrings of the specified string value that match regexp with rep. - df3 = spark.createDataFrame([('100-200',)], ['str']) - res = df3.select(regexp_replace('str', r'(\d+)', '--').alias('d')).collect() - expected = [Row(d='-----')] + df3 = spark.createDataFrame([("100-200",)], ["str"]) + res = df3.select(regexp_replace("str", r"(\d+)", "--").alias("d")).collect() + expected = [Row(d="-----")] print(expected) assert res == expected diff --git a/tests/fast/spark/test_replace_empty_value.py b/tests/fast/spark/test_replace_empty_value.py index 71a9f25f..aad6a43e 100644 --- a/tests/fast/spark/test_replace_empty_value.py +++ b/tests/fast/spark/test_replace_empty_value.py @@ -12,32 +12,32 @@ def test_replace_empty(self, spark): # Create the dataframe data = [("", "CA"), ("Julia", ""), ("Robert", ""), ("", "NJ")] df = spark.createDataFrame(data, ["name", "state"]) - res = df.select('name').collect() - assert res == [Row(name=''), Row(name='Julia'), Row(name='Robert'), Row(name='')] - res = df.select('state').collect() - assert res == [Row(state='CA'), Row(state=''), Row(state=''), Row(state='NJ')] + res = df.select("name").collect() + assert res == [Row(name=""), Row(name="Julia"), Row(name="Robert"), Row(name="")] + res = df.select("state").collect() + assert res == [Row(state="CA"), Row(state=""), Row(state=""), Row(state="NJ")] # Replace name # CASE WHEN "name" == '' THEN NULL ELSE "name" END from spark_namespace.sql.functions import col, when df2 = df.withColumn("name", when(col("name") == "", None).otherwise(col("name"))) - assert df2.columns == ['name', 'state'] - res = df2.select('name').collect() - assert res == [Row(name=None), Row(name='Julia'), Row(name='Robert'), Row(name=None)] + assert df2.columns == ["name", "state"] + res = df2.select("name").collect() + assert res == [Row(name=None), Row(name="Julia"), Row(name="Robert"), Row(name=None)] # Replace state + name from spark_namespace.sql.functions import col, when df2 = df.select([when(col(c) == "", None).otherwise(col(c)).alias(c) for c in df.columns]) - assert df2.columns == ['name', 'state'] + assert df2.columns == ["name", "state"] key_f = lambda x: x.name or x.state res = df2.sort("name", "state").collect() expected_res = [ - Row(name=None, state='CA'), - Row(name=None, state='NJ'), - Row(name='Julia', state=None), - Row(name='Robert', state=None), + Row(name=None, state="CA"), + Row(name=None, state="NJ"), + Row(name="Julia", state=None), + Row(name="Robert", state=None), ] assert res == expected_res @@ -46,15 +46,15 @@ def test_replace_empty(self, spark): from spark_namespace.sql.functions import col, when replaceCols = ["state"] - df2 = df.select([when(col(c) == "", None).otherwise(col(c)).alias(c) for c in replaceCols]).sort(col('state')) - assert df2.columns == ['state'] + df2 = df.select([when(col(c) == "", None).otherwise(col(c)).alias(c) for c in replaceCols]).sort(col("state")) + assert df2.columns == ["state"] key_f = lambda x: x.state or "" res = df2.collect() assert sorted(res, key=key_f) == sorted( [ - Row(state='CA'), - Row(state='NJ'), + Row(state="CA"), + Row(state="NJ"), Row(state=None), Row(state=None), ], diff --git a/tests/fast/spark/test_spark_catalog.py b/tests/fast/spark/test_spark_catalog.py index 7f523abd..2ecaad24 100644 --- a/tests/fast/spark/test_spark_catalog.py +++ b/tests/fast/spark/test_spark_catalog.py @@ -13,9 +13,9 @@ def test_list_databases(self, spark): assert all(isinstance(db, Database) for db in dbs) else: assert dbs == [ - Database(name='memory', description=None, locationUri=''), - Database(name='system', description=None, locationUri=''), - Database(name='temp', description=None, locationUri=''), + Database(name="memory", description=None, locationUri=""), + Database(name="system", description=None, locationUri=""), + Database(name="temp", description=None, locationUri=""), ] def test_list_tables(self, spark): @@ -26,31 +26,31 @@ def test_list_tables(self, spark): if not USE_ACTUAL_SPARK: # Skip this if we're using actual Spark because we can't create tables # with our setup. - spark.sql('create table tbl(a varchar)') + spark.sql("create table tbl(a varchar)") tbls = spark.catalog.listTables() assert tbls == [ Table( - name='tbl', - database='memory', - description='CREATE TABLE tbl(a VARCHAR);', - tableType='', + name="tbl", + database="memory", + description="CREATE TABLE tbl(a VARCHAR);", + tableType="", isTemporary=False, ) ] @pytest.mark.skipif(USE_ACTUAL_SPARK, reason="We can't create tables with our Spark test setup") def test_list_columns(self, spark): - spark.sql('create table tbl(a varchar, b bool)') - columns = spark.catalog.listColumns('tbl') + spark.sql("create table tbl(a varchar, b bool)") + columns = spark.catalog.listColumns("tbl") assert columns == [ - Column(name='a', description=None, dataType='VARCHAR', nullable=True, isPartition=False, isBucket=False), - Column(name='b', description=None, dataType='BOOLEAN', nullable=True, isPartition=False, isBucket=False), + Column(name="a", description=None, dataType="VARCHAR", nullable=True, isPartition=False, isBucket=False), + Column(name="b", description=None, dataType="BOOLEAN", nullable=True, isPartition=False, isBucket=False), ] # FIXME: should this error instead? - non_existant_columns = spark.catalog.listColumns('none_existant') + non_existant_columns = spark.catalog.listColumns("none_existant") assert non_existant_columns == [] - spark.sql('create view vw as select * from tbl') - view_columns = spark.catalog.listColumns('vw') + spark.sql("create view vw as select * from tbl") + view_columns = spark.catalog.listColumns("vw") assert view_columns == columns diff --git a/tests/fast/spark/test_spark_column.py b/tests/fast/spark/test_spark_column.py index e56ba9ee..9ef17d95 100644 --- a/tests/fast/spark/test_spark_column.py +++ b/tests/fast/spark/test_spark_column.py @@ -18,26 +18,26 @@ def test_struct_column(self, spark): # FIXME: column names should be set explicitly using the Row, rather than letting duckdb assign defaults (col0, col1, etc..) if USE_ACTUAL_SPARK: - df = df.withColumn('struct', struct(df.a, df.b)) + df = df.withColumn("struct", struct(df.a, df.b)) else: - df = df.withColumn('struct', struct(df.col0, df.col1)) - assert 'struct' in df - new_col = df.schema['struct'] + df = df.withColumn("struct", struct(df.col0, df.col1)) + assert "struct" in df + new_col = df.schema["struct"] if USE_ACTUAL_SPARK: - assert 'a' in df.schema['struct'].dataType.fieldNames() - assert 'b' in df.schema['struct'].dataType.fieldNames() + assert "a" in df.schema["struct"].dataType.fieldNames() + assert "b" in df.schema["struct"].dataType.fieldNames() else: - assert 'col0' in new_col.dataType - assert 'col1' in new_col.dataType + assert "col0" in new_col.dataType + assert "col1" in new_col.dataType with pytest.raises( PySparkTypeError, match=re.escape("[NOT_COLUMN] Argument `col` should be a Column, got str.") ): - df = df.withColumn('struct', 'yes') + df = df.withColumn("struct", "yes") def test_array_column(self, spark): - df = spark.createDataFrame([Row(a=1, b=2, c=3, d=4)], ['a', 'b', 'c', 'd']) + df = spark.createDataFrame([Row(a=1, b=2, c=3, d=4)], ["a", "b", "c", "d"]) df2 = df.select( array(df["a"], df["b"]).alias("array"), diff --git a/tests/fast/spark/test_spark_dataframe.py b/tests/fast/spark/test_spark_dataframe.py index d88b03eb..e86995ec 100644 --- a/tests/fast/spark/test_spark_dataframe.py +++ b/tests/fast/spark/test_spark_dataframe.py @@ -36,9 +36,9 @@ def test_dataframe_from_list_of_tuples(self, spark): df = spark.createDataFrame(address, ["id", "address", "state"]) res = df.collect() assert res == [ - Row(id=1, address='14851 Jeffrey Rd', state='DE'), - Row(id=2, address='43421 Margarita St', state='NY'), - Row(id=3, address='13111 Siemon Ave', state='CA'), + Row(id=1, address="14851 Jeffrey Rd", state="DE"), + Row(id=2, address="43421 Margarita St", state="NY"), + Row(id=3, address="13111 Siemon Ave", state="CA"), ] # Tuples of different sizes @@ -93,9 +93,9 @@ def test_dataframe_from_list_of_tuples(self, spark): df = spark.createDataFrame(address, []) res = df.collect() assert res == [ - Row(col0=1, col1='14851 Jeffrey Rd', col2='DE'), - Row(col0=2, col1='43421 Margarita St', col2='NY'), - Row(col0=3, col1='13111 Siemon Ave', col2='DE'), + Row(col0=1, col1="14851 Jeffrey Rd", col2="DE"), + Row(col0=2, col1="43421 Margarita St", col2="NY"), + Row(col0=3, col1="13111 Siemon Ave", col2="DE"), ] # Too many column names @@ -107,17 +107,17 @@ def test_dataframe_from_list_of_tuples(self, spark): # Column names is not a list (but is iterable) if not USE_ACTUAL_SPARK: # These things do not work in Spark or throw different errors - df = spark.createDataFrame(address, {'a': 5, 'b': 6, 'c': 42}) + df = spark.createDataFrame(address, {"a": 5, "b": 6, "c": 42}) res = df.collect() assert res == [ - Row(a=1, b='14851 Jeffrey Rd', c='DE'), - Row(a=2, b='43421 Margarita St', c='NY'), - Row(a=3, b='13111 Siemon Ave', c='DE'), + Row(a=1, b="14851 Jeffrey Rd", c="DE"), + Row(a=2, b="43421 Margarita St", c="NY"), + Row(a=3, b="13111 Siemon Ave", c="DE"), ] # Column names is not a list (string, becomes a single column name) with pytest.raises(PySparkValueError, match="number of columns in the DataFrame don't match"): - df = spark.createDataFrame(address, 'a') + df = spark.createDataFrame(address, "a") with pytest.raises(TypeError, match="must be an iterable, not int"): df = spark.createDataFrame(address, 5) @@ -126,7 +126,7 @@ def test_dataframe(self, spark): # Create DataFrame df = spark.createDataFrame([("Scala", 25000), ("Spark", 35000), ("PHP", 21000)]) res = df.collect() - assert res == [Row(col0='Scala', col1=25000), Row(col0='Spark', col1=35000), Row(col0='PHP', col1=21000)] + assert res == [Row(col0="Scala", col1=25000), Row(col0="Spark", col1=35000), Row(col0="PHP", col1=21000)] @pytest.mark.skipif(USE_ACTUAL_SPARK, reason="We can't create tables with our Spark test setup") def test_writing_to_table(self, spark): @@ -136,18 +136,18 @@ def test_writing_to_table(self, spark): create table sample_table("_1" bool, "_2" integer) """ ) - spark.sql('insert into sample_table VALUES (True, 42)') + spark.sql("insert into sample_table VALUES (True, 42)") spark.table("sample_table").write.saveAsTable("sample_hive_table") df3 = spark.sql("SELECT _1,_2 FROM sample_hive_table") res = df3.collect() assert res == [Row(_1=True, _2=42)] schema = df3.schema - assert schema == StructType([StructField('_1', BooleanType(), True), StructField('_2', IntegerType(), True)]) + assert schema == StructType([StructField("_1", BooleanType(), True), StructField("_2", IntegerType(), True)]) def test_dataframe_collect(self, spark): - df = spark.createDataFrame([(42,), (21,)]).toDF('a') + df = spark.createDataFrame([(42,), (21,)]).toDF("a") res = df.collect() - assert str(res) == '[Row(a=42), Row(a=21)]' + assert str(res) == "[Row(a=42), Row(a=21)]" def test_dataframe_from_rows(self, spark): columns = ["language", "users_count"] @@ -157,17 +157,17 @@ def test_dataframe_from_rows(self, spark): df = spark.createDataFrame(rowData, columns) res = df.collect() assert res == [ - Row(language='Java', users_count='20000'), - Row(language='Python', users_count='100000'), - Row(language='Scala', users_count='3000'), + Row(language="Java", users_count="20000"), + Row(language="Python", users_count="100000"), + Row(language="Scala", users_count="3000"), ] def test_empty_df(self, spark): schema = StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ) df = spark.createDataFrame([], schema=schema) @@ -178,18 +178,18 @@ def test_empty_df(self, spark): def test_df_from_pandas(self, spark): import pandas as pd - df = spark.createDataFrame(pd.DataFrame({'a': [42, 21], 'b': [True, False]})) + df = spark.createDataFrame(pd.DataFrame({"a": [42, 21], "b": [True, False]})) res = df.collect() assert res == [Row(a=42, b=True), Row(a=21, b=False)] def test_df_from_struct_type(self, spark): - schema = StructType([StructField('a', LongType()), StructField('b', BooleanType())]) + schema = StructType([StructField("a", LongType()), StructField("b", BooleanType())]) df = spark.createDataFrame([(42, True), (21, False)], schema) res = df.collect() assert res == [Row(a=42, b=True), Row(a=21, b=False)] def test_df_from_name_list(self, spark): - df = spark.createDataFrame([(42, True), (21, False)], ['a', 'b']) + df = spark.createDataFrame([(42, True), (21, False)], ["a", "b"]) res = df.collect() assert res == [Row(a=42, b=True), Row(a=21, b=False)] @@ -218,11 +218,11 @@ def test_df_creation_coverage(self, spark): df = spark.createDataFrame(data=data2, schema=schema) res = df.collect() assert res == [ - Row(firstname='James', middlename='', lastname='Smith', id='36636', gender='M', salary=3000), - Row(firstname='Michael', middlename='Rose', lastname='', id='40288', gender='M', salary=4000), - Row(firstname='Robert', middlename='', lastname='Williams', id='42114', gender='M', salary=4000), - Row(firstname='Maria', middlename='Anne', lastname='Jones', id='39192', gender='F', salary=4000), - Row(firstname='Jen', middlename='Mary', lastname='Brown', id='', gender='F', salary=-1), + Row(firstname="James", middlename="", lastname="Smith", id="36636", gender="M", salary=3000), + Row(firstname="Michael", middlename="Rose", lastname="", id="40288", gender="M", salary=4000), + Row(firstname="Robert", middlename="", lastname="Williams", id="42114", gender="M", salary=4000), + Row(firstname="Maria", middlename="Anne", lastname="Jones", id="39192", gender="F", salary=4000), + Row(firstname="Jen", middlename="Mary", lastname="Brown", id="", gender="F", salary=-1), ] def test_df_nested_struct(self, spark): @@ -236,18 +236,18 @@ def test_df_nested_struct(self, spark): structureSchema = StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), ), - StructField('id', StringType(), True), - StructField('gender', StringType(), True), - StructField('salary', IntegerType(), True), + StructField("id", StringType(), True), + StructField("gender", StringType(), True), + StructField("salary", IntegerType(), True), ] ) @@ -255,24 +255,24 @@ def test_df_nested_struct(self, spark): res = df2.collect() expected_res = [ Row( - name={'firstname': 'James', 'middlename': '', 'lastname': 'Smith'}, id='36636', gender='M', salary=3100 + name={"firstname": "James", "middlename": "", "lastname": "Smith"}, id="36636", gender="M", salary=3100 ), Row( - name={'firstname': 'Michael', 'middlename': 'Rose', 'lastname': ''}, id='40288', gender='M', salary=4300 + name={"firstname": "Michael", "middlename": "Rose", "lastname": ""}, id="40288", gender="M", salary=4300 ), Row( - name={'firstname': 'Robert', 'middlename': '', 'lastname': 'Williams'}, - id='42114', - gender='M', + name={"firstname": "Robert", "middlename": "", "lastname": "Williams"}, + id="42114", + gender="M", salary=1400, ), Row( - name={'firstname': 'Maria', 'middlename': 'Anne', 'lastname': 'Jones'}, - id='39192', - gender='F', + name={"firstname": "Maria", "middlename": "Anne", "lastname": "Jones"}, + id="39192", + gender="F", salary=5500, ), - Row(name={'firstname': 'Jen', 'middlename': 'Mary', 'lastname': 'Brown'}, id='', gender='F', salary=-1), + Row(name={"firstname": "Jen", "middlename": "Mary", "lastname": "Brown"}, id="", gender="F", salary=-1), ] if USE_ACTUAL_SPARK: expected_res = [Row(name=Row(**r.name), id=r.id, gender=r.gender, salary=r.salary) for r in expected_res] @@ -281,19 +281,19 @@ def test_df_nested_struct(self, spark): assert schema == StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), True, ), - StructField('id', StringType(), True), - StructField('gender', StringType(), True), - StructField('salary', IntegerType(), True), + StructField("id", StringType(), True), + StructField("gender", StringType(), True), + StructField("salary", IntegerType(), True), ] ) @@ -310,18 +310,18 @@ def test_df_columns(self, spark): structureSchema = StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), ), - StructField('id', StringType(), True), - StructField('gender', StringType(), True), - StructField('salary', IntegerType(), True), + StructField("id", StringType(), True), + StructField("gender", StringType(), True), + StructField("salary", IntegerType(), True), ] ) @@ -339,7 +339,7 @@ def test_df_columns(self, spark): ), ).drop("id", "gender", "salary") - assert 'OtherInfo' in updatedDF.columns + assert "OtherInfo" in updatedDF.columns def test_array_and_map_type(self, spark): """Array & Map""" @@ -347,17 +347,17 @@ def test_array_and_map_type(self, spark): arrayStructureSchema = StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), ), - StructField('hobbies', ArrayType(StringType()), True), - StructField('properties', MapType(StringType(), StringType()), True), + StructField("hobbies", ArrayType(StringType()), True), + StructField("properties", MapType(StringType(), StringType()), True), ] ) diff --git a/tests/fast/spark/test_spark_dataframe_sort.py b/tests/fast/spark/test_spark_dataframe_sort.py index 20558197..db7dce4b 100644 --- a/tests/fast/spark/test_spark_dataframe_sort.py +++ b/tests/fast/spark/test_spark_dataframe_sort.py @@ -91,11 +91,11 @@ def test_sort_with_desc(self, spark): df = df.sort(desc("name")) res = df.collect() assert res == [ - Row(age=3, name='Dave'), - Row(age=56, name='Carol'), - Row(age=1, name='Ben'), - Row(age=3, name='Anna'), - Row(age=20, name='Alice'), + Row(age=3, name="Dave"), + Row(age=56, name="Carol"), + Row(age=1, name="Ben"), + Row(age=3, name="Anna"), + Row(age=20, name="Alice"), ] def test_sort_with_asc(self, spark): @@ -103,9 +103,9 @@ def test_sort_with_asc(self, spark): df = df.sort(asc("name")) res = df.collect() assert res == [ - Row(age=20, name='Alice'), - Row(age=3, name='Anna'), - Row(age=1, name='Ben'), - Row(age=56, name='Carol'), - Row(age=3, name='Dave'), + Row(age=20, name="Alice"), + Row(age=3, name="Anna"), + Row(age=1, name="Ben"), + Row(age=56, name="Carol"), + Row(age=3, name="Dave"), ] diff --git a/tests/fast/spark/test_spark_drop_duplicates.py b/tests/fast/spark/test_spark_drop_duplicates.py index 6dc7f573..563a5e76 100644 --- a/tests/fast/spark/test_spark_drop_duplicates.py +++ b/tests/fast/spark/test_spark_drop_duplicates.py @@ -34,15 +34,15 @@ def test_spark_drop_duplicates(self, method, spark): res = distinctDF.collect() # James | Sales had a duplicate, has been removed expected = [ - Row(employee_name='James', department='Sales', salary=3000), - Row(employee_name='Jeff', department='Marketing', salary=3000), - Row(employee_name='Jen', department='Finance', salary=3900), - Row(employee_name='Kumar', department='Marketing', salary=2000), - Row(employee_name='Maria', department='Finance', salary=3000), - Row(employee_name='Michael', department='Sales', salary=4600), - Row(employee_name='Robert', department='Sales', salary=4100), - Row(employee_name='Saif', department='Sales', salary=4100), - Row(employee_name='Scott', department='Finance', salary=3300), + Row(employee_name="James", department="Sales", salary=3000), + Row(employee_name="Jeff", department="Marketing", salary=3000), + Row(employee_name="Jen", department="Finance", salary=3900), + Row(employee_name="Kumar", department="Marketing", salary=2000), + Row(employee_name="Maria", department="Finance", salary=3000), + Row(employee_name="Michael", department="Sales", salary=4600), + Row(employee_name="Robert", department="Sales", salary=4100), + Row(employee_name="Saif", department="Sales", salary=4100), + Row(employee_name="Scott", department="Finance", salary=3300), ] assert res == expected @@ -52,14 +52,14 @@ def test_spark_drop_duplicates(self, method, spark): assert res2 == res expected_subset = [ - Row(department='Finance', salary=3000), - Row(department='Finance', salary=3300), - Row(department='Finance', salary=3900), - Row(department='Marketing', salary=2000), - Row(department='Marketing', salary=3000), - Row(epartment='Sales', salary=3000), - Row(department='Sales', salary=4100), - Row(department='Sales', salary=4600), + Row(department="Finance", salary=3000), + Row(department="Finance", salary=3300), + Row(department="Finance", salary=3900), + Row(department="Marketing", salary=2000), + Row(department="Marketing", salary=3000), + Row(epartment="Sales", salary=3000), + Row(department="Sales", salary=4100), + Row(department="Sales", salary=4600), ] dropDisDF = getattr(df, method)(["department", "salary"]).sort("department", "salary") diff --git a/tests/fast/spark/test_spark_except.py b/tests/fast/spark/test_spark_except.py index 434ac613..7c28cc29 100644 --- a/tests/fast/spark/test_spark_except.py +++ b/tests/fast/spark/test_spark_except.py @@ -19,7 +19,6 @@ def df2(spark): class TestDataFrameIntersect: def test_exceptAll(self, spark, df, df2): - df3 = df.exceptAll(df2).sort(*df.columns) res = df3.collect() diff --git a/tests/fast/spark/test_spark_filter.py b/tests/fast/spark/test_spark_filter.py index fb6f0b1a..a4733a44 100644 --- a/tests/fast/spark/test_spark_filter.py +++ b/tests/fast/spark/test_spark_filter.py @@ -35,18 +35,18 @@ def test_dataframe_filter(self, spark): schema = StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), ), - StructField('languages', ArrayType(StringType()), True), - StructField('state', StringType(), True), - StructField('gender', StringType(), True), + StructField("languages", ArrayType(StringType()), True), + StructField("state", StringType(), True), + StructField("gender", StringType(), True), ] ) @@ -57,51 +57,51 @@ def test_dataframe_filter(self, spark): # Using equals condition df2 = df.filter(df.state == "OH") res = df2.collect() - assert res[0].state == 'OH' + assert res[0].state == "OH" # not equals condition df2 = df.filter(df.state != "OH") df2 = df.filter(~(df.state == "OH")) res = df2.collect() for item in res: - assert item.state == 'NY' or item.state == 'CA' + assert item.state == "NY" or item.state == "CA" df2 = df.filter(col("state") == "OH") res = df2.collect() - assert res[0].state == 'OH' + assert res[0].state == "OH" df2 = df.filter("gender == 'M'") res = df2.collect() - assert res[0].gender == 'M' + assert res[0].gender == "M" df2 = df.filter("gender != 'M'") res = df2.collect() - assert res[0].gender == 'F' + assert res[0].gender == "F" df2 = df.filter("gender <> 'M'") res = df2.collect() - assert res[0].gender == 'F' + assert res[0].gender == "F" # Filter multiple condition df2 = df.filter((df.state == "OH") & (df.gender == "M")) res = df2.collect() assert len(res) == 2 for item in res: - assert item.gender == 'M' and item.state == 'OH' + assert item.gender == "M" and item.state == "OH" # Filter IS IN List values li = ["OH", "NY"] df2 = df.filter(df.state.isin(li)) res = df2.collect() for item in res: - assert item.state == 'OH' or item.state == 'NY' + assert item.state == "OH" or item.state == "NY" # Filter NOT IS IN List values # These show all records with NY (NY is not part of the list) df2 = df.filter(~df.state.isin(li)) res = df2.collect() for item in res: - assert item.state != 'OH' and item.state != 'NY' + assert item.state != "OH" and item.state != "NY" df2 = df.filter(df.state.isin(li) == False) res2 = df2.collect() @@ -111,19 +111,19 @@ def test_dataframe_filter(self, spark): df2 = df.filter(df.state.startswith("N")) res = df2.collect() for item in res: - assert item.state == 'NY' + assert item.state == "NY" # using endswith df2 = df.filter(df.state.endswith("H")) res = df2.collect() for item in res: - assert item.state == 'OH' + assert item.state == "OH" # contains df2 = df.filter(df.state.contains("H")) res = df2.collect() for item in res: - assert item.state == 'OH' + assert item.state == "OH" data2 = [(2, "Michael Rose"), (3, "Robert Williams"), (4, "Rames Rose"), (5, "Rames rose")] df2 = spark.createDataFrame(data=data2, schema=["id", "name"]) @@ -131,56 +131,56 @@ def test_dataframe_filter(self, spark): # like - SQL LIKE pattern df3 = df2.filter(df2.name.like("%rose%")) res = df3.collect() - assert res == [Row(id=5, name='Rames rose')] + assert res == [Row(id=5, name="Rames rose")] # rlike - SQL RLIKE pattern (LIKE with Regex) # This check case insensitive df3 = df2.filter(df2.name.rlike("(?i)^*rose$")) res = df3.collect() - assert res == [Row(id=2, name='Michael Rose'), Row(id=4, name='Rames Rose'), Row(id=5, name='Rames rose')] + assert res == [Row(id=2, name="Michael Rose"), Row(id=4, name="Rames Rose"), Row(id=5, name="Rames rose")] df2 = df.filter(array_contains(df.languages, "Java")) res = df2.collect() - james_name = {'firstname': 'James', 'middlename': '', 'lastname': 'Smith'} - anna_name = {'firstname': 'Anna', 'middlename': 'Rose', 'lastname': ''} + james_name = {"firstname": "James", "middlename": "", "lastname": "Smith"} + anna_name = {"firstname": "Anna", "middlename": "Rose", "lastname": ""} if USE_ACTUAL_SPARK: james_name = Row(**james_name) anna_name = Row(**anna_name) assert res == [ Row( name=james_name, - languages=['Java', 'Scala', 'C++'], - state='OH', - gender='M', + languages=["Java", "Scala", "C++"], + state="OH", + gender="M", ), Row( name=anna_name, - languages=['Spark', 'Java', 'C++'], - state='CA', - gender='F', + languages=["Spark", "Java", "C++"], + state="CA", + gender="F", ), ] df2 = df.filter(df.name.lastname == "Williams") res = df2.collect() - julia_name = {'firstname': 'Julia', 'middlename': '', 'lastname': 'Williams'} - mike_name = {'firstname': 'Mike', 'middlename': 'Mary', 'lastname': 'Williams'} + julia_name = {"firstname": "Julia", "middlename": "", "lastname": "Williams"} + mike_name = {"firstname": "Mike", "middlename": "Mary", "lastname": "Williams"} if USE_ACTUAL_SPARK: julia_name = Row(**julia_name) mike_name = Row(**mike_name) assert res == [ Row( name=julia_name, - languages=['CSharp', 'VB'], - state='OH', - gender='F', + languages=["CSharp", "VB"], + state="OH", + gender="F", ), Row( name=mike_name, - languages=['Python', 'VB'], - state='OH', - gender='M', + languages=["Python", "VB"], + state="OH", + gender="M", ), ] diff --git a/tests/fast/spark/test_spark_functions_array.py b/tests/fast/spark/test_spark_functions_array.py index f83e0ef2..5ecba132 100644 --- a/tests/fast/spark/test_spark_functions_array.py +++ b/tests/fast/spark/test_spark_functions_array.py @@ -75,7 +75,7 @@ def test_array_min(self, spark): ] def test_get(self, spark): - df = spark.createDataFrame([(["a", "b", "c"], 1)], ['data', 'index']) + df = spark.createDataFrame([(["a", "b", "c"], 1)], ["data", "index"]) res = df.select(sf.get(df.data, 1).alias("r")).collect() assert res == [Row(r="b")] @@ -87,25 +87,25 @@ def test_get(self, spark): assert res == [Row(r=None)] res = df.select(sf.get(df.data, "index").alias("r")).collect() - assert res == [Row(r='b')] + assert res == [Row(r="b")] res = df.select(sf.get(df.data, sf.col("index") - 1).alias("r")).collect() - assert res == [Row(r='a')] + assert res == [Row(r="a")] def test_flatten(self, spark): - df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data']) + df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ["data"]) res = df.select(sf.flatten(df.data).alias("r")).collect() assert res == [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)] def test_array_compact(self, spark): - df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ['data']) + df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ["data"]) res = df.select(sf.array_compact(df.data).alias("v")).collect() assert [Row(v=[1, 2, 3]), Row(v=[4, 5, 4])] def test_array_remove(self, spark): - df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data']) + df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ["data"]) res = df.select(sf.array_remove(df.data, 1).alias("v")).collect() assert res == [Row(v=[2, 3]), Row(v=[])] @@ -126,101 +126,101 @@ def test_array_append(self, spark): df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2="c")], ["c1", "c2"]) res = df.select(sf.array_append(df.c1, df.c2).alias("r")).collect() - assert res == [Row(r=['b', 'a', 'c', 'c'])] + assert res == [Row(r=["b", "a", "c", "c"])] - res = df.select(sf.array_append(df.c1, 'x')).collect() - assert res == [Row(r=['b', 'a', 'c', 'x'])] + res = df.select(sf.array_append(df.c1, "x")).collect() + assert res == [Row(r=["b", "a", "c", "x"])] def test_array_insert(self, spark): df = spark.createDataFrame( - [(['a', 'b', 'c'], 2, 'd'), (['a', 'b', 'c', 'e'], 2, 'd'), (['c', 'b', 'a'], -2, 'd')], - ['data', 'pos', 'val'], + [(["a", "b", "c"], 2, "d"), (["a", "b", "c", "e"], 2, "d"), (["c", "b", "a"], -2, "d")], + ["data", "pos", "val"], ) - res = df.select(sf.array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect() + res = df.select(sf.array_insert(df.data, df.pos.cast("integer"), df.val).alias("data")).collect() assert res == [ - Row(data=['a', 'd', 'b', 'c']), - Row(data=['a', 'd', 'b', 'c', 'e']), - Row(data=['c', 'b', 'd', 'a']), + Row(data=["a", "d", "b", "c"]), + Row(data=["a", "d", "b", "c", "e"]), + Row(data=["c", "b", "d", "a"]), ] - res = df.select(sf.array_insert(df.data, 5, 'hello').alias('data')).collect() + res = df.select(sf.array_insert(df.data, 5, "hello").alias("data")).collect() assert res == [ - Row(data=['a', 'b', 'c', None, 'hello']), - Row(data=['a', 'b', 'c', 'e', 'hello']), - Row(data=['c', 'b', 'a', None, 'hello']), + Row(data=["a", "b", "c", None, "hello"]), + Row(data=["a", "b", "c", "e", "hello"]), + Row(data=["c", "b", "a", None, "hello"]), ] - res = df.select(sf.array_insert(df.data, -5, 'hello').alias('data')).collect() + res = df.select(sf.array_insert(df.data, -5, "hello").alias("data")).collect() assert res == [ - Row(data=['hello', None, 'a', 'b', 'c']), - Row(data=['hello', 'a', 'b', 'c', 'e']), - Row(data=['hello', None, 'c', 'b', 'a']), + Row(data=["hello", None, "a", "b", "c"]), + Row(data=["hello", "a", "b", "c", "e"]), + Row(data=["hello", None, "c", "b", "a"]), ] def test_slice(self, spark): - df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x']) + df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ["x"]) res = df.select(sf.slice(df.x, 2, 2).alias("sliced")).collect() assert res == [Row(sliced=[2, 3]), Row(sliced=[5])] def test_sort_array(self, spark): - df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ['data']) + df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ["data"]) - res = df.select(sf.sort_array(df.data).alias('r')).collect() + res = df.select(sf.sort_array(df.data).alias("r")).collect() assert res == [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])] - res = df.select(sf.sort_array(df.data, asc=False).alias('r')).collect() + res = df.select(sf.sort_array(df.data, asc=False).alias("r")).collect() assert res == [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] @pytest.mark.parametrize(("null_replacement", "expected_joined_2"), [(None, "a"), ("replaced", "a,replaced")]) def test_array_join(self, spark, null_replacement, expected_joined_2): - df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data']) + df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ["data"]) res = df.select(sf.array_join(df.data, ",", null_replacement=null_replacement).alias("joined")).collect() - assert res == [Row(joined='a,b,c'), Row(joined=expected_joined_2)] + assert res == [Row(joined="a,b,c"), Row(joined=expected_joined_2)] def test_array_position(self, spark): - df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data']) + df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ["data"]) res = df.select(sf.array_position(df.data, "a").alias("pos")).collect() assert res == [Row(pos=3), Row(pos=0)] def test_array_preprend(self, spark): - df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data']) + df = spark.createDataFrame([([2, 3, 4],), ([],)], ["data"]) res = df.select(sf.array_prepend(df.data, 1).alias("pre")).collect() assert res == [Row(pre=[1, 2, 3, 4]), Row(pre=[1])] def test_array_repeat(self, spark): - df = spark.createDataFrame([('ab',)], ['data']) + df = spark.createDataFrame([("ab",)], ["data"]) - res = df.select(sf.array_repeat(df.data, 3).alias('r')).collect() - assert res == [Row(r=['ab', 'ab', 'ab'])] + res = df.select(sf.array_repeat(df.data, 3).alias("r")).collect() + assert res == [Row(r=["ab", "ab", "ab"])] def test_array_size(self, spark): - df = spark.createDataFrame([([2, 1, 3],), (None,)], ['data']) + df = spark.createDataFrame([([2, 1, 3],), (None,)], ["data"]) - res = df.select(sf.array_size(df.data).alias('r')).collect() + res = df.select(sf.array_size(df.data).alias("r")).collect() assert res == [Row(r=3), Row(r=None)] def test_array_sort(self, spark): - df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ['data']) + df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ["data"]) - res = df.select(sf.array_sort(df.data).alias('r')).collect() + res = df.select(sf.array_sort(df.data).alias("r")).collect() assert res == [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])] def test_arrays_overlap(self, spark): df = spark.createDataFrame( - [(["a", "b"], ["b", "c"]), (["a"], ["b", "c"]), ([None, "c"], ["a"]), ([None, "c"], [None])], ['x', 'y'] + [(["a", "b"], ["b", "c"]), (["a"], ["b", "c"]), ([None, "c"], ["a"]), ([None, "c"], [None])], ["x", "y"] ) res = df.select(sf.arrays_overlap(df.x, df.y).alias("overlap")).collect() assert res == [Row(overlap=True), Row(overlap=False), Row(overlap=None), Row(overlap=None)] def test_arrays_zip(self, spark): - df = spark.createDataFrame([([1, 2, 3], [2, 4, 6], [3, 6])], ['vals1', 'vals2', 'vals3']) + df = spark.createDataFrame([([1, 2, 3], [2, 4, 6], [3, 6])], ["vals1", "vals2", "vals3"]) - res = df.select(sf.arrays_zip(df.vals1, df.vals2, df.vals3).alias('zipped')).collect() + res = df.select(sf.arrays_zip(df.vals1, df.vals2, df.vals3).alias("zipped")).collect() # FIXME: The structure of the results should be the same if USE_ACTUAL_SPARK: assert res == [ diff --git a/tests/fast/spark/test_spark_functions_base64.py b/tests/fast/spark/test_spark_functions_base64.py index 734a5275..5a179481 100644 --- a/tests/fast/spark/test_spark_functions_base64.py +++ b/tests/fast/spark/test_spark_functions_base64.py @@ -40,4 +40,4 @@ def test_unbase64(self, spark): .select("decoded_value") .collect() ) - assert res[0].decoded_value == b'quack' + assert res[0].decoded_value == b"quack" diff --git a/tests/fast/spark/test_spark_functions_date.py b/tests/fast/spark/test_spark_functions_date.py index 2a51d9b8..a298c0ff 100644 --- a/tests/fast/spark/test_spark_functions_date.py +++ b/tests/fast/spark/test_spark_functions_date.py @@ -145,43 +145,43 @@ def test_second(self, spark): assert result[0].second_num == 45 def test_unix_date(self, spark): - df = spark.createDataFrame([('1970-01-02',)], ['t']) - res = df.select(F.unix_date(df.t.cast("date")).alias('n')).collect() + df = spark.createDataFrame([("1970-01-02",)], ["t"]) + res = df.select(F.unix_date(df.t.cast("date")).alias("n")).collect() assert res == [Row(n=1)] def test_unix_micros(self, spark): - df = spark.createDataFrame([('2015-07-22 10:00:00+00:00',)], ['t']) - res = df.select(F.unix_micros(df.t.cast("timestamp")).alias('n')).collect() + df = spark.createDataFrame([("2015-07-22 10:00:00+00:00",)], ["t"]) + res = df.select(F.unix_micros(df.t.cast("timestamp")).alias("n")).collect() assert res == [Row(n=1437559200000000)] def test_unix_millis(self, spark): - df = spark.createDataFrame([('2015-07-22 10:00:00+00:00',)], ['t']) - res = df.select(F.unix_millis(df.t.cast("timestamp")).alias('n')).collect() + df = spark.createDataFrame([("2015-07-22 10:00:00+00:00",)], ["t"]) + res = df.select(F.unix_millis(df.t.cast("timestamp")).alias("n")).collect() assert res == [Row(n=1437559200000)] def test_unix_seconds(self, spark): - df = spark.createDataFrame([('2015-07-22 10:00:00+00:00',)], ['t']) - res = df.select(F.unix_seconds(df.t.cast("timestamp")).alias('n')).collect() + df = spark.createDataFrame([("2015-07-22 10:00:00+00:00",)], ["t"]) + res = df.select(F.unix_seconds(df.t.cast("timestamp")).alias("n")).collect() assert res == [Row(n=1437559200)] def test_weekday(self, spark): - df = spark.createDataFrame([('2015-04-08',)], ['dt']) - res = df.select(F.weekday(df.dt.cast("date")).alias('day')).collect() + df = spark.createDataFrame([("2015-04-08",)], ["dt"]) + res = df.select(F.weekday(df.dt.cast("date")).alias("day")).collect() assert res == [Row(day=2)] def test_to_date(self, spark): - df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - res = df.select(F.to_date(df.t).alias('date')).collect() + df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + res = df.select(F.to_date(df.t).alias("date")).collect() assert res == [Row(date=date(1997, 2, 28))] def test_to_timestamp(self, spark): - df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - res = df.select(F.to_timestamp(df.t).alias('dt')).collect() + df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) + res = df.select(F.to_timestamp(df.t).alias("dt")).collect() assert res == [Row(dt=datetime(1997, 2, 28, 10, 30))] def test_to_timestamp_ltz(self, spark): df = spark.createDataFrame([("2016-12-31",)], ["e"]) - res = df.select(F.to_timestamp_ltz(df.e).alias('r')).collect() + res = df.select(F.to_timestamp_ltz(df.e).alias("r")).collect() assert res == [Row(r=datetime(2016, 12, 31, 0, 0))] @@ -194,15 +194,15 @@ def test_to_timestamp_ntz(self, spark): if USE_ACTUAL_SPARK: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) - res = df.select(F.to_timestamp_ntz(df.e).alias('r')).collect() + res = df.select(F.to_timestamp_ntz(df.e).alias("r")).collect() else: - res = df.select(F.to_timestamp_ntz(df.e).alias('r')).collect() + res = df.select(F.to_timestamp_ntz(df.e).alias("r")).collect() assert res == [Row(r=datetime(2016, 4, 8, 0, 0))] def test_last_day(self, spark): - df = spark.createDataFrame([('1997-02-10',)], ['d']) + df = spark.createDataFrame([("1997-02-10",)], ["d"]) - res = df.select(F.last_day(df.d.cast("date")).alias('date')).collect() + res = df.select(F.last_day(df.d.cast("date")).alias("date")).collect() assert res == [Row(date=date(1997, 2, 28))] def test_add_months(self, spark): @@ -219,12 +219,12 @@ def test_add_months(self, spark): assert result[0].with_col == date(2024, 7, 12) def test_date_diff(self, spark): - df = spark.createDataFrame([('2015-04-08', '2015-05-10')], ["d1", "d2"]) + df = spark.createDataFrame([("2015-04-08", "2015-05-10")], ["d1", "d2"]) - result_data = df.select(F.date_diff(col("d2").cast('DATE'), col("d1").cast('DATE')).alias("diff")).collect() + result_data = df.select(F.date_diff(col("d2").cast("DATE"), col("d1").cast("DATE")).alias("diff")).collect() assert result_data[0]["diff"] == -32 - result_data = df.select(F.date_diff(col("d1").cast('DATE'), col("d2").cast('DATE')).alias("diff")).collect() + result_data = df.select(F.date_diff(col("d1").cast("DATE"), col("d2").cast("DATE")).alias("diff")).collect() assert result_data[0]["diff"] == 32 def test_try_to_timestamp(self, spark): @@ -239,4 +239,4 @@ def test_try_to_timestamp_with_format(self, spark): res = df.select(F.try_to_timestamp(df.t, format=F.lit("%Y-%m-%d %H:%M:%S")).alias("dt")).collect() assert res[0].dt == datetime(1997, 2, 28, 10, 30) assert res[1].dt is None - assert res[2].dt is None \ No newline at end of file + assert res[2].dt is None diff --git a/tests/fast/spark/test_spark_functions_hex.py b/tests/fast/spark/test_spark_functions_hex.py index e5cbf12f..7d5f3c6a 100644 --- a/tests/fast/spark/test_spark_functions_hex.py +++ b/tests/fast/spark/test_spark_functions_hex.py @@ -20,7 +20,7 @@ def test_hex_string_col(self, spark): def test_hex_binary_col(self, spark): data = [ - (b'quack',), + (b"quack",), ] res = ( spark.createDataFrame(data, ["firstColumn"]) @@ -65,4 +65,4 @@ def test_unhex(self, spark): .select("unhex_value") .collect() ) - assert res[0].unhex_value == b'quack' + assert res[0].unhex_value == b"quack" diff --git a/tests/fast/spark/test_spark_functions_miscellaneous.py b/tests/fast/spark/test_spark_functions_miscellaneous.py index 87b6b776..f6af47fe 100644 --- a/tests/fast/spark/test_spark_functions_miscellaneous.py +++ b/tests/fast/spark/test_spark_functions_miscellaneous.py @@ -30,38 +30,38 @@ def test_call_function(self, spark): ] def test_octet_length(self, spark): - df = spark.createDataFrame([('cat',)], ['c1']) - res = df.select(F.octet_length('c1').alias("o")).collect() + df = spark.createDataFrame([("cat",)], ["c1"]) + res = df.select(F.octet_length("c1").alias("o")).collect() assert res == [Row(o=3)] def test_positive(self, spark): - df = spark.createDataFrame([(-1,), (0,), (1,)], ['v']) + df = spark.createDataFrame([(-1,), (0,), (1,)], ["v"]) res = df.select(F.positive("v").alias("p")).collect() assert res == [Row(p=-1), Row(p=0), Row(p=1)] def test_sequence(self, spark): - df1 = spark.createDataFrame([(-2, 2)], ('C1', 'C2')) - res = df1.select(F.sequence('C1', 'C2').alias('r')).collect() + df1 = spark.createDataFrame([(-2, 2)], ("C1", "C2")) + res = df1.select(F.sequence("C1", "C2").alias("r")).collect() assert res == [Row(r=[-2, -1, 0, 1, 2])] - df2 = spark.createDataFrame([(4, -4, -2)], ('C1', 'C2', 'C3')) - res = df2.select(F.sequence('C1', 'C2', 'C3').alias('r')).collect() + df2 = spark.createDataFrame([(4, -4, -2)], ("C1", "C2", "C3")) + res = df2.select(F.sequence("C1", "C2", "C3").alias("r")).collect() assert res == [Row(r=[4, 2, 0, -2, -4])] def test_like(self, spark): - df = spark.createDataFrame([("Spark", "_park")], ['a', 'b']) - res = df.select(F.like(df.a, df.b).alias('r')).collect() + df = spark.createDataFrame([("Spark", "_park")], ["a", "b"]) + res = df.select(F.like(df.a, df.b).alias("r")).collect() assert res == [Row(r=True)] - df = spark.createDataFrame([("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], ['a', 'b']) - res = df.select(F.like(df.a, df.b, F.lit('/')).alias('r')).collect() + df = spark.createDataFrame([("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], ["a", "b"]) + res = df.select(F.like(df.a, df.b, F.lit("/")).alias("r")).collect() assert res == [Row(r=True)] def test_ilike(self, spark): - df = spark.createDataFrame([("Spark", "spark")], ['a', 'b']) - res = df.select(F.ilike(df.a, df.b).alias('r')).collect() + df = spark.createDataFrame([("Spark", "spark")], ["a", "b"]) + res = df.select(F.ilike(df.a, df.b).alias("r")).collect() assert res == [Row(r=True)] - df = spark.createDataFrame([("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], ['a', 'b']) - res = df.select(F.ilike(df.a, df.b, F.lit('/')).alias('r')).collect() + df = spark.createDataFrame([("%SystemDrive%/Users/John", "/%SystemDrive/%//Users%")], ["a", "b"]) + res = df.select(F.ilike(df.a, df.b, F.lit("/")).alias("r")).collect() assert res == [Row(r=True)] diff --git a/tests/fast/spark/test_spark_functions_null.py b/tests/fast/spark/test_spark_functions_null.py index 3f5ee31b..230634dc 100644 --- a/tests/fast/spark/test_spark_functions_null.py +++ b/tests/fast/spark/test_spark_functions_null.py @@ -62,7 +62,7 @@ def test_nvl2(self, spark): ], ["a", "b", "c"], ) - res = df.select(F.nvl2(df.a, df.b, df.c).alias('r')).collect() + res = df.select(F.nvl2(df.a, df.b, df.c).alias("r")).collect() assert res == [Row(r=6), Row(r=9)] def test_ifnull(self, spark): @@ -92,7 +92,7 @@ def test_nullif(self, spark): ], ["a", "b"], ) - res = df.select(F.nullif(df.a, df.b).alias('r')).collect() + res = df.select(F.nullif(df.a, df.b).alias("r")).collect() assert res == [Row(r=None), Row(r=1)] def test_isnull(self, spark): @@ -116,4 +116,4 @@ def test_isnotnull(self, spark): def test_equal_null(self, spark): df = spark.createDataFrame([(1, 1), (None, 2), (None, None)], ("a", "b")) res = df.select(F.equal_null("a", F.col("b")).alias("r")).collect() - assert res == [Row(r=True), Row(r=False), Row(r=True)] \ No newline at end of file + assert res == [Row(r=True), Row(r=False), Row(r=True)] diff --git a/tests/fast/spark/test_spark_functions_numeric.py b/tests/fast/spark/test_spark_functions_numeric.py index 9c4bafb9..3548d439 100644 --- a/tests/fast/spark/test_spark_functions_numeric.py +++ b/tests/fast/spark/test_spark_functions_numeric.py @@ -301,7 +301,7 @@ def test_corr(self, spark): # Have to use a groupby to test this as agg is not yet implemented without df = spark.createDataFrame(zip(a, b, ["group1"] * N), ["a", "b", "g"]) - res = df.groupBy("g").agg(sf.corr("a", "b").alias('c')).collect() + res = df.groupBy("g").agg(sf.corr("a", "b").alias("c")).collect() assert pytest.approx(res[0].c) == 1 def test_cot(self, spark): @@ -330,7 +330,7 @@ def test_pow(self, spark): def test_random(self, spark): df = spark.range(0, 2, 1) - res = df.withColumn('rand', sf.rand()).collect() + res = df.withColumn("rand", sf.rand()).collect() assert isinstance(res[0].rand, float) assert res[0].rand >= 0 and res[0].rand < 1 @@ -355,4 +355,4 @@ def test_negative(self, spark): res = df.collect() assert res[0].value == 0 assert res[1].value == -2 - assert res[2].value == -3 \ No newline at end of file + assert res[2].value == -3 diff --git a/tests/fast/spark/test_spark_functions_string.py b/tests/fast/spark/test_spark_functions_string.py index e90cca11..b8d7f483 100644 --- a/tests/fast/spark/test_spark_functions_string.py +++ b/tests/fast/spark/test_spark_functions_string.py @@ -152,47 +152,47 @@ def test_btrim(self, spark): "SL", ) ], - ['a', 'b'], + ["a", "b"], ) - res = df.select(F.btrim(df.a, df.b).alias('r')).collect() - assert res == [Row(r='parkSQ')] + res = df.select(F.btrim(df.a, df.b).alias("r")).collect() + assert res == [Row(r="parkSQ")] - df = spark.createDataFrame([(" SparkSQL ",)], ['a']) - res = df.select(F.btrim(df.a).alias('r')).collect() - assert res == [Row(r='SparkSQL')] + df = spark.createDataFrame([(" SparkSQL ",)], ["a"]) + res = df.select(F.btrim(df.a).alias("r")).collect() + assert res == [Row(r="SparkSQL")] def test_char(self, spark): df = spark.createDataFrame( [(65,), (65 + 256,), (66 + 256,)], [ - 'a', + "a", ], ) - res = df.select(F.char(df.a).alias('ch')).collect() - assert res == [Row(ch='A'), Row(ch='A'), Row(ch='B')] + res = df.select(F.char(df.a).alias("ch")).collect() + assert res == [Row(ch="A"), Row(ch="A"), Row(ch="B")] def test_encode(self, spark): - df = spark.createDataFrame([('abcd',)], ['c']) + df = spark.createDataFrame([("abcd",)], ["c"]) res = df.select(F.encode("c", "UTF-8").alias("encoded")).collect() # FIXME: Should return the same type if USE_ACTUAL_SPARK: - assert res == [Row(encoded=bytearray(b'abcd'))] + assert res == [Row(encoded=bytearray(b"abcd"))] else: - assert res == [Row(encoded=b'abcd')] + assert res == [Row(encoded=b"abcd")] def test_split(self, spark): df = spark.createDataFrame( - [('oneAtwoBthreeC',)], + [("oneAtwoBthreeC",)], [ - 's', + "s", ], ) - res = df.select(F.split(df.s, '[ABC]').alias('s')).collect() - assert res == [Row(s=['one', 'two', 'three', ''])] + res = df.select(F.split(df.s, "[ABC]").alias("s")).collect() + assert res == [Row(s=["one", "two", "three", ""])] def test_split_part(self, spark): df = spark.createDataFrame( @@ -206,8 +206,8 @@ def test_split_part(self, spark): ["a", "b", "c"], ) - res = df.select(F.split_part(df.a, df.b, df.c).alias('r')).collect() - assert res == [Row(r='13')] + res = df.select(F.split_part(df.a, df.b, df.c).alias("r")).collect() + assert res == [Row(r="13")] # If any input is null, should return null df = spark.createDataFrame( @@ -225,8 +225,8 @@ def test_split_part(self, spark): ], ["a", "b", "c"], ) - res = df.select(F.split_part(df.a, df.b, df.c).alias('r')).collect() - assert res == [Row(r=None), Row(r='11')] + res = df.select(F.split_part(df.a, df.b, df.c).alias("r")).collect() + assert res == [Row(r=None), Row(r="11")] # If partNum is out of range, should return an empty string df = spark.createDataFrame( @@ -239,8 +239,8 @@ def test_split_part(self, spark): ], ["a", "b", "c"], ) - res = df.select(F.split_part(df.a, df.b, df.c).alias('r')).collect() - assert res == [Row(r='')] + res = df.select(F.split_part(df.a, df.b, df.c).alias("r")).collect() + assert res == [Row(r="")] # If partNum is negative, parts are counted backwards df = spark.createDataFrame( @@ -253,8 +253,8 @@ def test_split_part(self, spark): ], ["a", "b", "c"], ) - res = df.select(F.split_part(df.a, df.b, df.c).alias('r')).collect() - assert res == [Row(r='13')] + res = df.select(F.split_part(df.a, df.b, df.c).alias("r")).collect() + assert res == [Row(r="13")] # If the delimiter is an empty string, the return should be empty df = spark.createDataFrame( @@ -267,8 +267,8 @@ def test_split_part(self, spark): ], ["a", "b", "c"], ) - res = df.select(F.split_part(df.a, df.b, df.c).alias('r')).collect() - assert res == [Row(r='')] + res = df.select(F.split_part(df.a, df.b, df.c).alias("r")).collect() + assert res == [Row(r="")] def test_substr(self, spark): df = spark.createDataFrame( @@ -282,7 +282,7 @@ def test_substr(self, spark): ["a", "b", "c"], ) res = df.select(F.substr("a", "b", "c").alias("s")).collect() - assert res == [Row(s='k')] + assert res == [Row(s="k")] df = spark.createDataFrame( [ @@ -295,21 +295,21 @@ def test_substr(self, spark): ["a", "b", "c"], ) res = df.select(F.substr("a", "b").alias("s")).collect() - assert res == [Row(s='k SQL')] + assert res == [Row(s="k SQL")] def test_find_in_set(self, spark): string_array = "abc,b,ab,c,def" - df = spark.createDataFrame([("ab", string_array), ("b,c", string_array), ("z", string_array)], ['a', 'b']) + df = spark.createDataFrame([("ab", string_array), ("b,c", string_array), ("z", string_array)], ["a", "b"]) - res = df.select(F.find_in_set(df.a, df.b).alias('r')).collect() + res = df.select(F.find_in_set(df.a, df.b).alias("r")).collect() assert res == [Row(r=3), Row(r=0), Row(r=0)] def test_initcap(self, spark): - df = spark.createDataFrame([('ab cd',)], ['a']) + df = spark.createDataFrame([("ab cd",)], ["a"]) - res = df.select(F.initcap("a").alias('v')).collect() - assert res == [Row(v='Ab Cd')] + res = df.select(F.initcap("a").alias("v")).collect() + assert res == [Row(v="Ab Cd")] def test_left(self, spark): df = spark.createDataFrame( @@ -327,11 +327,11 @@ def test_left(self, spark): -3, ), ], - ['a', 'b'], + ["a", "b"], ) - res = df.select(F.left(df.a, df.b).alias('r')).collect() - assert res == [Row(r='Spa'), Row(r=''), Row(r='')] + res = df.select(F.left(df.a, df.b).alias("r")).collect() + assert res == [Row(r="Spa"), Row(r=""), Row(r="")] def test_right(self, spark): df = spark.createDataFrame( @@ -349,39 +349,39 @@ def test_right(self, spark): -3, ), ], - ['a', 'b'], + ["a", "b"], ) - res = df.select(F.right(df.a, df.b).alias('r')).collect() - assert res == [Row(r='SQL'), Row(r=''), Row(r='')] + res = df.select(F.right(df.a, df.b).alias("r")).collect() + assert res == [Row(r="SQL"), Row(r=""), Row(r="")] def test_levenshtein(self, spark): - df = spark.createDataFrame([("kitten", "sitting"), ("saturdays", "sunday")], ['a', 'b']) + df = spark.createDataFrame([("kitten", "sitting"), ("saturdays", "sunday")], ["a", "b"]) - res = df.select(F.levenshtein(df.a, df.b).alias('r'), F.levenshtein(df.a, df.b, 3).alias('r_th')).collect() + res = df.select(F.levenshtein(df.a, df.b).alias("r"), F.levenshtein(df.a, df.b, 3).alias("r_th")).collect() assert res == [Row(r=3, r_th=3), Row(r=4, r_th=-1)] def test_lpad(self, spark): df = spark.createDataFrame( - [('abcd',)], + [("abcd",)], [ - 's', + "s", ], ) - res = df.select(F.lpad(df.s, 6, '#').alias('s')).collect() - assert res == [Row(s='##abcd')] + res = df.select(F.lpad(df.s, 6, "#").alias("s")).collect() + assert res == [Row(s="##abcd")] def test_rpad(self, spark): df = spark.createDataFrame( - [('abcd',)], + [("abcd",)], [ - 's', + "s", ], ) - res = df.select(F.rpad(df.s, 6, '#').alias('s')).collect() - assert res == [Row(s='abcd##')] + res = df.select(F.rpad(df.s, 6, "#").alias("s")).collect() + assert res == [Row(s="abcd##")] def test_printf(self, spark): df = spark.createDataFrame( @@ -395,79 +395,79 @@ def test_printf(self, spark): ["a", "b", "c"], ) res = df.select(F.printf("a", "b", "c").alias("r")).collect() - assert res == [Row(r='aa123cc')] + assert res == [Row(r="aa123cc")] @pytest.mark.parametrize("regexp_func", [F.regexp, F.regexp_like]) def test_regexp_and_regexp_like(self, spark, regexp_func): df = spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]) - res = df.select(regexp_func('str', F.lit(r'(\d+)')).alias("m")).collect() + res = df.select(regexp_func("str", F.lit(r"(\d+)")).alias("m")).collect() assert res[0].m is True df = spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]) - res = df.select(regexp_func('str', F.lit(r'\d{2}b')).alias("m")).collect() + res = df.select(regexp_func("str", F.lit(r"\d{2}b")).alias("m")).collect() assert res[0].m is False df = spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]) - res = df.select(regexp_func('str', F.col("regexp")).alias("m")).collect() + res = df.select(regexp_func("str", F.col("regexp")).alias("m")).collect() assert res[0].m is True def test_regexp_count(self, spark): df = spark.createDataFrame([("1a 2b 14m", r"\d+")], ["str", "regexp"]) - res = df.select(F.regexp_count('str', F.lit(r'\d+')).alias('d')).collect() + res = df.select(F.regexp_count("str", F.lit(r"\d+")).alias("d")).collect() assert res == [Row(d=3)] - res = df.select(F.regexp_count('str', F.lit(r'mmm')).alias('d')).collect() + res = df.select(F.regexp_count("str", F.lit(r"mmm")).alias("d")).collect() assert res == [Row(d=0)] - res = df.select(F.regexp_count("str", F.col("regexp")).alias('d')).collect() + res = df.select(F.regexp_count("str", F.col("regexp")).alias("d")).collect() assert res == [Row(d=3)] def test_regexp_extract(self, spark): - df = spark.createDataFrame([('100-200',)], ['str']) - res = df.select(F.regexp_extract('str', r'(\d+)-(\d+)', 1).alias('d')).collect() - assert res == [Row(d='100')] + df = spark.createDataFrame([("100-200",)], ["str"]) + res = df.select(F.regexp_extract("str", r"(\d+)-(\d+)", 1).alias("d")).collect() + assert res == [Row(d="100")] - df = spark.createDataFrame([('foo',)], ['str']) - res = df.select(F.regexp_extract('str', r'(\d+)', 1).alias('d')).collect() - assert res == [Row(d='')] + df = spark.createDataFrame([("foo",)], ["str"]) + res = df.select(F.regexp_extract("str", r"(\d+)", 1).alias("d")).collect() + assert res == [Row(d="")] - df = spark.createDataFrame([('aaaac',)], ['str']) - res = df.select(F.regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect() - assert res == [Row(d='')] + df = spark.createDataFrame([("aaaac",)], ["str"]) + res = df.select(F.regexp_extract("str", "(a+)(b)?(c)", 2).alias("d")).collect() + assert res == [Row(d="")] def test_regexp_extract_all(self, spark): df = spark.createDataFrame([("100-200, 300-400", r"(\d+)-(\d+)")], ["str", "regexp"]) - res = df.select(F.regexp_extract_all('str', F.lit(r'(\d+)-(\d+)')).alias('d')).collect() - assert res == [Row(d=['100', '300'])] + res = df.select(F.regexp_extract_all("str", F.lit(r"(\d+)-(\d+)")).alias("d")).collect() + assert res == [Row(d=["100", "300"])] - res = df.select(F.regexp_extract_all('str', F.lit(r'(\d+)-(\d+)'), 1).alias('d')).collect() - assert res == [Row(d=['100', '300'])] + res = df.select(F.regexp_extract_all("str", F.lit(r"(\d+)-(\d+)"), 1).alias("d")).collect() + assert res == [Row(d=["100", "300"])] - res = df.select(F.regexp_extract_all('str', F.lit(r'(\d+)-(\d+)'), 2).alias('d')).collect() - assert res == [Row(d=['200', '400'])] + res = df.select(F.regexp_extract_all("str", F.lit(r"(\d+)-(\d+)"), 2).alias("d")).collect() + assert res == [Row(d=["200", "400"])] - res = df.select(F.regexp_extract_all('str', F.col("regexp")).alias('d')).collect() - assert res == [Row(d=['100', '300'])] + res = df.select(F.regexp_extract_all("str", F.col("regexp")).alias("d")).collect() + assert res == [Row(d=["100", "300"])] def test_regexp_substr(self, spark): df = spark.createDataFrame([("1a 2b 14m", r"\d+")], ["str", "regexp"]) - res = df.select(F.regexp_substr('str', F.lit(r'\d+')).alias('d')).collect() - assert res == [Row(d='1')] + res = df.select(F.regexp_substr("str", F.lit(r"\d+")).alias("d")).collect() + assert res == [Row(d="1")] - res = df.select(F.regexp_substr('str', F.lit(r'mmm')).alias('d')).collect() + res = df.select(F.regexp_substr("str", F.lit(r"mmm")).alias("d")).collect() assert res == [Row(d=None)] - res = df.select(F.regexp_substr("str", F.col("regexp")).alias('d')).collect() - assert res == [Row(d='1')] + res = df.select(F.regexp_substr("str", F.col("regexp")).alias("d")).collect() + assert res == [Row(d="1")] def test_repeat(self, spark): df = spark.createDataFrame( - [('ab',)], + [("ab",)], [ - 's', + "s", ], ) - res = df.select(F.repeat(df.s, 3).alias('s')).collect() - assert res == [Row(s='ababab')] + res = df.select(F.repeat(df.s, 3).alias("s")).collect() + assert res == [Row(s="ababab")] def test_reverse(self, spark): data = [ diff --git a/tests/fast/spark/test_spark_group_by.py b/tests/fast/spark/test_spark_group_by.py index 8b66901f..9e8a8ea0 100644 --- a/tests/fast/spark/test_spark_group_by.py +++ b/tests/fast/spark/test_spark_group_by.py @@ -175,7 +175,7 @@ def test_group_by_empty(self, spark): ) res = df.groupBy("name").count().columns - assert res == ['name', 'count'] + assert res == ["name", "count"] def test_group_by_first_and_last(self, spark): df = spark.createDataFrame([("Alice", 2), ("Bob", 5), ("Alice", None)], ("name", "age")) @@ -188,7 +188,7 @@ def test_group_by_first_and_last(self, spark): .collect() ) - assert res == [Row(name='Alice', first_age=None, last_age=2), Row(name='Bob', first_age=5, last_age=5)] + assert res == [Row(name="Alice", first_age=None, last_age=2), Row(name="Bob", first_age=5, last_age=5)] def test_standard_deviations(self, spark): df = spark.createDataFrame( @@ -265,7 +265,7 @@ def test_group_by_mean(self, spark): res = df.groupBy("course").agg(median("earnings").alias("m")).collect() - assert sorted(res, key=lambda x: x.course) == [Row(course='Java', m=22000), Row(course='dotNET', m=10000)] + assert sorted(res, key=lambda x: x.course) == [Row(course="Java", m=22000), Row(course="dotNET", m=10000)] def test_group_by_mode(self, spark): df = spark.createDataFrame( @@ -282,11 +282,11 @@ def test_group_by_mode(self, spark): res = df.groupby("course").agg(mode("year").alias("mode")).collect() - assert sorted(res, key=lambda x: x.course) == [Row(course='Java', mode=2012), Row(course='dotNET', mode=2012)] + assert sorted(res, key=lambda x: x.course) == [Row(course="Java", mode=2012), Row(course="dotNET", mode=2012)] def test_group_by_product(self, spark): - df = spark.range(1, 10).toDF('x').withColumn('mod3', col('x') % 3) - res = df.groupBy('mod3').agg(product('x').alias('product')).orderBy("mod3").collect() + df = spark.range(1, 10).toDF("x").withColumn("mod3", col("x") % 3) + res = df.groupBy("mod3").agg(product("x").alias("product")).orderBy("mod3").collect() assert res == [Row(mod3=0, product=162), Row(mod3=1, product=28), Row(mod3=2, product=80)] def test_group_by_skewness(self, spark): diff --git a/tests/fast/spark/test_spark_intersect.py b/tests/fast/spark/test_spark_intersect.py index 7fd97d40..ba0afbdd 100644 --- a/tests/fast/spark/test_spark_intersect.py +++ b/tests/fast/spark/test_spark_intersect.py @@ -19,7 +19,6 @@ def df2(spark): class TestDataFrameIntersect: def test_intersect(self, spark, df, df2): - df3 = df.intersect(df2).sort(df.C1) res = df3.collect() @@ -29,7 +28,6 @@ def test_intersect(self, spark, df, df2): ] def test_intersect_all(self, spark, df, df2): - df3 = df.intersectAll(df2).sort(df.C1) res = df3.collect() diff --git a/tests/fast/spark/test_spark_join.py b/tests/fast/spark/test_spark_join.py index c7ef9878..f67c54cb 100644 --- a/tests/fast/spark/test_spark_join.py +++ b/tests/fast/spark/test_spark_join.py @@ -49,63 +49,63 @@ def test_inner_join(self, dataframe_a, dataframe_b): expected = [ Row( emp_id=1, - name='Smith', + name="Smith", superior_emp_id=-1, - year_joined='2018', - emp_dept_id='10', - gender='M', + year_joined="2018", + emp_dept_id="10", + gender="M", salary=3000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=2, - name='Rose', + name="Rose", superior_emp_id=1, - year_joined='2010', - emp_dept_id='20', - gender='M', + year_joined="2010", + emp_dept_id="20", + gender="M", salary=4000, - dept_name='Marketing', + dept_name="Marketing", dept_id=20, ), Row( emp_id=3, - name='Williams', + name="Williams", superior_emp_id=1, - year_joined='2010', - emp_dept_id='10', - gender='M', + year_joined="2010", + emp_dept_id="10", + gender="M", salary=1000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=4, - name='Jones', + name="Jones", superior_emp_id=2, - year_joined='2005', - emp_dept_id='10', - gender='F', + year_joined="2005", + emp_dept_id="10", + gender="F", salary=2000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=5, - name='Brown', + name="Brown", superior_emp_id=2, - year_joined='2010', - emp_dept_id='40', - gender='', + year_joined="2010", + emp_dept_id="40", + gender="", salary=-1, - dept_name='IT', + dept_name="IT", dept_id=40, ), ] assert sorted(res) == sorted(expected) - @pytest.mark.parametrize('how', ['outer', 'fullouter', 'full', 'full_outer']) + @pytest.mark.parametrize("how", ["outer", "fullouter", "full", "full_outer"]) def test_outer_join(self, dataframe_a, dataframe_b, how): df = dataframe_a.join(dataframe_b, dataframe_a.emp_dept_id == dataframe_b.dept_id, how) df = df.sort(*df.columns) @@ -114,66 +114,66 @@ def test_outer_join(self, dataframe_a, dataframe_b, how): [ Row( emp_id=1, - name='Smith', + name="Smith", superior_emp_id=-1, - year_joined='2018', - emp_dept_id='10', - gender='M', + year_joined="2018", + emp_dept_id="10", + gender="M", salary=3000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=2, - name='Rose', + name="Rose", superior_emp_id=1, - year_joined='2010', - emp_dept_id='20', - gender='M', + year_joined="2010", + emp_dept_id="20", + gender="M", salary=4000, - dept_name='Marketing', + dept_name="Marketing", dept_id=20, ), Row( emp_id=3, - name='Williams', + name="Williams", superior_emp_id=1, - year_joined='2010', - emp_dept_id='10', - gender='M', + year_joined="2010", + emp_dept_id="10", + gender="M", salary=1000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=4, - name='Jones', + name="Jones", superior_emp_id=2, - year_joined='2005', - emp_dept_id='10', - gender='F', + year_joined="2005", + emp_dept_id="10", + gender="F", salary=2000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=5, - name='Brown', + name="Brown", superior_emp_id=2, - year_joined='2010', - emp_dept_id='40', - gender='', + year_joined="2010", + emp_dept_id="40", + gender="", salary=-1, - dept_name='IT', + dept_name="IT", dept_id=40, ), Row( emp_id=6, - name='Brown', + name="Brown", superior_emp_id=2, - year_joined='2010', - emp_dept_id='50', - gender='', + year_joined="2010", + emp_dept_id="50", + gender="", salary=-1, dept_name=None, dept_id=None, @@ -186,14 +186,14 @@ def test_outer_join(self, dataframe_a, dataframe_b, how): emp_dept_id=None, gender=None, salary=None, - dept_name='Sales', + dept_name="Sales", dept_id=30, ), ], key=lambda x: x.emp_id or 0, ) - @pytest.mark.parametrize('how', ['right', 'rightouter', 'right_outer']) + @pytest.mark.parametrize("how", ["right", "rightouter", "right_outer"]) def test_right_join(self, dataframe_a, dataframe_b, how): df = dataframe_a.join(dataframe_b, dataframe_a.emp_dept_id == dataframe_b.dept_id, how) df = df.sort(*df.columns) @@ -202,57 +202,57 @@ def test_right_join(self, dataframe_a, dataframe_b, how): [ Row( emp_id=1, - name='Smith', + name="Smith", superior_emp_id=-1, - year_joined='2018', - emp_dept_id='10', - gender='M', + year_joined="2018", + emp_dept_id="10", + gender="M", salary=3000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=2, - name='Rose', + name="Rose", superior_emp_id=1, - year_joined='2010', - emp_dept_id='20', - gender='M', + year_joined="2010", + emp_dept_id="20", + gender="M", salary=4000, - dept_name='Marketing', + dept_name="Marketing", dept_id=20, ), Row( emp_id=3, - name='Williams', + name="Williams", superior_emp_id=1, - year_joined='2010', - emp_dept_id='10', - gender='M', + year_joined="2010", + emp_dept_id="10", + gender="M", salary=1000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=4, - name='Jones', + name="Jones", superior_emp_id=2, - year_joined='2005', - emp_dept_id='10', - gender='F', + year_joined="2005", + emp_dept_id="10", + gender="F", salary=2000, - dept_name='Finance', + dept_name="Finance", dept_id=10, ), Row( emp_id=5, - name='Brown', + name="Brown", superior_emp_id=2, - year_joined='2010', - emp_dept_id='40', - gender='', + year_joined="2010", + emp_dept_id="40", + gender="", salary=-1, - dept_name='IT', + dept_name="IT", dept_id=40, ), Row( @@ -263,14 +263,14 @@ def test_right_join(self, dataframe_a, dataframe_b, how): emp_dept_id=None, gender=None, salary=None, - dept_name='Sales', + dept_name="Sales", dept_id=30, ), ], key=lambda x: x.emp_id or 0, ) - @pytest.mark.parametrize('how', ['semi', 'leftsemi', 'left_semi']) + @pytest.mark.parametrize("how", ["semi", "leftsemi", "left_semi"]) def test_semi_join(self, dataframe_a, dataframe_b, how): df = dataframe_a.join(dataframe_b, dataframe_a.emp_dept_id == dataframe_b.dept_id, how) df = df.sort(*df.columns) @@ -279,59 +279,59 @@ def test_semi_join(self, dataframe_a, dataframe_b, how): [ Row( emp_id=1, - name='Smith', + name="Smith", superior_emp_id=-1, - year_joined='2018', - emp_dept_id='10', - gender='M', + year_joined="2018", + emp_dept_id="10", + gender="M", salary=3000, ), Row( emp_id=2, - name='Rose', + name="Rose", superior_emp_id=1, - year_joined='2010', - emp_dept_id='20', - gender='M', + year_joined="2010", + emp_dept_id="20", + gender="M", salary=4000, ), Row( emp_id=3, - name='Williams', + name="Williams", superior_emp_id=1, - year_joined='2010', - emp_dept_id='10', - gender='M', + year_joined="2010", + emp_dept_id="10", + gender="M", salary=1000, ), Row( emp_id=4, - name='Jones', + name="Jones", superior_emp_id=2, - year_joined='2005', - emp_dept_id='10', - gender='F', + year_joined="2005", + emp_dept_id="10", + gender="F", salary=2000, ), Row( emp_id=5, - name='Brown', + name="Brown", superior_emp_id=2, - year_joined='2010', - emp_dept_id='40', - gender='', + year_joined="2010", + emp_dept_id="40", + gender="", salary=-1, ), ] ) - @pytest.mark.parametrize('how', ['anti', 'leftanti', 'left_anti']) + @pytest.mark.parametrize("how", ["anti", "leftanti", "left_anti"]) def test_anti_join(self, dataframe_a, dataframe_b, how): df = dataframe_a.join(dataframe_b, dataframe_a.emp_dept_id == dataframe_b.dept_id, how) df = df.sort(*df.columns) res = df.collect() assert res == [ - Row(emp_id=6, name='Brown', superior_emp_id=2, year_joined='2010', emp_dept_id='50', gender='', salary=-1) + Row(emp_id=6, name="Brown", superior_emp_id=2, year_joined="2010", emp_dept_id="50", gender="", salary=-1) ] def test_self_join(self, dataframe_a): @@ -351,11 +351,11 @@ def test_self_join(self, dataframe_a): res = df.collect() assert sorted(res, key=lambda x: x.emp_id) == sorted( [ - Row(emp_id=2, name='Rose', superior_emp_id=1, superior_emp_name='Smith'), - Row(emp_id=3, name='Williams', superior_emp_id=1, superior_emp_name='Smith'), - Row(emp_id=4, name='Jones', superior_emp_id=2, superior_emp_name='Rose'), - Row(emp_id=5, name='Brown', superior_emp_id=2, superior_emp_name='Rose'), - Row(emp_id=6, name='Brown', superior_emp_id=2, superior_emp_name='Rose'), + Row(emp_id=2, name="Rose", superior_emp_id=1, superior_emp_name="Smith"), + Row(emp_id=3, name="Williams", superior_emp_id=1, superior_emp_name="Smith"), + Row(emp_id=4, name="Jones", superior_emp_id=2, superior_emp_name="Rose"), + Row(emp_id=5, name="Brown", superior_emp_id=2, superior_emp_name="Rose"), + Row(emp_id=6, name="Brown", superior_emp_id=2, superior_emp_name="Rose"), ], key=lambda x: x.emp_id, ) @@ -382,29 +382,29 @@ def test_cross_join(self, spark): ) def test_join_with_using_clause(self, spark, dataframe_a): - dataframe_a = dataframe_a.select('name', 'year_joined') + dataframe_a = dataframe_a.select("name", "year_joined") - df = dataframe_a.alias('df1') - df2 = dataframe_a.alias('df2') - res = df.join(df2, ['name', 'year_joined']).sort('name', 'year_joined') + df = dataframe_a.alias("df1") + df2 = dataframe_a.alias("df2") + res = df.join(df2, ["name", "year_joined"]).sort("name", "year_joined") res = res.collect() assert res == [ - Row(name='Brown', year_joined='2010'), - Row(name='Brown', year_joined='2010'), - Row(name='Brown', year_joined='2010'), - Row(name='Brown', year_joined='2010'), - Row(name='Jones', year_joined='2005'), - Row(name='Rose', year_joined='2010'), - Row(name='Smith', year_joined='2018'), - Row(name='Williams', year_joined='2010'), + Row(name="Brown", year_joined="2010"), + Row(name="Brown", year_joined="2010"), + Row(name="Brown", year_joined="2010"), + Row(name="Brown", year_joined="2010"), + Row(name="Jones", year_joined="2005"), + Row(name="Rose", year_joined="2010"), + Row(name="Smith", year_joined="2018"), + Row(name="Williams", year_joined="2010"), ] def test_join_with_common_column(self, spark, dataframe_a): - dataframe_a = dataframe_a.select('name', 'year_joined') + dataframe_a = dataframe_a.select("name", "year_joined") - df = dataframe_a.alias('df1') - df2 = dataframe_a.alias('df2') - res = df.join(df2, df.name == df2.name).sort('df1.name') + df = dataframe_a.alias("df1") + df2 = dataframe_a.alias("df2") + res = df.join(df2, df.name == df2.name).sort("df1.name") res = res.collect() assert ( str(res) diff --git a/tests/fast/spark/test_spark_order_by.py b/tests/fast/spark/test_spark_order_by.py index 92aa4d3a..cc08dd7c 100644 --- a/tests/fast/spark/test_spark_order_by.py +++ b/tests/fast/spark/test_spark_order_by.py @@ -38,15 +38,15 @@ def test_order_by(self, spark): df2 = df.sort("department", "state") res1 = df2.collect() assert res1 == [ - Row(employee_name='Maria', department='Finance', state='CA', salary=90000, age=24, bonus=23000), - Row(employee_name='Raman', department='Finance', state='CA', salary=99000, age=40, bonus=24000), - Row(employee_name='Scott', department='Finance', state='NY', salary=83000, age=36, bonus=19000), - Row(employee_name='Jen', department='Finance', state='NY', salary=79000, age=53, bonus=15000), - Row(employee_name='Jeff', department='Marketing', state='CA', salary=80000, age=25, bonus=18000), - Row(employee_name='Kumar', department='Marketing', state='NY', salary=91000, age=50, bonus=21000), - Row(employee_name='Robert', department='Sales', state='CA', salary=81000, age=30, bonus=23000), - Row(employee_name='James', department='Sales', state='NY', salary=90000, age=34, bonus=10000), - Row(employee_name='Michael', department='Sales', state='NY', salary=86000, age=56, bonus=20000), + Row(employee_name="Maria", department="Finance", state="CA", salary=90000, age=24, bonus=23000), + Row(employee_name="Raman", department="Finance", state="CA", salary=99000, age=40, bonus=24000), + Row(employee_name="Scott", department="Finance", state="NY", salary=83000, age=36, bonus=19000), + Row(employee_name="Jen", department="Finance", state="NY", salary=79000, age=53, bonus=15000), + Row(employee_name="Jeff", department="Marketing", state="CA", salary=80000, age=25, bonus=18000), + Row(employee_name="Kumar", department="Marketing", state="NY", salary=91000, age=50, bonus=21000), + Row(employee_name="Robert", department="Sales", state="CA", salary=81000, age=30, bonus=23000), + Row(employee_name="James", department="Sales", state="NY", salary=90000, age=34, bonus=10000), + Row(employee_name="Michael", department="Sales", state="NY", salary=86000, age=56, bonus=20000), ] df2 = df.sort(col("department"), col("state")) @@ -60,15 +60,15 @@ def test_order_by(self, spark): df2 = df.sort(df.department.asc(), df.state.desc()) res1 = df2.collect() assert res1 == [ - Row(employee_name='Scott', department='Finance', state='NY', salary=83000, age=36, bonus=19000), - Row(employee_name='Jen', department='Finance', state='NY', salary=79000, age=53, bonus=15000), - Row(employee_name='Maria', department='Finance', state='CA', salary=90000, age=24, bonus=23000), - Row(employee_name='Raman', department='Finance', state='CA', salary=99000, age=40, bonus=24000), - Row(employee_name='Kumar', department='Marketing', state='NY', salary=91000, age=50, bonus=21000), - Row(employee_name='Jeff', department='Marketing', state='CA', salary=80000, age=25, bonus=18000), - Row(employee_name='James', department='Sales', state='NY', salary=90000, age=34, bonus=10000), - Row(employee_name='Michael', department='Sales', state='NY', salary=86000, age=56, bonus=20000), - Row(employee_name='Robert', department='Sales', state='CA', salary=81000, age=30, bonus=23000), + Row(employee_name="Scott", department="Finance", state="NY", salary=83000, age=36, bonus=19000), + Row(employee_name="Jen", department="Finance", state="NY", salary=79000, age=53, bonus=15000), + Row(employee_name="Maria", department="Finance", state="CA", salary=90000, age=24, bonus=23000), + Row(employee_name="Raman", department="Finance", state="CA", salary=99000, age=40, bonus=24000), + Row(employee_name="Kumar", department="Marketing", state="NY", salary=91000, age=50, bonus=21000), + Row(employee_name="Jeff", department="Marketing", state="CA", salary=80000, age=25, bonus=18000), + Row(employee_name="James", department="Sales", state="NY", salary=90000, age=34, bonus=10000), + Row(employee_name="Michael", department="Sales", state="NY", salary=86000, age=56, bonus=20000), + Row(employee_name="Robert", department="Sales", state="CA", salary=81000, age=30, bonus=23000), ] df2 = df.sort(col("department").asc(), col("state").desc()) @@ -94,15 +94,15 @@ def test_order_by(self, spark): ) res = df2.collect() assert res == [ - Row(employee_name='Maria', department='Finance', state='CA', salary=90000, age=24, bonus=23000), - Row(employee_name='Raman', department='Finance', state='CA', salary=99000, age=40, bonus=24000), - Row(employee_name='Scott', department='Finance', state='NY', salary=83000, age=36, bonus=19000), - Row(employee_name='Jen', department='Finance', state='NY', salary=79000, age=53, bonus=15000), - Row(employee_name='Jeff', department='Marketing', state='CA', salary=80000, age=25, bonus=18000), - Row(employee_name='Kumar', department='Marketing', state='NY', salary=91000, age=50, bonus=21000), - Row(employee_name='James', department='Sales', state='NY', salary=90000, age=34, bonus=10000), - Row(employee_name='Michael', department='Sales', state='NY', salary=86000, age=56, bonus=20000), - Row(employee_name='Robert', department='Sales', state='CA', salary=81000, age=30, bonus=23000), + Row(employee_name="Maria", department="Finance", state="CA", salary=90000, age=24, bonus=23000), + Row(employee_name="Raman", department="Finance", state="CA", salary=99000, age=40, bonus=24000), + Row(employee_name="Scott", department="Finance", state="NY", salary=83000, age=36, bonus=19000), + Row(employee_name="Jen", department="Finance", state="NY", salary=79000, age=53, bonus=15000), + Row(employee_name="Jeff", department="Marketing", state="CA", salary=80000, age=25, bonus=18000), + Row(employee_name="Kumar", department="Marketing", state="NY", salary=91000, age=50, bonus=21000), + Row(employee_name="James", department="Sales", state="NY", salary=90000, age=34, bonus=10000), + Row(employee_name="Michael", department="Sales", state="NY", salary=86000, age=56, bonus=20000), + Row(employee_name="Robert", department="Sales", state="CA", salary=81000, age=30, bonus=23000), ] def test_null_ordering(self, spark): @@ -130,56 +130,56 @@ def test_null_ordering(self, spark): res = df.orderBy("value1", "value2").collect() assert res == [ - Row(value1=None, value2='A'), - Row(value1=2, value2='A'), + Row(value1=None, value2="A"), + Row(value1=2, value2="A"), Row(value1=3, value2=None), - Row(value1=3, value2='A'), + Row(value1=3, value2="A"), ] res = df.orderBy("value1", "value2", ascending=True).collect() assert res == [ - Row(value1=None, value2='A'), - Row(value1=2, value2='A'), + Row(value1=None, value2="A"), + Row(value1=2, value2="A"), Row(value1=3, value2=None), - Row(value1=3, value2='A'), + Row(value1=3, value2="A"), ] res = df.orderBy("value1", "value2", ascending=False).collect() assert res == [ - Row(value1=3, value2='A'), + Row(value1=3, value2="A"), Row(value1=3, value2=None), - Row(value1=2, value2='A'), - Row(value1=None, value2='A'), + Row(value1=2, value2="A"), + Row(value1=None, value2="A"), ] res = df.orderBy(df.value1, df.value2).collect() assert res == [ - Row(value1=None, value2='A'), - Row(value1=2, value2='A'), + Row(value1=None, value2="A"), + Row(value1=2, value2="A"), Row(value1=3, value2=None), - Row(value1=3, value2='A'), + Row(value1=3, value2="A"), ] res = df.orderBy(df.value1.asc(), df.value2.asc()).collect() assert res == [ - Row(value1=None, value2='A'), - Row(value1=2, value2='A'), + Row(value1=None, value2="A"), + Row(value1=2, value2="A"), Row(value1=3, value2=None), - Row(value1=3, value2='A'), + Row(value1=3, value2="A"), ] res = df.orderBy(df.value1.desc(), df.value2.desc()).collect() assert res == [ - Row(value1=3, value2='A'), + Row(value1=3, value2="A"), Row(value1=3, value2=None), - Row(value1=2, value2='A'), - Row(value1=None, value2='A'), + Row(value1=2, value2="A"), + Row(value1=None, value2="A"), ] res = df.orderBy(df.value1, df.value2, ascending=[True, False]).collect() assert res == [ - Row(value1=None, value2='A'), - Row(value1=2, value2='A'), + Row(value1=None, value2="A"), + Row(value1=2, value2="A"), Row(value1=3, value2="A"), Row(value1=3, value2=None), ] diff --git a/tests/fast/spark/test_spark_pandas_dataframe.py b/tests/fast/spark/test_spark_pandas_dataframe.py index dcec77a8..6491b7a6 100644 --- a/tests/fast/spark/test_spark_pandas_dataframe.py +++ b/tests/fast/spark/test_spark_pandas_dataframe.py @@ -23,9 +23,9 @@ @pytest.fixture def pandasDF(spark): - data = [['Scott', 50], ['Jeff', 45], ['Thomas', 54], ['Ann', 34]] + data = [["Scott", 50], ["Jeff", 45], ["Thomas", 54], ["Ann", 34]] # Create the pandas DataFrame - df = pd.DataFrame(data, columns=['Name', 'Age']) + df = pd.DataFrame(data, columns=["Name", "Age"]) yield df @@ -35,10 +35,10 @@ def test_pd_conversion_basic(self, spark, pandasDF): res = sparkDF.collect() sparkDF.show() expected = [ - Row(Name='Scott', Age=50), - Row(Name='Jeff', Age=45), - Row(Name='Thomas', Age=54), - Row(Name='Ann', Age=34), + Row(Name="Scott", Age=50), + Row(Name="Jeff", Age=45), + Row(Name="Thomas", Age=54), + Row(Name="Ann", Age=34), ] assert res == expected diff --git a/tests/fast/spark/test_spark_readcsv.py b/tests/fast/spark/test_spark_readcsv.py index 8e6c0515..5ba3d199 100644 --- a/tests/fast/spark/test_spark_readcsv.py +++ b/tests/fast/spark/test_spark_readcsv.py @@ -9,8 +9,8 @@ class TestSparkReadCSV(object): def test_read_csv(self, spark, tmp_path): - file_path = tmp_path / 'basic.csv' - with open(file_path, 'w+') as f: + file_path = tmp_path / "basic.csv" + with open(file_path, "w+") as f: f.write( textwrap.dedent( """ diff --git a/tests/fast/spark/test_spark_readjson.py b/tests/fast/spark/test_spark_readjson.py index a6ad05f0..638bee2d 100644 --- a/tests/fast/spark/test_spark_readjson.py +++ b/tests/fast/spark/test_spark_readjson.py @@ -9,9 +9,9 @@ class TestSparkReadJson(object): def test_read_json(self, duckdb_cursor, spark, tmp_path): - file_path = tmp_path / 'basic.parquet' + file_path = tmp_path / "basic.parquet" file_path = file_path.as_posix() duckdb_cursor.execute(f"COPY (select 42 a, true b, 'this is a long string' c) to '{file_path}' (FORMAT JSON)") df = spark.read.json(file_path) res = df.collect() - assert res == [Row(a=42, b=True, c='this is a long string')] + assert res == [Row(a=42, b=True, c="this is a long string")] diff --git a/tests/fast/spark/test_spark_readparquet.py b/tests/fast/spark/test_spark_readparquet.py index a08ab16d..1b3ddd74 100644 --- a/tests/fast/spark/test_spark_readparquet.py +++ b/tests/fast/spark/test_spark_readparquet.py @@ -9,11 +9,11 @@ class TestSparkReadParquet(object): def test_read_parquet(self, duckdb_cursor, spark, tmp_path): - file_path = tmp_path / 'basic.parquet' + file_path = tmp_path / "basic.parquet" file_path = file_path.as_posix() duckdb_cursor.execute( f"COPY (select 42 a, true b, 'this is a long string' c) to '{file_path}' (FORMAT PARQUET)" ) df = spark.read.parquet(file_path) res = df.collect() - assert res == [Row(a=42, b=True, c='this is a long string')] + assert res == [Row(a=42, b=True, c="this is a long string")] diff --git a/tests/fast/spark/test_spark_session.py b/tests/fast/spark/test_spark_session.py index 7c338898..06c9dbcb 100644 --- a/tests/fast/spark/test_spark_session.py +++ b/tests/fast/spark/test_spark_session.py @@ -14,14 +14,14 @@ def test_spark_session_default(self): session = SparkSession.builder.getOrCreate() def test_spark_session(self): - session = SparkSession.builder.master("local[1]").appName('SparkByExamples.com').getOrCreate() + session = SparkSession.builder.master("local[1]").appName("SparkByExamples.com").getOrCreate() def test_new_session(self, spark: SparkSession): session = spark.newSession() - @pytest.mark.skip(reason='not tested yet') + @pytest.mark.skip(reason="not tested yet") def test_retrieve_same_session(self): - spark = SparkSession.builder.master('test').appName('test2').getOrCreate() + spark = SparkSession.builder.master("test").appName("test2").getOrCreate() spark2 = SparkSession.builder.getOrCreate() # Same connection should be returned assert spark == spark2 @@ -49,7 +49,7 @@ def test_hive_support(self): @pytest.mark.skipif(USE_ACTUAL_SPARK, reason="Different version numbers") def test_version(self, spark): version = spark.version - assert version == '1.0.0' + assert version == "1.0.0" def test_get_active_session(self, spark): active_session = spark.getActiveSession() @@ -58,7 +58,7 @@ def test_read(self, spark): reader = spark.read def test_write(self, spark): - df = spark.sql('select 42') + df = spark.sql("select 42") writer = df.write def test_read_stream(self, spark): @@ -68,7 +68,7 @@ def test_spark_context(self, spark): context = spark.sparkContext def test_sql(self, spark): - df = spark.sql('select 42') + df = spark.sql("select 42") def test_stop_context(self, spark): context = spark.sparkContext @@ -78,8 +78,8 @@ def test_stop_context(self, spark): USE_ACTUAL_SPARK, reason="Can't create table with the local PySpark setup in the CI/CD pipeline" ) def test_table(self, spark): - spark.sql('create table tbl(a varchar(10))') - df = spark.table('tbl') + spark.sql("create table tbl(a varchar(10))") + df = spark.table("tbl") def test_range(self, spark): res_1 = spark.range(3).collect() diff --git a/tests/fast/spark/test_spark_to_csv.py b/tests/fast/spark/test_spark_to_csv.py index 5048e579..e5387a6c 100644 --- a/tests/fast/spark/test_spark_to_csv.py +++ b/tests/fast/spark/test_spark_to_csv.py @@ -40,14 +40,14 @@ def df(spark): @pytest.fixture(params=[NumpyPandas(), ArrowPandas()]) def pandas_df_ints(request, spark): pandas = request.param - dataframe = pandas.DataFrame({'a': [5, 3, 23, 2], 'b': [45, 234, 234, 2]}) + dataframe = pandas.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) yield dataframe @pytest.fixture(params=[NumpyPandas(), ArrowPandas()]) def pandas_df_strings(request, spark): pandas = request.param - dataframe = pandas.DataFrame({'a': ['string1', 'string2', 'string3']}) + dataframe = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) yield dataframe @@ -68,15 +68,15 @@ def test_to_csv_sep(self, pandas_df_ints, spark, tmp_path): df = spark.createDataFrame(pandas_df_ints) - df.write.csv(temp_file_name, sep=',') + df.write.csv(temp_file_name, sep=",") - csv_rel = spark.read.csv(temp_file_name, sep=',') + csv_rel = spark.read.csv(temp_file_name, sep=",") assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_na_rep(self, pandas, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") - pandas_df = pandas.DataFrame({'a': [5, None, 23, 2], 'b': [45, 234, 234, 2]}) + pandas_df = pandas.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) df = spark.createDataFrame(pandas_df) @@ -85,10 +85,10 @@ def test_to_csv_na_rep(self, pandas, spark, tmp_path): csv_rel = spark.read.csv(temp_file_name, nullValue="test") assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_header(self, pandas, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") - pandas_df = pandas.DataFrame({'a': [5, None, 23, 2], 'b': [45, 234, 234, 2]}) + pandas_df = pandas.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) df = spark.createDataFrame(pandas_df) @@ -97,20 +97,20 @@ def test_to_csv_header(self, pandas, spark, tmp_path): csv_rel = spark.read.csv(temp_file_name) assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_quotechar(self, pandas, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") - pandas_df = pandas.DataFrame({'a': ["\'a,b,c\'", None, "hello", "bye"], 'b': [45, 234, 234, 2]}) + pandas_df = pandas.DataFrame({"a": ["'a,b,c'", None, "hello", "bye"], "b": [45, 234, 234, 2]}) df = spark.createDataFrame(pandas_df) - df.write.csv(temp_file_name, quote='\'', sep=',') + df.write.csv(temp_file_name, quote="'", sep=",") - csv_rel = spark.read.csv(temp_file_name, sep=',', quote='\'') + csv_rel = spark.read.csv(temp_file_name, sep=",", quote="'") assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_escapechar(self, pandas, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") pandas_df = pandas.DataFrame( @@ -124,11 +124,11 @@ def test_to_csv_escapechar(self, pandas, spark, tmp_path): df = spark.createDataFrame(pandas_df) - df.write.csv(temp_file_name, quote='"', escape='!') - csv_rel = spark.read.csv(temp_file_name, quote='"', escape='!') + df.write.csv(temp_file_name, quote='"', escape="!") + csv_rel = spark.read.csv(temp_file_name, quote='"', escape="!") assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_date_format(self, pandas, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") pandas_df = pandas.DataFrame(getTimeSeriesData()) @@ -143,17 +143,17 @@ def test_to_csv_date_format(self, pandas, spark, tmp_path): assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_to_csv_timestamp_format(self, pandas, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") data = [datetime.time(hour=23, minute=1, second=34, microsecond=234345)] - pandas_df = pandas.DataFrame({'0': pandas.Series(data=data, dtype='object')}) + pandas_df = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) df = spark.createDataFrame(pandas_df) - df.write.csv(temp_file_name, timestampFormat='%m/%d/%Y') + df.write.csv(temp_file_name, timestampFormat="%m/%d/%Y") - csv_rel = spark.read.csv(temp_file_name, timestampFormat='%m/%d/%Y') + csv_rel = spark.read.csv(temp_file_name, timestampFormat="%m/%d/%Y") assert df.collect() == csv_rel.collect() diff --git a/tests/fast/spark/test_spark_transform.py b/tests/fast/spark/test_spark_transform.py index 83e219a5..1f1186c5 100644 --- a/tests/fast/spark/test_spark_transform.py +++ b/tests/fast/spark/test_spark_transform.py @@ -62,15 +62,15 @@ def apply_discount(df): df2 = df.transform(to_upper_str_columns).transform(reduce_price, 1000).transform(apply_discount) res = df2.collect() assert res == [ - Row(CourseName='JAVA', fee=4000, discount=5, new_fee=3000, discounted_fee=2850.0), - Row(CourseName='PYTHON', fee=4600, discount=10, new_fee=3600, discounted_fee=3240.0), - Row(CourseName='SCALA', fee=4100, discount=15, new_fee=3100, discounted_fee=2635.0), - Row(CourseName='SCALA', fee=4500, discount=15, new_fee=3500, discounted_fee=2975.0), - Row(CourseName='PHP', fee=3000, discount=20, new_fee=2000, discounted_fee=1600.0), + Row(CourseName="JAVA", fee=4000, discount=5, new_fee=3000, discounted_fee=2850.0), + Row(CourseName="PYTHON", fee=4600, discount=10, new_fee=3600, discounted_fee=3240.0), + Row(CourseName="SCALA", fee=4100, discount=15, new_fee=3100, discounted_fee=2635.0), + Row(CourseName="SCALA", fee=4500, discount=15, new_fee=3500, discounted_fee=2975.0), + Row(CourseName="PHP", fee=3000, discount=20, new_fee=2000, discounted_fee=1600.0), ] # https://sparkbyexamples.com/pyspark/pyspark-transform-function/ - @pytest.mark.skip(reason='LambdaExpressions are currently under development, waiting til that is finished') + @pytest.mark.skip(reason="LambdaExpressions are currently under development, waiting til that is finished") def test_transform_function(self, spark, array_df): from spark_namespace.sql.functions import upper, transform diff --git a/tests/fast/spark/test_spark_types.py b/tests/fast/spark/test_spark_types.py index fb6e6102..6c97c2d9 100644 --- a/tests/fast/spark/test_spark_types.py +++ b/tests/fast/spark/test_spark_types.py @@ -70,65 +70,65 @@ def test_all_types_schema(self, spark): schema = df.schema assert schema == StructType( [ - StructField('bool', BooleanType(), True), - StructField('tinyint', ByteType(), True), - StructField('smallint', ShortType(), True), - StructField('int', IntegerType(), True), - StructField('bigint', LongType(), True), - StructField('hugeint', HugeIntegerType(), True), - StructField('uhugeint', UnsignedHugeIntegerType(), True), - StructField('utinyint', UnsignedByteType(), True), - StructField('usmallint', UnsignedShortType(), True), - StructField('uint', UnsignedIntegerType(), True), - StructField('ubigint', UnsignedLongType(), True), - StructField('date', DateType(), True), - StructField('time', TimeNTZType(), True), - StructField('timestamp', TimestampNTZType(), True), - StructField('timestamp_s', TimestampSecondNTZType(), True), - StructField('timestamp_ms', TimestampNanosecondNTZType(), True), - StructField('timestamp_ns', TimestampMilisecondNTZType(), True), - StructField('time_tz', TimeType(), True), - StructField('timestamp_tz', TimestampType(), True), - StructField('float', FloatType(), True), - StructField('double', DoubleType(), True), - StructField('dec_4_1', DecimalType(4, 1), True), - StructField('dec_9_4', DecimalType(9, 4), True), - StructField('dec_18_6', DecimalType(18, 6), True), - StructField('dec38_10', DecimalType(38, 10), True), - StructField('uuid', UUIDType(), True), - StructField('interval', DayTimeIntervalType(0, 3), True), - StructField('varchar', StringType(), True), - StructField('blob', BinaryType(), True), - StructField('bit', BitstringType(), True), - StructField('int_array', ArrayType(IntegerType(), True), True), - StructField('double_array', ArrayType(DoubleType(), True), True), - StructField('date_array', ArrayType(DateType(), True), True), - StructField('timestamp_array', ArrayType(TimestampNTZType(), True), True), - StructField('timestamptz_array', ArrayType(TimestampType(), True), True), - StructField('varchar_array', ArrayType(StringType(), True), True), - StructField('nested_int_array', ArrayType(ArrayType(IntegerType(), True), True), True), + StructField("bool", BooleanType(), True), + StructField("tinyint", ByteType(), True), + StructField("smallint", ShortType(), True), + StructField("int", IntegerType(), True), + StructField("bigint", LongType(), True), + StructField("hugeint", HugeIntegerType(), True), + StructField("uhugeint", UnsignedHugeIntegerType(), True), + StructField("utinyint", UnsignedByteType(), True), + StructField("usmallint", UnsignedShortType(), True), + StructField("uint", UnsignedIntegerType(), True), + StructField("ubigint", UnsignedLongType(), True), + StructField("date", DateType(), True), + StructField("time", TimeNTZType(), True), + StructField("timestamp", TimestampNTZType(), True), + StructField("timestamp_s", TimestampSecondNTZType(), True), + StructField("timestamp_ms", TimestampNanosecondNTZType(), True), + StructField("timestamp_ns", TimestampMilisecondNTZType(), True), + StructField("time_tz", TimeType(), True), + StructField("timestamp_tz", TimestampType(), True), + StructField("float", FloatType(), True), + StructField("double", DoubleType(), True), + StructField("dec_4_1", DecimalType(4, 1), True), + StructField("dec_9_4", DecimalType(9, 4), True), + StructField("dec_18_6", DecimalType(18, 6), True), + StructField("dec38_10", DecimalType(38, 10), True), + StructField("uuid", UUIDType(), True), + StructField("interval", DayTimeIntervalType(0, 3), True), + StructField("varchar", StringType(), True), + StructField("blob", BinaryType(), True), + StructField("bit", BitstringType(), True), + StructField("int_array", ArrayType(IntegerType(), True), True), + StructField("double_array", ArrayType(DoubleType(), True), True), + StructField("date_array", ArrayType(DateType(), True), True), + StructField("timestamp_array", ArrayType(TimestampNTZType(), True), True), + StructField("timestamptz_array", ArrayType(TimestampType(), True), True), + StructField("varchar_array", ArrayType(StringType(), True), True), + StructField("nested_int_array", ArrayType(ArrayType(IntegerType(), True), True), True), StructField( - 'struct', - StructType([StructField('a', IntegerType(), True), StructField('b', StringType(), True)]), + "struct", + StructType([StructField("a", IntegerType(), True), StructField("b", StringType(), True)]), True, ), StructField( - 'struct_of_arrays', + "struct_of_arrays", StructType( [ - StructField('a', ArrayType(IntegerType(), True), True), - StructField('b', ArrayType(StringType(), True), True), + StructField("a", ArrayType(IntegerType(), True), True), + StructField("b", ArrayType(StringType(), True), True), ] ), True, ), StructField( - 'array_of_structs', + "array_of_structs", ArrayType( - StructType([StructField('a', IntegerType(), True), StructField('b', StringType(), True)]), True + StructType([StructField("a", IntegerType(), True), StructField("b", StringType(), True)]), True ), True, ), - StructField('map', MapType(StringType(), StringType(), True), True), + StructField("map", MapType(StringType(), StringType(), True), True), ] ) diff --git a/tests/fast/spark/test_spark_udf.py b/tests/fast/spark/test_spark_udf.py index 3b5a5d36..eebabbb3 100644 --- a/tests/fast/spark/test_spark_udf.py +++ b/tests/fast/spark/test_spark_udf.py @@ -5,7 +5,6 @@ class TestSparkUDF(object): def test_udf_register(self, spark): - def to_upper_fn(s: str) -> str: return s.upper() diff --git a/tests/fast/spark/test_spark_union.py b/tests/fast/spark/test_spark_union.py index ea889e05..8a3ff9ce 100644 --- a/tests/fast/spark/test_spark_union.py +++ b/tests/fast/spark/test_spark_union.py @@ -40,15 +40,15 @@ def test_merge_with_union(self, df, df2): unionDF = df.union(df2) res = unionDF.collect() assert res == [ - Row(employee_name='James', department='Sales', state='NY', salary=90000, age=34, bonus=10000), - Row(employee_name='Michael', department='Sales', state='NY', salary=86000, age=56, bonus=20000), - Row(employee_name='Robert', department='Sales', state='CA', salary=81000, age=30, bonus=23000), - Row(employee_name='Maria', department='Finance', state='CA', salary=90000, age=24, bonus=23000), - Row(employee_name='James', department='Sales', state='NY', salary=90000, age=34, bonus=10000), - Row(employee_name='Maria', department='Finance', state='CA', salary=90000, age=24, bonus=23000), - Row(employee_name='Jen', department='Finance', state='NY', salary=79000, age=53, bonus=15000), - Row(employee_name='Jeff', department='Marketing', state='CA', salary=80000, age=25, bonus=18000), - Row(employee_name='Kumar', department='Marketing', state='NY', salary=91000, age=50, bonus=21000), + Row(employee_name="James", department="Sales", state="NY", salary=90000, age=34, bonus=10000), + Row(employee_name="Michael", department="Sales", state="NY", salary=86000, age=56, bonus=20000), + Row(employee_name="Robert", department="Sales", state="CA", salary=81000, age=30, bonus=23000), + Row(employee_name="Maria", department="Finance", state="CA", salary=90000, age=24, bonus=23000), + Row(employee_name="James", department="Sales", state="NY", salary=90000, age=34, bonus=10000), + Row(employee_name="Maria", department="Finance", state="CA", salary=90000, age=24, bonus=23000), + Row(employee_name="Jen", department="Finance", state="NY", salary=79000, age=53, bonus=15000), + Row(employee_name="Jeff", department="Marketing", state="CA", salary=80000, age=25, bonus=18000), + Row(employee_name="Kumar", department="Marketing", state="NY", salary=91000, age=50, bonus=21000), ] unionDF = df.unionAll(df2) res2 = unionDF.collect() @@ -60,11 +60,11 @@ def test_merge_without_duplicates(self, df, df2): disDF = df.union(df2).distinct().sort(col("employee_name")) res = disDF.collect() assert res == [ - Row(employee_name='James', department='Sales', state='NY', salary=90000, age=34, bonus=10000), - Row(employee_name='Jeff', department='Marketing', state='CA', salary=80000, age=25, bonus=18000), - Row(employee_name='Jen', department='Finance', state='NY', salary=79000, age=53, bonus=15000), - Row(employee_name='Kumar', department='Marketing', state='NY', salary=91000, age=50, bonus=21000), - Row(employee_name='Maria', department='Finance', state='CA', salary=90000, age=24, bonus=23000), - Row(employee_name='Michael', department='Sales', state='NY', salary=86000, age=56, bonus=20000), - Row(employee_name='Robert', department='Sales', state='CA', salary=81000, age=30, bonus=23000), + Row(employee_name="James", department="Sales", state="NY", salary=90000, age=34, bonus=10000), + Row(employee_name="Jeff", department="Marketing", state="CA", salary=80000, age=25, bonus=18000), + Row(employee_name="Jen", department="Finance", state="NY", salary=79000, age=53, bonus=15000), + Row(employee_name="Kumar", department="Marketing", state="NY", salary=91000, age=50, bonus=21000), + Row(employee_name="Maria", department="Finance", state="CA", salary=90000, age=24, bonus=23000), + Row(employee_name="Michael", department="Sales", state="NY", salary=86000, age=56, bonus=20000), + Row(employee_name="Robert", department="Sales", state="CA", salary=81000, age=30, bonus=23000), ] diff --git a/tests/fast/spark/test_spark_union_by_name.py b/tests/fast/spark/test_spark_union_by_name.py index 08f3c62b..4739f0d8 100644 --- a/tests/fast/spark/test_spark_union_by_name.py +++ b/tests/fast/spark/test_spark_union_by_name.py @@ -38,14 +38,14 @@ def test_union_by_name(self, df1, df2): rel = df1.unionByName(df2) res = rel.collect() expected = [ - Row(name='James', id=34), - Row(name='Michael', id=56), - Row(name='Robert', id=30), - Row(name='Maria', id=24), - Row(name='James', id=34), - Row(name='Maria', id=45), - Row(name='Jen', id=45), - Row(name='Jeff', id=34), + Row(name="James", id=34), + Row(name="Michael", id=56), + Row(name="Robert", id=30), + Row(name="Maria", id=24), + Row(name="James", id=34), + Row(name="Maria", id=45), + Row(name="Jen", id=45), + Row(name="Jeff", id=34), ] assert res == expected @@ -53,13 +53,13 @@ def test_union_by_name_allow_missing_cols(self, df1, df2): rel = df1.unionByName(df2.drop("id"), allowMissingColumns=True) res = rel.collect() expected = [ - Row(name='James', id=34), - Row(name='Michael', id=56), - Row(name='Robert', id=30), - Row(name='Maria', id=24), - Row(name='James', id=None), - Row(name='Maria', id=None), - Row(name='Jen', id=None), - Row(name='Jeff', id=None), + Row(name="James", id=34), + Row(name="Michael", id=56), + Row(name="Robert", id=30), + Row(name="Maria", id=24), + Row(name="James", id=None), + Row(name="Maria", id=None), + Row(name="Jen", id=None), + Row(name="Jeff", id=None), ] assert res == expected diff --git a/tests/fast/spark/test_spark_with_column.py b/tests/fast/spark/test_spark_with_column.py index 80da34c3..2980e7fe 100644 --- a/tests/fast/spark/test_spark_with_column.py +++ b/tests/fast/spark/test_spark_with_column.py @@ -23,20 +23,20 @@ class TestWithColumn(object): def test_with_column(self, spark): data = [ - ('James', '', 'Smith', '1991-04-01', 'M', 3000), - ('Michael', 'Rose', '', '2000-05-19', 'M', 4000), - ('Robert', '', 'Williams', '1978-09-05', 'M', 4000), - ('Maria', 'Anne', 'Jones', '1967-12-01', 'F', 4000), - ('Jen', 'Mary', 'Brown', '1980-02-17', 'F', -1), + ("James", "", "Smith", "1991-04-01", "M", 3000), + ("Michael", "Rose", "", "2000-05-19", "M", 4000), + ("Robert", "", "Williams", "1978-09-05", "M", 4000), + ("Maria", "Anne", "Jones", "1967-12-01", "F", 4000), + ("Jen", "Mary", "Brown", "1980-02-17", "F", -1), ] columns = ["firstname", "middlename", "lastname", "dob", "gender", "salary"] df = spark.createDataFrame(data=data, schema=columns) - assert df.schema['salary'].dataType.typeName() == ('long' if USE_ACTUAL_SPARK else 'integer') + assert df.schema["salary"].dataType.typeName() == ("long" if USE_ACTUAL_SPARK else "integer") # The type of 'salary' has been cast to Bigint new_df = df.withColumn("salary", col("salary").cast("BIGINT")) - assert new_df.schema['salary'].dataType.typeName() == 'long' + assert new_df.schema["salary"].dataType.typeName() == "long" # Replace the 'salary' column with '(salary * 100)' df2 = df.withColumn("salary", col("salary") * 100) @@ -50,16 +50,16 @@ def test_with_column(self, spark): df2 = df.withColumn("Country", lit("USA")) res = df2.collect() - assert res[0].Country == 'USA' + assert res[0].Country == "USA" df2 = df.withColumn("Country", lit("USA")).withColumn("anotherColumn", lit("anotherValue")) res = df2.collect() - assert res[0].Country == 'USA' - assert res[1].anotherColumn == 'anotherValue' + assert res[0].Country == "USA" + assert res[1].anotherColumn == "anotherValue" df2 = df.withColumnRenamed("gender", "sex") - assert 'gender' not in df2.columns - assert 'sex' in df2.columns + assert "gender" not in df2.columns + assert "sex" in df2.columns df2 = df.drop("salary") - assert 'salary' not in df2.columns + assert "salary" not in df2.columns diff --git a/tests/fast/spark/test_spark_with_column_renamed.py b/tests/fast/spark/test_spark_with_column_renamed.py index 168ff23a..8534ab0b 100644 --- a/tests/fast/spark/test_spark_with_column_renamed.py +++ b/tests/fast/spark/test_spark_with_column_renamed.py @@ -22,49 +22,49 @@ class TestWithColumnRenamed(object): def test_with_column_renamed(self, spark): dataDF = [ - (('James', '', 'Smith'), '1991-04-01', 'M', 3000), - (('Michael', 'Rose', ''), '2000-05-19', 'M', 4000), - (('Robert', '', 'Williams'), '1978-09-05', 'M', 4000), - (('Maria', 'Anne', 'Jones'), '1967-12-01', 'F', 4000), - (('Jen', 'Mary', 'Brown'), '1980-02-17', 'F', -1), + (("James", "", "Smith"), "1991-04-01", "M", 3000), + (("Michael", "Rose", ""), "2000-05-19", "M", 4000), + (("Robert", "", "Williams"), "1978-09-05", "M", 4000), + (("Maria", "Anne", "Jones"), "1967-12-01", "F", 4000), + (("Jen", "Mary", "Brown"), "1980-02-17", "F", -1), ] from spark_namespace.sql.types import StructType, StructField, StringType, IntegerType schema = StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), ), - StructField('dob', StringType(), True), - StructField('gender', StringType(), True), - StructField('salary', IntegerType(), True), + StructField("dob", StringType(), True), + StructField("gender", StringType(), True), + StructField("salary", IntegerType(), True), ] ) df = spark.createDataFrame(data=dataDF, schema=schema) df2 = df.withColumnRenamed("dob", "DateOfBirth").withColumnRenamed("salary", "salary_amount") - assert 'dob' not in df2.columns - assert 'salary' not in df2.columns - assert 'DateOfBirth' in df2.columns - assert 'salary_amount' in df2.columns + assert "dob" not in df2.columns + assert "salary" not in df2.columns + assert "DateOfBirth" in df2.columns + assert "salary_amount" in df2.columns schema2 = StructType( [ StructField( - 'full name', + "full name", StructType( [ - StructField('fname', StringType(), True), - StructField('mname', StringType(), True), - StructField('lname', StringType(), True), + StructField("fname", StringType(), True), + StructField("mname", StringType(), True), + StructField("lname", StringType(), True), ] ), ), @@ -72,9 +72,9 @@ def test_with_column_renamed(self, spark): ) df2 = df.withColumnRenamed("name", "full name") - assert 'name' not in df2.columns - assert 'full name' in df2.columns - assert 'firstname' in df2.schema['full name'].dataType.fieldNames() + assert "name" not in df2.columns + assert "full name" in df2.columns + assert "firstname" in df2.schema["full name"].dataType.fieldNames() df2 = df.select( col("name").alias("full name"), @@ -82,9 +82,9 @@ def test_with_column_renamed(self, spark): col("gender"), col("salary"), ) - assert 'name' not in df2.columns - assert 'full name' in df2.columns - assert 'firstname' in df2.schema['full name'].dataType.fieldNames() + assert "name" not in df2.columns + assert "full name" in df2.columns + assert "firstname" in df2.schema["full name"].dataType.fieldNames() df2 = df.select( col("name.firstname").alias("fname"), @@ -94,5 +94,5 @@ def test_with_column_renamed(self, spark): col("gender"), col("salary"), ) - assert 'firstname' not in df2.columns - assert 'fname' in df2.columns + assert "firstname" not in df2.columns + assert "fname" in df2.columns diff --git a/tests/fast/spark/test_spark_with_columns.py b/tests/fast/spark/test_spark_with_columns.py index 6e1bedea..535f357d 100644 --- a/tests/fast/spark/test_spark_with_columns.py +++ b/tests/fast/spark/test_spark_with_columns.py @@ -10,20 +10,20 @@ class TestWithColumns: def test_with_columns(self, spark): data = [ - ('James', '', 'Smith', '1991-04-01', 'M', 3000), - ('Michael', 'Rose', '', '2000-05-19', 'M', 4000), - ('Robert', '', 'Williams', '1978-09-05', 'M', 4000), - ('Maria', 'Anne', 'Jones', '1967-12-01', 'F', 4000), - ('Jen', 'Mary', 'Brown', '1980-02-17', 'F', -1), + ("James", "", "Smith", "1991-04-01", "M", 3000), + ("Michael", "Rose", "", "2000-05-19", "M", 4000), + ("Robert", "", "Williams", "1978-09-05", "M", 4000), + ("Maria", "Anne", "Jones", "1967-12-01", "F", 4000), + ("Jen", "Mary", "Brown", "1980-02-17", "F", -1), ] columns = ["firstname", "middlename", "lastname", "dob", "gender", "salary"] df = spark.createDataFrame(data=data, schema=columns) - assert df.schema['salary'].dataType.typeName() == ('long' if USE_ACTUAL_SPARK else 'integer') + assert df.schema["salary"].dataType.typeName() == ("long" if USE_ACTUAL_SPARK else "integer") # The type of 'salary' has been cast to Bigint new_df = df.withColumns({"salary": col("salary").cast("BIGINT")}) - assert new_df.schema['salary'].dataType.typeName() == 'long' + assert new_df.schema["salary"].dataType.typeName() == "long" # Replace the 'salary' column with '(salary * 100)' and add a new column # from an existing column @@ -34,12 +34,12 @@ def test_with_columns(self, spark): df2 = df.withColumns({"Country": lit("USA")}) res = df2.collect() - assert res[0].Country == 'USA' + assert res[0].Country == "USA" df2 = df.withColumns({"Country": lit("USA")}).withColumns({"anotherColumn": lit("anotherValue")}) res = df2.collect() - assert res[0].Country == 'USA' - assert res[1].anotherColumn == 'anotherValue' + assert res[0].Country == "USA" + assert res[1].anotherColumn == "anotherValue" df2 = df.drop("salary") - assert 'salary' not in df2.columns + assert "salary" not in df2.columns diff --git a/tests/fast/spark/test_spark_with_columns_renamed.py b/tests/fast/spark/test_spark_with_columns_renamed.py index 99c4ce63..80b8b9e0 100644 --- a/tests/fast/spark/test_spark_with_columns_renamed.py +++ b/tests/fast/spark/test_spark_with_columns_renamed.py @@ -9,44 +9,44 @@ class TestWithColumnsRenamed(object): def test_with_columns_renamed(self, spark): dataDF = [ - (('James', '', 'Smith'), '1991-04-01', 'M', 3000), - (('Michael', 'Rose', ''), '2000-05-19', 'M', 4000), - (('Robert', '', 'Williams'), '1978-09-05', 'M', 4000), - (('Maria', 'Anne', 'Jones'), '1967-12-01', 'F', 4000), - (('Jen', 'Mary', 'Brown'), '1980-02-17', 'F', -1), + (("James", "", "Smith"), "1991-04-01", "M", 3000), + (("Michael", "Rose", ""), "2000-05-19", "M", 4000), + (("Robert", "", "Williams"), "1978-09-05", "M", 4000), + (("Maria", "Anne", "Jones"), "1967-12-01", "F", 4000), + (("Jen", "Mary", "Brown"), "1980-02-17", "F", -1), ] from spark_namespace.sql.types import StructType, StructField, StringType, IntegerType schema = StructType( [ StructField( - 'name', + "name", StructType( [ - StructField('firstname', StringType(), True), - StructField('middlename', StringType(), True), - StructField('lastname', StringType(), True), + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), ] ), ), - StructField('dob', StringType(), True), - StructField('gender', StringType(), True), - StructField('salary', IntegerType(), True), + StructField("dob", StringType(), True), + StructField("gender", StringType(), True), + StructField("salary", IntegerType(), True), ] ) df = spark.createDataFrame(data=dataDF, schema=schema) df2 = df.withColumnsRenamed({"dob": "DateOfBirth", "salary": "salary_amount"}) - assert 'dob' not in df2.columns - assert 'salary' not in df2.columns - assert 'DateOfBirth' in df2.columns - assert 'salary_amount' in df2.columns + assert "dob" not in df2.columns + assert "salary" not in df2.columns + assert "DateOfBirth" in df2.columns + assert "salary_amount" in df2.columns df2 = df.withColumnsRenamed({"name": "full name"}) - assert 'name' not in df2.columns - assert 'full name' in df2.columns - assert 'firstname' in df2.schema['full name'].dataType.fieldNames() + assert "name" not in df2.columns + assert "full name" in df2.columns + assert "firstname" in df2.schema["full name"].dataType.fieldNames() # PySpark does not raise an error. This is a convenience we provide in DuckDB. if not USE_ACTUAL_SPARK: diff --git a/tests/fast/sqlite/test_types.py b/tests/fast/sqlite/test_types.py index d4be447a..3ffdceae 100644 --- a/tests/fast/sqlite/test_types.py +++ b/tests/fast/sqlite/test_types.py @@ -42,10 +42,10 @@ def tearDown(self): self.con.close() def test_CheckString(self): - self.cur.execute("insert into test(s) values (?)", (u"Österreich",)) + self.cur.execute("insert into test(s) values (?)", ("Österreich",)) self.cur.execute("select s from test") row = self.cur.fetchone() - self.assertEqual(row[0], u"Österreich") + self.assertEqual(row[0], "Österreich") def test_CheckSmallInt(self): self.cur.execute("insert into test(i) values (?)", (42,)) @@ -75,7 +75,7 @@ def test_CheckDecimalTooBig(self): self.assertEqual(row[0], val) def test_CheckDecimal(self): - val = '17.29' + val = "17.29" val = decimal.Decimal(val) self.cur.execute("insert into test(f) values (?)", (val,)) self.cur.execute("select f from test") @@ -83,7 +83,7 @@ def test_CheckDecimal(self): self.assertEqual(row[0], self.cur.execute("select 17.29::DOUBLE").fetchone()[0]) def test_CheckDecimalWithExponent(self): - val = '1E5' + val = "1E5" val = decimal.Decimal(val) self.cur.execute("insert into test(f) values (?)", (val,)) self.cur.execute("select f from test") @@ -93,14 +93,14 @@ def test_CheckDecimalWithExponent(self): def test_CheckNaN(self): import math - val = decimal.Decimal('nan') + val = decimal.Decimal("nan") self.cur.execute("insert into test(f) values (?)", (val,)) self.cur.execute("select f from test") row = self.cur.fetchone() self.assertEqual(math.isnan(row[0]), True) def test_CheckInf(self): - val = decimal.Decimal('inf') + val = decimal.Decimal("inf") self.cur.execute("insert into test(f) values (?)", (val,)) self.cur.execute("select f from test") row = self.cur.fetchone() @@ -122,7 +122,7 @@ def test_CheckMemoryviewBlob(self): self.assertEqual(row[0], sample) def test_CheckMemoryviewFromhexBlob(self): - sample = bytes.fromhex('00FF0F2E3D4C5B6A798800FF00') + sample = bytes.fromhex("00FF0F2E3D4C5B6A798800FF00") val = memoryview(sample) self.cur.execute("insert into test(b) values (?)", (val,)) self.cur.execute("select b from test") @@ -137,9 +137,9 @@ def test_CheckNoneBlob(self): self.assertEqual(row[0], val) def test_CheckUnicodeExecute(self): - self.cur.execute(u"select 'Österreich'") + self.cur.execute("select 'Österreich'") row = self.cur.fetchone() - self.assertEqual(row[0], u"Österreich") + self.assertEqual(row[0], "Österreich") class CommonTableExpressionTests(unittest.TestCase): @@ -206,7 +206,7 @@ def test_CheckTimestamp(self): self.assertEqual(ts, ts2) def test_CheckSqlTimestamp(self): - now = datetime.datetime.now(datetime.UTC) if hasattr(datetime, 'UTC') else datetime.datetime.utcnow() + now = datetime.datetime.now(datetime.UTC) if hasattr(datetime, "UTC") else datetime.datetime.utcnow() self.cur.execute("insert into test(ts) values (current_timestamp)") self.cur.execute("select ts from test") ts = self.cur.fetchone()[0] diff --git a/tests/fast/test_alex_multithread.py b/tests/fast/test_alex_multithread.py index 92768ec0..bcb0181b 100644 --- a/tests/fast/test_alex_multithread.py +++ b/tests/fast/test_alex_multithread.py @@ -41,7 +41,7 @@ def test_multiple_cursors(self, duckdb_cursor): # Kick off multiple threads (in the same process) # Pass in the same connection as an argument, and an object to store the results for i in range(thread_count): - threads.append(Thread(target=insert_from_cursor, args=(duckdb_con,), name='my_thread_' + str(i))) + threads.append(Thread(target=insert_from_cursor, args=(duckdb_con,), name="my_thread_" + str(i))) for thread in threads: thread.start() @@ -50,9 +50,9 @@ def test_multiple_cursors(self, duckdb_cursor): thread.join() assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [ - ('my_thread_0',), - ('my_thread_1',), - ('my_thread_2',), + ("my_thread_0",), + ("my_thread_1",), + ("my_thread_2",), ] def test_same_connection(self, duckdb_cursor): @@ -67,7 +67,7 @@ def test_same_connection(self, duckdb_cursor): # Pass in the same connection as an argument, and an object to store the results for i in range(thread_count): cursors.append(duckdb_con.cursor()) - threads.append(Thread(target=insert_from_same_connection, args=(cursors[i],), name='my_thread_' + str(i))) + threads.append(Thread(target=insert_from_same_connection, args=(cursors[i],), name="my_thread_" + str(i))) for thread in threads: thread.start() @@ -76,9 +76,9 @@ def test_same_connection(self, duckdb_cursor): thread.join() assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [ - ('my_thread_0',), - ('my_thread_1',), - ('my_thread_2',), + ("my_thread_0",), + ("my_thread_1",), + ("my_thread_2",), ] def test_multiple_cursors_persisted(self, tmp_database): @@ -91,7 +91,7 @@ def test_multiple_cursors_persisted(self, tmp_database): # Kick off multiple threads (in the same process) # Pass in the same connection as an argument, and an object to store the results for i in range(thread_count): - threads.append(Thread(target=insert_from_cursor, args=(duckdb_con,), name='my_thread_' + str(i))) + threads.append(Thread(target=insert_from_cursor, args=(duckdb_con,), name="my_thread_" + str(i))) for thread in threads: thread.start() @@ -99,9 +99,9 @@ def test_multiple_cursors_persisted(self, tmp_database): thread.join() assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [ - ('my_thread_0',), - ('my_thread_1',), - ('my_thread_2',), + ("my_thread_0",), + ("my_thread_1",), + ("my_thread_2",), ] duckdb_con.close() @@ -115,7 +115,7 @@ def test_same_connection_persisted(self, tmp_database): # Kick off multiple threads (in the same process) # Pass in the same connection as an argument, and an object to store the results for i in range(thread_count): - threads.append(Thread(target=insert_from_same_connection, args=(duckdb_con,), name='my_thread_' + str(i))) + threads.append(Thread(target=insert_from_same_connection, args=(duckdb_con,), name="my_thread_" + str(i))) for thread in threads: thread.start() @@ -123,8 +123,8 @@ def test_same_connection_persisted(self, tmp_database): thread.join() assert duckdb_con.execute("""SELECT * FROM my_inserts order by thread_name""").fetchall() == [ - ('my_thread_0',), - ('my_thread_1',), - ('my_thread_2',), + ("my_thread_0",), + ("my_thread_1",), + ("my_thread_2",), ] duckdb_con.close() diff --git a/tests/fast/test_all_types.py b/tests/fast/test_all_types.py index 2128f9f1..3e701ced 100644 --- a/tests/fast/test_all_types.py +++ b/tests/fast/test_all_types.py @@ -12,7 +12,7 @@ def replace_with_ndarray(obj): - if hasattr(obj, '__getitem__'): + if hasattr(obj, "__getitem__"): if isinstance(obj, dict): for key, value in obj.items(): obj[key] = replace_with_ndarray(value) @@ -115,69 +115,69 @@ def recursive_equality(o1, o2): class TestAllTypes(object): - @pytest.mark.parametrize('cur_type', all_types) + @pytest.mark.parametrize("cur_type", all_types) def test_fetchall(self, cur_type): conn = duckdb.connect() conn.execute("SET TimeZone =UTC") # We replace these values since the extreme ranges are not supported in native-python. replacement_values = { - 'timestamp': "'1990-01-01 00:00:00'::TIMESTAMP", - 'timestamp_s': "'1990-01-01 00:00:00'::TIMESTAMP_S", - 'timestamp_ns': "'1990-01-01 00:00:00'::TIMESTAMP_NS", - 'timestamp_ms': "'1990-01-01 00:00:00'::TIMESTAMP_MS", - 'timestamp_tz': "'1990-01-01 00:00:00Z'::TIMESTAMPTZ", - 'date': "'1990-01-01'::DATE", - 'date_array': "[], ['1970-01-01'::DATE, NULL, '0001-01-01'::DATE, '9999-12-31'::DATE,], [NULL::DATE,]", - 'timestamp_array': "[], ['1970-01-01'::TIMESTAMP, NULL, '0001-01-01'::TIMESTAMP, '9999-12-31 23:59:59.999999'::TIMESTAMP,], [NULL::TIMESTAMP,]", - 'timestamptz_array': "[], ['1970-01-01 00:00:00Z'::TIMESTAMPTZ, NULL, '0001-01-01 00:00:00Z'::TIMESTAMPTZ, '9999-12-31 23:59:59.999999Z'::TIMESTAMPTZ,], [NULL::TIMESTAMPTZ,]", + "timestamp": "'1990-01-01 00:00:00'::TIMESTAMP", + "timestamp_s": "'1990-01-01 00:00:00'::TIMESTAMP_S", + "timestamp_ns": "'1990-01-01 00:00:00'::TIMESTAMP_NS", + "timestamp_ms": "'1990-01-01 00:00:00'::TIMESTAMP_MS", + "timestamp_tz": "'1990-01-01 00:00:00Z'::TIMESTAMPTZ", + "date": "'1990-01-01'::DATE", + "date_array": "[], ['1970-01-01'::DATE, NULL, '0001-01-01'::DATE, '9999-12-31'::DATE,], [NULL::DATE,]", + "timestamp_array": "[], ['1970-01-01'::TIMESTAMP, NULL, '0001-01-01'::TIMESTAMP, '9999-12-31 23:59:59.999999'::TIMESTAMP,], [NULL::TIMESTAMP,]", + "timestamptz_array": "[], ['1970-01-01 00:00:00Z'::TIMESTAMPTZ, NULL, '0001-01-01 00:00:00Z'::TIMESTAMPTZ, '9999-12-31 23:59:59.999999Z'::TIMESTAMPTZ,], [NULL::TIMESTAMPTZ,]", } adjusted_values = { - 'time': """CASE WHEN "time" = '24:00:00'::TIME THEN '23:59:59.999999'::TIME ELSE "time" END AS "time" """, - 'time_tz': """CASE WHEN time_tz = '24:00:00-1559'::TIMETZ THEN '23:59:59.999999-1559'::TIMETZ ELSE time_tz END AS "time_tz" """, + "time": """CASE WHEN "time" = '24:00:00'::TIME THEN '23:59:59.999999'::TIME ELSE "time" END AS "time" """, + "time_tz": """CASE WHEN time_tz = '24:00:00-1559'::TIMETZ THEN '23:59:59.999999-1559'::TIMETZ ELSE time_tz END AS "time_tz" """, } min_datetime = datetime.datetime.min min_datetime_with_utc = min_datetime.replace(tzinfo=pytz.UTC) max_datetime = datetime.datetime.max max_datetime_with_utc = max_datetime.replace(tzinfo=pytz.UTC) correct_answer_map = { - 'bool': [(False,), (True,), (None,)], - 'tinyint': [(-128,), (127,), (None,)], - 'smallint': [(-32768,), (32767,), (None,)], - 'int': [(-2147483648,), (2147483647,), (None,)], - 'bigint': [(-9223372036854775808,), (9223372036854775807,), (None,)], - 'hugeint': [ + "bool": [(False,), (True,), (None,)], + "tinyint": [(-128,), (127,), (None,)], + "smallint": [(-32768,), (32767,), (None,)], + "int": [(-2147483648,), (2147483647,), (None,)], + "bigint": [(-9223372036854775808,), (9223372036854775807,), (None,)], + "hugeint": [ (-170141183460469231731687303715884105728,), (170141183460469231731687303715884105727,), (None,), ], - 'utinyint': [(0,), (255,), (None,)], - 'usmallint': [(0,), (65535,), (None,)], - 'uint': [(0,), (4294967295,), (None,)], - 'ubigint': [(0,), (18446744073709551615,), (None,)], - 'time': [(datetime.time(0, 0),), (datetime.time(23, 59, 59, 999999),), (None,)], - 'float': [(-3.4028234663852886e38,), (3.4028234663852886e38,), (None,)], - 'double': [(-1.7976931348623157e308,), (1.7976931348623157e308,), (None,)], - 'dec_4_1': [(Decimal('-999.9'),), (Decimal('999.9'),), (None,)], - 'dec_9_4': [(Decimal('-99999.9999'),), (Decimal('99999.9999'),), (None,)], - 'dec_18_6': [(Decimal('-999999999999.999999'),), (Decimal('999999999999.999999'),), (None,)], - 'dec38_10': [ - (Decimal('-9999999999999999999999999999.9999999999'),), - (Decimal('9999999999999999999999999999.9999999999'),), + "utinyint": [(0,), (255,), (None,)], + "usmallint": [(0,), (65535,), (None,)], + "uint": [(0,), (4294967295,), (None,)], + "ubigint": [(0,), (18446744073709551615,), (None,)], + "time": [(datetime.time(0, 0),), (datetime.time(23, 59, 59, 999999),), (None,)], + "float": [(-3.4028234663852886e38,), (3.4028234663852886e38,), (None,)], + "double": [(-1.7976931348623157e308,), (1.7976931348623157e308,), (None,)], + "dec_4_1": [(Decimal("-999.9"),), (Decimal("999.9"),), (None,)], + "dec_9_4": [(Decimal("-99999.9999"),), (Decimal("99999.9999"),), (None,)], + "dec_18_6": [(Decimal("-999999999999.999999"),), (Decimal("999999999999.999999"),), (None,)], + "dec38_10": [ + (Decimal("-9999999999999999999999999999.9999999999"),), + (Decimal("9999999999999999999999999999.9999999999"),), (None,), ], - 'uuid': [ - (UUID('00000000-0000-0000-0000-000000000000'),), - (UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'),), + "uuid": [ + (UUID("00000000-0000-0000-0000-000000000000"),), + (UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"),), (None,), ], - 'varchar': [('🦆🦆🦆🦆🦆🦆',), ('goo\0se',), (None,)], - 'json': [('🦆🦆🦆🦆🦆🦆',), ('goose',), (None,)], - 'blob': [(b'thisisalongblob\x00withnullbytes',), (b'\x00\x00\x00a',), (None,)], - 'bit': [('0010001001011100010101011010111',), ('10101',), (None,)], - 'small_enum': [('DUCK_DUCK_ENUM',), ('GOOSE',), (None,)], - 'medium_enum': [('enum_0',), ('enum_299',), (None,)], - 'large_enum': [('enum_0',), ('enum_69999',), (None,)], - 'date_array': [ + "varchar": [("🦆🦆🦆🦆🦆🦆",), ("goo\0se",), (None,)], + "json": [("🦆🦆🦆🦆🦆🦆",), ("goose",), (None,)], + "blob": [(b"thisisalongblob\x00withnullbytes",), (b"\x00\x00\x00a",), (None,)], + "bit": [("0010001001011100010101011010111",), ("10101",), (None,)], + "small_enum": [("DUCK_DUCK_ENUM",), ("GOOSE",), (None,)], + "medium_enum": [("enum_0",), ("enum_299",), (None,)], + "large_enum": [("enum_0",), ("enum_69999",), (None,)], + "date_array": [ ( [], [datetime.date(1970, 1, 1), None, datetime.date.min, datetime.date.max], @@ -186,7 +186,7 @@ def test_fetchall(self, cur_type): ], ) ], - 'timestamp_array': [ + "timestamp_array": [ ( [], [datetime.datetime(1970, 1, 1), None, datetime.datetime.min, datetime.datetime.max], @@ -195,7 +195,7 @@ def test_fetchall(self, cur_type): ], ), ], - 'timestamptz_array': [ + "timestamptz_array": [ ( [], [ @@ -209,67 +209,67 @@ def test_fetchall(self, cur_type): ], ), ], - 'int_array': [([],), ([42, 999, None, None, -42],), (None,)], - 'varchar_array': [([],), (['🦆🦆🦆🦆🦆🦆', 'goose', None, ''],), (None,)], - 'double_array': [([],), ([42.0, float('nan'), float('inf'), float('-inf'), None, -42.0],), (None,)], - 'nested_int_array': [ + "int_array": [([],), ([42, 999, None, None, -42],), (None,)], + "varchar_array": [([],), (["🦆🦆🦆🦆🦆🦆", "goose", None, ""],), (None,)], + "double_array": [([],), ([42.0, float("nan"), float("inf"), float("-inf"), None, -42.0],), (None,)], + "nested_int_array": [ ([],), ([[], [42, 999, None, None, -42], None, [], [42, 999, None, None, -42]],), (None,), ], - 'struct': [({'a': None, 'b': None},), ({'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'},), (None,)], - 'struct_of_arrays': [ - ({'a': None, 'b': None},), - ({'a': [42, 999, None, None, -42], 'b': ['🦆🦆🦆🦆🦆🦆', 'goose', None, '']},), + "struct": [({"a": None, "b": None},), ({"a": 42, "b": "🦆🦆🦆🦆🦆🦆"},), (None,)], + "struct_of_arrays": [ + ({"a": None, "b": None},), + ({"a": [42, 999, None, None, -42], "b": ["🦆🦆🦆🦆🦆🦆", "goose", None, ""]},), (None,), ], - 'array_of_structs': [([],), ([{'a': None, 'b': None}, {'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}, None],), (None,)], - 'map': [ + "array_of_structs": [([],), ([{"a": None, "b": None}, {"a": 42, "b": "🦆🦆🦆🦆🦆🦆"}, None],), (None,)], + "map": [ ({},), - ({'key1': '🦆🦆🦆🦆🦆🦆', 'key2': 'goose'},), + ({"key1": "🦆🦆🦆🦆🦆🦆", "key2": "goose"},), (None,), ], - 'time_tz': [(datetime.time(0, 0),), (datetime.time(23, 59, 59, 999999),), (None,)], - 'interval': [ + "time_tz": [(datetime.time(0, 0),), (datetime.time(23, 59, 59, 999999),), (None,)], + "interval": [ (datetime.timedelta(0),), (datetime.timedelta(days=30969, seconds=999, microseconds=999999),), (None,), ], - 'timestamp': [(datetime.datetime(1990, 1, 1, 0, 0),)], - 'date': [(datetime.date(1990, 1, 1),)], - 'timestamp_s': [(datetime.datetime(1990, 1, 1, 0, 0),)], - 'timestamp_ns': [(datetime.datetime(1990, 1, 1, 0, 0),)], - 'timestamp_ms': [(datetime.datetime(1990, 1, 1, 0, 0),)], - 'timestamp_tz': [(datetime.datetime(1990, 1, 1, 0, 0, tzinfo=pytz.UTC),)], - 'union': [('Frank',), (5,), (None,)], - 'fixed_int_array': [((None, 2, 3),), ((4, 5, 6),), (None,)], - 'fixed_varchar_array': [(('a', None, 'c'),), (('d', 'e', 'f'),), (None,)], - 'fixed_nested_int_array': [ + "timestamp": [(datetime.datetime(1990, 1, 1, 0, 0),)], + "date": [(datetime.date(1990, 1, 1),)], + "timestamp_s": [(datetime.datetime(1990, 1, 1, 0, 0),)], + "timestamp_ns": [(datetime.datetime(1990, 1, 1, 0, 0),)], + "timestamp_ms": [(datetime.datetime(1990, 1, 1, 0, 0),)], + "timestamp_tz": [(datetime.datetime(1990, 1, 1, 0, 0, tzinfo=pytz.UTC),)], + "union": [("Frank",), (5,), (None,)], + "fixed_int_array": [((None, 2, 3),), ((4, 5, 6),), (None,)], + "fixed_varchar_array": [(("a", None, "c"),), (("d", "e", "f"),), (None,)], + "fixed_nested_int_array": [ (((None, 2, 3), None, (None, 2, 3)),), (((4, 5, 6), (None, 2, 3), (4, 5, 6)),), (None,), ], - 'fixed_nested_varchar_array': [ - ((('a', None, 'c'), None, ('a', None, 'c')),), - ((('d', 'e', 'f'), ('a', None, 'c'), ('d', 'e', 'f')),), + "fixed_nested_varchar_array": [ + ((("a", None, "c"), None, ("a", None, "c")),), + ((("d", "e", "f"), ("a", None, "c"), ("d", "e", "f")),), (None,), ], - 'fixed_struct_array': [ - (({'a': None, 'b': None}, {'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}, {'a': None, 'b': None}),), - (({'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}, {'a': None, 'b': None}, {'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}),), + "fixed_struct_array": [ + (({"a": None, "b": None}, {"a": 42, "b": "🦆🦆🦆🦆🦆🦆"}, {"a": None, "b": None}),), + (({"a": 42, "b": "🦆🦆🦆🦆🦆🦆"}, {"a": None, "b": None}, {"a": 42, "b": "🦆🦆🦆🦆🦆🦆"}),), (None,), ], - 'struct_of_fixed_array': [ - ({'a': (None, 2, 3), 'b': ('a', None, 'c')},), - ({'a': (4, 5, 6), 'b': ('d', 'e', 'f')},), + "struct_of_fixed_array": [ + ({"a": (None, 2, 3), "b": ("a", None, "c")},), + ({"a": (4, 5, 6), "b": ("d", "e", "f")},), (None,), ], - 'fixed_array_of_int_list': [ + "fixed_array_of_int_list": [ (([], [42, 999, None, None, -42], []),), (([42, 999, None, None, -42], [], [42, 999, None, None, -42]),), (None,), ], - 'list_of_fixed_int_array': [ + "list_of_fixed_int_array": [ ([(None, 2, 3), (4, 5, 6), (None, 2, 3)],), ([(4, 5, 6), (None, 2, 3), (4, 5, 6)],), (None,), @@ -278,14 +278,14 @@ def test_fetchall(self, cur_type): if cur_type in replacement_values: result = conn.execute("select " + replacement_values[cur_type]).fetchall() elif cur_type in adjusted_values: - result = conn.execute(f'select {adjusted_values[cur_type]} from test_all_types()').fetchall() + result = conn.execute(f"select {adjusted_values[cur_type]} from test_all_types()").fetchall() else: result = conn.execute(f'select "{cur_type}" from test_all_types()').fetchall() correct_result = correct_answer_map[cur_type] assert recursive_equality(result, correct_result) def test_bytearray_with_nulls(self): - con = duckdb.connect(database=':memory:') + con = duckdb.connect(database=":memory:") con.execute("CREATE TABLE test (content BLOB)") want = bytearray([1, 2, 0, 3, 4]) con.execute("INSERT INTO test VALUES (?)", [want]) @@ -295,90 +295,90 @@ def test_bytearray_with_nulls(self): # Don't truncate the array on the nullbyte assert want == bytearray(got) - @pytest.mark.parametrize('cur_type', all_types) + @pytest.mark.parametrize("cur_type", all_types) def test_fetchnumpy(self, cur_type): conn = duckdb.connect() correct_answer_map = { - 'bool': np.ma.array( + "bool": np.ma.array( [False, True, False], mask=[0, 0, 1], ), - 'tinyint': np.ma.array( + "tinyint": np.ma.array( [-128, 127, -1], mask=[0, 0, 1], dtype=np.int8, ), - 'smallint': np.ma.array( + "smallint": np.ma.array( [-32768, 32767, -1], mask=[0, 0, 1], dtype=np.int16, ), - 'int': np.ma.array( + "int": np.ma.array( [-2147483648, 2147483647, -1], mask=[0, 0, 1], dtype=np.int32, ), - 'bigint': np.ma.array( + "bigint": np.ma.array( [-9223372036854775808, 9223372036854775807, -1], mask=[0, 0, 1], dtype=np.int64, ), - 'utinyint': np.ma.array( + "utinyint": np.ma.array( [0, 255, 42], mask=[0, 0, 1], dtype=np.uint8, ), - 'usmallint': np.ma.array( + "usmallint": np.ma.array( [0, 65535, 42], mask=[0, 0, 1], dtype=np.uint16, ), - 'uint': np.ma.array( + "uint": np.ma.array( [0, 4294967295, 42], mask=[0, 0, 1], dtype=np.uint32, ), - 'ubigint': np.ma.array( + "ubigint": np.ma.array( [0, 18446744073709551615, 42], mask=[0, 0, 1], dtype=np.uint64, ), - 'float': np.ma.array( + "float": np.ma.array( [-3.4028234663852886e38, 3.4028234663852886e38, 42.0], mask=[0, 0, 1], dtype=np.float32, ), - 'double': np.ma.array( + "double": np.ma.array( [-1.7976931348623157e308, 1.7976931348623157e308, 42.0], mask=[0, 0, 1], dtype=np.float64, ), - 'uuid': np.ma.array( + "uuid": np.ma.array( [ - UUID('00000000-0000-0000-0000-000000000000'), - UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'), - UUID('00000000-0000-0000-0000-000000000042'), + UUID("00000000-0000-0000-0000-000000000000"), + UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), + UUID("00000000-0000-0000-0000-000000000042"), ], mask=[0, 0, 1], dtype=object, ), - 'varchar': np.ma.array( - ['🦆🦆🦆🦆🦆🦆', 'goo\0se', "42"], + "varchar": np.ma.array( + ["🦆🦆🦆🦆🦆🦆", "goo\0se", "42"], mask=[0, 0, 1], dtype=object, ), - 'json': np.ma.array( - ['🦆🦆🦆🦆🦆🦆', 'goose', "42"], + "json": np.ma.array( + ["🦆🦆🦆🦆🦆🦆", "goose", "42"], mask=[0, 0, 1], dtype=object, ), - 'blob': np.ma.array( - [b'thisisalongblob\x00withnullbytes', b'\x00\x00\x00a', b"42"], + "blob": np.ma.array( + [b"thisisalongblob\x00withnullbytes", b"\x00\x00\x00a", b"42"], mask=[0, 0, 1], dtype=object, ), - 'interval': np.ma.array( + "interval": np.ma.array( [ np.timedelta64(0), np.timedelta64(2675722599999999000), @@ -388,7 +388,7 @@ def test_fetchnumpy(self, cur_type): ), # For timestamp_ns, the lowest value is out-of-range for numpy, # such that the conversion yields "Not a Time" - 'timestamp_ns': np.ma.array( + "timestamp_ns": np.ma.array( [ np.datetime64("NaT"), np.datetime64(9223372036854775806, "ns"), @@ -397,21 +397,21 @@ def test_fetchnumpy(self, cur_type): mask=[0, 0, 1], ), # Enums don't have a numpy equivalent and yield pandas Categorical. - 'small_enum': pd.Categorical( - ['DUCK_DUCK_ENUM', 'GOOSE', np.nan], + "small_enum": pd.Categorical( + ["DUCK_DUCK_ENUM", "GOOSE", np.nan], ordered=True, ), - 'medium_enum': pd.Categorical( - ['enum_0', 'enum_299', np.nan], + "medium_enum": pd.Categorical( + ["enum_0", "enum_299", np.nan], ordered=True, ), - 'large_enum': pd.Categorical( - ['enum_0', 'enum_69999', np.nan], + "large_enum": pd.Categorical( + ["enum_0", "enum_69999", np.nan], ordered=True, ), # The following types don't have a numpy equivalent and yield # object arrays: - 'int_array': np.ma.array( + "int_array": np.ma.array( [ [], [42, 999, None, None, -42], @@ -420,25 +420,25 @@ def test_fetchnumpy(self, cur_type): mask=[0, 0, 1], dtype=object, ), - 'varchar_array': np.ma.array( + "varchar_array": np.ma.array( [ [], - ['🦆🦆🦆🦆🦆🦆', 'goose', None, ''], + ["🦆🦆🦆🦆🦆🦆", "goose", None, ""], None, ], mask=[0, 0, 1], dtype=object, ), - 'double_array': np.ma.array( + "double_array": np.ma.array( [ [], - [42.0, float('nan'), float('inf'), float('-inf'), None, -42.0], + [42.0, float("nan"), float("inf"), float("-inf"), None, -42.0], None, ], mask=[0, 0, 1], dtype=object, ), - 'nested_int_array': np.ma.array( + "nested_int_array": np.ma.array( [ [], [[], [42, 999, None, None, -42], None, [], [42, 999, None, None, -42]], @@ -447,53 +447,53 @@ def test_fetchnumpy(self, cur_type): mask=[0, 0, 1], dtype=object, ), - 'struct': np.ma.array( + "struct": np.ma.array( [ - {'a': None, 'b': None}, - {'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}, + {"a": None, "b": None}, + {"a": 42, "b": "🦆🦆🦆🦆🦆🦆"}, None, ], mask=[0, 0, 1], dtype=object, ), - 'struct_of_arrays': np.ma.array( + "struct_of_arrays": np.ma.array( [ - {'a': None, 'b': None}, - {'a': [42, 999, None, None, -42], 'b': ['🦆🦆🦆🦆🦆🦆', 'goose', None, '']}, + {"a": None, "b": None}, + {"a": [42, 999, None, None, -42], "b": ["🦆🦆🦆🦆🦆🦆", "goose", None, ""]}, None, ], mask=[0, 0, 1], dtype=object, ), - 'array_of_structs': np.ma.array( + "array_of_structs": np.ma.array( [ [], - [{'a': None, 'b': None}, {'a': 42, 'b': '🦆🦆🦆🦆🦆🦆'}, None], + [{"a": None, "b": None}, {"a": 42, "b": "🦆🦆🦆🦆🦆🦆"}, None], None, ], mask=[0, 0, 1], dtype=object, ), - 'map': np.ma.array( + "map": np.ma.array( [ {}, - {'key1': '🦆🦆🦆🦆🦆🦆', 'key2': 'goose'}, + {"key1": "🦆🦆🦆🦆🦆🦆", "key2": "goose"}, None, ], mask=[0, 0, 1], dtype=object, ), - 'time': np.ma.array( - ['00:00:00', '24:00:00', None], + "time": np.ma.array( + ["00:00:00", "24:00:00", None], mask=[0, 0, 1], dtype=object, ), - 'time_tz': np.ma.array( - ['00:00:00', '23:59:59.999999', None], + "time_tz": np.ma.array( + ["00:00:00", "23:59:59.999999", None], mask=[0, 0, 1], dtype=object, ), - 'union': np.ma.array(['Frank', 5, None], mask=[0, 0, 1], dtype=object), + "union": np.ma.array(["Frank", 5, None], mask=[0, 0, 1], dtype=object), } correct_answer_map = replace_with_ndarray(correct_answer_map) @@ -535,19 +535,19 @@ def test_fetchnumpy(self, cur_type): assert np.all(result.mask == correct_answer.mask) np.testing.assert_equal(result, correct_answer) - @pytest.mark.parametrize('cur_type', all_types) + @pytest.mark.parametrize("cur_type", all_types) def test_arrow(self, cur_type): try: import pyarrow as pa except: return # We skip those since the extreme ranges are not supported in arrow. - replacement_values = {'interval': "INTERVAL '2 years'"} + replacement_values = {"interval": "INTERVAL '2 years'"} # We do not round trip enum types - enum_types = {'small_enum', 'medium_enum', 'large_enum', 'double_array'} + enum_types = {"small_enum", "medium_enum", "large_enum", "double_array"} # uhugeint currently not supported by arrow - skip_types = {'uhugeint'} + skip_types = {"uhugeint"} if cur_type in skip_types: return @@ -565,33 +565,33 @@ def test_arrow(self, cur_type): round_trip_arrow_table = conn.execute("select * from arrow_table").fetch_arrow_table() assert arrow_table.equals(round_trip_arrow_table, check_metadata=True) - @pytest.mark.parametrize('cur_type', all_types) + @pytest.mark.parametrize("cur_type", all_types) def test_pandas(self, cur_type): # We skip those since the extreme ranges are not supported in python. replacement_values = { - 'timestamp': "'1990-01-01 00:00:00'::TIMESTAMP", - 'timestamp_s': "'1990-01-01 00:00:00'::TIMESTAMP_S", - 'timestamp_ns': "'1990-01-01 00:00:00'::TIMESTAMP_NS", - 'timestamp_ms': "'1990-01-01 00:00:00'::TIMESTAMP_MS", - 'timestamp_tz': "'1990-01-01 00:00:00Z'::TIMESTAMPTZ", - 'date': "'1990-01-01'::DATE", - 'date_array': "[], ['1970-01-01'::DATE, NULL, '0001-01-01'::DATE, '9999-12-31'::DATE,], [NULL::DATE,]", - 'timestamp_array': "[], ['1970-01-01'::TIMESTAMP, NULL, '0001-01-01'::TIMESTAMP, '9999-12-31 23:59:59.999999'::TIMESTAMP,], [NULL::TIMESTAMP,]", - 'timestamptz_array': "[], ['1970-01-01 00:00:00Z'::TIMESTAMPTZ, NULL, '0001-01-01 00:00:00Z'::TIMESTAMPTZ, '9999-12-31 23:59:59.999999Z'::TIMESTAMPTZ,], [NULL::TIMESTAMPTZ,]", + "timestamp": "'1990-01-01 00:00:00'::TIMESTAMP", + "timestamp_s": "'1990-01-01 00:00:00'::TIMESTAMP_S", + "timestamp_ns": "'1990-01-01 00:00:00'::TIMESTAMP_NS", + "timestamp_ms": "'1990-01-01 00:00:00'::TIMESTAMP_MS", + "timestamp_tz": "'1990-01-01 00:00:00Z'::TIMESTAMPTZ", + "date": "'1990-01-01'::DATE", + "date_array": "[], ['1970-01-01'::DATE, NULL, '0001-01-01'::DATE, '9999-12-31'::DATE,], [NULL::DATE,]", + "timestamp_array": "[], ['1970-01-01'::TIMESTAMP, NULL, '0001-01-01'::TIMESTAMP, '9999-12-31 23:59:59.999999'::TIMESTAMP,], [NULL::TIMESTAMP,]", + "timestamptz_array": "[], ['1970-01-01 00:00:00Z'::TIMESTAMPTZ, NULL, '0001-01-01 00:00:00Z'::TIMESTAMPTZ, '9999-12-31 23:59:59.999999Z'::TIMESTAMPTZ,], [NULL::TIMESTAMPTZ,]", } adjusted_values = { - 'time': """CASE WHEN "time" = '24:00:00'::TIME THEN '23:59:59.999999'::TIME ELSE "time" END AS "time" """, + "time": """CASE WHEN "time" = '24:00:00'::TIME THEN '23:59:59.999999'::TIME ELSE "time" END AS "time" """, } conn = duckdb.connect() # Pandas <= 2.2.3 does not convert without throwing a warning conn.execute("SET timezone = UTC") - warnings.simplefilter(action='ignore', category=RuntimeWarning) + warnings.simplefilter(action="ignore", category=RuntimeWarning) with suppress(TypeError): if cur_type in replacement_values: dataframe = conn.execute("select " + replacement_values[cur_type]).df() elif cur_type in adjusted_values: - dataframe = conn.execute(f'select {adjusted_values[cur_type]} from test_all_types()').df() + dataframe = conn.execute(f"select {adjusted_values[cur_type]} from test_all_types()").df() else: dataframe = conn.execute(f'select "{cur_type}" from test_all_types()').df() print(cur_type) diff --git a/tests/fast/test_case_alias.py b/tests/fast/test_case_alias.py index 4fcbd49c..2e42f0ed 100644 --- a/tests/fast/test_case_alias.py +++ b/tests/fast/test_case_alias.py @@ -7,35 +7,35 @@ class TestCaseAlias(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_case_alias(self, duckdb_cursor, pandas): import numpy as np import datetime import duckdb - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") df = pandas.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val3", "CoL2": 17}]) - r1 = con.from_df(df).query('df', 'select * from df').df() + r1 = con.from_df(df).query("df", "select * from df").df() assert r1["COL1"][0] == "val1" assert r1["COL1"][1] == "val3" assert r1["CoL2"][0] == 1.05 assert r1["CoL2"][1] == 17 - r2 = con.from_df(df).query('df', 'select COL1, COL2 from df').df() + r2 = con.from_df(df).query("df", "select COL1, COL2 from df").df() assert r2["COL1"][0] == "val1" assert r2["COL1"][1] == "val3" assert r2["CoL2"][0] == 1.05 assert r2["CoL2"][1] == 17 - r3 = con.from_df(df).query('df', 'select COL1, COL2 from df ORDER BY COL1').df() + r3 = con.from_df(df).query("df", "select COL1, COL2 from df ORDER BY COL1").df() assert r3["COL1"][0] == "val1" assert r3["COL1"][1] == "val3" assert r3["CoL2"][0] == 1.05 assert r3["CoL2"][1] == 17 - r4 = con.from_df(df).query('df', 'select COL1, COL2 from df GROUP BY COL1, COL2 ORDER BY COL1').df() + r4 = con.from_df(df).query("df", "select COL1, COL2 from df GROUP BY COL1, COL2 ORDER BY COL1").df() assert r4["COL1"][0] == "val1" assert r4["COL1"][1] == "val3" assert r4["CoL2"][0] == 1.05 diff --git a/tests/fast/test_context_manager.py b/tests/fast/test_context_manager.py index 2ac451d1..65ec1d33 100644 --- a/tests/fast/test_context_manager.py +++ b/tests/fast/test_context_manager.py @@ -3,5 +3,5 @@ class TestContextManager(object): def test_context_manager(self, duckdb_cursor): - with duckdb.connect(database=':memory:', read_only=False) as con: + with duckdb.connect(database=":memory:", read_only=False) as con: assert con.execute("select 1").fetchall() == [(1,)] diff --git a/tests/fast/test_duckdb_api.py b/tests/fast/test_duckdb_api.py index f5dcfb60..ea847d50 100644 --- a/tests/fast/test_duckdb_api.py +++ b/tests/fast/test_duckdb_api.py @@ -5,4 +5,4 @@ def test_duckdb_api(): res = duckdb.execute("SELECT name, value FROM duckdb_settings() WHERE name == 'duckdb_api'") formatted_python_version = f"{sys.version_info.major}.{sys.version_info.minor}" - assert res.fetchall() == [('duckdb_api', f'python/{formatted_python_version}')] + assert res.fetchall() == [("duckdb_api", f"python/{formatted_python_version}")] diff --git a/tests/fast/test_expression.py b/tests/fast/test_expression.py index e0f830c5..82753382 100644 --- a/tests/fast/test_expression.py +++ b/tests/fast/test_expression.py @@ -21,7 +21,7 @@ ) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def filter_rel(): con = duckdb.connect() rel = con.sql( @@ -59,7 +59,7 @@ def test_constant_expression(self): res = rel.fetchall() assert res == [(5,)] - @pytest.mark.skipif(platform.system() == 'Windows', reason="There is some weird interaction in Windows CI") + @pytest.mark.skipif(platform.system() == "Windows", reason="There is some weird interaction in Windows CI") def test_column_expression(self): con = duckdb.connect() @@ -71,12 +71,12 @@ def test_column_expression(self): 3 as c """ ) - column = ColumnExpression('a') + column = ColumnExpression("a") rel2 = rel.select(column) res = rel2.fetchall() assert res == [(1,)] - column = ColumnExpression('d') + column = ColumnExpression("d") with pytest.raises(duckdb.BinderException, match='Referenced column "d" not found'): rel2 = rel.select(column) @@ -89,9 +89,9 @@ def test_coalesce_operator(self): """ ) - rel2 = rel.select(CoalesceOperator(ConstantExpression(None), ConstantExpression('hello').cast(int))) + rel2 = rel.select(CoalesceOperator(ConstantExpression(None), ConstantExpression("hello").cast(int))) res = rel2.explain() - assert 'COALESCE' in res + assert "COALESCE" in res with pytest.raises(duckdb.ConversionException, match="Could not convert string 'hello' to INT64"): rel2.fetchall() @@ -103,7 +103,7 @@ def test_coalesce_operator(self): """ ) - with pytest.raises(duckdb.InvalidInputException, match='Please provide at least one argument'): + with pytest.raises(duckdb.InvalidInputException, match="Please provide at least one argument"): rel3 = rel.select(CoalesceOperator()) rel4 = rel.select(CoalesceOperator(ConstantExpression(None))) @@ -112,7 +112,7 @@ def test_coalesce_operator(self): rel5 = rel.select(CoalesceOperator(ConstantExpression(42))) assert rel5.fetchone() == (42,) - exprtest = con.table('exprtest') + exprtest = con.table("exprtest") rel6 = exprtest.select(CoalesceOperator(ColumnExpression("a"))) res = rel6.fetchall() assert res == [(42,), (43,), (None,), (45,)] @@ -193,17 +193,17 @@ def test_column_expression_explain(self): """ ) rel = rel.select( - ConstantExpression("a").alias('c0'), - ConstantExpression(42).alias('c1'), - ConstantExpression(None).alias('c2'), + ConstantExpression("a").alias("c0"), + ConstantExpression(42).alias("c1"), + ConstantExpression(None).alias("c2"), ) res = rel.explain() - assert 'c0' in res - assert 'c1' in res + assert "c0" in res + assert "c1" in res # 'c2' is not in the explain result because it shows NULL instead - assert 'NULL' in res + assert "NULL" in res res = rel.fetchall() - assert res == [('a', 42, None)] + assert res == [("a", 42, None)] def test_column_expression_table(self): con = duckdb.connect() @@ -219,10 +219,10 @@ def test_column_expression_table(self): """ ) - rel = con.table('tbl') - rel2 = rel.select('c0', 'c1', 'c2') + rel = con.table("tbl") + rel2 = rel.select("c0", "c1", "c2") res = rel2.fetchall() - assert res == [('a', 'b', 'c'), ('d', 'e', 'f'), ('g', 'h', 'i')] + assert res == [("a", "b", "c"), ("d", "e", "f"), ("g", "h", "i")] def test_column_expression_view(self): con = duckdb.connect() @@ -241,18 +241,18 @@ def test_column_expression_view(self): CREATE VIEW v1 as select c0 as c3, c2 as c4 from tbl; """ ) - rel = con.view('v1') - rel2 = rel.select('c3', 'c4') + rel = con.view("v1") + rel2 = rel.select("c3", "c4") res = rel2.fetchall() - assert res == [('a', 'c'), ('d', 'f'), ('g', 'i')] + assert res == [("a", "c"), ("d", "f"), ("g", "i")] def test_column_expression_replacement_scan(self): con = duckdb.connect() pd = pytest.importorskip("pandas") - df = pd.DataFrame({'a': [42, 43, 0], 'b': [True, False, True], 'c': [23.123, 623.213, 0.30234]}) + df = pd.DataFrame({"a": [42, 43, 0], "b": [True, False, True], "c": [23.123, 623.213, 0.30234]}) rel = con.sql("select * from df") - rel2 = rel.select('a', 'b') + rel2 = rel.select("a", "b") res = rel2.fetchall() assert res == [(42, True), (43, False), (0, True)] @@ -271,7 +271,7 @@ def test_add_operator(self): ) constant = ConstantExpression(val) - col = ColumnExpression('b') + col = ColumnExpression("b") expr = col + constant rel = rel.select(expr, expr) @@ -288,7 +288,7 @@ def test_binary_function_expression(self): 5 as b """ ) - function = FunctionExpression("-", ColumnExpression('b'), ColumnExpression('a')) + function = FunctionExpression("-", ColumnExpression("b"), ColumnExpression("a")) rel2 = rel.select(function) res = rel2.fetchall() assert res == [(4,)] @@ -301,7 +301,7 @@ def test_negate_expression(self): select 5 as a """ ) - col = ColumnExpression('a') + col = ColumnExpression("a") col = -col rel = rel.select(col) res = rel.fetchall() @@ -317,8 +317,8 @@ def test_subtract_expression(self): 1 as b """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") expr = col1 - col2 rel2 = rel.select(expr) res = rel2.fetchall() @@ -337,8 +337,8 @@ def test_multiply_expression(self): 2 as b """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") expr = col1 * col2 rel = rel.select(expr) res = rel.fetchall() @@ -354,8 +354,8 @@ def test_division_expression(self): 2 as b """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") expr = col1 / col2 rel2 = rel.select(expr) res = rel2.fetchall() @@ -376,8 +376,8 @@ def test_modulus_expression(self): 2 as b """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") expr = col1 % col2 rel2 = rel.select(expr) res = rel2.fetchall() @@ -393,8 +393,8 @@ def test_power_expression(self): 2 as b """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") expr = col1**col2 rel2 = rel.select(expr) res = rel2.fetchall() @@ -411,9 +411,9 @@ def test_between_expression(self): 3 as c """ ) - a = ColumnExpression('a') - b = ColumnExpression('b') - c = ColumnExpression('c') + a = ColumnExpression("a") + b = ColumnExpression("b") + c = ColumnExpression("c") # 5 BETWEEN 2 AND 3 -> false assert rel.select(a.between(b, c)).fetchall() == [(False,)] @@ -437,32 +437,32 @@ def test_collate_expression(self): """ ) - col1 = ColumnExpression('c0') - col2 = ColumnExpression('c1') + col1 = ColumnExpression("c0") + col2 = ColumnExpression("c1") - lower_a = ConstantExpression('a') - upper_a = ConstantExpression('A') + lower_a = ConstantExpression("a") + upper_a = ConstantExpression("A") # SELECT c0 LIKE 'a' == True - assert rel.select(FunctionExpression('~~', col1, lower_a)).fetchall() == [(True,)] + assert rel.select(FunctionExpression("~~", col1, lower_a)).fetchall() == [(True,)] # SELECT c0 LIKE 'A' == False - assert rel.select(FunctionExpression('~~', col1, upper_a)).fetchall() == [(False,)] + assert rel.select(FunctionExpression("~~", col1, upper_a)).fetchall() == [(False,)] # SELECT c0 LIKE 'A' COLLATE NOCASE == True - assert rel.select(FunctionExpression('~~', col1, upper_a.collate('NOCASE'))).fetchall() == [(True,)] + assert rel.select(FunctionExpression("~~", col1, upper_a.collate("NOCASE"))).fetchall() == [(True,)] # SELECT c1 LIKE 'a' == False - assert rel.select(FunctionExpression('~~', col2, lower_a)).fetchall() == [(False,)] + assert rel.select(FunctionExpression("~~", col2, lower_a)).fetchall() == [(False,)] # SELECT c1 LIKE 'a' COLLATE NOCASE == True - assert rel.select(FunctionExpression('~~', col2, lower_a.collate('NOCASE'))).fetchall() == [(True,)] + assert rel.select(FunctionExpression("~~", col2, lower_a.collate("NOCASE"))).fetchall() == [(True,)] - with pytest.raises(duckdb.BinderException, match='collations are only supported for type varchar'): - rel.select(FunctionExpression('~~', col2, lower_a).collate('NOCASE')) + with pytest.raises(duckdb.BinderException, match="collations are only supported for type varchar"): + rel.select(FunctionExpression("~~", col2, lower_a).collate("NOCASE")) - with pytest.raises(duckdb.CatalogException, match='Collation with name non-existant does not exist'): - rel.select(FunctionExpression('~~', col2, lower_a.collate('non-existant'))) + with pytest.raises(duckdb.CatalogException, match="Collation with name non-existant does not exist"): + rel.select(FunctionExpression("~~", col2, lower_a.collate("non-existant"))) def test_equality_expression(self): con = duckdb.connect() @@ -475,9 +475,9 @@ def test_equality_expression(self): 5 as c """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') - col3 = ColumnExpression('c') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") + col3 = ColumnExpression("c") expr1 = col1 == col2 expr2 = col1 == col3 rel2 = rel.select(expr1, expr2) @@ -497,22 +497,22 @@ def test_lambda_expression(self): # Use a tuple of strings as 'lhs' func = FunctionExpression( "list_reduce", - ColumnExpression('a'), - LambdaExpression(('x', 'y'), ColumnExpression('x') + ColumnExpression('y')), + ColumnExpression("a"), + LambdaExpression(("x", "y"), ColumnExpression("x") + ColumnExpression("y")), ) rel2 = rel.select(func) res = rel2.fetchall() assert res == [(6,)] # Use only a string name as 'lhs' - func = FunctionExpression("list_apply", ColumnExpression('a'), LambdaExpression('x', ColumnExpression('x') + 3)) + func = FunctionExpression("list_apply", ColumnExpression("a"), LambdaExpression("x", ColumnExpression("x") + 3)) rel2 = rel.select(func) res = rel2.fetchall() assert res == [([4, 5, 6],)] # 'row' is not a lambda function, so it doesn't accept a lambda expression - func = FunctionExpression("row", ColumnExpression('a'), LambdaExpression('x', ColumnExpression('x') + 3)) - with pytest.raises(duckdb.BinderException, match='This scalar function does not support lambdas'): + func = FunctionExpression("row", ColumnExpression("a"), LambdaExpression("x", ColumnExpression("x") + 3)) + with pytest.raises(duckdb.BinderException, match="This scalar function does not support lambdas"): rel2 = rel.select(func) # lhs has to be a tuple of strings or a single string @@ -520,11 +520,11 @@ def test_lambda_expression(self): ValueError, match="Please provide 'lhs' as either a tuple containing strings, or a single string" ): func = FunctionExpression( - "list_filter", ColumnExpression('a'), LambdaExpression(42, ColumnExpression('x') + 3) + "list_filter", ColumnExpression("a"), LambdaExpression(42, ColumnExpression("x") + 3) ) func = FunctionExpression( - "list_filter", ColumnExpression('a'), LambdaExpression('x', ColumnExpression('y') != 3) + "list_filter", ColumnExpression("a"), LambdaExpression("x", ColumnExpression("y") != 3) ) with pytest.raises(duckdb.BinderException, match='Referenced column "y" not found in FROM clause'): rel2 = rel.select(func) @@ -540,9 +540,9 @@ def test_inequality_expression(self): 5 as c """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') - col3 = ColumnExpression('c') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") + col3 = ColumnExpression("c") expr1 = col1 != col2 expr2 = col1 != col3 rel2 = rel.select(expr1, expr2) @@ -561,10 +561,10 @@ def test_comparison_expressions(self): 3 as d """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') - col3 = ColumnExpression('c') - col4 = ColumnExpression('d') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") + col3 = ColumnExpression("c") + col4 = ColumnExpression("d") # Greater than expr1 = col1 > col2 @@ -606,11 +606,11 @@ def test_expression_alias(self): select 1 as a """ ) - col = ColumnExpression('a') - col = col.alias('b') + col = ColumnExpression("a") + col = col.alias("b") rel2 = rel.select(col) - assert rel2.columns == ['b'] + assert rel2.columns == ["b"] def test_star_expression(self): con = duckdb.connect() @@ -628,7 +628,7 @@ def test_star_expression(self): assert res == [(1, 2)] # With exclude list - star = StarExpression(exclude=['a']) + star = StarExpression(exclude=["a"]) rel2 = rel.select(star) res = rel2.fetchall() assert res == [(2,)] @@ -644,13 +644,13 @@ def test_struct_expression(self): """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') - expr = FunctionExpression('struct_pack', col1, col2).alias('struct') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") + expr = FunctionExpression("struct_pack", col1, col2).alias("struct") rel = rel.select(expr) res = rel.fetchall() - assert res == [({'a': 1, 'b': 2},)] + assert res == [({"a": 1, "b": 2},)] def test_function_expression_udf(self): con = duckdb.connect() @@ -658,7 +658,7 @@ def test_function_expression_udf(self): def my_simple_func(a: int, b: int, c: int) -> int: return a + b + c - con.create_function('my_func', my_simple_func) + con.create_function("my_func", my_simple_func) rel = con.sql( """ @@ -668,10 +668,10 @@ def my_simple_func(a: int, b: int, c: int) -> int: 3 as c """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') - col3 = ColumnExpression('c') - expr = FunctionExpression('my_func', col1, col2, col3) + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") + col3 = ColumnExpression("c") + expr = FunctionExpression("my_func", col1, col2, col3) rel2 = rel.select(expr) res = rel2.fetchall() assert res == [(6,)] @@ -688,10 +688,10 @@ def test_function_expression_basic(self): ) tbl(text, start, "end") """ ) - expr = FunctionExpression('array_slice', "start", "text", "end") + expr = FunctionExpression("array_slice", "start", "text", "end") rel2 = rel.select(expr) res = rel2.fetchall() - assert res == [('tes',), ('his is',), ('di',)] + assert res == [("tes",), ("his is",), ("di",)] def test_column_expression_function_coverage(self): con = duckdb.connect() @@ -707,11 +707,11 @@ def test_column_expression_function_coverage(self): """ ) - rel = con.table('tbl') - expr = FunctionExpression('||', FunctionExpression('||', 'c0', 'c1'), 'c2') + rel = con.table("tbl") + expr = FunctionExpression("||", FunctionExpression("||", "c0", "c1"), "c2") rel2 = rel.select(expr) res = rel2.fetchall() - assert res == [('abc',), ('def',), ('ghi',)] + assert res == [("abc",), ("def",), ("ghi",)] def test_function_expression_aggregate(self): con = duckdb.connect() @@ -725,9 +725,9 @@ def test_function_expression_aggregate(self): ) tbl(text) """ ) - expr = FunctionExpression('first', 'text') + expr = FunctionExpression("first", "text") with pytest.raises( - duckdb.BinderException, match='Binder Error: Aggregates cannot be present in a Project relation!' + duckdb.BinderException, match="Binder Error: Aggregates cannot be present in a Project relation!" ): rel2 = rel.select(expr) @@ -743,9 +743,9 @@ def test_case_expression(self): """ ) - col1 = ColumnExpression('a') - col2 = ColumnExpression('b') - col3 = ColumnExpression('c') + col1 = ColumnExpression("a") + col2 = ColumnExpression("b") + col3 = ColumnExpression("c") const1 = ConstantExpression(IntegerValue(1)) # CASE WHEN col1 > 1 THEN 5 ELSE NULL @@ -796,7 +796,7 @@ def test_implicit_constant_conversion(self): def test_numeric_overflow(self): con = duckdb.connect() - rel = con.sql('select 3000::SHORT salary') + rel = con.sql("select 3000::SHORT salary") with pytest.raises(duckdb.OutOfRangeException, match="Overflow in multiplication of INT16"): expr = ColumnExpression("salary") * 100 rel2 = rel.select(expr) @@ -823,7 +823,7 @@ def test_filter_equality(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 2 - assert res == [(1, 'a'), (1, 'b')] + assert res == [(1, "a"), (1, "b")] def test_filter_not(self, filter_rel): expr = ColumnExpression("a") == 1 @@ -832,18 +832,18 @@ def test_filter_not(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 3 - assert res == [(2, 'b'), (3, 'c'), (4, 'a')] + assert res == [(2, "b"), (3, "c"), (4, "a")] def test_filter_and(self, filter_rel): expr = ColumnExpression("a") == 1 expr = ~expr # AND operator - expr = expr & ('b' != ConstantExpression('b')) + expr = expr & ("b" != ConstantExpression("b")) rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 2 - assert res == [(3, 'c'), (4, 'a')] + assert res == [(3, "c"), (4, "a")] def test_filter_or(self, filter_rel): # OR operator @@ -851,7 +851,7 @@ def test_filter_or(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 3 - assert res == [(1, 'a'), (1, 'b'), (4, 'a')] + assert res == [(1, "a"), (1, "b"), (4, "a")] def test_filter_mixed(self, filter_rel): # Mixed @@ -861,7 +861,7 @@ def test_filter_mixed(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 2 - assert res == [(1, 'a'), (4, 'a')] + assert res == [(1, "a"), (4, "a")] def test_empty_in(self, filter_rel): expr = ColumnExpression("a") @@ -884,7 +884,7 @@ def test_filter_in(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 3 - assert res == [(1, 'a'), (2, 'b'), (1, 'b')] + assert res == [(1, "a"), (2, "b"), (1, "b")] def test_filter_not_in(self, filter_rel): expr = ColumnExpression("a") @@ -894,7 +894,7 @@ def test_filter_not_in(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 2 - assert res == [(3, 'c'), (4, 'a')] + assert res == [(3, "c"), (4, "a")] # NOT IN expression expr = ColumnExpression("a") @@ -902,7 +902,7 @@ def test_filter_not_in(self, filter_rel): rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 2 - assert res == [(3, 'c'), (4, 'a')] + assert res == [(3, "c"), (4, "a")] def test_null(self): con = duckdb.connect() @@ -924,7 +924,7 @@ def test_null(self): assert res == [(False,), (False,), (True,), (False,), (False,)] res2 = rel.filter(b.isnotnull()).fetchall() - assert res2 == [(1, 'a'), (2, 'b'), (4, 'c'), (5, 'a')] + assert res2 == [(1, "a"), (2, "b"), (4, "c"), (5, "a")] def test_sort(self): con = duckdb.connect() @@ -956,12 +956,12 @@ def test_sort(self): # Nulls first rel2 = rel.sort(b.desc().nulls_first()) res = rel2.b.fetchall() - assert res == [(None,), ('c',), ('b',), ('a',), ('a',)] + assert res == [(None,), ("c",), ("b",), ("a",), ("a",)] # Nulls last rel2 = rel.sort(b.desc().nulls_last()) res = rel2.b.fetchall() - assert res == [('c',), ('b',), ('a',), ('a',), (None,)] + assert res == [("c",), ("b",), ("a",), ("a",), (None,)] def test_aggregate(self): con = duckdb.connect() @@ -983,7 +983,7 @@ def test_aggregate_error(self): # Providing something that can not be converted into an expression is an error: with pytest.raises( - duckdb.InvalidInputException, match='Invalid Input Error: Please provide arguments of type Expression!' + duckdb.InvalidInputException, match="Invalid Input Error: Please provide arguments of type Expression!" ): class MyClass: diff --git a/tests/fast/test_filesystem.py b/tests/fast/test_filesystem.py index 195de165..7b8fbb05 100644 --- a/tests/fast/test_filesystem.py +++ b/tests/fast/test_filesystem.py @@ -10,12 +10,12 @@ from duckdb import DuckDBPyConnection, InvalidInputException from pytest import raises, importorskip, fixture, MonkeyPatch, mark -importorskip('fsspec', '2022.11.0') +importorskip("fsspec", "2022.11.0") from fsspec import filesystem, AbstractFileSystem from fsspec.implementations.memory import MemoryFileSystem from fsspec.implementations.local import LocalFileOpener, LocalFileSystem -FILENAME = 'integers.csv' +FILENAME = "integers.csv" logging.basicConfig(level=logging.DEBUG) @@ -43,11 +43,11 @@ def duckdb_cursor(): @fixture() def memory(): - fs = filesystem('memory', skip_instance_cache=True) + fs = filesystem("memory", skip_instance_cache=True) # ensure each instance is independent (to work around a weird quirk in fsspec) fs.store = {} - fs.pseudo_dirs = [''] + fs.pseudo_dirs = [""] # copy csv into memory filesystem add_file(fs) @@ -55,39 +55,39 @@ def memory(): def add_file(fs, filename=FILENAME): - with (Path(__file__).parent / 'data' / filename).open('rb') as source, fs.open(filename, 'wb') as dest: + with (Path(__file__).parent / "data" / filename).open("rb") as source, fs.open(filename, "wb") as dest: copyfileobj(source, dest) class TestPythonFilesystem: def test_unregister_non_existent_filesystem(self, duckdb_cursor: DuckDBPyConnection): with raises(InvalidInputException): - duckdb_cursor.unregister_filesystem('fake') + duckdb_cursor.unregister_filesystem("fake") def test_memory_filesystem(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): duckdb_cursor.register_filesystem(memory) - assert memory.protocol == 'memory' + assert memory.protocol == "memory" duckdb_cursor.execute(f"select * from 'memory://{FILENAME}'") assert duckdb_cursor.fetchall() == [(1, 10, 0), (2, 50, 30)] - duckdb_cursor.unregister_filesystem('memory') + duckdb_cursor.unregister_filesystem("memory") def test_reject_abstract_filesystem(self, duckdb_cursor: DuckDBPyConnection): with raises(InvalidInputException): duckdb_cursor.register_filesystem(AbstractFileSystem()) def test_unregister_builtin(self, require: Callable[[str], DuckDBPyConnection]): - duckdb_cursor = require('httpfs') - assert duckdb_cursor.filesystem_is_registered('S3FileSystem') == True - duckdb_cursor.unregister_filesystem('S3FileSystem') - assert duckdb_cursor.filesystem_is_registered('S3FileSystem') == False + duckdb_cursor = require("httpfs") + assert duckdb_cursor.filesystem_is_registered("S3FileSystem") == True + duckdb_cursor.unregister_filesystem("S3FileSystem") + assert duckdb_cursor.filesystem_is_registered("S3FileSystem") == False def test_multiple_protocol_filesystems(self, duckdb_cursor: DuckDBPyConnection): class ExtendedMemoryFileSystem(MemoryFileSystem): - protocol = ('file', 'local') + protocol = ("file", "local") # defer to the original implementation that doesn't hardcode the protocol _strip_protocol = classmethod(AbstractFileSystem._strip_protocol.__func__) @@ -104,51 +104,51 @@ def test_write(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSyst duckdb_cursor.execute("copy (select 1) to 'memory://01.csv' (FORMAT CSV, HEADER 0)") - assert memory.open('01.csv').read() == b'1\n' + assert memory.open("01.csv").read() == b"1\n" def test_null_bytes(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): - with memory.open('test.csv', 'wb') as fh: - fh.write(b'hello\n\0world\0') + with memory.open("test.csv", "wb") as fh: + fh.write(b"hello\n\0world\0") duckdb_cursor.register_filesystem(memory) - duckdb_cursor.execute('select * from read_csv("memory://test.csv", header = 0, quote = \'"\', escape = \'"\')') + duckdb_cursor.execute("select * from read_csv(\"memory://test.csv\", header = 0, quote = '\"', escape = '\"')") - assert duckdb_cursor.fetchall() == [('hello',), ('\0world\0',)] + assert duckdb_cursor.fetchall() == [("hello",), ("\0world\0",)] def test_read_parquet(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): - filename = 'binary_string.parquet' + filename = "binary_string.parquet" add_file(memory, filename) duckdb_cursor.register_filesystem(memory) duckdb_cursor.execute(f"select * from read_parquet('memory://{filename}')") - assert duckdb_cursor.fetchall() == [(b'foo',), (b'bar',), (b'baz',)] + assert duckdb_cursor.fetchall() == [(b"foo",), (b"bar",), (b"baz",)] def test_write_parquet(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): duckdb_cursor.register_filesystem(memory) - filename = 'output.parquet' + filename = "output.parquet" - duckdb_cursor.execute(f'''COPY (SELECT 1) TO 'memory://{filename}' (FORMAT PARQUET);''') + duckdb_cursor.execute(f"""COPY (SELECT 1) TO 'memory://{filename}' (FORMAT PARQUET);""") - assert memory.open(filename).read().startswith(b'PAR1') + assert memory.open(filename).read().startswith(b"PAR1") def test_when_fsspec_not_installed(self, duckdb_cursor: DuckDBPyConnection, monkeypatch: MonkeyPatch): - monkeypatch.setitem(sys.modules, 'fsspec', None) + monkeypatch.setitem(sys.modules, "fsspec", None) with raises(ModuleNotFoundError): duckdb_cursor.register_filesystem(None) @mark.skipif(sys.version_info < (3, 8), reason="ArrowFSWrapper requires python 3.8 or higher") def test_arrow_fs_wrapper(self, tmp_path: Path, duckdb_cursor: DuckDBPyConnection): - fs = importorskip('pyarrow.fs') + fs = importorskip("pyarrow.fs") from fsspec.implementations.arrow import ArrowFSWrapper local = fs.LocalFileSystem() local_fsspec = ArrowFSWrapper(local, skip_instance_cache=True) # posix calls here required as ArrowFSWrapper only supports url-like paths (not Windows paths) filename = str(PurePosixPath(tmp_path.as_posix()) / "test.csv") - with local_fsspec.open(filename, mode='w') as f: + with local_fsspec.open(filename, mode="w") as f: f.write("a,b,c\n") f.write("1,2,3\n") f.write("4,5,6\n") @@ -159,29 +159,29 @@ def test_arrow_fs_wrapper(self, tmp_path: Path, duckdb_cursor: DuckDBPyConnectio assert duckdb_cursor.fetchall() == [(1, 2, 3), (4, 5, 6)] def test_database_attach(self, tmp_path: Path, monkeypatch: MonkeyPatch): - db_path = str(tmp_path / 'hello.db') + db_path = str(tmp_path / "hello.db") # setup a database to attach later with duckdb.connect(db_path) as conn: conn.execute( - ''' + """ CREATE TABLE t (id int); INSERT INTO t VALUES (0) - ''' + """ ) assert exists(db_path) with duckdb.connect() as conn: - fs = filesystem('file', skip_instance_cache=True) - write_errors = intercept(monkeypatch, LocalFileOpener, 'write') + fs = filesystem("file", skip_instance_cache=True) + write_errors = intercept(monkeypatch, LocalFileOpener, "write") conn.register_filesystem(fs) db_path_posix = str(PurePosixPath(tmp_path.as_posix()) / "hello.db") conn.execute(f"ATTACH 'file://{db_path_posix}'") - conn.execute('INSERT INTO hello.t VALUES (1)') + conn.execute("INSERT INTO hello.t VALUES (1)") - conn.execute('FROM hello.t') + conn.execute("FROM hello.t") assert conn.fetchall() == [(0,), (1,)] # duckdb sometimes seems to swallow write errors, so we use this to ensure that @@ -193,7 +193,7 @@ def test_copy_partition(self, duckdb_cursor: DuckDBPyConnection, memory: Abstrac duckdb_cursor.execute("copy (select 1 as a, 2 as b) to 'memory://root' (partition_by (a), HEADER 0)") - assert memory.open('/root/a=1/data_0.csv').read() == b'2\n' + assert memory.open("/root/a=1/data_0.csv").read() == b"2\n" def test_copy_partition_with_columns_written(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): duckdb_cursor.register_filesystem(memory) @@ -202,7 +202,7 @@ def test_copy_partition_with_columns_written(self, duckdb_cursor: DuckDBPyConnec "copy (select 1 as a) to 'memory://root' (partition_by (a), HEADER 0, WRITE_PARTITION_COLUMNS)" ) - assert memory.open('/root/a=1/data_0.csv').read() == b'1\n' + assert memory.open("/root/a=1/data_0.csv").read() == b"1\n" def test_read_hive_partition(self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem): duckdb_cursor.register_filesystem(memory) @@ -210,25 +210,25 @@ def test_read_hive_partition(self, duckdb_cursor: DuckDBPyConnection, memory: Ab "copy (select 2 as a, 3 as b, 4 as c) to 'memory://partition' (partition_by (a), HEADER 0)" ) - path = 'memory:///partition/*/*.csv' + path = "memory:///partition/*/*.csv" query = "SELECT * FROM read_csv_auto('" + path + "'" # hive partitioning - duckdb_cursor.execute(query + ', HIVE_PARTITIONING=1' + ');') + duckdb_cursor.execute(query + ", HIVE_PARTITIONING=1" + ");") assert duckdb_cursor.fetchall() == [(3, 4, 2)] # hive partitioning: auto detection - duckdb_cursor.execute(query + ');') + duckdb_cursor.execute(query + ");") assert duckdb_cursor.fetchall() == [(3, 4, 2)] # hive partitioning: cast to int - duckdb_cursor.execute(query + ', HIVE_PARTITIONING=1' + ', HIVE_TYPES_AUTOCAST=1' + ');') + duckdb_cursor.execute(query + ", HIVE_PARTITIONING=1" + ", HIVE_TYPES_AUTOCAST=1" + ");") assert duckdb_cursor.fetchall() == [(3, 4, 2)] # hive partitioning: no cast to int - duckdb_cursor.execute(query + ', HIVE_PARTITIONING=1' + ', HIVE_TYPES_AUTOCAST=0' + ');') - assert duckdb_cursor.fetchall() == [(3, 4, '2')] + duckdb_cursor.execute(query + ", HIVE_PARTITIONING=1" + ", HIVE_TYPES_AUTOCAST=0" + ");") + assert duckdb_cursor.fetchall() == [(3, 4, "2")] def test_read_hive_partition_with_columns_written( self, duckdb_cursor: DuckDBPyConnection, memory: AbstractFileSystem @@ -238,34 +238,34 @@ def test_read_hive_partition_with_columns_written( "copy (select 2 as a) to 'memory://partition' (partition_by (a), HEADER 0, WRITE_PARTITION_COLUMNS)" ) - path = 'memory:///partition/*/*.csv' + path = "memory:///partition/*/*.csv" query = "SELECT * FROM read_csv_auto('" + path + "'" # hive partitioning - duckdb_cursor.execute(query + ', HIVE_PARTITIONING=1' + ');') + duckdb_cursor.execute(query + ", HIVE_PARTITIONING=1" + ");") assert duckdb_cursor.fetchall() == [(2, 2)] # hive partitioning: auto detection - duckdb_cursor.execute(query + ');') + duckdb_cursor.execute(query + ");") assert duckdb_cursor.fetchall() == [(2, 2)] # hive partitioning: cast to int - duckdb_cursor.execute(query + ', HIVE_PARTITIONING=1' + ', HIVE_TYPES_AUTOCAST=1' + ');') + duckdb_cursor.execute(query + ", HIVE_PARTITIONING=1" + ", HIVE_TYPES_AUTOCAST=1" + ");") assert duckdb_cursor.fetchall() == [(2, 2)] # hive partitioning: no cast to int - duckdb_cursor.execute(query + ', HIVE_PARTITIONING=1' + ', HIVE_TYPES_AUTOCAST=0' + ');') - assert duckdb_cursor.fetchall() == [(2, '2')] + duckdb_cursor.execute(query + ", HIVE_PARTITIONING=1" + ", HIVE_TYPES_AUTOCAST=0" + ");") + assert duckdb_cursor.fetchall() == [(2, "2")] def test_parallel_union_by_name(self, tmp_path): - pa = importorskip('pyarrow') - pq = importorskip('pyarrow.parquet') - fsspec = importorskip('fsspec') + pa = importorskip("pyarrow") + pq = importorskip("pyarrow.parquet") + fsspec = importorskip("fsspec") table1 = pa.Table.from_pylist( [ - {'time': 1719568210134107692, 'col1': 1}, + {"time": 1719568210134107692, "col1": 1}, ] ) table1_path = tmp_path / "table1.parquet" @@ -273,7 +273,7 @@ def test_parallel_union_by_name(self, tmp_path): table2 = pa.Table.from_pylist( [ - {'time': 1719568210134107692, 'col1': 1}, + {"time": 1719568210134107692, "col1": 1}, ] ) table2_path = tmp_path / "table2.parquet" diff --git a/tests/fast/test_get_table_names.py b/tests/fast/test_get_table_names.py index c11b8a65..1f90e444 100644 --- a/tests/fast/test_get_table_names.py +++ b/tests/fast/test_get_table_names.py @@ -6,7 +6,7 @@ class TestGetTableNames(object): def test_table_success(self, duckdb_cursor): conn = duckdb.connect() table_names = conn.get_table_names("SELECT * FROM my_table1, my_table2, my_table3") - assert table_names == {'my_table2', 'my_table3', 'my_table1'} + assert table_names == {"my_table2", "my_table3", "my_table1"} def test_table_fail(self, duckdb_cursor): conn = duckdb.connect() @@ -19,11 +19,11 @@ def test_qualified_parameter_basic(self): # Default (qualified=False) table_names = conn.get_table_names("SELECT * FROM test_table") - assert table_names == {'test_table'} + assert table_names == {"test_table"} # Explicit qualified=False table_names = conn.get_table_names("SELECT * FROM test_table", qualified=False) - assert table_names == {'test_table'} + assert table_names == {"test_table"} def test_qualified_parameter_schemas(self): conn = duckdb.connect() @@ -31,11 +31,11 @@ def test_qualified_parameter_schemas(self): # Default (qualified=False) query = "SELECT * FROM test_schema.schema_table, main_table" table_names = conn.get_table_names(query) - assert table_names == {'schema_table', 'main_table'} + assert table_names == {"schema_table", "main_table"} # Test with qualified names table_names = conn.get_table_names(query, qualified=True) - assert table_names == {'test_schema.schema_table', 'main_table'} + assert table_names == {"test_schema.schema_table", "main_table"} def test_qualified_parameter_catalogs(self): conn = duckdb.connect() @@ -45,11 +45,11 @@ def test_qualified_parameter_catalogs(self): # Default (qualified=False) table_names = conn.get_table_names(query) - assert table_names == {'catalog_table', 'regular_table'} + assert table_names == {"catalog_table", "regular_table"} # With qualified=True table_names = conn.get_table_names(query, qualified=True) - assert table_names == {'catalog1.test_schema.catalog_table', 'regular_table'} + assert table_names == {"catalog1.test_schema.catalog_table", "regular_table"} def test_qualified_parameter_quoted_identifiers(self): conn = duckdb.connect() @@ -59,7 +59,7 @@ def test_qualified_parameter_quoted_identifiers(self): # Default (qualified=False) table_names = conn.get_table_names(query) - assert table_names == {'Table.With.Dots', 'Table With Spaces'} + assert table_names == {"Table.With.Dots", "Table With Spaces"} # With qualified=True table_names = conn.get_table_names(query, qualified=True) @@ -67,45 +67,45 @@ def test_qualified_parameter_quoted_identifiers(self): def test_expanded_views(self): conn = duckdb.connect() - conn.execute('CREATE TABLE my_table(i INT)') - conn.execute('CREATE VIEW v1 AS SELECT * FROM my_table') + conn.execute("CREATE TABLE my_table(i INT)") + conn.execute("CREATE VIEW v1 AS SELECT * FROM my_table") # Test that v1 expands to my_table - query = 'SELECT col_a FROM v1' + query = "SELECT col_a FROM v1" # Default (qualified=False) table_names = conn.get_table_names(query) - assert table_names == {'my_table'} + assert table_names == {"my_table"} # With qualified=True table_names = conn.get_table_names(query, qualified=True) - assert table_names == {'my_table'} + assert table_names == {"my_table"} def test_expanded_views_with_schema(self): conn = duckdb.connect() - conn.execute('CREATE SCHEMA my_schema') - conn.execute('CREATE TABLE my_schema.my_table(i INT)') - conn.execute('CREATE VIEW v1 AS SELECT * FROM my_schema.my_table') + conn.execute("CREATE SCHEMA my_schema") + conn.execute("CREATE TABLE my_schema.my_table(i INT)") + conn.execute("CREATE VIEW v1 AS SELECT * FROM my_schema.my_table") # Test that v1 expands to my_table - query = 'SELECT col_a FROM v1' + query = "SELECT col_a FROM v1" # Default (qualified=False) table_names = conn.get_table_names(query) - assert table_names == {'my_table'} + assert table_names == {"my_table"} # With qualified=True table_names = conn.get_table_names(query, qualified=True) - assert table_names == {'my_schema.my_table'} + assert table_names == {"my_schema.my_table"} def test_select_function(self): conn = duckdb.connect() - query = 'SELECT EXTRACT(second FROM i) FROM timestamps;' + query = "SELECT EXTRACT(second FROM i) FROM timestamps;" # Default (qualified=False) table_names = conn.get_table_names(query) - assert table_names == {'timestamps'} + assert table_names == {"timestamps"} # With qualified=True table_names = conn.get_table_names(query, qualified=True) - assert table_names == {'timestamps'} + assert table_names == {"timestamps"} diff --git a/tests/fast/test_import_export.py b/tests/fast/test_import_export.py index 2fce1636..d98a2d73 100644 --- a/tests/fast/test_import_export.py +++ b/tests/fast/test_import_export.py @@ -33,7 +33,7 @@ def move_database(export_location, import_location): assert path.exists(export_location) assert path.exists(import_location) - for file in ['schema.sql', 'load.sql', 'tbl.csv']: + for file in ["schema.sql", "load.sql", "tbl.csv"]: shutil.move(path.join(export_location, file), import_location) @@ -56,7 +56,7 @@ def export_and_import_empty_db(db_path, _): class TestDuckDBImportExport: - @pytest.mark.parametrize('routine', [export_move_and_import, export_and_import_empty_db]) + @pytest.mark.parametrize("routine", [export_move_and_import, export_and_import_empty_db]) def test_import_and_export(self, routine, tmp_path_factory): export_path = str(tmp_path_factory.mktemp("export_dbs", numbered=True)) import_path = str(tmp_path_factory.mktemp("import_dbs", numbered=True)) @@ -66,15 +66,15 @@ def test_import_empty_db(self, tmp_path_factory): import_path = str(tmp_path_factory.mktemp("empty_db", numbered=True)) # Create an empty db folder structure - Path(Path(import_path) / 'load.sql').touch() - Path(Path(import_path) / 'schema.sql').touch() + Path(Path(import_path) / "load.sql").touch() + Path(Path(import_path) / "schema.sql").touch() con = duckdb.connect() con.execute(f"import database '{import_path}'") # Put a single comment into the 'schema.sql' file - with open(Path(import_path) / 'schema.sql', 'w') as f: - f.write('--\n') + with open(Path(import_path) / "schema.sql", "w") as f: + f.write("--\n") con.close() con = duckdb.connect() diff --git a/tests/fast/test_insert.py b/tests/fast/test_insert.py index 1465b68a..baae75b4 100644 --- a/tests/fast/test_insert.py +++ b/tests/fast/test_insert.py @@ -6,7 +6,7 @@ class TestInsert(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_insert(self, pandas): test_df = pandas.DataFrame({"i": [1, 2, 3], "j": ["one", "two", "three"]}) # connect to an in-memory temporary database @@ -15,19 +15,19 @@ def test_insert(self, pandas): cursor = conn.cursor() conn.execute("CREATE TABLE test (i INTEGER, j STRING)") rel = conn.table("test") - rel.insert([1, 'one']) - rel.insert([2, 'two']) - rel.insert([3, 'three']) - rel_a3 = cursor.table('test').project('CAST(i as BIGINT)i, j').to_df() + rel.insert([1, "one"]) + rel.insert([2, "two"]) + rel.insert([3, "three"]) + rel_a3 = cursor.table("test").project("CAST(i as BIGINT)i, j").to_df() pandas.testing.assert_frame_equal(rel_a3, test_df) def test_insert_with_schema(self, duckdb_cursor): duckdb_cursor.sql("create schema not_main") duckdb_cursor.sql("create table not_main.tbl as select * from range(10)") - res = duckdb_cursor.table('not_main.tbl').fetchall() + res = duckdb_cursor.table("not_main.tbl").fetchall() assert len(res) == 10 # FIXME: This is not currently supported - with pytest.raises(duckdb.CatalogException, match='Table with name tbl does not exist'): - duckdb_cursor.table('not_main.tbl').insert([42, 21, 1337]) + with pytest.raises(duckdb.CatalogException, match="Table with name tbl does not exist"): + duckdb_cursor.table("not_main.tbl").insert([42, 21, 1337]) diff --git a/tests/fast/test_many_con_same_file.py b/tests/fast/test_many_con_same_file.py index 6b7362a6..3cef2494 100644 --- a/tests/fast/test_many_con_same_file.py +++ b/tests/fast/test_many_con_same_file.py @@ -23,7 +23,7 @@ def test_multiple_writes(): con1.close() con3 = duckdb.connect("test.db") tbls = get_tables(con3) - assert tbls == ['bar1', 'foo1'] + assert tbls == ["bar1", "foo1"] del con1 del con2 del con3 @@ -41,9 +41,9 @@ def test_multiple_writes_memory(): con2.execute("CREATE TABLE bar1 as SELECT 2 as a, 3 as b") con3 = duckdb.connect(":memory:") tbls = get_tables(con1) - assert tbls == ['foo1'] + assert tbls == ["foo1"] tbls = get_tables(con2) - assert tbls == ['bar1'] + assert tbls == ["bar1"] tbls = get_tables(con3) assert tbls == [] del con1 @@ -58,7 +58,7 @@ def test_multiple_writes_named_memory(): con2.execute("CREATE TABLE bar1 as SELECT 2 as a, 3 as b") con3 = duckdb.connect(":memory:1") tbls = get_tables(con3) - assert tbls == ['bar1', 'foo1'] + assert tbls == ["bar1", "foo1"] del con1 del con2 del con3 @@ -76,7 +76,7 @@ def test_diff_config(): def test_diff_config_extended(): - con1 = duckdb.connect("test.db", config={'null_order': 'NULLS FIRST'}) + con1 = duckdb.connect("test.db", config={"null_order": "NULLS FIRST"}) with pytest.raises( duckdb.ConnectionException, match="Can't open a connection to same database file with a different configuration than existing connections", diff --git a/tests/fast/test_map.py b/tests/fast/test_map.py index 4dbd1a36..f86dd60b 100644 --- a/tests/fast/test_map.py +++ b/tests/fast/test_map.py @@ -9,36 +9,36 @@ # column count differs from bind def evil1(df): if len(df) == 0: - return df['col0'].to_frame() + return df["col0"].to_frame() else: return df class TestMap(object): - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_evil_map(self, duckdb_cursor, pandas): testrel = duckdb.values([1, 2]) - with pytest.raises(duckdb.InvalidInputException, match='Expected 1 columns from UDF, got 2'): - rel = testrel.map(evil1, schema={'i': str}) + with pytest.raises(duckdb.InvalidInputException, match="Expected 1 columns from UDF, got 2"): + rel = testrel.map(evil1, schema={"i": str}) df = rel.df() print(df) - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_map(self, duckdb_cursor, pandas): testrel = duckdb.values([1, 2]) conn = duckdb_cursor - conn.execute('CREATE TABLE t (a integer)') - empty_rel = conn.table('t') + conn.execute("CREATE TABLE t (a integer)") + empty_rel = conn.table("t") - newdf1 = testrel.map(lambda df: df['col0'].add(42).to_frame()) - newdf2 = testrel.map(lambda df: df['col0'].astype('string').to_frame()) + newdf1 = testrel.map(lambda df: df["col0"].add(42).to_frame()) + newdf2 = testrel.map(lambda df: df["col0"].astype("string").to_frame()) newdf3 = testrel.map(lambda df: df) # column type differs from bind def evil2(df): result = df.copy(deep=True) if len(result) == 0: - result['col0'] = result['col0'].astype('double') + result["col0"] = result["col0"].astype("double") return result # column name differs from bind @@ -56,10 +56,10 @@ def evil5(df): raise TypeError def return_dataframe(df): - return pandas.DataFrame({'A': [1]}) + return pandas.DataFrame({"A": [1]}) def return_big_dataframe(df): - return pandas.DataFrame({'A': [1] * 5000}) + return pandas.DataFrame({"A": [1] * 5000}) def return_none(df): return None @@ -67,13 +67,13 @@ def return_none(df): def return_empty_df(df): return pandas.DataFrame() - with pytest.raises(duckdb.InvalidInputException, match='Expected 1 columns from UDF, got 2'): + with pytest.raises(duckdb.InvalidInputException, match="Expected 1 columns from UDF, got 2"): print(testrel.map(evil1).df()) - with pytest.raises(duckdb.InvalidInputException, match='UDF column type mismatch'): + with pytest.raises(duckdb.InvalidInputException, match="UDF column type mismatch"): print(testrel.map(evil2).df()) - with pytest.raises(duckdb.InvalidInputException, match='UDF column name mismatch'): + with pytest.raises(duckdb.InvalidInputException, match="UDF column name mismatch"): print(testrel.map(evil3).df()) with pytest.raises( @@ -92,19 +92,19 @@ def return_empty_df(df): with pytest.raises(TypeError): print(testrel.map().df()) - testrel.map(return_dataframe).df().equals(pandas.DataFrame({'A': [1]})) + testrel.map(return_dataframe).df().equals(pandas.DataFrame({"A": [1]})) with pytest.raises( - duckdb.InvalidInputException, match='UDF returned more than 2048 rows, which is not allowed.' + duckdb.InvalidInputException, match="UDF returned more than 2048 rows, which is not allowed." ): testrel.map(return_big_dataframe).df() - empty_rel.map(return_dataframe).df().equals(pandas.DataFrame({'A': []})) + empty_rel.map(return_dataframe).df().equals(pandas.DataFrame({"A": []})) - with pytest.raises(duckdb.InvalidInputException, match='No return value from Python function'): + with pytest.raises(duckdb.InvalidInputException, match="No return value from Python function"): testrel.map(return_none).df() - with pytest.raises(duckdb.InvalidInputException, match='Need a DataFrame with at least one column'): + with pytest.raises(duckdb.InvalidInputException, match="Need a DataFrame with at least one column"): testrel.map(return_empty_df).df() def test_map_with_object_column(self, duckdb_cursor): @@ -115,21 +115,21 @@ def return_with_no_modification(df): # when a dataframe with 'object' column is returned, we use the content to infer the type # when the dataframe is empty, this results in NULL, which is not desirable # in this case we assume the returned type should be the same as the input type - duckdb_cursor.values([b'1234']).map(return_with_no_modification).fetchall() + duckdb_cursor.values([b"1234"]).map(return_with_no_modification).fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_isse_3237(self, duckdb_cursor, pandas): def process(rel): def mapper(x): - dates = x['date'].to_numpy("datetime64[us]") - days = x['days_to_add'].to_numpy("int") + dates = x["date"].to_numpy("datetime64[us]") + days = x["days_to_add"].to_numpy("int") x["result1"] = pandas.Series( [pandas.to_datetime(y[0]).date() + timedelta(days=y[1].item()) for y in zip(dates, days)], - dtype='datetime64[us]', + dtype="datetime64[us]", ) x["result2"] = pandas.Series( [pandas.to_datetime(y[0]).date() + timedelta(days=-y[1].item()) for y in zip(dates, days)], - dtype='datetime64[us]', + dtype="datetime64[us]", ) return x @@ -140,22 +140,22 @@ def mapper(x): return rel df = pandas.DataFrame( - {'date': pandas.Series([date(2000, 1, 1), date(2000, 1, 2)], dtype="datetime64[us]"), 'days_to_add': [1, 2]} + {"date": pandas.Series([date(2000, 1, 1), date(2000, 1, 2)], dtype="datetime64[us]"), "days_to_add": [1, 2]} ) rel = duckdb.from_df(df) rel = process(rel) x = rel.fetchdf() - assert x['days_to_add'].to_numpy()[0] == 1 + assert x["days_to_add"].to_numpy()[0] == 1 def test_explicit_schema(self): def cast_to_string(df): - df['i'] = df['i'].astype(str) + df["i"] = df["i"].astype(str) return df con = duckdb.connect() - rel = con.sql('select i from range (10) tbl(i)') + rel = con.sql("select i from range (10) tbl(i)") assert rel.types[0] == duckdb.NUMBER - mapped_rel = rel.map(cast_to_string, schema={'i': str}) + mapped_rel = rel.map(cast_to_string, schema={"i": str}) assert mapped_rel.types[0] == duckdb.STRING def test_explicit_schema_returntype_mismatch(self): @@ -163,45 +163,45 @@ def does_nothing(df): return df con = duckdb.connect() - rel = con.sql('select i from range(10) tbl(i)') + rel = con.sql("select i from range(10) tbl(i)") # expects the mapper to return a string column - rel = rel.map(does_nothing, schema={'i': str}) + rel = rel.map(does_nothing, schema={"i": str}) with pytest.raises( duckdb.InvalidInputException, match=re.escape("UDF column type mismatch, expected [VARCHAR], got [BIGINT]") ): rel.fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_explicit_schema_name_mismatch(self, pandas): def renames_column(df): - return pandas.DataFrame({'a': df['i']}) + return pandas.DataFrame({"a": df["i"]}) con = duckdb.connect() - rel = con.sql('select i from range(10) tbl(i)') - rel = rel.map(renames_column, schema={'i': int}) - with pytest.raises(duckdb.InvalidInputException, match=re.escape('UDF column name mismatch')): + rel = con.sql("select i from range(10) tbl(i)") + rel = rel.map(renames_column, schema={"i": int}) + with pytest.raises(duckdb.InvalidInputException, match=re.escape("UDF column name mismatch")): rel.fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_explicit_schema_error(self, pandas): def no_op(df): return df con = duckdb.connect() - rel = con.sql('select 42') + rel = con.sql("select 42") with pytest.raises( duckdb.InvalidInputException, match=re.escape("Invalid Input Error: 'schema' should be given as a Dict[str, DuckDBType]"), ): rel.map(no_op, schema=[int]) - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_returns_non_dataframe(self, pandas): def returns_series(df): - return df.loc[:, 'i'] + return df.loc[:, "i"] con = duckdb.connect() - rel = con.sql('select i, i as j from range(10) tbl(i)') + rel = con.sql("select i, i as j from range(10) tbl(i)") with pytest.raises( duckdb.InvalidInputException, match=re.escape( @@ -210,29 +210,29 @@ def returns_series(df): ): rel = rel.map(returns_series) - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_explicit_schema_columncount_mismatch(self, pandas): def returns_subset(df): - return pandas.DataFrame({'i': df.loc[:, 'i']}) + return pandas.DataFrame({"i": df.loc[:, "i"]}) con = duckdb.connect() - rel = con.sql('select i, i as j from range(10) tbl(i)') - rel = rel.map(returns_subset, schema={'i': int, 'j': int}) + rel = con.sql("select i, i as j from range(10) tbl(i)") + rel = rel.map(returns_subset, schema={"i": int, "j": int}) with pytest.raises( - duckdb.InvalidInputException, match='Invalid Input Error: Expected 2 columns from UDF, got 1' + duckdb.InvalidInputException, match="Invalid Input Error: Expected 2 columns from UDF, got 1" ): rel.fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_pyarrow_df(self, pandas): # PyArrow backed dataframes only exist on pandas >= 2.0.0 _ = pytest.importorskip("pandas", "2.0.0") def basic_function(df): # Create a pyarrow backed dataframe - df = pandas.DataFrame({'a': [5, 3, 2, 1, 2]}).convert_dtypes(dtype_backend='pyarrow') + df = pandas.DataFrame({"a": [5, 3, 2, 1, 2]}).convert_dtypes(dtype_backend="pyarrow") return df con = duckdb.connect() with pytest.raises(duckdb.InvalidInputException): - rel = con.sql('select 42').map(basic_function) + rel = con.sql("select 42").map(basic_function) diff --git a/tests/fast/test_metatransaction.py b/tests/fast/test_metatransaction.py index 158bb6a9..f617cba2 100644 --- a/tests/fast/test_metatransaction.py +++ b/tests/fast/test_metatransaction.py @@ -10,7 +10,7 @@ class TestMetaTransaction(object): def test_fetchmany(self, duckdb_cursor): duckdb_cursor.execute("CREATE SEQUENCE id_seq") - column_names = ',\n'.join([f'column_{i} FLOAT' for i in range(1, NUMBER_OF_COLUMNS + 1)]) + column_names = ",\n".join([f"column_{i} FLOAT" for i in range(1, NUMBER_OF_COLUMNS + 1)]) create_table_query = f""" CREATE TABLE my_table ( id INTEGER DEFAULT nextval('id_seq'), @@ -23,7 +23,7 @@ def test_fetchmany(self, duckdb_cursor): for i in range(20): # Then insert a large amount of tuples, triggering a parallel execution data = np.random.rand(NUMBER_OF_ROWS, NUMBER_OF_COLUMNS) - columns = [f'Column_{i+1}' for i in range(NUMBER_OF_COLUMNS)] + columns = [f"Column_{i + 1}" for i in range(NUMBER_OF_COLUMNS)] df = pd.DataFrame(data, columns=columns) df_columns = ", ".join(df.columns) # This gets executed in parallel, causing NextValFunction to be called in parallel diff --git a/tests/fast/test_multi_statement.py b/tests/fast/test_multi_statement.py index db82eaf3..722ab31a 100644 --- a/tests/fast/test_multi_statement.py +++ b/tests/fast/test_multi_statement.py @@ -7,36 +7,36 @@ class TestMultiStatement(object): def test_multi_statement(self, duckdb_cursor): import duckdb - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") # test empty statement - con.execute('') + con.execute("") # run multiple statements in one call to execute con.execute( - ''' + """ CREATE TABLE integers(i integer); insert into integers select * from range(10); select * from integers; - ''' + """ ) results = [x[0] for x in con.fetchall()] assert results == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # test export/import - export_location = os.path.join(os.getcwd(), 'duckdb_pytest_dir_export') + export_location = os.path.join(os.getcwd(), "duckdb_pytest_dir_export") try: shutil.rmtree(export_location) except: pass - con.execute('CREATE TABLE integers2(i INTEGER)') - con.execute('INSERT INTO integers2 VALUES (1), (5), (7), (1928)') + con.execute("CREATE TABLE integers2(i INTEGER)") + con.execute("INSERT INTO integers2 VALUES (1), (5), (7), (1928)") con.execute("EXPORT DATABASE '%s'" % (export_location,)) # reset connection - con = duckdb.connect(':memory:') + con = duckdb.connect(":memory:") con.execute("IMPORT DATABASE '%s'" % (export_location,)) - integers = [x[0] for x in con.execute('SELECT * FROM integers').fetchall()] - integers2 = [x[0] for x in con.execute('SELECT * FROM integers2').fetchall()] + integers = [x[0] for x in con.execute("SELECT * FROM integers").fetchall()] + integers2 = [x[0] for x in con.execute("SELECT * FROM integers2").fetchall()] assert integers == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] assert integers2 == [1, 5, 7, 1928] shutil.rmtree(export_location) diff --git a/tests/fast/test_multithread.py b/tests/fast/test_multithread.py index ad2d56fd..628aacd8 100644 --- a/tests/fast/test_multithread.py +++ b/tests/fast/test_multithread.py @@ -16,7 +16,7 @@ def connect_duck(duckdb_conn): - out = duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchall() + out = duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetchall() assert out == [(42,), (84,), (None,), (128,)] @@ -39,7 +39,7 @@ def multithread_test(self, result_verification=everything_succeeded): for i in range(0, self.duckdb_insert_thread_count): self.threads.append( threading.Thread( - target=self.thread_function, args=(duckdb_conn, queue, self.pandas), name='duckdb_thread_' + str(i) + target=self.thread_function, args=(duckdb_conn, queue, self.pandas), name="duckdb_thread_" + str(i) ) ) @@ -60,7 +60,7 @@ def multithread_test(self, result_verification=everything_succeeded): def execute_query_same_connection(duckdb_conn, queue, pandas): try: - out = duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)') + out = duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)") queue.put(False) except: queue.put(True) @@ -70,7 +70,7 @@ def execute_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)') + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)") queue.put(True) except: queue.put(False) @@ -80,7 +80,7 @@ def insert_runtime_error(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('insert into T values (42), (84), (NULL), (128)') + duckdb_conn.execute("insert into T values (42), (84), (NULL), (128)") queue.put(False) except: queue.put(True) @@ -104,9 +104,9 @@ def execute_many_query(duckdb_conn, queue, pandas): ) # Larger example that inserts many records at a time purchases = [ - ('2006-03-28', 'BUY', 'IBM', 1000, 45.00), - ('2006-04-05', 'BUY', 'MSFT', 1000, 72.00), - ('2006-04-06', 'SELL', 'IBM', 500, 53.00), + ("2006-03-28", "BUY", "IBM", 1000, 45.00), + ("2006-04-05", "BUY", "MSFT", 1000, 72.00), + ("2006-04-06", "SELL", "IBM", 500, 53.00), ] duckdb_conn.executemany( """ @@ -123,7 +123,7 @@ def fetchone_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchone() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetchone() queue.put(True) except: queue.put(False) @@ -133,7 +133,7 @@ def fetchall_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchall() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetchall() queue.put(True) except: queue.put(False) @@ -153,7 +153,7 @@ def fetchnp_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchnumpy() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetchnumpy() queue.put(True) except: queue.put(False) @@ -163,7 +163,7 @@ def fetchdf_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetchdf() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetchdf() queue.put(True) except: queue.put(False) @@ -173,7 +173,7 @@ def fetchdf_chunk_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetch_df_chunk() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetch_df_chunk() queue.put(True) except: queue.put(False) @@ -183,7 +183,7 @@ def fetch_arrow_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetch_arrow_table() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetch_arrow_table() queue.put(True) except: queue.put(False) @@ -193,7 +193,7 @@ def fetch_record_batch_query(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute('select i from (values (42), (84), (NULL), (128)) tbl(i)').fetch_record_batch() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetch_record_batch() queue.put(True) except: queue.put(False) @@ -205,9 +205,9 @@ def transaction_query(duckdb_conn, queue, pandas): duckdb_conn.execute("CREATE TABLE T ( i INTEGER)") try: duckdb_conn.begin() - duckdb_conn.execute('insert into T values (42), (84), (NULL), (128)') + duckdb_conn.execute("insert into T values (42), (84), (NULL), (128)") duckdb_conn.rollback() - duckdb_conn.execute('insert into T values (42), (84), (NULL), (128)') + duckdb_conn.execute("insert into T values (42), (84), (NULL), (128)") duckdb_conn.commit() queue.put(True) except: @@ -218,9 +218,9 @@ def df_append(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE T ( i INTEGER)") - df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=['A']) + df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=["A"]) try: - duckdb_conn.append('T', df) + duckdb_conn.append("T", df) queue.put(True) except: queue.put(False) @@ -229,9 +229,9 @@ def df_append(duckdb_conn, queue, pandas): def df_register(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=['A']) + df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=["A"]) try: - duckdb_conn.register('T', df) + duckdb_conn.register("T", df) queue.put(True) except: queue.put(False) @@ -240,10 +240,10 @@ def df_register(duckdb_conn, queue, pandas): def df_unregister(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=['A']) + df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=["A"]) try: - duckdb_conn.register('T', df) - duckdb_conn.unregister('T') + duckdb_conn.register("T", df) + duckdb_conn.unregister("T") queue.put(True) except: queue.put(False) @@ -251,12 +251,12 @@ def df_unregister(duckdb_conn, queue, pandas): def arrow_register_unregister(duckdb_conn, queue, pandas): # Get a new connection - pa = pytest.importorskip('pyarrow') + pa = pytest.importorskip("pyarrow") duckdb_conn = duckdb.connect() - arrow_tbl = pa.Table.from_pydict({'my_column': pa.array([1, 2, 3, 4, 5], type=pa.int64())}) + arrow_tbl = pa.Table.from_pydict({"my_column": pa.array([1, 2, 3, 4, 5], type=pa.int64())}) try: - duckdb_conn.register('T', arrow_tbl) - duckdb_conn.unregister('T') + duckdb_conn.register("T", arrow_tbl) + duckdb_conn.unregister("T") queue.put(True) except: queue.put(False) @@ -267,7 +267,7 @@ def table(duckdb_conn, queue, pandas): duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE T ( i INTEGER)") try: - out = duckdb_conn.table('T') + out = duckdb_conn.table("T") queue.put(True) except: queue.put(False) @@ -279,7 +279,7 @@ def view(duckdb_conn, queue, pandas): duckdb_conn.execute("CREATE TABLE T ( i INTEGER)") duckdb_conn.execute("CREATE VIEW V as (SELECT * FROM T)") try: - out = duckdb_conn.values([5, 'five']) + out = duckdb_conn.values([5, "five"]) queue.put(True) except: queue.put(False) @@ -289,7 +289,7 @@ def values(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() try: - out = duckdb_conn.values([5, 'five']) + out = duckdb_conn.values([5, "five"]) queue.put(True) except: queue.put(False) @@ -308,7 +308,7 @@ def from_query(duckdb_conn, queue, pandas): def from_df(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - df = pandas.DataFrame(['bla', 'blabla'] * 10, columns=['A']) + df = pandas.DataFrame(["bla", "blabla"] * 10, columns=["A"]) try: out = duckdb_conn.execute("select * from df").fetchall() queue.put(True) @@ -318,9 +318,9 @@ def from_df(duckdb_conn, queue, pandas): def from_arrow(duckdb_conn, queue, pandas): # Get a new connection - pa = pytest.importorskip('pyarrow') + pa = pytest.importorskip("pyarrow") duckdb_conn = duckdb.connect() - arrow_tbl = pa.Table.from_pydict({'my_column': pa.array([1, 2, 3, 4, 5], type=pa.int64())}) + arrow_tbl = pa.Table.from_pydict({"my_column": pa.array([1, 2, 3, 4, 5], type=pa.int64())}) try: out = duckdb_conn.from_arrow(arrow_tbl) queue.put(True) @@ -331,7 +331,7 @@ def from_arrow(duckdb_conn, queue, pandas): def from_csv_auto(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'integers.csv') + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "integers.csv") try: out = duckdb_conn.from_csv_auto(filename) queue.put(True) @@ -342,7 +342,7 @@ def from_csv_auto(duckdb_conn, queue, pandas): def from_parquet(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'binary_string.parquet') + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "binary_string.parquet") try: out = duckdb_conn.from_parquet(filename) queue.put(True) @@ -353,7 +353,7 @@ def from_parquet(duckdb_conn, queue, pandas): def description(duckdb_conn, queue, pandas): # Get a new connection duckdb_conn = duckdb.connect() - duckdb_conn.execute('CREATE TABLE test (i bool, j TIME, k VARCHAR)') + duckdb_conn.execute("CREATE TABLE test (i bool, j TIME, k VARCHAR)") duckdb_conn.execute("INSERT INTO test VALUES (TRUE, '01:01:01', 'bla' )") rel = duckdb_conn.table("test") rel.execute() @@ -368,138 +368,138 @@ def cursor(duckdb_conn, queue, pandas): # Get a new connection cx = duckdb_conn.cursor() try: - cx.execute('CREATE TABLE test (i bool, j TIME, k VARCHAR)') + cx.execute("CREATE TABLE test (i bool, j TIME, k VARCHAR)") queue.put(False) except: queue.put(True) class TestDuckMultithread(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_execute(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, execute_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_execute_many(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, execute_many_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetchone(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, fetchone_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetchall(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, fetchall_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_close(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, conn_close, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetchnp(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, fetchnp_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetchdf(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, fetchdf_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetchdfchunk(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, fetchdf_chunk_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetcharrow(self, duckdb_cursor, pandas): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") duck_threads = DuckDBThreaded(10, fetch_arrow_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_fetch_record_batch(self, duckdb_cursor, pandas): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") duck_threads = DuckDBThreaded(10, fetch_record_batch_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_transaction(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, transaction_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_df_append(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, df_append, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_df_register(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, df_register, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_df_unregister(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, df_unregister, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_arrow_register_unregister(self, duckdb_cursor, pandas): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") duck_threads = DuckDBThreaded(10, arrow_register_unregister, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_table(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, table, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_view(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, view, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_values(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, values, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_query(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, from_query, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_DF(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, from_df, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_arrow(self, duckdb_cursor, pandas): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") duck_threads = DuckDBThreaded(10, from_arrow, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_csv_auto(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, from_csv_auto, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_parquet(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, from_parquet, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_description(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, description, pandas) duck_threads.multithread_test() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_cursor(self, duckdb_cursor, pandas): def only_some_succeed(results: list[bool]): if not any([result == True for result in results]): diff --git a/tests/fast/test_non_default_conn.py b/tests/fast/test_non_default_conn.py index bc9fa5f0..cb0218e3 100644 --- a/tests/fast/test_non_default_conn.py +++ b/tests/fast/test_non_default_conn.py @@ -24,7 +24,7 @@ def test_from_csv(self, duckdb_cursor): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) test_df.to_csv(temp_file_name, index=False) rel = duckdb_cursor.from_csv_auto(temp_file_name) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) def test_from_parquet(self, duckdb_cursor): try: @@ -37,16 +37,16 @@ def test_from_parquet(self, duckdb_cursor): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) test_df.to_parquet(temp_file_name, index=False) rel = duckdb_cursor.from_parquet(temp_file_name) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) def test_from_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) rel = duckdb.df(test_df, connection=duckdb_cursor) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) rel = duckdb_cursor.from_df(test_df) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) def test_from_arrow(self, duckdb_cursor): try: @@ -59,55 +59,55 @@ def test_from_arrow(self, duckdb_cursor): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) test_arrow = pa.Table.from_pandas(test_df) rel = duckdb_cursor.from_arrow(test_arrow) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) rel = duckdb.arrow(test_arrow, connection=duckdb_cursor) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) def test_filter_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1), (4)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) rel = duckdb.filter(test_df, "i < 2", connection=duckdb_cursor) - assert rel.query('t_2', 'select count(*) from t inner join t_2 on (a = i)').fetchall()[0] == (1,) + assert rel.query("t_2", "select count(*) from t inner join t_2 on (a = i)").fetchall()[0] == (1,) def test_project_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1), (4)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": [1, 2, 3, 4]}) rel = duckdb.project(test_df, "i", connection=duckdb_cursor) - assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (1, 1) + assert rel.query("t_2", "select * from t inner join t_2 on (a = i)").fetchall()[0] == (1, 1) def test_agg_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1), (4)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": [1, 2, 3, 4]}) rel = duckdb.aggregate(test_df, "count(*) as i", connection=duckdb_cursor) - assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (4, 4) + assert rel.query("t_2", "select * from t inner join t_2 on (a = i)").fetchall()[0] == (4, 4) def test_distinct_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1)") test_df = pd.DataFrame.from_dict({"i": [1, 1, 2, 3, 4]}) rel = duckdb.distinct(test_df, connection=duckdb_cursor) - assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (1, 1) + assert rel.query("t_2", "select * from t inner join t_2 on (a = i)").fetchall()[0] == (1, 1) def test_limit_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1),(4)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) rel = duckdb.limit(test_df, 1, connection=duckdb_cursor) - assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (1, 1) + assert rel.query("t_2", "select * from t inner join t_2 on (a = i)").fetchall()[0] == (1, 1) def test_query_df(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1),(4)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) - rel = duckdb.query_df(test_df, 't_2', 'select * from t inner join t_2 on (a = i)', connection=duckdb_cursor) + rel = duckdb.query_df(test_df, "t_2", "select * from t inner join t_2 on (a = i)", connection=duckdb_cursor) assert rel.fetchall()[0] == (1, 1) def test_query_order(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb_cursor.execute("insert into t values (1),(4)") test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) - rel = duckdb.order(test_df, 'i', connection=duckdb_cursor) - assert rel.query('t_2', 'select * from t inner join t_2 on (a = i)').fetchall()[0] == (1, 1) + rel = duckdb.order(test_df, "i", connection=duckdb_cursor) + assert rel.query("t_2", "select * from t inner join t_2 on (a = i)").fetchall()[0] == (1, 1) diff --git a/tests/fast/test_parameter_list.py b/tests/fast/test_parameter_list.py index 032b1b9c..5a85ac2f 100644 --- a/tests/fast/test_parameter_list.py +++ b/tests/fast/test_parameter_list.py @@ -11,22 +11,22 @@ def test_bool(self, duckdb_cursor): res = conn.execute("select count(*) from bool_table where a =?", [True]) assert res.fetchone()[0] == 1 - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_exception(self, duckdb_cursor, pandas): conn = duckdb.connect() df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) conn.execute("create table bool_table (a bool)") conn.execute("insert into bool_table values (TRUE)") - with pytest.raises(duckdb.NotImplementedException, match='Unable to transform'): + with pytest.raises(duckdb.NotImplementedException, match="Unable to transform"): res = conn.execute("select count(*) from bool_table where a =?", [df_in]) def test_explicit_nan_param(self): con = duckdb.default_connection() - res = con.execute('select isnan(cast(? as double))', (float("nan"),)) + res = con.execute("select isnan(cast(? as double))", (float("nan"),)) assert res.fetchone()[0] == True def test_string_parameter(self, duckdb_cursor): diff --git a/tests/fast/test_parquet.py b/tests/fast/test_parquet.py index 51d8d276..61d74023 100644 --- a/tests/fast/test_parquet.py +++ b/tests/fast/test_parquet.py @@ -7,13 +7,13 @@ VARCHAR = duckdb.typing.VARCHAR BIGINT = duckdb.typing.BIGINT -filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'binary_string.parquet') +filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "binary_string.parquet") @pytest.fixture(scope="session") def tmp_parquets(tmp_path_factory): - tmp_dir = tmp_path_factory.mktemp('parquets', numbered=True) - tmp_parquets = [str(tmp_dir / ('tmp' + str(i) + '.parquet')) for i in range(1, 4)] + tmp_dir = tmp_path_factory.mktemp("parquets", numbered=True) + tmp_parquets = [str(tmp_dir / ("tmp" + str(i) + ".parquet")) for i in range(1, 4)] return tmp_parquets @@ -21,34 +21,34 @@ class TestParquet(object): def test_scan_binary(self, duckdb_cursor): conn = duckdb.connect() res = conn.execute("SELECT typeof(#1) FROM parquet_scan('" + filename + "') limit 1").fetchall() - assert res[0] == ('BLOB',) + assert res[0] == ("BLOB",) res = conn.execute("SELECT * FROM parquet_scan('" + filename + "')").fetchall() - assert res[0] == (b'foo',) + assert res[0] == (b"foo",) def test_from_parquet_binary(self, duckdb_cursor): rel = duckdb.from_parquet(filename) - assert rel.types == ['BLOB'] + assert rel.types == ["BLOB"] res = rel.execute().fetchall() - assert res[0] == (b'foo',) + assert res[0] == (b"foo",) def test_scan_binary_as_string(self, duckdb_cursor): conn = duckdb.connect() res = conn.execute( "SELECT typeof(#1) FROM parquet_scan('" + filename + "',binary_as_string=True) limit 1" ).fetchall() - assert res[0] == ('VARCHAR',) + assert res[0] == ("VARCHAR",) res = conn.execute("SELECT * FROM parquet_scan('" + filename + "',binary_as_string=True)").fetchall() - assert res[0] == ('foo',) + assert res[0] == ("foo",) def test_from_parquet_binary_as_string(self, duckdb_cursor): rel = duckdb.from_parquet(filename, True) assert rel.types == [VARCHAR] res = rel.execute().fetchall() - assert res[0] == ('foo',) + assert res[0] == ("foo",) def test_from_parquet_file_row_number(self, duckdb_cursor): rel = duckdb.from_parquet(filename, binary_as_string=True, file_row_number=True) @@ -56,7 +56,7 @@ def test_from_parquet_file_row_number(self, duckdb_cursor): res = rel.execute().fetchall() assert res[0] == ( - 'foo', + "foo", 0, ) @@ -66,7 +66,7 @@ def test_from_parquet_filename(self, duckdb_cursor): res = rel.execute().fetchall() assert res[0] == ( - 'foo', + "foo", filename, ) @@ -75,7 +75,7 @@ def test_from_parquet_list_binary_as_string(self, duckdb_cursor): assert rel.types == [VARCHAR] res = rel.execute().fetchall() - assert res[0] == ('foo',) + assert res[0] == ("foo",) def test_from_parquet_list_file_row_number(self, duckdb_cursor): rel = duckdb.from_parquet([filename], binary_as_string=True, file_row_number=True) @@ -83,7 +83,7 @@ def test_from_parquet_list_file_row_number(self, duckdb_cursor): res = rel.execute().fetchall() assert res[0] == ( - 'foo', + "foo", 0, ) @@ -93,41 +93,41 @@ def test_from_parquet_list_filename(self, duckdb_cursor): res = rel.execute().fetchall() assert res[0] == ( - 'foo', + "foo", filename, ) def test_parquet_binary_as_string_pragma(self, duckdb_cursor): conn = duckdb.connect() res = conn.execute("SELECT typeof(#1) FROM parquet_scan('" + filename + "') limit 1").fetchall() - assert res[0] == ('BLOB',) + assert res[0] == ("BLOB",) res = conn.execute("SELECT * FROM parquet_scan('" + filename + "')").fetchall() - assert res[0] == (b'foo',) + assert res[0] == (b"foo",) conn.execute("PRAGMA binary_as_string=1") res = conn.execute("SELECT typeof(#1) FROM parquet_scan('" + filename + "') limit 1").fetchall() - assert res[0] == ('VARCHAR',) + assert res[0] == ("VARCHAR",) res = conn.execute("SELECT * FROM parquet_scan('" + filename + "')").fetchall() - assert res[0] == ('foo',) + assert res[0] == ("foo",) res = conn.execute( "SELECT typeof(#1) FROM parquet_scan('" + filename + "',binary_as_string=False) limit 1" ).fetchall() - assert res[0] == ('BLOB',) + assert res[0] == ("BLOB",) res = conn.execute("SELECT * FROM parquet_scan('" + filename + "',binary_as_string=False)").fetchall() - assert res[0] == (b'foo',) + assert res[0] == (b"foo",) conn.execute("PRAGMA binary_as_string=0") res = conn.execute("SELECT typeof(#1) FROM parquet_scan('" + filename + "') limit 1").fetchall() - assert res[0] == ('BLOB',) + assert res[0] == ("BLOB",) res = conn.execute("SELECT * FROM parquet_scan('" + filename + "')").fetchall() - assert res[0] == (b'foo',) + assert res[0] == (b"foo",) def test_from_parquet_binary_as_string_default_conn(self, duckdb_cursor): duckdb.execute("PRAGMA binary_as_string=1") @@ -136,7 +136,7 @@ def test_from_parquet_binary_as_string_default_conn(self, duckdb_cursor): assert rel.types == [VARCHAR] res = rel.execute().fetchall() - assert res[0] == ('foo',) + assert res[0] == ("foo",) def test_from_parquet_union_by_name(self, tmp_parquets): conn = duckdb.connect() @@ -159,7 +159,7 @@ def test_from_parquet_union_by_name(self, tmp_parquets): + "' (format 'parquet');" ) - rel = duckdb.from_parquet(tmp_parquets, union_by_name=True).order('a') + rel = duckdb.from_parquet(tmp_parquets, union_by_name=True).order("a") assert rel.execute().fetchall() == [ ( 1, diff --git a/tests/fast/test_pypi_cleanup.py b/tests/fast/test_pypi_cleanup.py index 6e1460e2..84d4c9ff 100644 --- a/tests/fast/test_pypi_cleanup.py +++ b/tests/fast/test_pypi_cleanup.py @@ -15,51 +15,61 @@ duckdb_packaging = pytest.importorskip("duckdb_packaging") from duckdb_packaging.pypi_cleanup import ( - PyPICleanup, CsrfParser, PyPICleanupError, AuthenticationError, ValidationError, - setup_logging, validate_username, create_argument_parser, session_with_retries, - load_credentials, validate_arguments, main + PyPICleanup, + CsrfParser, + PyPICleanupError, + AuthenticationError, + ValidationError, + setup_logging, + validate_username, + create_argument_parser, + session_with_retries, + load_credentials, + validate_arguments, + main, ) + class TestValidation: """Test input validation functions.""" - + def test_validate_username_valid(self): """Test valid usernames.""" assert validate_username("user123") == "user123" assert validate_username(" user.name ") == "user.name" assert validate_username("test-user_name") == "test-user_name" assert validate_username("a") == "a" - + def test_validate_username_invalid(self): """Test invalid usernames.""" from argparse import ArgumentTypeError - + with pytest.raises(ArgumentTypeError, match="cannot be empty"): validate_username("") - + with pytest.raises(ArgumentTypeError, match="cannot be empty"): validate_username(" ") - + with pytest.raises(ArgumentTypeError, match="too long"): validate_username("a" * 101) - + with pytest.raises(ArgumentTypeError, match="Invalid username format"): validate_username("-invalid") - + with pytest.raises(ArgumentTypeError, match="Invalid username format"): validate_username("invalid-") - + def test_validate_arguments_dry_run(self): """Test argument validation for dry run mode.""" args = Mock(dry_run=True, username=None, max_nightlies=2) validate_arguments(args) # Should not raise - + def test_validate_arguments_live_mode_no_username(self): """Test argument validation for live mode without username.""" args = Mock(dry_run=False, username=None, max_nightlies=2) with pytest.raises(ValidationError, match="username is required"): validate_arguments(args) - + def test_validate_arguments_negative_nightlies(self): """Test argument validation with negative max nightlies.""" args = Mock(dry_run=True, username="test", max_nightlies=-1) @@ -69,27 +79,27 @@ def test_validate_arguments_negative_nightlies(self): class TestCredentials: """Test credential loading.""" - + def test_load_credentials_dry_run(self): """Test credential loading in dry run mode.""" password, otp = load_credentials(dry_run=True) assert password is None assert otp is None - - @patch.dict(os.environ, {'PYPI_CLEANUP_PASSWORD': 'test_pass', 'PYPI_CLEANUP_OTP': 'test_otp'}) + + @patch.dict(os.environ, {"PYPI_CLEANUP_PASSWORD": "test_pass", "PYPI_CLEANUP_OTP": "test_otp"}) def test_load_credentials_live_mode_success(self): """Test successful credential loading in live mode.""" password, otp = load_credentials(dry_run=False) - assert password == 'test_pass' - assert otp == 'test_otp' - + assert password == "test_pass" + assert otp == "test_otp" + @patch.dict(os.environ, {}, clear=True) def test_load_credentials_missing_password(self): """Test credential loading with missing password.""" with pytest.raises(ValidationError, match="PYPI_CLEANUP_PASSWORD"): load_credentials(dry_run=False) - - @patch.dict(os.environ, {'PYPI_CLEANUP_PASSWORD': 'test_pass'}) + + @patch.dict(os.environ, {"PYPI_CLEANUP_PASSWORD": "test_pass"}) def test_load_credentials_missing_otp(self): """Test credential loading with missing OTP.""" with pytest.raises(ValidationError, match="PYPI_CLEANUP_OTP"): @@ -105,56 +115,56 @@ def test_create_session_with_retries(self): assert isinstance(session, requests.Session) # Verify retry adapter is mounted adapter = session.get_adapter("https://example.com") - assert hasattr(adapter, 'max_retries') - retries = getattr(adapter, 'max_retries') + assert hasattr(adapter, "max_retries") + retries = getattr(adapter, "max_retries") assert isinstance(retries, Retry) - @patch('duckdb_packaging.pypi_cleanup.logging.basicConfig') + @patch("duckdb_packaging.pypi_cleanup.logging.basicConfig") def test_setup_logging_normal(self, mock_basicConfig): """Test logging setup in normal mode.""" setup_logging(verbose=False) mock_basicConfig.assert_called_once() call_args = mock_basicConfig.call_args[1] - assert call_args['level'] == 20 # INFO level + assert call_args["level"] == 20 # INFO level - @patch('duckdb_packaging.pypi_cleanup.logging.basicConfig') + @patch("duckdb_packaging.pypi_cleanup.logging.basicConfig") def test_setup_logging_verbose(self, mock_basicConfig): """Test logging setup in verbose mode.""" setup_logging(verbose=True) mock_basicConfig.assert_called_once() call_args = mock_basicConfig.call_args[1] - assert call_args['level'] == 10 # DEBUG level + assert call_args["level"] == 10 # DEBUG level class TestCsrfParser: """Test CSRF token parser.""" - + def test_csrf_parser_simple_form(self): """Test parsing CSRF token from simple form.""" - html = ''' + html = """
      - ''' + """ parser = CsrfParser("/test") parser.feed(html) assert parser.csrf == "abc123" - + def test_csrf_parser_multiple_forms(self): """Test parsing CSRF token when multiple forms exist.""" - html = ''' + html = """
      - ''' + """ parser = CsrfParser("/test") parser.feed(html) assert parser.csrf == "correct" - + def test_csrf_parser_no_token(self): """Test parser when no CSRF token is found.""" html = '
      ' @@ -165,6 +175,7 @@ def test_csrf_parser_no_token(self): class TestPyPICleanup: """Test the main PyPICleanup class.""" + @pytest.fixture def cleanup_dryrun_max_2(self) -> PyPICleanup: return PyPICleanup("https://test.pypi.org/", False, 2) @@ -175,26 +186,59 @@ def cleanup_dryrun_max_0(self) -> PyPICleanup: @pytest.fixture def cleanup_max_2(self) -> PyPICleanup: - return PyPICleanup("https://test.pypi.org/", True, 2, - username="", password="", otp="") + return PyPICleanup("https://test.pypi.org/", True, 2, username="", password="", otp="") def test_determine_versions_to_delete_max_2(self, cleanup_dryrun_max_2): start_state = { "0.1.0", - "1.0.0.dev1", "1.0.0.dev2", "1.0.0.rc1", "1.0.0", - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", "1.0.1", - "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", "1.1.0", "1.1.0.post1", - "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", - "2.0.0.dev602", "2.0.0.rc1", "2.0.0.rc2", "2.0.0.rc3", "2.0.0.rc4", "2.0.0", - "2.0.1.dev974", "2.0.1.rc1", "2.0.1.rc2", "2.0.1.rc3", + "1.0.0.dev1", + "1.0.0.dev2", + "1.0.0.rc1", + "1.0.0", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", + "1.0.1.dev13", + "1.0.1.dev21", + "1.0.1", + "1.1.0.dev34", + "1.1.0.dev54", + "1.1.0.dev88", + "1.1.0", + "1.1.0.post1", + "1.1.1.dev142", + "1.1.1.dev230", + "1.1.1.dev372", + "2.0.0.dev602", + "2.0.0.rc1", + "2.0.0.rc2", + "2.0.0.rc3", + "2.0.0.rc4", + "2.0.0", + "2.0.1.dev974", + "2.0.1.rc1", + "2.0.1.rc2", + "2.0.1.rc3", } expected_deletions = { - "1.0.0.dev1", "1.0.0.dev2", "1.0.0.rc1", - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", - "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", + "1.0.0.dev1", + "1.0.0.dev2", + "1.0.0.rc1", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", + "1.0.1.dev13", + "1.0.1.dev21", + "1.1.0.dev34", + "1.1.0.dev54", + "1.1.0.dev88", "1.1.1.dev142", - "2.0.0.dev602", "2.0.0.rc1", "2.0.0.rc2", "2.0.0.rc3", "2.0.0.rc4", - "2.0.1.dev974" + "2.0.0.dev602", + "2.0.0.rc1", + "2.0.0.rc2", + "2.0.0.rc3", + "2.0.0.rc4", + "2.0.1.dev974", } versions_to_delete = cleanup_dryrun_max_2._determine_versions_to_delete(start_state) assert versions_to_delete == expected_deletions @@ -202,35 +246,82 @@ def test_determine_versions_to_delete_max_2(self, cleanup_dryrun_max_2): def test_determine_versions_to_delete_max_0(self, cleanup_dryrun_max_0): start_state = { "0.1.0", - "1.0.0.dev1", "1.0.0.dev2", "1.0.0.rc1", "1.0.0", - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", "1.0.1", - "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", "1.1.0", "1.1.0.post1", - "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", - "2.0.0.dev602", "2.0.0.rc1", "2.0.0.rc2", "2.0.0.rc3", "2.0.0.rc4", "2.0.0", - "2.0.1.dev974", "2.0.1.rc1", "2.0.1.rc2", "2.0.1.rc3", + "1.0.0.dev1", + "1.0.0.dev2", + "1.0.0.rc1", + "1.0.0", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", + "1.0.1.dev13", + "1.0.1.dev21", + "1.0.1", + "1.1.0.dev34", + "1.1.0.dev54", + "1.1.0.dev88", + "1.1.0", + "1.1.0.post1", + "1.1.1.dev142", + "1.1.1.dev230", + "1.1.1.dev372", + "2.0.0.dev602", + "2.0.0.rc1", + "2.0.0.rc2", + "2.0.0.rc3", + "2.0.0.rc4", + "2.0.0", + "2.0.1.dev974", + "2.0.1.rc1", + "2.0.1.rc2", + "2.0.1.rc3", } expected_deletions = { - "1.0.0.dev1", "1.0.0.dev2", "1.0.0.rc1", - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", - "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", - "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", - "2.0.0.dev602", "2.0.0.rc1", "2.0.0.rc2", "2.0.0.rc3", "2.0.0.rc4", - "2.0.1.dev974" + "1.0.0.dev1", + "1.0.0.dev2", + "1.0.0.rc1", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", + "1.0.1.dev13", + "1.0.1.dev21", + "1.1.0.dev34", + "1.1.0.dev54", + "1.1.0.dev88", + "1.1.1.dev142", + "1.1.1.dev230", + "1.1.1.dev372", + "2.0.0.dev602", + "2.0.0.rc1", + "2.0.0.rc2", + "2.0.0.rc3", + "2.0.0.rc4", + "2.0.1.dev974", } versions_to_delete = cleanup_dryrun_max_0._determine_versions_to_delete(start_state) assert versions_to_delete == expected_deletions def test_determine_versions_to_delete_only_devs_max_2(self, cleanup_dryrun_max_2): start_state = { - "1.0.0.dev1", "1.0.0.dev2", - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", - "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", - "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", + "1.0.0.dev1", + "1.0.0.dev2", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", + "1.0.1.dev13", + "1.0.1.dev21", + "1.1.0.dev34", + "1.1.0.dev54", + "1.1.0.dev88", + "1.1.1.dev142", + "1.1.1.dev230", + "1.1.1.dev372", "2.0.0.dev602", "2.0.1.dev974", } expected_deletions = { - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", "1.1.0.dev34", "1.1.1.dev142", } @@ -239,19 +330,28 @@ def test_determine_versions_to_delete_only_devs_max_2(self, cleanup_dryrun_max_2 def test_determine_versions_to_delete_only_devs_max_0_fails(self, cleanup_dryrun_max_0): start_state = { - "1.0.0.dev1", "1.0.0.dev2", - "1.0.1.dev3", "1.0.1.dev5", "1.0.1.dev8", "1.0.1.dev13", "1.0.1.dev21", - "1.1.0.dev34", "1.1.0.dev54", "1.1.0.dev88", - "1.1.1.dev142", "1.1.1.dev230", "1.1.1.dev372", + "1.0.0.dev1", + "1.0.0.dev2", + "1.0.1.dev3", + "1.0.1.dev5", + "1.0.1.dev8", + "1.0.1.dev13", + "1.0.1.dev21", + "1.1.0.dev34", + "1.1.0.dev54", + "1.1.0.dev88", + "1.1.1.dev142", + "1.1.1.dev230", + "1.1.1.dev372", "2.0.0.dev602", "2.0.1.dev974", } with pytest.raises(PyPICleanupError, match="Safety check failed"): cleanup_dryrun_max_0._determine_versions_to_delete(start_state) - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._delete_versions') - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._fetch_released_versions') - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._determine_versions_to_delete') + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._delete_versions") + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._fetch_released_versions") + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._determine_versions_to_delete") def test_execute_cleanup_dry_run(self, mock_determine, mock_fetch, mock_delete, cleanup_dryrun_max_2): mock_fetch.return_value = {"1.0.0.dev1"} mock_determine.return_value = {"1.0.0.dev1"} @@ -264,14 +364,14 @@ def test_execute_cleanup_dry_run(self, mock_determine, mock_fetch, mock_delete, mock_determine.assert_called_once() mock_delete.assert_not_called() - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._fetch_released_versions') + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._fetch_released_versions") def test_execute_cleanup_no_releases(self, mock_fetch, cleanup_dryrun_max_2): mock_fetch.return_value = {} with session_with_retries() as session: result = cleanup_dryrun_max_2._execute_cleanup(session) assert result == 0 - @patch('requests.Session.get') + @patch("requests.Session.get") def test_fetch_released_versions_success(self, mock_get, cleanup_dryrun_max_2): """Test successful package release fetching.""" mock_response = Mock() @@ -288,7 +388,7 @@ def test_fetch_released_versions_success(self, mock_get, cleanup_dryrun_max_2): assert releases == {"1.0.0", "1.0.0.dev1"} - @patch('requests.Session.get') + @patch("requests.Session.get") def test_fetch_released_versions_not_found(self, mock_get, cleanup_dryrun_max_2): """Test package release fetching when package not found.""" mock_response = Mock() @@ -299,8 +399,8 @@ def test_fetch_released_versions_not_found(self, mock_get, cleanup_dryrun_max_2) with session_with_retries() as session: cleanup_dryrun_max_2._fetch_released_versions(session) - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._get_csrf_token') - @patch('requests.Session.post') + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._get_csrf_token") + @patch("requests.Session.post") def test_authenticate_success(self, mock_post, mock_csrf, cleanup_max_2): """Test successful authentication.""" mock_csrf.return_value = "csrf123" @@ -313,11 +413,11 @@ def test_authenticate_success(self, mock_post, mock_csrf, cleanup_max_2): mock_csrf.assert_called_once_with(session, "/account/login/") mock_post.assert_called_once() - assert mock_post.call_args.args[0].endswith('/account/login/') + assert mock_post.call_args.args[0].endswith("/account/login/") - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._get_csrf_token') - @patch('requests.Session.post') - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._handle_two_factor_auth') + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._get_csrf_token") + @patch("requests.Session.post") + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._handle_two_factor_auth") def test_authenticate_with_2fa(self, mock_2fa, mock_post, mock_csrf, cleanup_max_2): mock_csrf.return_value = "csrf123" mock_response = Mock() @@ -332,7 +432,7 @@ def test_authenticate_missing_credentials(self, cleanup_dryrun_max_2): with pytest.raises(AuthenticationError, match="Username and password are required"): cleanup_dryrun_max_2._authenticate(None) - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._delete_single_version') + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._delete_single_version") def test_delete_versions_success(self, mock_delete, cleanup_max_2): """Test successful version deletion.""" versions = {"1.0.0.dev1", "1.0.0.dev2"} @@ -343,7 +443,7 @@ def test_delete_versions_success(self, mock_delete, cleanup_max_2): assert mock_delete.call_count == 2 - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup._delete_single_version') + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup._delete_single_version") def test_delete_versions_partial_failure(self, mock_delete, cleanup_max_2): """Test version deletion with partial failures.""" versions = {"1.0.0.dev1", "1.0.0.dev2"} @@ -360,75 +460,75 @@ def test_delete_single_version_safety_check(self, cleanup_max_2): class TestArgumentParser: """Test command line argument parsing.""" - + def test_argument_parser_creation(self): """Test argument parser creation.""" parser = create_argument_parser() assert parser.prog is not None - + def test_parse_args_prod_dry_run(self): """Test parsing arguments for production dry run.""" parser = create_argument_parser() - args = parser.parse_args(['--prod', '--dry-run']) - + args = parser.parse_args(["--prod", "--dry-run"]) + assert args.prod is True assert args.test is False assert args.dry_run is True assert args.max_nightlies == 2 assert args.verbose is False - + def test_parse_args_test_with_username(self): """Test parsing arguments for test with username.""" parser = create_argument_parser() - args = parser.parse_args(['--test', '-u', 'testuser', '--verbose']) - + args = parser.parse_args(["--test", "-u", "testuser", "--verbose"]) + assert args.test is True assert args.prod is False - assert args.username == 'testuser' + assert args.username == "testuser" assert args.verbose is True - + def test_parse_args_missing_host(self): """Test parsing arguments with missing host selection.""" parser = create_argument_parser() - + with pytest.raises(SystemExit): - parser.parse_args(['--dry-run']) # Missing --prod or --test + parser.parse_args(["--dry-run"]) # Missing --prod or --test class TestMainFunction: """Test the main function.""" - - @patch('duckdb_packaging.pypi_cleanup.setup_logging') - @patch('duckdb_packaging.pypi_cleanup.PyPICleanup') - @patch.dict(os.environ, {'PYPI_CLEANUP_PASSWORD': 'test', 'PYPI_CLEANUP_OTP': 'test'}) + + @patch("duckdb_packaging.pypi_cleanup.setup_logging") + @patch("duckdb_packaging.pypi_cleanup.PyPICleanup") + @patch.dict(os.environ, {"PYPI_CLEANUP_PASSWORD": "test", "PYPI_CLEANUP_OTP": "test"}) def test_main_success(self, mock_cleanup_class, mock_setup_logging): """Test successful main function execution.""" mock_cleanup = Mock() mock_cleanup.run.return_value = 0 mock_cleanup_class.return_value = mock_cleanup - - with patch('sys.argv', ['pypi_cleanup.py', '--test', '-u', 'testuser']): + + with patch("sys.argv", ["pypi_cleanup.py", "--test", "-u", "testuser"]): result = main() - + assert result == 0 mock_setup_logging.assert_called_once() mock_cleanup.run.assert_called_once() - - @patch('duckdb_packaging.pypi_cleanup.setup_logging') + + @patch("duckdb_packaging.pypi_cleanup.setup_logging") def test_main_validation_error(self, mock_setup_logging): """Test main function with validation error.""" - with patch('sys.argv', ['pypi_cleanup.py', '--test']): # Missing username for live mode + with patch("sys.argv", ["pypi_cleanup.py", "--test"]): # Missing username for live mode result = main() - + assert result == 2 # Validation error exit code - - @patch('duckdb_packaging.pypi_cleanup.setup_logging') - @patch('duckdb_packaging.pypi_cleanup.validate_arguments') + + @patch("duckdb_packaging.pypi_cleanup.setup_logging") + @patch("duckdb_packaging.pypi_cleanup.validate_arguments") def test_main_keyboard_interrupt(self, mock_validate, mock_setup_logging): """Test main function with keyboard interrupt.""" mock_validate.side_effect = KeyboardInterrupt() - - with patch('sys.argv', ['pypi_cleanup.py', '--test', '--dry-run']): + + with patch("sys.argv", ["pypi_cleanup.py", "--test", "--dry-run"]): result = main() - + assert result == 130 # Keyboard interrupt exit code diff --git a/tests/fast/test_pytorch.py b/tests/fast/test_pytorch.py index 365585cc..c5b9b4d6 100644 --- a/tests/fast/test_pytorch.py +++ b/tests/fast/test_pytorch.py @@ -2,7 +2,7 @@ import pytest -torch = pytest.importorskip('torch') +torch = pytest.importorskip("torch") @pytest.mark.skip(reason="some issues with Numpy, to be reverted") @@ -15,16 +15,16 @@ def test_pytorch(): # Test from connection duck_torch = con.execute("select * from t").torch() duck_numpy = con.sql("select * from t").fetchnumpy() - torch.equal(duck_torch['a'], torch.tensor(duck_numpy['a'])) - torch.equal(duck_torch['b'], torch.tensor(duck_numpy['b'])) + torch.equal(duck_torch["a"], torch.tensor(duck_numpy["a"])) + torch.equal(duck_torch["b"], torch.tensor(duck_numpy["b"])) # Test from relation duck_torch = con.sql("select * from t").torch() - torch.equal(duck_torch['a'], torch.tensor(duck_numpy['a'])) - torch.equal(duck_torch['b'], torch.tensor(duck_numpy['b'])) + torch.equal(duck_torch["a"], torch.tensor(duck_numpy["a"])) + torch.equal(duck_torch["b"], torch.tensor(duck_numpy["b"])) # Test all Numeric Types - numeric_types = ['TINYINT', 'SMALLINT', 'BIGINT', 'HUGEINT', 'FLOAT', 'DOUBLE', 'DECIMAL(4,1)', 'UTINYINT'] + numeric_types = ["TINYINT", "SMALLINT", "BIGINT", "HUGEINT", "FLOAT", "DOUBLE", "DECIMAL(4,1)", "UTINYINT"] for supported_type in numeric_types: con = duckdb.connect() @@ -32,8 +32,8 @@ def test_pytorch(): con.execute("insert into t values (1,2), (3,4)") duck_torch = con.sql("select * from t").torch() duck_numpy = con.sql("select * from t").fetchnumpy() - torch.equal(duck_torch['a'], torch.tensor(duck_numpy['a'])) - torch.equal(duck_torch['b'], torch.tensor(duck_numpy['b'])) + torch.equal(duck_torch["a"], torch.tensor(duck_numpy["a"])) + torch.equal(duck_torch["b"], torch.tensor(duck_numpy["b"])) # Comment out test that might fail or not depending on pytorch versions # with pytest.raises(TypeError, match="can't convert"): diff --git a/tests/fast/test_relation.py b/tests/fast/test_relation.py index 8e68c149..31ca393c 100644 --- a/tests/fast/test_relation.py +++ b/tests/fast/test_relation.py @@ -37,10 +37,10 @@ def test_csv_auto(self): csv_rel = duckdb.from_csv_auto(temp_file_name) assert df_rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_relation_view(self, duckdb_cursor, pandas): def create_view(duckdb_cursor): - df_in = pandas.DataFrame({'numbers': [1, 2, 3, 4, 5]}) + df_in = pandas.DataFrame({"numbers": [1, 2, 3, 4, 5]}) rel = duckdb_cursor.query("select * from df_in") rel.to_view("my_view") @@ -59,23 +59,23 @@ def create_view(duckdb_cursor): def test_filter_operator(self): conn = duckdb.connect() rel = get_relation(conn) - assert rel.filter('i > 1').execute().fetchall() == [(2, 'two'), (3, 'three'), (4, 'four')] + assert rel.filter("i > 1").execute().fetchall() == [(2, "two"), (3, "three"), (4, "four")] def test_projection_operator_single(self): conn = duckdb.connect() rel = get_relation(conn) - assert rel.project('i').execute().fetchall() == [(1,), (2,), (3,), (4,)] + assert rel.project("i").execute().fetchall() == [(1,), (2,), (3,), (4,)] def test_projection_operator_double(self): conn = duckdb.connect() rel = get_relation(conn) - assert rel.order('j').execute().fetchall() == [(4, 'four'), (1, 'one'), (3, 'three'), (2, 'two')] + assert rel.order("j").execute().fetchall() == [(4, "four"), (1, "one"), (3, "three"), (2, "two")] def test_limit_operator(self): conn = duckdb.connect() rel = get_relation(conn) - assert rel.limit(2).execute().fetchall() == [(1, 'one'), (2, 'two')] - assert rel.limit(2, offset=1).execute().fetchall() == [(2, 'two'), (3, 'three')] + assert rel.limit(2).execute().fetchall() == [(1, "one"), (2, "two")] + assert rel.limit(2, offset=1).execute().fetchall() == [(2, "two"), (3, "three")] def test_intersect_operator(self): conn = duckdb.connect() @@ -86,23 +86,23 @@ def test_intersect_operator(self): rel = conn.from_df(test_df) rel_2 = conn.from_df(test_df_2) - assert rel.intersect(rel_2).order('i').execute().fetchall() == [(3,), (4,)] + assert rel.intersect(rel_2).order("i").execute().fetchall() == [(3,), (4,)] def test_aggregate_operator(self): conn = duckdb.connect() rel = get_relation(conn) assert rel.aggregate("sum(i)").execute().fetchall() == [(10,)] - assert rel.aggregate("j, sum(i)").order('#2').execute().fetchall() == [ - ('one', 1), - ('two', 2), - ('three', 3), - ('four', 4), + assert rel.aggregate("j, sum(i)").order("#2").execute().fetchall() == [ + ("one", 1), + ("two", 2), + ("three", 3), + ("four", 4), ] def test_relation_fetch_df_chunk(self, duckdb_cursor): duckdb_cursor.execute(f"create table tbl as select * from range({duckdb.__standard_vector_size__ * 3})") - rel = duckdb_cursor.table('tbl') + rel = duckdb_cursor.table("tbl") # default arguments df1 = rel.fetch_df_chunk() assert len(df1) == duckdb.__standard_vector_size__ @@ -114,40 +114,40 @@ def test_relation_fetch_df_chunk(self, duckdb_cursor): f"create table dates as select (DATE '2021/02/21' + INTERVAL (i) DAYS)::DATE a from range({duckdb.__standard_vector_size__ * 4}) t(i)" ) - rel = duckdb_cursor.table('dates') + rel = duckdb_cursor.table("dates") # default arguments df1 = rel.fetch_df_chunk() assert len(df1) == duckdb.__standard_vector_size__ - assert df1['a'][0].__class__ == pd.Timestamp + assert df1["a"][0].__class__ == pd.Timestamp # date as object df1 = rel.fetch_df_chunk(date_as_object=True) assert len(df1) == duckdb.__standard_vector_size__ - assert df1['a'][0].__class__ == datetime.date + assert df1["a"][0].__class__ == datetime.date # vectors and date as object df1 = rel.fetch_df_chunk(2, date_as_object=True) assert len(df1) == duckdb.__standard_vector_size__ * 2 - assert df1['a'][0].__class__ == datetime.date + assert df1["a"][0].__class__ == datetime.date def test_distinct_operator(self): conn = duckdb.connect() rel = get_relation(conn) - assert rel.distinct().order('all').execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'), (4, 'four')] + assert rel.distinct().order("all").execute().fetchall() == [(1, "one"), (2, "two"), (3, "three"), (4, "four")] def test_union_operator(self): conn = duckdb.connect() rel = get_relation(conn) print(rel.union(rel).execute().fetchall()) assert rel.union(rel).execute().fetchall() == [ - (1, 'one'), - (2, 'two'), - (3, 'three'), - (4, 'four'), - (1, 'one'), - (2, 'two'), - (3, 'three'), - (4, 'four'), + (1, "one"), + (2, "two"), + (3, "three"), + (4, "four"), + (1, "one"), + (2, "two"), + (3, "three"), + (4, "four"), ] def test_join_operator(self): @@ -156,11 +156,11 @@ def test_join_operator(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) rel = conn.from_df(test_df) rel2 = conn.from_df(test_df) - assert rel.join(rel2, 'i').execute().fetchall() == [ - (1, 'one', 'one'), - (2, 'two', 'two'), - (3, 'three', 'three'), - (4, 'four', 'four'), + assert rel.join(rel2, "i").execute().fetchall() == [ + (1, "one", "one"), + (2, "two", "two"), + (3, "three", "three"), + (4, "four", "four"), ] def test_except_operator(self): @@ -176,10 +176,10 @@ def test_create_operator(self): rel = conn.from_df(test_df) rel.create("test_df") assert conn.query("select * from test_df").execute().fetchall() == [ - (1, 'one'), - (2, 'two'), - (3, 'three'), - (4, 'four'), + (1, "one"), + (2, "two"), + (3, "three"), + (4, "four"), ] def test_create_view_operator(self): @@ -188,31 +188,31 @@ def test_create_view_operator(self): rel = conn.from_df(test_df) rel.create_view("test_df") assert conn.query("select * from test_df").execute().fetchall() == [ - (1, 'one'), - (2, 'two'), - (3, 'three'), - (4, 'four'), + (1, "one"), + (2, "two"), + (3, "three"), + (4, "four"), ] def test_update_relation(self, duckdb_cursor): duckdb_cursor.sql("create table tbl (a varchar default 'test', b int)") - duckdb_cursor.table('tbl').insert(['hello', 21]) - duckdb_cursor.table('tbl').insert(['hello', 42]) + duckdb_cursor.table("tbl").insert(["hello", 21]) + duckdb_cursor.table("tbl").insert(["hello", 42]) # UPDATE tbl SET a = DEFAULT where b = 42 - duckdb_cursor.table('tbl').update( - {'a': duckdb.DefaultExpression()}, condition=duckdb.ColumnExpression('b') == 42 + duckdb_cursor.table("tbl").update( + {"a": duckdb.DefaultExpression()}, condition=duckdb.ColumnExpression("b") == 42 ) - assert duckdb_cursor.table('tbl').fetchall() == [('hello', 21), ('test', 42)] + assert duckdb_cursor.table("tbl").fetchall() == [("hello", 21), ("test", 42)] - rel = duckdb_cursor.table('tbl') - with pytest.raises(duckdb.InvalidInputException, match='Please provide at least one set expression'): + rel = duckdb_cursor.table("tbl") + with pytest.raises(duckdb.InvalidInputException, match="Please provide at least one set expression"): rel.update({}) with pytest.raises( - duckdb.InvalidInputException, match='Please provide the column name as the key of the dictionary' + duckdb.InvalidInputException, match="Please provide the column name as the key of the dictionary" ): rel.update({1: 21}) - with pytest.raises(duckdb.BinderException, match='Referenced update column c not found in table!'): - rel.update({'c': 21}) + with pytest.raises(duckdb.BinderException, match="Referenced update column c not found in table!"): + rel.update({"c": 21}) with pytest.raises( duckdb.InvalidInputException, match="Please provide 'set' as a dictionary of column name to Expression" ): @@ -221,11 +221,11 @@ def test_update_relation(self, duckdb_cursor): duckdb.InvalidInputException, match="Please provide an object of type Expression as the value, not ", ): - rel.update({'a': {21}}) + rel.update({"a": {21}}) def test_value_relation(self, duckdb_cursor): # Needs at least one input - with pytest.raises(duckdb.InvalidInputException, match='Could not create a ValueRelation without any inputs'): + with pytest.raises(duckdb.InvalidInputException, match="Could not create a ValueRelation without any inputs"): duckdb_cursor.values() # From a list of (python) values @@ -233,28 +233,28 @@ def test_value_relation(self, duckdb_cursor): assert rel.fetchall() == [(1, 2, 3)] # From an Expression - rel = duckdb_cursor.values(duckdb.ConstantExpression('test')) - assert rel.fetchall() == [('test',)] + rel = duckdb_cursor.values(duckdb.ConstantExpression("test")) + assert rel.fetchall() == [("test",)] # From multiple Expressions rel = duckdb_cursor.values( - duckdb.ConstantExpression('1'), duckdb.ConstantExpression('2'), duckdb.ConstantExpression('3') + duckdb.ConstantExpression("1"), duckdb.ConstantExpression("2"), duckdb.ConstantExpression("3") ) - assert rel.fetchall() == [('1', '2', '3')] + assert rel.fetchall() == [("1", "2", "3")] # From Expressions mixed with random values - with pytest.raises(duckdb.InvalidInputException, match='Please provide arguments of type Expression!'): + with pytest.raises(duckdb.InvalidInputException, match="Please provide arguments of type Expression!"): rel = duckdb_cursor.values( - duckdb.ConstantExpression('1'), - {'test'}, - duckdb.ConstantExpression('3'), + duckdb.ConstantExpression("1"), + {"test"}, + duckdb.ConstantExpression("3"), ) # From Expressions mixed with values that *can* be autocast to Expression rel = duckdb_cursor.values( - duckdb.ConstantExpression('1'), + duckdb.ConstantExpression("1"), 2, - duckdb.ConstantExpression('3'), + duckdb.ConstantExpression("3"), ) const = duckdb.ConstantExpression @@ -264,21 +264,21 @@ def test_value_relation(self, duckdb_cursor): # From mismatching tuples of Expressions with pytest.raises( - duckdb.InvalidInputException, match='Mismatch between length of tuples in input, expected 3 but found 2' + duckdb.InvalidInputException, match="Mismatch between length of tuples in input, expected 3 but found 2" ): rel = duckdb_cursor.values((const(1), const(2), const(3)), (const(5), const(4))) # From an empty tuple - with pytest.raises(duckdb.InvalidInputException, match='Please provide a non-empty tuple'): + with pytest.raises(duckdb.InvalidInputException, match="Please provide a non-empty tuple"): rel = duckdb_cursor.values(()) # Mixing tuples with Expressions - with pytest.raises(duckdb.InvalidInputException, match='Expected objects of type tuple'): + with pytest.raises(duckdb.InvalidInputException, match="Expected objects of type tuple"): rel = duckdb_cursor.values((const(1), const(2), const(3)), const(4)) # Using Expressions that can't be resolved: with pytest.raises(duckdb.BinderException, match='Referenced column "a" not found in FROM clause!'): - duckdb_cursor.values(duckdb.ColumnExpression('a')) + duckdb_cursor.values(duckdb.ColumnExpression("a")) def test_insert_into_operator(self): conn = duckdb.connect() @@ -290,17 +290,17 @@ def test_insert_into_operator(self): rel.insert_into("test_table3") # Inserting elements into table_3 - print(conn.values([5, 'five']).insert_into("test_table3")) + print(conn.values([5, "five"]).insert_into("test_table3")) rel_3 = conn.table("test_table3") - rel_3.insert([6, 'six']) + rel_3.insert([6, "six"]) assert rel_3.execute().fetchall() == [ - (1, 'one'), - (2, 'two'), - (3, 'three'), - (4, 'four'), - (5, 'five'), - (6, 'six'), + (1, "one"), + (2, "two"), + (3, "three"), + (4, "four"), + (5, "five"), + (6, "six"), ] def test_write_csv_operator(self): @@ -316,8 +316,8 @@ def test_table_update_with_schema(self, duckdb_cursor): duckdb_cursor.sql("create schema not_main;") duckdb_cursor.sql("create table not_main.tbl as select * from range(10) t(a)") - duckdb_cursor.table('not_main.tbl').update({'a': 21}, condition=ColumnExpression('a') == 5) - res = duckdb_cursor.table('not_main.tbl').fetchall() + duckdb_cursor.table("not_main.tbl").update({"a": 21}, condition=ColumnExpression("a") == 5) + res = duckdb_cursor.table("not_main.tbl").fetchall() assert res == [(0,), (1,), (2,), (3,), (4,), (21,), (6,), (7,), (8,), (9,)] def test_table_update_with_catalog(self, duckdb_cursor): @@ -325,8 +325,8 @@ def test_table_update_with_catalog(self, duckdb_cursor): duckdb_cursor.sql("create schema pg.not_main;") duckdb_cursor.sql("create table pg.not_main.tbl as select * from range(10) t(a)") - duckdb_cursor.table('pg.not_main.tbl').update({'a': 21}, condition=ColumnExpression('a') == 5) - res = duckdb_cursor.table('pg.not_main.tbl').fetchall() + duckdb_cursor.table("pg.not_main.tbl").update({"a": 21}, condition=ColumnExpression("a") == 5) + res = duckdb_cursor.table("pg.not_main.tbl").fetchall() assert res == [(0,), (1,), (2,), (3,), (4,), (21,), (6,), (7,), (8,), (9,)] def test_get_attr_operator(self): @@ -335,50 +335,50 @@ def test_get_attr_operator(self): rel = conn.table("test") assert rel.alias == "test" assert rel.type == "TABLE_RELATION" - assert rel.columns == ['i'] - assert rel.types == ['INTEGER'] + assert rel.columns == ["i"] + assert rel.types == ["INTEGER"] def test_query_fail(self): conn = duckdb.connect() conn.execute("CREATE TABLE test (i INTEGER)") rel = conn.table("test") - with pytest.raises(TypeError, match='incompatible function arguments'): + with pytest.raises(TypeError, match="incompatible function arguments"): rel.query("select j from test") def test_execute_fail(self): conn = duckdb.connect() conn.execute("CREATE TABLE test (i INTEGER)") rel = conn.table("test") - with pytest.raises(TypeError, match='incompatible function arguments'): + with pytest.raises(TypeError, match="incompatible function arguments"): rel.execute("select j from test") def test_df_proj(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) - rel = duckdb.project(test_df, 'i') + rel = duckdb.project(test_df, "i") assert rel.execute().fetchall() == [(1,), (2,), (3,), (4,)] def test_relation_lifetime(self, duckdb_cursor): def create_relation(con): - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) return con.sql("select * from df") assert create_relation(duckdb_cursor).fetchall() == [(1,), (2,), (3,)] def create_simple_join(con): - df1 = pd.DataFrame({'a': ['a', 'b', 'c'], 'b': [1, 2, 3]}) - df2 = pd.DataFrame({'a': ['a', 'b', 'c'], 'b': [4, 5, 6]}) + df1 = pd.DataFrame({"a": ["a", "b", "c"], "b": [1, 2, 3]}) + df2 = pd.DataFrame({"a": ["a", "b", "c"], "b": [4, 5, 6]}) return con.sql("select * from df1 JOIN df2 USING (a, a)") - assert create_simple_join(duckdb_cursor).fetchall() == [('a', 1, 4), ('b', 2, 5), ('c', 3, 6)] + assert create_simple_join(duckdb_cursor).fetchall() == [("a", 1, 4), ("b", 2, 5), ("c", 3, 6)] def create_complex_join(con): - df1 = pd.DataFrame({'a': [1], '1': [1]}) - df2 = pd.DataFrame({'a': [1], '2': [2]}) - df3 = pd.DataFrame({'a': [1], '3': [3]}) - df4 = pd.DataFrame({'a': [1], '4': [4]}) - df5 = pd.DataFrame({'a': [1], '5': [5]}) - df6 = pd.DataFrame({'a': [1], '6': [6]}) + df1 = pd.DataFrame({"a": [1], "1": [1]}) + df2 = pd.DataFrame({"a": [1], "2": [2]}) + df3 = pd.DataFrame({"a": [1], "3": [3]}) + df4 = pd.DataFrame({"a": [1], "4": [4]}) + df5 = pd.DataFrame({"a": [1], "5": [5]}) + df6 = pd.DataFrame({"a": [1], "6": [6]}) query = "select * from df1" for i in range(5): query += f" JOIN df{i + 2} USING (a, a)" @@ -407,7 +407,7 @@ def test_project_on_types(self): assert projection.columns == ["c2", "c4"] # select bigint, tinyint and a type that isn't there - projection = rel.select_types([BIGINT, "tinyint", con.struct_type({'a': VARCHAR, 'b': TINYINT})]) + projection = rel.select_types([BIGINT, "tinyint", con.struct_type({"a": VARCHAR, "b": TINYINT})]) assert projection.columns == ["c0", "c1"] ## select with empty projection list, not possible @@ -420,30 +420,30 @@ def test_project_on_types(self): def test_df_alias(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) - rel = duckdb.alias(test_df, 'dfzinho') + rel = duckdb.alias(test_df, "dfzinho") assert rel.alias == "dfzinho" def test_df_filter(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) - rel = duckdb.filter(test_df, 'i > 1') - assert rel.execute().fetchall() == [(2, 'two'), (3, 'three'), (4, 'four')] + rel = duckdb.filter(test_df, "i > 1") + assert rel.execute().fetchall() == [(2, "two"), (3, "three"), (4, "four")] def test_df_order_by(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) - rel = duckdb.order(test_df, 'j') - assert rel.execute().fetchall() == [(4, 'four'), (1, 'one'), (3, 'three'), (2, 'two')] + rel = duckdb.order(test_df, "j") + assert rel.execute().fetchall() == [(4, "four"), (1, "one"), (3, "three"), (2, "two")] def test_df_distinct(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) - rel = duckdb.distinct(test_df).order('i') - assert rel.execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'), (4, 'four')] + rel = duckdb.distinct(test_df).order("i") + assert rel.execute().fetchall() == [(1, "one"), (2, "two"), (3, "three"), (4, "four")] def test_df_write_csv(self): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]}) temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) duckdb.write_csv(test_df, temp_file_name) csv_rel = duckdb.from_csv_auto(temp_file_name) - assert csv_rel.execute().fetchall() == [(1, 'one'), (2, 'two'), (3, 'three'), (4, 'four')] + assert csv_rel.execute().fetchall() == [(1, "one"), (2, "two"), (3, "three"), (4, "four")] def test_join_types(self): test_df1 = pd.DataFrame.from_dict({"i": [1, 2, 3, 4]}) @@ -452,9 +452,9 @@ def test_join_types(self): rel1 = con.from_df(test_df1) rel2 = con.from_df(test_df2) - assert rel1.join(rel2, 'i=j', 'inner').aggregate('count()').fetchone()[0] == 2 + assert rel1.join(rel2, "i=j", "inner").aggregate("count()").fetchone()[0] == 2 - assert rel1.join(rel2, 'i=j', 'left').aggregate('count()').fetchone()[0] == 4 + assert rel1.join(rel2, "i=j", "left").aggregate("count()").fetchone()[0] == 4 def test_fetchnumpy(self): start, stop = -1000, 2000 @@ -493,10 +493,10 @@ def counter(): counter.count = 0 conn = duckdb.connect() - conn.create_function('my_counter', counter, [], BIGINT) + conn.create_function("my_counter", counter, [], BIGINT) # Create a relation - rel = conn.sql('select my_counter()') + rel = conn.sql("select my_counter()") # Execute the relation once rel.fetchall() assert counter.count == 1 @@ -508,20 +508,20 @@ def counter(): assert counter.count == 2 # Verify that the query is run at least once if it's closed before it was executed. - rel = conn.sql('select my_counter()') + rel = conn.sql("select my_counter()") rel.close() assert counter.count == 3 def test_relation_print(self): con = duckdb.connect() con.execute("Create table t1 as select * from range(1000000)") - rel1 = con.table('t1') + rel1 = con.table("t1") text1 = str(rel1) - assert '? rows' in text1 - assert '>9999 rows' in text1 + assert "? rows" in text1 + assert ">9999 rows" in text1 @pytest.mark.parametrize( - 'num_rows', + "num_rows", [ 1024, 2048, @@ -563,7 +563,7 @@ def test_materialized_relation(self, duckdb_cursor, num_rows): assert len(res) == num_rows rel = duckdb_cursor.sql(query) - projection = rel.select('column0') + projection = rel.select("column0") assert projection.fetchall() == [(42,) for _ in range(num_rows)] filtered = rel.filter("column1 != 'test'") @@ -575,58 +575,58 @@ def test_materialized_relation(self, duckdb_cursor, num_rows): ): rel.insert([1, 2, 3, 4]) - query_rel = rel.query('x', "select 42 from x where column0 != 42") + query_rel = rel.query("x", "select 42 from x where column0 != 42") assert query_rel.fetchall() == [] distinct_rel = rel.distinct() - assert distinct_rel.fetchall() == [(42, 'test', 'this is a long string', True)] + assert distinct_rel.fetchall() == [(42, "test", "this is a long string", True)] limited_rel = rel.limit(50) assert len(limited_rel.fetchall()) == 50 # Using parameters also results in a MaterializedRelation materialized_one = duckdb_cursor.sql("select * from range(?)", params=[10]).project( - ColumnExpression('range').cast(str).alias('range') + ColumnExpression("range").cast(str).alias("range") ) materialized_two = duckdb_cursor.sql("call repeat('a', 5)") - joined_rel = materialized_one.join(materialized_two, 'range != a') + joined_rel = materialized_one.join(materialized_two, "range != a") res = joined_rel.fetchall() assert len(res) == 50 relation = duckdb_cursor.sql("select a from materialized_two") - assert relation.fetchone() == ('a',) + assert relation.fetchone() == ("a",) described = materialized_one.describe() res = described.fetchall() - assert res == [('count', '10'), ('mean', None), ('stddev', None), ('min', '0'), ('max', '9'), ('median', None)] + assert res == [("count", "10"), ("mean", None), ("stddev", None), ("min", "0"), ("max", "9"), ("median", None)] unioned_rel = materialized_one.union(materialized_two) res = unioned_rel.fetchall() assert res == [ - ('0',), - ('1',), - ('2',), - ('3',), - ('4',), - ('5',), - ('6',), - ('7',), - ('8',), - ('9',), - ('a',), - ('a',), - ('a',), - ('a',), - ('a',), + ("0",), + ("1",), + ("2",), + ("3",), + ("4",), + ("5",), + ("6",), + ("7",), + ("8",), + ("9",), + ("a",), + ("a",), + ("a",), + ("a",), + ("a",), ] except_rel = unioned_rel.except_(materialized_one) res = except_rel.fetchall() - assert res == [tuple('a') for _ in range(5)] + assert res == [tuple("a") for _ in range(5)] - intersect_rel = unioned_rel.intersect(materialized_one).order('range') + intersect_rel = unioned_rel.intersect(materialized_one).order("range") res = intersect_rel.fetchall() - assert res == [('0',), ('1',), ('2',), ('3',), ('4',), ('5',), ('6',), ('7',), ('8',), ('9',)] + assert res == [("0",), ("1",), ("2",), ("3",), ("4",), ("5",), ("6",), ("7",), ("8",), ("9",)] def test_materialized_relation_view(self, duckdb_cursor): def create_view(duckdb_cursor): @@ -635,11 +635,11 @@ def create_view(duckdb_cursor): create table tbl(a varchar); insert into tbl values ('test') returning * """ - ).to_view('vw') + ).to_view("vw") create_view(duckdb_cursor) res = duckdb_cursor.sql("select * from vw").fetchone() - assert res == ('test',) + assert res == ("test",) def test_materialized_relation_view2(self, duckdb_cursor): # This creates a MaterializedRelation @@ -654,7 +654,7 @@ def test_materialized_relation_view2(self, duckdb_cursor): # The VIEW still works because the CDC that is being referenced is kept alive through the MaterializedDependency item rel = duckdb_cursor.sql("select * from test") res = rel.fetchall() - assert res == [([2], ['Alice'])] + assert res == [([2], ["Alice"])] def test_serialized_materialized_relation(self, tmp_database): con = duckdb.connect(tmp_database) @@ -663,9 +663,9 @@ def create_view(con, view_name: str): rel = con.sql("select 'this is not a small string ' || range::varchar from range(?)", params=[10]) rel.to_view(view_name) - expected = [(f'this is not a small string {i}',) for i in range(10)] + expected = [(f"this is not a small string {i}",) for i in range(10)] - create_view(con, 'vw') + create_view(con, "vw") res = con.sql("select * from vw").fetchall() assert res == expected diff --git a/tests/fast/test_relation_dependency_leak.py b/tests/fast/test_relation_dependency_leak.py index ca505704..ee98e30a 100644 --- a/tests/fast/test_relation_dependency_leak.py +++ b/tests/fast/test_relation_dependency_leak.py @@ -31,13 +31,13 @@ def from_df(pandas, duckdb_cursor): def from_arrow(pandas, duckdb_cursor): data = pa.array(np.random.rand(1_000_000), type=pa.float32()) - arrow_table = pa.Table.from_arrays([data], ['a']) + arrow_table = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_table) def arrow_replacement(pandas, duckdb_cursor): data = pa.array(np.random.rand(1_000_000), type=pa.float32()) - arrow_table = pa.Table.from_arrays([data], ['a']) + arrow_table = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.query("select sum(a) from arrow_table").fetchall() @@ -47,27 +47,27 @@ def pandas_replacement(pandas, duckdb_cursor): class TestRelationDependencyMemoryLeak(object): - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_arrow_leak(self, pandas, duckdb_cursor): if not can_run: return check_memory(from_arrow, pandas, duckdb_cursor) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_df_leak(self, pandas, duckdb_cursor): check_memory(from_df, pandas, duckdb_cursor) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_arrow_replacement_scan_leak(self, pandas, duckdb_cursor): if not can_run: return check_memory(arrow_replacement, pandas, duckdb_cursor) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_replacement_scan_leak(self, pandas, duckdb_cursor): check_memory(pandas_replacement, pandas, duckdb_cursor) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_relation_view_leak(self, pandas, duckdb_cursor): rel = from_df(pandas, duckdb_cursor) rel.create_view("bla") diff --git a/tests/fast/test_replacement_scan.py b/tests/fast/test_replacement_scan.py index 0cf69356..555773dc 100644 --- a/tests/fast/test_replacement_scan.py +++ b/tests/fast/test_replacement_scan.py @@ -8,13 +8,13 @@ def using_table(con, to_scan, object_name): - local_scope = {'con': con, object_name: to_scan, 'object_name': object_name} + local_scope = {"con": con, object_name: to_scan, "object_name": object_name} exec(f"result = con.table(object_name)", globals(), local_scope) return local_scope["result"] def using_sql(con, to_scan, object_name): - local_scope = {'con': con, object_name: to_scan, 'object_name': object_name} + local_scope = {"con": con, object_name: to_scan, "object_name": object_name} exec(f"result = con.sql('select * from \"{object_name}\"')", globals(), local_scope) return local_scope["result"] @@ -60,40 +60,40 @@ def fetch_relation(rel): def from_pandas(): - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) return df def from_arrow(): - schema = pa.schema([('field_1', pa.int64())]) + schema = pa.schema([("field_1", pa.int64())]) df = pa.RecordBatchReader.from_batches(schema, [pa.RecordBatch.from_arrays([pa.array([1, 2, 3])], schema=schema)]) return df def create_relation(conn, query: str) -> duckdb.DuckDBPyRelation: - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) return conn.sql(query) class TestReplacementScan(object): def test_csv_replacement(self): con = duckdb.connect() - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'integers.csv') + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "integers.csv") res = con.execute("select count(*) from '%s'" % (filename)) assert res.fetchone()[0] == 2 def test_parquet_replacement(self): con = duckdb.connect() - filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data', 'binary_string.parquet') + filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "binary_string.parquet") res = con.execute("select count(*) from '%s'" % (filename)) assert res.fetchone()[0] == 3 - @pytest.mark.parametrize('get_relation', [using_table, using_sql]) + @pytest.mark.parametrize("get_relation", [using_table, using_sql]) @pytest.mark.parametrize( - 'fetch_method', + "fetch_method", [fetch_polars, fetch_df, fetch_arrow, fetch_arrow_table, fetch_arrow_record_batch, fetch_relation], ) - @pytest.mark.parametrize('object_name', ['tbl', 'table', 'select', 'update']) + @pytest.mark.parametrize("object_name", ["tbl", "table", "select", "update"]) def test_table_replacement_scans(self, duckdb_cursor, get_relation, fetch_method, object_name): base_rel = duckdb_cursor.values([1, 2, 3]) to_scan = fetch_method(base_rel) @@ -105,29 +105,29 @@ def test_table_replacement_scans(self, duckdb_cursor, get_relation, fetch_method def test_scan_global(self, duckdb_cursor): duckdb_cursor.execute("set python_enable_replacements=false") - with pytest.raises(duckdb.CatalogException, match='Table with name global_polars_df does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name global_polars_df does not exist"): # We set the depth to look for global variables to 0 so it's never found duckdb_cursor.sql("select * from global_polars_df") duckdb_cursor.execute("set python_enable_replacements=true") # Now the depth is 1, which is enough to locate the variable rel = duckdb_cursor.sql("select * from global_polars_df") res = rel.fetchone() - assert res == (1, 'banana', 5, 'beetle') + assert res == (1, "banana", 5, "beetle") def test_scan_local(self, duckdb_cursor): - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) def inner_func(duckdb_cursor): duckdb_cursor.execute("set python_enable_replacements=false") - with pytest.raises(duckdb.CatalogException, match='Table with name df does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist"): # We set the depth to look for local variables to 0 so it's never found duckdb_cursor.sql("select * from df") duckdb_cursor.execute("set python_enable_replacements=true") - with pytest.raises(duckdb.CatalogException, match='Table with name df does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist"): # Here it's still not found, because it's not visible to this frame duckdb_cursor.sql("select * from df") - df = pd.DataFrame({'a': [4, 5, 6]}) + df = pd.DataFrame({"a": [4, 5, 6]}) duckdb_cursor.execute("set python_enable_replacements=true") # We can find the newly defined 'df' with depth 1 rel = duckdb_cursor.sql("select * from df") @@ -137,11 +137,11 @@ def inner_func(duckdb_cursor): inner_func(duckdb_cursor) def test_scan_local_unlimited(self, duckdb_cursor): - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) def inner_func(duckdb_cursor): duckdb_cursor.execute("set python_enable_replacements=true") - with pytest.raises(duckdb.CatalogException, match='Table with name df does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist"): # We set the depth to look for local variables to 1 so it's still not found because it wasn't defined in this function duckdb_cursor.sql("select * from df") duckdb_cursor.execute("set python_scan_all_frames=true") @@ -155,37 +155,37 @@ def inner_func(duckdb_cursor): def test_replacement_scan_relapi(self): con = duckdb.connect() - pyrel1 = con.query('from (values (42), (84), (120)) t(i)') + pyrel1 = con.query("from (values (42), (84), (120)) t(i)") assert isinstance(pyrel1, duckdb.DuckDBPyRelation) assert pyrel1.fetchall() == [(42,), (84,), (120,)] - pyrel2 = con.query('from pyrel1 limit 2') + pyrel2 = con.query("from pyrel1 limit 2") assert isinstance(pyrel2, duckdb.DuckDBPyRelation) assert pyrel2.fetchall() == [(42,), (84,)] - pyrel3 = con.query('select i + 100 from pyrel2') + pyrel3 = con.query("select i + 100 from pyrel2") assert type(pyrel3) == duckdb.DuckDBPyRelation assert pyrel3.fetchall() == [(142,), (184,)] def test_replacement_scan_not_found(self): con = duckdb.connect() con.execute("set python_scan_all_frames=true") - with pytest.raises(duckdb.CatalogException, match='Table with name non_existant does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name non_existant does not exist"): res = con.sql("select * from non_existant").fetchall() def test_replacement_scan_alias(self): con = duckdb.connect() - pyrel1 = con.query('from (values (1, 2)) t(i, j)') - pyrel2 = con.query('from (values (1, 10)) t(i, k)') - pyrel3 = con.query('from pyrel1 join pyrel2 using(i)') + pyrel1 = con.query("from (values (1, 2)) t(i, j)") + pyrel2 = con.query("from (values (1, 10)) t(i, k)") + pyrel3 = con.query("from pyrel1 join pyrel2 using(i)") assert type(pyrel3) == duckdb.DuckDBPyRelation assert pyrel3.fetchall() == [(1, 2, 10)] def test_replacement_scan_pandas_alias(self): con = duckdb.connect() - df1 = con.query('from (values (1, 2)) t(i, j)').df() - df2 = con.query('from (values (1, 10)) t(i, k)').df() - df3 = con.query('from df1 join df2 using(i)') + df1 = con.query("from (values (1, 2)) t(i, j)").df() + df2 = con.query("from (values (1, 10)) t(i, k)").df() + df3 = con.query("from df1 join df2 using(i)") assert df3.fetchall() == [(1, 2, 10)] def test_replacement_scan_after_creation(self, duckdb_cursor): @@ -194,14 +194,14 @@ def test_replacement_scan_after_creation(self, duckdb_cursor): rel = duckdb_cursor.sql("select * from df") duckdb_cursor.execute("drop table df") - df = pd.DataFrame({'b': [1, 2, 3]}) + df = pd.DataFrame({"b": [1, 2, 3]}) res = rel.fetchall() # FIXME: this should error instead, the 'df' table we relied on has been removed and replaced with a replacement scan assert res == [(1,), (2,), (3,)] def test_replacement_scan_caching(self, duckdb_cursor): def return_rel(conn): - df = pd.DataFrame({'a': [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) rel = conn.sql("select * from df") return rel @@ -220,7 +220,7 @@ def test_replacement_scan_fail(self): con.execute("select count(*) from random_object").fetchone() @pytest.mark.parametrize( - 'df_create', + "df_create", [ from_pandas, from_arrow, @@ -332,7 +332,7 @@ def test_same_name_cte(self, duckdb_cursor): def test_use_with_view(self, duckdb_cursor): rel = create_relation(duckdb_cursor, "select * from df") - rel.create_view('v1') + rel.create_view("v1") del rel rel = duckdb_cursor.sql("select * from v1") @@ -342,12 +342,12 @@ def test_use_with_view(self, duckdb_cursor): def create_view_in_func(con): df = pd.DataFrame({"a": [1, 2, 3]}) - con.execute('CREATE VIEW v1 AS SELECT * FROM df') + con.execute("CREATE VIEW v1 AS SELECT * FROM df") create_view_in_func(duckdb_cursor) # FIXME: this should be fixed in the future, likely by unifying the behavior of .sql and .execute - with pytest.raises(duckdb.CatalogException, match='Table with name df does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist"): rel = duckdb_cursor.sql("select * from v1") def test_recursive_cte(self, duckdb_cursor): @@ -409,7 +409,7 @@ def test_multiple_replacements(self, duckdb_cursor): """ rel = duckdb_cursor.sql(query) res = rel.fetchall() - assert res == [(2, 'Bob', None), (3, 'Charlie', None), (4, 'David', 1.0), (5, 'Eve', 1.0)] + assert res == [(2, "Bob", None), (3, "Charlie", None), (4, "David", 1.0), (5, "Eve", 1.0)] def test_cte_at_different_levels(self, duckdb_cursor): query = """ @@ -459,17 +459,17 @@ def test_replacement_disabled(self): ## disable external access con.execute("set enable_external_access=false") - with pytest.raises(duckdb.CatalogException, match='Table with name df does not exist!'): + with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist!"): rel = create_relation(con, "select * from df") res = rel.fetchall() with pytest.raises( - duckdb.InvalidInputException, match='Cannot change enable_external_access setting while database is running' + duckdb.InvalidInputException, match="Cannot change enable_external_access setting while database is running" ): con.execute("set enable_external_access=true") # Create connection with external access disabled - con = duckdb.connect(config={'enable_external_access': False}) - with pytest.raises(duckdb.CatalogException, match='Table with name df does not exist!'): + con = duckdb.connect(config={"enable_external_access": False}) + with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist!"): rel = create_relation(con, "select * from df") res = rel.fetchall() @@ -486,23 +486,23 @@ def test_replacement_disabled(self): assert res == [(1,), (2,), (3,)] def test_replacement_of_cross_connection_relation(self): - con1 = duckdb.connect(':memory:') - con2 = duckdb.connect(':memory:') - con1.query('create table integers(i int)') - con2.query('create table integers(v varchar)') - con1.query('insert into integers values (42)') - con2.query('insert into integers values (\'xxx\')') - rel1 = con1.query('select * from integers') + con1 = duckdb.connect(":memory:") + con2 = duckdb.connect(":memory:") + con1.query("create table integers(i int)") + con2.query("create table integers(v varchar)") + con1.query("insert into integers values (42)") + con2.query("insert into integers values ('xxx')") + rel1 = con1.query("select * from integers") with pytest.raises( duckdb.InvalidInputException, - match=r'The object was created by another Connection and can therefore not be used by this Connection.', + match=r"The object was created by another Connection and can therefore not be used by this Connection.", ): - con2.query('from rel1') + con2.query("from rel1") del con1 with pytest.raises( duckdb.InvalidInputException, - match=r'The object was created by another Connection and can therefore not be used by this Connection.', + match=r"The object was created by another Connection and can therefore not be used by this Connection.", ): - con2.query('from rel1') + con2.query("from rel1") diff --git a/tests/fast/test_result.py b/tests/fast/test_result.py index af68e268..906b1198 100644 --- a/tests/fast/test_result.py +++ b/tests/fast/test_result.py @@ -5,42 +5,42 @@ class TestPythonResult(object): def test_result_closed(self, duckdb_cursor): - connection = duckdb.connect('') + connection = duckdb.connect("") cursor = connection.cursor() - cursor.execute('CREATE TABLE integers (i integer)') - cursor.execute('INSERT INTO integers VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9),(NULL)') + cursor.execute("CREATE TABLE integers (i integer)") + cursor.execute("INSERT INTO integers VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9),(NULL)") rel = connection.table("integers") res = rel.aggregate("sum(i)").execute() res.close() - with pytest.raises(duckdb.InvalidInputException, match='result closed'): + with pytest.raises(duckdb.InvalidInputException, match="result closed"): res.fetchone() - with pytest.raises(duckdb.InvalidInputException, match='result closed'): + with pytest.raises(duckdb.InvalidInputException, match="result closed"): res.fetchall() - with pytest.raises(duckdb.InvalidInputException, match='result closed'): + with pytest.raises(duckdb.InvalidInputException, match="result closed"): res.fetchnumpy() - with pytest.raises(duckdb.InvalidInputException, match='There is no query result'): + with pytest.raises(duckdb.InvalidInputException, match="There is no query result"): res.fetch_arrow_table() - with pytest.raises(duckdb.InvalidInputException, match='There is no query result'): + with pytest.raises(duckdb.InvalidInputException, match="There is no query result"): res.fetch_arrow_reader(1) def test_result_describe_types(self, duckdb_cursor): - connection = duckdb.connect('') + connection = duckdb.connect("") cursor = connection.cursor() - cursor.execute('CREATE TABLE test (i bool, j TIME, k VARCHAR)') + cursor.execute("CREATE TABLE test (i bool, j TIME, k VARCHAR)") cursor.execute("INSERT INTO test VALUES (TRUE, '01:01:01', 'bla' )") rel = connection.table("test") res = rel.execute() assert res.description == [ - ('i', 'BOOLEAN', None, None, None, None, None), - ('j', 'TIME', None, None, None, None, None), - ('k', 'VARCHAR', None, None, None, None, None), + ("i", "BOOLEAN", None, None, None, None, None), + ("j", "TIME", None, None, None, None, None), + ("k", "VARCHAR", None, None, None, None, None), ] def test_result_timestamps(self, duckdb_cursor): - connection = duckdb.connect('') + connection = duckdb.connect("") cursor = connection.cursor() cursor.execute( - 'CREATE TABLE IF NOT EXISTS timestamps (sec TIMESTAMP_S, milli TIMESTAMP_MS,micro TIMESTAMP_US, nano TIMESTAMP_NS );' + "CREATE TABLE IF NOT EXISTS timestamps (sec TIMESTAMP_S, milli TIMESTAMP_MS,micro TIMESTAMP_US, nano TIMESTAMP_NS );" ) cursor.execute( "INSERT INTO timestamps VALUES ('2008-01-01 00:00:11','2008-01-01 00:00:01.794','2008-01-01 00:00:01.98926','2008-01-01 00:00:01.899268321' )" @@ -59,12 +59,12 @@ def test_result_timestamps(self, duckdb_cursor): def test_result_interval(self): connection = duckdb.connect() cursor = connection.cursor() - cursor.execute('CREATE TABLE IF NOT EXISTS intervals (ivals INTERVAL)') + cursor.execute("CREATE TABLE IF NOT EXISTS intervals (ivals INTERVAL)") cursor.execute("INSERT INTO intervals VALUES ('1 day'), ('2 second'), ('1 microsecond')") rel = connection.table("intervals") res = rel.execute() - assert res.description == [('ivals', 'INTERVAL', None, None, None, None, None)] + assert res.description == [("ivals", "INTERVAL", None, None, None, None, None)] assert res.fetchall() == [ (datetime.timedelta(days=1.0),), (datetime.timedelta(seconds=2.0),), diff --git a/tests/fast/test_runtime_error.py b/tests/fast/test_runtime_error.py index 29e81d1e..327be004 100644 --- a/tests/fast/test_runtime_error.py +++ b/tests/fast/test_runtime_error.py @@ -2,8 +2,8 @@ import pytest from conftest import NumpyPandas, ArrowPandas -closed = lambda: pytest.raises(duckdb.ConnectionException, match='Connection already closed') -no_result_set = lambda: pytest.raises(duckdb.InvalidInputException, match='No open result set') +closed = lambda: pytest.raises(duckdb.ConnectionException, match="Connection already closed") +no_result_set = lambda: pytest.raises(duckdb.InvalidInputException, match="No open result set") class TestRuntimeError(object): @@ -20,7 +20,7 @@ def test_df_error(self): con.execute("select i::int from tbl").df() def test_arrow_error(self): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") con = duckdb.connect() con.execute("create table tbl as select 'hello' i") @@ -34,83 +34,83 @@ def test_register_error(self): con.register(py_obj, "v") def test_arrow_fetch_table_error(self): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") con = duckdb.connect() arrow_object = con.execute("select 1").fetch_arrow_table() arrow_relation = con.from_arrow(arrow_object) res = arrow_relation.execute() res.close() - with pytest.raises(duckdb.InvalidInputException, match='There is no query result'): + with pytest.raises(duckdb.InvalidInputException, match="There is no query result"): res.fetch_arrow_table() def test_arrow_record_batch_reader_error(self): - pytest.importorskip('pyarrow') + pytest.importorskip("pyarrow") con = duckdb.connect() arrow_object = con.execute("select 1").fetch_arrow_table() arrow_relation = con.from_arrow(arrow_object) res = arrow_relation.execute() res.close() - with pytest.raises(duckdb.ProgrammingError, match='There is no query result'): + with pytest.raises(duckdb.ProgrammingError, match="There is no query result"): res.fetch_arrow_reader(1) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_relation_cache_fetchall(self, pandas): conn = duckdb.connect() df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) conn.execute("create view x as select * from df_in") rel = conn.query("select * from x") del df_in - with pytest.raises(duckdb.ProgrammingError, match='Table with name df_in does not exist'): + with pytest.raises(duckdb.ProgrammingError, match="Table with name df_in does not exist"): # Even when we preserve ExternalDependency objects correctly, this is not supported # Relations only save dependencies for their immediate TableRefs, # so the dependency of 'x' on 'df_in' is not registered in 'rel' rel.fetchall() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_relation_cache_execute(self, pandas): conn = duckdb.connect() df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) conn.execute("create view x as select * from df_in") rel = conn.query("select * from x") del df_in - with pytest.raises(duckdb.ProgrammingError, match='Table with name df_in does not exist'): + with pytest.raises(duckdb.ProgrammingError, match="Table with name df_in does not exist"): rel.execute() - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_relation_query_error(self, pandas): conn = duckdb.connect() df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) conn.execute("create view x as select * from df_in") rel = conn.query("select * from x") del df_in - with pytest.raises(duckdb.CatalogException, match='Table with name df_in does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name df_in does not exist"): rel.query("bla", "select * from bla") - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_conn_broken_statement_error(self, pandas): conn = duckdb.connect() df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) conn.execute("create view x as select * from df_in") del df_in - with pytest.raises(duckdb.CatalogException, match='Table with name df_in does not exist'): + with pytest.raises(duckdb.CatalogException, match="Table with name df_in does not exist"): conn.execute("select 1; select * from x; select 3;") def test_conn_prepared_statement_error(self): @@ -118,17 +118,17 @@ def test_conn_prepared_statement_error(self): conn.execute("create table integers (a integer, b integer)") with pytest.raises( duckdb.InvalidInputException, - match='Values were not provided for the following prepared statement parameters: 2', + match="Values were not provided for the following prepared statement parameters: 2", ): conn.execute("select * from integers where a =? and b=?", [1]) - @pytest.mark.parametrize('pandas', [NumpyPandas(), ArrowPandas()]) + @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_closed_conn_exceptions(self, pandas): conn = duckdb.connect() conn.close() df_in = pandas.DataFrame( { - 'numbers': [1, 2, 3, 4, 5], + "numbers": [1, 2, 3, 4, 5], } ) diff --git a/tests/fast/test_sql_expression.py b/tests/fast/test_sql_expression.py index 771be84d..4dc4cab5 100644 --- a/tests/fast/test_sql_expression.py +++ b/tests/fast/test_sql_expression.py @@ -9,7 +9,6 @@ class TestSQLExpression(object): def test_sql_expression_basic(self, duckdb_cursor): - # Test simple constant expressions expr = SQLExpression("42") rel = duckdb_cursor.sql("SELECT 1").select(expr) @@ -17,7 +16,7 @@ def test_sql_expression_basic(self, duckdb_cursor): expr = SQLExpression("'hello'") rel = duckdb_cursor.sql("SELECT 1").select(expr) - assert rel.fetchall() == [('hello',)] + assert rel.fetchall() == [("hello",)] expr = SQLExpression("NULL") rel = duckdb_cursor.sql("SELECT 1").select(expr) @@ -43,14 +42,13 @@ def test_sql_expression_basic(self, duckdb_cursor): # Test function calls expr = SQLExpression("UPPER('test')") rel = duckdb_cursor.sql("SELECT 1").select(expr) - assert rel.fetchall() == [('TEST',)] + assert rel.fetchall() == [("TEST",)] expr = SQLExpression("CONCAT('hello', ' ', 'world')") rel = duckdb_cursor.sql("SELECT 1").select(expr) - assert rel.fetchall() == [('hello world',)] + assert rel.fetchall() == [("hello world",)] def test_sql_expression_with_columns(self, duckdb_cursor): - # Create a test table duckdb_cursor.execute( """ @@ -75,12 +73,12 @@ def test_sql_expression_with_columns(self, duckdb_cursor): expr = SQLExpression("UPPER(b)") rel2 = rel.select(expr) - assert rel2.fetchall() == [('ONE',), ('TWO',), ('THREE',)] + assert rel2.fetchall() == [("ONE",), ("TWO",), ("THREE",)] # Test complex expressions expr = SQLExpression("CASE WHEN a > 1 THEN b ELSE 'default' END") rel2 = rel.select(expr) - assert rel2.fetchall() == [('default',), ('two',), ('three',)] + assert rel2.fetchall() == [("default",), ("two",), ("three",)] # Test combining with other expression types expr1 = SQLExpression("a + 5") @@ -122,8 +120,8 @@ def test_sql_expression_alias(self, duckdb_cursor): rel = duckdb_cursor.table("test_alias") expr = SQLExpression("a + 10").alias("a_plus_10") rel2 = rel.select(expr, "b") - assert rel2.fetchall() == [(11, 'one'), (12, 'two')] - assert rel2.columns == ['a_plus_10', 'b'] + assert rel2.fetchall() == [(11, "one"), (12, "two")] + assert rel2.columns == ["a_plus_10", "b"] def test_sql_expression_in_filter(self, duckdb_cursor): duckdb_cursor.execute( @@ -142,18 +140,18 @@ def test_sql_expression_in_filter(self, duckdb_cursor): # Test filter with SQL expression expr = SQLExpression("a > 2") rel2 = rel.filter(expr) - assert rel2.fetchall() == [(3, 'three'), (4, 'four')] + assert rel2.fetchall() == [(3, "three"), (4, "four")] # Test complex filter expr = SQLExpression("a % 2 = 0 AND b LIKE '%o%'") rel2 = rel.filter(expr) - assert rel2.fetchall() == [(2, 'two'), (4, 'four')] + assert rel2.fetchall() == [(2, "two"), (4, "four")] # Test combining with other expression types expr1 = SQLExpression("a > 1") expr2 = ColumnExpression("b") == ConstantExpression("four") rel2 = rel.filter(expr1 & expr2) - assert rel2.fetchall() == [(4, 'four')] + assert rel2.fetchall() == [(4, "four")] def test_sql_expression_in_aggregates(self, duckdb_cursor): duckdb_cursor.execute( @@ -176,14 +174,14 @@ def test_sql_expression_in_aggregates(self, duckdb_cursor): # Test aggregation with group by expr = SQLExpression("SUM(c)") - rel2 = rel.aggregate([expr, "b"]).sort('b') + rel2 = rel.aggregate([expr, "b"]).sort("b") result = rel2.fetchall() - assert result == [(30, 'group1'), (70, 'group2')] + assert result == [(30, "group1"), (70, "group2")] # Test multiple aggregations expr1 = SQLExpression("SUM(a)").alias("sum_a") expr2 = SQLExpression("AVG(c)").alias("avg_c") - rel2 = rel.aggregate([expr1, expr2], "b").sort('sum_a', 'avg_c') + rel2 = rel.aggregate([expr1, expr2], "b").sort("sum_a", "avg_c") result = rel2.fetchall() result.sort() assert result == [(3, 15.0), (7, 35.0)] diff --git a/tests/fast/test_string_annotation.py b/tests/fast/test_string_annotation.py index c5500c66..83685bed 100644 --- a/tests/fast/test_string_annotation.py +++ b/tests/fast/test_string_annotation.py @@ -14,7 +14,7 @@ def test_base(): test_base.__code__, test_base.__globals__, test_base.__name__, test_base.__defaults__, test_base.__closure__ ) # Add the 'type' string as return_annotation - test_function.__annotations__ = {'return': type} + test_function.__annotations__ = {"return": type} return test_function @@ -33,12 +33,12 @@ class TestStringAnnotation(object): python_version_lower_than_3_10(), reason="inspect.signature(eval_str=True) only supported since 3.10 and higher" ) @pytest.mark.parametrize( - ['input', 'expected'], + ["input", "expected"], [ - ('str', 'VARCHAR'), - ('list[str]', 'VARCHAR[]'), - ('dict[str, str]', 'MAP(VARCHAR, VARCHAR)'), - ('dict[Union[str, bool], str]', 'MAP(UNION(u1 VARCHAR, u2 BOOLEAN), VARCHAR)'), + ("str", "VARCHAR"), + ("list[str]", "VARCHAR[]"), + ("dict[str, str]", "MAP(VARCHAR, VARCHAR)"), + ("dict[Union[str, bool], str]", "MAP(UNION(u1 VARCHAR, u2 BOOLEAN), VARCHAR)"), ], ) def test_string_annotations(self, duckdb_cursor, input, expected): diff --git a/tests/fast/test_tf.py b/tests/fast/test_tf.py index b65acec6..db93d0de 100644 --- a/tests/fast/test_tf.py +++ b/tests/fast/test_tf.py @@ -2,7 +2,7 @@ import pytest -tf = pytest.importorskip('tensorflow') +tf = pytest.importorskip("tensorflow") def test_tf(): @@ -14,16 +14,16 @@ def test_tf(): # Test from connection duck_tf = con.execute("select * from t").tf() duck_numpy = con.sql("select * from t").fetchnumpy() - tf.math.equal(duck_tf['a'], tf.convert_to_tensor(duck_numpy['a'])) - tf.math.equal(duck_tf['b'], tf.convert_to_tensor(duck_numpy['b'])) + tf.math.equal(duck_tf["a"], tf.convert_to_tensor(duck_numpy["a"])) + tf.math.equal(duck_tf["b"], tf.convert_to_tensor(duck_numpy["b"])) # Test from relation duck_tf = con.sql("select * from t").tf() - tf.math.equal(duck_tf['a'], tf.convert_to_tensor(duck_numpy['a'])) - tf.math.equal(duck_tf['b'], tf.convert_to_tensor(duck_numpy['b'])) + tf.math.equal(duck_tf["a"], tf.convert_to_tensor(duck_numpy["a"])) + tf.math.equal(duck_tf["b"], tf.convert_to_tensor(duck_numpy["b"])) # Test all Numeric Types - numeric_types = ['TINYINT', 'SMALLINT', 'BIGINT', 'HUGEINT', 'FLOAT', 'DOUBLE', 'DECIMAL(4,1)', 'UTINYINT'] + numeric_types = ["TINYINT", "SMALLINT", "BIGINT", "HUGEINT", "FLOAT", "DOUBLE", "DECIMAL(4,1)", "UTINYINT"] for supported_type in numeric_types: con = duckdb.connect() @@ -31,5 +31,5 @@ def test_tf(): con.execute("insert into t values (1,2), (3,4)") duck_tf = con.sql("select * from t").tf() duck_numpy = con.sql("select * from t").fetchnumpy() - tf.math.equal(duck_tf['a'], tf.convert_to_tensor(duck_numpy['a'])) - tf.math.equal(duck_tf['b'], tf.convert_to_tensor(duck_numpy['b'])) + tf.math.equal(duck_tf["a"], tf.convert_to_tensor(duck_numpy["a"])) + tf.math.equal(duck_tf["b"], tf.convert_to_tensor(duck_numpy["b"])) diff --git a/tests/fast/test_transaction.py b/tests/fast/test_transaction.py index 54deaf82..ff0ba1a7 100644 --- a/tests/fast/test_transaction.py +++ b/tests/fast/test_transaction.py @@ -5,16 +5,16 @@ class TestConnectionTransaction(object): def test_transaction(self, duckdb_cursor): con = duckdb.connect() - con.execute('create table t (i integer)') - con.execute('insert into t values (1)') + con.execute("create table t (i integer)") + con.execute("insert into t values (1)") con.begin() - con.execute('insert into t values (1)') - assert con.execute('select count (*) from t').fetchone()[0] == 2 + con.execute("insert into t values (1)") + assert con.execute("select count (*) from t").fetchone()[0] == 2 con.rollback() - assert con.execute('select count (*) from t').fetchone()[0] == 1 + assert con.execute("select count (*) from t").fetchone()[0] == 1 con.begin() - con.execute('insert into t values (1)') - assert con.execute('select count (*) from t').fetchone()[0] == 2 + con.execute("insert into t values (1)") + assert con.execute("select count (*) from t").fetchone()[0] == 2 con.commit() - assert con.execute('select count (*) from t').fetchone()[0] == 2 + assert con.execute("select count (*) from t").fetchone()[0] == 2 diff --git a/tests/fast/test_type.py b/tests/fast/test_type.py index c5a62694..1e8ebc25 100644 --- a/tests/fast/test_type.py +++ b/tests/fast/test_type.py @@ -40,83 +40,83 @@ class TestType(object): def test_sqltype(self): - assert str(duckdb.sqltype('struct(a VARCHAR, b BIGINT)')) == 'STRUCT(a VARCHAR, b BIGINT)' + assert str(duckdb.sqltype("struct(a VARCHAR, b BIGINT)")) == "STRUCT(a VARCHAR, b BIGINT)" # todo: add tests with invalid type_str def test_primitive_types(self): assert str(SQLNULL) == '"NULL"' - assert str(BOOLEAN) == 'BOOLEAN' - assert str(TINYINT) == 'TINYINT' - assert str(UTINYINT) == 'UTINYINT' - assert str(SMALLINT) == 'SMALLINT' - assert str(USMALLINT) == 'USMALLINT' - assert str(INTEGER) == 'INTEGER' - assert str(UINTEGER) == 'UINTEGER' - assert str(BIGINT) == 'BIGINT' - assert str(UBIGINT) == 'UBIGINT' - assert str(HUGEINT) == 'HUGEINT' - assert str(UHUGEINT) == 'UHUGEINT' - assert str(UUID) == 'UUID' - assert str(FLOAT) == 'FLOAT' - assert str(DOUBLE) == 'DOUBLE' - assert str(DATE) == 'DATE' - assert str(TIMESTAMP) == 'TIMESTAMP' - assert str(TIMESTAMP_MS) == 'TIMESTAMP_MS' - assert str(TIMESTAMP_NS) == 'TIMESTAMP_NS' - assert str(TIMESTAMP_S) == 'TIMESTAMP_S' - assert str(TIME) == 'TIME' - assert str(TIME_TZ) == 'TIME WITH TIME ZONE' - assert str(TIMESTAMP_TZ) == 'TIMESTAMP WITH TIME ZONE' - assert str(VARCHAR) == 'VARCHAR' - assert str(BLOB) == 'BLOB' - assert str(BIT) == 'BIT' - assert str(INTERVAL) == 'INTERVAL' + assert str(BOOLEAN) == "BOOLEAN" + assert str(TINYINT) == "TINYINT" + assert str(UTINYINT) == "UTINYINT" + assert str(SMALLINT) == "SMALLINT" + assert str(USMALLINT) == "USMALLINT" + assert str(INTEGER) == "INTEGER" + assert str(UINTEGER) == "UINTEGER" + assert str(BIGINT) == "BIGINT" + assert str(UBIGINT) == "UBIGINT" + assert str(HUGEINT) == "HUGEINT" + assert str(UHUGEINT) == "UHUGEINT" + assert str(UUID) == "UUID" + assert str(FLOAT) == "FLOAT" + assert str(DOUBLE) == "DOUBLE" + assert str(DATE) == "DATE" + assert str(TIMESTAMP) == "TIMESTAMP" + assert str(TIMESTAMP_MS) == "TIMESTAMP_MS" + assert str(TIMESTAMP_NS) == "TIMESTAMP_NS" + assert str(TIMESTAMP_S) == "TIMESTAMP_S" + assert str(TIME) == "TIME" + assert str(TIME_TZ) == "TIME WITH TIME ZONE" + assert str(TIMESTAMP_TZ) == "TIMESTAMP WITH TIME ZONE" + assert str(VARCHAR) == "VARCHAR" + assert str(BLOB) == "BLOB" + assert str(BIT) == "BIT" + assert str(INTERVAL) == "INTERVAL" def test_list_type(self): type = duckdb.list_type(BIGINT) - assert str(type) == 'BIGINT[]' + assert str(type) == "BIGINT[]" def test_array_type(self): type = duckdb.array_type(BIGINT, 3) - assert str(type) == 'BIGINT[3]' + assert str(type) == "BIGINT[3]" def test_struct_type(self): - type = duckdb.struct_type({'a': BIGINT, 'b': BOOLEAN}) - assert str(type) == 'STRUCT(a BIGINT, b BOOLEAN)' + type = duckdb.struct_type({"a": BIGINT, "b": BOOLEAN}) + assert str(type) == "STRUCT(a BIGINT, b BOOLEAN)" # FIXME: create an unnamed struct when fields are provided as a list type = duckdb.struct_type([BIGINT, BOOLEAN]) - assert str(type) == 'STRUCT(v1 BIGINT, v2 BOOLEAN)' + assert str(type) == "STRUCT(v1 BIGINT, v2 BOOLEAN)" def test_incomplete_struct_type(self): with pytest.raises( - duckdb.InvalidInputException, match='Could not convert empty dictionary to a duckdb STRUCT type' + duckdb.InvalidInputException, match="Could not convert empty dictionary to a duckdb STRUCT type" ): type = duckdb.typing.DuckDBPyType(dict()) def test_map_type(self): type = duckdb.map_type(duckdb.sqltype("BIGINT"), duckdb.sqltype("DECIMAL(10, 2)")) - assert str(type) == 'MAP(BIGINT, DECIMAL(10,2))' + assert str(type) == "MAP(BIGINT, DECIMAL(10,2))" def test_decimal_type(self): type = duckdb.decimal_type(5, 3) - assert str(type) == 'DECIMAL(5,3)' + assert str(type) == "DECIMAL(5,3)" def test_string_type(self): type = duckdb.string_type() - assert str(type) == 'VARCHAR' + assert str(type) == "VARCHAR" def test_string_type_collation(self): - type = duckdb.string_type('NOCASE') + type = duckdb.string_type("NOCASE") # collation does not show up in the string representation.. - assert str(type) == 'VARCHAR' + assert str(type) == "VARCHAR" def test_union_type(self): type = duckdb.union_type([BIGINT, VARCHAR, TINYINT]) - assert str(type) == 'UNION(v1 BIGINT, v2 VARCHAR, v3 TINYINT)' + assert str(type) == "UNION(v1 BIGINT, v2 VARCHAR, v3 TINYINT)" - type = duckdb.union_type({'a': BIGINT, 'b': VARCHAR, 'c': TINYINT}) - assert str(type) == 'UNION(a BIGINT, b VARCHAR, c TINYINT)' + type = duckdb.union_type({"a": BIGINT, "b": VARCHAR, "c": TINYINT}) + assert str(type) == "UNION(a BIGINT, b VARCHAR, c TINYINT)" import sys @@ -125,42 +125,42 @@ def test_implicit_convert_from_builtin_type(self): type = duckdb.list_type(list[str]) assert str(type.child) == "VARCHAR[]" - mapping = {str: 'VARCHAR', int: 'BIGINT', bytes: 'BLOB', bytearray: 'BLOB', bool: 'BOOLEAN', float: 'DOUBLE'} + mapping = {str: "VARCHAR", int: "BIGINT", bytes: "BLOB", bytearray: "BLOB", bool: "BOOLEAN", float: "DOUBLE"} for duckdb_type, expected in mapping.items(): res = duckdb.list_type(duckdb_type) assert str(res.child) == expected - res = duckdb.list_type({'a': str, 'b': int}) - assert str(res.child) == 'STRUCT(a VARCHAR, b BIGINT)' + res = duckdb.list_type({"a": str, "b": int}) + assert str(res.child) == "STRUCT(a VARCHAR, b BIGINT)" res = duckdb.list_type(dict[str, int]) - assert str(res.child) == 'MAP(VARCHAR, BIGINT)' + assert str(res.child) == "MAP(VARCHAR, BIGINT)" res = duckdb.list_type(list[str]) - assert str(res.child) == 'VARCHAR[]' + assert str(res.child) == "VARCHAR[]" res = duckdb.list_type(list[dict[str, dict[list[str], str]]]) - assert str(res.child) == 'MAP(VARCHAR, MAP(VARCHAR[], VARCHAR))[]' + assert str(res.child) == "MAP(VARCHAR, MAP(VARCHAR[], VARCHAR))[]" res = duckdb.list_type(list[Union[str, int]]) - assert str(res.child) == 'UNION(u1 VARCHAR, u2 BIGINT)[]' + assert str(res.child) == "UNION(u1 VARCHAR, u2 BIGINT)[]" def test_implicit_convert_from_numpy(self, duckdb_cursor): np = pytest.importorskip("numpy") type_mapping = { - 'bool': 'BOOLEAN', - 'int8': 'TINYINT', - 'uint8': 'UTINYINT', - 'int16': 'SMALLINT', - 'uint16': 'USMALLINT', - 'int32': 'INTEGER', - 'uint32': 'UINTEGER', - 'int64': 'BIGINT', - 'uint64': 'UBIGINT', - 'float16': 'FLOAT', - 'float32': 'FLOAT', - 'float64': 'DOUBLE', + "bool": "BOOLEAN", + "int8": "TINYINT", + "uint8": "UTINYINT", + "int16": "SMALLINT", + "uint16": "USMALLINT", + "int32": "INTEGER", + "uint32": "UINTEGER", + "int64": "BIGINT", + "uint64": "UBIGINT", + "float16": "FLOAT", + "float32": "FLOAT", + "float64": "DOUBLE", } builtins = [] @@ -189,30 +189,30 @@ def test_implicit_convert_from_numpy(self, duckdb_cursor): def test_attribute_accessor(self): type = duckdb.row_type([BIGINT, duckdb.list_type(duckdb.map_type(BLOB, BIT))]) - assert hasattr(type, 'a') == False - assert hasattr(type, 'v1') == True + assert hasattr(type, "a") == False + assert hasattr(type, "v1") == True - field_one = type['v1'] - assert str(field_one) == 'BIGINT' + field_one = type["v1"] + assert str(field_one) == "BIGINT" field_one = type.v1 - assert str(field_one) == 'BIGINT' + assert str(field_one) == "BIGINT" - field_two = type['v2'] - assert str(field_two) == 'MAP(BLOB, BIT)[]' + field_two = type["v2"] + assert str(field_two) == "MAP(BLOB, BIT)[]" child_type = type.v2.child - assert str(child_type) == 'MAP(BLOB, BIT)' + assert str(child_type) == "MAP(BLOB, BIT)" def test_json_type(self): - json_type = duckdb.type('JSON') + json_type = duckdb.type("JSON") val = duckdb.Value('{"duck": 42}', json_type) res = duckdb.execute("select typeof($1)", [val]).fetchone() - assert res == ('JSON',) + assert res == ("JSON",) def test_struct_from_dict(self): - res = duckdb.list_type({'a': VARCHAR, 'b': VARCHAR}) - assert res == 'STRUCT(a VARCHAR, b VARCHAR)[]' + res = duckdb.list_type({"a": VARCHAR, "b": VARCHAR}) + assert res == "STRUCT(a VARCHAR, b VARCHAR)[]" def test_hash_method(self): type1 = duckdb.list_type({'a': VARCHAR, 'b': VARCHAR}) @@ -232,29 +232,29 @@ def test_hash_method(self): @pytest.mark.skipif(sys.version_info < (3, 9), reason="python3.7 does not store Optional[..] in a recognized way") def test_optional(self): type = duckdb.typing.DuckDBPyType(Optional[str]) - assert type == 'VARCHAR' + assert type == "VARCHAR" type = duckdb.typing.DuckDBPyType(Optional[Union[int, bool]]) - assert type == 'UNION(u1 BIGINT, u2 BOOLEAN)' + assert type == "UNION(u1 BIGINT, u2 BOOLEAN)" type = duckdb.typing.DuckDBPyType(Optional[list[int]]) - assert type == 'BIGINT[]' + assert type == "BIGINT[]" type = duckdb.typing.DuckDBPyType(Optional[dict[int, str]]) - assert type == 'MAP(BIGINT, VARCHAR)' + assert type == "MAP(BIGINT, VARCHAR)" type = duckdb.typing.DuckDBPyType(Optional[dict[Optional[int], Optional[str]]]) - assert type == 'MAP(BIGINT, VARCHAR)' + assert type == "MAP(BIGINT, VARCHAR)" type = duckdb.typing.DuckDBPyType(Optional[dict[Optional[int], Optional[str]]]) - assert type == 'MAP(BIGINT, VARCHAR)' + assert type == "MAP(BIGINT, VARCHAR)" type = duckdb.typing.DuckDBPyType(Optional[Union[Optional[str], Optional[bool]]]) - assert type == 'UNION(u1 VARCHAR, u2 BOOLEAN)' + assert type == "UNION(u1 VARCHAR, u2 BOOLEAN)" type = duckdb.typing.DuckDBPyType(Union[str, None]) - assert type == 'VARCHAR' + assert type == "VARCHAR" @pytest.mark.skipif(sys.version_info < (3, 10), reason="'str | None' syntax requires Python 3.10 or higher") def test_optional_310(self): type = duckdb.typing.DuckDBPyType(str | None) - assert type == 'VARCHAR' + assert type == "VARCHAR" def test_children_attribute(self): - assert DuckDBPyType('INTEGER[]').children == [('child', DuckDBPyType('INTEGER'))] - assert DuckDBPyType('INTEGER[2]').children == [('child', DuckDBPyType('INTEGER')), ('size', 2)] - assert DuckDBPyType('INTEGER[2][3]').children == [('child', DuckDBPyType('INTEGER[2]')), ('size', 3)] - assert DuckDBPyType("ENUM('a', 'b', 'c')").children == [('values', ['a', 'b', 'c'])] + assert DuckDBPyType("INTEGER[]").children == [("child", DuckDBPyType("INTEGER"))] + assert DuckDBPyType("INTEGER[2]").children == [("child", DuckDBPyType("INTEGER")), ("size", 2)] + assert DuckDBPyType("INTEGER[2][3]").children == [("child", DuckDBPyType("INTEGER[2]")), ("size", 3)] + assert DuckDBPyType("ENUM('a', 'b', 'c')").children == [("values", ["a", "b", "c"])] diff --git a/tests/fast/test_type_explicit.py b/tests/fast/test_type_explicit.py index 23dcddc3..7b0797e6 100644 --- a/tests/fast/test_type_explicit.py +++ b/tests/fast/test_type_explicit.py @@ -2,19 +2,18 @@ class TestMap(object): - def test_array_list_tuple_ambiguity(self): con = duckdb.connect() - res = con.sql("SELECT $arg", params={'arg': (1, 2)}).fetchall()[0][0] + res = con.sql("SELECT $arg", params={"arg": (1, 2)}).fetchall()[0][0] assert res == [1, 2] # By using an explicit duckdb.Value with an array type, we should convert the input as an array # and get an array (tuple) back typ = duckdb.array_type(duckdb.typing.BIGINT, 2) val = duckdb.Value((1, 2), typ) - res = con.sql("SELECT $arg", params={'arg': val}).fetchall()[0][0] + res = con.sql("SELECT $arg", params={"arg": val}).fetchall()[0][0] assert res == (1, 2) val = duckdb.Value([3, 4], typ) - res = con.sql("SELECT $arg", params={'arg': val}).fetchall()[0][0] + res = con.sql("SELECT $arg", params={"arg": val}).fetchall()[0][0] assert res == (3, 4) diff --git a/tests/fast/test_unicode.py b/tests/fast/test_unicode.py index b697f84a..7d08ac88 100644 --- a/tests/fast/test_unicode.py +++ b/tests/fast/test_unicode.py @@ -7,7 +7,7 @@ class TestUnicode(object): def test_unicode_pandas_scan(self, duckdb_cursor): - con = duckdb.connect(database=':memory:', read_only=False) - test_df = pd.DataFrame.from_dict({"i": [1, 2, 3], "j": ["a", "c", u"ë"]}) - con.register('test_df_view', test_df) - con.execute('SELECT i, j, LENGTH(j) FROM test_df_view').fetchall() + con = duckdb.connect(database=":memory:", read_only=False) + test_df = pd.DataFrame.from_dict({"i": [1, 2, 3], "j": ["a", "c", "ë"]}) + con.register("test_df_view", test_df) + con.execute("SELECT i, j, LENGTH(j) FROM test_df_view").fetchall() diff --git a/tests/fast/test_value.py b/tests/fast/test_value.py index 4f74516c..c17264fd 100644 --- a/tests/fast/test_value.py +++ b/tests/fast/test_value.py @@ -71,7 +71,7 @@ class TestValue(object): # This excludes timezone aware values, as those are a pain to test @pytest.mark.parametrize( - 'item', + "item", [ (BOOLEAN, BooleanValue(True), True), (UTINYINT, UnsignedBinaryValue(129), 129), @@ -88,17 +88,17 @@ class TestValue(object): (DOUBLE, DoubleValue(0.23234234234), 0.23234234234), ( duckdb.decimal_type(12, 8), - DecimalValue(decimal.Decimal('1234.12345678'), 12, 8), - decimal.Decimal('1234.12345678'), + DecimalValue(decimal.Decimal("1234.12345678"), 12, 8), + decimal.Decimal("1234.12345678"), ), - (VARCHAR, StringValue('this is a long string'), 'this is a long string'), + (VARCHAR, StringValue("this is a long string"), "this is a long string"), ( UUID, - UUIDValue(uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')), - uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'), + UUIDValue(uuid.UUID("ffffffff-ffff-ffff-ffff-ffffffffffff")), + uuid.UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), ), - (BIT, BitValue(b'010101010101'), '010101010101'), - (BLOB, BlobValue(b'\x00\x00\x00a'), b'\x00\x00\x00a'), + (BIT, BitValue(b"010101010101"), "010101010101"), + (BLOB, BlobValue(b"\x00\x00\x00a"), b"\x00\x00\x00a"), (DATE, DateValue(datetime.date(2000, 5, 4)), datetime.date(2000, 5, 4)), (INTERVAL, IntervalValue(datetime.timedelta(days=5)), datetime.timedelta(days=5)), ( @@ -116,10 +116,10 @@ def test_value_helpers(self, item): expected_value = item[2] con = duckdb.connect() - observed_type = con.execute('select typeof(a) from (select $1) tbl(a)', [value_object]).fetchall()[0][0] + observed_type = con.execute("select typeof(a) from (select $1) tbl(a)", [value_object]).fetchall()[0][0] assert observed_type == str(expected_type) - con.execute('select $1', [value_object]) + con.execute("select $1", [value_object]) result = con.fetchone() result = result[0] assert result == expected_value @@ -129,10 +129,10 @@ def test_float_to_decimal_prevention(self): con = duckdb.connect() with pytest.raises(duckdb.ConversionException, match="Can't losslessly convert"): - con.execute('select $1', [value]).fetchall() + con.execute("select $1", [value]).fetchall() @pytest.mark.parametrize( - 'value', + "value", [ TimestampSecondValue(datetime.datetime(1970, 3, 21, 12, 36, 43)), TimestampMilisecondValue(datetime.datetime(1970, 3, 21, 12, 36, 43)), @@ -144,10 +144,10 @@ def test_timestamp_sec_not_supported(self, value): with pytest.raises( duckdb.NotImplementedException, match="Conversion from 'datetime' to type .* is not implemented yet" ): - con.execute('select $1', [value]).fetchall() + con.execute("select $1", [value]).fetchall() @pytest.mark.parametrize( - 'target_type,test_value,expected_conversion_success', + "target_type,test_value,expected_conversion_success", [ (TINYINT, 0, True), (TINYINT, 255, False), @@ -187,7 +187,7 @@ def test_numeric_values(self, target_type, test_value, expected_conversion_succe value = Value(test_value, target_type) con = duckdb.connect() - work = lambda: con.execute('select typeof(a) from (select $1) tbl(a)', [value]).fetchall() + work = lambda: con.execute("select typeof(a) from (select $1) tbl(a)", [value]).fetchall() if expected_conversion_success: res = work() diff --git a/tests/fast/test_versioning.py b/tests/fast/test_versioning.py index 7a3c7a68..2ec3f784 100644 --- a/tests/fast/test_versioning.py +++ b/tests/fast/test_versioning.py @@ -1,6 +1,7 @@ """ Tests for duckdb_pytooling versioning functionality. """ + import os import unittest @@ -109,26 +110,26 @@ def test_bump_version_exact_tag(self): assert _bump_version("1.2.3", 0, False) == "1.2.3" assert _bump_version("1.2.3.post1", 0, False) == "1.2.3.post1" - @patch.dict('os.environ', {'MAIN_BRANCH_VERSIONING': '1'}) + @patch.dict("os.environ", {"MAIN_BRANCH_VERSIONING": "1"}) def test_bump_version_with_distance(self): """Test bump_version with distance from tag.""" assert _bump_version("1.2.3", 5, False) == "1.3.0.dev5" - + # Post-release development assert _bump_version("1.2.3.post1", 3, False) == "1.2.3.post2.dev3" - @patch.dict('os.environ', {'MAIN_BRANCH_VERSIONING': '0'}) + @patch.dict("os.environ", {"MAIN_BRANCH_VERSIONING": "0"}) def test_bump_version_release_branch(self): """Test bump_version on bugfix branch.""" assert _bump_version("1.2.3", 5, False) == "1.2.4.dev5" - @patch.dict('os.environ', {'MAIN_BRANCH_VERSIONING': '1'}) + @patch.dict("os.environ", {"MAIN_BRANCH_VERSIONING": "1"}) def test_bump_version_dirty(self): """Test bump_version with dirty working directory.""" assert _bump_version("1.2.3", 0, True) == "1.3.0.dev0" assert _bump_version("1.2.3.post1", 0, True) == "1.2.3.post2.dev0" - @patch.dict('os.environ', {'MAIN_BRANCH_VERSIONING': '1'}) + @patch.dict("os.environ", {"MAIN_BRANCH_VERSIONING": "1"}) def test_version_scheme_function(self): """Test the version_scheme function that setuptools_scm calls.""" # Mock setuptools_scm version object @@ -136,7 +137,7 @@ def test_version_scheme_function(self): mock_version.tag = "1.2.3" mock_version.distance = 5 mock_version.dirty = False - + result = version_scheme(mock_version) assert result == "1.3.0.dev5" @@ -149,48 +150,45 @@ def test_bump_version_invalid_format(self): class TestGitOperations(unittest.TestCase): """Test git-related operations (mocked).""" - @patch('subprocess.run') + @patch("subprocess.run") def test_get_current_version_success(self, mock_run): """Test successful current version retrieval.""" mock_run.return_value.stdout = "v1.2.3\n" mock_run.return_value.check = True - + result = get_current_version() assert result == "1.2.3" mock_run.assert_called_once_with( - ["git", "describe", "--tags", "--abbrev=0"], - capture_output=True, - text=True, - check=True + ["git", "describe", "--tags", "--abbrev=0"], capture_output=True, text=True, check=True ) - @patch('subprocess.run') + @patch("subprocess.run") def test_get_current_version_with_post_release(self, mock_run): """Test current version retrieval with post-release tag.""" mock_run.return_value.stdout = "v1.2.3-post1\n" mock_run.return_value.check = True - + result = get_current_version() assert result == "1.2.3.post1" - @patch('subprocess.run') + @patch("subprocess.run") def test_get_current_version_no_tags(self, mock_run): """Test current version retrieval when no tags exist.""" mock_run.side_effect = subprocess.CalledProcessError(1, "git") - + result = get_current_version() assert result is None - @patch('subprocess.run') + @patch("subprocess.run") def test_get_git_describe_success(self, mock_run): """Test successful git describe.""" mock_run.return_value.stdout = "v1.2.3-5-g1234567\n" mock_run.return_value.check = True - + result = get_git_describe() assert result == "v1.2.3-5-g1234567" - @patch('subprocess.run') + @patch("subprocess.run") def test_get_git_describe_no_tags(self, mock_run): """Test git describe when no tags exist.""" mock_run.side_effect = subprocess.CalledProcessError(1, "git") @@ -202,21 +200,21 @@ def test_get_git_describe_no_tags(self, mock_run): class TestEnvironmentVariableHandling(unittest.TestCase): """Test environment variable handling in setuptools_scm integration.""" - @patch.dict('os.environ', {'OVERRIDE_GIT_DESCRIBE': 'v1.2.3-5-g1234567'}) + @patch.dict("os.environ", {"OVERRIDE_GIT_DESCRIBE": "v1.2.3-5-g1234567"}) def test_override_git_describe_basic(self): """Test OVERRIDE_GIT_DESCRIBE with basic format.""" forced_version_from_env() # Check that the environment variable was processed - assert 'SETUPTOOLS_SCM_PRETEND_VERSION_FOR_DUCKDB' in os.environ + assert "SETUPTOOLS_SCM_PRETEND_VERSION_FOR_DUCKDB" in os.environ - @patch.dict('os.environ', {'OVERRIDE_GIT_DESCRIBE': 'v1.2.3-post1-3-g1234567'}) + @patch.dict("os.environ", {"OVERRIDE_GIT_DESCRIBE": "v1.2.3-post1-3-g1234567"}) def test_override_git_describe_post_release(self): """Test OVERRIDE_GIT_DESCRIBE with post-release format.""" forced_version_from_env() # Check that post-release was converted correctly - assert 'SETUPTOOLS_SCM_PRETEND_VERSION_FOR_DUCKDB' in os.environ + assert "SETUPTOOLS_SCM_PRETEND_VERSION_FOR_DUCKDB" in os.environ - @patch.dict('os.environ', {'OVERRIDE_GIT_DESCRIBE': 'invalid-format'}) + @patch.dict("os.environ", {"OVERRIDE_GIT_DESCRIBE": "invalid-format"}) def test_override_git_describe_invalid(self): """Test OVERRIDE_GIT_DESCRIBE with invalid format.""" with pytest.raises(ValueError, match="Invalid git describe override"): diff --git a/tests/fast/test_windows_abs_path.py b/tests/fast/test_windows_abs_path.py index bc9f05ec..4ce8311b 100644 --- a/tests/fast/test_windows_abs_path.py +++ b/tests/fast/test_windows_abs_path.py @@ -6,15 +6,15 @@ class TestWindowsAbsPath(object): def test_windows_path_accent(self): - if os.name != 'nt': + if os.name != "nt": return current_directory = os.getcwd() - test_dir = os.path.join(current_directory, 'tést') + test_dir = os.path.join(current_directory, "tést") if os.path.isdir(test_dir): shutil.rmtree(test_dir) os.mkdir(test_dir) - dbname = 'test.db' + dbname = "test.db" dbpath = os.path.join(test_dir, dbname) con = duckdb.connect(dbpath) con.execute("CREATE OR REPLACE TABLE int AS SELECT * FROM range(10) t(i)") @@ -23,8 +23,8 @@ def test_windows_path_accent(self): del res del con - os.chdir('tést') - dbpath = os.path.join('..', dbpath) + os.chdir("tést") + dbpath = os.path.join("..", dbpath) con = duckdb.connect(dbpath) res = con.execute("SELECT COUNT(*) FROM int").fetchall() assert res[0][0] == 10 @@ -37,13 +37,13 @@ def test_windows_path_accent(self): del res del con - os.chdir('..') + os.chdir("..") def test_windows_abs_path(self): - if os.name != 'nt': + if os.name != "nt": return current_directory = os.getcwd() - dbpath = os.path.join(current_directory, 'test.db') + dbpath = os.path.join(current_directory, "test.db") con = duckdb.connect(dbpath) con.execute("CREATE OR REPLACE TABLE int AS SELECT * FROM range(10) t(i)") res = con.execute("SELECT COUNT(*) FROM int").fetchall() @@ -51,7 +51,7 @@ def test_windows_abs_path(self): del res del con - assert dbpath[1] == ':' + assert dbpath[1] == ":" # remove the drive separator and reconnect dbpath = dbpath[2:] con = duckdb.connect(dbpath) @@ -61,7 +61,7 @@ def test_windows_abs_path(self): del con # forward slashes work as well - dbpath = dbpath.replace('\\', '/') + dbpath = dbpath.replace("\\", "/") con = duckdb.connect(dbpath) res = con.execute("SELECT COUNT(*) FROM int").fetchall() assert res[0][0] == 10 diff --git a/tests/fast/types/test_blob.py b/tests/fast/types/test_blob.py index 162859d2..0d331f7f 100644 --- a/tests/fast/types/test_blob.py +++ b/tests/fast/types/test_blob.py @@ -6,8 +6,8 @@ class TestBlob(object): def test_blob(self, duckdb_cursor): duckdb_cursor.execute("SELECT BLOB 'hello'") results = duckdb_cursor.fetchall() - assert results[0][0] == b'hello' + assert results[0][0] == b"hello" duckdb_cursor.execute("SELECT BLOB 'hello' AS a") results = duckdb_cursor.fetchnumpy() - assert results['a'] == numpy.array([b'hello'], dtype=object) + assert results["a"] == numpy.array([b"hello"], dtype=object) diff --git a/tests/fast/types/test_datetime_datetime.py b/tests/fast/types/test_datetime_datetime.py index 08a9953d..2df14b18 100644 --- a/tests/fast/types/test_datetime_datetime.py +++ b/tests/fast/types/test_datetime_datetime.py @@ -4,29 +4,29 @@ def create_query(positive, type): - inf = 'infinity' if positive else '-infinity' + inf = "infinity" if positive else "-infinity" return f""" select '{inf}'::{type} """ class TestDateTimeDateTime(object): - @pytest.mark.parametrize('positive', [True, False]) + @pytest.mark.parametrize("positive", [True, False]) @pytest.mark.parametrize( - 'type', + "type", [ - 'TIMESTAMP', - 'TIMESTAMP_S', - 'TIMESTAMP_MS', - 'TIMESTAMP_NS', - 'TIMESTAMPTZ', - 'TIMESTAMP_US', + "TIMESTAMP", + "TIMESTAMP_S", + "TIMESTAMP_MS", + "TIMESTAMP_NS", + "TIMESTAMPTZ", + "TIMESTAMP_US", ], ) def test_timestamp_infinity(self, positive, type): con = duckdb.connect() - if type in ['TIMESTAMP_S', 'TIMESTAMP_MS', 'TIMESTAMP_NS']: + if type in ["TIMESTAMP_S", "TIMESTAMP_MS", "TIMESTAMP_NS"]: # Infinity (both positive and negative) is not supported for non-usecond timetamps return diff --git a/tests/fast/types/test_decimal.py b/tests/fast/types/test_decimal.py index 30cb13e7..b068056d 100644 --- a/tests/fast/types/test_decimal.py +++ b/tests/fast/types/test_decimal.py @@ -6,21 +6,21 @@ class TestDecimal(object): def test_decimal(self, duckdb_cursor): duckdb_cursor.execute( - 'SELECT 1.2::DECIMAL(4,1), 100.3::DECIMAL(9,1), 320938.4298::DECIMAL(18,4), 49082094824.904820482094::DECIMAL(30,12), NULL::DECIMAL' + "SELECT 1.2::DECIMAL(4,1), 100.3::DECIMAL(9,1), 320938.4298::DECIMAL(18,4), 49082094824.904820482094::DECIMAL(30,12), NULL::DECIMAL" ) result = duckdb_cursor.fetchall() assert result == [ - (Decimal('1.2'), Decimal('100.3'), Decimal('320938.4298'), Decimal('49082094824.904820482094'), None) + (Decimal("1.2"), Decimal("100.3"), Decimal("320938.4298"), Decimal("49082094824.904820482094"), None) ] def test_decimal_numpy(self, duckdb_cursor): duckdb_cursor.execute( - 'SELECT 1.2::DECIMAL(4,1) AS a, 100.3::DECIMAL(9,1) AS b, 320938.4298::DECIMAL(18,4) AS c, 49082094824.904820482094::DECIMAL(30,12) AS d' + "SELECT 1.2::DECIMAL(4,1) AS a, 100.3::DECIMAL(9,1) AS b, 320938.4298::DECIMAL(18,4) AS c, 49082094824.904820482094::DECIMAL(30,12) AS d" ) result = duckdb_cursor.fetchnumpy() assert result == { - 'a': numpy.array([1.2]), - 'b': numpy.array([100.3]), - 'c': numpy.array([320938.4298]), - 'd': numpy.array([49082094824.904820482094]), + "a": numpy.array([1.2]), + "b": numpy.array([100.3]), + "c": numpy.array([320938.4298]), + "d": numpy.array([49082094824.904820482094]), } diff --git a/tests/fast/types/test_hugeint.py b/tests/fast/types/test_hugeint.py index f0254380..e9b5016a 100644 --- a/tests/fast/types/test_hugeint.py +++ b/tests/fast/types/test_hugeint.py @@ -4,11 +4,11 @@ class TestHugeint(object): def test_hugeint(self, duckdb_cursor): - duckdb_cursor.execute('SELECT 437894723897234238947043214') + duckdb_cursor.execute("SELECT 437894723897234238947043214") result = duckdb_cursor.fetchall() assert result == [(437894723897234238947043214,)] def test_hugeint_numpy(self, duckdb_cursor): - duckdb_cursor.execute('SELECT 1::HUGEINT AS i') + duckdb_cursor.execute("SELECT 1::HUGEINT AS i") result = duckdb_cursor.fetchnumpy() - assert result == {'i': numpy.array([1.0])} + assert result == {"i": numpy.array([1.0])} diff --git a/tests/fast/types/test_nan.py b/tests/fast/types/test_nan.py index b714ae6c..fe99a990 100644 --- a/tests/fast/types/test_nan.py +++ b/tests/fast/types/test_nan.py @@ -15,34 +15,34 @@ def test_pandas_nan(self, duckdb_cursor): # now create a new column with the current time # (FIXME: we replace the microseconds with 0 for now, because we only support millisecond resolution) current_time = datetime.datetime.now().replace(microsecond=0) - df['datetest'] = current_time + df["datetest"] = current_time # introduce a NaT (Not a Time value) - df.loc[0, 'datetest'] = pandas.NaT + df.loc[0, "datetest"] = pandas.NaT # now pass the DF through duckdb: - conn = duckdb.connect(':memory:') - conn.register('testing_null_values', df) + conn = duckdb.connect(":memory:") + conn.register("testing_null_values", df) # scan the DF and fetch the results normally - results = conn.execute('select * from testing_null_values').fetchall() - assert results[0][0] == 'val1' + results = conn.execute("select * from testing_null_values").fetchall() + assert results[0][0] == "val1" assert results[0][1] == 1.05 assert results[0][2] == None assert results[0][3] == None - assert results[1][0] == 'val3' + assert results[1][0] == "val3" assert results[1][1] == None - assert results[1][2] == 'val3' + assert results[1][2] == "val3" assert results[1][3] == current_time # now fetch the results as numpy: - result_np = conn.execute('select * from testing_null_values').fetchnumpy() - assert result_np['col1'][0] == df['col1'][0] - assert result_np['col1'][1] == df['col1'][1] - assert result_np['col2'][0] == df['col2'][0] + result_np = conn.execute("select * from testing_null_values").fetchnumpy() + assert result_np["col1"][0] == df["col1"][0] + assert result_np["col1"][1] == df["col1"][1] + assert result_np["col2"][0] == df["col2"][0] - assert result_np['col2'].mask[1] - assert result_np['newcol1'].mask[0] - assert result_np['newcol1'][1] == df['newcol1'][1] + assert result_np["col2"].mask[1] + assert result_np["newcol1"].mask[0] + assert result_np["newcol1"][1] == df["newcol1"][1] - result_df = conn.execute('select * from testing_null_values').fetchdf() - assert pandas.isnull(result_df['datetest'][0]) - assert result_df['datetest'][1] == df['datetest'][1] + result_df = conn.execute("select * from testing_null_values").fetchdf() + assert pandas.isnull(result_df["datetest"][0]) + assert result_df["datetest"][1] == df["datetest"][1] diff --git a/tests/fast/types/test_nested.py b/tests/fast/types/test_nested.py index e005b3f3..7f777384 100644 --- a/tests/fast/types/test_nested.py +++ b/tests/fast/types/test_nested.py @@ -23,24 +23,24 @@ def test_nested_lists(self, duckdb_cursor): def test_struct(self, duckdb_cursor): result = duckdb_cursor.execute("SELECT STRUCT_PACK(a := 42, b := 43)").fetchall() - assert result == [({'a': 42, 'b': 43},)] + assert result == [({"a": 42, "b": 43},)] result = duckdb_cursor.execute("SELECT STRUCT_PACK(a := 42, b := NULL)").fetchall() - assert result == [({'a': 42, 'b': None},)] + assert result == [({"a": 42, "b": None},)] def test_unnamed_struct(self, duckdb_cursor): result = duckdb_cursor.execute("SELECT row('aa','bb') AS x").fetchall() - assert result == [(('aa', 'bb'),)] + assert result == [(("aa", "bb"),)] result = duckdb_cursor.execute("SELECT row('aa',NULL) AS x").fetchall() - assert result == [(('aa', None),)] + assert result == [(("aa", None),)] def test_nested_struct(self, duckdb_cursor): result = duckdb_cursor.execute("SELECT STRUCT_PACK(a := 42, b := LIST_VALUE(10, 9, 8, 7))").fetchall() - assert result == [({'a': 42, 'b': [10, 9, 8, 7]},)] + assert result == [({"a": 42, "b": [10, 9, 8, 7]},)] result = duckdb_cursor.execute("SELECT STRUCT_PACK(a := 42, b := LIST_VALUE(10, 9, 8, NULL))").fetchall() - assert result == [({'a': 42, 'b': [10, 9, 8, None]},)] + assert result == [({"a": 42, "b": [10, 9, 8, None]},)] def test_map(self, duckdb_cursor): result = duckdb_cursor.execute("select MAP(LIST_VALUE(1, 2, 3, 4),LIST_VALUE(10, 9, 8, 7))").fetchall() diff --git a/tests/fast/types/test_numpy.py b/tests/fast/types/test_numpy.py index 42ae33a0..40b1a5de 100644 --- a/tests/fast/types/test_numpy.py +++ b/tests/fast/types/test_numpy.py @@ -11,7 +11,7 @@ def test_numpy_datetime64(self, duckdb_cursor): duckdb_con.execute("create table tbl(col TIMESTAMP)") duckdb_con.execute( "insert into tbl VALUES (CAST(? AS TIMESTAMP WITHOUT TIME ZONE))", - parameters=[np.datetime64('2022-02-08T06:01:38.761310')], + parameters=[np.datetime64("2022-02-08T06:01:38.761310")], ) assert [(datetime.datetime(2022, 2, 8, 6, 1, 38, 761310),)] == duckdb_con.execute( "select * from tbl" @@ -24,11 +24,11 @@ def test_numpy_datetime_big(self): duckdb_con.execute("INSERT INTO TEST VALUES ('2263-02-28')") res1 = duckdb_con.execute("select * from test").fetchnumpy() - date_value = {'date': np.array(['2263-02-28'], dtype='datetime64[us]')} + date_value = {"date": np.array(["2263-02-28"], dtype="datetime64[us]")} assert res1 == date_value def test_numpy_enum_conversion(self, duckdb_cursor): - arr = np.array(['a', 'b', 'c']) + arr = np.array(["a", "b", "c"]) rel = duckdb_cursor.sql("select * from arr") - res = rel.fetchnumpy()['column0'] + res = rel.fetchnumpy()["column0"] np.testing.assert_equal(res, arr) diff --git a/tests/fast/types/test_object_int.py b/tests/fast/types/test_object_int.py index ce153d49..ed3a8d14 100644 --- a/tests/fast/types/test_object_int.py +++ b/tests/fast/types/test_object_int.py @@ -12,19 +12,19 @@ def test_object_integer(self, duckdb_cursor): pd = pytest.importorskip("pandas") df_in = pd.DataFrame( { - 'int8': pd.Series([None, 1, -1], dtype="Int8"), - 'int16': pd.Series([None, 1, -1], dtype="Int16"), - 'int32': pd.Series([None, 1, -1], dtype="Int32"), - 'int64': pd.Series([None, 1, -1], dtype="Int64"), + "int8": pd.Series([None, 1, -1], dtype="Int8"), + "int16": pd.Series([None, 1, -1], dtype="Int16"), + "int32": pd.Series([None, 1, -1], dtype="Int32"), + "int64": pd.Series([None, 1, -1], dtype="Int64"), } ) - warnings.simplefilter(action='ignore', category=RuntimeWarning) + warnings.simplefilter(action="ignore", category=RuntimeWarning) df_expected_res = pd.DataFrame( { - 'int8': pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype='Int8'), - 'int16': pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype='Int16'), - 'int32': pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype='Int32'), - 'int64': pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype='Int64'), + "int8": pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype="Int8"), + "int16": pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype="Int16"), + "int32": pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype="Int32"), + "int64": pd.Series(np.ma.masked_array([0, 1, -1], mask=[True, False, False]), dtype="Int64"), } ) df_out = duckdb.query_df(df_in, "data", "SELECT * FROM data").df() @@ -37,22 +37,22 @@ def test_object_uinteger(self, duckdb_cursor): with suppress(TypeError): df_in = pd.DataFrame( { - 'uint8': pd.Series([None, 1, 255], dtype="UInt8"), - 'uint16': pd.Series([None, 1, 65535], dtype="UInt16"), - 'uint32': pd.Series([None, 1, 4294967295], dtype="UInt32"), - 'uint64': pd.Series([None, 1, 18446744073709551615], dtype="UInt64"), + "uint8": pd.Series([None, 1, 255], dtype="UInt8"), + "uint16": pd.Series([None, 1, 65535], dtype="UInt16"), + "uint32": pd.Series([None, 1, 4294967295], dtype="UInt32"), + "uint64": pd.Series([None, 1, 18446744073709551615], dtype="UInt64"), } ) - warnings.simplefilter(action='ignore', category=RuntimeWarning) + warnings.simplefilter(action="ignore", category=RuntimeWarning) df_expected_res = pd.DataFrame( { - 'uint8': pd.Series(np.ma.masked_array([0, 1, 255], mask=[True, False, False]), dtype='UInt8'), - 'uint16': pd.Series(np.ma.masked_array([0, 1, 65535], mask=[True, False, False]), dtype='UInt16'), - 'uint32': pd.Series( - np.ma.masked_array([0, 1, 4294967295], mask=[True, False, False]), dtype='UInt32' + "uint8": pd.Series(np.ma.masked_array([0, 1, 255], mask=[True, False, False]), dtype="UInt8"), + "uint16": pd.Series(np.ma.masked_array([0, 1, 65535], mask=[True, False, False]), dtype="UInt16"), + "uint32": pd.Series( + np.ma.masked_array([0, 1, 4294967295], mask=[True, False, False]), dtype="UInt32" ), - 'uint64': pd.Series( - np.ma.masked_array([0, 1, 18446744073709551615], mask=[True, False, False]), dtype='UInt64' + "uint64": pd.Series( + np.ma.masked_array([0, 1, 18446744073709551615], mask=[True, False, False]), dtype="UInt64" ), } ) @@ -63,20 +63,20 @@ def test_object_uinteger(self, duckdb_cursor): # Unsigned Masked float/double types def test_object_float(self, duckdb_cursor): # Require pandas 1.2.0 >= for this, because Float32|Float64 was not added before this version - pd = pytest.importorskip("pandas", '1.2.0') + pd = pytest.importorskip("pandas", "1.2.0") df_in = pd.DataFrame( { - 'float32': pd.Series([None, 1, 4294967295], dtype="Float32"), - 'float64': pd.Series([None, 1, 18446744073709551615], dtype="Float64"), + "float32": pd.Series([None, 1, 4294967295], dtype="Float32"), + "float64": pd.Series([None, 1, 18446744073709551615], dtype="Float64"), } ) df_expected_res = pd.DataFrame( { - 'float32': pd.Series( - np.ma.masked_array([0, 1, 4294967295], mask=[True, False, False]), dtype='float32' + "float32": pd.Series( + np.ma.masked_array([0, 1, 4294967295], mask=[True, False, False]), dtype="float32" ), - 'float64': pd.Series( - np.ma.masked_array([0, 1, 18446744073709551615], mask=[True, False, False]), dtype='float64' + "float64": pd.Series( + np.ma.masked_array([0, 1, 18446744073709551615], mask=[True, False, False]), dtype="float64" ), } ) diff --git a/tests/fast/types/test_time_tz.py b/tests/fast/types/test_time_tz.py index 66475df8..eceed79a 100644 --- a/tests/fast/types/test_time_tz.py +++ b/tests/fast/types/test_time_tz.py @@ -11,7 +11,7 @@ class TestTimeTz(object): def test_time_tz(self, duckdb_cursor): df = pandas.DataFrame({"col1": [time(1, 2, 3, tzinfo=timezone.utc)]}) - sql = f'SELECT * FROM df' + sql = f"SELECT * FROM df" duckdb_cursor.execute(sql) diff --git a/tests/fast/types/test_unsigned.py b/tests/fast/types/test_unsigned.py index 6ac50727..a35a2216 100644 --- a/tests/fast/types/test_unsigned.py +++ b/tests/fast/types/test_unsigned.py @@ -1,7 +1,7 @@ class TestUnsigned(object): def test_unsigned(self, duckdb_cursor): - duckdb_cursor.execute('create table unsigned (a utinyint, b usmallint, c uinteger, d ubigint)') - duckdb_cursor.execute('insert into unsigned values (1,1,1,1), (null,null,null,null)') - duckdb_cursor.execute('select * from unsigned order by a nulls first') + duckdb_cursor.execute("create table unsigned (a utinyint, b usmallint, c uinteger, d ubigint)") + duckdb_cursor.execute("insert into unsigned values (1,1,1,1), (null,null,null,null)") + duckdb_cursor.execute("select * from unsigned order by a nulls first") result = duckdb_cursor.fetchall() assert result == [(None, None, None, None), (1, 1, 1, 1)] diff --git a/tests/fast/udf/test_null_filtering.py b/tests/fast/udf/test_null_filtering.py index 208a9246..fd5b45d0 100644 --- a/tests/fast/udf/test_null_filtering.py +++ b/tests/fast/udf/test_null_filtering.py @@ -2,7 +2,7 @@ import pytest pd = pytest.importorskip("pandas") -pa = pytest.importorskip('pyarrow', '18.0.0') +pa = pytest.importorskip("pyarrow", "18.0.0") from typing import Union import pyarrow.compute as pc import uuid @@ -22,11 +22,11 @@ class Candidate(NamedTuple): def layout(index: int): return [ - ['x', 'x', 'y'], - ['x', None, 'y'], - [None, 'y', None], - ['x', None, None], - [None, None, 'y'], + ["x", "x", "y"], + ["x", None, "y"], + [None, "y", None], + ["x", None, None], + [None, None, "y"], [None, None, None], ][index] @@ -36,14 +36,14 @@ def add_variations(data, index: int): data.extend( [ { - 'a': layout(index), - 'b': layout(0), - 'c': layout(0), + "a": layout(index), + "b": layout(0), + "c": layout(0), }, { - 'a': layout(0), - 'b': layout(0), - 'c': layout(index), + "a": layout(0), + "b": layout(0), + "c": layout(index), }, ] ) @@ -83,9 +83,9 @@ def get_types(): 2147483647, ), Candidate(UBIGINT, 18446744073709551615, 9223372036854776000), - Candidate(VARCHAR, 'long_string_test', 'smallstring'), + Candidate(VARCHAR, "long_string_test", "smallstring"), Candidate( - UUID, uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff'), uuid.UUID('ffffffff-ffff-ffff-ffff-000000000000') + UUID, uuid.UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"), uuid.UUID("ffffffff-ffff-ffff-ffff-000000000000") ), Candidate( FLOAT, @@ -106,8 +106,8 @@ def get_types(): ), Candidate( BLOB, - b'\xf6\x96\xb0\x85', - b'\x85\xb0\x96\xf6', + b"\xf6\x96\xb0\x85", + b"\x85\xb0\x96\xf6", ), Candidate( INTERVAL, @@ -120,24 +120,24 @@ def get_types(): False, ), Candidate( - duckdb.struct_type(['BIGINT[]', 'VARCHAR[]']), - {'v1': [1, 2, 3], 'v2': ['a', 'non-inlined string', 'duckdb']}, - {'v1': [5, 4, 3, 2, 1], 'v2': ['non-inlined-string', 'a', 'b', 'c', 'duckdb']}, + duckdb.struct_type(["BIGINT[]", "VARCHAR[]"]), + {"v1": [1, 2, 3], "v2": ["a", "non-inlined string", "duckdb"]}, + {"v1": [5, 4, 3, 2, 1], "v2": ["non-inlined-string", "a", "b", "c", "duckdb"]}, ), - Candidate(duckdb.list_type('VARCHAR'), ['the', 'duck', 'non-inlined string'], ['non-inlined-string', 'test']), + Candidate(duckdb.list_type("VARCHAR"), ["the", "duck", "non-inlined string"], ["non-inlined-string", "test"]), ] def construct_query(tuples) -> str: def construct_values_list(row, start_param_idx): parameter_count = len(row) - parameters = [f'${x+start_param_idx}' for x in range(parameter_count)] - parameters = '(' + ', '.join(parameters) + ')' + parameters = [f"${x + start_param_idx}" for x in range(parameter_count)] + parameters = "(" + ", ".join(parameters) + ")" return parameters row_size = len(tuples[0]) values_list = [construct_values_list(x, 1 + (i * row_size)) for i, x in enumerate(tuples)] - values_list = ', '.join(values_list) + values_list = ", ".join(values_list) query = f""" select * from (values {values_list}) @@ -154,19 +154,19 @@ def construct_parameters(tuples, dbtype): class TestUDFNullFiltering(object): @pytest.mark.parametrize( - 'table_data', + "table_data", get_table_data(), ) @pytest.mark.parametrize( - 'test_type', + "test_type", get_types(), ) - @pytest.mark.parametrize('udf_type', ['arrow', 'native']) + @pytest.mark.parametrize("udf_type", ["arrow", "native"]) def test_null_filtering(self, duckdb_cursor, table_data: dict, test_type: Candidate, udf_type): null_count = sum([1 for x in list(zip(*table_data.values())) if any([y == None for y in x])]) row_count = len(table_data) table_data = { - key: [None if not x else test_type.variant_one if x == 'x' else test_type.variant_two for x in value] + key: [None if not x else test_type.variant_one if x == "x" else test_type.variant_two for x in value] for key, value in table_data.items() } @@ -174,21 +174,21 @@ def test_null_filtering(self, duckdb_cursor, table_data: dict, test_type: Candid query = construct_query(tuples) parameters = construct_parameters(tuples, test_type.type) rel = duckdb_cursor.sql(query + " t(a, b, c)", params=parameters) - rel.to_table('tbl') + rel.to_table("tbl") rel.show() def my_func(*args): - if udf_type == 'arrow': + if udf_type == "arrow": my_func.count += len(args[0]) else: my_func.count += 1 return args[0] def create_parameters(table_data, dbtype): - return ", ".join(f'{key}::{dbtype}' for key in list(table_data.keys())) + return ", ".join(f"{key}::{dbtype}" for key in list(table_data.keys())) my_func.count = 0 - duckdb_cursor.create_function('test', my_func, None, test_type.type, type=udf_type) + duckdb_cursor.create_function("test", my_func, None, test_type.type, type=udf_type) query = f"select test({create_parameters(table_data, test_type.type)}) from tbl" result = duckdb_cursor.sql(query).fetchall() @@ -201,7 +201,7 @@ def create_parameters(table_data, dbtype): assert my_func.count == row_count - null_count @pytest.mark.parametrize( - 'table_data', + "table_data", [ [1, 2, 3, 4], [1, 2, None, 4], @@ -211,14 +211,14 @@ def test_nulls_from_default_null_handling_native(self, duckdb_cursor, table_data def returns_null(x): return None - df = pd.DataFrame({'a': table_data}) + df = pd.DataFrame({"a": table_data}) duckdb_cursor.execute("create table tbl as select * from df") - duckdb_cursor.create_function('test', returns_null, [str], int, type='native') - with pytest.raises(duckdb.InvalidInputException, match='The UDF is not expected to return NULL values'): + duckdb_cursor.create_function("test", returns_null, [str], int, type="native") + with pytest.raises(duckdb.InvalidInputException, match="The UDF is not expected to return NULL values"): result = duckdb_cursor.sql("select test(a::VARCHAR) from tbl").fetchall() @pytest.mark.parametrize( - 'table_data', + "table_data", [ [1, 2, 3, 4], [1, 2, None, 4], @@ -229,9 +229,9 @@ def returns_null(x): l = x.to_pylist() return pa.array([None for _ in l], type=pa.int64()) - df = pd.DataFrame({'a': table_data}) + df = pd.DataFrame({"a": table_data}) duckdb_cursor.execute("create table tbl as select * from df") - duckdb_cursor.create_function('test', returns_null, [str], int, type='arrow') - with pytest.raises(duckdb.InvalidInputException, match='The UDF is not expected to return NULL values'): + duckdb_cursor.create_function("test", returns_null, [str], int, type="arrow") + with pytest.raises(duckdb.InvalidInputException, match="The UDF is not expected to return NULL values"): result = duckdb_cursor.sql("select test(a::VARCHAR) from tbl").fetchall() print(result) diff --git a/tests/fast/udf/test_remove_function.py b/tests/fast/udf/test_remove_function.py index e67045c4..d03fd7e6 100644 --- a/tests/fast/udf/test_remove_function.py +++ b/tests/fast/udf/test_remove_function.py @@ -21,37 +21,37 @@ def test_not_created(self): duckdb.InvalidInputException, match="No function by the name of 'not_a_registered_function' was found in the list of registered functions", ): - con.remove_function('not_a_registered_function') + con.remove_function("not_a_registered_function") def test_double_remove(self): def func(x: int) -> int: return x con = duckdb.connect() - con.create_function('func', func) - con.sql('select func(42)') - con.remove_function('func') + con.create_function("func", func) + con.sql("select func(42)") + con.remove_function("func") with pytest.raises( duckdb.InvalidInputException, match="No function by the name of 'func' was found in the list of registered functions", ): - con.remove_function('func') + con.remove_function("func") - with pytest.raises(duckdb.CatalogException, match='Scalar Function with name func does not exist!'): - con.sql('select func(42)') + with pytest.raises(duckdb.CatalogException, match="Scalar Function with name func does not exist!"): + con.sql("select func(42)") def test_use_after_remove(self): def func(x: int) -> int: return x con = duckdb.connect() - con.create_function('func', func) - rel = con.sql('select func(42)') - con.remove_function('func') + con.create_function("func", func) + rel = con.sql("select func(42)") + con.remove_function("func") """ Error: Catalog Error: Scalar Function with name func does not exist! """ - with pytest.raises(duckdb.CatalogException, match='Scalar Function with name func does not exist!'): + with pytest.raises(duckdb.CatalogException, match="Scalar Function with name func does not exist!"): res = rel.fetchall() def test_use_after_remove_and_recreation(self): @@ -59,18 +59,18 @@ def func(x: str) -> str: return x con = duckdb.connect() - con.create_function('func', func) + con.create_function("func", func) - with pytest.raises(duckdb.BinderException, match='No function matches the given name'): - rel1 = con.sql('select func(42)') + with pytest.raises(duckdb.BinderException, match="No function matches the given name"): + rel1 = con.sql("select func(42)") rel2 = con.sql("select func('test'::VARCHAR)") - con.remove_function('func') + con.remove_function("func") def also_func(x: int) -> int: return x - con.create_function('func', also_func) - with pytest.raises(duckdb.BinderException, match='No function matches the given name'): + con.create_function("func", also_func) + with pytest.raises(duckdb.BinderException, match="No function matches the given name"): res = rel2.fetchall() def test_overwrite_name(self): @@ -79,7 +79,7 @@ def func(x): con = duckdb.connect() # create first version of the function - con.create_function('func', func, [BIGINT], BIGINT) + con.create_function("func", func, [BIGINT], BIGINT) # create relation that uses the function rel1 = con.sql("select func('3')") @@ -91,17 +91,17 @@ def other_func(x): duckdb.NotImplementedException, match="A function by the name of 'func' is already created, creating multiple functions with the same name is not supported yet, please remove it first", ): - con.create_function('func', other_func, [VARCHAR], VARCHAR) + con.create_function("func", other_func, [VARCHAR], VARCHAR) - con.remove_function('func') + con.remove_function("func") with pytest.raises( - duckdb.CatalogException, match='Catalog Error: Scalar Function with name func does not exist!' + duckdb.CatalogException, match="Catalog Error: Scalar Function with name func does not exist!" ): # Attempted to execute the relation using the 'func' function, but it was deleted rel1.fetchall() - con.create_function('func', other_func, [VARCHAR], VARCHAR) + con.create_function("func", other_func, [VARCHAR], VARCHAR) # create relation that uses the new version rel2 = con.sql("select func('test')") @@ -109,5 +109,5 @@ def other_func(x): res1 = rel1.fetchall() res2 = rel2.fetchall() # This has been converted to string, because the previous version of the function no longer exists - assert res1 == [('3',)] - assert res2 == [('test',)] + assert res1 == [("3",)] + assert res2 == [("test",)] diff --git a/tests/fast/udf/test_scalar.py b/tests/fast/udf/test_scalar.py index 8e0eb8b1..c156f94b 100644 --- a/tests/fast/udf/test_scalar.py +++ b/tests/fast/udf/test_scalar.py @@ -3,7 +3,7 @@ import pytest pd = pytest.importorskip("pandas") -pa = pytest.importorskip('pyarrow', '18.0.0') +pa = pytest.importorskip("pyarrow", "18.0.0") from typing import Union, Any import pyarrow.compute as pc import uuid @@ -25,14 +25,14 @@ def test_base(x): test_base.__code__, test_base.__globals__, test_base.__name__, test_base.__defaults__, test_base.__closure__ ) # Add annotations for the return type and 'x' - test_function.__annotations__ = {'return': type, 'x': type} + test_function.__annotations__ = {"return": type, "x": type} return test_function class TestScalarUDF(object): - @pytest.mark.parametrize('function_type', ['native', 'arrow']) + @pytest.mark.parametrize("function_type", ["native", "arrow"]) @pytest.mark.parametrize( - 'test_type', + "test_type", [ (TINYINT, -42), (SMALLINT, -512), @@ -43,21 +43,21 @@ class TestScalarUDF(object): (UINTEGER, 4294967295), (UBIGINT, 18446744073709551615), (HUGEINT, 18446744073709551616), - (VARCHAR, 'long_string_test'), - (UUID, uuid.UUID('ffffffff-ffff-ffff-ffff-ffffffffffff')), + (VARCHAR, "long_string_test"), + (UUID, uuid.UUID("ffffffff-ffff-ffff-ffff-ffffffffffff")), (FLOAT, 0.12246409803628922), (DOUBLE, 123142.12312416293784721232344), (DATE, datetime.date(2005, 3, 11)), (TIMESTAMP, datetime.datetime(2009, 2, 13, 11, 5, 53)), (TIME, datetime.time(14, 1, 12)), - (BLOB, b'\xf6\x96\xb0\x85'), + (BLOB, b"\xf6\x96\xb0\x85"), (INTERVAL, datetime.timedelta(days=30969, seconds=999, microseconds=999999)), (BOOLEAN, True), ( - duckdb.struct_type(['BIGINT[]', 'VARCHAR[]']), - {'v1': [1, 2, 3], 'v2': ['a', 'non-inlined string', 'duckdb']}, + duckdb.struct_type(["BIGINT[]", "VARCHAR[]"]), + {"v1": [1, 2, 3], "v2": ["a", "non-inlined string", "duckdb"]}, ), - (duckdb.list_type('VARCHAR'), ['the', 'duck', 'non-inlined string']), + (duckdb.list_type("VARCHAR"), ["the", "duck", "non-inlined string"]), ], ) def test_type_coverage(self, test_type, function_type): @@ -67,7 +67,7 @@ def test_type_coverage(self, test_type, function_type): test_function = make_annotated_function(type) con = duckdb.connect() - con.create_function('test', test_function, type=function_type) + con.create_function("test", test_function, type=function_type) # Single value res = con.execute(f"select test(?::{str(type)})", [value]).fetchall() assert res[0][0] == value @@ -114,46 +114,46 @@ def test_type_coverage(self, test_type, function_type): # Using 'relation.project' con.execute(f"create table tbl as select ?::{str(type)} as x", [value]) - table_rel = con.table('tbl') - res = table_rel.project('test(x)').fetchall() + table_rel = con.table("tbl") + res = table_rel.project("test(x)").fetchall() assert res[0][0] == value - @pytest.mark.parametrize('udf_type', ['arrow', 'native']) + @pytest.mark.parametrize("udf_type", ["arrow", "native"]) def test_map_coverage(self, udf_type): def no_op(x): return x con = duckdb.connect() - map_type = con.map_type('VARCHAR', 'BIGINT') - con.create_function('test_map', no_op, [map_type], map_type, type=udf_type) + map_type = con.map_type("VARCHAR", "BIGINT") + con.create_function("test_map", no_op, [map_type], map_type, type=udf_type) rel = con.sql("select test_map(map(['non-inlined string', 'test', 'duckdb'], [42, 1337, 123]))") res = rel.fetchall() - assert res == [({'non-inlined string': 42, 'test': 1337, 'duckdb': 123},)] + assert res == [({"non-inlined string": 42, "test": 1337, "duckdb": 123},)] - @pytest.mark.parametrize('udf_type', ['arrow', 'native']) + @pytest.mark.parametrize("udf_type", ["arrow", "native"]) def test_exceptions(self, udf_type): def raises_exception(x): raise AttributeError("error") con = duckdb.connect() - con.create_function('raises', raises_exception, [BIGINT], BIGINT, type=udf_type) + con.create_function("raises", raises_exception, [BIGINT], BIGINT, type=udf_type) with pytest.raises( duckdb.InvalidInputException, - match=' Python exception occurred while executing the UDF: AttributeError: error', + match=" Python exception occurred while executing the UDF: AttributeError: error", ): - res = con.sql('select raises(3)').fetchall() + res = con.sql("select raises(3)").fetchall() - con.remove_function('raises') + con.remove_function("raises") con.create_function( - 'raises', raises_exception, [BIGINT], BIGINT, exception_handling='return_null', type=udf_type + "raises", raises_exception, [BIGINT], BIGINT, exception_handling="return_null", type=udf_type ) - res = con.sql('select raises(3) from range(5)').fetchall() + res = con.sql("select raises(3) from range(5)").fetchall() assert res == [(None,), (None,), (None,), (None,), (None,)] def test_non_callable(self): con = duckdb.connect() with pytest.raises(TypeError): - con.create_function('func', 5, [BIGINT], BIGINT, type='arrow') + con.create_function("func", 5, [BIGINT], BIGINT, type="arrow") class MyCallable: def __init__(self) -> None: @@ -163,22 +163,22 @@ def __call__(self, x) -> Any: return x my_callable = MyCallable() - con.create_function('func', my_callable, [BIGINT], BIGINT, type='arrow') - res = con.sql('select func(5)').fetchall() + con.create_function("func", my_callable, [BIGINT], BIGINT, type="arrow") + res = con.sql("select func(5)").fetchall() assert res == [(5,)] # pyarrow does not support creating an array filled with pd.NA values - @pytest.mark.parametrize('udf_type', ['native']) - @pytest.mark.parametrize('duckdb_type', [FLOAT, DOUBLE]) + @pytest.mark.parametrize("udf_type", ["native"]) + @pytest.mark.parametrize("duckdb_type", [FLOAT, DOUBLE]) def test_pd_nan(self, duckdb_type, udf_type): def return_pd_nan(): - if udf_type == 'native': + if udf_type == "native": return pd.NA con = duckdb.connect() - con.create_function('return_pd_nan', return_pd_nan, None, duckdb_type, null_handling='SPECIAL', type=udf_type) + con.create_function("return_pd_nan", return_pd_nan, None, duckdb_type, null_handling="SPECIAL", type=udf_type) - res = con.sql('select return_pd_nan()').fetchall() + res = con.sql("select return_pd_nan()").fetchall() assert res[0][0] == None def test_side_effects(self): @@ -190,21 +190,21 @@ def count() -> int: count.counter = 0 con = duckdb.connect() - con.create_function('my_counter', count, side_effects=False) - res = con.sql('select my_counter() from range(10)').fetchall() + con.create_function("my_counter", count, side_effects=False) + res = con.sql("select my_counter() from range(10)").fetchall() assert res == [(0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,), (0,)] count.counter = 0 - con.remove_function('my_counter') - con.create_function('my_counter', count, side_effects=True) - res = con.sql('select my_counter() from range(10)').fetchall() + con.remove_function("my_counter") + con.create_function("my_counter", count, side_effects=True) + res = con.sql("select my_counter() from range(10)").fetchall() assert res == [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,)] - @pytest.mark.parametrize('udf_type', ['arrow', 'native']) - @pytest.mark.parametrize('duckdb_type', [FLOAT, DOUBLE]) + @pytest.mark.parametrize("udf_type", ["arrow", "native"]) + @pytest.mark.parametrize("duckdb_type", [FLOAT, DOUBLE]) def test_np_nan(self, duckdb_type, udf_type): def return_np_nan(): - if udf_type == 'native': + if udf_type == "native": return np.nan else: import pyarrow as pa @@ -212,18 +212,18 @@ def return_np_nan(): return pa.chunked_array([[np.nan]], type=pa.float64()) con = duckdb.connect() - con.create_function('return_np_nan', return_np_nan, None, duckdb_type, null_handling='SPECIAL', type=udf_type) + con.create_function("return_np_nan", return_np_nan, None, duckdb_type, null_handling="SPECIAL", type=udf_type) - res = con.sql('select return_np_nan()').fetchall() + res = con.sql("select return_np_nan()").fetchall() assert pd.isnull(res[0][0]) - @pytest.mark.parametrize('udf_type', ['arrow', 'native']) - @pytest.mark.parametrize('duckdb_type', [FLOAT, DOUBLE]) + @pytest.mark.parametrize("udf_type", ["arrow", "native"]) + @pytest.mark.parametrize("duckdb_type", [FLOAT, DOUBLE]) def test_math_nan(self, duckdb_type, udf_type): def return_math_nan(): import cmath - if udf_type == 'native': + if udf_type == "native": return cmath.nan else: import pyarrow as pa @@ -232,15 +232,15 @@ def return_math_nan(): con = duckdb.connect() con.create_function( - 'return_math_nan', return_math_nan, None, duckdb_type, null_handling='SPECIAL', type=udf_type + "return_math_nan", return_math_nan, None, duckdb_type, null_handling="SPECIAL", type=udf_type ) - res = con.sql('select return_math_nan()').fetchall() + res = con.sql("select return_math_nan()").fetchall() assert pd.isnull(res[0][0]) - @pytest.mark.parametrize('udf_type', ['arrow', 'native']) + @pytest.mark.parametrize("udf_type", ["arrow", "native"]) @pytest.mark.parametrize( - 'data_type', + "data_type", [ TINYINT, SMALLINT, @@ -262,13 +262,13 @@ def return_math_nan(): BLOB, INTERVAL, BOOLEAN, - duckdb.struct_type(['BIGINT[]', 'VARCHAR[]']), - duckdb.list_type('VARCHAR'), + duckdb.struct_type(["BIGINT[]", "VARCHAR[]"]), + duckdb.list_type("VARCHAR"), ], ) def test_return_null(self, data_type, udf_type): def return_null(): - if udf_type == 'native': + if udf_type == "native": return None else: import pyarrow as pa @@ -276,8 +276,8 @@ def return_null(): return pa.nulls(1) con = duckdb.connect() - con.create_function('return_null', return_null, None, data_type, null_handling='special', type=udf_type) - rel = con.sql('select return_null() as x') + con.create_function("return_null", return_null, None, data_type, null_handling="special", type=udf_type) + rel = con.sql("select return_null() as x") assert rel.types[0] == data_type assert rel.fetchall()[0][0] == None @@ -286,13 +286,13 @@ def func(x: int) -> int: return x con = duckdb.connect() - rel = con.sql('select 42') + rel = con.sql("select 42") # Using fetchone keeps the result open, with a transaction rel.fetchone() - con.create_function('func', func) + con.create_function("func", func) rel.fetchall() - res = con.sql('select func(5)').fetchall() + res = con.sql("select func(5)").fetchall() assert res == [(5,)] diff --git a/tests/fast/udf/test_scalar_arrow.py b/tests/fast/udf/test_scalar_arrow.py index 5773c474..794ebc35 100644 --- a/tests/fast/udf/test_scalar_arrow.py +++ b/tests/fast/udf/test_scalar_arrow.py @@ -15,35 +15,35 @@ class TestPyArrowUDF(object): def test_basic_use(self): def plus_one(x): - table = pa.lib.Table.from_arrays([x], names=['c0']) + table = pa.lib.Table.from_arrays([x], names=["c0"]) import pandas as pd df = pd.DataFrame(x.to_pandas()) - df['c0'] = df['c0'] + 1 + df["c0"] = df["c0"] + 1 return pa.lib.Table.from_pandas(df) con = duckdb.connect() - con.create_function('plus_one', plus_one, [BIGINT], BIGINT, type='arrow') - assert [(6,)] == con.sql('select plus_one(5)').fetchall() + con.create_function("plus_one", plus_one, [BIGINT], BIGINT, type="arrow") + assert [(6,)] == con.sql("select plus_one(5)").fetchall() - range_table = con.table_function('range', [5000]) - res = con.sql('select plus_one(i) from range_table tbl(i)').fetchall() + range_table = con.table_function("range", [5000]) + res = con.sql("select plus_one(i) from range_table tbl(i)").fetchall() assert len(res) == 5000 vector_size = duckdb.__standard_vector_size__ - res = con.sql(f'select i, plus_one(i) from test_vector_types(NULL::BIGINT, false) t(i), range({vector_size})') + res = con.sql(f"select i, plus_one(i) from test_vector_types(NULL::BIGINT, false) t(i), range({vector_size})") assert len(res) == (vector_size * 11) # NOTE: This only works up to duckdb.__standard_vector_size__, # because we process up to STANDARD_VECTOR_SIZE tuples at a time def test_sort_table(self): def sort_table(x): - table = pa.lib.Table.from_arrays([x], names=['c0']) + table = pa.lib.Table.from_arrays([x], names=["c0"]) sorted_table = table.sort_by([("c0", "ascending")]) return sorted_table con = duckdb.connect() - con.create_function('sort_table', sort_table, [BIGINT], BIGINT, type='arrow') + con.create_function("sort_table", sort_table, [BIGINT], BIGINT, type="arrow") res = con.sql("select 100-i as original, sort_table(original) from range(100) tbl(i)").fetchall() assert res[0] == (100, 1) @@ -57,7 +57,7 @@ def variable_args(*args): con = duckdb.connect() # This function takes any number of arguments, returning the first column - con.create_function('varargs', variable_args, None, BIGINT, type='arrow') + con.create_function("varargs", variable_args, None, BIGINT, type="arrow") res = con.sql("""select varargs(5, '3', '2', 1, 0.12345)""").fetchall() assert res == [(5,)] @@ -70,7 +70,7 @@ def takes_string(col): con = duckdb.connect() # The return type of the function is set to BIGINT, but it takes a VARCHAR - con.create_function('pyarrow_string_to_num', takes_string, [VARCHAR], BIGINT, type='arrow') + con.create_function("pyarrow_string_to_num", takes_string, [VARCHAR], BIGINT, type="arrow") # Successful conversion res = con.sql("""select pyarrow_string_to_num('5')""").fetchall() @@ -84,14 +84,14 @@ def returns_two_columns(col): import pandas as pd # Return a pyarrow table consisting of two columns - return pa.lib.Table.from_pandas(pd.DataFrame({'a': [5, 4, 3], 'b': ['test', 'quack', 'duckdb']})) + return pa.lib.Table.from_pandas(pd.DataFrame({"a": [5, 4, 3], "b": ["test", "quack", "duckdb"]})) con = duckdb.connect() # Scalar functions only return a single value per tuple - con.create_function('two_columns', returns_two_columns, [BIGINT], BIGINT, type='arrow') + con.create_function("two_columns", returns_two_columns, [BIGINT], BIGINT, type="arrow") with pytest.raises( duckdb.InvalidInputException, - match='The returned table from a pyarrow scalar udf should only contain one column, found 2', + match="The returned table from a pyarrow scalar udf should only contain one column, found 2", ): res = con.sql("""select two_columns(5)""").fetchall() @@ -100,35 +100,35 @@ def returns_none(col): return None con = duckdb.connect() - con.create_function('will_crash', returns_none, [BIGINT], BIGINT, type='arrow') + con.create_function("will_crash", returns_none, [BIGINT], BIGINT, type="arrow") with pytest.raises(duckdb.Error, match="""Could not convert the result into an Arrow Table"""): res = con.sql("""select will_crash(5)""").fetchall() def test_empty_result(self): def return_empty(col): # Always returns an empty table - return pa.lib.Table.from_arrays([[]], names=['c0']) + return pa.lib.Table.from_arrays([[]], names=["c0"]) con = duckdb.connect() - con.create_function('empty_result', return_empty, [BIGINT], BIGINT, type='arrow') - with pytest.raises(duckdb.InvalidInputException, match='Returned pyarrow table should have 1 tuples, found 0'): + con.create_function("empty_result", return_empty, [BIGINT], BIGINT, type="arrow") + with pytest.raises(duckdb.InvalidInputException, match="Returned pyarrow table should have 1 tuples, found 0"): res = con.sql("""select empty_result(5)""").fetchall() def test_excessive_result(self): def return_too_many(col): # Always returns a table consisting of 5 tuples - return pa.lib.Table.from_arrays([[5, 4, 3, 2, 1]], names=['c0']) + return pa.lib.Table.from_arrays([[5, 4, 3, 2, 1]], names=["c0"]) con = duckdb.connect() - con.create_function('too_many_tuples', return_too_many, [BIGINT], BIGINT, type='arrow') - with pytest.raises(duckdb.InvalidInputException, match='Returned pyarrow table should have 1 tuples, found 5'): + con.create_function("too_many_tuples", return_too_many, [BIGINT], BIGINT, type="arrow") + with pytest.raises(duckdb.InvalidInputException, match="Returned pyarrow table should have 1 tuples, found 5"): res = con.sql("""select too_many_tuples(5)""").fetchall() def test_arrow_side_effects(self, duckdb_cursor): import random as r def random_arrow(x): - if not hasattr(random_arrow, 'data'): + if not hasattr(random_arrow, "data"): random_arrow.data = 0 input = x.to_pylist() @@ -158,17 +158,17 @@ def return_struct(col): ).fetch_arrow_table() con = duckdb.connect() - struct_type = con.struct_type({'a': BIGINT, 'b': VARCHAR, 'c': con.list_type(BIGINT)}) - con.create_function('return_struct', return_struct, [BIGINT], struct_type, type='arrow') + struct_type = con.struct_type({"a": BIGINT, "b": VARCHAR, "c": con.list_type(BIGINT)}) + con.create_function("return_struct", return_struct, [BIGINT], struct_type, type="arrow") res = con.sql("""select return_struct(5)""").fetchall() - assert res == [({'a': 5, 'b': 'test', 'c': [5, 3, 2]},)] + assert res == [({"a": 5, "b": "test", "c": [5, 3, 2]},)] def test_multiple_chunks(self): def return_unmodified(col): return col con = duckdb.connect() - con.create_function('unmodified', return_unmodified, [BIGINT], BIGINT, type='arrow') + con.create_function("unmodified", return_unmodified, [BIGINT], BIGINT, type="arrow") res = con.sql( """ select unmodified(i) from range(5000) tbl(i) @@ -176,19 +176,19 @@ def return_unmodified(col): ).fetchall() assert len(res) == 5000 - assert res == con.sql('select * from range(5000)').fetchall() + assert res == con.sql("select * from range(5000)").fetchall() def test_inferred(self): def func(x: int) -> int: import pandas as pd - df = pd.DataFrame({'c0': x}) - df['c0'] = df['c0'] ** 2 + df = pd.DataFrame({"c0": x}) + df["c0"] = df["c0"] ** 2 return pa.lib.Table.from_pandas(df) con = duckdb.connect() - con.create_function('inferred', func, type='arrow') - res = con.sql('select inferred(42)').fetchall() + con.create_function("inferred", func, type="arrow") + res = con.sql("select inferred(42)").fetchall() assert res == [(1764,)] def test_nulls(self): @@ -196,27 +196,27 @@ def return_five(x): import pandas as pd length = len(x) - return pa.lib.Table.from_pandas(pd.DataFrame({'a': [5 for _ in range(length)]})) + return pa.lib.Table.from_pandas(pd.DataFrame({"a": [5 for _ in range(length)]})) con = duckdb.connect() - con.create_function('return_five', return_five, [BIGINT], BIGINT, null_handling='special', type='arrow') - res = con.sql('select return_five(NULL) from range(10)').fetchall() + con.create_function("return_five", return_five, [BIGINT], BIGINT, null_handling="special", type="arrow") + res = con.sql("select return_five(NULL) from range(10)").fetchall() # without 'special' null handling these would all be NULL assert res == [(5,), (5,), (5,), (5,), (5,), (5,), (5,), (5,), (5,), (5,)] con = duckdb.connect() - con.create_function('return_five', return_five, [BIGINT], BIGINT, null_handling='default', type='arrow') - res = con.sql('select return_five(NULL) from range(10)').fetchall() + con.create_function("return_five", return_five, [BIGINT], BIGINT, null_handling="default", type="arrow") + res = con.sql("select return_five(NULL) from range(10)").fetchall() # Because we didn't specify 'special' null handling, these are all NULL assert res == [(None,), (None,), (None,), (None,), (None,), (None,), (None,), (None,), (None,), (None,)] def test_struct_with_non_inlined_string(self, duckdb_cursor): def func(data): - return pa.array([{'x': 1, 'y': 'this is not an inlined string'}] * data.length()) + return pa.array([{"x": 1, "y": "this is not an inlined string"}] * data.length()) duckdb_cursor.create_function( name="func", function=func, return_type="STRUCT(x integer, y varchar)", type="arrow", side_effects=False ) res = duckdb_cursor.sql("select func(1).y").fetchone() - assert res == ('this is not an inlined string',) + assert res == ("this is not an inlined string",) diff --git a/tests/fast/udf/test_scalar_native.py b/tests/fast/udf/test_scalar_native.py index df58f6a4..0c5cf927 100644 --- a/tests/fast/udf/test_scalar_native.py +++ b/tests/fast/udf/test_scalar_native.py @@ -11,8 +11,8 @@ def test_default_conn(self): def passthrough(x): return x - duckdb.create_function('default_conn_passthrough', passthrough, [BIGINT], BIGINT) - res = duckdb.sql('select default_conn_passthrough(5)').fetchall() + duckdb.create_function("default_conn_passthrough", passthrough, [BIGINT], BIGINT) + res = duckdb.sql("select default_conn_passthrough(5)").fetchall() assert res == [(5,)] def test_basic_use(self): @@ -22,15 +22,15 @@ def plus_one(x): return x + 1 con = duckdb.connect() - con.create_function('plus_one', plus_one, [BIGINT], BIGINT) - assert [(6,)] == con.sql('select plus_one(5)').fetchall() + con.create_function("plus_one", plus_one, [BIGINT], BIGINT) + assert [(6,)] == con.sql("select plus_one(5)").fetchall() - range_table = con.table_function('range', [5000]) - res = con.sql('select plus_one(i) from range_table tbl(i)').fetchall() + range_table = con.table_function("range", [5000]) + res = con.sql("select plus_one(i) from range_table tbl(i)").fetchall() assert len(res) == 5000 vector_size = duckdb.__standard_vector_size__ - res = con.sql(f'select i, plus_one(i) from test_vector_types(NULL::BIGINT, false) t(i), range({vector_size})') + res = con.sql(f"select i, plus_one(i) from test_vector_types(NULL::BIGINT, false) t(i), range({vector_size})") assert len(res) == (vector_size * 11) def test_passthrough(self): @@ -38,10 +38,10 @@ def passthrough(x): return x con = duckdb.connect() - con.create_function('passthrough', passthrough, [BIGINT], BIGINT) + con.create_function("passthrough", passthrough, [BIGINT], BIGINT) assert ( - con.sql('select passthrough(i) from range(5000) tbl(i)').fetchall() - == con.sql('select * from range(5000)').fetchall() + con.sql("select passthrough(i) from range(5000) tbl(i)").fetchall() + == con.sql("select * from range(5000)").fetchall() ) def test_execute(self): @@ -49,8 +49,8 @@ def func(x): return x % 2 con = duckdb.connect() - con.create_function('modulo_op', func, [BIGINT], TINYINT) - res = con.execute('select modulo_op(?)', [5]).fetchall() + con.create_function("modulo_op", func, [BIGINT], TINYINT) + res = con.execute("select modulo_op(?)", [5]).fetchall() assert res == [(1,)] def test_cast_output(self): @@ -58,7 +58,7 @@ def takes_string(x): return x con = duckdb.connect() - con.create_function('casts_from_string', takes_string, [VARCHAR], BIGINT) + con.create_function("casts_from_string", takes_string, [VARCHAR], BIGINT) res = con.sql("select casts_from_string('42')").fetchall() assert res == [(42,)] @@ -71,13 +71,13 @@ def concatenate(a: str, b: str): return a + b con = duckdb.connect() - con.create_function('py_concatenate', concatenate, None, VARCHAR) + con.create_function("py_concatenate", concatenate, None, VARCHAR) res = con.sql( """ select py_concatenate('5','3'); """ ).fetchall() - assert res[0][0] == '53' + assert res[0][0] == "53" def test_detected_return_type(self): def add_nums(*args) -> int: @@ -87,7 +87,7 @@ def add_nums(*args) -> int: return sum con = duckdb.connect() - con.create_function('add_nums', add_nums) + con.create_function("add_nums", add_nums) res = con.sql( """ select add_nums(5,3,2,1); @@ -101,20 +101,20 @@ def variable_args(*args): return amount con = duckdb.connect() - con.create_function('varargs', variable_args, None, BIGINT) + con.create_function("varargs", variable_args, None, BIGINT) res = con.sql("""select varargs('5', '3', '2', 1, 0.12345)""").fetchall() assert res == [(5,)] def test_return_incorrectly_typed_object(self): def returns_duckdb() -> int: - return 'duckdb' + return "duckdb" con = duckdb.connect() - con.create_function('fastest_database_in_the_west', returns_duckdb) + con.create_function("fastest_database_in_the_west", returns_duckdb) with pytest.raises( duckdb.InvalidInputException, match="Failed to cast value: Could not convert string 'duckdb' to INT64" ): - res = con.sql('select fastest_database_in_the_west()').fetchall() + res = con.sql("select fastest_database_in_the_west()").fetchall() def test_nulls(self): def five_if_null(x): @@ -123,12 +123,12 @@ def five_if_null(x): return x con = duckdb.connect() - con.create_function('null_test', five_if_null, [BIGINT], BIGINT, null_handling="SPECIAL") - res = con.sql('select null_test(NULL)').fetchall() + con.create_function("null_test", five_if_null, [BIGINT], BIGINT, null_handling="SPECIAL") + res = con.sql("select null_test(NULL)").fetchall() assert res == [(5,)] @pytest.mark.parametrize( - 'pair', + "pair", [ (TINYINT, -129), (TINYINT, 128), @@ -159,26 +159,26 @@ def return_overflow(): return overflowing_value con = duckdb.connect() - con.create_function('return_overflow', return_overflow, None, duckdb_type) + con.create_function("return_overflow", return_overflow, None, duckdb_type) with pytest.raises(duckdb.InvalidInputException): - rel = con.sql('select return_overflow()') + rel = con.sql("select return_overflow()") res = rel.fetchall() print(duckdb_type) print(res) def test_structs(self): def add_extra_column(original): - original['a'] = 200 - original['c'] = 0 + original["a"] = 200 + original["c"] = 0 return original con = duckdb.connect() - range_table = con.table_function('range', [5000]) + range_table = con.table_function("range", [5000]) con.create_function( "append_field", add_extra_column, - [duckdb.struct_type({'a': BIGINT, 'b': BIGINT})], - duckdb.struct_type({'a': BIGINT, 'b': BIGINT, 'c': BIGINT}), + [duckdb.struct_type({"a": BIGINT, "b": BIGINT})], + duckdb.struct_type({"a": BIGINT, "b": BIGINT, "c": BIGINT}), ) res = con.sql( @@ -205,17 +205,17 @@ def swap_keys(dict): return result con.create_function( - 'swap_keys', + "swap_keys", swap_keys, - [con.struct_type({'a': BIGINT, 'b': VARCHAR})], - con.struct_type({'a': VARCHAR, 'b': BIGINT}), + [con.struct_type({"a": BIGINT, "b": VARCHAR})], + con.struct_type({"a": VARCHAR, "b": BIGINT}), ) res = con.sql( """ select swap_keys({'a': 42, 'b': 'answer_to_life'}) """ ).fetchall() - assert res == [({'a': 'answer_to_life', 'b': 42},)] + assert res == [({"a": "answer_to_life", "b": 42},)] def test_struct_different_field_order(self, duckdb_cursor): def example(): diff --git a/tests/fast/udf/test_transactionality.py b/tests/fast/udf/test_transactionality.py index 50286e8e..134df663 100644 --- a/tests/fast/udf/test_transactionality.py +++ b/tests/fast/udf/test_transactionality.py @@ -3,7 +3,7 @@ class TestUDFTransactionality(object): - @pytest.mark.xfail(reason='fetchone() does not realize the stream result was closed before completion') + @pytest.mark.xfail(reason="fetchone() does not realize the stream result was closed before completion") def test_type_coverage(self, duckdb_cursor): rel = duckdb_cursor.sql("select * from range(4096)") res = rel.fetchone() @@ -12,7 +12,7 @@ def test_type_coverage(self, duckdb_cursor): def my_func(x: str) -> int: return int(x) - duckdb_cursor.create_function('test', my_func) + duckdb_cursor.create_function("test", my_func) - with pytest.raises(duckdb.InvalidInputException, match='result closed'): + with pytest.raises(duckdb.InvalidInputException, match="result closed"): res = rel.fetchone() diff --git a/tests/slow/test_h2oai_arrow.py b/tests/slow/test_h2oai_arrow.py index 40bde07b..b0901ab8 100644 --- a/tests/slow/test_h2oai_arrow.py +++ b/tests/slow/test_h2oai_arrow.py @@ -3,17 +3,17 @@ import math from pytest import mark, fixture, importorskip -read_csv = importorskip('pyarrow.csv').read_csv -requests = importorskip('requests') -requests_adapters = importorskip('requests.adapters') -urllib3_util = importorskip('urllib3.util') -np = importorskip('numpy') +read_csv = importorskip("pyarrow.csv").read_csv +requests = importorskip("requests") +requests_adapters = importorskip("requests.adapters") +urllib3_util = importorskip("urllib3.util") +np = importorskip("numpy") def group_by_q1(con): con.execute("CREATE TABLE ans AS SELECT id1, sum(v1) AS v1 FROM x GROUP BY id1") res = con.execute("SELECT COUNT(*), sum(v1)::varchar AS v1 FROM ans").fetchall() - assert res == [(96, '28498857')] + assert res == [(96, "28498857")] con.execute("DROP TABLE ans") @@ -155,7 +155,7 @@ def join_by_q5(con): class TestH2OAIArrow(object): @mark.parametrize( - 'function', + "function", [ group_by_q1, group_by_q2, @@ -169,15 +169,15 @@ class TestH2OAIArrow(object): group_by_q10, ], ) - @mark.parametrize('threads', [1, 4]) - @mark.usefixtures('group_by_data') + @mark.parametrize("threads", [1, 4]) + @mark.usefixtures("group_by_data") def test_group_by(self, threads, function, group_by_data): group_by_data.execute(f"PRAGMA threads={threads}") function(group_by_data) - @mark.parametrize('threads', [1, 4]) + @mark.parametrize("threads", [1, 4]) @mark.parametrize( - 'function', + "function", [ join_by_q1, join_by_q2, @@ -186,7 +186,7 @@ def test_group_by(self, threads, function, group_by_data): join_by_q5, ], ) - @mark.usefixtures('large_data') + @mark.usefixtures("large_data") def test_join(self, threads, function, large_data): large_data.execute(f"PRAGMA threads={threads}") @@ -198,7 +198,7 @@ def arrow_dataset_register(): """Single fixture to download files and register them on the given connection""" session = requests.Session() retries = urllib3_util.Retry( - allowed_methods={'GET'}, # only retry on GETs (all we do) + allowed_methods={"GET"}, # only retry on GETs (all we do) total=None, # disable to make the below take effect redirect=10, # Don't follow more than 10 redirects in a row connect=3, # try 3 times before giving up on connection errors @@ -211,12 +211,12 @@ def arrow_dataset_register(): raise_on_status=True, # raise exception when status error retries are exhausted respect_retry_after_header=True, # respect Retry-After headers ) - session.mount('https://', requests_adapters.HTTPAdapter(max_retries=retries)) + session.mount("https://", requests_adapters.HTTPAdapter(max_retries=retries)) saved_filenames = set() def _register(url, filename, con, tablename): r = session.get(url) - with open(filename, 'wb') as f: + with open(filename, "wb") as f: f.write(r.content) con.register(tablename, read_csv(filename)) saved_filenames.add(filename) @@ -232,26 +232,26 @@ def _register(url, filename, con, tablename): def large_data(arrow_dataset_register): con = duckdb.connect() arrow_dataset_register( - 'https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_NA_0_0.csv.gz', - 'J1_1e7_NA_0_0.csv.gz', + "https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_NA_0_0.csv.gz", + "J1_1e7_NA_0_0.csv.gz", con, "x", ) arrow_dataset_register( - 'https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_1e1_0_0.csv.gz', - 'J1_1e7_1e1_0_0.csv.gz', + "https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_1e1_0_0.csv.gz", + "J1_1e7_1e1_0_0.csv.gz", con, "small", ) arrow_dataset_register( - 'https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_1e4_0_0.csv.gz', - 'J1_1e7_1e4_0_0.csv.gz', + "https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_1e4_0_0.csv.gz", + "J1_1e7_1e4_0_0.csv.gz", con, "medium", ) arrow_dataset_register( - 'https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_1e7_0_0.csv.gz', - 'J1_1e7_1e7_0_0.csv.gz', + "https://github.com/duckdb/duckdb-data/releases/download/v1.0/J1_1e7_1e7_0_0.csv.gz", + "J1_1e7_1e7_0_0.csv.gz", con, "big", ) @@ -263,8 +263,8 @@ def large_data(arrow_dataset_register): def group_by_data(arrow_dataset_register): con = duckdb.connect() arrow_dataset_register( - 'https://github.com/duckdb/duckdb-data/releases/download/v1.0/G1_1e7_1e2_5_0.csv.gz', - 'G1_1e7_1e2_5_0.csv.gz', + "https://github.com/duckdb/duckdb-data/releases/download/v1.0/G1_1e7_1e2_5_0.csv.gz", + "G1_1e7_1e2_5_0.csv.gz", con, "x", ) diff --git a/tests/stubs/test_stubs.py b/tests/stubs/test_stubs.py index 2f178bcc..c68f7068 100644 --- a/tests/stubs/test_stubs.py +++ b/tests/stubs/test_stubs.py @@ -2,18 +2,18 @@ from mypy import stubtest -MYPY_INI_PATH = os.path.join(os.path.dirname(__file__), 'mypy.ini') +MYPY_INI_PATH = os.path.join(os.path.dirname(__file__), "mypy.ini") def test_generated_stubs(): - skip_stubs_errors = ['pybind11', 'git_revision', 'is inconsistent, metaclass differs'] + skip_stubs_errors = ["pybind11", "git_revision", "is inconsistent, metaclass differs"] - options = stubtest.parse_options(['duckdb', '--mypy-config-file', MYPY_INI_PATH]) + options = stubtest.parse_options(["duckdb", "--mypy-config-file", MYPY_INI_PATH]) stubtest.test_stubs(options) broken_stubs = [ error.get_description() - for error in stubtest.test_module('duckdb') + for error in stubtest.test_module("duckdb") if not any(skip in error.get_description() for skip in skip_stubs_errors) ] From c72f2c2a9ff2c97879f7c84bf3b7a0018f48ef9f Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:28:42 +0200 Subject: [PATCH 213/472] Ruff linter fixes --- adbc_driver_duckdb/dbapi.py | 4 +- duckdb/__init__.py | 234 ++-- duckdb/experimental/spark/__init__.py | 6 +- duckdb/experimental/spark/_globals.py | 3 +- duckdb/experimental/spark/_typing.py | 7 +- duckdb/experimental/spark/conf.py | 3 +- duckdb/experimental/spark/context.py | 4 +- duckdb/experimental/spark/errors/__init__.py | 66 +- .../spark/errors/exceptions/base.py | 91 +- duckdb/experimental/spark/errors/utils.py | 10 +- duckdb/experimental/spark/exception.py | 3 +- duckdb/experimental/spark/sql/__init__.py | 10 +- duckdb/experimental/spark/sql/_typing.py | 2 - duckdb/experimental/spark/sql/catalog.py | 5 +- duckdb/experimental/spark/sql/column.py | 36 +- duckdb/experimental/spark/sql/conf.py | 3 +- duckdb/experimental/spark/sql/dataframe.py | 131 +- duckdb/experimental/spark/sql/functions.py | 1162 ++++++++--------- duckdb/experimental/spark/sql/group.py | 30 +- duckdb/experimental/spark/sql/readwriter.py | 18 +- duckdb/experimental/spark/sql/session.py | 24 +- duckdb/experimental/spark/sql/streaming.py | 2 +- duckdb/experimental/spark/sql/type_utils.py | 48 +- duckdb/experimental/spark/sql/types.py | 150 +-- duckdb/filesystem.py | 8 +- duckdb/functional/__init__.py | 4 +- duckdb/polars_io.py | 33 +- duckdb/query_graph/__main__.py | 4 +- duckdb/typing/__init__.py | 10 +- duckdb/udf.py | 5 +- duckdb/value/constant/__init__.py | 49 +- duckdb_packaging/_versioning.py | 6 +- duckdb_packaging/build_backend.py | 34 +- duckdb_packaging/pypi_cleanup.py | 21 +- duckdb_packaging/setuptools_scm_version.py | 13 +- scripts/generate_connection_methods.py | 6 +- scripts/generate_connection_stubs.py | 6 +- .../generate_connection_wrapper_methods.py | 14 +- scripts/generate_connection_wrapper_stubs.py | 8 +- scripts/generate_import_cache_cpp.py | 3 +- scripts/generate_import_cache_json.py | 8 +- scripts/get_cpp_methods.py | 4 +- sqllogic/conftest.py | 16 +- sqllogic/test_sqllogic.py | 17 +- tests/conftest.py | 19 +- .../test_pandas_categorical_coverage.py | 8 +- tests/extensions/json/test_read_json.py | 10 +- tests/extensions/test_extensions_loading.py | 4 +- tests/extensions/test_httpfs.py | 18 +- tests/fast/adbc/test_adbc.py | 8 +- tests/fast/adbc/test_connection_get_info.py | 5 +- tests/fast/adbc/test_statement_bind.py | 2 +- tests/fast/api/test_3324.py | 3 +- tests/fast/api/test_3654.py | 7 +- tests/fast/api/test_3728.py | 2 +- tests/fast/api/test_6315.py | 2 +- tests/fast/api/test_attribute_getter.py | 13 +- tests/fast/api/test_config.py | 15 +- tests/fast/api/test_connection_close.py | 8 +- tests/fast/api/test_connection_interrupt.py | 5 +- tests/fast/api/test_cursor.py | 3 +- tests/fast/api/test_dbapi00.py | 5 +- tests/fast/api/test_dbapi01.py | 3 +- tests/fast/api/test_dbapi04.py | 2 +- tests/fast/api/test_dbapi05.py | 2 +- tests/fast/api/test_dbapi07.py | 7 +- tests/fast/api/test_dbapi08.py | 6 +- tests/fast/api/test_dbapi09.py | 5 +- tests/fast/api/test_dbapi10.py | 7 +- tests/fast/api/test_dbapi11.py | 7 +- tests/fast/api/test_dbapi12.py | 8 +- tests/fast/api/test_dbapi13.py | 5 +- tests/fast/api/test_dbapi_fetch.py | 10 +- tests/fast/api/test_duckdb_connection.py | 51 +- tests/fast/api/test_duckdb_execute.py | 5 +- tests/fast/api/test_duckdb_query.py | 7 +- tests/fast/api/test_explain.py | 3 +- tests/fast/api/test_fsspec.py | 10 +- tests/fast/api/test_insert_into.py | 7 +- tests/fast/api/test_join.py | 5 +- tests/fast/api/test_native_tz.py | 7 +- tests/fast/api/test_query_interrupt.py | 11 +- tests/fast/api/test_query_progress.py | 9 +- tests/fast/api/test_read_csv.py | 12 +- tests/fast/api/test_relation_to_view.py | 3 +- tests/fast/api/test_streaming_result.py | 3 +- tests/fast/api/test_to_csv.py | 15 +- tests/fast/api/test_to_parquet.py | 16 +- .../api/test_with_propagating_exceptions.py | 3 +- tests/fast/arrow/parquet_write_roundtrip.py | 10 +- tests/fast/arrow/test_10795.py | 3 +- tests/fast/arrow/test_12384.py | 6 +- tests/fast/arrow/test_14344.py | 7 +- tests/fast/arrow/test_2426.py | 5 +- tests/fast/arrow/test_5547.py | 5 +- tests/fast/arrow/test_6584.py | 4 +- tests/fast/arrow/test_6796.py | 5 +- tests/fast/arrow/test_7652.py | 7 +- tests/fast/arrow/test_7699.py | 6 +- tests/fast/arrow/test_8522.py | 7 +- tests/fast/arrow/test_9443.py | 3 +- tests/fast/arrow/test_arrow_batch_index.py | 5 +- tests/fast/arrow/test_arrow_binary_view.py | 5 +- tests/fast/arrow/test_arrow_case_sensitive.py | 3 +- tests/fast/arrow/test_arrow_decimal256.py | 8 +- tests/fast/arrow/test_arrow_decimal_32_64.py | 8 +- tests/fast/arrow/test_arrow_extensions.py | 12 +- tests/fast/arrow/test_arrow_fetch.py | 5 +- .../arrow/test_arrow_fetch_recordbatch.py | 5 +- tests/fast/arrow/test_arrow_fixed_binary.py | 2 +- tests/fast/arrow/test_arrow_ipc.py | 3 +- tests/fast/arrow/test_arrow_list.py | 7 +- tests/fast/arrow/test_arrow_offsets.py | 42 +- tests/fast/arrow/test_arrow_pycapsule.py | 7 +- .../arrow/test_arrow_recordbatchreader.py | 6 +- .../fast/arrow/test_arrow_replacement_scan.py | 9 +- .../fast/arrow/test_arrow_run_end_encoding.py | 43 +- tests/fast/arrow/test_arrow_scanner.py | 10 +- tests/fast/arrow/test_arrow_string_view.py | 6 +- tests/fast/arrow/test_arrow_types.py | 5 +- tests/fast/arrow/test_arrow_union.py | 3 +- tests/fast/arrow/test_arrow_version_format.py | 14 +- tests/fast/arrow/test_binary_type.py | 6 +- tests/fast/arrow/test_buffer_size_option.py | 5 +- tests/fast/arrow/test_dataset.py | 6 +- tests/fast/arrow/test_date.py | 8 +- tests/fast/arrow/test_dictionary_arrow.py | 3 +- tests/fast/arrow/test_filter_pushdown.py | 12 +- tests/fast/arrow/test_integration.py | 10 +- tests/fast/arrow/test_interval.py | 9 +- tests/fast/arrow/test_large_offsets.py | 9 +- tests/fast/arrow/test_large_string.py | 6 +- tests/fast/arrow/test_multiple_reads.py | 5 +- tests/fast/arrow/test_nested_arrow.py | 6 +- tests/fast/arrow/test_parallel.py | 7 +- tests/fast/arrow/test_polars.py | 9 +- tests/fast/arrow/test_progress.py | 7 +- tests/fast/arrow/test_projection_pushdown.py | 4 +- tests/fast/arrow/test_time.py | 8 +- tests/fast/arrow/test_timestamp_timezone.py | 8 +- tests/fast/arrow/test_timestamps.py | 8 +- tests/fast/arrow/test_tpch.py | 6 +- tests/fast/arrow/test_unregister.py | 10 +- tests/fast/arrow/test_view.py | 4 +- tests/fast/numpy/test_numpy_new_path.py | 8 +- tests/fast/pandas/test_2304.py | 7 +- tests/fast/pandas/test_append_df.py | 7 +- tests/fast/pandas/test_bug2281.py | 9 +- tests/fast/pandas/test_bug5922.py | 7 +- tests/fast/pandas/test_copy_on_write.py | 5 +- .../pandas/test_create_table_from_pandas.py | 9 +- tests/fast/pandas/test_date_as_datetime.py | 5 +- tests/fast/pandas/test_datetime_time.py | 10 +- tests/fast/pandas/test_datetime_timestamp.py | 7 +- tests/fast/pandas/test_df_analyze.py | 9 +- .../fast/pandas/test_df_object_resolution.py | 18 +- tests/fast/pandas/test_df_recursive_nested.py | 11 +- tests/fast/pandas/test_fetch_df_chunk.py | 3 +- tests/fast/pandas/test_fetch_nested.py | 7 +- .../fast/pandas/test_implicit_pandas_scan.py | 7 +- tests/fast/pandas/test_import_cache.py | 5 +- tests/fast/pandas/test_issue_1767.py | 9 +- tests/fast/pandas/test_limit.py | 7 +- tests/fast/pandas/test_pandas_arrow.py | 7 +- tests/fast/pandas/test_pandas_category.py | 7 +- tests/fast/pandas/test_pandas_df_none.py | 6 +- tests/fast/pandas/test_pandas_enum.py | 7 +- tests/fast/pandas/test_pandas_limit.py | 5 +- tests/fast/pandas/test_pandas_na.py | 11 +- tests/fast/pandas/test_pandas_object.py | 9 +- tests/fast/pandas/test_pandas_string.py | 11 +- tests/fast/pandas/test_pandas_timestamp.py | 8 +- tests/fast/pandas/test_pandas_types.py | 14 +- tests/fast/pandas/test_pandas_unregister.py | 13 +- tests/fast/pandas/test_pandas_update.py | 5 +- .../fast/pandas/test_parallel_pandas_scan.py | 13 +- .../pandas/test_partitioned_pandas_scan.py | 10 +- tests/fast/pandas/test_progress_bar.py | 10 +- .../test_pyarrow_projection_pushdown.py | 8 +- tests/fast/pandas/test_same_name.py | 4 +- tests/fast/pandas/test_stride.py | 8 +- tests/fast/pandas/test_timedelta.py | 8 +- tests/fast/pandas/test_timestamp.py | 10 +- tests/fast/relational_api/test_groupings.py | 7 +- tests/fast/relational_api/test_joins.py | 7 +- tests/fast/relational_api/test_pivot.py | 6 +- .../relational_api/test_rapi_aggregations.py | 9 +- tests/fast/relational_api/test_rapi_close.py | 5 +- .../relational_api/test_rapi_description.py | 3 +- .../relational_api/test_rapi_functions.py | 2 +- tests/fast/relational_api/test_rapi_query.py | 14 +- .../fast/relational_api/test_rapi_windows.py | 5 +- .../relational_api/test_table_function.py | 8 +- tests/fast/spark/test_replace_column_value.py | 2 +- tests/fast/spark/test_replace_empty_value.py | 3 +- tests/fast/spark/test_spark_arrow_table.py | 2 - tests/fast/spark/test_spark_catalog.py | 4 +- tests/fast/spark/test_spark_column.py | 12 +- tests/fast/spark/test_spark_dataframe.py | 28 +- tests/fast/spark/test_spark_dataframe_sort.py | 8 +- .../fast/spark/test_spark_drop_duplicates.py | 4 +- tests/fast/spark/test_spark_except.py | 2 - tests/fast/spark/test_spark_filter.py | 22 +- .../spark/test_spark_function_concat_ws.py | 4 +- .../fast/spark/test_spark_functions_array.py | 5 +- .../fast/spark/test_spark_functions_base64.py | 2 +- tests/fast/spark/test_spark_functions_date.py | 5 +- tests/fast/spark/test_spark_functions_expr.py | 2 +- tests/fast/spark/test_spark_functions_hash.py | 2 +- tests/fast/spark/test_spark_functions_hex.py | 6 +- tests/fast/spark/test_spark_functions_null.py | 2 +- .../spark/test_spark_functions_numeric.py | 3 +- .../fast/spark/test_spark_functions_string.py | 2 +- tests/fast/spark/test_spark_group_by.py | 48 +- tests/fast/spark/test_spark_intersect.py | 2 - tests/fast/spark/test_spark_join.py | 18 +- tests/fast/spark/test_spark_limit.py | 2 +- tests/fast/spark/test_spark_order_by.py | 15 +- .../fast/spark/test_spark_pandas_dataframe.py | 20 +- tests/fast/spark/test_spark_readcsv.py | 7 +- tests/fast/spark/test_spark_readjson.py | 5 +- tests/fast/spark/test_spark_readparquet.py | 5 +- tests/fast/spark/test_spark_runtime_config.py | 2 +- tests/fast/spark/test_spark_session.py | 7 +- tests/fast/spark/test_spark_to_csv.py | 24 +- tests/fast/spark/test_spark_to_parquet.py | 9 +- tests/fast/spark/test_spark_transform.py | 19 +- tests/fast/spark/test_spark_types.py | 41 +- tests/fast/spark/test_spark_udf.py | 2 +- tests/fast/spark/test_spark_union.py | 9 +- tests/fast/spark/test_spark_union_by_name.py | 17 +- tests/fast/spark/test_spark_with_column.py | 18 +- .../spark/test_spark_with_column_renamed.py | 20 +- tests/fast/spark/test_spark_with_columns.py | 2 +- .../spark/test_spark_with_columns_renamed.py | 5 +- tests/fast/sqlite/test_types.py | 2 +- tests/fast/test_alex_multithread.py | 5 +- tests/fast/test_all_types.py | 19 +- tests/fast/test_ambiguous_prepare.py | 5 +- tests/fast/test_case_alias.py | 13 +- tests/fast/test_context_manager.py | 2 +- tests/fast/test_duckdb_api.py | 3 +- tests/fast/test_expression.py | 25 +- tests/fast/test_filesystem.py | 18 +- tests/fast/test_get_table_names.py | 5 +- tests/fast/test_import_export.py | 9 +- tests/fast/test_insert.py | 10 +- tests/fast/test_json_logging.py | 3 +- tests/fast/test_many_con_same_file.py | 4 +- tests/fast/test_map.py | 13 +- tests/fast/test_metatransaction.py | 2 +- tests/fast/test_multi_statement.py | 6 +- tests/fast/test_multithread.py | 20 +- tests/fast/test_non_default_conn.py | 9 +- tests/fast/test_parameter_list.py | 7 +- tests/fast/test_parquet.py | 10 +- tests/fast/test_pypi_cleanup.py | 17 +- tests/fast/test_pytorch.py | 2 +- tests/fast/test_relation.py | 17 +- tests/fast/test_relation_dependency_leak.py | 8 +- tests/fast/test_replacement_scan.py | 8 +- tests/fast/test_result.py | 8 +- tests/fast/test_runtime_error.py | 7 +- tests/fast/test_sql_expression.py | 5 +- tests/fast/test_string_annotation.py | 8 +- tests/fast/test_tf.py | 2 +- tests/fast/test_transaction.py | 4 +- tests/fast/test_type.py | 56 +- tests/fast/test_type_explicit.py | 2 +- tests/fast/test_unicode.py | 6 +- tests/fast/test_union.py | 4 +- tests/fast/test_value.py | 110 +- tests/fast/test_version.py | 3 +- tests/fast/test_versioning.py | 15 +- tests/fast/test_windows_abs_path.py | 6 +- tests/fast/types/test_blob.py | 3 +- tests/fast/types/test_boolean.py | 4 +- tests/fast/types/test_datetime_date.py | 5 +- tests/fast/types/test_datetime_datetime.py | 6 +- tests/fast/types/test_decimal.py | 6 +- tests/fast/types/test_hugeint.py | 3 +- tests/fast/types/test_nan.py | 8 +- tests/fast/types/test_nested.py | 3 +- tests/fast/types/test_null.py | 3 +- tests/fast/types/test_numeric.py | 4 +- tests/fast/types/test_numpy.py | 9 +- tests/fast/types/test_object_int.py | 11 +- tests/fast/types/test_time_tz.py | 9 +- tests/fast/types/test_unsigned.py | 2 +- tests/fast/udf/test_null_filtering.py | 13 +- tests/fast/udf/test_remove_function.py | 13 +- tests/fast/udf/test_scalar.py | 28 +- tests/fast/udf/test_scalar_arrow.py | 14 +- tests/fast/udf/test_scalar_native.py | 9 +- tests/fast/udf/test_transactionality.py | 5 +- tests/slow/test_h2oai_arrow.py | 10 +- 296 files changed, 2147 insertions(+), 2425 deletions(-) diff --git a/adbc_driver_duckdb/dbapi.py b/adbc_driver_duckdb/dbapi.py index 793c4242..7d703713 100644 --- a/adbc_driver_duckdb/dbapi.py +++ b/adbc_driver_duckdb/dbapi.py @@ -15,14 +15,14 @@ # specific language governing permissions and limitations # under the License. -""" -DBAPI 2.0-compatible facade for the ADBC DuckDB driver. +"""DBAPI 2.0-compatible facade for the ADBC DuckDB driver. """ import typing import adbc_driver_manager import adbc_driver_manager.dbapi + import adbc_driver_duckdb __all__ = [ diff --git a/duckdb/__init__.py b/duckdb/__init__.py index bf50be5b..73fcbbd2 100644 --- a/duckdb/__init__.py +++ b/duckdb/__init__.py @@ -1,8 +1,10 @@ # Modules +from importlib.metadata import version + +from _duckdb import __version__ as duckdb_version + import duckdb.functional as functional import duckdb.typing as typing -from _duckdb import __version__ as duckdb_version -from importlib.metadata import version # duckdb.__version__ returns the version of the distribution package, i.e. the pypi version __version__ = version("duckdb") @@ -62,25 +64,25 @@ def __repr__(self): # Classes from _duckdb import ( - DuckDBPyRelation, + CaseExpression, + CoalesceOperator, + ColumnExpression, + ConstantExpression, + CSVLineTerminator, + DefaultExpression, DuckDBPyConnection, - Statement, - ExplainType, - StatementType, + DuckDBPyRelation, ExpectedResultType, - CSVLineTerminator, - PythonExceptionHandling, - RenderMode, + ExplainType, Expression, - ConstantExpression, - ColumnExpression, - DefaultExpression, - CoalesceOperator, - LambdaExpression, - StarExpression, FunctionExpression, - CaseExpression, + LambdaExpression, + PythonExceptionHandling, + RenderMode, SQLExpression, + StarExpression, + Statement, + StatementType, ) _exported_symbols.extend( @@ -104,91 +106,85 @@ def __repr__(self): # These are overloaded twice, we define them inside of C++ so pybind can deal with it _exported_symbols.extend(["df", "arrow"]) -from _duckdb import df, arrow - # NOTE: this section is generated by tools/pythonpkg/scripts/generate_connection_wrapper_methods.py. # Do not edit this section manually, your changes will be overwritten! - # START OF CONNECTION WRAPPER - from _duckdb import ( - cursor, - register_filesystem, - unregister_filesystem, - list_filesystems, - filesystem_is_registered, - create_function, - remove_function, - sqltype, - dtype, - type, + aggregate, + alias, + append, array_type, - list_type, - union_type, - string_type, - enum_type, + arrow, + begin, + checkpoint, + close, + commit, + create_function, + cursor, decimal_type, - struct_type, - row_type, - map_type, + description, + df, + distinct, + dtype, duplicate, + enum_type, execute, executemany, - close, - interrupt, - query_progress, - fetchone, - fetchmany, - fetchall, - fetchnumpy, - fetchdf, + extract_statements, + fetch_arrow_table, fetch_df, - df, fetch_df_chunk, - pl, - fetch_arrow_table, - arrow, fetch_record_batch, - torch, - tf, - begin, - commit, - rollback, - checkpoint, - append, - register, - unregister, - table, - view, - values, - table_function, - read_json, - extract_statements, - sql, - query, - from_query, - read_csv, + fetchall, + fetchdf, + fetchmany, + fetchnumpy, + fetchone, + filesystem_is_registered, + filter, + from_arrow, from_csv_auto, from_df, - from_arrow, - from_parquet, - read_parquet, from_parquet, - read_parquet, + from_query, get_table_names, install_extension, - load_extension, - project, - distinct, - write_csv, - aggregate, - alias, - filter, + interrupt, limit, + list_filesystems, + list_type, + load_extension, + map_type, order, + pl, + project, + query, query_df, - description, + query_progress, + read_csv, + read_json, + read_parquet, + register, + register_filesystem, + remove_function, + rollback, + row_type, rowcount, + sql, + sqltype, + string_type, + struct_type, + table, + table_function, + tf, + torch, + type, + union_type, + unregister, + unregister_filesystem, + values, + view, + write_csv, ) _exported_symbols.extend( @@ -276,17 +272,17 @@ def __repr__(self): # END OF CONNECTION WRAPPER # Enums -from _duckdb import ANALYZE, DEFAULT, RETURN_NULL, STANDARD, COLUMNS, ROWS +from _duckdb import ANALYZE, COLUMNS, DEFAULT, RETURN_NULL, ROWS, STANDARD _exported_symbols.extend(["ANALYZE", "DEFAULT", "RETURN_NULL", "STANDARD"]) # read-only properties from _duckdb import ( - __standard_vector_size__, + __formatted_python_version__, __interactive__, __jupyter__, - __formatted_python_version__, + __standard_vector_size__, apilevel, comment, identifier, @@ -337,35 +333,35 @@ def __repr__(self): # Exceptions from _duckdb import ( - Error, - DataError, + BinderException, + CatalogException, + ConnectionException, + ConstraintException, ConversionException, - OutOfRangeException, - TypeMismatchException, + DataError, + Error, FatalException, + HTTPException, IntegrityError, - ConstraintException, InternalError, InternalException, InterruptException, - NotSupportedError, + InvalidInputException, + InvalidTypeException, + IOException, NotImplementedException, + NotSupportedError, OperationalError, - ConnectionException, - IOException, - HTTPException, OutOfMemoryException, - SerializationException, - TransactionException, + OutOfRangeException, + ParserException, PermissionException, ProgrammingError, - BinderException, - CatalogException, - InvalidInputException, - InvalidTypeException, - ParserException, - SyntaxException, SequenceException, + SerializationException, + SyntaxException, + TransactionException, + TypeMismatchException, Warning, ) @@ -406,34 +402,34 @@ def __repr__(self): # Value from duckdb.value.constant import ( - Value, - NullValue, - BooleanValue, - UnsignedBinaryValue, - UnsignedShortValue, - UnsignedIntegerValue, - UnsignedLongValue, BinaryValue, - ShortValue, - IntegerValue, - LongValue, - HugeIntegerValue, - FloatValue, - DoubleValue, - DecimalValue, - StringValue, - UUIDValue, BitValue, BlobValue, + BooleanValue, DateValue, + DecimalValue, + DoubleValue, + FloatValue, + HugeIntegerValue, + IntegerValue, IntervalValue, - TimestampValue, - TimestampSecondValue, + LongValue, + NullValue, + ShortValue, + StringValue, TimestampMilisecondValue, TimestampNanosecondValue, + TimestampSecondValue, TimestampTimeZoneValue, - TimeValue, + TimestampValue, TimeTimeZoneValue, + TimeValue, + UnsignedBinaryValue, + UnsignedIntegerValue, + UnsignedLongValue, + UnsignedShortValue, + UUIDValue, + Value, ) _exported_symbols.extend( diff --git a/duckdb/experimental/spark/__init__.py b/duckdb/experimental/spark/__init__.py index 66895dcb..bdde2ef8 100644 --- a/duckdb/experimental/spark/__init__.py +++ b/duckdb/experimental/spark/__init__.py @@ -1,7 +1,7 @@ -from .sql import SparkSession, DataFrame +from ._globals import _NoValue from .conf import SparkConf from .context import SparkContext -from ._globals import _NoValue from .exception import ContributionsAcceptedError +from .sql import DataFrame, SparkSession -__all__ = ["SparkSession", "DataFrame", "SparkConf", "SparkContext", "ContributionsAcceptedError"] +__all__ = ["ContributionsAcceptedError", "DataFrame", "SparkConf", "SparkContext", "SparkSession"] diff --git a/duckdb/experimental/spark/_globals.py b/duckdb/experimental/spark/_globals.py index d6a02326..4bc325f7 100644 --- a/duckdb/experimental/spark/_globals.py +++ b/duckdb/experimental/spark/_globals.py @@ -15,8 +15,7 @@ # limitations under the License. # -""" -Module defining global singleton classes. +"""Module defining global singleton classes. This module raises a RuntimeError if an attempt to reload it is made. In that way the identities of the classes defined here are fixed and will remain so diff --git a/duckdb/experimental/spark/_typing.py b/duckdb/experimental/spark/_typing.py index 251ef695..12d16ced 100644 --- a/duckdb/experimental/spark/_typing.py +++ b/duckdb/experimental/spark/_typing.py @@ -16,10 +16,11 @@ # specific language governing permissions and limitations # under the License. -from typing import Callable, Iterable, Sized, TypeVar, Union -from typing_extensions import Literal, Protocol +from collections.abc import Iterable, Sized +from typing import Callable, TypeVar, Union -from numpy import int32, int64, float32, float64, ndarray +from numpy import float32, float64, int32, int64, ndarray +from typing_extensions import Literal, Protocol F = TypeVar("F", bound=Callable) T_co = TypeVar("T_co", covariant=True) diff --git a/duckdb/experimental/spark/conf.py b/duckdb/experimental/spark/conf.py index 79706781..ea1153b4 100644 --- a/duckdb/experimental/spark/conf.py +++ b/duckdb/experimental/spark/conf.py @@ -1,4 +1,5 @@ -from typing import Optional, List, Tuple +from typing import Optional + from duckdb.experimental.spark.exception import ContributionsAcceptedError diff --git a/duckdb/experimental/spark/context.py b/duckdb/experimental/spark/context.py index dd4b016c..9f1b4155 100644 --- a/duckdb/experimental/spark/context.py +++ b/duckdb/experimental/spark/context.py @@ -1,9 +1,9 @@ from typing import Optional + import duckdb from duckdb import DuckDBPyConnection - -from duckdb.experimental.spark.exception import ContributionsAcceptedError from duckdb.experimental.spark.conf import SparkConf +from duckdb.experimental.spark.exception import ContributionsAcceptedError class SparkContext: diff --git a/duckdb/experimental/spark/errors/__init__.py b/duckdb/experimental/spark/errors/__init__.py index 6aac49d7..2f265d97 100644 --- a/duckdb/experimental/spark/errors/__init__.py +++ b/duckdb/experimental/spark/errors/__init__.py @@ -15,59 +15,57 @@ # limitations under the License. # -""" -PySpark exceptions. +"""PySpark exceptions. """ -from .exceptions.base import ( # noqa: F401 - PySparkException, +from .exceptions.base import ( AnalysisException, - TempTableAlreadyExistsException, - ParseException, - IllegalArgumentException, ArithmeticException, - UnsupportedOperationException, ArrayIndexOutOfBoundsException, DateTimeException, + IllegalArgumentException, NumberFormatException, - StreamingQueryException, - QueryExecutionException, + ParseException, + PySparkAssertionError, + PySparkAttributeError, + PySparkException, + PySparkIndexError, + PySparkNotImplementedError, + PySparkRuntimeError, + PySparkTypeError, + PySparkValueError, PythonException, - UnknownException, + QueryExecutionException, SparkRuntimeException, SparkUpgradeException, - PySparkTypeError, - PySparkValueError, - PySparkIndexError, - PySparkAttributeError, - PySparkRuntimeError, - PySparkAssertionError, - PySparkNotImplementedError, + StreamingQueryException, + TempTableAlreadyExistsException, + UnknownException, + UnsupportedOperationException, ) - __all__ = [ - "PySparkException", "AnalysisException", - "TempTableAlreadyExistsException", - "ParseException", - "IllegalArgumentException", "ArithmeticException", - "UnsupportedOperationException", "ArrayIndexOutOfBoundsException", "DateTimeException", + "IllegalArgumentException", "NumberFormatException", - "StreamingQueryException", - "QueryExecutionException", + "ParseException", + "PySparkAssertionError", + "PySparkAttributeError", + "PySparkException", + "PySparkIndexError", + "PySparkNotImplementedError", + "PySparkRuntimeError", + "PySparkTypeError", + "PySparkValueError", "PythonException", - "UnknownException", + "QueryExecutionException", "SparkRuntimeException", "SparkUpgradeException", - "PySparkTypeError", - "PySparkValueError", - "PySparkIndexError", - "PySparkAttributeError", - "PySparkRuntimeError", - "PySparkAssertionError", - "PySparkNotImplementedError", + "StreamingQueryException", + "TempTableAlreadyExistsException", + "UnknownException", + "UnsupportedOperationException", ] diff --git a/duckdb/experimental/spark/errors/exceptions/base.py b/duckdb/experimental/spark/errors/exceptions/base.py index 48a3ea95..a6f1f940 100644 --- a/duckdb/experimental/spark/errors/exceptions/base.py +++ b/duckdb/experimental/spark/errors/exceptions/base.py @@ -1,11 +1,10 @@ -from typing import Dict, Optional, cast +from typing import Optional, cast from ..utils import ErrorClassesReader class PySparkException(Exception): - """ - Base Exception for handling errors generated from PySpark. + """Base Exception for handling errors generated from PySpark. """ def __init__( @@ -25,7 +24,7 @@ def __init__( if message is None: self.message = self.error_reader.get_error_message( - cast(str, error_class), cast(dict[str, str], message_parameters) + cast("str", error_class), cast("dict[str, str]", message_parameters) ) else: self.message = message @@ -34,12 +33,11 @@ def __init__( self.message_parameters = message_parameters def getErrorClass(self) -> Optional[str]: - """ - Returns an error class as a string. + """Returns an error class as a string. .. versionadded:: 3.4.0 - See Also + See Also: -------- :meth:`PySparkException.getMessageParameters` :meth:`PySparkException.getSqlState` @@ -47,12 +45,11 @@ def getErrorClass(self) -> Optional[str]: return self.error_class def getMessageParameters(self) -> Optional[dict[str, str]]: - """ - Returns a message parameters as a dictionary. + """Returns a message parameters as a dictionary. .. versionadded:: 3.4.0 - See Also + See Also: -------- :meth:`PySparkException.getErrorClass` :meth:`PySparkException.getSqlState` @@ -60,14 +57,13 @@ def getMessageParameters(self) -> Optional[dict[str, str]]: return self.message_parameters def getSqlState(self) -> None: - """ - Returns an SQLSTATE as a string. + """Returns an SQLSTATE as a string. Errors generated in Python have no SQLSTATE, so it always returns None. .. versionadded:: 3.4.0 - See Also + See Also: -------- :meth:`PySparkException.getErrorClass` :meth:`PySparkException.getMessageParameters` @@ -82,138 +78,115 @@ def __str__(self) -> str: class AnalysisException(PySparkException): - """ - Failed to analyze a SQL query plan. + """Failed to analyze a SQL query plan. """ class SessionNotSameException(PySparkException): - """ - Performed the same operation on different SparkSession. + """Performed the same operation on different SparkSession. """ class TempTableAlreadyExistsException(AnalysisException): - """ - Failed to create temp view since it is already exists. + """Failed to create temp view since it is already exists. """ class ParseException(AnalysisException): - """ - Failed to parse a SQL command. + """Failed to parse a SQL command. """ class IllegalArgumentException(PySparkException): - """ - Passed an illegal or inappropriate argument. + """Passed an illegal or inappropriate argument. """ class ArithmeticException(PySparkException): - """ - Arithmetic exception thrown from Spark with an error class. + """Arithmetic exception thrown from Spark with an error class. """ class UnsupportedOperationException(PySparkException): - """ - Unsupported operation exception thrown from Spark with an error class. + """Unsupported operation exception thrown from Spark with an error class. """ class ArrayIndexOutOfBoundsException(PySparkException): - """ - Array index out of bounds exception thrown from Spark with an error class. + """Array index out of bounds exception thrown from Spark with an error class. """ class DateTimeException(PySparkException): - """ - Datetime exception thrown from Spark with an error class. + """Datetime exception thrown from Spark with an error class. """ class NumberFormatException(IllegalArgumentException): - """ - Number format exception thrown from Spark with an error class. + """Number format exception thrown from Spark with an error class. """ class StreamingQueryException(PySparkException): - """ - Exception that stopped a :class:`StreamingQuery`. + """Exception that stopped a :class:`StreamingQuery`. """ class QueryExecutionException(PySparkException): - """ - Failed to execute a query. + """Failed to execute a query. """ class PythonException(PySparkException): - """ - Exceptions thrown from Python workers. + """Exceptions thrown from Python workers. """ class SparkRuntimeException(PySparkException): - """ - Runtime exception thrown from Spark with an error class. + """Runtime exception thrown from Spark with an error class. """ class SparkUpgradeException(PySparkException): - """ - Exception thrown because of Spark upgrade. + """Exception thrown because of Spark upgrade. """ class UnknownException(PySparkException): - """ - None of the above exceptions. + """None of the above exceptions. """ class PySparkValueError(PySparkException, ValueError): - """ - Wrapper class for ValueError to support error classes. + """Wrapper class for ValueError to support error classes. """ class PySparkIndexError(PySparkException, IndexError): - """ - Wrapper class for IndexError to support error classes. + """Wrapper class for IndexError to support error classes. """ class PySparkTypeError(PySparkException, TypeError): - """ - Wrapper class for TypeError to support error classes. + """Wrapper class for TypeError to support error classes. """ class PySparkAttributeError(PySparkException, AttributeError): - """ - Wrapper class for AttributeError to support error classes. + """Wrapper class for AttributeError to support error classes. """ class PySparkRuntimeError(PySparkException, RuntimeError): - """ - Wrapper class for RuntimeError to support error classes. + """Wrapper class for RuntimeError to support error classes. """ class PySparkAssertionError(PySparkException, AssertionError): - """ - Wrapper class for AssertionError to support error classes. + """Wrapper class for AssertionError to support error classes. """ class PySparkNotImplementedError(PySparkException, NotImplementedError): - """ - Wrapper class for NotImplementedError to support error classes. + """Wrapper class for NotImplementedError to support error classes. """ diff --git a/duckdb/experimental/spark/errors/utils.py b/duckdb/experimental/spark/errors/utils.py index f1b37f75..c8c66896 100644 --- a/duckdb/experimental/spark/errors/utils.py +++ b/duckdb/experimental/spark/errors/utils.py @@ -16,22 +16,19 @@ # import re -from typing import Dict from .error_classes import ERROR_CLASSES_MAP class ErrorClassesReader: - """ - A reader to load error information from error_classes.py. + """A reader to load error information from error_classes.py. """ def __init__(self) -> None: self.error_info_map = ERROR_CLASSES_MAP def get_error_message(self, error_class: str, message_parameters: dict[str, str]) -> str: - """ - Returns the completed error message by applying message parameters to the message template. + """Returns the completed error message by applying message parameters to the message template. """ message_template = self.get_message_template(error_class) # Verify message parameters. @@ -44,8 +41,7 @@ def get_error_message(self, error_class: str, message_parameters: dict[str, str] return message_template.translate(table).format(**message_parameters) def get_message_template(self, error_class: str) -> str: - """ - Returns the message template for corresponding error class from error_classes.py. + """Returns the message template for corresponding error class from error_classes.py. For example, when given `error_class` is "EXAMPLE_ERROR_CLASS", diff --git a/duckdb/experimental/spark/exception.py b/duckdb/experimental/spark/exception.py index 60495d88..791f7090 100644 --- a/duckdb/experimental/spark/exception.py +++ b/duckdb/experimental/spark/exception.py @@ -1,6 +1,5 @@ class ContributionsAcceptedError(NotImplementedError): - """ - This method is not planned to be implemented, if you would like to implement this method + """This method is not planned to be implemented, if you would like to implement this method or show your interest in this method to other members of the community, feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb """ diff --git a/duckdb/experimental/spark/sql/__init__.py b/duckdb/experimental/spark/sql/__init__.py index 2312ee50..9ae09308 100644 --- a/duckdb/experimental/spark/sql/__init__.py +++ b/duckdb/experimental/spark/sql/__init__.py @@ -1,7 +1,7 @@ -from .session import SparkSession -from .readwriter import DataFrameWriter -from .dataframe import DataFrame -from .conf import RuntimeConfig from .catalog import Catalog +from .conf import RuntimeConfig +from .dataframe import DataFrame +from .readwriter import DataFrameWriter +from .session import SparkSession -__all__ = ["SparkSession", "DataFrame", "RuntimeConfig", "DataFrameWriter", "Catalog"] +__all__ = ["Catalog", "DataFrame", "DataFrameWriter", "RuntimeConfig", "SparkSession"] diff --git a/duckdb/experimental/spark/sql/_typing.py b/duckdb/experimental/spark/sql/_typing.py index b5a8b079..caf0058c 100644 --- a/duckdb/experimental/spark/sql/_typing.py +++ b/duckdb/experimental/spark/sql/_typing.py @@ -19,9 +19,7 @@ from typing import ( Any, Callable, - List, Optional, - Tuple, TypeVar, Union, ) diff --git a/duckdb/experimental/spark/sql/catalog.py b/duckdb/experimental/spark/sql/catalog.py index 3cc96f45..8e510fdf 100644 --- a/duckdb/experimental/spark/sql/catalog.py +++ b/duckdb/experimental/spark/sql/catalog.py @@ -1,4 +1,5 @@ -from typing import List, NamedTuple, Optional +from typing import NamedTuple, Optional + from .session import SparkSession @@ -75,4 +76,4 @@ def setCurrentDatabase(self, dbName: str) -> None: raise NotImplementedError -__all__ = ["Catalog", "Table", "Column", "Function", "Database"] +__all__ = ["Catalog", "Column", "Database", "Function", "Table"] diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index f78b31ae..3a6f6cea 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -1,13 +1,12 @@ -from typing import Union, TYPE_CHECKING, Any, cast, Callable, Tuple -from ..exception import ContributionsAcceptedError +from typing import TYPE_CHECKING, Any, Callable, Union, cast +from ..exception import ContributionsAcceptedError from .types import DataType if TYPE_CHECKING: - from ._typing import ColumnOrName, LiteralType, DecimalLiteral, DateTimeLiteral - -from duckdb import ConstantExpression, ColumnExpression, FunctionExpression, Expression + from ._typing import DateTimeLiteral, DecimalLiteral, LiteralType +from duckdb import ColumnExpression, ConstantExpression, Expression, FunctionExpression from duckdb.typing import DuckDBPyType __all__ = ["Column"] @@ -78,8 +77,7 @@ def _( class Column: - """ - A column in a DataFrame. + """A column in a DataFrame. :class:`Column` instances can be created by:: @@ -139,8 +137,7 @@ def __neg__(self) -> "Column": __rpow__ = _bin_op("__rpow__") def __getitem__(self, k: Any) -> "Column": - """ - An expression that gets an item at position ``ordinal`` out of a list, + """An expression that gets an item at position ``ordinal`` out of a list, or gets an item by key out of a dict. .. versionadded:: 1.3.0 @@ -153,13 +150,13 @@ def __getitem__(self, k: Any) -> "Column": k a literal value, or a slice object without step. - Returns + Returns: ------- :class:`Column` Column representing the item got by key out of a dict, or substrings sliced by the given slice object. - Examples + Examples: -------- >>> df = spark.createDataFrame([("abcedfg", {"key": "value"})], ["l", "d"]) >>> df.select(df.l[slice(1, 3)], df.d["key"]).show() @@ -180,8 +177,7 @@ def __getitem__(self, k: Any) -> "Column": return Column(ColumnExpression(expr_str)) def __getattr__(self, item: Any) -> "Column": - """ - An expression that gets an item at position ``ordinal`` out of a list, + """An expression that gets an item at position ``ordinal`` out of a list, or gets an item by key out of a dict. Parameters @@ -189,12 +185,12 @@ def __getattr__(self, item: Any) -> "Column": item a literal value. - Returns + Returns: ------- :class:`Column` Column representing the item got by key out of a dict. - Examples + Examples: -------- >>> df = spark.createDataFrame([("abcedfg", {"key": "value"})], ["l", "d"]) >>> df.select(df.d.key).show() @@ -234,10 +230,10 @@ def cast(self, dataType: Union[DataType, str]) -> "Column": def isin(self, *cols: Any) -> "Column": if len(cols) == 1 and isinstance(cols[0], (list, set)): # Only one argument supplied, it's a list - cols = cast(tuple, cols[0]) + cols = cast("tuple", cols[0]) cols = cast( - tuple, + "tuple", [_get_expr(c) for c in cols], ) return Column(self.expr.isin(*cols)) @@ -247,14 +243,14 @@ def __eq__( # type: ignore[override] self, other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"], ) -> "Column": - """binary function""" + """Binary function""" return Column(self.expr == (_get_expr(other))) def __ne__( # type: ignore[override] self, - other: Any, + other: object, ) -> "Column": - """binary function""" + """Binary function""" return Column(self.expr != (_get_expr(other))) __lt__ = _bin_op("__lt__") diff --git a/duckdb/experimental/spark/sql/conf.py b/duckdb/experimental/spark/sql/conf.py index 8e30d7ca..8ab9fa38 100644 --- a/duckdb/experimental/spark/sql/conf.py +++ b/duckdb/experimental/spark/sql/conf.py @@ -1,6 +1,7 @@ from typing import Optional, Union -from duckdb.experimental.spark._globals import _NoValueType, _NoValue + from duckdb import DuckDBPyConnection +from duckdb.experimental.spark._globals import _NoValue, _NoValueType class RuntimeConfig: diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 19f5576b..3f32aa32 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -1,25 +1,22 @@ +import uuid from functools import reduce +from keyword import iskeyword from typing import ( TYPE_CHECKING, Any, Callable, - List, - Dict, Optional, - Tuple, Union, cast, overload, ) -import uuid -from keyword import iskeyword import duckdb from duckdb import ColumnExpression, Expression, StarExpression -from ._typing import ColumnOrName -from ..errors import PySparkTypeError, PySparkValueError, PySparkIndexError +from ..errors import PySparkIndexError, PySparkTypeError, PySparkValueError from ..exception import ContributionsAcceptedError +from ._typing import ColumnOrName from .column import Column from .readwriter import DataFrameWriter from .type_utils import duckdb_to_spark_schema @@ -29,10 +26,9 @@ import pyarrow as pa from pandas.core.frame import DataFrame as PandasDataFrame - from .group import GroupedData, Grouping + from .group import GroupedData from .session import SparkSession -from ..errors import PySparkValueError from .functions import _to_column_expr, col, lit @@ -51,21 +47,20 @@ def toPandas(self) -> "PandasDataFrame": return self.relation.df() def toArrow(self) -> "pa.Table": - """ - Returns the contents of this :class:`DataFrame` as PyArrow ``pyarrow.Table``. + """Returns the contents of this :class:`DataFrame` as PyArrow ``pyarrow.Table``. This is only available if PyArrow is installed and available. .. versionadded:: 4.0.0 - Notes + Notes: ----- This method should only be used if the resulting PyArrow ``pyarrow.Table`` is expected to be small, as all the data is loaded into the driver's memory. This API is a developer API. - Examples + Examples: -------- >>> df.toArrow() # doctest: +SKIP pyarrow.Table @@ -88,7 +83,7 @@ def createOrReplaceTempView(self, name: str) -> None: name : str Name of the view. - Examples + Examples: -------- Create a local temporary view named 'people'. @@ -144,8 +139,7 @@ def withColumn(self, columnName: str, col: Column) -> "DataFrame": return DataFrame(rel, self.session) def withColumns(self, *colsMap: dict[str, Column]) -> "DataFrame": - """ - Returns a new :class:`DataFrame` by adding multiple columns or replacing the + """Returns a new :class:`DataFrame` by adding multiple columns or replacing the existing columns that have the same names. The colsMap is a map of column name and column, the column must only refer to attributes @@ -162,12 +156,12 @@ def withColumns(self, *colsMap: dict[str, Column]) -> "DataFrame": colsMap : dict a dict of column name and :class:`Column`. Currently, only a single map is supported. - Returns + Returns: ------- :class:`DataFrame` DataFrame with new or replaced columns. - Examples + Examples: -------- >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) >>> df.withColumns({"age2": df.age + 2, "age3": df.age + 3}).show() @@ -219,8 +213,7 @@ def withColumns(self, *colsMap: dict[str, Column]) -> "DataFrame": return DataFrame(rel, self.session) def withColumnsRenamed(self, colsMap: dict[str, str]) -> "DataFrame": - """ - Returns a new :class:`DataFrame` by renaming multiple columns. + """Returns a new :class:`DataFrame` by renaming multiple columns. This is a no-op if the schema doesn't contain the given column names. .. versionadded:: 3.4.0 @@ -232,20 +225,20 @@ def withColumnsRenamed(self, colsMap: dict[str, str]) -> "DataFrame": a dict of existing column names and corresponding desired column names. Currently, only a single map is supported. - Returns + Returns: ------- :class:`DataFrame` DataFrame with renamed columns. - See Also + See Also: -------- :meth:`withColumnRenamed` - Notes + Notes: ----- Support Spark Connect - Examples + Examples: -------- >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) >>> df = df.withColumns({"age2": df.age + 2, "age3": df.age + 3}) @@ -308,12 +301,12 @@ def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any) .. versionadded:: 3.3.0 - Returns + Returns: ------- :class:`DataFrame` Transformed DataFrame. - Examples + Examples: -------- >>> from pyspark.sql.functions import col >>> df = spark.createDataFrame([(1, 1.0), (2, 2.0)], ["int", "float"]) @@ -362,12 +355,12 @@ def sort(self, *cols: Union[str, Column, list[Union[str, Column]]], **kwargs: An Sort ascending vs. descending. Specify list for multiple sort orders. If a list is specified, the length of the list must equal the length of the `cols`. - Returns + Returns: ------- :class:`DataFrame` Sorted DataFrame. - Examples + Examples: -------- >>> from pyspark.sql.functions import desc, asc >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) @@ -499,12 +492,12 @@ def filter(self, condition: "ColumnOrName") -> "DataFrame": a :class:`Column` of :class:`types.BooleanType` or a string of SQL expressions. - Returns + Returns: ------- :class:`DataFrame` Filtered DataFrame. - Examples + Examples: -------- >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) @@ -567,7 +560,7 @@ def select(self, *cols) -> "DataFrame": def columns(self) -> list[str]: """Returns all column names as a list. - Examples + Examples: -------- >>> df.columns ['age', 'name'] @@ -607,12 +600,12 @@ def join( ``right``, ``rightouter``, ``right_outer``, ``semi``, ``leftsemi``, ``left_semi``, ``anti``, ``leftanti`` and ``left_anti``. - Returns + Returns: ------- :class:`DataFrame` Joined DataFrame. - Examples + Examples: -------- The following performs a full outer join between ``df1`` and ``df2``. @@ -678,7 +671,6 @@ def join( | Bob| 5| +-----+---+ """ - if on is not None and not isinstance(on, list): on = [on] # type: ignore[assignment] if on is not None and not all([isinstance(x, str) for x in on]): @@ -688,7 +680,7 @@ def join( # & all the Expressions together to form one Expression assert isinstance(on[0], Expression), "on should be Column or list of Column" - on = reduce(lambda x, y: x.__and__(y), cast(list[Expression], on)) + on = reduce(lambda x, y: x.__and__(y), cast("list[Expression]", on)) if on is None and how is None: result = self.relation.join(other.relation) @@ -740,12 +732,12 @@ def crossJoin(self, other: "DataFrame") -> "DataFrame": other : :class:`DataFrame` Right side of the cartesian product. - Returns + Returns: ------- :class:`DataFrame` Joined DataFrame. - Examples + Examples: -------- >>> from pyspark.sql import Row >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) @@ -772,12 +764,12 @@ def alias(self, alias: str) -> "DataFrame": alias : str an alias name to be set for the :class:`DataFrame`. - Returns + Returns: ------- :class:`DataFrame` Aliased DataFrame. - Examples + Examples: -------- >>> from pyspark.sql.functions import col, desc >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) @@ -827,12 +819,12 @@ def limit(self, num: int) -> "DataFrame": Number of records to return. Will return this number of records or all records if the DataFrame contains less than this number of records. - Returns + Returns: ------- :class:`DataFrame` Subset of the records - Examples + Examples: -------- >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) >>> df.limit(1).show() @@ -851,8 +843,7 @@ def limit(self, num: int) -> "DataFrame": return DataFrame(rel, self.session) def __contains__(self, item: str) -> bool: - """ - Check if the :class:`DataFrame` contains a column by the name of `item` + """Check if the :class:`DataFrame` contains a column by the name of `item` """ return item in self.relation @@ -860,7 +851,7 @@ def __contains__(self, item: str) -> bool: def schema(self) -> StructType: """Returns the schema of this :class:`DataFrame` as a :class:`duckdb.experimental.spark.sql.types.StructType`. - Examples + Examples: -------- >>> df.schema StructType([StructField('age', IntegerType(), True), @@ -877,7 +868,7 @@ def __getitem__(self, item: Union[Column, list, tuple]) -> "DataFrame": ... def __getitem__(self, item: Union[int, str, Column, list, tuple]) -> Union[Column, "DataFrame"]: """Returns the column as a :class:`Column`. - Examples + Examples: -------- >>> df.select(df["age"]).collect() [Row(age=2), Row(age=5)] @@ -902,7 +893,7 @@ def __getitem__(self, item: Union[int, str, Column, list, tuple]) -> Union[Colum def __getattr__(self, name: str) -> Column: """Returns the :class:`Column` denoted by ``name``. - Examples + Examples: -------- >>> df.select(df.age).collect() [Row(age=2), Row(age=5)] @@ -931,12 +922,12 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] Each element should be a column name (string) or an expression (:class:`Column`) or list of them. - Returns + Returns: ------- :class:`GroupedData` Grouped data by given columns. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [(2, "Alice"), (2, "Bob"), (2, "Bob"), (5, "Bob")], schema=["age", "name"] @@ -1008,22 +999,22 @@ def union(self, other: "DataFrame") -> "DataFrame": other : :class:`DataFrame` Another :class:`DataFrame` that needs to be unioned - Returns + Returns: ------- :class:`DataFrame` - See Also + See Also: -------- DataFrame.unionAll - Notes + Notes: ----- This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does deduplication of elements), use this function followed by :func:`distinct`. Also as standard in SQL, this function resolves columns by position (not by name). - Examples + Examples: -------- >>> df1 = spark.createDataFrame([[1, 2, 3]], ["col0", "col1", "col2"]) >>> df2 = spark.createDataFrame([[4, 5, 6]], ["col1", "col2", "col0"]) @@ -1067,12 +1058,12 @@ def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) -> .. versionadded:: 3.1.0 - Returns + Returns: ------- :class:`DataFrame` Combined DataFrame. - Examples + Examples: -------- The difference between this function and :func:`union` is that this function resolves columns by name (not by position): @@ -1131,16 +1122,16 @@ def intersect(self, other: "DataFrame") -> "DataFrame": other : :class:`DataFrame` Another :class:`DataFrame` that needs to be combined. - Returns + Returns: ------- :class:`DataFrame` Combined DataFrame. - Notes + Notes: ----- This is equivalent to `INTERSECT` in SQL. - Examples + Examples: -------- >>> df1 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3), ("c", 4)], ["C1", "C2"]) >>> df2 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3)], ["C1", "C2"]) @@ -1171,12 +1162,12 @@ def intersectAll(self, other: "DataFrame") -> "DataFrame": other : :class:`DataFrame` Another :class:`DataFrame` that needs to be combined. - Returns + Returns: ------- :class:`DataFrame` Combined DataFrame. - Examples + Examples: -------- >>> df1 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3), ("c", 4)], ["C1", "C2"]) >>> df2 = spark.createDataFrame([("a", 1), ("a", 1), ("b", 3)], ["C1", "C2"]) @@ -1208,11 +1199,11 @@ def exceptAll(self, other: "DataFrame") -> "DataFrame": other : :class:`DataFrame` The other :class:`DataFrame` to compare to. - Returns + Returns: ------- :class:`DataFrame` - Examples + Examples: -------- >>> df1 = spark.createDataFrame( ... [("a", 1), ("a", 1), ("a", 1), ("a", 2), ("b", 3), ("c", 4)], ["C1", "C2"] @@ -1248,12 +1239,12 @@ def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": subset : List of column names, optional List of columns to use for duplicate comparison (default All columns). - Returns + Returns: ------- :class:`DataFrame` DataFrame without duplicates. - Examples + Examples: -------- >>> from pyspark.sql import Row >>> df = spark.createDataFrame( @@ -1297,12 +1288,12 @@ def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": def distinct(self) -> "DataFrame": """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. - Returns + Returns: ------- :class:`DataFrame` DataFrame with distinct records. - Examples + Examples: -------- >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (23, "Alice")], ["age", "name"]) @@ -1317,12 +1308,12 @@ def distinct(self) -> "DataFrame": def count(self) -> int: """Returns the number of rows in this :class:`DataFrame`. - Returns + Returns: ------- int Number of rows. - Examples + Examples: -------- >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) @@ -1377,16 +1368,16 @@ def cache(self) -> "DataFrame": .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The default storage level has changed to `MEMORY_AND_DISK_DESER` to match Scala in 3.0. - Returns + Returns: ------- :class:`DataFrame` Cached DataFrame. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.cache() diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index dfcf7e2e..501c9503 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Callable, Union, overload, Optional, List, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, overload from duckdb import ( CaseExpression, @@ -17,14 +17,13 @@ from ..errors import PySparkTypeError from ..exception import ContributionsAcceptedError +from . import types as _types from ._typing import ColumnOrName from .column import Column, _get_expr -from . import types as _types def _invoke_function_over_columns(name: str, *cols: "ColumnOrName") -> Column: - """ - Invokes n-ary JVM function identified by name + """Invokes n-ary JVM function identified by name and wraps the result with :class:`~pyspark.sql.Column`. """ cols = [_to_column_expr(expr) for expr in cols] @@ -36,8 +35,7 @@ def col(column: str): def upper(col: "ColumnOrName") -> Column: - """ - Converts a string expression to upper case. + """Converts a string expression to upper case. .. versionadded:: 1.5.0 @@ -49,12 +47,12 @@ def upper(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` upper case values. - Examples + Examples: -------- >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") >>> df.select(upper("value")).show() @@ -70,8 +68,7 @@ def upper(col: "ColumnOrName") -> Column: def ucase(str: "ColumnOrName") -> Column: - """ - Returns `str` with all characters changed to uppercase. + """Returns `str` with all characters changed to uppercase. .. versionadded:: 3.5.0 @@ -80,7 +77,7 @@ def ucase(str: "ColumnOrName") -> Column: str : :class:`~pyspark.sql.Column` or str Input column or strings. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(1).select(sf.ucase(sf.lit("Spark"))).show() @@ -123,12 +120,12 @@ def array(*cols: Union["ColumnOrName", Union[list["ColumnOrName"], tuple["Column column names or :class:`~pyspark.sql.Column`\\s that have the same data type. - Returns + Returns: ------- :class:`~pyspark.sql.Column` a column of array type. - Examples + Examples: -------- >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5)], ("name", "age")) >>> df.select(array("age", "age").alias("arr")).collect() @@ -170,7 +167,7 @@ def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Colum .. versionadded:: 1.5.0 - Examples + Examples: -------- >>> df = spark.createDataFrame([("100-200",)], ["str"]) >>> df.select(regexp_replace("str", r"(\d+)", "--").alias("d")).collect() @@ -186,8 +183,7 @@ def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Colum def slice(x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int]) -> Column: - """ - Collection function: returns an array containing all the elements in `x` from index `start` + """Collection function: returns an array containing all the elements in `x` from index `start` (array indices start at 1, or from the end if `start` is negative) with the specified `length`. .. versionadded:: 2.4.0 @@ -204,12 +200,12 @@ def slice(x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["C length : :class:`~pyspark.sql.Column` or str or int column name, column, or int containing the length of the slice - Returns + Returns: ------- :class:`~pyspark.sql.Column` a column of array type. Subset of array. - Examples + Examples: -------- >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ["x"]) >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect() @@ -224,8 +220,7 @@ def slice(x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["C def asc(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the ascending order of the given column name. + """Returns a sort expression based on the ascending order of the given column name. .. versionadded:: 1.3.0 @@ -237,12 +232,12 @@ def asc(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the ascending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- Sort by the column 'id' in the descending order. @@ -276,8 +271,7 @@ def asc(col: "ColumnOrName") -> Column: def asc_nulls_first(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the ascending order of the given + """Returns a sort expression based on the ascending order of the given column name, and null values return before non-null values. .. versionadded:: 2.4.0 @@ -290,12 +284,12 @@ def asc_nulls_first(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the ascending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- >>> df1 = spark.createDataFrame([(1, "Bob"), (0, None), (2, "Alice")], ["age", "name"]) >>> df1.sort(asc_nulls_first(df1.name)).show() @@ -312,8 +306,7 @@ def asc_nulls_first(col: "ColumnOrName") -> Column: def asc_nulls_last(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the ascending order of the given + """Returns a sort expression based on the ascending order of the given column name, and null values appear after non-null values. .. versionadded:: 2.4.0 @@ -326,12 +319,12 @@ def asc_nulls_last(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the ascending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- >>> df1 = spark.createDataFrame([(0, None), (1, "Bob"), (2, "Alice")], ["age", "name"]) >>> df1.sort(asc_nulls_last(df1.name)).show() @@ -348,8 +341,7 @@ def asc_nulls_last(col: "ColumnOrName") -> Column: def desc(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the descending order of the given column name. + """Returns a sort expression based on the descending order of the given column name. .. versionadded:: 1.3.0 @@ -361,12 +353,12 @@ def desc(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the descending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- Sort by the column 'id' in the descending order. @@ -385,8 +377,7 @@ def desc(col: "ColumnOrName") -> Column: def desc_nulls_first(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the descending order of the given + """Returns a sort expression based on the descending order of the given column name, and null values appear before non-null values. .. versionadded:: 2.4.0 @@ -399,12 +390,12 @@ def desc_nulls_first(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the descending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- >>> df1 = spark.createDataFrame([(0, None), (1, "Bob"), (2, "Alice")], ["age", "name"]) >>> df1.sort(desc_nulls_first(df1.name)).show() @@ -421,8 +412,7 @@ def desc_nulls_first(col: "ColumnOrName") -> Column: def desc_nulls_last(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the descending order of the given + """Returns a sort expression based on the descending order of the given column name, and null values appear after non-null values. .. versionadded:: 2.4.0 @@ -435,12 +425,12 @@ def desc_nulls_last(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the descending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- >>> df1 = spark.createDataFrame([(0, None), (1, "Bob"), (2, "Alice")], ["age", "name"]) >>> df1.sort(desc_nulls_last(df1.name)).show() @@ -457,8 +447,7 @@ def desc_nulls_last(col: "ColumnOrName") -> Column: def left(str: "ColumnOrName", len: "ColumnOrName") -> Column: - """ - Returns the leftmost `len`(`len` can be string type) characters from the string `str`, + """Returns the leftmost `len`(`len` can be string type) characters from the string `str`, if `len` is less or equal than 0 the result is an empty string. .. versionadded:: 3.5.0 @@ -470,7 +459,7 @@ def left(str: "ColumnOrName", len: "ColumnOrName") -> Column: len : :class:`~pyspark.sql.Column` or str Input column or strings, the leftmost `len`. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -493,8 +482,7 @@ def left(str: "ColumnOrName", len: "ColumnOrName") -> Column: def right(str: "ColumnOrName", len: "ColumnOrName") -> Column: - """ - Returns the rightmost `len`(`len` can be string type) characters from the string `str`, + """Returns the rightmost `len`(`len` can be string type) characters from the string `str`, if `len` is less or equal than 0 the result is an empty string. .. versionadded:: 3.5.0 @@ -506,7 +494,7 @@ def right(str: "ColumnOrName", len: "ColumnOrName") -> Column: len : :class:`~pyspark.sql.Column` or str Input column or strings, the rightmost `len`. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -549,12 +537,12 @@ def levenshtein(left: "ColumnOrName", right: "ColumnOrName", threshold: Optional .. versionchanged: 3.5.0 Added ``threshold`` argument. - Returns + Returns: ------- :class:`~pyspark.sql.Column` Levenshtein distance as integer value. - Examples + Examples: -------- >>> df0 = spark.createDataFrame( ... [ @@ -581,8 +569,7 @@ def levenshtein(left: "ColumnOrName", right: "ColumnOrName", threshold: Optional def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: - """ - Left-pad the string column to width `len` with `pad`. + """Left-pad the string column to width `len` with `pad`. .. versionadded:: 1.5.0 @@ -598,12 +585,12 @@ def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: pad : str chars to prepend. - Returns + Returns: ------- :class:`~pyspark.sql.Column` left padded result. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [("abcd",)], @@ -618,8 +605,7 @@ def lpad(col: "ColumnOrName", len: int, pad: str) -> Column: def rpad(col: "ColumnOrName", len: int, pad: str) -> Column: - """ - Right-pad the string column to width `len` with `pad`. + """Right-pad the string column to width `len` with `pad`. .. versionadded:: 1.5.0 @@ -635,12 +621,12 @@ def rpad(col: "ColumnOrName", len: int, pad: str) -> Column: pad : str chars to append. - Returns + Returns: ------- :class:`~pyspark.sql.Column` right padded result. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [("abcd",)], @@ -655,8 +641,7 @@ def rpad(col: "ColumnOrName", len: int, pad: str) -> Column: def ascii(col: "ColumnOrName") -> Column: - """ - Computes the numeric value of the first character of the string column. + """Computes the numeric value of the first character of the string column. .. versionadded:: 1.5.0 @@ -668,12 +653,12 @@ def ascii(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` numeric value. - Examples + Examples: -------- >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") >>> df.select(ascii("value")).show() @@ -689,8 +674,7 @@ def ascii(col: "ColumnOrName") -> Column: def asin(col: "ColumnOrName") -> Column: - """ - Computes inverse sine of the input column. + """Computes inverse sine of the input column. .. versionadded:: 1.4.0 @@ -702,12 +686,12 @@ def asin(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` inverse sine of `col`, as if computed by `java.lang.Math.asin()` - Examples + Examples: -------- >>> df = spark.createDataFrame([(0,), (2,)]) >>> df.select(asin(df.schema.fieldNames()[0])).show() @@ -728,8 +712,7 @@ def asin(col: "ColumnOrName") -> Column: def like(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None) -> Column: - """ - Returns true if str matches `pattern` with `escape`, + """Returns true if str matches `pattern` with `escape`, null if any arguments are null, false otherwise. The default escape character is the '\'. @@ -746,7 +729,7 @@ def like(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Col If an escape character precedes a special symbol or another escape character, the following character is matched literally. It is invalid to escape any other character. - Examples + Examples: -------- >>> df = spark.createDataFrame([("Spark", "_park")], ["a", "b"]) >>> df.select(like(df.a, df.b).alias("r")).collect() @@ -766,8 +749,7 @@ def like(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Col def ilike(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None) -> Column: - """ - Returns true if str matches `pattern` with `escape` case-insensitively, + """Returns true if str matches `pattern` with `escape` case-insensitively, null if any arguments are null, false otherwise. The default escape character is the '\'. @@ -784,7 +766,7 @@ def ilike(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Co If an escape character precedes a special symbol or another escape character, the following character is matched literally. It is invalid to escape any other character. - Examples + Examples: -------- >>> df = spark.createDataFrame([("Spark", "_park")], ["a", "b"]) >>> df.select(ilike(df.a, df.b).alias("r")).collect() @@ -804,8 +786,7 @@ def ilike(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Co def array_agg(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns a list of objects with duplicates. + """Aggregate function: returns a list of objects with duplicates. .. versionadded:: 3.5.0 @@ -814,12 +795,12 @@ def array_agg(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` list of objects with duplicates. - Examples + Examples: -------- >>> df = spark.createDataFrame([[1], [1], [2]], ["c"]) >>> df.agg(array_agg("c").alias("r")).collect() @@ -829,15 +810,14 @@ def array_agg(col: "ColumnOrName") -> Column: def collect_list(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns a list of objects with duplicates. + """Aggregate function: returns a list of objects with duplicates. .. versionadded:: 1.6.0 .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The function is non-deterministic because the order of collected results depends on the order of the rows which may be non-deterministic after a shuffle. @@ -847,12 +827,12 @@ def collect_list(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` list of objects with duplicates. - Examples + Examples: -------- >>> df2 = spark.createDataFrame([(2,), (5,), (5,)], ("age",)) >>> df2.agg(collect_list("age")).collect() @@ -862,8 +842,7 @@ def collect_list(col: "ColumnOrName") -> Column: def array_append(col: "ColumnOrName", value: Any) -> Column: - """ - Collection function: returns an array of the elements in col1 along + """Collection function: returns an array of the elements in col1 along with the added element in col2 at the last of the array. .. versionadded:: 3.4.0 @@ -875,16 +854,16 @@ def array_append(col: "ColumnOrName", value: Any) -> Column: value : a literal value, or a :class:`~pyspark.sql.Column` expression. - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of values from first array along with the element. - Notes + Notes: ----- Supports Spark Connect. - Examples + Examples: -------- >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2="c")]) @@ -897,8 +876,7 @@ def array_append(col: "ColumnOrName", value: Any) -> Column: def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: Any) -> Column: - """ - Collection function: adds an item into a given array at a specified array index. + """Collection function: adds an item into a given array at a specified array index. Array indices start at 1, or start from the end if index is negative. Index above array size appends the array, or prepends the array if index is negative, with 'null' elements. @@ -915,16 +893,16 @@ def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: An value : a literal value, or a :class:`~pyspark.sql.Column` expression. - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of values, including the new specified value - Notes + Notes: ----- Supports Spark Connect. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [(["a", "b", "c"], 2, "d"), (["c", "b", "a"], -2, "d")], ["data", "pos", "val"] @@ -991,8 +969,7 @@ def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: An def array_contains(col: "ColumnOrName", value: Any) -> Column: - """ - Collection function: returns null if the array is null, true if the array contains the + """Collection function: returns null if the array is null, true if the array contains the given value, and false otherwise. Parameters @@ -1002,12 +979,12 @@ def array_contains(col: "ColumnOrName", value: Any) -> Column: value : value or column to check for in array - Returns + Returns: ------- :class:`~pyspark.sql.Column` a column of Boolean type. - Examples + Examples: -------- >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ["data"]) >>> df.select(array_contains(df.data, "a")).collect() @@ -1020,8 +997,7 @@ def array_contains(col: "ColumnOrName", value: Any) -> Column: def array_distinct(col: "ColumnOrName") -> Column: - """ - Collection function: removes duplicate values from the array. + """Collection function: removes duplicate values from the array. .. versionadded:: 2.4.0 @@ -1033,12 +1009,12 @@ def array_distinct(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str name of column or expression - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of unique values. - Examples + Examples: -------- >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ["data"]) >>> df.select(array_distinct(df.data)).collect() @@ -1048,8 +1024,7 @@ def array_distinct(col: "ColumnOrName") -> Column: def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: - """ - Collection function: returns an array of the elements in the intersection of col1 and col2, + """Collection function: returns an array of the elements in the intersection of col1 and col2, without duplicates. .. versionadded:: 2.4.0 @@ -1064,12 +1039,12 @@ def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col2 : :class:`~pyspark.sql.Column` or str name of column containing array - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of values in the intersection of two arrays. - Examples + Examples: -------- >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) @@ -1080,8 +1055,7 @@ def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: def array_union(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: - """ - Collection function: returns an array of the elements in the union of col1 and col2, + """Collection function: returns an array of the elements in the union of col1 and col2, without duplicates. .. versionadded:: 2.4.0 @@ -1096,12 +1070,12 @@ def array_union(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col2 : :class:`~pyspark.sql.Column` or str name of column containing array - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of values in union of two arrays. - Examples + Examples: -------- >>> from pyspark.sql import Row >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) @@ -1112,8 +1086,7 @@ def array_union(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: def array_max(col: "ColumnOrName") -> Column: - """ - Collection function: returns the maximum value of the array. + """Collection function: returns the maximum value of the array. .. versionadded:: 2.4.0 @@ -1125,12 +1098,12 @@ def array_max(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str name of column or expression - Returns + Returns: ------- :class:`~pyspark.sql.Column` maximum value of an array. - Examples + Examples: -------- >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ["data"]) >>> df.select(array_max(df.data).alias("max")).collect() @@ -1142,8 +1115,7 @@ def array_max(col: "ColumnOrName") -> Column: def array_min(col: "ColumnOrName") -> Column: - """ - Collection function: returns the minimum value of the array. + """Collection function: returns the minimum value of the array. .. versionadded:: 2.4.0 @@ -1155,12 +1127,12 @@ def array_min(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str name of column or expression - Returns + Returns: ------- :class:`~pyspark.sql.Column` minimum value of array. - Examples + Examples: -------- >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ["data"]) >>> df.select(array_min(df.data).alias("min")).collect() @@ -1172,8 +1144,7 @@ def array_min(col: "ColumnOrName") -> Column: def avg(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the average of the values in a group. + """Aggregate function: returns the average of the values in a group. .. versionadded:: 1.3.0 @@ -1185,12 +1156,12 @@ def avg(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.range(10) >>> df.select(avg(col("id"))).show() @@ -1204,8 +1175,7 @@ def avg(col: "ColumnOrName") -> Column: def sum(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the sum of all values in the expression. + """Aggregate function: returns the sum of all values in the expression. .. versionadded:: 1.3.0 @@ -1217,12 +1187,12 @@ def sum(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.range(10) >>> df.select(sum(df["id"])).show() @@ -1236,8 +1206,7 @@ def sum(col: "ColumnOrName") -> Column: def max(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the maximum value of the expression in a group. + """Aggregate function: returns the maximum value of the expression in a group. .. versionadded:: 1.3.0 @@ -1249,12 +1218,12 @@ def max(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` column for computed results. - Examples + Examples: -------- >>> df = spark.range(10) >>> df.select(max(col("id"))).show() @@ -1268,8 +1237,7 @@ def max(col: "ColumnOrName") -> Column: def mean(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the average of the values in a group. + """Aggregate function: returns the average of the values in a group. An alias of :func:`avg`. .. versionadded:: 1.4.0 @@ -1282,12 +1250,12 @@ def mean(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.range(10) >>> df.select(mean(df.id)).show() @@ -1301,8 +1269,7 @@ def mean(col: "ColumnOrName") -> Column: def median(col: "ColumnOrName") -> Column: - """ - Returns the median of the values in a group. + """Returns the median of the values in a group. .. versionadded:: 3.4.0 @@ -1311,16 +1278,16 @@ def median(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the median of the values in a group. - Notes + Notes: ----- Supports Spark Connect. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -1345,8 +1312,7 @@ def median(col: "ColumnOrName") -> Column: def mode(col: "ColumnOrName") -> Column: - """ - Returns the most frequent value in a group. + """Returns the most frequent value in a group. .. versionadded:: 3.4.0 @@ -1355,16 +1321,16 @@ def mode(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the most frequent value in a group. - Notes + Notes: ----- Supports Spark Connect. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -1389,8 +1355,7 @@ def mode(col: "ColumnOrName") -> Column: def min(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the minimum value of the expression in a group. + """Aggregate function: returns the minimum value of the expression in a group. .. versionadded:: 1.3.0 @@ -1402,12 +1367,12 @@ def min(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` column for computed results. - Examples + Examples: -------- >>> df = spark.range(10) >>> df.select(min(df.id)).show() @@ -1432,12 +1397,12 @@ def any_value(col: "ColumnOrName") -> Column: ignorenulls : :class:`~pyspark.sql.Column` or bool if first value is null then look for first non-null value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` some value of `col` for a group of rows. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [(None, 1), ("a", 2), ("a", 3), ("b", 8), ("b", 2)], ["c1", "c2"] @@ -1451,8 +1416,7 @@ def any_value(col: "ColumnOrName") -> Column: def count(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the number of items in a group. + """Aggregate function: returns the number of items in a group. .. versionadded:: 1.3.0 @@ -1464,12 +1428,12 @@ def count(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` column for computed results. - Examples + Examples: -------- Count by all columns (start), and by a column that does not count ``None``. @@ -1500,12 +1464,12 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C maximum relative standard deviation allowed (default = 0.05). For rsd < 0.01, it is more efficient to use :func:`count_distinct` - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column of computed results. - Examples + Examples: -------- >>> df = spark.createDataFrame([1, 2, 2, 3], "INT") >>> df.agg(approx_count_distinct("value").alias("distinct_values")).show() @@ -1521,8 +1485,7 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C def approxCountDistinct(col: "ColumnOrName", rsd: Optional[float] = None) -> Column: - """ - .. versionadded:: 1.3.0 + """.. versionadded:: 1.3.0 .. versionchanged:: 3.4.0 Supports Spark Connect. @@ -1546,8 +1509,7 @@ def transform( col: "ColumnOrName", f: Union[Callable[[Column], Column], Callable[[Column, Column], Column]], ) -> Column: - """ - Returns an array of elements after applying a transformation to each element in the input array. + """Returns an array of elements after applying a transformation to each element in the input array. .. versionadded:: 3.1.0 @@ -1571,12 +1533,12 @@ def transform( Python ``UserDefinedFunctions`` are not supported (`SPARK-27052 `__). - Returns + Returns: ------- :class:`~pyspark.sql.Column` a new array of transformed elements. - Examples + Examples: -------- >>> df = spark.createDataFrame([(1, [1, 2, 3, 4])], ("key", "values")) >>> df.select(transform("values", lambda x: x * 2).alias("doubled")).show() @@ -1599,8 +1561,7 @@ def transform( def concat_ws(sep: str, *cols: "ColumnOrName") -> "Column": - """ - Concatenates multiple input string columns together into a single string column, + """Concatenates multiple input string columns together into a single string column, using the given separator. .. versionadded:: 1.5.0 @@ -1615,12 +1576,12 @@ def concat_ws(sep: str, *cols: "ColumnOrName") -> "Column": cols : :class:`~pyspark.sql.Column` or str list of columns to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` string of concatenated words. - Examples + Examples: -------- >>> df = spark.createDataFrame([("abcd", "123")], ["s", "d"]) >>> df.select(concat_ws("-", df.s, df.d).alias("s")).collect() @@ -1631,8 +1592,7 @@ def concat_ws(sep: str, *cols: "ColumnOrName") -> "Column": def lower(col: "ColumnOrName") -> Column: - """ - Converts a string expression to lower case. + """Converts a string expression to lower case. .. versionadded:: 1.5.0 @@ -1644,12 +1604,12 @@ def lower(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` lower case values. - Examples + Examples: -------- >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") >>> df.select(lower("value")).show() @@ -1665,8 +1625,7 @@ def lower(col: "ColumnOrName") -> Column: def lcase(str: "ColumnOrName") -> Column: - """ - Returns `str` with all characters changed to lowercase. + """Returns `str` with all characters changed to lowercase. .. versionadded:: 3.5.0 @@ -1675,7 +1634,7 @@ def lcase(str: "ColumnOrName") -> Column: str : :class:`~pyspark.sql.Column` or str Input column or strings. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(1).select(sf.lcase(sf.lit("Spark"))).show() @@ -1689,8 +1648,7 @@ def lcase(str: "ColumnOrName") -> Column: def ceil(col: "ColumnOrName") -> Column: - """ - Computes the ceiling of the given value. + """Computes the ceiling of the given value. .. versionadded:: 1.4.0 @@ -1702,12 +1660,12 @@ def ceil(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(ceil(lit(-0.1))).show() @@ -1725,8 +1683,7 @@ def ceiling(col: "ColumnOrName") -> Column: def floor(col: "ColumnOrName") -> Column: - """ - Computes the floor of the given value. + """Computes the floor of the given value. .. versionadded:: 1.4.0 @@ -1738,12 +1695,12 @@ def floor(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str column to find floor for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` nearest integer that is less than or equal to given value. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(floor(lit(2.5))).show() @@ -1757,8 +1714,7 @@ def floor(col: "ColumnOrName") -> Column: def abs(col: "ColumnOrName") -> Column: - """ - Computes the absolute value. + """Computes the absolute value. .. versionadded:: 1.3.0 @@ -1770,12 +1726,12 @@ def abs(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` column for computed results. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(abs(lit(-1))).show() @@ -1801,12 +1757,12 @@ def isnan(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` True if value is NaN and False otherwise. - Examples + Examples: -------- >>> df = spark.createDataFrame([(1.0, float("nan")), (float("nan"), 2.0)], ("a", "b")) >>> df.select("a", "b", isnan("a").alias("r1"), isnan(df.b).alias("r2")).show() @@ -1833,12 +1789,12 @@ def isnull(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` True if value is null and False otherwise. - Examples + Examples: -------- >>> df = spark.createDataFrame([(1, None), (None, 2)], ("a", "b")) >>> df.select("a", "b", isnull("a").alias("r1"), isnull(df.b).alias("r2")).show() @@ -1853,8 +1809,7 @@ def isnull(col: "ColumnOrName") -> Column: def isnotnull(col: "ColumnOrName") -> Column: - """ - Returns true if `col` is not null, or false otherwise. + """Returns true if `col` is not null, or false otherwise. .. versionadded:: 3.5.0 @@ -1862,7 +1817,7 @@ def isnotnull(col: "ColumnOrName") -> Column: ---------- col : :class:`~pyspark.sql.Column` or str - Examples + Examples: -------- >>> df = spark.createDataFrame([(None,), (1,)], ["e"]) >>> df.select(isnotnull(df.e).alias("r")).collect() @@ -1872,15 +1827,15 @@ def isnotnull(col: "ColumnOrName") -> Column: def equal_null(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: - """ - Returns same result as the EQUAL(=) operator for non-null operands, + """Returns same result as the EQUAL(=) operator for non-null operands, but returns true if both are null, false if one of the them is null. .. versionadded:: 3.5.0 Parameters ---------- col1 : :class:`~pyspark.sql.Column` or str col2 : :class:`~pyspark.sql.Column` or str - Examples + + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -1908,8 +1863,7 @@ def equal_null(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: def flatten(col: "ColumnOrName") -> Column: - """ - Collection function: creates a single array from an array of arrays. + """Collection function: creates a single array from an array of arrays. If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. @@ -1923,12 +1877,12 @@ def flatten(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str name of column or expression - Returns + Returns: ------- :class:`~pyspark.sql.Column` flattened array. - Examples + Examples: -------- >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ["data"]) >>> df.show(truncate=False) @@ -1952,8 +1906,7 @@ def flatten(col: "ColumnOrName") -> Column: def array_compact(col: "ColumnOrName") -> Column: - """ - Collection function: removes null values from the array. + """Collection function: removes null values from the array. .. versionadded:: 3.4.0 @@ -1962,16 +1915,16 @@ def array_compact(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str name of column or expression - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array by excluding the null values. - Notes + Notes: ----- Supports Spark Connect. - Examples + Examples: -------- >>> df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ["data"]) >>> df.select(array_compact(df.data)).collect() @@ -1983,8 +1936,7 @@ def array_compact(col: "ColumnOrName") -> Column: def array_remove(col: "ColumnOrName", element: Any) -> Column: - """ - Collection function: Remove all elements that equal to element from the given array. + """Collection function: Remove all elements that equal to element from the given array. .. versionadded:: 2.4.0 @@ -1998,12 +1950,12 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: element : element to be removed from the array - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array excluding given value. - Examples + Examples: -------- >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ["data"]) >>> df.select(array_remove(df.data, 1)).collect() @@ -2015,8 +1967,7 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: def last_day(date: "ColumnOrName") -> Column: - """ - Returns the last day of the month which the given date belongs to. + """Returns the last day of the month which the given date belongs to. .. versionadded:: 1.5.0 @@ -2028,12 +1979,12 @@ def last_day(date: "ColumnOrName") -> Column: date : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` last day of the month. - Examples + Examples: -------- >>> df = spark.createDataFrame([("1997-02-10",)], ["d"]) >>> df.select(last_day(df.d).alias("date")).collect() @@ -2043,8 +1994,7 @@ def last_day(date: "ColumnOrName") -> Column: def sqrt(col: "ColumnOrName") -> Column: - """ - Computes the square root of the specified float value. + """Computes the square root of the specified float value. .. versionadded:: 1.3.0 @@ -2056,12 +2006,12 @@ def sqrt(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` column for computed results. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(sqrt(lit(4))).show() @@ -2075,8 +2025,7 @@ def sqrt(col: "ColumnOrName") -> Column: def cbrt(col: "ColumnOrName") -> Column: - """ - Computes the cube-root of the given value. + """Computes the cube-root of the given value. .. versionadded:: 1.4.0 @@ -2088,12 +2037,12 @@ def cbrt(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(cbrt(lit(27))).show() @@ -2107,8 +2056,7 @@ def cbrt(col: "ColumnOrName") -> Column: def char(col: "ColumnOrName") -> Column: - """ - Returns the ASCII character having the binary equivalent to `col`. If col is larger than 256 the + """Returns the ASCII character having the binary equivalent to `col`. If col is larger than 256 the result is equivalent to char(col % 256) .. versionadded:: 3.5.0 @@ -2118,7 +2066,7 @@ def char(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str Input column or strings. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(1).select(sf.char(sf.lit(65))).show() @@ -2148,12 +2096,12 @@ def corr(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col1 : :class:`~pyspark.sql.Column` or str second column to calculate correlation. - Returns + Returns: ------- :class:`~pyspark.sql.Column` Pearson Correlation Coefficient of these two column values. - Examples + Examples: -------- >>> a = range(20) >>> b = [2 * x for x in range(20)] @@ -2165,8 +2113,7 @@ def corr(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: def cot(col: "ColumnOrName") -> Column: - """ - Computes cotangent of the input column. + """Computes cotangent of the input column. .. versionadded:: 3.3.0 @@ -2178,12 +2125,12 @@ def cot(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str angle in radians. - Returns + Returns: ------- :class:`~pyspark.sql.Column` cotangent of the angle. - Examples + Examples: -------- >>> import math >>> df = spark.range(1) @@ -2198,7 +2145,7 @@ def e() -> Column: .. versionadded:: 3.5.0 - Examples + Examples: -------- >>> spark.range(1).select(e()).show() +-----------------+ @@ -2211,18 +2158,19 @@ def e() -> Column: def negative(col: "ColumnOrName") -> Column: - """ - Returns the negative value. + """Returns the negative value. .. versionadded:: 3.5.0 Parameters ---------- col : :class:`~pyspark.sql.Column` or str column to calculate negative value for. - Returns + + Returns: ------- :class:`~pyspark.sql.Column` negative value. - Examples + + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(3).select(sf.negative("id")).show() @@ -2242,7 +2190,7 @@ def pi() -> Column: .. versionadded:: 3.5.0 - Examples + Examples: -------- >>> spark.range(1).select(pi()).show() +-----------------+ @@ -2255,8 +2203,7 @@ def pi() -> Column: def positive(col: "ColumnOrName") -> Column: - """ - Returns the value. + """Returns the value. .. versionadded:: 3.5.0 @@ -2265,12 +2212,12 @@ def positive(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str input value column. - Returns + Returns: ------- :class:`~pyspark.sql.Column` value. - Examples + Examples: -------- >>> df = spark.createDataFrame([(-1,), (0,), (1,)], ["v"]) >>> df.select(positive("v").alias("p")).show() @@ -2286,8 +2233,7 @@ def positive(col: "ColumnOrName") -> Column: def pow(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) -> Column: - """ - Returns the value of the first argument raised to the power of the second argument. + """Returns the value of the first argument raised to the power of the second argument. .. versionadded:: 1.4.0 @@ -2301,12 +2247,12 @@ def pow(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) col2 : str, :class:`~pyspark.sql.Column` or float the exponent number. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the base rased to the power the argument. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(pow(lit(3), lit(2))).first() @@ -2316,8 +2262,7 @@ def pow(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) def printf(format: "ColumnOrName", *cols: "ColumnOrName") -> Column: - """ - Formats the arguments in printf-style and returns the result as a string column. + """Formats the arguments in printf-style and returns the result as a string column. .. versionadded:: 3.5.0 @@ -2328,7 +2273,7 @@ def printf(format: "ColumnOrName", *cols: "ColumnOrName") -> Column: cols : :class:`~pyspark.sql.Column` or str column names or :class:`~pyspark.sql.Column`\\s to be used in formatting - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( @@ -2351,8 +2296,7 @@ def printf(format: "ColumnOrName", *cols: "ColumnOrName") -> Column: def product(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the product of the values in a group. + """Aggregate function: returns the product of the values in a group. .. versionadded:: 3.2.0 @@ -2364,12 +2308,12 @@ def product(col: "ColumnOrName") -> Column: col : str, :class:`Column` column containing values to be multiplied together - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.range(1, 10).toDF("x").withColumn("mod3", col("x") % 3) >>> prods = df.groupBy("mod3").agg(product("x").alias("product")) @@ -2394,7 +2338,7 @@ def rand(seed: Optional[int] = None) -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The function is non-deterministic in general case. @@ -2403,12 +2347,12 @@ def rand(seed: Optional[int] = None) -> Column: seed : int (default: None) seed value for random generator. - Returns + Returns: ------- :class:`~pyspark.sql.Column` random values. - Examples + Examples: -------- >>> from pyspark.sql import functions as sf >>> spark.range(0, 2, 1, 1).withColumn("rand", sf.rand(seed=42) * 3).show() @@ -2437,12 +2381,12 @@ def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: regexp : :class:`~pyspark.sql.Column` or str regex pattern to apply. - Returns + Returns: ------- :class:`~pyspark.sql.Column` true if `str` matches a Java regex, or false otherwise. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( @@ -2490,12 +2434,12 @@ def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: regexp : :class:`~pyspark.sql.Column` or str regex pattern to apply. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the number of times that a Java regex pattern is matched in the string. - Examples + Examples: -------- >>> df = spark.createDataFrame([("1a 2b 14m", r"\d+")], ["str", "regexp"]) >>> df.select(regexp_count("str", lit(r"\d+")).alias("d")).collect() @@ -2526,12 +2470,12 @@ def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: idx : int matched group id. - Returns + Returns: ------- :class:`~pyspark.sql.Column` matched value specified by `idx` group id. - Examples + Examples: -------- >>> df = spark.createDataFrame([("100-200",)], ["str"]) >>> df.select(regexp_extract("str", r"(\d+)-(\d+)", 1).alias("d")).collect() @@ -2563,12 +2507,12 @@ def regexp_extract_all(str: "ColumnOrName", regexp: "ColumnOrName", idx: Optiona idx : int matched group id. - Returns + Returns: ------- :class:`~pyspark.sql.Column` all strings in the `str` that match a Java regex and corresponding to the regex group index. - Examples + Examples: -------- >>> df = spark.createDataFrame([("100-200, 300-400", r"(\d+)-(\d+)")], ["str", "regexp"]) >>> df.select(regexp_extract_all("str", lit(r"(\d+)-(\d+)")).alias("d")).collect() @@ -2599,12 +2543,12 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: regexp : :class:`~pyspark.sql.Column` or str regex pattern to apply. - Returns + Returns: ------- :class:`~pyspark.sql.Column` true if `str` matches a Java regex, or false otherwise. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]).select( @@ -2652,12 +2596,12 @@ def regexp_substr(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: regexp : :class:`~pyspark.sql.Column` or str regex pattern to apply. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the substring that matches a Java regex within the string `str`. - Examples + Examples: -------- >>> df = spark.createDataFrame([("1a 2b 14m", r"\d+")], ["str", "regexp"]) >>> df.select(regexp_substr("str", lit(r"\d+")).alias("d")).collect() @@ -2677,8 +2621,7 @@ def regexp_substr(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: def repeat(col: "ColumnOrName", n: int) -> Column: - """ - Repeats a string column n times, and returns it as a new string column. + """Repeats a string column n times, and returns it as a new string column. .. versionadded:: 1.5.0 @@ -2692,12 +2635,12 @@ def repeat(col: "ColumnOrName", n: int) -> Column: n : int number of times to repeat value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` string with repeated values. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [("ab",)], @@ -2712,8 +2655,7 @@ def repeat(col: "ColumnOrName", n: int) -> Column: def sequence(start: "ColumnOrName", stop: "ColumnOrName", step: Optional["ColumnOrName"] = None) -> Column: - """ - Generate a sequence of integers from `start` to `stop`, incrementing by `step`. + """Generate a sequence of integers from `start` to `stop`, incrementing by `step`. If `step` is not set, incrementing by 1 if `start` is less than or equal to `stop`, otherwise -1. @@ -2731,12 +2673,12 @@ def sequence(start: "ColumnOrName", stop: "ColumnOrName", step: Optional["Column step : :class:`~pyspark.sql.Column` or str, optional value to add to current to get next element (default is 1) - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of sequence values - Examples + Examples: -------- >>> df1 = spark.createDataFrame([(-2, 2)], ("C1", "C2")) >>> df1.select(sequence("C1", "C2").alias("r")).collect() @@ -2752,8 +2694,7 @@ def sequence(start: "ColumnOrName", stop: "ColumnOrName", step: Optional["Column def sign(col: "ColumnOrName") -> Column: - """ - Computes the signum of the given value. + """Computes the signum of the given value. .. versionadded:: 1.4.0 @@ -2765,12 +2706,12 @@ def sign(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(1).select(sf.sign(sf.lit(-5)), sf.sign(sf.lit(6))).show() @@ -2784,8 +2725,7 @@ def sign(col: "ColumnOrName") -> Column: def signum(col: "ColumnOrName") -> Column: - """ - Computes the signum of the given value. + """Computes the signum of the given value. .. versionadded:: 1.4.0 @@ -2797,12 +2737,12 @@ def signum(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(1).select(sf.signum(sf.lit(-5)), sf.signum(sf.lit(6))).show() @@ -2816,8 +2756,7 @@ def signum(col: "ColumnOrName") -> Column: def sin(col: "ColumnOrName") -> Column: - """ - Computes sine of the input column. + """Computes sine of the input column. .. versionadded:: 1.4.0 @@ -2829,12 +2768,12 @@ def sin(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` sine of the angle, as if computed by `java.lang.Math.sin()` - Examples + Examples: -------- >>> import math >>> df = spark.range(1) @@ -2845,8 +2784,7 @@ def sin(col: "ColumnOrName") -> Column: def skewness(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the skewness of the values in a group. + """Aggregate function: returns the skewness of the values in a group. .. versionadded:: 1.6.0 @@ -2858,12 +2796,12 @@ def skewness(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` skewness of given column. - Examples + Examples: -------- >>> df = spark.createDataFrame([[1], [1], [2]], ["c"]) >>> df.select(skewness(df.c)).first() @@ -2873,8 +2811,7 @@ def skewness(col: "ColumnOrName") -> Column: def encode(col: "ColumnOrName", charset: str) -> Column: - """ - Computes the first argument into a binary from a string using the provided character set + """Computes the first argument into a binary from a string using the provided character set (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). .. versionadded:: 1.5.0 @@ -2889,12 +2826,12 @@ def encode(col: "ColumnOrName", charset: str) -> Column: charset : str charset to use to encode. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.createDataFrame([("abcd",)], ["c"]) >>> df.select(encode("c", "UTF-8")).show() @@ -2910,8 +2847,7 @@ def encode(col: "ColumnOrName", charset: str) -> Column: def find_in_set(str: "ColumnOrName", str_array: "ColumnOrName") -> Column: - """ - Returns the index (1-based) of the given string (`str`) in the comma-delimited + """Returns the index (1-based) of the given string (`str`) in the comma-delimited list (`strArray`). Returns 0, if the string was not found or if the given string (`str`) contains a comma. @@ -2924,7 +2860,7 @@ def find_in_set(str: "ColumnOrName", str_array: "ColumnOrName") -> Column: str_array : :class:`~pyspark.sql.Column` or str The comma-delimited list. - Examples + Examples: -------- >>> df = spark.createDataFrame([("ab", "abc,b,ab,c,def")], ["a", "b"]) >>> df.select(find_in_set(df.a, df.b).alias("r")).collect() @@ -2956,7 +2892,7 @@ def first(col: "ColumnOrName", ignorenulls: bool = False) -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The function is non-deterministic because its results depends on the order of the rows which may be non-deterministic after a shuffle. @@ -2968,12 +2904,12 @@ def first(col: "ColumnOrName", ignorenulls: bool = False) -> Column: ignorenulls : :class:`~pyspark.sql.Column` or str if first value is null then look for first non-null value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` first value of the group. - Examples + Examples: -------- >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5), ("Alice", None)], ("name", "age")) >>> df = df.orderBy(df.age) @@ -3011,7 +2947,7 @@ def last(col: "ColumnOrName", ignorenulls: bool = False) -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The function is non-deterministic because its results depends on the order of the rows which may be non-deterministic after a shuffle. @@ -3023,12 +2959,12 @@ def last(col: "ColumnOrName", ignorenulls: bool = False) -> Column: ignorenulls : :class:`~pyspark.sql.Column` or str if last value is null then look for non-null value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` last value of the group. - Examples + Examples: -------- >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5), ("Alice", None)], ("name", "age")) >>> df = df.orderBy(df.age.desc()) @@ -3056,8 +2992,7 @@ def last(col: "ColumnOrName", ignorenulls: bool = False) -> Column: def greatest(*cols: "ColumnOrName") -> Column: - """ - Returns the greatest value of the list of column names, skipping null values. + """Returns the greatest value of the list of column names, skipping null values. This function takes at least 2 parameters. It will return null if all parameters are null. .. versionadded:: 1.5.0 @@ -3070,18 +3005,17 @@ def greatest(*cols: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str columns to check for gratest value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` gratest value. - Examples + Examples: -------- >>> df = spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"]) >>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect() [Row(greatest=4)] """ - if len(cols) < 2: raise ValueError("greatest should take at least 2 columns") @@ -3090,8 +3024,7 @@ def greatest(*cols: "ColumnOrName") -> Column: def least(*cols: "ColumnOrName") -> Column: - """ - Returns the least value of the list of column names, skipping null values. + """Returns the least value of the list of column names, skipping null values. This function takes at least 2 parameters. It will return null if all parameters are null. .. versionadded:: 1.5.0 @@ -3104,12 +3037,12 @@ def least(*cols: "ColumnOrName") -> Column: cols : :class:`~pyspark.sql.Column` or str column names or columns to be compared - Returns + Returns: ------- :class:`~pyspark.sql.Column` least value. - Examples + Examples: -------- >>> df = spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"]) >>> df.select(least(df.a, df.b, df.c).alias("least")).collect() @@ -3123,8 +3056,7 @@ def least(*cols: "ColumnOrName") -> Column: def trim(col: "ColumnOrName") -> Column: - """ - Trim the spaces from left end for the specified string value. + """Trim the spaces from left end for the specified string value. .. versionadded:: 1.5.0 @@ -3136,12 +3068,12 @@ def trim(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` left trimmed values. - Examples + Examples: -------- >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING") >>> df.select(ltrim("value").alias("r")).withColumn("length", length("r")).show() @@ -3157,8 +3089,7 @@ def trim(col: "ColumnOrName") -> Column: def rtrim(col: "ColumnOrName") -> Column: - """ - Trim the spaces from right end for the specified string value. + """Trim the spaces from right end for the specified string value. .. versionadded:: 1.5.0 @@ -3170,12 +3101,12 @@ def rtrim(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` right trimmed values. - Examples + Examples: -------- >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING") >>> df.select(rtrim("value").alias("r")).withColumn("length", length("r")).show() @@ -3191,8 +3122,7 @@ def rtrim(col: "ColumnOrName") -> Column: def ltrim(col: "ColumnOrName") -> Column: - """ - Trim the spaces from left end for the specified string value. + """Trim the spaces from left end for the specified string value. .. versionadded:: 1.5.0 @@ -3204,12 +3134,12 @@ def ltrim(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` left trimmed values. - Examples + Examples: -------- >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING") >>> df.select(ltrim("value").alias("r")).withColumn("length", length("r")).show() @@ -3225,8 +3155,7 @@ def ltrim(col: "ColumnOrName") -> Column: def btrim(str: "ColumnOrName", trim: Optional["ColumnOrName"] = None) -> Column: - """ - Remove the leading and trailing `trim` characters from `str`. + """Remove the leading and trailing `trim` characters from `str`. .. versionadded:: 3.5.0 @@ -3237,7 +3166,7 @@ def btrim(str: "ColumnOrName", trim: Optional["ColumnOrName"] = None) -> Column: trim : :class:`~pyspark.sql.Column` or str The trim string characters to trim, the default value is a single space - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -3262,8 +3191,7 @@ def btrim(str: "ColumnOrName", trim: Optional["ColumnOrName"] = None) -> Column: def endswith(str: "ColumnOrName", suffix: "ColumnOrName") -> Column: - """ - Returns a boolean. The value is True if str ends with suffix. + """Returns a boolean. The value is True if str ends with suffix. Returns NULL if either input expression is NULL. Otherwise, returns False. Both str or suffix must be of STRING or BINARY type. @@ -3276,7 +3204,7 @@ def endswith(str: "ColumnOrName", suffix: "ColumnOrName") -> Column: suffix : :class:`~pyspark.sql.Column` or str A column of string, the suffix. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -3315,8 +3243,7 @@ def endswith(str: "ColumnOrName", suffix: "ColumnOrName") -> Column: def startswith(str: "ColumnOrName", prefix: "ColumnOrName") -> Column: - """ - Returns a boolean. The value is True if str starts with prefix. + """Returns a boolean. The value is True if str starts with prefix. Returns NULL if either input expression is NULL. Otherwise, returns False. Both str or prefix must be of STRING or BINARY type. @@ -3329,7 +3256,7 @@ def startswith(str: "ColumnOrName", prefix: "ColumnOrName") -> Column: prefix : :class:`~pyspark.sql.Column` or str A column of string, the prefix. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -3382,12 +3309,12 @@ def length(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` length of the value. - Examples + Examples: -------- >>> spark.createDataFrame([("ABC ",)], ["a"]).select(length("a").alias("length")).collect() [Row(length=4)] @@ -3404,11 +3331,13 @@ def coalesce(*cols: "ColumnOrName") -> Column: ---------- cols : :class:`~pyspark.sql.Column` or str list of columns to work on. - Returns + + Returns: ------- :class:`~pyspark.sql.Column` value of the first column that is not null. - Examples + + Examples: -------- >>> cDf = spark.createDataFrame([(None, None), (1, None), (None, 2)], ("a", "b")) >>> cDf.show() @@ -3436,20 +3365,19 @@ def coalesce(*cols: "ColumnOrName") -> Column: |NULL| 2| 0.0| +----+----+----------------+ """ - cols = [_to_column_expr(expr) for expr in cols] return Column(CoalesceOperator(*cols)) def nvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: - """ - Returns `col2` if `col1` is null, or `col1` otherwise. + """Returns `col2` if `col1` is null, or `col1` otherwise. .. versionadded:: 3.5.0 Parameters ---------- col1 : :class:`~pyspark.sql.Column` or str col2 : :class:`~pyspark.sql.Column` or str - Examples + + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -3467,13 +3395,11 @@ def nvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df.select(nvl(df.a, df.b).alias("r")).collect() [Row(r=8), Row(r=1)] """ - return coalesce(col1, col2) def nvl2(col1: "ColumnOrName", col2: "ColumnOrName", col3: "ColumnOrName") -> Column: - """ - Returns `col2` if `col1` is not null, or `col3` otherwise. + """Returns `col2` if `col1` is not null, or `col3` otherwise. .. versionadded:: 3.5.0 @@ -3483,7 +3409,7 @@ def nvl2(col1: "ColumnOrName", col2: "ColumnOrName", col3: "ColumnOrName") -> Co col2 : :class:`~pyspark.sql.Column` or str col3 : :class:`~pyspark.sql.Column` or str - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -3510,14 +3436,14 @@ def nvl2(col1: "ColumnOrName", col2: "ColumnOrName", col3: "ColumnOrName") -> Co def ifnull(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: - """ - Returns `col2` if `col1` is null, or `col1` otherwise. + """Returns `col2` if `col1` is null, or `col1` otherwise. .. versionadded:: 3.5.0 Parameters ---------- col1 : :class:`~pyspark.sql.Column` or str col2 : :class:`~pyspark.sql.Column` or str - Examples + + Examples: -------- >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([(None,), (1,)], ["e"]) @@ -3533,8 +3459,7 @@ def ifnull(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: def nullif(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: - """ - Returns null if `col1` equals to `col2`, or `col1` otherwise. + """Returns null if `col1` equals to `col2`, or `col1` otherwise. .. versionadded:: 3.5.0 @@ -3543,7 +3468,7 @@ def nullif(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col1 : :class:`~pyspark.sql.Column` or str col2 : :class:`~pyspark.sql.Column` or str - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -3577,12 +3502,12 @@ def md5(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> spark.createDataFrame([("ABC",)], ["a"]).select(md5("a").alias("hash")).collect() [Row(hash='902fbdd2b1df0c4f70b4a5d23525e932')] @@ -3608,12 +3533,12 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column: the desired bit length of the result, which must have a value of 224, 256, 384, 512, or 0 (which is equivalent to 256). - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column for computed results. - Examples + Examples: -------- >>> df = spark.createDataFrame([["Alice"], ["Bob"]], ["name"]) >>> df.withColumn("sha2", sha2(df.name, 256)).show(truncate=False) @@ -3624,7 +3549,6 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column: |Bob |cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961| +-----+----------------------------------------------------------------+ """ - if numBits not in {224, 256, 384, 512, 0}: raise ValueError("numBits should be one of {224, 256, 384, 512, 0}") @@ -3635,18 +3559,17 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column: def curdate() -> Column: - """ - Returns the current date at the start of query evaluation as a :class:`DateType` column. + """Returns the current date at the start of query evaluation as a :class:`DateType` column. All calls of current_date within the same query return the same value. .. versionadded:: 3.5.0 - Returns + Returns: ------- :class:`~pyspark.sql.Column` current date. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(1).select(sf.curdate()).show() # doctest: +SKIP @@ -3660,8 +3583,7 @@ def curdate() -> Column: def current_date() -> Column: - """ - Returns the current date at the start of query evaluation as a :class:`DateType` column. + """Returns the current date at the start of query evaluation as a :class:`DateType` column. All calls of current_date within the same query return the same value. .. versionadded:: 1.5.0 @@ -3669,12 +3591,12 @@ def current_date() -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Returns + Returns: ------- :class:`~pyspark.sql.Column` current date. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(current_date()).show() # doctest: +SKIP @@ -3688,17 +3610,16 @@ def current_date() -> Column: def now() -> Column: - """ - Returns the current timestamp at the start of query evaluation. + """Returns the current timestamp at the start of query evaluation. .. versionadded:: 3.5.0 - Returns + Returns: ------- :class:`~pyspark.sql.Column` current timestamp at the start of query evaluation. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(now()).show(truncate=False) # doctest: +SKIP @@ -3712,8 +3633,7 @@ def now() -> Column: def desc(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the descending order of the given column name. + """Returns a sort expression based on the descending order of the given column name. .. versionadded:: 1.3.0 @@ -3725,12 +3645,12 @@ def desc(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the descending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- Sort by the column 'id' in the descending order. @@ -3749,8 +3669,7 @@ def desc(col: "ColumnOrName") -> Column: def asc(col: "ColumnOrName") -> Column: - """ - Returns a sort expression based on the ascending order of the given column name. + """Returns a sort expression based on the ascending order of the given column name. .. versionadded:: 1.3.0 @@ -3762,12 +3681,12 @@ def asc(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to sort by in the ascending order. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the column specifying the order. - Examples + Examples: -------- Sort by the column 'id' in the descending order. @@ -3801,8 +3720,7 @@ def asc(col: "ColumnOrName") -> Column: def date_trunc(format: str, timestamp: "ColumnOrName") -> Column: - """ - Returns timestamp truncated to the unit specified by the format. + """Returns timestamp truncated to the unit specified by the format. .. versionadded:: 2.3.0 @@ -3813,7 +3731,7 @@ def date_trunc(format: str, timestamp: "ColumnOrName") -> Column: 'day', 'dd', 'hour', 'minute', 'second', 'week', 'quarter' timestamp : :class:`~pyspark.sql.Column` or str - Examples + Examples: -------- >>> df = spark.createDataFrame([("1997-02-28 05:02:11",)], ["t"]) >>> df.select(date_trunc("year", df.t).alias("year")).collect() @@ -3834,8 +3752,7 @@ def date_trunc(format: str, timestamp: "ColumnOrName") -> Column: def date_part(field: "ColumnOrName", source: "ColumnOrName") -> Column: - """ - Extracts a part of the date/timestamp or interval source. + """Extracts a part of the date/timestamp or interval source. .. versionadded:: 3.5.0 @@ -3847,12 +3764,12 @@ def date_part(field: "ColumnOrName", source: "ColumnOrName") -> Column: source : :class:`~pyspark.sql.Column` or str a date/timestamp or interval column from where `field` should be extracted. - Returns + Returns: ------- :class:`~pyspark.sql.Column` a part of the date/timestamp or interval source. - Examples + Examples: -------- >>> import datetime >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) @@ -3870,8 +3787,7 @@ def date_part(field: "ColumnOrName", source: "ColumnOrName") -> Column: def extract(field: "ColumnOrName", source: "ColumnOrName") -> Column: - """ - Extracts a part of the date/timestamp or interval source. + """Extracts a part of the date/timestamp or interval source. .. versionadded:: 3.5.0 @@ -3882,12 +3798,12 @@ def extract(field: "ColumnOrName", source: "ColumnOrName") -> Column: source : :class:`~pyspark.sql.Column` or str a date/timestamp or interval column from where `field` should be extracted. - Returns + Returns: ------- :class:`~pyspark.sql.Column` a part of the date/timestamp or interval source. - Examples + Examples: -------- >>> import datetime >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) @@ -3905,8 +3821,7 @@ def extract(field: "ColumnOrName", source: "ColumnOrName") -> Column: def datepart(field: "ColumnOrName", source: "ColumnOrName") -> Column: - """ - Extracts a part of the date/timestamp or interval source. + """Extracts a part of the date/timestamp or interval source. .. versionadded:: 3.5.0 @@ -3918,12 +3833,12 @@ def datepart(field: "ColumnOrName", source: "ColumnOrName") -> Column: source : :class:`~pyspark.sql.Column` or str a date/timestamp or interval column from where `field` should be extracted. - Returns + Returns: ------- :class:`~pyspark.sql.Column` a part of the date/timestamp or interval source. - Examples + Examples: -------- >>> import datetime >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) @@ -3941,8 +3856,7 @@ def datepart(field: "ColumnOrName", source: "ColumnOrName") -> Column: def date_diff(end: "ColumnOrName", start: "ColumnOrName") -> Column: - """ - Returns the number of days from `start` to `end`. + """Returns the number of days from `start` to `end`. .. versionadded:: 3.5.0 @@ -3953,12 +3867,12 @@ def date_diff(end: "ColumnOrName", start: "ColumnOrName") -> Column: start : :class:`~pyspark.sql.Column` or column name from date column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` difference in days between two dates. - See Also + See Also: -------- :meth:`pyspark.sql.functions.dateadd` :meth:`pyspark.sql.functions.date_add` @@ -3966,7 +3880,7 @@ def date_diff(end: "ColumnOrName", start: "ColumnOrName") -> Column: :meth:`pyspark.sql.functions.datediff` :meth:`pyspark.sql.functions.timestamp_diff` - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([("2015-04-08", "2015-05-10")], ["d1", "d2"]) @@ -3992,8 +3906,7 @@ def date_diff(end: "ColumnOrName", start: "ColumnOrName") -> Column: def year(col: "ColumnOrName") -> Column: - """ - Extract the year of a given date/timestamp as integer. + """Extract the year of a given date/timestamp as integer. .. versionadded:: 1.5.0 @@ -4005,12 +3918,12 @@ def year(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` year part of the date/timestamp as integer. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(year("dt").alias("year")).collect() @@ -4020,8 +3933,7 @@ def year(col: "ColumnOrName") -> Column: def quarter(col: "ColumnOrName") -> Column: - """ - Extract the quarter of a given date/timestamp as integer. + """Extract the quarter of a given date/timestamp as integer. .. versionadded:: 1.5.0 @@ -4033,12 +3945,12 @@ def quarter(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` quarter of the date/timestamp as integer. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(quarter("dt").alias("quarter")).collect() @@ -4048,8 +3960,7 @@ def quarter(col: "ColumnOrName") -> Column: def month(col: "ColumnOrName") -> Column: - """ - Extract the month of a given date/timestamp as integer. + """Extract the month of a given date/timestamp as integer. .. versionadded:: 1.5.0 @@ -4061,12 +3972,12 @@ def month(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` month part of the date/timestamp as integer. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(month("dt").alias("month")).collect() @@ -4076,8 +3987,7 @@ def month(col: "ColumnOrName") -> Column: def dayofweek(col: "ColumnOrName") -> Column: - """ - Extract the day of the week of a given date/timestamp as integer. + """Extract the day of the week of a given date/timestamp as integer. Ranges from 1 for a Sunday through to 7 for a Saturday .. versionadded:: 2.3.0 @@ -4090,12 +4000,12 @@ def dayofweek(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` day of the week for given date/timestamp as integer. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(dayofweek("dt").alias("day")).collect() @@ -4105,8 +4015,7 @@ def dayofweek(col: "ColumnOrName") -> Column: def day(col: "ColumnOrName") -> Column: - """ - Extract the day of the month of a given date/timestamp as integer. + """Extract the day of the month of a given date/timestamp as integer. .. versionadded:: 3.5.0 @@ -4115,12 +4024,12 @@ def day(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` day of the month for given date/timestamp as integer. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(day("dt").alias("day")).collect() @@ -4130,8 +4039,7 @@ def day(col: "ColumnOrName") -> Column: def dayofmonth(col: "ColumnOrName") -> Column: - """ - Extract the day of the month of a given date/timestamp as integer. + """Extract the day of the month of a given date/timestamp as integer. .. versionadded:: 1.5.0 @@ -4143,12 +4051,12 @@ def dayofmonth(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` day of the month for given date/timestamp as integer. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(dayofmonth("dt").alias("day")).collect() @@ -4158,8 +4066,7 @@ def dayofmonth(col: "ColumnOrName") -> Column: def dayofyear(col: "ColumnOrName") -> Column: - """ - Extract the day of the year of a given date/timestamp as integer. + """Extract the day of the year of a given date/timestamp as integer. .. versionadded:: 1.5.0 @@ -4171,12 +4078,12 @@ def dayofyear(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` day of the year for given date/timestamp as integer. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(dayofyear("dt").alias("day")).collect() @@ -4186,8 +4093,7 @@ def dayofyear(col: "ColumnOrName") -> Column: def hour(col: "ColumnOrName") -> Column: - """ - Extract the hours of a given timestamp as integer. + """Extract the hours of a given timestamp as integer. .. versionadded:: 1.5.0 @@ -4199,12 +4105,12 @@ def hour(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` hour part of the timestamp as integer. - Examples + Examples: -------- >>> import datetime >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) @@ -4215,8 +4121,7 @@ def hour(col: "ColumnOrName") -> Column: def minute(col: "ColumnOrName") -> Column: - """ - Extract the minutes of a given timestamp as integer. + """Extract the minutes of a given timestamp as integer. .. versionadded:: 1.5.0 @@ -4228,12 +4133,12 @@ def minute(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` minutes part of the timestamp as integer. - Examples + Examples: -------- >>> import datetime >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) @@ -4244,8 +4149,7 @@ def minute(col: "ColumnOrName") -> Column: def second(col: "ColumnOrName") -> Column: - """ - Extract the seconds of a given date as integer. + """Extract the seconds of a given date as integer. .. versionadded:: 1.5.0 @@ -4257,12 +4161,12 @@ def second(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` `seconds` part of the timestamp as integer. - Examples + Examples: -------- >>> import datetime >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ["ts"]) @@ -4273,8 +4177,7 @@ def second(col: "ColumnOrName") -> Column: def weekofyear(col: "ColumnOrName") -> Column: - """ - Extract the week number of a given date as integer. + """Extract the week number of a given date as integer. A week is considered to start on a Monday and week 1 is the first week with more than 3 days, as defined by ISO 8601 @@ -4288,12 +4191,12 @@ def weekofyear(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` `week` of the year for given date as integer. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(weekofyear(df.dt).alias("week")).collect() @@ -4303,8 +4206,7 @@ def weekofyear(col: "ColumnOrName") -> Column: def cos(col: "ColumnOrName") -> Column: - """ - Computes cosine of the input column. + """Computes cosine of the input column. .. versionadded:: 1.4.0 @@ -4316,12 +4218,12 @@ def cos(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str angle in radians - Returns + Returns: ------- :class:`~pyspark.sql.Column` cosine of the angle, as if computed by `java.lang.Math.cos()`. - Examples + Examples: -------- >>> import math >>> df = spark.range(1) @@ -4332,8 +4234,7 @@ def cos(col: "ColumnOrName") -> Column: def acos(col: "ColumnOrName") -> Column: - """ - Computes inverse cosine of the input column. + """Computes inverse cosine of the input column. .. versionadded:: 1.4.0 @@ -4345,12 +4246,12 @@ def acos(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` inverse cosine of `col`, as if computed by `java.lang.Math.acos()` - Examples + Examples: -------- >>> df = spark.range(1, 3) >>> df.select(acos(df.id)).show() @@ -4365,8 +4266,7 @@ def acos(col: "ColumnOrName") -> Column: def call_function(funcName: str, *cols: "ColumnOrName") -> Column: - """ - Call a SQL function. + """Call a SQL function. .. versionadded:: 3.5.0 @@ -4377,12 +4277,12 @@ def call_function(funcName: str, *cols: "ColumnOrName") -> Column: cols : :class:`~pyspark.sql.Column` or str column names or :class:`~pyspark.sql.Column`\\s to be used in the function - Returns + Returns: ------- :class:`~pyspark.sql.Column` result of executed function. - Examples + Examples: -------- >>> from pyspark.sql.functions import call_udf, col >>> from pyspark.sql.types import IntegerType, StringType @@ -4447,12 +4347,12 @@ def covar_pop(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col1 : :class:`~pyspark.sql.Column` or str second column to calculate covariance. - Returns + Returns: ------- :class:`~pyspark.sql.Column` covariance of these two column values. - Examples + Examples: -------- >>> a = [1] * 10 >>> b = [1] * 10 @@ -4479,12 +4379,12 @@ def covar_samp(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: col1 : :class:`~pyspark.sql.Column` or str second column to calculate covariance. - Returns + Returns: ------- :class:`~pyspark.sql.Column` sample covariance of these two column values. - Examples + Examples: -------- >>> a = [1] * 10 >>> b = [1] * 10 @@ -4496,8 +4396,7 @@ def covar_samp(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: def exp(col: "ColumnOrName") -> Column: - """ - Computes the exponential of the given value. + """Computes the exponential of the given value. .. versionadded:: 1.4.0 @@ -4509,12 +4408,12 @@ def exp(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str column to calculate exponential for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` exponential of the given value. - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(exp(lit(0))).show() @@ -4528,8 +4427,7 @@ def exp(col: "ColumnOrName") -> Column: def factorial(col: "ColumnOrName") -> Column: - """ - Computes the factorial of the given value. + """Computes the factorial of the given value. .. versionadded:: 1.5.0 @@ -4541,12 +4439,12 @@ def factorial(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str a column to calculate factorial for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` factorial of given value. - Examples + Examples: -------- >>> df = spark.createDataFrame([(5,)], ["n"]) >>> df.select(factorial(df.n).alias("f")).collect() @@ -4568,12 +4466,12 @@ def log2(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str a column to calculate logariphm for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` logariphm of given value. - Examples + Examples: -------- >>> df = spark.createDataFrame([(4,)], ["a"]) >>> df.select(log2("a").alias("log2")).show() @@ -4596,12 +4494,12 @@ def ln(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str a column to calculate logariphm for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` natural logarithm of given value. - Examples + Examples: -------- >>> df = spark.createDataFrame([(4,)], ["a"]) >>> df.select(ln("a")).show() @@ -4615,8 +4513,7 @@ def ln(col: "ColumnOrName") -> Column: def degrees(col: "ColumnOrName") -> Column: - """ - Converts an angle measured in radians to an approximately equivalent angle + """Converts an angle measured in radians to an approximately equivalent angle measured in degrees. .. versionadded:: 2.1.0 @@ -4629,12 +4526,12 @@ def degrees(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str angle in radians - Returns + Returns: ------- :class:`~pyspark.sql.Column` angle in degrees, as if computed by `java.lang.Math.toDegrees()` - Examples + Examples: -------- >>> import math >>> df = spark.range(1) @@ -4645,8 +4542,7 @@ def degrees(col: "ColumnOrName") -> Column: def radians(col: "ColumnOrName") -> Column: - """ - Converts an angle measured in degrees to an approximately equivalent angle + """Converts an angle measured in degrees to an approximately equivalent angle measured in radians. .. versionadded:: 2.1.0 @@ -4659,12 +4555,12 @@ def radians(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str angle in degrees - Returns + Returns: ------- :class:`~pyspark.sql.Column` angle in radians, as if computed by `java.lang.Math.toRadians()` - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(radians(lit(180))).first() @@ -4674,8 +4570,7 @@ def radians(col: "ColumnOrName") -> Column: def atan(col: "ColumnOrName") -> Column: - """ - Compute inverse tangent of the input column. + """Compute inverse tangent of the input column. .. versionadded:: 1.4.0 @@ -4687,12 +4582,12 @@ def atan(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` inverse tangent of `col`, as if computed by `java.lang.Math.atan()` - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(atan(df.id)).show() @@ -4706,8 +4601,7 @@ def atan(col: "ColumnOrName") -> Column: def atan2(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) -> Column: - """ - .. versionadded:: 1.4.0 + """.. versionadded:: 1.4.0 .. versionchanged:: 3.4.0 Supports Spark Connect. @@ -4719,7 +4613,7 @@ def atan2(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float] col2 : str, :class:`~pyspark.sql.Column` or float coordinate on x-axis - Returns + Returns: ------- :class:`~pyspark.sql.Column` the `theta` component of the point @@ -4728,7 +4622,7 @@ def atan2(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float] (`x`, `y`) in Cartesian coordinates, as if computed by `java.lang.Math.atan2()` - Examples + Examples: -------- >>> df = spark.range(1) >>> df.select(atan2(lit(1), lit(2))).first() @@ -4744,8 +4638,7 @@ def lit_or_column(x: Union["ColumnOrName", float]) -> Column: def tan(col: "ColumnOrName") -> Column: - """ - Computes tangent of the input column. + """Computes tangent of the input column. .. versionadded:: 1.4.0 @@ -4757,12 +4650,12 @@ def tan(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str angle in radians - Returns + Returns: ------- :class:`~pyspark.sql.Column` tangent of the given value, as if computed by `java.lang.Math.tan()` - Examples + Examples: -------- >>> import math >>> df = spark.range(1) @@ -4773,8 +4666,7 @@ def tan(col: "ColumnOrName") -> Column: def round(col: "ColumnOrName", scale: int = 0) -> Column: - """ - Round the given value to `scale` decimal places using HALF_UP rounding mode if `scale` >= 0 + """Round the given value to `scale` decimal places using HALF_UP rounding mode if `scale` >= 0 or at integral part when `scale` < 0. .. versionadded:: 1.5.0 @@ -4789,12 +4681,12 @@ def round(col: "ColumnOrName", scale: int = 0) -> Column: scale : int optional default 0 scale value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` rounded values. - Examples + Examples: -------- >>> spark.createDataFrame([(2.5,)], ["a"]).select(round("a", 0).alias("r")).collect() [Row(r=3.0)] @@ -4803,8 +4695,7 @@ def round(col: "ColumnOrName", scale: int = 0) -> Column: def bround(col: "ColumnOrName", scale: int = 0) -> Column: - """ - Round the given value to `scale` decimal places using HALF_EVEN rounding mode if `scale` >= 0 + """Round the given value to `scale` decimal places using HALF_EVEN rounding mode if `scale` >= 0 or at integral part when `scale` < 0. .. versionadded:: 2.0.0 @@ -4819,12 +4710,12 @@ def bround(col: "ColumnOrName", scale: int = 0) -> Column: scale : int optional default 0 scale value. - Returns + Returns: ------- :class:`~pyspark.sql.Column` rounded values. - Examples + Examples: -------- >>> spark.createDataFrame([(2.5,)], ["a"]).select(bround("a", 0).alias("r")).collect() [Row(r=2.0)] @@ -4833,8 +4724,7 @@ def bround(col: "ColumnOrName", scale: int = 0) -> Column: def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: - """ - Collection function: Returns element of array at given (0-based) index. + """Collection function: Returns element of array at given (0-based) index. If the index points outside of the array boundaries, then this function returns NULL. @@ -4847,21 +4737,21 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: index : :class:`~pyspark.sql.Column` or str or int index to check for in array - Returns + Returns: ------- :class:`~pyspark.sql.Column` value at given position. - Notes + Notes: ----- The position is not 1 based, but 0 based index. Supports Spark Connect. - See Also + See Also: -------- :meth:`element_at` - Examples + Examples: -------- >>> df = spark.createDataFrame([(["a", "b", "c"], 1)], ["data", "index"]) >>> df.select(get(df.data, 1)).show() @@ -4919,12 +4809,12 @@ def initcap(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` string with all first letters are uppercase in each word. - Examples + Examples: -------- >>> spark.createDataFrame([("ab cd",)], ["a"]).select(initcap("a").alias("v")).collect() [Row(v='Ab Cd')] @@ -4953,8 +4843,7 @@ def initcap(col: "ColumnOrName") -> Column: def octet_length(col: "ColumnOrName") -> Column: - """ - Calculates the byte length for the specified string column. + """Calculates the byte length for the specified string column. .. versionadded:: 3.3.0 @@ -4966,12 +4855,12 @@ def octet_length(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str Source column or strings - Returns + Returns: ------- :class:`~pyspark.sql.Column` Byte length of the col - Examples + Examples: -------- >>> from pyspark.sql.functions import octet_length >>> spark.createDataFrame([('cat',), ( '\U0001f408',)], ['cat']) \\ @@ -4982,8 +4871,7 @@ def octet_length(col: "ColumnOrName") -> Column: def hex(col: "ColumnOrName") -> Column: - """ - Computes hex value of the given column, which could be :class:`~pyspark.sql.types.StringType`, :class:`~pyspark.sql.types.BinaryType`, :class:`~pyspark.sql.types.IntegerType` or :class:`~pyspark.sql.types.LongType`. + """Computes hex value of the given column, which could be :class:`~pyspark.sql.types.StringType`, :class:`~pyspark.sql.types.BinaryType`, :class:`~pyspark.sql.types.IntegerType` or :class:`~pyspark.sql.types.LongType`. .. versionadded:: 1.5.0 @@ -4995,12 +4883,12 @@ def hex(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` hexadecimal representation of given value as string. - Examples + Examples: -------- >>> spark.createDataFrame([("ABC", 3)], ["a", "b"]).select(hex("a"), hex("b")).collect() [Row(hex(a)='414243', hex(b)='3')] @@ -5009,8 +4897,7 @@ def hex(col: "ColumnOrName") -> Column: def unhex(col: "ColumnOrName") -> Column: - """ - Inverse of hex. Interprets each pair of characters as a hexadecimal number and converts to the byte representation of number. column and returns it as a binary column. + """Inverse of hex. Interprets each pair of characters as a hexadecimal number and converts to the byte representation of number. column and returns it as a binary column. .. versionadded:: 1.5.0 @@ -5022,12 +4909,12 @@ def unhex(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` string representation of given hexadecimal value. - Examples + Examples: -------- >>> spark.createDataFrame([("414243",)], ["a"]).select(unhex("a")).collect() [Row(unhex(a)=bytearray(b'ABC'))] @@ -5036,8 +4923,7 @@ def unhex(col: "ColumnOrName") -> Column: def base64(col: "ColumnOrName") -> Column: - """ - Computes the BASE64 encoding of a binary column and returns it as a string column. + """Computes the BASE64 encoding of a binary column and returns it as a string column. .. versionadded:: 1.5.0 @@ -5049,12 +4935,12 @@ def base64(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` BASE64 encoding of string value. - Examples + Examples: -------- >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") >>> df.select(base64("value")).show() @@ -5072,8 +4958,7 @@ def base64(col: "ColumnOrName") -> Column: def unbase64(col: "ColumnOrName") -> Column: - """ - Decodes a BASE64 encoded string column and returns it as a binary column. + """Decodes a BASE64 encoded string column and returns it as a binary column. .. versionadded:: 1.5.0 @@ -5085,12 +4970,12 @@ def unbase64(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` encoded string value. - Examples + Examples: -------- >>> df = spark.createDataFrame(["U3Bhcms=", "UHlTcGFyaw==", "UGFuZGFzIEFQSQ=="], "STRING") >>> df.select(unbase64("value")).show() @@ -5106,8 +4991,7 @@ def unbase64(col: "ColumnOrName") -> Column: def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Column: - """ - Returns the date that is `months` months after `start`. If `months` is a negative value + """Returns the date that is `months` months after `start`. If `months` is a negative value then these amount of months will be deducted from the `start`. .. versionadded:: 1.5.0 @@ -5123,12 +5007,12 @@ def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Col how many months after the given date to calculate. Accepts negative value as well to calculate backwards. - Returns + Returns: ------- :class:`~pyspark.sql.Column` a date after/before given number of months. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08", 2)], ["dt", "add"]) >>> df.select(add_months(df.dt, 1).alias("next_month")).collect() @@ -5143,8 +5027,7 @@ def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Col def array_join(col: "ColumnOrName", delimiter: str, null_replacement: Optional[str] = None) -> Column: - """ - Concatenates the elements of `column` using the `delimiter`. Null values are replaced with + """Concatenates the elements of `column` using the `delimiter`. Null values are replaced with `null_replacement` if set, otherwise they are ignored. .. versionadded:: 2.4.0 @@ -5161,12 +5044,12 @@ def array_join(col: "ColumnOrName", delimiter: str, null_replacement: Optional[s null_replacement : str, optional if set then null values will be replaced by this value - Returns + Returns: ------- :class:`~pyspark.sql.Column` a column of string type. Concatenated values. - Examples + Examples: -------- >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ["data"]) >>> df.select(array_join(df.data, ",").alias("joined")).collect() @@ -5190,8 +5073,7 @@ def array_join(col: "ColumnOrName", delimiter: str, null_replacement: Optional[s def array_position(col: "ColumnOrName", value: Any) -> Column: - """ - Collection function: Locates the position of the first occurrence of the given value + """Collection function: Locates the position of the first occurrence of the given value in the given array. Returns null if either of the arguments are null. .. versionadded:: 2.4.0 @@ -5199,7 +5081,7 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The position is not zero based, but 1 based index. Returns 0 if the given value could not be found in the array. @@ -5211,12 +5093,12 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: value : Any value to look for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` position of the value in the given array if found and 0 otherwise. - Examples + Examples: -------- >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ["data"]) >>> df.select(array_position(df.data, "a")).collect() @@ -5230,8 +5112,7 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: def array_prepend(col: "ColumnOrName", value: Any) -> Column: - """ - Collection function: Returns an array containing element as + """Collection function: Returns an array containing element as well as all elements from array. The new element is positioned at the beginning of the array. @@ -5244,12 +5125,12 @@ def array_prepend(col: "ColumnOrName", value: Any) -> Column: value : a literal value, or a :class:`~pyspark.sql.Column` expression. - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array excluding given value. - Examples + Examples: -------- >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ["data"]) >>> df.select(array_prepend(df.data, 1)).collect() @@ -5259,8 +5140,7 @@ def array_prepend(col: "ColumnOrName", value: Any) -> Column: def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Column: - """ - Collection function: creates an array containing a column repeated count times. + """Collection function: creates an array containing a column repeated count times. .. versionadded:: 2.4.0 @@ -5274,12 +5154,12 @@ def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Colu count : :class:`~pyspark.sql.Column` or str or int column name, column, or int containing the number of times to repeat the first argument - Returns + Returns: ------- :class:`~pyspark.sql.Column` an array of repeated elements. - Examples + Examples: -------- >>> df = spark.createDataFrame([("ab",)], ["data"]) >>> df.select(array_repeat(df.data, 3).alias("r")).collect() @@ -5291,8 +5171,7 @@ def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Colu def array_size(col: "ColumnOrName") -> Column: - """ - Returns the total number of elements in the array. The function returns null for null input. + """Returns the total number of elements in the array. The function returns null for null input. .. versionadded:: 3.5.0 @@ -5301,12 +5180,12 @@ def array_size(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` total number of elements in the array. - Examples + Examples: -------- >>> df = spark.createDataFrame([([2, 1, 3],), (None,)], ["data"]) >>> df.select(array_size(df.data).alias("r")).collect() @@ -5316,8 +5195,7 @@ def array_size(col: "ColumnOrName") -> Column: def array_sort(col: "ColumnOrName", comparator: Optional[Callable[[Column, Column], Column]] = None) -> Column: - """ - Collection function: sorts the input array in ascending order. The elements of the input array + """Collection function: sorts the input array in ascending order. The elements of the input array must be orderable. Null elements will be placed at the end of the returned array. .. versionadded:: 2.4.0 @@ -5339,12 +5217,12 @@ def array_sort(col: "ColumnOrName", comparator: Optional[Callable[[Column, Colum positive integer as the first element is less than, equal to, or greater than the second element. If the comparator function returns null, the function will fail and raise an error. - Returns + Returns: ------- :class:`~pyspark.sql.Column` sorted array. - Examples + Examples: -------- >>> df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ["data"]) >>> df.select(array_sort(df.data).alias("r")).collect() @@ -5369,8 +5247,7 @@ def array_sort(col: "ColumnOrName", comparator: Optional[Callable[[Column, Colum def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: - """ - Collection function: sorts the input array in ascending or descending order according + """Collection function: sorts the input array in ascending or descending order according to the natural ordering of the array elements. Null elements will be placed at the beginning of the returned array in ascending order or at the end of the returned array in descending order. @@ -5388,12 +5265,12 @@ def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: whether to sort in ascending or descending order. If `asc` is True (default) then ascending and if False then descending. - Returns + Returns: ------- :class:`~pyspark.sql.Column` sorted array. - Examples + Examples: -------- >>> df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ["data"]) >>> df.select(sort_array(df.data).alias("r")).collect() @@ -5411,8 +5288,7 @@ def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: - """ - Splits str around matches of the given pattern. + """Splits str around matches of the given pattern. .. versionadded:: 1.5.0 @@ -5438,12 +5314,12 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: .. versionchanged:: 3.0 `split` now takes an optional `limit` field. If not provided, default limit value is -1. - Returns + Returns: ------- :class:`~pyspark.sql.Column` array of separated strings. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [("oneAtwoBthreeC",)], @@ -5464,8 +5340,7 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: def split_part(src: "ColumnOrName", delimiter: "ColumnOrName", partNum: "ColumnOrName") -> Column: - """ - Splits `str` by delimiter and return requested part of the split (1-based). + """Splits `str` by delimiter and return requested part of the split (1-based). If any input is null, returns null. if `partNum` is out of range of split parts, returns empty string. If `partNum` is 0, throws an error. If `partNum` is negative, the parts are counted backward from the end of the string. @@ -5482,7 +5357,7 @@ def split_part(src: "ColumnOrName", delimiter: "ColumnOrName", partNum: "ColumnO partNum : :class:`~pyspark.sql.Column` or str A column of string, requested part of the split (1-based). - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [ @@ -5510,8 +5385,7 @@ def split_part(src: "ColumnOrName", delimiter: "ColumnOrName", partNum: "ColumnO def stddev_samp(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the unbiased sample standard deviation of + """Aggregate function: returns the unbiased sample standard deviation of the expression in a group. .. versionadded:: 1.6.0 @@ -5524,12 +5398,12 @@ def stddev_samp(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` standard deviation of given column. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(6).select(sf.stddev_samp("id")).show() @@ -5543,8 +5417,7 @@ def stddev_samp(col: "ColumnOrName") -> Column: def stddev(col: "ColumnOrName") -> Column: - """ - Aggregate function: alias for stddev_samp. + """Aggregate function: alias for stddev_samp. .. versionadded:: 1.6.0 @@ -5556,12 +5429,12 @@ def stddev(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` standard deviation of given column. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(6).select(sf.stddev("id")).show() @@ -5575,8 +5448,7 @@ def stddev(col: "ColumnOrName") -> Column: def std(col: "ColumnOrName") -> Column: - """ - Aggregate function: alias for stddev_samp. + """Aggregate function: alias for stddev_samp. .. versionadded:: 3.5.0 @@ -5585,12 +5457,12 @@ def std(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` standard deviation of given column. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(6).select(sf.std("id")).show() @@ -5604,8 +5476,7 @@ def std(col: "ColumnOrName") -> Column: def stddev_pop(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns population standard deviation of + """Aggregate function: returns population standard deviation of the expression in a group. .. versionadded:: 1.6.0 @@ -5618,12 +5489,12 @@ def stddev_pop(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` standard deviation of given column. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.range(6).select(sf.stddev_pop("id")).show() @@ -5637,8 +5508,7 @@ def stddev_pop(col: "ColumnOrName") -> Column: def var_pop(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the population variance of the values in a group. + """Aggregate function: returns the population variance of the values in a group. .. versionadded:: 1.6.0 @@ -5650,12 +5520,12 @@ def var_pop(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` variance of given column. - Examples + Examples: -------- >>> df = spark.range(6) >>> df.select(var_pop(df.id)).first() @@ -5665,8 +5535,7 @@ def var_pop(col: "ColumnOrName") -> Column: def var_samp(col: "ColumnOrName") -> Column: - """ - Aggregate function: returns the unbiased sample variance of + """Aggregate function: returns the unbiased sample variance of the values in a group. .. versionadded:: 1.6.0 @@ -5679,12 +5548,12 @@ def var_samp(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` variance of given column. - Examples + Examples: -------- >>> df = spark.range(6) >>> df.select(var_samp(df.id)).show() @@ -5698,8 +5567,7 @@ def var_samp(col: "ColumnOrName") -> Column: def variance(col: "ColumnOrName") -> Column: - """ - Aggregate function: alias for var_samp + """Aggregate function: alias for var_samp .. versionadded:: 1.6.0 @@ -5711,12 +5579,12 @@ def variance(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target column to compute on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` variance of given column. - Examples + Examples: -------- >>> df = spark.range(6) >>> df.select(variance(df.id)).show() @@ -5730,8 +5598,7 @@ def variance(col: "ColumnOrName") -> Column: def weekday(col: "ColumnOrName") -> Column: - """ - Returns the day of the week for date/timestamp (0 = Monday, 1 = Tuesday, ..., 6 = Sunday). + """Returns the day of the week for date/timestamp (0 = Monday, 1 = Tuesday, ..., 6 = Sunday). .. versionadded:: 3.5.0 @@ -5740,12 +5607,12 @@ def weekday(col: "ColumnOrName") -> Column: col : :class:`~pyspark.sql.Column` or str target date/timestamp column to work on. - Returns + Returns: ------- :class:`~pyspark.sql.Column` the day of the week for date/timestamp (0 = Monday, 1 = Tuesday, ..., 6 = Sunday). - Examples + Examples: -------- >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(weekday("dt").alias("day")).show() @@ -5759,8 +5626,7 @@ def weekday(col: "ColumnOrName") -> Column: def zeroifnull(col: "ColumnOrName") -> Column: - """ - Returns zero if `col` is null, or `col` otherwise. + """Returns zero if `col` is null, or `col` otherwise. .. versionadded:: 4.0.0 @@ -5768,7 +5634,7 @@ def zeroifnull(col: "ColumnOrName") -> Column: ---------- col : :class:`~pyspark.sql.Column` or str - Examples + Examples: -------- >>> df = spark.createDataFrame([(None,), (1,)], ["a"]) >>> df.select(zeroifnull(df.a).alias("result")).show() @@ -5811,12 +5677,12 @@ def to_date(col: "ColumnOrName", format: Optional[str] = None) -> Column: format: str, optional format to use to convert date values. - Returns + Returns: ------- :class:`~pyspark.sql.Column` date value as :class:`pyspark.sql.types.DateType` type. - Examples + Examples: -------- >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) >>> df.select(to_date(df.t).alias("date")).collect() @@ -5849,12 +5715,12 @@ def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: format: str, optional format to use to convert timestamp values. - Returns + Returns: ------- :class:`~pyspark.sql.Column` timestamp value as :class:`pyspark.sql.types.TimestampType` type. - Examples + Examples: -------- >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) >>> df.select(to_timestamp(df.t).alias("dt")).collect() @@ -5871,8 +5737,7 @@ def to_timestamp_ltz( timestamp: "ColumnOrName", format: Optional["ColumnOrName"] = None, ) -> Column: - """ - Parses the `timestamp` with the `format` to a timestamp without time zone. + """Parses the `timestamp` with the `format` to a timestamp without time zone. Returns null with invalid input. .. versionadded:: 3.5.0 @@ -5884,7 +5749,7 @@ def to_timestamp_ltz( format : :class:`~pyspark.sql.Column` or str, optional format to use to convert type `TimestampType` timestamp values. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2016-12-31",)], ["e"]) >>> df.select(to_timestamp_ltz(df.e, lit("yyyy-MM-dd")).alias("r")).collect() @@ -5903,8 +5768,7 @@ def to_timestamp_ntz( timestamp: "ColumnOrName", format: Optional["ColumnOrName"] = None, ) -> Column: - """ - Parses the `timestamp` with the `format` to a timestamp without time zone. + """Parses the `timestamp` with the `format` to a timestamp without time zone. Returns null with invalid input. .. versionadded:: 3.5.0 @@ -5916,7 +5780,7 @@ def to_timestamp_ntz( format : :class:`~pyspark.sql.Column` or str, optional format to use to convert type `TimestampNTZType` timestamp values. - Examples + Examples: -------- >>> df = spark.createDataFrame([("2016-04-08",)], ["e"]) >>> df.select(to_timestamp_ntz(df.e, lit("yyyy-MM-dd")).alias("r")).collect() @@ -5932,8 +5796,7 @@ def to_timestamp_ntz( def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = None) -> Column: - """ - Parses the `col` with the `format` to a timestamp. The function always + """Parses the `col` with the `format` to a timestamp. The function always returns null on an invalid input with/without ANSI SQL mode enabled. The result data type is consistent with the value of configuration `spark.sql.timestampType`. .. versionadded:: 3.5.0 @@ -5943,7 +5806,8 @@ def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = Non column values to convert. format: str, optional format to use to convert timestamp values. - Examples + + Examples: -------- >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) >>> df.select(try_to_timestamp(df.t).alias("dt")).collect() @@ -5958,8 +5822,7 @@ def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = Non def substr(str: "ColumnOrName", pos: "ColumnOrName", len: Optional["ColumnOrName"] = None) -> Column: - """ - Returns the substring of `str` that starts at `pos` and is of length `len`, + """Returns the substring of `str` that starts at `pos` and is of length `len`, or the slice of byte array that starts at `pos` and is of length `len`. .. versionadded:: 3.5.0 @@ -5973,7 +5836,7 @@ def substr(str: "ColumnOrName", pos: "ColumnOrName", len: Optional["ColumnOrName len : :class:`~pyspark.sql.Column` or str, optional A column of string, the substring of `str` is of length `len`. - Examples + Examples: -------- >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( @@ -6026,7 +5889,7 @@ def unix_date(col: "ColumnOrName") -> Column: .. versionadded:: 3.5.0 - Examples + Examples: -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") >>> df = spark.createDataFrame([("1970-01-02",)], ["t"]) @@ -6042,7 +5905,7 @@ def unix_micros(col: "ColumnOrName") -> Column: .. versionadded:: 3.5.0 - Examples + Examples: -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") >>> df = spark.createDataFrame([("2015-07-22 10:00:00",)], ["t"]) @@ -6059,7 +5922,7 @@ def unix_millis(col: "ColumnOrName") -> Column: .. versionadded:: 3.5.0 - Examples + Examples: -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") >>> df = spark.createDataFrame([("2015-07-22 10:00:00",)], ["t"]) @@ -6076,7 +5939,7 @@ def unix_seconds(col: "ColumnOrName") -> Column: .. versionadded:: 3.5.0 - Examples + Examples: -------- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") >>> df = spark.createDataFrame([("2015-07-22 10:00:00",)], ["t"]) @@ -6088,8 +5951,7 @@ def unix_seconds(col: "ColumnOrName") -> Column: def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: - """ - Collection function: returns true if the arrays contain any common non-null element; if not, + """Collection function: returns true if the arrays contain any common non-null element; if not, returns null if both the arrays are non-empty and any of them contains a null element; returns false otherwise. @@ -6098,12 +5960,12 @@ def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Returns + Returns: ------- :class:`~pyspark.sql.Column` a column of Boolean type. - Examples + Examples: -------- >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ["x", "y"]) >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() @@ -6134,8 +5996,7 @@ def _list_contains_null(c: ColumnExpression) -> Expression: def arrays_zip(*cols: "ColumnOrName") -> Column: - """ - Collection function: Returns a merged array of structs in which the N-th struct contains all + """Collection function: Returns a merged array of structs in which the N-th struct contains all N-th values of input arrays. If one of the arrays is shorter than others then resulting struct type value will be a `null` for missing elements. @@ -6149,12 +6010,12 @@ def arrays_zip(*cols: "ColumnOrName") -> Column: cols : :class:`~pyspark.sql.Column` or str columns of arrays to be merged. - Returns + Returns: ------- :class:`~pyspark.sql.Column` merged array of entries. - Examples + Examples: -------- >>> from pyspark.sql.functions import arrays_zip >>> df = spark.createDataFrame( @@ -6179,14 +6040,14 @@ def arrays_zip(*cols: "ColumnOrName") -> Column: def substring(str: "ColumnOrName", pos: int, len: int) -> Column: - """ - Substring starts at `pos` and is of length `len` when str is String type or + """Substring starts at `pos` and is of length `len` when str is String type or returns the slice of byte array that starts at `pos` in byte and is of length `len` when str is Binary type. .. versionadded:: 1.5.0 .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + + Notes: ----- The position is not zero based, but 1 based index. Parameters @@ -6197,11 +6058,13 @@ def substring(str: "ColumnOrName", pos: int, len: int) -> Column: starting position in str. len : int length of chars. - Returns + + Returns: ------- :class:`~pyspark.sql.Column` substring of given value. - Examples + + Examples: -------- >>> df = spark.createDataFrame( ... [("abcd",)], @@ -6221,8 +6084,7 @@ def substring(str: "ColumnOrName", pos: int, len: int) -> Column: def contains(left: "ColumnOrName", right: "ColumnOrName") -> Column: - """ - Returns a boolean. The value is True if right is found inside left. + """Returns a boolean. The value is True if right is found inside left. Returns NULL if either input expression is NULL. Otherwise, returns False. Both left or right must be of STRING or BINARY type. .. versionadded:: 3.5.0 @@ -6232,7 +6094,8 @@ def contains(left: "ColumnOrName", right: "ColumnOrName") -> Column: The input column or strings to check, may be NULL. right : :class:`~pyspark.sql.Column` or str The input column or strings to find, may be NULL. - Examples + + Examples: -------- >>> df = spark.createDataFrame([("Spark SQL", "Spark")], ["a", "b"]) >>> df.select(contains(df.a, df.b).alias("r")).collect() @@ -6262,8 +6125,7 @@ def contains(left: "ColumnOrName", right: "ColumnOrName") -> Column: def reverse(col: "ColumnOrName") -> Column: - """ - Collection function: returns a reversed string or an array with reverse order of elements. + """Collection function: returns a reversed string or an array with reverse order of elements. .. versionadded:: 1.5.0 .. versionchanged:: 3.4.0 Supports Spark Connect. @@ -6271,11 +6133,13 @@ def reverse(col: "ColumnOrName") -> Column: ---------- col : :class:`~pyspark.sql.Column` or str name of column or expression - Returns + + Returns: ------- :class:`~pyspark.sql.Column` array of elements in reverse order. - Examples + + Examples: -------- >>> df = spark.createDataFrame([("Spark SQL",)], ["data"]) >>> df.select(reverse(df.data).alias("s")).collect() @@ -6288,8 +6152,7 @@ def reverse(col: "ColumnOrName") -> Column: def concat(*cols: "ColumnOrName") -> Column: - """ - Concatenates multiple input columns together into a single column. + """Concatenates multiple input columns together into a single column. The function works with strings, numeric, binary and compatible array columns. .. versionadded:: 1.5.0 .. versionchanged:: 3.4.0 @@ -6298,14 +6161,17 @@ def concat(*cols: "ColumnOrName") -> Column: ---------- cols : :class:`~pyspark.sql.Column` or str target column or columns to work on. - Returns + + Returns: ------- :class:`~pyspark.sql.Column` concatenated values. Type of the `Column` depends on input columns' type. - See Also + + See Also: -------- :meth:`pyspark.sql.functions.array_join` : to concatenate string columns with delimiter - Examples + + Examples: -------- >>> df = spark.createDataFrame([("abcd", "123")], ["s", "d"]) >>> df = df.select(concat(df.s, df.d).alias("s")) @@ -6326,8 +6192,7 @@ def concat(*cols: "ColumnOrName") -> Column: def instr(str: "ColumnOrName", substr: str) -> Column: - """ - Locate the position of the first occurrence of substr column in the given string. + """Locate the position of the first occurrence of substr column in the given string. Returns null if either of the arguments are null. .. versionadded:: 1.5.0 @@ -6335,7 +6200,7 @@ def instr(str: "ColumnOrName", substr: str) -> Column: .. versionchanged:: 3.4.0 Supports Spark Connect. - Notes + Notes: ----- The position is not zero based, but 1 based index. Returns 0 if substr could not be found in str. @@ -6347,12 +6212,12 @@ def instr(str: "ColumnOrName", substr: str) -> Column: substr : str substring to look for. - Returns + Returns: ------- :class:`~pyspark.sql.Column` location of the first occurrence of the substring as integer. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [("abcd",)], @@ -6379,12 +6244,12 @@ def expr(str: str) -> Column: str : str expression defined in string. - Returns + Returns: ------- :class:`~pyspark.sql.Column` column representing the expression. - Examples + Examples: -------- >>> df = spark.createDataFrame([["Alice"], ["Bob"]], ["name"]) >>> df.select("name", expr("length(name)")).show() @@ -6399,8 +6264,7 @@ def expr(str: str) -> Column: def broadcast(df: "DataFrame") -> "DataFrame": - """ - The broadcast function in Spark is used to optimize joins by broadcasting a smaller + """The broadcast function in Spark is used to optimize joins by broadcasting a smaller dataset to all the worker nodes. However, DuckDB operates on a single-node architecture . As a result, the function simply returns the input DataFrame without applying any modifications or optimizations, since broadcasting is not applicable in the DuckDB context. diff --git a/duckdb/experimental/spark/sql/group.py b/duckdb/experimental/spark/sql/group.py index 29210e29..7aa9eb11 100644 --- a/duckdb/experimental/spark/sql/group.py +++ b/duckdb/experimental/spark/sql/group.py @@ -15,19 +15,16 @@ # limitations under the License. # -from ..exception import ContributionsAcceptedError -from typing import Callable, TYPE_CHECKING, overload, Dict, Union, List +from typing import Callable, Union, overload +from ..exception import ContributionsAcceptedError +from ._typing import ColumnOrName from .column import Column -from .session import SparkSession from .dataframe import DataFrame from .functions import _to_column_expr -from ._typing import ColumnOrName +from .session import SparkSession from .types import NumericType -if TYPE_CHECKING: - from ._typing import LiteralType - __all__ = ["GroupedData", "Grouping"] @@ -35,7 +32,7 @@ def _api_internal(self: "GroupedData", name: str, *cols: str) -> DataFrame: expressions = ",".join(list(cols)) group_by = str(self._grouping) if self._grouping else "" projections = self._grouping.get_columns() - jdf = getattr(self._df.relation, "apply")( + jdf = self._df.relation.apply( function_name=name, # aggregate function function_aggr=expressions, # inputs to aggregate group_expr=group_by, # groups @@ -76,8 +73,7 @@ def __str__(self) -> str: class GroupedData: - """ - A set of methods for aggregations on a :class:`DataFrame`, + """A set of methods for aggregations on a :class:`DataFrame`, created by :func:`DataFrame.groupBy`. """ @@ -93,7 +89,7 @@ def __repr__(self) -> str: def count(self) -> DataFrame: """Counts the number of records for each group. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"] @@ -142,7 +138,7 @@ def avg(self, *cols: str) -> DataFrame: cols : str column names. Non-numeric columns are ignored. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], @@ -188,7 +184,7 @@ def avg(self, *cols: str) -> DataFrame: def max(self, *cols: str) -> DataFrame: """Computes the max value for each numeric columns for each group. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], @@ -233,7 +229,7 @@ def min(self, *cols: str) -> DataFrame: cols : str column names. Non-numeric columns are ignored. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], @@ -278,7 +274,7 @@ def sum(self, *cols: str) -> DataFrame: cols : str column names. Non-numeric columns are ignored. - Examples + Examples: -------- >>> df = spark.createDataFrame( ... [(2, "Alice", 80), (3, "Alice", 100), (5, "Bob", 120), (10, "Bob", 140)], @@ -352,12 +348,12 @@ def agg(self, *exprs: Union[Column, dict[str, str]]) -> DataFrame: a dict mapping from column name (string) to aggregate functions (string), or a list of :class:`Column`. - Notes + Notes: ----- Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed in a single call to this function. - Examples + Examples: -------- >>> from pyspark.sql import functions as F >>> from pyspark.sql.functions import pandas_udf, PandasUDFType diff --git a/duckdb/experimental/spark/sql/readwriter.py b/duckdb/experimental/spark/sql/readwriter.py index 18095ab6..607e9d36 100644 --- a/duckdb/experimental/spark/sql/readwriter.py +++ b/duckdb/experimental/spark/sql/readwriter.py @@ -1,11 +1,9 @@ -from typing import TYPE_CHECKING, List, Optional, Union, cast +from typing import TYPE_CHECKING, Optional, Union, cast +from ..errors import PySparkNotImplementedError, PySparkTypeError from ..exception import ContributionsAcceptedError from .types import StructType - -from ..errors import PySparkNotImplementedError, PySparkTypeError - PrimitiveType = Union[bool, float, int, str] OptionalPrimitiveType = Optional[PrimitiveType] @@ -123,7 +121,7 @@ def load( if schema: if not isinstance(schema, StructType): raise ContributionsAcceptedError - schema = cast(StructType, schema) + schema = cast("StructType", schema) types, names = schema.extract_types_and_names() df = df._cast_types(types) df = df.toDF(names) @@ -225,7 +223,7 @@ def csv( dtype = None names = None if schema: - schema = cast(StructType, schema) + schema = cast("StructType", schema) dtype, names = schema.extract_types_and_names() rel = self.session.conn.read_csv( @@ -289,8 +287,7 @@ def json( modifiedAfter: Optional[Union[bool, str]] = None, allowNonNumericNumbers: Optional[Union[bool, str]] = None, ) -> "DataFrame": - """ - Loads JSON files and returns the results as a :class:`DataFrame`. + """Loads JSON files and returns the results as a :class:`DataFrame`. `JSON Lines `_ (newline-delimited JSON) is supported by default. For JSON (one record per file), set the ``multiLine`` parameter to ``true``. @@ -321,7 +318,7 @@ def json( .. # noqa - Examples + Examples: -------- Write a DataFrame into a JSON file and read it back. @@ -340,7 +337,6 @@ def json( |100|Hyukjin Kwon| +---+------------+ """ - if schema is not None: raise ContributionsAcceptedError("The 'schema' option is not supported") if primitivesAsString is not None: @@ -410,4 +406,4 @@ def json( ) -__all__ = ["DataFrameWriter", "DataFrameReader"] +__all__ = ["DataFrameReader", "DataFrameWriter"] diff --git a/duckdb/experimental/spark/sql/session.py b/duckdb/experimental/spark/sql/session.py index c83c7e82..4b919446 100644 --- a/duckdb/experimental/spark/sql/session.py +++ b/duckdb/experimental/spark/sql/session.py @@ -1,24 +1,24 @@ -from typing import Optional, List, Any, Union, Iterable, TYPE_CHECKING import uuid +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Optional, Union if TYPE_CHECKING: - from .catalog import Catalog from pandas.core.frame import DataFrame as PandasDataFrame -from ..exception import ContributionsAcceptedError -from .types import StructType, AtomicType, DataType + from .catalog import Catalog + + from ..conf import SparkConf -from .dataframe import DataFrame +from ..context import SparkContext +from ..errors import PySparkTypeError +from ..errors.error_classes import * +from ..exception import ContributionsAcceptedError from .conf import RuntimeConfig +from .dataframe import DataFrame from .readwriter import DataFrameReader -from ..context import SparkContext -from .udf import UDFRegistration from .streaming import DataStreamReader -import duckdb - -from ..errors import PySparkTypeError, PySparkValueError - -from ..errors.error_classes import * +from .types import StructType +from .udf import UDFRegistration # In spark: # SparkSession holds a SparkContext diff --git a/duckdb/experimental/spark/sql/streaming.py b/duckdb/experimental/spark/sql/streaming.py index 4dcba01f..ba54db60 100644 --- a/duckdb/experimental/spark/sql/streaming.py +++ b/duckdb/experimental/spark/sql/streaming.py @@ -1,4 +1,5 @@ from typing import TYPE_CHECKING, Optional, Union + from .types import StructType if TYPE_CHECKING: @@ -29,7 +30,6 @@ def load( schema: Union[StructType, str, None] = None, **options: OptionalPrimitiveType, ) -> "DataFrame": - from duckdb.experimental.spark.sql.dataframe import DataFrame raise NotImplementedError diff --git a/duckdb/experimental/spark/sql/type_utils.py b/duckdb/experimental/spark/sql/type_utils.py index f8c8ce4f..446eac97 100644 --- a/duckdb/experimental/spark/sql/type_utils.py +++ b/duckdb/experimental/spark/sql/type_utils.py @@ -1,38 +1,40 @@ +from typing import cast + from duckdb.typing import DuckDBPyType -from typing import List, Tuple, cast + from .types import ( - DataType, - StringType, + ArrayType, BinaryType, BitstringType, - UUIDType, BooleanType, + ByteType, + DataType, DateType, - TimestampType, - TimestampNTZType, - TimeType, - TimeNTZType, - TimestampNanosecondNTZType, - TimestampMilisecondNTZType, - TimestampSecondNTZType, + DayTimeIntervalType, DecimalType, DoubleType, FloatType, - ByteType, - UnsignedByteType, - ShortType, - UnsignedShortType, + HugeIntegerType, IntegerType, - UnsignedIntegerType, LongType, - UnsignedLongType, - HugeIntegerType, - UnsignedHugeIntegerType, - DayTimeIntervalType, - ArrayType, MapType, + ShortType, + StringType, StructField, StructType, + TimeNTZType, + TimestampMilisecondNTZType, + TimestampNanosecondNTZType, + TimestampNTZType, + TimestampSecondNTZType, + TimestampType, + TimeType, + UnsignedByteType, + UnsignedHugeIntegerType, + UnsignedIntegerType, + UnsignedLongType, + UnsignedShortType, + UUIDType, ) _sqltype_to_spark_class = { @@ -93,8 +95,8 @@ def convert_type(dtype: DuckDBPyType) -> DataType: return convert_nested_type(dtype) if id == "decimal": children: list[tuple[str, DuckDBPyType]] = dtype.children - precision = cast(int, children[0][1]) - scale = cast(int, children[1][1]) + precision = cast("int", children[0][1]) + scale = cast("int", children[1][1]) return DecimalType(precision, scale) spark_type = _sqltype_to_spark_class[id] return spark_type() diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 81293caf..d8a04b8e 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -1,25 +1,21 @@ # This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'spark' folder. +import calendar +import datetime +import math +import re +import time +from builtins import tuple +from collections.abc import Iterator from typing import ( - cast, - overload, - Dict, - Optional, - List, - Tuple, Any, - Union, - Type, - TypeVar, ClassVar, - Iterator, + Optional, + TypeVar, + Union, + cast, + overload, ) -from builtins import tuple -import datetime -import calendar -import time -import math -import re import duckdb from duckdb.typing import DuckDBPyType @@ -30,40 +26,40 @@ U = TypeVar("U") __all__ = [ - "DataType", - "NullType", - "StringType", + "ArrayType", "BinaryType", - "UUIDType", "BitstringType", "BooleanType", + "ByteType", + "DataType", "DateType", - "TimestampType", - "TimestampNTZType", - "TimestampNanosecondNTZType", - "TimestampMilisecondNTZType", - "TimestampSecondNTZType", - "TimeType", - "TimeNTZType", + "DayTimeIntervalType", "DecimalType", "DoubleType", "FloatType", - "ByteType", - "UnsignedByteType", - "ShortType", - "UnsignedShortType", + "HugeIntegerType", "IntegerType", - "UnsignedIntegerType", "LongType", - "UnsignedLongType", - "HugeIntegerType", - "UnsignedHugeIntegerType", - "DayTimeIntervalType", - "Row", - "ArrayType", "MapType", + "NullType", + "Row", + "ShortType", + "StringType", "StructField", "StructType", + "TimeNTZType", + "TimeType", + "TimestampMilisecondNTZType", + "TimestampNTZType", + "TimestampNanosecondNTZType", + "TimestampSecondNTZType", + "TimestampType", + "UUIDType", + "UnsignedByteType", + "UnsignedHugeIntegerType", + "UnsignedIntegerType", + "UnsignedLongType", + "UnsignedShortType", ] @@ -79,10 +75,10 @@ def __repr__(self) -> str: def __hash__(self) -> int: return hash(str(self)) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) @classmethod @@ -99,22 +95,19 @@ def json(self) -> str: raise ContributionsAcceptedError def needConversion(self) -> bool: - """ - Does this type needs conversion between Python object and internal SQL object. + """Does this type needs conversion between Python object and internal SQL object. This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType. """ return False def toInternal(self, obj: Any) -> Any: - """ - Converts a Python object into an internal SQL object. + """Converts a Python object into an internal SQL object. """ return obj def fromInternal(self, obj: Any) -> Any: - """ - Converts an internal SQL object into a native Python object. + """Converts an internal SQL object into a native Python object. """ return obj @@ -148,7 +141,8 @@ def typeName(cls) -> str: class AtomicType(DataType): """An internal type used to represent everything that is not - null, UDTs, arrays, structs, and maps.""" + null, UDTs, arrays, structs, and maps. + """ class NumericType(AtomicType): @@ -538,8 +532,8 @@ def __init__(self, startField: Optional[int] = None, endField: Optional[int] = N fields = DayTimeIntervalType._fields if startField not in fields.keys() or endField not in fields.keys(): raise RuntimeError("interval %s to %s is invalid" % (startField, endField)) - self.startField = cast(int, startField) - self.endField = cast(int, endField) + self.startField = cast("int", startField) + self.endField = cast("int", endField) def _str_repr(self) -> str: fields = DayTimeIntervalType._fields @@ -577,7 +571,7 @@ class ArrayType(DataType): containsNull : bool, optional whether the array can contain null (None) values. - Examples + Examples: -------- >>> ArrayType(StringType()) == ArrayType(StringType(), True) True @@ -626,11 +620,11 @@ class MapType(DataType): valueContainsNull : bool, optional indicates whether values can contain null (None) values. - Notes + Notes: ----- Keys in a map data type are not allowed to be null (None). - Examples + Examples: -------- >>> (MapType(StringType(), IntegerType()) == MapType(StringType(), IntegerType(), True)) True @@ -693,7 +687,7 @@ class StructField(DataType): metadata : dict, optional a dict from string to simple type that can be toInternald to JSON automatically - Examples + Examples: -------- >>> (StructField("f1", StringType(), True) == StructField("f1", StringType(), True)) True @@ -750,7 +744,7 @@ class StructType(DataType): Iterating a :class:`StructType` will iterate over its :class:`StructField`\\s. A contained :class:`StructField` can be accessed by its name or position. - Examples + Examples: -------- >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct1["f1"] @@ -805,8 +799,7 @@ def add( nullable: bool = True, metadata: Optional[dict[str, Any]] = None, ) -> "StructType": - """ - Construct a :class:`StructType` by adding new elements to it, to define the schema. + """Construct a :class:`StructType` by adding new elements to it, to define the schema. The method accepts either: a) A single parameter which is a :class:`StructField` object. @@ -825,11 +818,11 @@ def add( metadata : dict, optional Any additional metadata (default None) - Returns + Returns: ------- :class:`StructType` - Examples + Examples: -------- >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) >>> struct2 = StructType([StructField("f1", StringType(), True), \\ @@ -875,7 +868,7 @@ def __getitem__(self, key: Union[str, int]) -> StructField: for field in self: if field.name == key: return field - raise KeyError("No StructField named {0}".format(key)) + raise KeyError(f"No StructField named {key}") elif isinstance(key, int): try: return self.fields[key] @@ -904,10 +897,9 @@ def extract_types_and_names(self) -> tuple[list[str], list[str]]: return (types, names) def fieldNames(self) -> list[str]: - """ - Returns all field names in a list. + """Returns all field names in a list. - Examples + Examples: -------- >>> struct = StructType([StructField("f1", StringType(), True)]) >>> struct.fieldNames() @@ -987,22 +979,19 @@ def typeName(cls) -> str: @classmethod def sqlType(cls) -> DataType: - """ - Underlying SQL storage type for this UDT. + """Underlying SQL storage type for this UDT. """ raise NotImplementedError("UDT must implement sqlType().") @classmethod def module(cls) -> str: - """ - The Python module of the UDT. + """The Python module of the UDT. """ raise NotImplementedError("UDT must implement module().") @classmethod def scalaUDT(cls) -> str: - """ - The class name of the paired Scala UDT (could be '', if there + """The class name of the paired Scala UDT (could be '', if there is no corresponding one). """ return "" @@ -1012,8 +1001,7 @@ def needConversion(self) -> bool: @classmethod def _cachedSqlType(cls) -> DataType: - """ - Cache the sqlType() into class, because it's heavily used in `toInternal`. + """Cache the sqlType() into class, because it's heavily used in `toInternal`. """ if not hasattr(cls, "_cached_sql_type"): cls._cached_sql_type = cls.sqlType() # type: ignore[attr-defined] @@ -1029,21 +1017,19 @@ def fromInternal(self, obj: Any) -> Any: return self.deserialize(v) def serialize(self, obj: Any) -> Any: - """ - Converts a user-type object into a SQL datum. + """Converts a user-type object into a SQL datum. """ raise NotImplementedError("UDT must implement toInternal().") def deserialize(self, datum: Any) -> Any: - """ - Converts a SQL datum into a user-type object. + """Converts a SQL datum into a user-type object. """ raise NotImplementedError("UDT must implement fromInternal().") def simpleString(self) -> str: return "udt" - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return type(self) == type(other) @@ -1086,8 +1072,7 @@ def _create_row(fields: Union["Row", list[str]], values: Union[tuple[Any, ...], class Row(tuple): - """ - A row in :class:`DataFrame`. + """A row in :class:`DataFrame`. The fields in it can be accessed: * like attributes (``row.key``) @@ -1104,7 +1089,7 @@ class Row(tuple): field names sorted alphabetically and will be ordered in the position as entered. - Examples + Examples: -------- >>> row = Row(name="Alice", age=11) >>> row @@ -1159,15 +1144,14 @@ def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": return tuple.__new__(cls, args) def asDict(self, recursive: bool = False) -> dict[str, Any]: - """ - Return as a dict + """Return as a dict Parameters ---------- recursive : bool, optional turns the nested Rows to dict (default: False). - Notes + Notes: ----- If a row contains duplicate field names, e.g., the rows of a join between two :class:`DataFrame` that both have the fields of same names, @@ -1175,7 +1159,7 @@ def asDict(self, recursive: bool = False) -> dict[str, Any]: will also return one of the duplicate fields, however returned value might be different to ``asDict``. - Examples + Examples: -------- >>> Row(name="Alice", age=11).asDict() == {"name": "Alice", "age": 11} True @@ -1212,7 +1196,7 @@ def __contains__(self, item: Any) -> bool: # let object acts like class def __call__(self, *args: Any) -> "Row": - """create new Row object""" + """Create new Row object""" if len(args) > len(self): raise ValueError( "Can not create Row with fields %s, expected %d values but got %s" % (self, len(self), args) diff --git a/duckdb/filesystem.py b/duckdb/filesystem.py index ea4ba540..885c797f 100644 --- a/duckdb/filesystem.py +++ b/duckdb/filesystem.py @@ -1,8 +1,10 @@ -from fsspec import filesystem, AbstractFileSystem -from fsspec.implementations.memory import MemoryFileSystem, MemoryFile -from .bytes_io_wrapper import BytesIOWrapper from io import TextIOBase +from fsspec import AbstractFileSystem +from fsspec.implementations.memory import MemoryFile, MemoryFileSystem + +from .bytes_io_wrapper import BytesIOWrapper + def is_file_like(obj): # We only care that we can read from the file diff --git a/duckdb/functional/__init__.py b/duckdb/functional/__init__.py index 90c2a561..b1ddab19 100644 --- a/duckdb/functional/__init__.py +++ b/duckdb/functional/__init__.py @@ -1,3 +1,3 @@ -from _duckdb.functional import FunctionNullHandling, PythonUDFType, SPECIAL, DEFAULT, NATIVE, ARROW +from _duckdb.functional import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonUDFType -__all__ = ["FunctionNullHandling", "PythonUDFType", "SPECIAL", "DEFAULT", "NATIVE", "ARROW"] +__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"] diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index ef87f03a..b1fc244c 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -1,17 +1,18 @@ -import duckdb -import polars as pl -from typing import Iterator, Optional +import datetime +import json +from collections.abc import Iterator +from decimal import Decimal +from typing import Optional +import polars as pl from polars.io.plugins import register_io_source + +import duckdb from duckdb import SQLExpression -import json -from decimal import Decimal -import datetime def _predicate_to_expression(predicate: pl.Expr) -> Optional[SQLExpression]: - """ - Convert a Polars predicate expression to a DuckDB-compatible SQL expression. + """Convert a Polars predicate expression to a DuckDB-compatible SQL expression. Parameters: predicate (pl.Expr): A Polars expression (e.g., col("foo") > 5) @@ -37,8 +38,7 @@ def _predicate_to_expression(predicate: pl.Expr) -> Optional[SQLExpression]: def _pl_operation_to_sql(op: str) -> str: - """ - Map Polars binary operation strings to SQL equivalents. + """Map Polars binary operation strings to SQL equivalents. Example: >>> _pl_operation_to_sql("Eq") @@ -60,8 +60,7 @@ def _pl_operation_to_sql(op: str) -> str: def _escape_sql_identifier(identifier: str) -> str: - """ - Escape SQL identifiers by doubling any double quotes and wrapping in double quotes. + """Escape SQL identifiers by doubling any double quotes and wrapping in double quotes. Example: >>> _escape_sql_identifier('column"name') @@ -72,8 +71,7 @@ def _escape_sql_identifier(identifier: str) -> str: def _pl_tree_to_sql(tree: dict) -> str: - """ - Recursively convert a Polars expression tree (as JSON) to a SQL string. + """Recursively convert a Polars expression tree (as JSON) to a SQL string. Parameters: tree (dict): JSON-deserialized expression tree from Polars @@ -158,7 +156,7 @@ def _pl_tree_to_sql(tree: dict) -> str: if dtype.startswith("{'Datetime'") or dtype == "Datetime": micros = value["Datetime"][0] dt_timestamp = datetime.datetime.fromtimestamp(micros / 1_000_000, tz=datetime.UTC) - return f"'{str(dt_timestamp)}'::TIMESTAMP" + return f"'{dt_timestamp!s}'::TIMESTAMP" # Match simple numeric/boolean types if dtype in ( @@ -202,14 +200,13 @@ def _pl_tree_to_sql(tree: dict) -> str: string_val = value.get("StringOwned", value.get("String", None)) return f"'{string_val}'" - raise NotImplementedError(f"Unsupported scalar type {str(dtype)}, with value {value}") + raise NotImplementedError(f"Unsupported scalar type {dtype!s}, with value {value}") raise NotImplementedError(f"Node type: {node_type} is not implemented. {subtree}") def duckdb_source(relation: duckdb.DuckDBPyRelation, schema: pl.schema.Schema) -> pl.LazyFrame: - """ - A polars IO plugin for DuckDB. + """A polars IO plugin for DuckDB. """ def source_generator( diff --git a/duckdb/query_graph/__main__.py b/duckdb/query_graph/__main__.py index aa67b42f..88d96350 100644 --- a/duckdb/query_graph/__main__.py +++ b/duckdb/query_graph/__main__.py @@ -1,10 +1,10 @@ +import argparse import json import os -import sys import re +import sys import webbrowser from functools import reduce -import argparse qgraph_css = """ .styled-table { diff --git a/duckdb/typing/__init__.py b/duckdb/typing/__init__.py index 33cf4cd7..53207418 100644 --- a/duckdb/typing/__init__.py +++ b/duckdb/typing/__init__.py @@ -1,5 +1,4 @@ from _duckdb.typing import ( - DuckDBPyType, BIGINT, BIT, BLOB, @@ -8,29 +7,29 @@ DOUBLE, FLOAT, HUGEINT, - UHUGEINT, INTEGER, INTERVAL, SMALLINT, SQLNULL, TIME, + TIME_TZ, TIMESTAMP, TIMESTAMP_MS, TIMESTAMP_NS, TIMESTAMP_S, TIMESTAMP_TZ, - TIME_TZ, TINYINT, UBIGINT, + UHUGEINT, UINTEGER, USMALLINT, UTINYINT, UUID, VARCHAR, + DuckDBPyType, ) __all__ = [ - "DuckDBPyType", "BIGINT", "BIT", "BLOB", @@ -39,7 +38,6 @@ "DOUBLE", "FLOAT", "HUGEINT", - "UHUGEINT", "INTEGER", "INTERVAL", "SMALLINT", @@ -53,9 +51,11 @@ "TIME_TZ", "TINYINT", "UBIGINT", + "UHUGEINT", "UINTEGER", "USMALLINT", "UTINYINT", "UUID", "VARCHAR", + "DuckDBPyType", ] diff --git a/duckdb/udf.py b/duckdb/udf.py index bbf05c7d..0eb59ba9 100644 --- a/duckdb/udf.py +++ b/duckdb/udf.py @@ -1,9 +1,8 @@ def vectorized(func): + """Decorate a function with annotated function parameters, so DuckDB can infer that the function should be provided with pyarrow arrays and should expect pyarrow array(s) as output """ - Decorate a function with annotated function parameters, so DuckDB can infer that the function should be provided with pyarrow arrays and should expect pyarrow array(s) as output - """ - from inspect import signature import types + from inspect import signature new_func = types.FunctionType(func.__code__, func.__globals__, func.__name__, func.__defaults__, func.__closure__) # Construct the annotations: diff --git a/duckdb/value/constant/__init__.py b/duckdb/value/constant/__init__.py index fb7d7284..9bbf2493 100644 --- a/duckdb/value/constant/__init__.py +++ b/duckdb/value/constant/__init__.py @@ -1,5 +1,5 @@ from typing import Any, Dict -from duckdb.typing import DuckDBPyType + from duckdb.typing import ( BIGINT, BIT, @@ -9,25 +9,26 @@ DOUBLE, FLOAT, HUGEINT, - UHUGEINT, INTEGER, INTERVAL, SMALLINT, SQLNULL, TIME, + TIME_TZ, TIMESTAMP, TIMESTAMP_MS, TIMESTAMP_NS, TIMESTAMP_S, TIMESTAMP_TZ, - TIME_TZ, TINYINT, UBIGINT, + UHUGEINT, UINTEGER, USMALLINT, UTINYINT, UUID, VARCHAR, + DuckDBPyType, ) @@ -236,33 +237,33 @@ def __init__(self, object: Any, members: dict[str, DuckDBPyType]) -> None: # TODO: add EnumValue once `duckdb.enum_type` is added __all__ = [ - "Value", - "NullValue", - "BooleanValue", - "UnsignedBinaryValue", - "UnsignedShortValue", - "UnsignedIntegerValue", - "UnsignedLongValue", "BinaryValue", - "ShortValue", - "IntegerValue", - "LongValue", - "HugeIntegerValue", - "UnsignedHugeIntegerValue", - "FloatValue", - "DoubleValue", - "DecimalValue", - "StringValue", - "UUIDValue", "BitValue", "BlobValue", + "BooleanValue", "DateValue", + "DecimalValue", + "DoubleValue", + "FloatValue", + "HugeIntegerValue", + "IntegerValue", "IntervalValue", - "TimestampValue", - "TimestampSecondValue", + "LongValue", + "NullValue", + "ShortValue", + "StringValue", + "TimeTimeZoneValue", + "TimeValue", "TimestampMilisecondValue", "TimestampNanosecondValue", + "TimestampSecondValue", "TimestampTimeZoneValue", - "TimeValue", - "TimeTimeZoneValue", + "TimestampValue", + "UUIDValue", + "UnsignedBinaryValue", + "UnsignedHugeIntegerValue", + "UnsignedIntegerValue", + "UnsignedLongValue", + "UnsignedShortValue", + "Value", ] diff --git a/duckdb_packaging/_versioning.py b/duckdb_packaging/_versioning.py index 3709dac0..57008fa3 100644 --- a/duckdb_packaging/_versioning.py +++ b/duckdb_packaging/_versioning.py @@ -7,10 +7,9 @@ """ import pathlib +import re import subprocess from typing import Optional -import re - VERSION_RE = re.compile( r"^(?P[0-9]+)\.(?P[0-9]+)\.(?P[0-9]+)(?:rc(?P[0-9]+)|\.post(?P[0-9]+))?$" @@ -139,8 +138,7 @@ def create_git_tag(version: str, message: Optional[str] = None, repo_path: Optio def strip_post_from_version(version: str) -> str: - """ - Removing post-release suffixes from the given version. + """Removing post-release suffixes from the given version. DuckDB doesn't allow post-release versions, so .post* suffixes are stripped. """ diff --git a/duckdb_packaging/build_backend.py b/duckdb_packaging/build_backend.py index b9a005db..dc94eeaa 100644 --- a/duckdb_packaging/build_backend.py +++ b/duckdb_packaging/build_backend.py @@ -13,25 +13,29 @@ Also see https://peps.python.org/pep-0517/#in-tree-build-backends. """ -import sys import os import subprocess +import sys from pathlib import Path -from typing import Optional, Dict, List, Union +from typing import Optional, Union + from scikit_build_core.build import ( - build_wheel as skbuild_build_wheel, build_editable, - build_sdist as skbuild_build_sdist, - get_requires_for_build_wheel, - get_requires_for_build_sdist, get_requires_for_build_editable, - prepare_metadata_for_build_wheel, + get_requires_for_build_sdist, + get_requires_for_build_wheel, prepare_metadata_for_build_editable, + prepare_metadata_for_build_wheel, +) +from scikit_build_core.build import ( + build_sdist as skbuild_build_sdist, +) +from scikit_build_core.build import ( + build_wheel as skbuild_build_wheel, ) -from duckdb_packaging._versioning import create_git_tag, pep440_to_git_tag, get_git_describe, strip_post_from_version -from duckdb_packaging.setuptools_scm_version import forced_version_from_env, MAIN_BRANCH_VERSIONING - +from duckdb_packaging._versioning import get_git_describe, pep440_to_git_tag, strip_post_from_version +from duckdb_packaging.setuptools_scm_version import MAIN_BRANCH_VERSIONING, forced_version_from_env _DUCKDB_VERSION_FILENAME = "duckdb_version.txt" _LOGGING_FORMAT = "[duckdb_pytooling.build_backend] {}" @@ -251,12 +255,12 @@ def build_wheel( __all__ = [ - "build_wheel", - "build_sdist", "build_editable", - "get_requires_for_build_wheel", - "get_requires_for_build_sdist", + "build_sdist", + "build_wheel", "get_requires_for_build_editable", - "prepare_metadata_for_build_wheel", + "get_requires_for_build_sdist", + "get_requires_for_build_wheel", "prepare_metadata_for_build_editable", + "prepare_metadata_for_build_wheel", ] diff --git a/duckdb_packaging/pypi_cleanup.py b/duckdb_packaging/pypi_cleanup.py index 80073c0e..8e91b34f 100644 --- a/duckdb_packaging/pypi_cleanup.py +++ b/duckdb_packaging/pypi_cleanup.py @@ -1,5 +1,4 @@ -""" -!!HERE BE DRAGONS!! Use this script with care! +"""!!HERE BE DRAGONS!! Use this script with care! PyPI package cleanup tool. This script will: * Never remove a stable version (including a post release version) @@ -17,8 +16,9 @@ import sys import time from collections import defaultdict +from collections.abc import Generator from html.parser import HTMLParser -from typing import Optional, Set, Generator +from typing import Optional from urllib.parse import urlparse import pyotp @@ -77,19 +77,16 @@ def create_argument_parser() -> argparse.ArgumentParser: class PyPICleanupError(Exception): """Base exception for PyPI cleanup operations.""" - pass class AuthenticationError(PyPICleanupError): """Raised when authentication fails.""" - pass class ValidationError(PyPICleanupError): """Raised when input validation fails.""" - pass def setup_logging(verbose: bool = False) -> None: @@ -236,7 +233,7 @@ def run(self) -> int: int: Exit code (0 for success, non-zero for failure) """ if self._do_delete: - logging.warning(f"NOT A DRILL: WILL DELETE PACKAGES") + logging.warning("NOT A DRILL: WILL DELETE PACKAGES") else: logging.info("Running in DRY RUN mode, nothing will be deleted") @@ -246,7 +243,7 @@ def run(self) -> int: with session_with_retries() as http_session: return self._execute_cleanup(http_session) except PyPICleanupError as e: - logging.error(f"Cleanup failed: {e}") + logging.exception(f"Cleanup failed: {e}") return 1 except Exception as e: logging.error(f"Unexpected error: {e}", exc_info=True) @@ -254,7 +251,6 @@ def run(self) -> int: def _execute_cleanup(self, http_session: Session) -> int: """Execute the main cleanup logic.""" - # Get released versions versions = self._fetch_released_versions(http_session) if not versions: @@ -418,7 +414,6 @@ def _get_csrf_token(self, http_session: Session, form_action: str) -> str: def _perform_login(self, http_session: Session) -> requests.Response: """Perform the initial login with username/password.""" - # Get login form and CSRF token csrf_token = self._get_csrf_token(http_session, "/account/login/") @@ -487,7 +482,7 @@ def _delete_versions(self, http_session: Session, versions_to_delete: set[str]) logging.info(f"Successfully deleted {self._package} version {version}") except Exception as e: # Continue with other versions rather than failing completely - logging.error(f"Failed to delete version {version}: {e}") + logging.exception(f"Failed to delete version {version}: {e}") failed_deletions.append(version) if failed_deletions: @@ -547,13 +542,13 @@ def main() -> int: return cleanup.run() except ValidationError as e: - logging.error(f"Configuration error: {e}") + logging.exception(f"Configuration error: {e}") return 2 except KeyboardInterrupt: logging.info("Operation cancelled by user") return 130 except Exception as e: - logging.error(f"Unexpected error: {e}", exc_info=args.verbose) + logging.exception(f"Unexpected error: {e}", exc_info=args.verbose) return 1 diff --git a/duckdb_packaging/setuptools_scm_version.py b/duckdb_packaging/setuptools_scm_version.py index 217b2ffe..2ff79f80 100644 --- a/duckdb_packaging/setuptools_scm_version.py +++ b/duckdb_packaging/setuptools_scm_version.py @@ -1,5 +1,4 @@ -""" -setuptools_scm integration for DuckDB Python versioning. +"""setuptools_scm integration for DuckDB Python versioning. This module provides the setuptools_scm version scheme and handles environment variable overrides to match the exact behavior of the original DuckDB Python package. @@ -10,7 +9,7 @@ from typing import Any # Import from our own versioning module to avoid duplication -from ._versioning import parse_version, format_version +from ._versioning import format_version, parse_version # MAIN_BRANCH_VERSIONING should be 'True' on main branch only MAIN_BRANCH_VERSIONING = False @@ -26,8 +25,7 @@ def _main_branch_versioning(): def version_scheme(version: Any) -> str: - """ - setuptools_scm version scheme that matches DuckDB's original behavior. + """setuptools_scm version scheme that matches DuckDB's original behavior. Args: version: setuptools_scm version object @@ -55,7 +53,7 @@ def _bump_version(base_version: str, distance: int, dirty: bool = False) -> str: # Validate the base version (this should never include anything else than X.Y.Z or X.Y.Z.[rc|post]N) try: major, minor, patch, post, rc = parse_version(base_version) - except ValueError as e: + except ValueError: raise ValueError(f"Incorrect version format: {base_version} (expected X.Y.Z or X.Y.Z.postN)") # If we're exactly on a tag (distance = 0, dirty=False) @@ -76,8 +74,7 @@ def _bump_version(base_version: str, distance: int, dirty: bool = False) -> str: def forced_version_from_env(): - """ - Handle getting versions from environment variables. + """Handle getting versions from environment variables. Only supports a single way of manually overriding the version through OVERRIDE_GIT_DESCRIBE. If SETUPTOOLS_SCM_PRETEND_VERSION* is set, it gets unset. diff --git a/scripts/generate_connection_methods.py b/scripts/generate_connection_methods.py index a48b6142..51f667f6 100644 --- a/scripts/generate_connection_methods.py +++ b/scripts/generate_connection_methods.py @@ -1,5 +1,5 @@ -import os import json +import os os.chdir(os.path.dirname(__file__)) @@ -29,7 +29,7 @@ def is_py_args(method): def generate(): # Read the PYCONNECTION_SOURCE file - with open(PYCONNECTION_SOURCE, "r") as source_file: + with open(PYCONNECTION_SOURCE) as source_file: source_code = source_file.readlines() start_index = -1 @@ -52,7 +52,7 @@ def generate(): # ---- Generate the definition code from the json ---- # Read the JSON file - with open(JSON_PATH, "r") as json_file: + with open(JSON_PATH) as json_file: connection_methods = json.load(json_file) DEFAULT_ARGUMENT_MAP = { diff --git a/scripts/generate_connection_stubs.py b/scripts/generate_connection_stubs.py index e3831173..9b1be9aa 100644 --- a/scripts/generate_connection_stubs.py +++ b/scripts/generate_connection_stubs.py @@ -1,5 +1,5 @@ -import os import json +import os os.chdir(os.path.dirname(__file__)) @@ -12,7 +12,7 @@ def generate(): # Read the DUCKDB_STUBS_FILE file - with open(DUCKDB_STUBS_FILE, "r") as source_file: + with open(DUCKDB_STUBS_FILE) as source_file: source_code = source_file.readlines() start_index = -1 @@ -35,7 +35,7 @@ def generate(): # ---- Generate the definition code from the json ---- # Read the JSON file - with open(JSON_PATH, "r") as json_file: + with open(JSON_PATH) as json_file: connection_methods = json.load(json_file) body = [] diff --git a/scripts/generate_connection_wrapper_methods.py b/scripts/generate_connection_wrapper_methods.py index 45ac45cc..d2ef0bba 100644 --- a/scripts/generate_connection_wrapper_methods.py +++ b/scripts/generate_connection_wrapper_methods.py @@ -1,10 +1,8 @@ -import os -import sys import json +import os # Requires `python3 -m pip install cxxheaderparser pcpp` -from get_cpp_methods import get_methods, FunctionParam, ConnectionMethod -from typing import List, Tuple +from get_cpp_methods import ConnectionMethod, get_methods os.chdir(os.path.dirname(__file__)) @@ -40,7 +38,7 @@ INIT_PY_END = "# END OF CONNECTION WRAPPER" # Read the JSON file -with open(WRAPPER_JSON_PATH, "r") as json_file: +with open(WRAPPER_JSON_PATH) as json_file: wrapper_methods = json.load(json_file) # On DuckDBPyConnection these are read_only_properties, they're basically functions without requiring () to invoke @@ -94,19 +92,19 @@ def remove_section(content, start_marker, end_marker) -> tuple[list[str], list[s def generate(): # Read the DUCKDB_PYTHON_SOURCE file - with open(DUCKDB_PYTHON_SOURCE, "r") as source_file: + with open(DUCKDB_PYTHON_SOURCE) as source_file: source_code = source_file.readlines() start_section, end_section = remove_section(source_code, START_MARKER, END_MARKER) # Read the DUCKDB_INIT_FILE file - with open(DUCKDB_INIT_FILE, "r") as source_file: + with open(DUCKDB_INIT_FILE) as source_file: source_code = source_file.readlines() py_start, py_end = remove_section(source_code, INIT_PY_START, INIT_PY_END) # ---- Generate the definition code from the json ---- # Read the JSON file - with open(JSON_PATH, "r") as json_file: + with open(JSON_PATH) as json_file: connection_methods = json.load(json_file) # Collect the definitions from the pyconnection.hpp header diff --git a/scripts/generate_connection_wrapper_stubs.py b/scripts/generate_connection_wrapper_stubs.py index 02e36c4e..3b3b8c93 100644 --- a/scripts/generate_connection_wrapper_stubs.py +++ b/scripts/generate_connection_wrapper_stubs.py @@ -1,5 +1,5 @@ -import os import json +import os os.chdir(os.path.dirname(__file__)) @@ -13,7 +13,7 @@ def generate(): # Read the DUCKDB_STUBS_FILE file - with open(DUCKDB_STUBS_FILE, "r") as source_file: + with open(DUCKDB_STUBS_FILE) as source_file: source_code = source_file.readlines() start_index = -1 @@ -38,10 +38,10 @@ def generate(): methods = [] # Read the JSON file - with open(JSON_PATH, "r") as json_file: + with open(JSON_PATH) as json_file: connection_methods = json.load(json_file) - with open(WRAPPER_JSON_PATH, "r") as json_file: + with open(WRAPPER_JSON_PATH) as json_file: wrapper_methods = json.load(json_file) methods.extend(connection_methods) diff --git a/scripts/generate_import_cache_cpp.py b/scripts/generate_import_cache_cpp.py index 8a4b0c36..036115f4 100644 --- a/scripts/generate_import_cache_cpp.py +++ b/scripts/generate_import_cache_cpp.py @@ -1,14 +1,13 @@ import os script_dir = os.path.dirname(__file__) -from typing import List, Dict import json # Load existing JSON data from a file if it exists json_data = {} json_cache_path = os.path.join(script_dir, "cache_data.json") try: - with open(json_cache_path, "r") as file: + with open(json_cache_path) as file: json_data = json.load(file) except FileNotFoundError: print("Please first use 'generate_import_cache_json.py' first to generate json") diff --git a/scripts/generate_import_cache_json.py b/scripts/generate_import_cache_json.py index 099db841..34cd84b6 100644 --- a/scripts/generate_import_cache_json.py +++ b/scripts/generate_import_cache_json.py @@ -1,8 +1,8 @@ import os script_dir = os.path.dirname(__file__) -from typing import List, Dict, Union import json +from typing import Union lines: list[str] = [file for file in open(f"{script_dir}/imports.py").read().split("\n") if file != ""] @@ -13,7 +13,7 @@ def __init__(self, full_path: str) -> None: self.type = "attribute" self.name = parts[-1] self.full_path = full_path - self.children: dict[str, "ImportCacheAttribute"] = {} + self.children: dict[str, ImportCacheAttribute] = {} def has_item(self, item_name: str) -> bool: return item_name in self.children @@ -46,7 +46,7 @@ def __init__(self, full_path) -> None: self.type = "module" self.name = parts[-1] self.full_path = full_path - self.items: dict[str, Union[ImportCacheAttribute, "ImportCacheModule"]] = {} + self.items: dict[str, Union[ImportCacheAttribute, ImportCacheModule]] = {} def add_item(self, item: Union[ImportCacheAttribute, "ImportCacheModule"]): assert self.full_path != item.full_path @@ -156,7 +156,7 @@ def to_json(self): existing_json_data = {} json_cache_path = os.path.join(script_dir, "cache_data.json") try: - with open(json_cache_path, "r") as file: + with open(json_cache_path) as file: existing_json_data = json.load(file) except FileNotFoundError: pass diff --git a/scripts/get_cpp_methods.py b/scripts/get_cpp_methods.py index 97b28af3..25aa7c7d 100644 --- a/scripts/get_cpp_methods.py +++ b/scripts/get_cpp_methods.py @@ -1,10 +1,10 @@ # Requires `python3 -m pip install cxxheaderparser pcpp` import os +from typing import Callable import cxxheaderparser.parser -import cxxheaderparser.visitor import cxxheaderparser.preprocessor -from typing import List, Dict, Callable +import cxxheaderparser.visitor scripts_folder = os.path.dirname(os.path.abspath(__file__)) diff --git a/sqllogic/conftest.py b/sqllogic/conftest.py index b8d913ea..8d772111 100644 --- a/sqllogic/conftest.py +++ b/sqllogic/conftest.py @@ -1,11 +1,12 @@ import itertools import pathlib -import pytest import random import re import typing import warnings -import glob + +import pytest + from .skipped_tests import SKIPPED_TESTS SQLLOGIC_TEST_CASE_NAME = "test_sqllogic" @@ -126,11 +127,9 @@ def create_parameters_from_paths(paths, root_dir: pathlib.Path, config: pytest.C def scan_for_test_scripts(root_dir: pathlib.Path, config: pytest.Config) -> typing.Iterator[typing.Any]: - """ - Scans for .test files in the given directory and its subdirectories. + """Scans for .test files in the given directory and its subdirectories. Returns an iterator of pytest parameters (argument, id and marks). """ - # TODO: Add tests from extensions test_script_extensions = [".test", ".test_slow", ".test_coverage"] it = itertools.chain.from_iterable(root_dir.rglob(f"*{ext}") for ext in test_script_extensions) @@ -166,13 +165,11 @@ def pytest_generate_tests(metafunc: pytest.Metafunc): def determine_test_offsets(config: pytest.Config, num_tests: int) -> tuple[int, int]: - """ - If start_offset and end_offset are specified, then these are used. + """If start_offset and end_offset are specified, then these are used. start_offset defaults to 0. end_offset defaults to and is capped to the last test index. start_offset_percentage and end_offset_percentage are used to calculate the start and end offsets based on the total number of tests. This is done in a way that a test run to 25% and another test run starting at 25% do not overlap by excluding the 25th percent test. """ - start_offset = config.getoption("start_offset") end_offset = config.getoption("end_offset") start_offset_percentage = config.getoption("start_offset_percentage") @@ -271,8 +268,7 @@ def pytest_collection_modifyitems(session: pytest.Session, config: pytest.Config def pytest_runtest_setup(item: pytest.Item): - """ - Show the test index after the test name + """Show the test index after the test name """ def get_from_tuple_list(tuples, key): diff --git a/sqllogic/test_sqllogic.py b/sqllogic/test_sqllogic.py index 6f55e931..35736015 100644 --- a/sqllogic/test_sqllogic.py +++ b/sqllogic/test_sqllogic.py @@ -1,24 +1,25 @@ import gc import os import pathlib -import pytest import signal import sys -from typing import Any, Generator, Optional +from collections.abc import Generator +from typing import Any, Optional + +import pytest sys.path.append(str(pathlib.Path(__file__).parent.parent / "external" / "duckdb" / "scripts")) from sqllogictest import ( - SQLParserException, SQLLogicParser, SQLLogicTest, + SQLParserException, ) - from sqllogictest.result import ( - TestException, - SQLLogicRunner, - SQLLogicDatabase, - SQLLogicContext, ExecuteResult, + SQLLogicContext, + SQLLogicDatabase, + SQLLogicRunner, + TestException, ) diff --git a/tests/conftest.py b/tests/conftest.py index d69cdfce..83c10f3a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,14 @@ +import glob import os +import shutil +import warnings +from importlib import import_module +from os.path import abspath, dirname, join, normpath from typing import Any import pytest -import shutil -from os.path import abspath, join, dirname, normpath -import glob + import duckdb -import warnings -from importlib import import_module try: # need to ignore warnings that might be thrown deep inside pandas's import tree (from dateutil in this case) @@ -71,11 +72,12 @@ def duckdb_empty_cursor(request): def getTimeSeriesData(nper=None, freq: "Frequency" = "B"): - from pandas import DatetimeIndex, bdate_range, Series + import string from datetime import datetime - from pandas._typing import Frequency + import numpy as np - import string + from pandas import DatetimeIndex, Series, bdate_range + from pandas._typing import Frequency _N = 30 _K = 4 @@ -226,7 +228,6 @@ def _require(extension_name, db_name=""): # By making the scope 'function' we ensure that a new connection gets created for every function that uses the fixture @pytest.fixture(scope="function") def spark(): - from spark_namespace import USE_ACTUAL_SPARK if not hasattr(spark, "session"): # Cache the import diff --git a/tests/coverage/test_pandas_categorical_coverage.py b/tests/coverage/test_pandas_categorical_coverage.py index 15eee10a..b0130577 100644 --- a/tests/coverage/test_pandas_categorical_coverage.py +++ b/tests/coverage/test_pandas_categorical_coverage.py @@ -1,7 +1,7 @@ -import duckdb -import numpy import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import NumpyPandas + +import duckdb def check_result_list(res): @@ -69,7 +69,7 @@ def check_create_table(category, pandas): # TODO: extend tests with ArrowPandas -class TestCategory(object): +class TestCategory: @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_category_string_uint16(self, duckdb_cursor, pandas): category = [] diff --git a/tests/extensions/json/test_read_json.py b/tests/extensions/json/test_read_json.py index f0fd809f..9ac5be88 100644 --- a/tests/extensions/json/test_read_json.py +++ b/tests/extensions/json/test_read_json.py @@ -1,10 +1,8 @@ -import numpy -import datetime -import pandas +from io import StringIO + import pytest + import duckdb -import re -from io import StringIO def TestFile(name): @@ -14,7 +12,7 @@ def TestFile(name): return filename -class TestReadJSON(object): +class TestReadJSON: def test_read_json_columns(self): rel = duckdb.read_json(TestFile("example.json"), columns={"id": "integer", "name": "varchar"}) res = rel.fetchone() diff --git a/tests/extensions/test_extensions_loading.py b/tests/extensions/test_extensions_loading.py index f35366ba..3aa5fe81 100644 --- a/tests/extensions/test_extensions_loading.py +++ b/tests/extensions/test_extensions_loading.py @@ -1,10 +1,10 @@ import os import platform -import duckdb -from pytest import raises import pytest +from pytest import raises +import duckdb pytestmark = pytest.mark.skipif( platform.system() == "Emscripten", diff --git a/tests/extensions/test_httpfs.py b/tests/extensions/test_httpfs.py index 866491f0..bd1ec015 100644 --- a/tests/extensions/test_httpfs.py +++ b/tests/extensions/test_httpfs.py @@ -1,9 +1,11 @@ -import duckdb +import datetime import os -from pytest import raises, mark + import pytest -from conftest import NumpyPandas, ArrowPandas -import datetime +from conftest import ArrowPandas, NumpyPandas +from pytest import mark, raises + +import duckdb # We only run this test if this env var is set # FIXME: we can add a custom command line argument to pytest to provide an extension directory @@ -14,7 +16,7 @@ ) -class TestHTTPFS(object): +class TestHTTPFS: def test_read_json_httpfs(self, require): connection = require("httpfs") try: @@ -29,7 +31,7 @@ def test_read_json_httpfs(self, require): def test_s3fs(self, require): connection = require("httpfs") - rel = connection.read_csv(f"s3://duckdb-blobs/data/Star_Trek-Season_1.csv", header=True) + rel = connection.read_csv("s3://duckdb-blobs/data/Star_Trek-Season_1.csv", header=True) res = rel.fetchone() assert res == (1, 0, datetime.date(1965, 2, 28), 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 6, 0, 0, 0, 0) @@ -42,9 +44,7 @@ def test_httpfs(self, require, pandas): ) except RuntimeError as e: # Test will ignore result if it fails due to networking issues while running the test. - if str(e).startswith("HTTP HEAD error"): - return - elif str(e).startswith("Unable to connect"): + if str(e).startswith("HTTP HEAD error") or str(e).startswith("Unable to connect"): return else: raise e diff --git a/tests/fast/adbc/test_adbc.py b/tests/fast/adbc/test_adbc.py index 80b6b385..6f6213e6 100644 --- a/tests/fast/adbc/test_adbc.py +++ b/tests/fast/adbc/test_adbc.py @@ -1,9 +1,11 @@ -import duckdb -import pytest -import sys import datetime import os +import sys + import numpy as np +import pytest + +import duckdb if sys.version_info < (3, 9): pytest.skip( diff --git a/tests/fast/adbc/test_connection_get_info.py b/tests/fast/adbc/test_connection_get_info.py index 3744b7da..4f8163bc 100644 --- a/tests/fast/adbc/test_connection_get_info.py +++ b/tests/fast/adbc/test_connection_get_info.py @@ -1,8 +1,9 @@ import sys -import duckdb import pytest +import duckdb + pa = pytest.importorskip("pyarrow") adbc_driver_manager = pytest.importorskip("adbc_driver_manager") @@ -22,7 +23,7 @@ ) -class TestADBCConnectionGetInfo(object): +class TestADBCConnectionGetInfo: def test_connection_basic(self): con = adbc_driver_duckdb.connect() with con.cursor() as cursor: diff --git a/tests/fast/adbc/test_statement_bind.py b/tests/fast/adbc/test_statement_bind.py index d1919cb1..dc5d1f59 100644 --- a/tests/fast/adbc/test_statement_bind.py +++ b/tests/fast/adbc/test_statement_bind.py @@ -31,7 +31,7 @@ def _bind(stmt, batch): stmt.bind(array, schema) -class TestADBCStatementBind(object): +class TestADBCStatementBind: def test_bind_multiple_rows(self): data = pa.record_batch( [ diff --git a/tests/fast/api/test_3324.py b/tests/fast/api/test_3324.py index f3cd235b..fb860600 100644 --- a/tests/fast/api/test_3324.py +++ b/tests/fast/api/test_3324.py @@ -1,8 +1,9 @@ import pytest + import duckdb -class Test3324(object): +class Test3324: def test_3324(self, duckdb_cursor): create_output = duckdb_cursor.execute( """ diff --git a/tests/fast/api/test_3654.py b/tests/fast/api/test_3654.py index 8fad47e6..2ffee855 100644 --- a/tests/fast/api/test_3654.py +++ b/tests/fast/api/test_3654.py @@ -1,16 +1,17 @@ -import duckdb import pytest +import duckdb + try: import pyarrow as pa can_run = True except: can_run = False -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas -class Test3654(object): +class Test3654: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_3654_pandas(self, duckdb_cursor, pandas): df1 = pandas.DataFrame( diff --git a/tests/fast/api/test_3728.py b/tests/fast/api/test_3728.py index 37b50ee6..bd770bf0 100644 --- a/tests/fast/api/test_3728.py +++ b/tests/fast/api/test_3728.py @@ -1,7 +1,7 @@ import duckdb -class Test3728(object): +class Test3728: def test_3728_describe_enum(self, duckdb_cursor): # Create an in-memory database, but the problem is also present in file-backed DBs cursor = duckdb.connect(":memory:") diff --git a/tests/fast/api/test_6315.py b/tests/fast/api/test_6315.py index b9e7c0cf..3702831e 100644 --- a/tests/fast/api/test_6315.py +++ b/tests/fast/api/test_6315.py @@ -1,7 +1,7 @@ import duckdb -class Test6315(object): +class Test6315: def test_6315(self, duckdb_cursor): # segfault when accessing description after fetching rows c = duckdb.connect(":memory:") diff --git a/tests/fast/api/test_attribute_getter.py b/tests/fast/api/test_attribute_getter.py index eda6845a..3b1513d1 100644 --- a/tests/fast/api/test_attribute_getter.py +++ b/tests/fast/api/test_attribute_getter.py @@ -1,15 +1,10 @@ -import duckdb -import tempfile -import os -import pandas as pd -import tempfile -import pandas._testing as tm -import datetime -import csv + import pytest +import duckdb + -class TestGetAttribute(object): +class TestGetAttribute: def test_basic_getattr(self, duckdb_cursor): rel = duckdb_cursor.sql("select i as a, (i + 5) % 10 as b, (i + 2) % 3 as c from range(100) tbl(i)") assert rel.a.fetchmany(5) == [(0,), (1,), (2,), (3,), (4,)] diff --git a/tests/fast/api/test_config.py b/tests/fast/api/test_config.py index 4a0a0445..89620e96 100644 --- a/tests/fast/api/test_config.py +++ b/tests/fast/api/test_config.py @@ -1,14 +1,15 @@ # simple DB API testcase -import duckdb -import numpy -import pytest -import re import os -from conftest import NumpyPandas, ArrowPandas +import re + +import pytest +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestDBConfig(object): +class TestDBConfig: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_default_order(self, duckdb_cursor, pandas): df = pandas.DataFrame({"a": [1, 2, 3]}) @@ -51,7 +52,7 @@ def test_extension_setting(self): if not repository: return con = duckdb.connect(config={"TimeZone": "UTC", "autoinstall_extension_repository": repository}) - assert "UTC" == con.sql("select current_setting('TimeZone')").fetchone()[0] + assert con.sql("select current_setting('TimeZone')").fetchone()[0] == "UTC" def test_unrecognized_option(self, duckdb_cursor): success = True diff --git a/tests/fast/api/test_connection_close.py b/tests/fast/api/test_connection_close.py index f71a02bb..bbf66772 100644 --- a/tests/fast/api/test_connection_close.py +++ b/tests/fast/api/test_connection_close.py @@ -1,10 +1,12 @@ # cursor description -import duckdb -import tempfile import os +import tempfile + import pytest +import duckdb + def check_exception(f): had_exception = False @@ -15,7 +17,7 @@ def check_exception(f): assert had_exception -class TestConnectionClose(object): +class TestConnectionClose: def test_connection_close(self, duckdb_cursor): fd, db = tempfile.mkstemp() os.close(fd) diff --git a/tests/fast/api/test_connection_interrupt.py b/tests/fast/api/test_connection_interrupt.py index 4efd68b5..8a027b5a 100644 --- a/tests/fast/api/test_connection_interrupt.py +++ b/tests/fast/api/test_connection_interrupt.py @@ -2,11 +2,12 @@ import threading import time -import duckdb import pytest +import duckdb + -class TestConnectionInterrupt(object): +class TestConnectionInterrupt: @pytest.mark.xfail( condition=platform.system() == "Emscripten", reason="threads not allowed on Emscripten", diff --git a/tests/fast/api/test_cursor.py b/tests/fast/api/test_cursor.py index 69c3fe79..7a2c4176 100644 --- a/tests/fast/api/test_cursor.py +++ b/tests/fast/api/test_cursor.py @@ -1,10 +1,11 @@ # simple DB API testcase import pytest + import duckdb -class TestDBAPICursor(object): +class TestDBAPICursor: def test_cursor_basic(self): # Create a connection con = duckdb.connect(":memory:") diff --git a/tests/fast/api/test_dbapi00.py b/tests/fast/api/test_dbapi00.py index 38d87887..6201d569 100644 --- a/tests/fast/api/test_dbapi00.py +++ b/tests/fast/api/test_dbapi00.py @@ -2,15 +2,14 @@ import numpy import pytest -import duckdb -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas def assert_result_equal(result): assert result == [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (None,)], "Incorrect result returned" -class TestSimpleDBAPI(object): +class TestSimpleDBAPI: def test_regular_selection(self, duckdb_cursor, integers): duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchall() diff --git a/tests/fast/api/test_dbapi01.py b/tests/fast/api/test_dbapi01.py index f7f00a10..4d52fd64 100644 --- a/tests/fast/api/test_dbapi01.py +++ b/tests/fast/api/test_dbapi01.py @@ -1,10 +1,11 @@ # multiple result sets import numpy + import duckdb -class TestMultipleResultSets(object): +class TestMultipleResultSets: def test_regular_selection(self, duckdb_cursor, integers): duckdb_cursor.execute("SELECT * FROM integers") duckdb_cursor.execute("SELECT * FROM integers") diff --git a/tests/fast/api/test_dbapi04.py b/tests/fast/api/test_dbapi04.py index 1125f819..2c2259ce 100644 --- a/tests/fast/api/test_dbapi04.py +++ b/tests/fast/api/test_dbapi04.py @@ -1,7 +1,7 @@ # simple DB API testcase -class TestSimpleDBAPI(object): +class TestSimpleDBAPI: def test_regular_selection(self, duckdb_cursor, integers): duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchall() diff --git a/tests/fast/api/test_dbapi05.py b/tests/fast/api/test_dbapi05.py index 234fb2ec..6c6d4fa1 100644 --- a/tests/fast/api/test_dbapi05.py +++ b/tests/fast/api/test_dbapi05.py @@ -1,7 +1,7 @@ # simple DB API testcase -class TestSimpleDBAPI(object): +class TestSimpleDBAPI: def test_prepare(self, duckdb_cursor): result = duckdb_cursor.execute("SELECT CAST(? AS INTEGER), CAST(? AS INTEGER)", ["42", "84"]).fetchall() assert result == [ diff --git a/tests/fast/api/test_dbapi07.py b/tests/fast/api/test_dbapi07.py index 238f30fc..eab581e5 100644 --- a/tests/fast/api/test_dbapi07.py +++ b/tests/fast/api/test_dbapi07.py @@ -1,16 +1,17 @@ # timestamp ms precision -import numpy from datetime import datetime +import numpy + -class TestNumpyTimestampMilliseconds(object): +class TestNumpyTimestampMilliseconds: def test_numpy_timestamp(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIMESTAMP '2019-11-26 21:11:42.501' as test_time").fetchnumpy() assert res["test_time"] == numpy.datetime64("2019-11-26 21:11:42.501") -class TestTimestampMilliseconds(object): +class TestTimestampMilliseconds: def test_numpy_timestamp(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIMESTAMP '2019-11-26 21:11:42.501' as test_time").fetchone()[0] assert res == datetime.strptime("2019-11-26 21:11:42.501", "%Y-%m-%d %H:%M:%S.%f") diff --git a/tests/fast/api/test_dbapi08.py b/tests/fast/api/test_dbapi08.py index 457a9e78..def4e925 100644 --- a/tests/fast/api/test_dbapi08.py +++ b/tests/fast/api/test_dbapi08.py @@ -1,11 +1,11 @@ # test fetchdf with various types -import numpy import pytest -import duckdb from conftest import NumpyPandas +import duckdb + -class TestType(object): +class TestType: @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_fetchdf(self, pandas): con = duckdb.connect() diff --git a/tests/fast/api/test_dbapi09.py b/tests/fast/api/test_dbapi09.py index 538e7fc3..8a31e10e 100644 --- a/tests/fast/api/test_dbapi09.py +++ b/tests/fast/api/test_dbapi09.py @@ -1,11 +1,12 @@ # date type -import numpy import datetime + +import numpy import pandas -class TestNumpyDate(object): +class TestNumpyDate: def test_fetchall_date(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT DATE '2020-01-10' as test_date").fetchall() assert res == [(datetime.date(2020, 1, 10),)] diff --git a/tests/fast/api/test_dbapi10.py b/tests/fast/api/test_dbapi10.py index 0ab69e0b..8b5cb0e4 100644 --- a/tests/fast/api/test_dbapi10.py +++ b/tests/fast/api/test_dbapi10.py @@ -1,10 +1,11 @@ # cursor description -from datetime import datetime, date +from datetime import date, datetime + from pytest import mark import duckdb -class TestCursorDescription(object): +class TestCursorDescription: @mark.parametrize( "query,column_name,string_type,real_type", [ @@ -51,6 +52,6 @@ def test_none_description(self, duckdb_empty_cursor): assert duckdb_empty_cursor.description is None -class TestCursorRowcount(object): +class TestCursorRowcount: def test_rowcount(self, duckdb_cursor): assert duckdb_cursor.rowcount == -1 diff --git a/tests/fast/api/test_dbapi11.py b/tests/fast/api/test_dbapi11.py index 91237b9e..56f5724d 100644 --- a/tests/fast/api/test_dbapi11.py +++ b/tests/fast/api/test_dbapi11.py @@ -1,8 +1,9 @@ # cursor description -import duckdb -import tempfile import os +import tempfile + +import duckdb def check_exception(f): @@ -14,7 +15,7 @@ def check_exception(f): assert had_exception -class TestReadOnly(object): +class TestReadOnly: def test_readonly(self, duckdb_cursor): fd, db = tempfile.mkstemp() os.close(fd) diff --git a/tests/fast/api/test_dbapi12.py b/tests/fast/api/test_dbapi12.py index 833d231c..96b1deac 100644 --- a/tests/fast/api/test_dbapi12.py +++ b/tests/fast/api/test_dbapi12.py @@ -1,10 +1,10 @@ -import duckdb -import tempfile -import os + import pandas as pd +import duckdb + -class TestRelationApi(object): +class TestRelationApi: def test_readonly(self, duckdb_cursor): test_df = pd.DataFrame.from_dict({"i": [1, 2, 3], "j": ["one", "two", "three"]}) diff --git a/tests/fast/api/test_dbapi13.py b/tests/fast/api/test_dbapi13.py index ffdb4884..c08cefb1 100644 --- a/tests/fast/api/test_dbapi13.py +++ b/tests/fast/api/test_dbapi13.py @@ -1,11 +1,12 @@ # time type -import numpy import datetime + +import numpy import pandas -class TestNumpyTime(object): +class TestNumpyTime: def test_fetchall_time(self, duckdb_cursor): res = duckdb_cursor.execute("SELECT TIME '13:06:40' as test_time").fetchall() assert res == [(datetime.time(13, 6, 40),)] diff --git a/tests/fast/api/test_dbapi_fetch.py b/tests/fast/api/test_dbapi_fetch.py index 9c47c54c..5ec18aca 100644 --- a/tests/fast/api/test_dbapi_fetch.py +++ b/tests/fast/api/test_dbapi_fetch.py @@ -1,11 +1,13 @@ -import duckdb -import pytest -from uuid import UUID import datetime from decimal import Decimal +from uuid import UUID + +import pytest + +import duckdb -class TestDBApiFetch(object): +class TestDBApiFetch: def test_multiple_fetch_one(self, duckdb_cursor): con = duckdb.connect() c = con.execute("SELECT 42") diff --git a/tests/fast/api/test_duckdb_connection.py b/tests/fast/api/test_duckdb_connection.py index 4b0dc4d6..eb241145 100644 --- a/tests/fast/api/test_duckdb_connection.py +++ b/tests/fast/api/test_duckdb_connection.py @@ -1,7 +1,8 @@ +import pytest +from conftest import ArrowPandas, NumpyPandas + import duckdb import duckdb.typing -import pytest -from conftest import NumpyPandas, ArrowPandas pa = pytest.importorskip("pyarrow") @@ -22,7 +23,7 @@ def tmp_database(tmp_path_factory): # This file contains tests for DuckDBPyConnection methods, # wrapped by the 'duckdb' module, to execute with the 'default_connection' -class TestDuckDBConnection(object): +class TestDuckDBConnection: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_append(self, pandas): duckdb.execute("Create table integers (i integer)") @@ -118,7 +119,7 @@ def test_readonly_properties(self): assert rowcount == -1 def test_execute(self): - assert [([4, 2],)] == duckdb.execute("select [4,2]").fetchall() + assert duckdb.execute("select [4,2]").fetchall() == [([4, 2],)] def test_executemany(self): # executemany does not keep an open result set @@ -231,7 +232,7 @@ def test_fetch_record_batch(self): assert len(chunk) == 3000 def test_fetchall(self): - assert [([1, 2, 3],)] == duckdb.execute("select [1,2,3]").fetchall() + assert duckdb.execute("select [1,2,3]").fetchall() == [([1, 2, 3],)] def test_fetchdf(self): ref = [([1, 2, 3],)] @@ -241,7 +242,7 @@ def test_fetchdf(self): assert res == ref def test_fetchmany(self): - assert [(0,), (1,)] == duckdb.execute("select * from range(5)").fetchmany(2) + assert duckdb.execute("select * from range(5)").fetchmany(2) == [(0,), (1,)] def test_fetchnumpy(self): numpy = pytest.importorskip("numpy") @@ -254,37 +255,37 @@ def test_fetchnumpy(self): assert results["a"] == numpy.array([b"hello"], dtype=object) def test_fetchone(self): - assert (0,) == duckdb.execute("select * from range(5)").fetchone() + assert duckdb.execute("select * from range(5)").fetchone() == (0,) def test_from_arrow(self): - assert None != duckdb.from_arrow + assert duckdb.from_arrow != None def test_from_csv_auto(self): - assert None != duckdb.from_csv_auto + assert duckdb.from_csv_auto != None def test_from_df(self): - assert None != duckdb.from_df + assert duckdb.from_df != None def test_from_parquet(self): - assert None != duckdb.from_parquet + assert duckdb.from_parquet != None def test_from_query(self): - assert None != duckdb.from_query + assert duckdb.from_query != None def test_get_table_names(self): - assert None != duckdb.get_table_names + assert duckdb.get_table_names != None def test_install_extension(self): - assert None != duckdb.install_extension + assert duckdb.install_extension != None def test_load_extension(self): - assert None != duckdb.load_extension + assert duckdb.load_extension != None def test_query(self): - assert [(3,)] == duckdb.query("select 3").fetchall() + assert duckdb.query("select 3").fetchall() == [(3,)] def test_register(self): - assert None != duckdb.register + assert duckdb.register != None def test_register_relation(self): con = duckdb.connect() @@ -334,27 +335,27 @@ def temporary_scope(): def test_table(self): con = duckdb.connect() con.execute("create table tbl as select 1") - assert [(1,)] == con.table("tbl").fetchall() + assert con.table("tbl").fetchall() == [(1,)] def test_table_function(self): - assert None != duckdb.table_function + assert duckdb.table_function != None def test_unregister(self): - assert None != duckdb.unregister + assert duckdb.unregister != None def test_values(self): - assert None != duckdb.values + assert duckdb.values != None def test_view(self): duckdb.execute("create view vw as select range(5)") - assert [([0, 1, 2, 3, 4],)] == duckdb.view("vw").fetchall() + assert duckdb.view("vw").fetchall() == [([0, 1, 2, 3, 4],)] duckdb.execute("drop view vw") def test_close(self): - assert None != duckdb.close + assert duckdb.close != None def test_interrupt(self): - assert None != duckdb.interrupt + assert duckdb.interrupt != None def test_wrap_shadowing(self): pd = NumpyPandas() @@ -393,7 +394,7 @@ def test_set_pandas_analyze_sample_size(self): # Find the cached config con2 = duckdb.connect(":memory:named", config={"pandas_analyze_sample": 0}) - con2.execute(f"SET GLOBAL pandas_analyze_sample=2") + con2.execute("SET GLOBAL pandas_analyze_sample=2") # This change is reflected in 'con' because the instance was cached res = con.sql("select current_setting('pandas_analyze_sample')").fetchone() diff --git a/tests/fast/api/test_duckdb_execute.py b/tests/fast/api/test_duckdb_execute.py index a025fc42..df8bff63 100644 --- a/tests/fast/api/test_duckdb_execute.py +++ b/tests/fast/api/test_duckdb_execute.py @@ -1,8 +1,9 @@ -import duckdb import pytest +import duckdb + -class TestDuckDBExecute(object): +class TestDuckDBExecute: def test_execute_basic(self, duckdb_cursor): duckdb_cursor.execute("create table t as select 5") res = duckdb_cursor.table("t").fetchall() diff --git a/tests/fast/api/test_duckdb_query.py b/tests/fast/api/test_duckdb_query.py index 2ecfd8f3..db807f44 100644 --- a/tests/fast/api/test_duckdb_query.py +++ b/tests/fast/api/test_duckdb_query.py @@ -1,10 +1,11 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb from duckdb import Value -class TestDuckDBQuery(object): +class TestDuckDBQuery: def test_duckdb_query(self, duckdb_cursor): # we can use duckdb_cursor.sql to run both DDL statements and select statements duckdb_cursor.sql("create view v1 as select 42 i") diff --git a/tests/fast/api/test_explain.py b/tests/fast/api/test_explain.py index feedc134..23bcfcd4 100644 --- a/tests/fast/api/test_explain.py +++ b/tests/fast/api/test_explain.py @@ -1,8 +1,9 @@ import pytest + import duckdb -class TestExplain(object): +class TestExplain: def test_explain_basic(self, duckdb_cursor): res = duckdb_cursor.sql("select 42").explain() assert isinstance(res, str) diff --git a/tests/fast/api/test_fsspec.py b/tests/fast/api/test_fsspec.py index 7b797598..d7d2503d 100644 --- a/tests/fast/api/test_fsspec.py +++ b/tests/fast/api/test_fsspec.py @@ -1,16 +1,16 @@ -import pytest -import duckdb -import io import datetime +import io + +import pytest fsspec = pytest.importorskip("fsspec") -class TestReadParquet(object): +class TestReadParquet: def test_fsspec_deadlock(self, duckdb_cursor, tmp_path): # Create test parquet data file_path = tmp_path / "data.parquet" - duckdb_cursor.sql("COPY (FROM range(50_000)) TO '{}' (FORMAT parquet)".format(str(file_path))) + duckdb_cursor.sql(f"COPY (FROM range(50_000)) TO '{file_path!s}' (FORMAT parquet)") with open(file_path, "rb") as f: parquet_data = f.read() diff --git a/tests/fast/api/test_insert_into.py b/tests/fast/api/test_insert_into.py index 2537c182..1214203b 100644 --- a/tests/fast/api/test_insert_into.py +++ b/tests/fast/api/test_insert_into.py @@ -1,9 +1,10 @@ -import duckdb -from pandas import DataFrame import pytest +from pandas import DataFrame + +import duckdb -class TestInsertInto(object): +class TestInsertInto: def test_insert_into_schema(self, duckdb_cursor): # open connection con = duckdb.connect() diff --git a/tests/fast/api/test_join.py b/tests/fast/api/test_join.py index 5e2a148f..30ace540 100644 --- a/tests/fast/api/test_join.py +++ b/tests/fast/api/test_join.py @@ -1,8 +1,9 @@ -import duckdb import pytest +import duckdb + -class TestJoin(object): +class TestJoin: def test_alias_from_sql(self): con = duckdb.connect() rel1 = con.sql("SELECT 1 AS col1, 2 AS col2") diff --git a/tests/fast/api/test_native_tz.py b/tests/fast/api/test_native_tz.py index f4a9d716..39d301e2 100644 --- a/tests/fast/api/test_native_tz.py +++ b/tests/fast/api/test_native_tz.py @@ -1,9 +1,10 @@ -import duckdb import datetime -import pytz import os + import pytest +import duckdb + pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow") from packaging.version import Version @@ -11,7 +12,7 @@ filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "data", "tz.parquet") -class TestNativeTimeZone(object): +class TestNativeTimeZone: def test_native_python_timestamp_timezone(self, duckdb_cursor): duckdb_cursor.execute("SET timezone='America/Los_Angeles';") res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetchone() diff --git a/tests/fast/api/test_query_interrupt.py b/tests/fast/api/test_query_interrupt.py index e6d2b998..56c182f8 100644 --- a/tests/fast/api/test_query_interrupt.py +++ b/tests/fast/api/test_query_interrupt.py @@ -1,10 +1,11 @@ -import duckdb +import _thread as thread +import platform +import threading import time + import pytest -import platform -import threading -import _thread as thread +import duckdb def send_keyboard_interrupt(): @@ -14,7 +15,7 @@ def send_keyboard_interrupt(): thread.interrupt_main() -class TestQueryInterruption(object): +class TestQueryInterruption: @pytest.mark.xfail( condition=platform.system() == "Emscripten", reason="Emscripten builds cannot use threads", diff --git a/tests/fast/api/test_query_progress.py b/tests/fast/api/test_query_progress.py index f885e36d..8d1d85a9 100644 --- a/tests/fast/api/test_query_progress.py +++ b/tests/fast/api/test_query_progress.py @@ -2,11 +2,12 @@ import threading import time -import duckdb import pytest +import duckdb + -class TestQueryProgress(object): +class TestQueryProgress: @pytest.mark.xfail( condition=platform.system() == "Emscripten", reason="threads not allowed on Emscripten", @@ -33,7 +34,7 @@ def thread_target(): # query never progresses. This will also fail if the query is too # quick as it will be back at -1 as soon as the query is finished. - for _ in range(0, 500): + for _ in range(500): assert thread.is_alive(), "query finished too quick" if (qp1 := conn.query_progress()) > 0: break @@ -42,7 +43,7 @@ def thread_target(): pytest.fail("query start timeout") # keep monitoring and wait for the progress to increase - for _ in range(0, 500): + for _ in range(500): assert thread.is_alive(), "query finished too quick" if (qp2 := conn.query_progress()) > qp1: break diff --git a/tests/fast/api/test_read_csv.py b/tests/fast/api/test_read_csv.py index dff90869..a4e90c44 100644 --- a/tests/fast/api/test_read_csv.py +++ b/tests/fast/api/test_read_csv.py @@ -1,11 +1,12 @@ -from multiprocessing.sharedctypes import Value import datetime -import pytest import platform +import sys +from io import BytesIO, StringIO + +import pytest + import duckdb -from io import StringIO, BytesIO from duckdb import CSVLineTerminator -import sys def TestFile(name): @@ -33,7 +34,7 @@ def create_temp_csv(tmp_path): return file1_path, file2_path -class TestReadCSV(object): +class TestReadCSV: def test_using_connection_wrapper(self): rel = duckdb.read_csv(TestFile("category.csv")) res = rel.fetchone() @@ -361,7 +362,6 @@ def test_filelike_custom(self, duckdb_cursor): class CustomIO: def __init__(self) -> None: self.loc = 0 - pass def seek(self, loc): self.loc = loc diff --git a/tests/fast/api/test_relation_to_view.py b/tests/fast/api/test_relation_to_view.py index 31a19d54..14f4cb4d 100644 --- a/tests/fast/api/test_relation_to_view.py +++ b/tests/fast/api/test_relation_to_view.py @@ -1,8 +1,9 @@ import pytest + import duckdb -class TestRelationToView(object): +class TestRelationToView: def test_values_to_view(self, duckdb_cursor): rel = duckdb_cursor.values(["test", "this is a long string"]) res = rel.fetchall() diff --git a/tests/fast/api/test_streaming_result.py b/tests/fast/api/test_streaming_result.py index 739fd17a..700057ed 100644 --- a/tests/fast/api/test_streaming_result.py +++ b/tests/fast/api/test_streaming_result.py @@ -1,8 +1,9 @@ import pytest + import duckdb -class TestStreamingResult(object): +class TestStreamingResult: def test_fetch_one(self, duckdb_cursor): # fetch one res = duckdb_cursor.sql("SELECT * FROM range(100000)") diff --git a/tests/fast/api/test_to_csv.py b/tests/fast/api/test_to_csv.py index 5f8000a9..ef2aef6c 100644 --- a/tests/fast/api/test_to_csv.py +++ b/tests/fast/api/test_to_csv.py @@ -1,14 +1,15 @@ -import duckdb -import tempfile -import os -import pandas._testing as tm -import datetime import csv +import datetime +import os +import tempfile + import pytest -from conftest import NumpyPandas, ArrowPandas, getTimeSeriesData +from conftest import ArrowPandas, NumpyPandas, getTimeSeriesData + +import duckdb -class TestToCSV(object): +class TestToCSV: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_basic_to_csv(self, pandas): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) diff --git a/tests/fast/api/test_to_parquet.py b/tests/fast/api/test_to_parquet.py index c13ac011..834763bf 100644 --- a/tests/fast/api/test_to_parquet.py +++ b/tests/fast/api/test_to_parquet.py @@ -1,15 +1,13 @@ -import duckdb -import tempfile import os import tempfile -import pandas._testing as tm -import datetime -import csv + import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestToParquet(object): +class TestToParquet: @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_basic_to_parquet(self, pd): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) @@ -43,12 +41,12 @@ def test_field_ids(self): rel.to_parquet(temp_file_name, field_ids=dict(i=42, my_struct={"__duckdb_field_id": 43, "j": 44})) parquet_rel = duckdb.read_parquet(temp_file_name) assert rel.execute().fetchall() == parquet_rel.execute().fetchall() - assert [("duckdb_schema", None), ("i", 42), ("my_struct", 43), ("j", 44)] == duckdb.sql( + assert duckdb.sql( f""" select name,field_id from parquet_schema('{temp_file_name}') """ - ).execute().fetchall() + ).execute().fetchall() == [("duckdb_schema", None), ("i", 42), ("my_struct", 43), ("j", 44)] @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) @pytest.mark.parametrize("row_group_size_bytes", [122880 * 1024, "2MB"]) diff --git a/tests/fast/api/test_with_propagating_exceptions.py b/tests/fast/api/test_with_propagating_exceptions.py index 8613d6f4..41df088f 100644 --- a/tests/fast/api/test_with_propagating_exceptions.py +++ b/tests/fast/api/test_with_propagating_exceptions.py @@ -1,8 +1,9 @@ import pytest + import duckdb -class TestWithPropagatingExceptions(object): +class TestWithPropagatingExceptions: def test_with(self): # Should propagate exception raised in the 'with duckdb.connect() ..' with pytest.raises(duckdb.ParserException, match="syntax error at or near *"): diff --git a/tests/fast/arrow/parquet_write_roundtrip.py b/tests/fast/arrow/parquet_write_roundtrip.py index 5dbf3949..5c42773c 100644 --- a/tests/fast/arrow/parquet_write_roundtrip.py +++ b/tests/fast/arrow/parquet_write_roundtrip.py @@ -1,9 +1,11 @@ -import duckdb -import pytest +import datetime import tempfile + import numpy import pandas -import datetime +import pytest + +import duckdb pa = pytest.importorskip("pyarrow") @@ -37,7 +39,7 @@ def parquet_types_test(type_list): assert read_df.equals(read_from_arrow) -class TestParquetRoundtrip(object): +class TestParquetRoundtrip: def test_roundtrip_numeric(self, duckdb_cursor): type_list = [ ([-(2**7), 0, 2**7 - 1], numpy.int8, "TINYINT"), diff --git a/tests/fast/arrow/test_10795.py b/tests/fast/arrow/test_10795.py index 5503e529..5dc88402 100644 --- a/tests/fast/arrow/test_10795.py +++ b/tests/fast/arrow/test_10795.py @@ -1,6 +1,7 @@ -import duckdb import pytest +import duckdb + pyarrow = pytest.importorskip("pyarrow") diff --git a/tests/fast/arrow/test_12384.py b/tests/fast/arrow/test_12384.py index d2d4a7fc..e91cbe8c 100644 --- a/tests/fast/arrow/test_12384.py +++ b/tests/fast/arrow/test_12384.py @@ -1,7 +1,9 @@ -import duckdb -import pytest import os +import pytest + +import duckdb + pa = pytest.importorskip("pyarrow") diff --git a/tests/fast/arrow/test_14344.py b/tests/fast/arrow/test_14344.py index 86f8728b..77cfaaa2 100644 --- a/tests/fast/arrow/test_14344.py +++ b/tests/fast/arrow/test_14344.py @@ -1,4 +1,3 @@ -import duckdb import pytest pa = pytest.importorskip("pyarrow") @@ -6,13 +5,13 @@ def test_14344(duckdb_cursor): - my_table = pa.Table.from_pydict({"foo": pa.array([hashlib.sha256("foo".encode()).digest()], type=pa.binary())}) + my_table = pa.Table.from_pydict({"foo": pa.array([hashlib.sha256(b"foo").digest()], type=pa.binary())}) my_table2 = pa.Table.from_pydict( - {"foo": pa.array([hashlib.sha256("foo".encode()).digest()], type=pa.binary()), "a": ["123"]} + {"foo": pa.array([hashlib.sha256(b"foo").digest()], type=pa.binary()), "a": ["123"]} ) res = duckdb_cursor.sql( - f""" + """ SELECT my_table2.* EXCLUDE (foo) FROM diff --git a/tests/fast/arrow/test_2426.py b/tests/fast/arrow/test_2426.py index 6d760500..5e6d42ef 100644 --- a/tests/fast/arrow/test_2426.py +++ b/tests/fast/arrow/test_2426.py @@ -1,15 +1,14 @@ + import duckdb -import os try: - import pyarrow as pa can_run = True except: can_run = False -class Test2426(object): +class Test2426: def test_2426(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_5547.py b/tests/fast/arrow/test_5547.py index eb77ab83..8e8b40ed 100644 --- a/tests/fast/arrow/test_5547.py +++ b/tests/fast/arrow/test_5547.py @@ -1,7 +1,8 @@ -import duckdb import pandas as pd -from pandas.testing import assert_frame_equal import pytest +from pandas.testing import assert_frame_equal + +import duckdb pa = pytest.importorskip("pyarrow") diff --git a/tests/fast/arrow/test_6584.py b/tests/fast/arrow/test_6584.py index 6f96bf2d..feadc6d7 100644 --- a/tests/fast/arrow/test_6584.py +++ b/tests/fast/arrow/test_6584.py @@ -1,7 +1,9 @@ from concurrent.futures import ThreadPoolExecutor -import duckdb + import pytest +import duckdb + pyarrow = pytest.importorskip("pyarrow") diff --git a/tests/fast/arrow/test_6796.py b/tests/fast/arrow/test_6796.py index ef464f49..454fa005 100644 --- a/tests/fast/arrow/test_6796.py +++ b/tests/fast/arrow/test_6796.py @@ -1,6 +1,7 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb pyarrow = pytest.importorskip("pyarrow") diff --git a/tests/fast/arrow/test_7652.py b/tests/fast/arrow/test_7652.py index 857d871d..e38174b9 100644 --- a/tests/fast/arrow/test_7652.py +++ b/tests/fast/arrow/test_7652.py @@ -1,13 +1,12 @@ -import duckdb -import os -import pytest import tempfile +import pytest + pa = pytest.importorskip("pyarrow", minversion="11") pq = pytest.importorskip("pyarrow.parquet", minversion="11") -class Test7652(object): +class Test7652: def test_7652(self, duckdb_cursor): temp_file_name = tempfile.NamedTemporaryFile(suffix=".parquet").name # Generate a list of values that aren't uniform in changes. diff --git a/tests/fast/arrow/test_7699.py b/tests/fast/arrow/test_7699.py index a4de66b9..ba2f4af3 100644 --- a/tests/fast/arrow/test_7699.py +++ b/tests/fast/arrow/test_7699.py @@ -1,13 +1,13 @@ -import duckdb -import pytest import string +import pytest + pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") pl = pytest.importorskip("polars") -class Test7699(object): +class Test7699: def test_7699(self, duckdb_cursor): pl_tbl = pl.DataFrame( { diff --git a/tests/fast/arrow/test_8522.py b/tests/fast/arrow/test_8522.py index 84aa125c..681e8fdf 100644 --- a/tests/fast/arrow/test_8522.py +++ b/tests/fast/arrow/test_8522.py @@ -1,15 +1,14 @@ -import duckdb -import pytest -import string import datetime as dt +import pytest + pa = pytest.importorskip("pyarrow") # Reconstruct filters when pushing down into arrow scan # arrow supports timestamp_tz with different units than US, we only support US # so we have to convert ConstantValues back to their native unit when pushing the filter expression containing them down to pyarrow -class Test8522(object): +class Test8522: def test_8522(self, duckdb_cursor): t_us = pa.Table.from_arrays( arrays=[pa.array([dt.datetime(2022, 1, 1)])], diff --git a/tests/fast/arrow/test_9443.py b/tests/fast/arrow/test_9443.py index 7de04bde..f6627c00 100644 --- a/tests/fast/arrow/test_9443.py +++ b/tests/fast/arrow/test_9443.py @@ -1,4 +1,3 @@ -import duckdb import pytest pq = pytest.importorskip("pyarrow.parquet") @@ -8,7 +7,7 @@ from pathlib import PurePosixPath -class Test9443(object): +class Test9443: def test_9443(self, tmp_path, duckdb_cursor): arrow_table = pa.Table.from_pylist( [ diff --git a/tests/fast/arrow/test_arrow_batch_index.py b/tests/fast/arrow/test_arrow_batch_index.py index a8dc2c7f..0cd4d679 100644 --- a/tests/fast/arrow/test_arrow_batch_index.py +++ b/tests/fast/arrow/test_arrow_batch_index.py @@ -1,12 +1,11 @@ -import duckdb import pytest -import pandas as pd + import duckdb pa = pytest.importorskip("pyarrow") -class TestArrowBatchIndex(object): +class TestArrowBatchIndex: def test_arrow_batch_index(self, duckdb_cursor): con = duckdb.connect() df = con.execute("SELECT * FROM range(10000000) t(i)").df() diff --git a/tests/fast/arrow/test_arrow_binary_view.py b/tests/fast/arrow/test_arrow_binary_view.py index 31107f67..4e161ac3 100644 --- a/tests/fast/arrow/test_arrow_binary_view.py +++ b/tests/fast/arrow/test_arrow_binary_view.py @@ -1,10 +1,11 @@ -import duckdb import pytest +import duckdb + pa = pytest.importorskip("pyarrow") -class TestArrowBinaryView(object): +class TestArrowBinaryView: def test_arrow_binary_view(self, duckdb_cursor): con = duckdb.connect() tab = pa.table({"x": pa.array([b"abc", b"thisisaverybigbinaryyaymorethanfifteen", None], pa.binary_view())}) diff --git a/tests/fast/arrow/test_arrow_case_sensitive.py b/tests/fast/arrow/test_arrow_case_sensitive.py index ef60046a..11bca339 100644 --- a/tests/fast/arrow/test_arrow_case_sensitive.py +++ b/tests/fast/arrow/test_arrow_case_sensitive.py @@ -1,10 +1,9 @@ -import duckdb import pytest pa = pytest.importorskip("pyarrow") -class TestArrowCaseSensitive(object): +class TestArrowCaseSensitive: def test_arrow_case_sensitive(self, duckdb_cursor): data = (pa.array([1], type=pa.int32()), pa.array([1000], type=pa.int32())) arrow_table = pa.Table.from_arrays([data[0], data[1]], ["A1", "a1"]) diff --git a/tests/fast/arrow/test_arrow_decimal256.py b/tests/fast/arrow/test_arrow_decimal256.py index 0ab84d3a..08612918 100644 --- a/tests/fast/arrow/test_arrow_decimal256.py +++ b/tests/fast/arrow/test_arrow_decimal256.py @@ -1,11 +1,13 @@ -import duckdb -import pytest from decimal import Decimal +import pytest + +import duckdb + pa = pytest.importorskip("pyarrow") -class TestArrowDecimal256(object): +class TestArrowDecimal256: def test_decimal_256_throws(self, duckdb_cursor): with duckdb.connect() as conn: pa_decimal256 = pa.Table.from_pylist( diff --git a/tests/fast/arrow/test_arrow_decimal_32_64.py b/tests/fast/arrow/test_arrow_decimal_32_64.py index 39b6e43a..301d890f 100644 --- a/tests/fast/arrow/test_arrow_decimal_32_64.py +++ b/tests/fast/arrow/test_arrow_decimal_32_64.py @@ -1,11 +1,13 @@ -import duckdb -import pytest from decimal import Decimal +import pytest + +import duckdb + pa = pytest.importorskip("pyarrow") -class TestArrowDecimalTypes(object): +class TestArrowDecimalTypes: def test_decimal_32(self, duckdb_cursor): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("SET arrow_output_version = 1.5") diff --git a/tests/fast/arrow/test_arrow_extensions.py b/tests/fast/arrow/test_arrow_extensions.py index 43c995bb..1bc0e179 100644 --- a/tests/fast/arrow/test_arrow_extensions.py +++ b/tests/fast/arrow/test_arrow_extensions.py @@ -1,14 +1,16 @@ -import duckdb -import pytest -import uuid +import datetime import json +import uuid from uuid import UUID -import datetime + +import pytest + +import duckdb pa = pytest.importorskip("pyarrow", "18.0.0") -class TestCanonicalExtensionTypes(object): +class TestCanonicalExtensionTypes: def test_uuid(self): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("SET arrow_lossless_conversion = true") diff --git a/tests/fast/arrow/test_arrow_fetch.py b/tests/fast/arrow/test_arrow_fetch.py index a969da21..62460912 100644 --- a/tests/fast/arrow/test_arrow_fetch.py +++ b/tests/fast/arrow/test_arrow_fetch.py @@ -1,8 +1,7 @@ + import duckdb -import pytest try: - import pyarrow as pa can_run = True except: @@ -18,7 +17,7 @@ def check_equal(duckdb_conn): assert arrow_result == true_result -class TestArrowFetch(object): +class TestArrowFetch: def test_empty_table(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_arrow_fetch_recordbatch.py b/tests/fast/arrow/test_arrow_fetch_recordbatch.py index 8915d886..4d7fe28a 100644 --- a/tests/fast/arrow/test_arrow_fetch_recordbatch.py +++ b/tests/fast/arrow/test_arrow_fetch_recordbatch.py @@ -1,10 +1,11 @@ -import duckdb import pytest +import duckdb + pa = pytest.importorskip("pyarrow") -class TestArrowFetchRecordBatch(object): +class TestArrowFetchRecordBatch: # Test with basic numeric conversion (integers, floats, and others fall this code-path) def test_record_batch_next_batch_numeric(self, duckdb_cursor): duckdb_cursor = duckdb.connect() diff --git a/tests/fast/arrow/test_arrow_fixed_binary.py b/tests/fast/arrow/test_arrow_fixed_binary.py index cec8d520..754a472f 100644 --- a/tests/fast/arrow/test_arrow_fixed_binary.py +++ b/tests/fast/arrow/test_arrow_fixed_binary.py @@ -3,7 +3,7 @@ pa = pytest.importorskip("pyarrow") -class TestArrowFixedBinary(object): +class TestArrowFixedBinary: def test_arrow_fixed_binary(self, duckdb_cursor): ids = [ None, diff --git a/tests/fast/arrow/test_arrow_ipc.py b/tests/fast/arrow/test_arrow_ipc.py index 24718bbc..b3271fcd 100644 --- a/tests/fast/arrow/test_arrow_ipc.py +++ b/tests/fast/arrow/test_arrow_ipc.py @@ -1,4 +1,5 @@ import pytest + import duckdb pa = pytest.importorskip("pyarrow") @@ -11,7 +12,7 @@ def get_record_batch(): return pa.record_batch(data, names=["f0", "f1", "f2"]) -class TestArrowIPCExtension(object): +class TestArrowIPCExtension: # Only thing we can test in core is that it suggests the # instalation and loading of the extension def test_single_buffer(self, duckdb_cursor): diff --git a/tests/fast/arrow/test_arrow_list.py b/tests/fast/arrow/test_arrow_list.py index 47b8cb2a..4c2804a0 100644 --- a/tests/fast/arrow/test_arrow_list.py +++ b/tests/fast/arrow/test_arrow_list.py @@ -1,4 +1,3 @@ -import duckdb import numpy as np import pytest @@ -91,13 +90,13 @@ def generate_list(child_size) -> ListGenerationResult: return ListGenerationResult(list_arr, list_view_arr) -class TestArrowListType(object): +class TestArrowListType: def test_regular_list(self, duckdb_cursor): n = 5 # Amount of lists generated_size = 3 # Size of each list list_size = -1 # Argument passed to `pa._list()` - data = [np.random.random((generated_size)) for _ in range(n)] + data = [np.random.random(generated_size) for _ in range(n)] list_type = pa.list_(pa.float32(), list_size=list_size) create_and_register_arrow_table( @@ -120,7 +119,7 @@ def test_fixedsize_list(self, duckdb_cursor): generated_size = 3 # Size of each list list_size = 3 # Argument passed to `pa._list()` - data = [np.random.random((generated_size)) for _ in range(n)] + data = [np.random.random(generated_size) for _ in range(n)] list_type = pa.list_(pa.float32(), list_size=list_size) create_and_register_arrow_table( diff --git a/tests/fast/arrow/test_arrow_offsets.py b/tests/fast/arrow/test_arrow_offsets.py index 0ddc0f7d..32a59112 100644 --- a/tests/fast/arrow/test_arrow_offsets.py +++ b/tests/fast/arrow/test_arrow_offsets.py @@ -1,9 +1,9 @@ -import duckdb -import pytest -from pytest import mark import datetime import decimal + +import pytest import pytz +from pytest import mark pa = pytest.importorskip("pyarrow") @@ -80,10 +80,10 @@ def expected_result(col1_null, col2_null, expected): ) -class TestArrowOffsets(object): +class TestArrowOffsets: @null_test_parameters() def test_struct_of_strings(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -107,7 +107,7 @@ def test_struct_of_strings(self, duckdb_cursor, col1_null, col2_null): @null_test_parameters() def test_struct_of_bools(self, duckdb_cursor, col1_null, col2_null): - tuples = [False for i in range(0, MAGIC_ARRAY_SIZE)] + tuples = [False for i in range(MAGIC_ARRAY_SIZE)] tuples[-1] = True col1 = tuples @@ -140,7 +140,7 @@ def test_struct_of_bools(self, duckdb_cursor, col1_null, col2_null): ) @null_test_parameters() def test_struct_of_dates(self, duckdb_cursor, constructor, expected, col1_null, col2_null): - tuples = [i for i in range(0, MAGIC_ARRAY_SIZE)] + tuples = [i for i in range(MAGIC_ARRAY_SIZE)] col1 = tuples if col1_null: @@ -192,7 +192,7 @@ def test_struct_of_enum(self, duckdb_cursor, col1_null, col2_null): @null_test_parameters() def test_struct_of_blobs(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -230,7 +230,7 @@ def test_struct_of_time(self, duckdb_cursor, constructor, unit, expected, col1_n # FIXME: We limit the size because we don't support time values > 24 hours size = 86400 # The amount of seconds in a day - col1 = [i for i in range(0, size)] + col1 = [i for i in range(size)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -265,7 +265,7 @@ def test_struct_of_time(self, duckdb_cursor, constructor, unit, expected, col1_n def test_struct_of_interval(self, duckdb_cursor, constructor, expected, converter, col1_null, col2_null): size = MAGIC_ARRAY_SIZE - col1 = [converter(i) for i in range(0, size)] + col1 = [converter(i) for i in range(size)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -300,7 +300,7 @@ def test_struct_of_interval(self, duckdb_cursor, constructor, expected, converte def test_struct_of_duration(self, duckdb_cursor, constructor, unit, expected, col1_null, col2_null): size = MAGIC_ARRAY_SIZE - col1 = [i for i in range(0, size)] + col1 = [i for i in range(size)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -336,7 +336,7 @@ def test_struct_of_timestamp_tz(self, duckdb_cursor, constructor, unit, expected size = MAGIC_ARRAY_SIZE duckdb_cursor.execute("set timezone='UTC'") - col1 = [i for i in range(0, size)] + col1 = [i for i in range(size)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -362,7 +362,7 @@ def test_struct_of_timestamp_tz(self, duckdb_cursor, constructor, unit, expected @null_test_parameters() def test_struct_of_large_blobs(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -400,7 +400,7 @@ def test_struct_of_large_blobs(self, duckdb_cursor, col1_null, col2_null): ) def test_struct_of_decimal(self, duckdb_cursor, precision_scale, expected, col1_null, col2_null): precision, scale = precision_scale - col1 = [decimal_value(i, precision, scale) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [decimal_value(i, precision, scale) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -427,7 +427,7 @@ def test_struct_of_decimal(self, duckdb_cursor, precision_scale, expected, col1_ @null_test_parameters() def test_struct_of_small_list(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -457,7 +457,7 @@ def test_struct_of_small_list(self, duckdb_cursor, col1_null, col2_null): @null_test_parameters() def test_struct_of_fixed_size_list(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -487,7 +487,7 @@ def test_struct_of_fixed_size_list(self, duckdb_cursor, col1_null, col2_null): @null_test_parameters() def test_struct_of_fixed_size_blob(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -518,7 +518,7 @@ def test_struct_of_fixed_size_blob(self, duckdb_cursor, col1_null, col2_null): @null_test_parameters() def test_struct_of_list_of_blobs(self, duckdb_cursor, col1_null, col2_null): - col1 = [str(i) for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [str(i) for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -549,7 +549,7 @@ def test_struct_of_list_of_blobs(self, duckdb_cursor, col1_null, col2_null): @null_test_parameters() def test_struct_of_list_of_list(self, duckdb_cursor, col1_null, col2_null): - col1 = [i for i in range(0, MAGIC_ARRAY_SIZE)] + col1 = [i for i in range(MAGIC_ARRAY_SIZE)] if col1_null: col1[-1] = None # "a" in the struct matches the value for col1 @@ -581,7 +581,7 @@ def test_struct_of_list_of_list(self, duckdb_cursor, col1_null, col2_null): @pytest.mark.parametrize("col1_null", [True, False]) def test_list_of_struct(self, duckdb_cursor, col1_null): # One single tuple containing a very big list - tuples = [{"a": i} for i in range(0, MAGIC_ARRAY_SIZE)] + tuples = [{"a": i} for i in range(MAGIC_ARRAY_SIZE)] if col1_null: tuples[-1] = None tuples = [tuples] @@ -590,7 +590,7 @@ def test_list_of_struct(self, duckdb_cursor, col1_null): schema=pa.schema([("col1", pa.list_(pa.struct({"a": pa.int32()})))]), ) res = duckdb_cursor.sql( - f""" + """ SELECT col1 FROM arrow_table diff --git a/tests/fast/arrow/test_arrow_pycapsule.py b/tests/fast/arrow/test_arrow_pycapsule.py index 6df5053f..295f0292 100644 --- a/tests/fast/arrow/test_arrow_pycapsule.py +++ b/tests/fast/arrow/test_arrow_pycapsule.py @@ -1,6 +1,7 @@ -import duckdb + import pytest -import os + +import duckdb pl = pytest.importorskip("polars") @@ -14,7 +15,7 @@ def polars_supports_capsule(): @pytest.mark.skipif( not polars_supports_capsule(), reason="Polars version does not support the Arrow PyCapsule interface" ) -class TestArrowPyCapsule(object): +class TestArrowPyCapsule: def test_polars_pycapsule_scan(self, duckdb_cursor): class MyObject: def __init__(self, obj) -> None: diff --git a/tests/fast/arrow/test_arrow_recordbatchreader.py b/tests/fast/arrow/test_arrow_recordbatchreader.py index a9523d43..80520499 100644 --- a/tests/fast/arrow/test_arrow_recordbatchreader.py +++ b/tests/fast/arrow/test_arrow_recordbatchreader.py @@ -1,14 +1,16 @@ -import duckdb import os + import pytest +import duckdb + pyarrow = pytest.importorskip("pyarrow") pyarrow.parquet = pytest.importorskip("pyarrow.parquet") pyarrow.dataset = pytest.importorskip("pyarrow.dataset") np = pytest.importorskip("numpy") -class TestArrowRecordBatchReader(object): +class TestArrowRecordBatchReader: def test_parallel_reader(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") diff --git a/tests/fast/arrow/test_arrow_replacement_scan.py b/tests/fast/arrow/test_arrow_replacement_scan.py index f2a9c13b..614e1e9f 100644 --- a/tests/fast/arrow/test_arrow_replacement_scan.py +++ b/tests/fast/arrow/test_arrow_replacement_scan.py @@ -1,14 +1,15 @@ -import duckdb -import pytest import os -import pandas as pd + +import pytest + +import duckdb pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") ds = pytest.importorskip("pyarrow.dataset") -class TestArrowReplacementScan(object): +class TestArrowReplacementScan: def test_arrow_table_replacement_scan(self, duckdb_cursor): parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_table = pq.read_table(parquet_filename) diff --git a/tests/fast/arrow/test_arrow_run_end_encoding.py b/tests/fast/arrow/test_arrow_run_end_encoding.py index c6f9fad5..de841dd0 100644 --- a/tests/fast/arrow/test_arrow_run_end_encoding.py +++ b/tests/fast/arrow/test_arrow_run_end_encoding.py @@ -1,7 +1,4 @@ -import duckdb import pytest -import pandas as pd -import duckdb pa = pytest.importorskip("pyarrow", "21.0.0", reason="Needs pyarrow >= 21") pc = pytest.importorskip("pyarrow.compute") @@ -30,7 +27,7 @@ def list_constructors(): return result -class TestArrowREE(object): +class TestArrowREE: @pytest.mark.parametrize( "query", [ @@ -130,7 +127,7 @@ def test_arrow_run_end_encoding(self, duckdb_cursor, dbtype, val1, val2, filter) duckdb_cursor.execute(query) rel = duckdb_cursor.query("select * from ree_tbl") - expected = duckdb_cursor.query("select {} from ree_tbl where {}".format(projection, filter)).fetchall() + expected = duckdb_cursor.query(f"select {projection} from ree_tbl where {filter}").fetchall() # Create an Arrow Table from the table arrow_conversion = rel.fetch_arrow_table() @@ -156,7 +153,7 @@ def test_arrow_run_end_encoding(self, duckdb_cursor, dbtype, val1, val2, filter) tbl = pa.Table.from_arrays([encoded_arrays["ree"], encoded_arrays["a"], encoded_arrays["b"]], schema=schema) # Scan the Arrow Table and verify that the results are the same - res = duckdb_cursor.sql("select {} from tbl where {}".format(projection, filter)).fetchall() + res = duckdb_cursor.sql(f"select {projection} from tbl where {filter}").fetchall() assert res == expected def test_arrow_ree_empty_table(self, duckdb_cursor): @@ -227,26 +224,26 @@ def test_arrow_ree_projections(self, duckdb_cursor, projection): # This should be pushed down into arrow to only provide us with the necessary columns res = duckdb_cursor.query( - """ - select {} from arrow_tbl - """.format(projection) + f""" + select {projection} from arrow_tbl + """ ).fetch_arrow_table() # Verify correctness by fetching from the original table and the constructed result - expected = duckdb_cursor.query("select {} from tbl".format(projection)).fetchall() - actual = duckdb_cursor.query("select {} from res".format(projection)).fetchall() + expected = duckdb_cursor.query(f"select {projection} from tbl").fetchall() + actual = duckdb_cursor.query(f"select {projection} from res").fetchall() assert expected == actual @pytest.mark.parametrize("create_list", list_constructors()) def test_arrow_ree_list(self, duckdb_cursor, create_list): size = 1000 duckdb_cursor.query( - """ + f""" create table tbl as select i // 4 as ree, - FROM range({}) t(i) - """.format(size) + FROM range({size}) t(i) + """ ) # Populate the table with data @@ -325,15 +322,15 @@ def test_arrow_ree_union(self, duckdb_cursor): size = 1000 duckdb_cursor.query( - """ + f""" create table tbl as select i // 4 as ree, i as a, i % 2 == 0 as b, i::VARCHAR as c - FROM range({}) t(i) - """.format(size) + FROM range({size}) t(i) + """ ) # Populate the table with data @@ -383,13 +380,13 @@ def test_arrow_ree_map(self, duckdb_cursor): size = 1000 duckdb_cursor.query( - """ + f""" create table tbl as select i // 4 as ree, i as a, - FROM range({}) t(i) - """.format(size) + FROM range({size}) t(i) + """ ) # Populate the table with data @@ -433,12 +430,12 @@ def test_arrow_ree_dictionary(self, duckdb_cursor): size = 1000 duckdb_cursor.query( - """ + f""" create table tbl as select i // 4 as ree, - FROM range({}) t(i) - """.format(size) + FROM range({size}) t(i) + """ ) # Populate the table with data diff --git a/tests/fast/arrow/test_arrow_scanner.py b/tests/fast/arrow/test_arrow_scanner.py index 2e8b1296..ccfa5676 100644 --- a/tests/fast/arrow/test_arrow_scanner.py +++ b/tests/fast/arrow/test_arrow_scanner.py @@ -1,20 +1,20 @@ -import duckdb import os +import duckdb + try: import pyarrow - import pyarrow.parquet + import pyarrow.compute as pc import pyarrow.dataset + import pyarrow.parquet from pyarrow.dataset import Scanner - import pyarrow.compute as pc - import numpy as np can_run = True except: can_run = False -class TestArrowScanner(object): +class TestArrowScanner: def test_parallel_scanner(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_arrow_string_view.py b/tests/fast/arrow/test_arrow_string_view.py index a1b46e5b..0d34bb6e 100644 --- a/tests/fast/arrow/test_arrow_string_view.py +++ b/tests/fast/arrow/test_arrow_string_view.py @@ -1,6 +1,6 @@ -import duckdb import pytest -from packaging import version + +import duckdb pa = pytest.importorskip("pyarrow") @@ -43,7 +43,7 @@ def RoundTripDuckDBInternal(query): assert res[i] == from_arrow_res[i] -class TestArrowStringView(object): +class TestArrowStringView: # Test Small Inlined String View def test_inlined_string_view(self): RoundTripStringView( diff --git a/tests/fast/arrow/test_arrow_types.py b/tests/fast/arrow/test_arrow_types.py index f2bf71c7..199874cf 100644 --- a/tests/fast/arrow/test_arrow_types.py +++ b/tests/fast/arrow/test_arrow_types.py @@ -1,11 +1,12 @@ -import duckdb import pytest +import duckdb + pa = pytest.importorskip("pyarrow") ds = pytest.importorskip("pyarrow.dataset") -class TestArrowTypes(object): +class TestArrowTypes: def test_null_type(self, duckdb_cursor): schema = pa.schema([("data", pa.null())]) inputs = [pa.array([None, None, None], type=pa.null())] diff --git a/tests/fast/arrow/test_arrow_union.py b/tests/fast/arrow/test_arrow_union.py index c0a5d568..04fd73b3 100644 --- a/tests/fast/arrow/test_arrow_union.py +++ b/tests/fast/arrow/test_arrow_union.py @@ -2,8 +2,7 @@ importorskip("pyarrow") -import duckdb -from pyarrow import scalar, string, large_string, list_, int32, types +from pyarrow import int32, list_, scalar, string, types def test_nested(duckdb_cursor): diff --git a/tests/fast/arrow/test_arrow_version_format.py b/tests/fast/arrow/test_arrow_version_format.py index fd169ce0..ed335f9e 100644 --- a/tests/fast/arrow/test_arrow_version_format.py +++ b/tests/fast/arrow/test_arrow_version_format.py @@ -1,14 +1,16 @@ -import duckdb -import pytest from decimal import Decimal +import pytest + +import duckdb + pa = pytest.importorskip("pyarrow") -class TestArrowDecimalTypes(object): +class TestArrowDecimalTypes: def test_decimal_v1_5(self, duckdb_cursor): duckdb_cursor = duckdb.connect() - duckdb_cursor.execute(f"SET arrow_output_version = 1.5") + duckdb_cursor.execute("SET arrow_output_version = 1.5") decimal_32 = pa.Table.from_pylist( [ {"data": Decimal("100.20")}, @@ -51,11 +53,11 @@ def test_decimal_v1_5(self, duckdb_cursor): def test_invalide_opt(self, duckdb_cursor): duckdb_cursor = duckdb.connect() with pytest.raises(duckdb.NotImplementedException, match="unrecognized"): - duckdb_cursor.execute(f"SET arrow_output_version = 999.9") + duckdb_cursor.execute("SET arrow_output_version = 999.9") def test_view_v1_4(self, duckdb_cursor): duckdb_cursor = duckdb.connect() - duckdb_cursor.execute(f"SET arrow_output_version = 1.5") + duckdb_cursor.execute("SET arrow_output_version = 1.5") duckdb_cursor.execute("SET produce_arrow_string_view=True") duckdb_cursor.execute("SET arrow_output_list_view=True") col_type = duckdb_cursor.execute("SELECT 'string' as data ").fetch_arrow_table().schema.field("data").type diff --git a/tests/fast/arrow/test_binary_type.py b/tests/fast/arrow/test_binary_type.py index 489d4caf..5932fba8 100644 --- a/tests/fast/arrow/test_binary_type.py +++ b/tests/fast/arrow/test_binary_type.py @@ -1,10 +1,8 @@ + import duckdb -import os try: import pyarrow as pa - from pyarrow import parquet as pq - import numpy as np can_run = True except: @@ -17,7 +15,7 @@ def create_binary_table(type): return pa.Table.from_arrays(inputs, schema=schema) -class TestArrowBinary(object): +class TestArrowBinary: def test_binary_types(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_buffer_size_option.py b/tests/fast/arrow/test_buffer_size_option.py index 7d5131e5..845d0db0 100644 --- a/tests/fast/arrow/test_buffer_size_option.py +++ b/tests/fast/arrow/test_buffer_size_option.py @@ -1,11 +1,12 @@ -import duckdb import pytest +import duckdb + pa = pytest.importorskip("pyarrow") from duckdb.typing import * -class TestArrowBufferSize(object): +class TestArrowBufferSize: def test_arrow_buffer_size(self): con = duckdb.connect() diff --git a/tests/fast/arrow/test_dataset.py b/tests/fast/arrow/test_dataset.py index 8ec0094e..aa2a8b9b 100644 --- a/tests/fast/arrow/test_dataset.py +++ b/tests/fast/arrow/test_dataset.py @@ -1,14 +1,16 @@ -import duckdb import os + import pytest +import duckdb + pyarrow = pytest.importorskip("pyarrow") np = pytest.importorskip("numpy") pyarrow.parquet = pytest.importorskip("pyarrow.parquet") pyarrow.dataset = pytest.importorskip("pyarrow.dataset") -class TestArrowDataset(object): +class TestArrowDataset: def test_parallel_dataset(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") diff --git a/tests/fast/arrow/test_date.py b/tests/fast/arrow/test_date.py index 9649ffa6..83c14932 100644 --- a/tests/fast/arrow/test_date.py +++ b/tests/fast/arrow/test_date.py @@ -1,18 +1,16 @@ + + import duckdb -import os -import datetime -import pytest try: import pyarrow as pa - import pandas as pd can_run = True except: can_run = False -class TestArrowDate(object): +class TestArrowDate: def test_date_types(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_dictionary_arrow.py b/tests/fast/arrow/test_dictionary_arrow.py index e4319f7c..5cb2d38d 100644 --- a/tests/fast/arrow/test_dictionary_arrow.py +++ b/tests/fast/arrow/test_dictionary_arrow.py @@ -1,4 +1,3 @@ -import duckdb import pytest @@ -12,7 +11,7 @@ Timestamp = pd.Timestamp -class TestArrowDictionary(object): +class TestArrowDictionary: def test_dictionary(self, duckdb_cursor): indices = pa.array([0, 1, 0, 1, 2, 1, 0, 2]) dictionary = pa.array([10, 100, None]) diff --git a/tests/fast/arrow/test_filter_pushdown.py b/tests/fast/arrow/test_filter_pushdown.py index 026b52f4..2238f744 100644 --- a/tests/fast/arrow/test_filter_pushdown.py +++ b/tests/fast/arrow/test_filter_pushdown.py @@ -1,12 +1,11 @@ -from re import S -import duckdb -import os +import sys + import pytest -import tempfile from conftest import pandas_supports_arrow_backend -import sys from packaging.version import Version +import duckdb + pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") ds = pytest.importorskip("pyarrow.dataset") @@ -178,7 +177,7 @@ def string_check_or_pushdown(connection, tbl_name, create_table): assert not match -class TestArrowFilterPushdown(object): +class TestArrowFilterPushdown: @pytest.mark.parametrize( "data_type", [ @@ -532,7 +531,6 @@ def test_filter_pushdown_integers(self, duckdb_cursor, data_type, value, create_ ) def test_9371(self, duckdb_cursor, tmp_path): import datetime - import pathlib # connect to an in-memory database duckdb_cursor.execute("SET TimeZone='UTC';") diff --git a/tests/fast/arrow/test_integration.py b/tests/fast/arrow/test_integration.py index 6ab7350d..1c00c800 100644 --- a/tests/fast/arrow/test_integration.py +++ b/tests/fast/arrow/test_integration.py @@ -1,14 +1,16 @@ -import duckdb -import os import datetime +import os + import pytest +import duckdb + pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") np = pytest.importorskip("numpy") -class TestArrowIntegration(object): +class TestArrowIntegration: def test_parquet_roundtrip(self, duckdb_cursor): parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") cols = "id, first_name, last_name, email, gender, ip_address, cc, country, birthdate, salary, title, comments" @@ -216,7 +218,7 @@ def test_strings_roundtrip(self, duckdb_cursor): duckdb_cursor.execute("CREATE TABLE test (a varchar)") # Test Small, Null and Very Big String - for i in range(0, 1000): + for i in range(1000): duckdb_cursor.execute( "INSERT INTO test VALUES ('Matt Damon'),(NULL), ('Jeffffreeeey Jeeeeef Baaaaaaazos'), ('X-Content-Type-Options')" ) diff --git a/tests/fast/arrow/test_interval.py b/tests/fast/arrow/test_interval.py index 32b7fa64..5cdb04bd 100644 --- a/tests/fast/arrow/test_interval.py +++ b/tests/fast/arrow/test_interval.py @@ -1,18 +1,17 @@ -import duckdb -import os -import datetime + import pytest +import duckdb + try: import pyarrow as pa - import pandas as pd can_run = True except: can_run = False -class TestArrowInterval(object): +class TestArrowInterval: def test_duration_types(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_large_offsets.py b/tests/fast/arrow/test_large_offsets.py index dccfa101..0a2669f5 100644 --- a/tests/fast/arrow/test_large_offsets.py +++ b/tests/fast/arrow/test_large_offsets.py @@ -1,9 +1,6 @@ -from re import S -import duckdb -import os import pytest -import tempfile -from conftest import pandas_supports_arrow_backend + +import duckdb pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") @@ -11,7 +8,7 @@ np = pytest.importorskip("numpy") -class TestArrowLargeOffsets(object): +class TestArrowLargeOffsets: @pytest.mark.skip(reason="CI does not have enough memory to validate this") def test_large_lists(self, duckdb_cursor): ary = pa.array([np.arange(start=0, stop=3000, dtype=np.uint8) for i in range(1_000_000)]) diff --git a/tests/fast/arrow/test_large_string.py b/tests/fast/arrow/test_large_string.py index 308785af..bb9d1b5b 100644 --- a/tests/fast/arrow/test_large_string.py +++ b/tests/fast/arrow/test_large_string.py @@ -1,17 +1,15 @@ + import duckdb -import os try: import pyarrow as pa - from pyarrow import parquet as pq - import numpy as np can_run = True except: can_run = False -class TestArrowLargeString(object): +class TestArrowLargeString: def test_large_string_type(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_multiple_reads.py b/tests/fast/arrow/test_multiple_reads.py index 36fb8f59..30b0c02a 100644 --- a/tests/fast/arrow/test_multiple_reads.py +++ b/tests/fast/arrow/test_multiple_reads.py @@ -1,6 +1,7 @@ -import duckdb import os +import duckdb + try: import pyarrow import pyarrow.parquet @@ -10,7 +11,7 @@ can_run = False -class TestArrowReads(object): +class TestArrowReads: def test_multiple_queries_same_relation(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_nested_arrow.py b/tests/fast/arrow/test_nested_arrow.py index a906324f..42b674e3 100644 --- a/tests/fast/arrow/test_nested_arrow.py +++ b/tests/fast/arrow/test_nested_arrow.py @@ -1,7 +1,7 @@ -import duckdb - import pytest +import duckdb + pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") np = pytest.importorskip("numpy") @@ -27,7 +27,7 @@ def get_use_list_view_options(): return result -class TestArrowNested(object): +class TestArrowNested: def test_lists_basic(self, duckdb_cursor): # Test Constant List query = ( diff --git a/tests/fast/arrow/test_parallel.py b/tests/fast/arrow/test_parallel.py index c768a1dd..3348c13e 100644 --- a/tests/fast/arrow/test_parallel.py +++ b/tests/fast/arrow/test_parallel.py @@ -1,17 +1,18 @@ -import duckdb import os +import duckdb + try: + import numpy as np import pyarrow import pyarrow.parquet - import numpy as np can_run = True except: can_run = False -class TestArrowParallel(object): +class TestArrowParallel: def test_parallel_run(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_polars.py b/tests/fast/arrow/test_polars.py index a4e94d18..329a9758 100644 --- a/tests/fast/arrow/test_polars.py +++ b/tests/fast/arrow/test_polars.py @@ -1,8 +1,9 @@ -import duckdb -import pytest -import sys import datetime +import pytest + +import duckdb + pl = pytest.importorskip("polars") arrow = pytest.importorskip("pyarrow") pl_testing = pytest.importorskip("polars.testing") @@ -20,7 +21,7 @@ def invalid_filter(filter): assert sql_expression is None -class TestPolars(object): +class TestPolars: def test_polars(self, duckdb_cursor): df = pl.DataFrame( { diff --git a/tests/fast/arrow/test_progress.py b/tests/fast/arrow/test_progress.py index 6f056937..4c558784 100644 --- a/tests/fast/arrow/test_progress.py +++ b/tests/fast/arrow/test_progress.py @@ -1,12 +1,13 @@ -import duckdb import os + import pytest +import duckdb + pyarrow_parquet = pytest.importorskip("pyarrow.parquet") -import sys -class TestProgressBarArrow(object): +class TestProgressBarArrow: def test_progress_arrow(self): if os.name == "nt": return diff --git a/tests/fast/arrow/test_projection_pushdown.py b/tests/fast/arrow/test_projection_pushdown.py index 802259e1..803a2703 100644 --- a/tests/fast/arrow/test_projection_pushdown.py +++ b/tests/fast/arrow/test_projection_pushdown.py @@ -1,9 +1,7 @@ -import duckdb -import os import pytest -class TestArrowProjectionPushdown(object): +class TestArrowProjectionPushdown: def test_projection_pushdown_no_filter(self, duckdb_cursor): pa = pytest.importorskip("pyarrow") ds = pytest.importorskip("pyarrow.dataset") diff --git a/tests/fast/arrow/test_time.py b/tests/fast/arrow/test_time.py index e7c4404e..b3bab360 100644 --- a/tests/fast/arrow/test_time.py +++ b/tests/fast/arrow/test_time.py @@ -1,18 +1,16 @@ + + import duckdb -import os -import datetime -import pytest try: import pyarrow as pa - import pandas as pd can_run = True except: can_run = False -class TestArrowTime(object): +class TestArrowTime: def test_time_types(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_timestamp_timezone.py b/tests/fast/arrow/test_timestamp_timezone.py index 08816be1..c056f19f 100644 --- a/tests/fast/arrow/test_timestamp_timezone.py +++ b/tests/fast/arrow/test_timestamp_timezone.py @@ -1,8 +1,10 @@ -import duckdb -import pytest import datetime + +import pytest import pytz +import duckdb + pa = pytest.importorskip("pyarrow") @@ -16,7 +18,7 @@ def generate_table(current_time, precision, timezone): timezones = ["UTC", "BET", "CET", "Asia/Kathmandu"] -class TestArrowTimestampsTimezone(object): +class TestArrowTimestampsTimezone: def test_timestamp_timezone(self, duckdb_cursor): precisions = ["us", "s", "ns", "ms"] current_time = datetime.datetime(2017, 11, 28, 23, 55, 59, tzinfo=pytz.UTC) diff --git a/tests/fast/arrow/test_timestamps.py b/tests/fast/arrow/test_timestamps.py index 684a333c..6efe0000 100644 --- a/tests/fast/arrow/test_timestamps.py +++ b/tests/fast/arrow/test_timestamps.py @@ -1,18 +1,16 @@ -import duckdb -import os import datetime -import pytest + +import duckdb try: import pyarrow as pa - import pandas as pd can_run = True except: can_run = False -class TestArrowTimestamps(object): +class TestArrowTimestamps: def test_timestamp_types(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_tpch.py b/tests/fast/arrow/test_tpch.py index d5d13b20..30eca05c 100644 --- a/tests/fast/arrow/test_tpch.py +++ b/tests/fast/arrow/test_tpch.py @@ -1,10 +1,10 @@ import pytest + import duckdb try: import pyarrow import pyarrow.parquet - import numpy as np can_run = True except: @@ -24,7 +24,7 @@ def check_result(result, answers): db_result = result.fetchone() cq_results = q_res.split("|") # The end of the rows, continue - if cq_results == [""] and str(db_result) == "None" or str(db_result[0]) == "None": + if (cq_results == [""] and str(db_result) == "None") or str(db_result[0]) == "None": continue ans_result = [munge(cell) for cell in cq_results] db_result = [munge(cell) for cell in db_result] @@ -34,7 +34,7 @@ def check_result(result, answers): @pytest.mark.skip(reason="Test needs to be adapted to missing TPCH extension") -class TestTPCHArrow(object): +class TestTPCHArrow: def test_tpch_arrow(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_unregister.py b/tests/fast/arrow/test_unregister.py index 8ff37b5a..de8bab9c 100644 --- a/tests/fast/arrow/test_unregister.py +++ b/tests/fast/arrow/test_unregister.py @@ -1,8 +1,10 @@ -import pytest -import tempfile import gc -import duckdb import os +import tempfile + +import pytest + +import duckdb try: import pyarrow @@ -13,7 +15,7 @@ can_run = False -class TestArrowUnregister(object): +class TestArrowUnregister: def test_arrow_unregister1(self, duckdb_cursor): if not can_run: return diff --git a/tests/fast/arrow/test_view.py b/tests/fast/arrow/test_view.py index 7f1410aa..98b0b6cc 100644 --- a/tests/fast/arrow/test_view.py +++ b/tests/fast/arrow/test_view.py @@ -1,12 +1,12 @@ -import duckdb import os + import pytest pa = pytest.importorskip("pyarrow") pq = pytest.importorskip("pyarrow.parquet") -class TestArrowView(object): +class TestArrowView: def test_arrow_view(self, duckdb_cursor): parquet_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "userdata1.parquet") userdata_parquet_table = pa.parquet.read_table(parquet_filename) diff --git a/tests/fast/numpy/test_numpy_new_path.py b/tests/fast/numpy/test_numpy_new_path.py index b872d4d9..d95c93d1 100644 --- a/tests/fast/numpy/test_numpy_new_path.py +++ b/tests/fast/numpy/test_numpy_new_path.py @@ -2,13 +2,15 @@ Therefore, we only test the new codes and exec paths. """ -import numpy as np -import duckdb from datetime import timedelta + +import numpy as np import pytest +import duckdb + -class TestScanNumpy(object): +class TestScanNumpy: def test_scan_numpy(self, duckdb_cursor): z = np.array([1, 2, 3]) res = duckdb_cursor.sql("select * from z").fetchall() diff --git a/tests/fast/pandas/test_2304.py b/tests/fast/pandas/test_2304.py index 11344df8..859c5265 100644 --- a/tests/fast/pandas/test_2304.py +++ b/tests/fast/pandas/test_2304.py @@ -1,10 +1,11 @@ -import duckdb import numpy as np import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestPandasMergeSameName(object): +class TestPandasMergeSameName: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_2304(self, duckdb_cursor, pandas): df1 = pandas.DataFrame( diff --git a/tests/fast/pandas/test_append_df.py b/tests/fast/pandas/test_append_df.py index e6d64776..d93cfa2d 100644 --- a/tests/fast/pandas/test_append_df.py +++ b/tests/fast/pandas/test_append_df.py @@ -1,9 +1,10 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestAppendDF(object): +class TestAppendDF: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_df_to_table_append(self, duckdb_cursor, pandas): conn = duckdb.connect() diff --git a/tests/fast/pandas/test_bug2281.py b/tests/fast/pandas/test_bug2281.py index 98a90937..ca80504d 100644 --- a/tests/fast/pandas/test_bug2281.py +++ b/tests/fast/pandas/test_bug2281.py @@ -1,12 +1,9 @@ -import duckdb -import os -import datetime -import pytest -import pandas as pd import io +import pandas as pd + -class TestPandasStringNull(object): +class TestPandasStringNull: def test_pandas_string_null(self, duckdb_cursor): csv = """what,is_control,is_test ,0,0 diff --git a/tests/fast/pandas/test_bug5922.py b/tests/fast/pandas/test_bug5922.py index 28daabe9..584fe710 100644 --- a/tests/fast/pandas/test_bug5922.py +++ b/tests/fast/pandas/test_bug5922.py @@ -1,9 +1,10 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestPandasAcceptFloat16(object): +class TestPandasAcceptFloat16: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_accept_float16(self, duckdb_cursor, pandas): df = pandas.DataFrame({"col": [1, 2, 3]}) diff --git a/tests/fast/pandas/test_copy_on_write.py b/tests/fast/pandas/test_copy_on_write.py index ec1b8786..0fcf503f 100644 --- a/tests/fast/pandas/test_copy_on_write.py +++ b/tests/fast/pandas/test_copy_on_write.py @@ -1,6 +1,7 @@ -import duckdb import pytest +import duckdb + # https://pandas.pydata.org/docs/dev/user_guide/copy_on_write.html pandas = pytest.importorskip("pandas", "1.5", reason="copy_on_write does not exist in earlier versions") import datetime @@ -21,7 +22,7 @@ def convert_to_result(col): return [(x,) for x in col] -class TestCopyOnWrite(object): +class TestCopyOnWrite: @pytest.mark.parametrize( "col", [ diff --git a/tests/fast/pandas/test_create_table_from_pandas.py b/tests/fast/pandas/test_create_table_from_pandas.py index 2194d964..bc5792e0 100644 --- a/tests/fast/pandas/test_create_table_from_pandas.py +++ b/tests/fast/pandas/test_create_table_from_pandas.py @@ -1,8 +1,9 @@ +import sys + import pytest +from conftest import ArrowPandas, NumpyPandas + import duckdb -import numpy as np -import sys -from conftest import NumpyPandas, ArrowPandas def assert_create(internal_data, expected_result, data_type, pandas): @@ -25,7 +26,7 @@ def assert_create_register(internal_data, expected_result, data_type, pandas): assert result == expected_result -class TestCreateTableFromPandas(object): +class TestCreateTableFromPandas: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_integer_create_table(self, duckdb_cursor, pandas): if sys.version_info.major < 3: diff --git a/tests/fast/pandas/test_date_as_datetime.py b/tests/fast/pandas/test_date_as_datetime.py index b738b2e1..484674ea 100644 --- a/tests/fast/pandas/test_date_as_datetime.py +++ b/tests/fast/pandas/test_date_as_datetime.py @@ -1,7 +1,8 @@ +import datetime + import pandas as pd + import duckdb -import datetime -import pytest def run_checks(df): diff --git a/tests/fast/pandas/test_datetime_time.py b/tests/fast/pandas/test_datetime_time.py index 1a5a3f7a..0b2642b0 100644 --- a/tests/fast/pandas/test_datetime_time.py +++ b/tests/fast/pandas/test_datetime_time.py @@ -1,13 +1,15 @@ -import duckdb +from datetime import datetime, time, timezone + import numpy as np import pytest -from conftest import NumpyPandas, ArrowPandas -from datetime import datetime, timezone, time, timedelta +from conftest import ArrowPandas, NumpyPandas + +import duckdb _ = pytest.importorskip("pandas", minversion="2.0.0") -class TestDateTimeTime(object): +class TestDateTimeTime: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_time_high(self, duckdb_cursor, pandas): duckdb_time = duckdb_cursor.sql("SELECT make_time(23, 1, 34.234345) AS '0'").df() diff --git a/tests/fast/pandas/test_datetime_timestamp.py b/tests/fast/pandas/test_datetime_timestamp.py index ffc1b7d8..2649cee0 100644 --- a/tests/fast/pandas/test_datetime_timestamp.py +++ b/tests/fast/pandas/test_datetime_timestamp.py @@ -1,14 +1,13 @@ -import duckdb import datetime -import numpy as np + import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas from packaging.version import Version pd = pytest.importorskip("pandas") -class TestDateTimeTimeStamp(object): +class TestDateTimeTimeStamp: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_timestamp_high(self, pandas, duckdb_cursor): duckdb_time = duckdb_cursor.sql("SELECT '2260-01-01 23:59:00'::TIMESTAMP AS '0'").df() diff --git a/tests/fast/pandas/test_df_analyze.py b/tests/fast/pandas/test_df_analyze.py index 8e67da4a..92318085 100644 --- a/tests/fast/pandas/test_df_analyze.py +++ b/tests/fast/pandas/test_df_analyze.py @@ -1,15 +1,16 @@ -import duckdb -import datetime + import numpy as np import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb def create_generic_dataframe(data, pandas): return pandas.DataFrame({"col0": pandas.Series(data=data, dtype="object")}) -class TestResolveObjectColumns(object): +class TestResolveObjectColumns: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_sample_low_correct(self, duckdb_cursor, pandas): print(pandas.backend) diff --git a/tests/fast/pandas/test_df_object_resolution.py b/tests/fast/pandas/test_df_object_resolution.py index 73470818..27dc2116 100644 --- a/tests/fast/pandas/test_df_object_resolution.py +++ b/tests/fast/pandas/test_df_object_resolution.py @@ -1,13 +1,15 @@ -import duckdb import datetime -import numpy as np -import platform -import pytest import decimal import math -from decimal import Decimal +import platform import re -from conftest import NumpyPandas, ArrowPandas +from decimal import Decimal + +import numpy as np +import pytest +from conftest import ArrowPandas, NumpyPandas + +import duckdb standard_vector_size = duckdb.__standard_vector_size__ @@ -81,7 +83,7 @@ def check_struct_upgrade(expected_type: str, creation_method, pair: ObjectPair, assert expected_type == rel.types[0] -class TestResolveObjectColumns(object): +class TestResolveObjectColumns: # TODO: add support for ArrowPandas @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_integers(self, pandas, duckdb_cursor): @@ -674,7 +676,7 @@ def test_multiple_chunks(self, pandas, duckdb_cursor): @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): - duckdb_cursor.execute(f"SET GLOBAL pandas_analyze_sample=4096") + duckdb_cursor.execute("SET GLOBAL pandas_analyze_sample=4096") duckdb_cursor.execute( "create table dates as select '2022-09-14'::DATE + INTERVAL (i::INTEGER) DAY as i from range(4096) tbl(i);" ) diff --git a/tests/fast/pandas/test_df_recursive_nested.py b/tests/fast/pandas/test_df_recursive_nested.py index fb7d2ad0..4eacf777 100644 --- a/tests/fast/pandas/test_df_recursive_nested.py +++ b/tests/fast/pandas/test_df_recursive_nested.py @@ -1,9 +1,8 @@ -import duckdb -import datetime -import numpy as np + import pytest -import copy -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb from duckdb import Value NULL = None @@ -23,7 +22,7 @@ def create_reference_query(): return query -class TestDFRecursiveNested(object): +class TestDFRecursiveNested: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_list_of_structs(self, duckdb_cursor, pandas): data = [[{"a": 5}, NULL, {"a": NULL}], NULL, [{"a": 5}, NULL, {"a": NULL}]] diff --git a/tests/fast/pandas/test_fetch_df_chunk.py b/tests/fast/pandas/test_fetch_df_chunk.py index 1f2d4b1b..90f4e428 100644 --- a/tests/fast/pandas/test_fetch_df_chunk.py +++ b/tests/fast/pandas/test_fetch_df_chunk.py @@ -1,10 +1,11 @@ import pytest + import duckdb VECTOR_SIZE = duckdb.__standard_vector_size__ -class TestType(object): +class TestType: def test_fetch_df_chunk(self): size = 3000 con = duckdb.connect() diff --git a/tests/fast/pandas/test_fetch_nested.py b/tests/fast/pandas/test_fetch_nested.py index e25a44ba..6e878643 100644 --- a/tests/fast/pandas/test_fetch_nested.py +++ b/tests/fast/pandas/test_fetch_nested.py @@ -1,6 +1,7 @@ + import pytest + import duckdb -import sys pd = pytest.importorskip("pandas") import numpy as np @@ -55,7 +56,7 @@ def list_test_cases(): }), ("SELECT a from (SELECT LIST(i) as a FROM range(10000) tbl(i)) as t", { 'a': [ - list(range(0, 10000)) + list(range(10000)) ] }), ("SELECT LIST(i) as a FROM range(5) tbl(i) group by i%2 order by all", { @@ -146,7 +147,7 @@ def list_test_cases(): return test_cases -class TestFetchNested(object): +class TestFetchNested: @pytest.mark.parametrize("query, expected", list_test_cases()) def test_fetch_df_list(self, duckdb_cursor, query, expected): compare_results(duckdb_cursor, query, expected) diff --git a/tests/fast/pandas/test_implicit_pandas_scan.py b/tests/fast/pandas/test_implicit_pandas_scan.py index 2d4610ff..3808c42a 100644 --- a/tests/fast/pandas/test_implicit_pandas_scan.py +++ b/tests/fast/pandas/test_implicit_pandas_scan.py @@ -1,11 +1,12 @@ # simple DB API testcase -import duckdb import pandas as pd import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas from packaging.version import Version +import duckdb + numpy_nullable_df = pd.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val4", "CoL2": 17}]) try: @@ -22,7 +23,7 @@ pyarrow_df = numpy_nullable_df -class TestImplicitPandasScan(object): +class TestImplicitPandasScan: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_local_pandas_scan(self, duckdb_cursor, pandas): con = duckdb.connect() diff --git a/tests/fast/pandas/test_import_cache.py b/tests/fast/pandas/test_import_cache.py index 6ed601c5..d67b50ca 100644 --- a/tests/fast/pandas/test_import_cache.py +++ b/tests/fast/pandas/test_import_cache.py @@ -1,6 +1,7 @@ -from conftest import NumpyPandas, ArrowPandas -import duckdb import pytest +from conftest import ArrowPandas, NumpyPandas + +import duckdb @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) diff --git a/tests/fast/pandas/test_issue_1767.py b/tests/fast/pandas/test_issue_1767.py index 27f0c2ff..48d3e852 100644 --- a/tests/fast/pandas/test_issue_1767.py +++ b/tests/fast/pandas/test_issue_1767.py @@ -1,14 +1,13 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -import duckdb -import numpy import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb # Join from pandas not matching identical strings #1767 -class TestIssue1767(object): +class TestIssue1767: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_unicode_join_pandas(self, duckdb_cursor, pandas): A = pandas.DataFrame({"key": ["a", "п"]}) diff --git a/tests/fast/pandas/test_limit.py b/tests/fast/pandas/test_limit.py index 460716cd..51c4a382 100644 --- a/tests/fast/pandas/test_limit.py +++ b/tests/fast/pandas/test_limit.py @@ -1,9 +1,10 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestLimitPandas(object): +class TestLimitPandas: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_limit_df(self, duckdb_cursor, pandas): df_in = pandas.DataFrame( diff --git a/tests/fast/pandas/test_pandas_arrow.py b/tests/fast/pandas/test_pandas_arrow.py index e1661041..cd82736b 100644 --- a/tests/fast/pandas/test_pandas_arrow.py +++ b/tests/fast/pandas/test_pandas_arrow.py @@ -1,16 +1,17 @@ -import duckdb -import pytest import datetime +import pytest from conftest import pandas_supports_arrow_backend +import duckdb + pd = pytest.importorskip("pandas", "2.0.0") import numpy as np from pandas.api.types import is_integer_dtype @pytest.mark.skipif(not pandas_supports_arrow_backend(), reason="pandas does not support the 'pyarrow' backend") -class TestPandasArrow(object): +class TestPandasArrow: def test_pandas_arrow(self, duckdb_cursor): pd = pytest.importorskip("pandas") df = pd.DataFrame({"a": pd.Series([5, 4, 3])}).convert_dtypes() diff --git a/tests/fast/pandas/test_pandas_category.py b/tests/fast/pandas/test_pandas_category.py index 4b29b3fb..b40fefb8 100644 --- a/tests/fast/pandas/test_pandas_category.py +++ b/tests/fast/pandas/test_pandas_category.py @@ -1,8 +1,9 @@ -import duckdb -import pandas as pd import numpy +import pandas as pd import pytest +import duckdb + def check_category_equal(category): df_in = pd.DataFrame( @@ -54,7 +55,7 @@ def check_create_table(category): conn.execute("DROP TABLE t1") -class TestCategory(object): +class TestCategory: def test_category_simple(self, duckdb_cursor): df_in = pd.DataFrame({"float": [1.0, 2.0, 1.0], "int": pd.Series([1, 2, 1], dtype="category")}) diff --git a/tests/fast/pandas/test_pandas_df_none.py b/tests/fast/pandas/test_pandas_df_none.py index 50e1553c..5fa76c8c 100644 --- a/tests/fast/pandas/test_pandas_df_none.py +++ b/tests/fast/pandas/test_pandas_df_none.py @@ -1,11 +1,7 @@ -import pandas as pd -import pytest import duckdb -import sys -import gc -class TestPandasDFNone(object): +class TestPandasDFNone: # This used to decrease the ref count of None def test_none_deref(self): con = duckdb.connect() diff --git a/tests/fast/pandas/test_pandas_enum.py b/tests/fast/pandas/test_pandas_enum.py index b1eb2c7f..5b246fcf 100644 --- a/tests/fast/pandas/test_pandas_enum.py +++ b/tests/fast/pandas/test_pandas_enum.py @@ -1,9 +1,10 @@ import pandas as pd import pytest + import duckdb -class TestPandasEnum(object): +class TestPandasEnum: def test_3480(self, duckdb_cursor): duckdb_cursor.execute( """ @@ -14,7 +15,7 @@ def test_3480(self, duckdb_cursor): ); """ ) - df = duckdb_cursor.query(f"SELECT * FROM tab LIMIT 0;").to_df() + df = duckdb_cursor.query("SELECT * FROM tab LIMIT 0;").to_df() assert df["cat"].cat.categories.equals(pd.Index(["marie", "duchess", "toulouse"])) duckdb_cursor.execute("DROP TABLE tab") duckdb_cursor.execute("DROP TYPE cat") @@ -41,7 +42,7 @@ def test_3479(self, duckdb_cursor): duckdb.ConversionException, match="Type UINT8 with value 0 can't be cast because the value is out of range for the destination type UINT8", ): - duckdb_cursor.execute(f"INSERT INTO tab SELECT * FROM df;") + duckdb_cursor.execute("INSERT INTO tab SELECT * FROM df;") assert duckdb_cursor.execute("select * from tab").fetchall() == [] duckdb_cursor.execute("DROP TABLE tab") diff --git a/tests/fast/pandas/test_pandas_limit.py b/tests/fast/pandas/test_pandas_limit.py index d551a6e4..89fe1583 100644 --- a/tests/fast/pandas/test_pandas_limit.py +++ b/tests/fast/pandas/test_pandas_limit.py @@ -1,9 +1,8 @@ + import duckdb -import pandas as pd -import pytest -class TestPandasLimit(object): +class TestPandasLimit: def test_pandas_limit(self, duckdb_cursor): con = duckdb.connect() df = con.execute("select * from range(10000000) tbl(i)").df() diff --git a/tests/fast/pandas/test_pandas_na.py b/tests/fast/pandas/test_pandas_na.py index 7bc01003..f83be08a 100644 --- a/tests/fast/pandas/test_pandas_na.py +++ b/tests/fast/pandas/test_pandas_na.py @@ -1,9 +1,10 @@ +import platform + import numpy as np -import datetime -import duckdb import pytest -import platform -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb def assert_nullness(items, null_indices): @@ -15,7 +16,7 @@ def assert_nullness(items, null_indices): @pytest.mark.skipif(platform.system() == "Emscripten", reason="Pandas interaction is broken in Pyodide 3.11") -class TestPandasNA(object): +class TestPandasNA: @pytest.mark.parametrize("rows", [100, duckdb.__standard_vector_size__, 5000, 1000000]) @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) def test_pandas_string_null(self, duckdb_cursor, rows, pd): diff --git a/tests/fast/pandas/test_pandas_object.py b/tests/fast/pandas/test_pandas_object.py index 9e10681c..bb8e3eff 100644 --- a/tests/fast/pandas/test_pandas_object.py +++ b/tests/fast/pandas/test_pandas_object.py @@ -1,11 +1,12 @@ -import pandas as pd -import duckdb import datetime + import numpy as np -import random +import pandas as pd + +import duckdb -class TestPandasObject(object): +class TestPandasObject: def test_object_lotof_nulls(self): # Test mostly null column data = [None] + [1] + [None] * 10000 # Last element is 1, others are None diff --git a/tests/fast/pandas/test_pandas_string.py b/tests/fast/pandas/test_pandas_string.py index 4bd5996d..d1302f89 100644 --- a/tests/fast/pandas/test_pandas_string.py +++ b/tests/fast/pandas/test_pandas_string.py @@ -1,9 +1,10 @@ -import duckdb -import pandas as pd import numpy +import pandas as pd + +import duckdb -class TestPandasString(object): +class TestPandasString: def test_pandas_string(self, duckdb_cursor): strings = numpy.array(["foo", "bar", "baz"]) @@ -31,12 +32,12 @@ def test_bug_2467(self, duckdb_cursor): con = duckdb.connect() con.register("df", df) con.execute( - f""" + """ CREATE TABLE t1 AS SELECT * FROM df """ ) assert con.execute( - f""" + """ SELECT count(*) from t1 """ ).fetchall() == [(3000000,)] diff --git a/tests/fast/pandas/test_pandas_timestamp.py b/tests/fast/pandas/test_pandas_timestamp.py index 835ff3af..635cee36 100644 --- a/tests/fast/pandas/test_pandas_timestamp.py +++ b/tests/fast/pandas/test_pandas_timestamp.py @@ -1,11 +1,11 @@ -import duckdb +from datetime import datetime + import pandas import pytest - -from datetime import datetime -from pytz import timezone from conftest import pandas_2_or_higher +import duckdb + @pytest.mark.parametrize("timezone", ["UTC", "CET", "Asia/Kathmandu"]) @pytest.mark.skipif(not pandas_2_or_higher(), reason="Pandas <2.0.0 does not support timezones in the metadata string") diff --git a/tests/fast/pandas/test_pandas_types.py b/tests/fast/pandas/test_pandas_types.py index fcc63b82..f7df363d 100644 --- a/tests/fast/pandas/test_pandas_types.py +++ b/tests/fast/pandas/test_pandas_types.py @@ -1,12 +1,14 @@ -import duckdb -import pytest -import pandas as pd -import numpy import string -from packaging import version import warnings from contextlib import suppress +import numpy +import pandas as pd +import pytest +from packaging import version + +import duckdb + def round_trip(data, pandas_type): df_in = pd.DataFrame( @@ -21,7 +23,7 @@ def round_trip(data, pandas_type): assert df_out.equals(df_in) -class TestNumpyNullableTypes(object): +class TestNumpyNullableTypes: def test_pandas_numeric(self): base_df = pd.DataFrame({"a": range(10)}) diff --git a/tests/fast/pandas/test_pandas_unregister.py b/tests/fast/pandas/test_pandas_unregister.py index fce8f42a..bce93158 100644 --- a/tests/fast/pandas/test_pandas_unregister.py +++ b/tests/fast/pandas/test_pandas_unregister.py @@ -1,13 +1,14 @@ -import duckdb -import pytest -import tempfile -import os import gc +import os +import tempfile + import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestPandasUnregister(object): +class TestPandasUnregister: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_pandas_unregister1(self, duckdb_cursor, pandas): df = pandas.DataFrame([[1, 2, 3], [4, 5, 6]]) diff --git a/tests/fast/pandas/test_pandas_update.py b/tests/fast/pandas/test_pandas_update.py index 86d17154..bc1740d9 100644 --- a/tests/fast/pandas/test_pandas_update.py +++ b/tests/fast/pandas/test_pandas_update.py @@ -1,8 +1,9 @@ -import duckdb import pandas as pd +import duckdb + -class TestPandasUpdateList(object): +class TestPandasUpdateList: def test_pandas_update_list(self, duckdb_cursor): duckdb_cursor = duckdb.connect(":memory:") duckdb_cursor.execute("create table t (l int[])") diff --git a/tests/fast/pandas/test_parallel_pandas_scan.py b/tests/fast/pandas/test_parallel_pandas_scan.py index d113bbca..b389fce5 100644 --- a/tests/fast/pandas/test_parallel_pandas_scan.py +++ b/tests/fast/pandas/test_parallel_pandas_scan.py @@ -1,14 +1,15 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -import duckdb -import numpy import datetime + +import numpy import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb def run_parallel_queries(main_table, left_join_table, expected_df, pandas, iteration_count=5): - for i in range(0, iteration_count): + for i in range(iteration_count): output_df = None sql = """ select @@ -35,7 +36,7 @@ def run_parallel_queries(main_table, left_join_table, expected_df, pandas, itera duckdb_conn.close() -class TestParallelPandasScan(object): +class TestParallelPandasScan: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_parallel_numeric_scan(self, duckdb_cursor, pandas): main_table = pandas.DataFrame([{"join_column": 3}]) diff --git a/tests/fast/pandas/test_partitioned_pandas_scan.py b/tests/fast/pandas/test_partitioned_pandas_scan.py index d2447ef8..9f580659 100644 --- a/tests/fast/pandas/test_partitioned_pandas_scan.py +++ b/tests/fast/pandas/test_partitioned_pandas_scan.py @@ -1,11 +1,11 @@ -import duckdb -import pandas as pd + import numpy -import datetime -import time +import pandas as pd + +import duckdb -class TestPartitionedPandasScan(object): +class TestPartitionedPandasScan: def test_parallel_pandas(self, duckdb_cursor): con = duckdb.connect() df = pd.DataFrame({"i": numpy.arange(10000000)}) diff --git a/tests/fast/pandas/test_progress_bar.py b/tests/fast/pandas/test_progress_bar.py index 7c1c21e1..c8cfb2e0 100644 --- a/tests/fast/pandas/test_progress_bar.py +++ b/tests/fast/pandas/test_progress_bar.py @@ -1,11 +1,11 @@ -import duckdb -import pandas as pd + import numpy -import datetime -import time +import pandas as pd + +import duckdb -class TestProgressBarPandas(object): +class TestProgressBarPandas: def test_progress_pandas_single(self, duckdb_cursor): con = duckdb.connect() df = pd.DataFrame({"i": numpy.arange(10000000)}) diff --git a/tests/fast/pandas/test_pyarrow_projection_pushdown.py b/tests/fast/pandas/test_pyarrow_projection_pushdown.py index b04f713a..4191a96e 100644 --- a/tests/fast/pandas/test_pyarrow_projection_pushdown.py +++ b/tests/fast/pandas/test_pyarrow_projection_pushdown.py @@ -1,16 +1,16 @@ -import duckdb -import os -import pytest +import pytest from conftest import pandas_supports_arrow_backend +import duckdb + pa = pytest.importorskip("pyarrow") ds = pytest.importorskip("pyarrow.dataset") _ = pytest.importorskip("pandas", "2.0.0") @pytest.mark.skipif(not pandas_supports_arrow_backend(), reason="pandas does not support the 'pyarrow' backend") -class TestArrowDFProjectionPushdown(object): +class TestArrowDFProjectionPushdown: def test_projection_pushdown_no_filter(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE test (a INTEGER, b INTEGER, c INTEGER)") diff --git a/tests/fast/pandas/test_same_name.py b/tests/fast/pandas/test_same_name.py index ac4f407a..ff499ddf 100644 --- a/tests/fast/pandas/test_same_name.py +++ b/tests/fast/pandas/test_same_name.py @@ -1,9 +1,7 @@ -import pytest -import duckdb import pandas as pd -class TestMultipleColumnsSameName(object): +class TestMultipleColumnsSameName: def test_multiple_columns_with_same_name(self, duckdb_cursor): df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "d": [9, 10, 11, 12]}) df = df.rename(columns={df.columns[1]: "a"}) diff --git a/tests/fast/pandas/test_stride.py b/tests/fast/pandas/test_stride.py index 1b2f5052..cbe23cfd 100644 --- a/tests/fast/pandas/test_stride.py +++ b/tests/fast/pandas/test_stride.py @@ -1,10 +1,12 @@ +import datetime + +import numpy as np import pandas as pd + import duckdb -import numpy as np -import datetime -class TestPandasStride(object): +class TestPandasStride: def test_stride(self, duckdb_cursor): expected_df = pd.DataFrame(np.arange(20).reshape(5, 4), columns=["a", "b", "c", "d"]) con = duckdb.connect() diff --git a/tests/fast/pandas/test_timedelta.py b/tests/fast/pandas/test_timedelta.py index c0afeb74..deca62e0 100644 --- a/tests/fast/pandas/test_timedelta.py +++ b/tests/fast/pandas/test_timedelta.py @@ -1,11 +1,13 @@ +import datetime import platform + import pandas as pd -import duckdb -import datetime import pytest +import duckdb + -class TestTimedelta(object): +class TestTimedelta: def test_timedelta_positive(self, duckdb_cursor): duckdb_interval = duckdb_cursor.query( "SELECT '2290-01-01 23:59:00'::TIMESTAMP - '2000-01-01 23:59:00'::TIMESTAMP AS '0'" diff --git a/tests/fast/pandas/test_timestamp.py b/tests/fast/pandas/test_timestamp.py index dbb7273d..e14d82a6 100644 --- a/tests/fast/pandas/test_timestamp.py +++ b/tests/fast/pandas/test_timestamp.py @@ -1,13 +1,15 @@ -import duckdb import datetime import os -import pytest -import pandas as pd import platform + +import pandas as pd +import pytest from conftest import pandas_2_or_higher +import duckdb + -class TestPandasTimestamps(object): +class TestPandasTimestamps: @pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) def test_timestamp_types_roundtrip(self, unit): d = { diff --git a/tests/fast/relational_api/test_groupings.py b/tests/fast/relational_api/test_groupings.py index b0a95410..250df7ad 100644 --- a/tests/fast/relational_api/test_groupings.py +++ b/tests/fast/relational_api/test_groupings.py @@ -1,6 +1,7 @@ -import duckdb import pytest +import duckdb + @pytest.fixture def con(): @@ -17,10 +18,10 @@ def con(): ) AS tbl(a, b, c)) """ ) - yield conn + return conn -class TestGroupings(object): +class TestGroupings: def test_basic_grouping(self, con): rel = con.table("tbl").sum("a", "b") res = rel.fetchall() diff --git a/tests/fast/relational_api/test_joins.py b/tests/fast/relational_api/test_joins.py index cf3d3cf2..726fdac8 100644 --- a/tests/fast/relational_api/test_joins.py +++ b/tests/fast/relational_api/test_joins.py @@ -1,5 +1,6 @@ -import duckdb import pytest + +import duckdb from duckdb import ColumnExpression @@ -26,10 +27,10 @@ def con(): ) AS t(a, b)) """ ) - yield conn + return conn -class TestRAPIJoins(object): +class TestRAPIJoins: def test_outer_join(self, con): a = con.table("tbl_a") b = con.table("tbl_b") diff --git a/tests/fast/relational_api/test_pivot.py b/tests/fast/relational_api/test_pivot.py index 9cf91e56..1cca02b4 100644 --- a/tests/fast/relational_api/test_pivot.py +++ b/tests/fast/relational_api/test_pivot.py @@ -1,10 +1,8 @@ -import duckdb -import pytest import os import tempfile -class TestPivot(object): +class TestPivot: def test_pivot_issue_14600(self, duckdb_cursor): duckdb_cursor.sql( "create table input_data as select unnest(['u','v','w']) as a, unnest(['x','y','z']) as b, unnest([1,2,3]) as c;" @@ -26,5 +24,5 @@ def test_pivot_issue_14601(self, duckdb_cursor): pivot_1.create("pivot_1") export_dir = tempfile.mkdtemp() duckdb_cursor.query(f"EXPORT DATABASE '{export_dir}'") - with open(os.path.join(export_dir, "schema.sql"), "r") as f: + with open(os.path.join(export_dir, "schema.sql")) as f: assert "CREATE TYPE" not in f.read() diff --git a/tests/fast/relational_api/test_rapi_aggregations.py b/tests/fast/relational_api/test_rapi_aggregations.py index 3466a77a..31cb21c9 100644 --- a/tests/fast/relational_api/test_rapi_aggregations.py +++ b/tests/fast/relational_api/test_rapi_aggregations.py @@ -1,7 +1,8 @@ -import duckdb -from decimal import Decimal + import pytest +import duckdb + @pytest.fixture(autouse=True) def setup_and_teardown_of_table(duckdb_cursor): @@ -23,12 +24,12 @@ def setup_and_teardown_of_table(duckdb_cursor): duckdb_cursor.execute("drop table agg") -@pytest.fixture() +@pytest.fixture def table(duckdb_cursor): return duckdb_cursor.table("agg") -class TestRAPIAggregations(object): +class TestRAPIAggregations: # General aggregate functions def test_any_value(self, table): diff --git a/tests/fast/relational_api/test_rapi_close.py b/tests/fast/relational_api/test_rapi_close.py index b6355167..969e2792 100644 --- a/tests/fast/relational_api/test_rapi_close.py +++ b/tests/fast/relational_api/test_rapi_close.py @@ -1,9 +1,10 @@ -import duckdb import pytest +import duckdb + # A closed connection should invalidate all relation's methods -class TestRAPICloseConnRel(object): +class TestRAPICloseConnRel: def test_close_conn_rel(self, duckdb_cursor): con = duckdb.connect() con.execute("CREATE TABLE items(item VARCHAR, value DECIMAL(10,2), count INTEGER)") diff --git a/tests/fast/relational_api/test_rapi_description.py b/tests/fast/relational_api/test_rapi_description.py index 80616132..2696ed2f 100644 --- a/tests/fast/relational_api/test_rapi_description.py +++ b/tests/fast/relational_api/test_rapi_description.py @@ -1,8 +1,7 @@ -import duckdb import pytest -class TestRAPIDescription(object): +class TestRAPIDescription: def test_rapi_description(self, duckdb_cursor): res = duckdb_cursor.query("select 42::INT AS a, 84::BIGINT AS b") desc = res.description diff --git a/tests/fast/relational_api/test_rapi_functions.py b/tests/fast/relational_api/test_rapi_functions.py index c6b1f1fa..143aa8df 100644 --- a/tests/fast/relational_api/test_rapi_functions.py +++ b/tests/fast/relational_api/test_rapi_functions.py @@ -1,7 +1,7 @@ import duckdb -class TestRAPIFunctions(object): +class TestRAPIFunctions: def test_rapi_str_print(self, duckdb_cursor): res = duckdb_cursor.query("select 42::INT AS a, 84::BIGINT AS b") assert str(res) is not None diff --git a/tests/fast/relational_api/test_rapi_query.py b/tests/fast/relational_api/test_rapi_query.py index 16ed326c..b9f2ef68 100644 --- a/tests/fast/relational_api/test_rapi_query.py +++ b/tests/fast/relational_api/test_rapi_query.py @@ -1,10 +1,12 @@ -import duckdb -import pytest import platform import sys +import pytest + +import duckdb + -@pytest.fixture() +@pytest.fixture def tbl_table(): con = duckdb.default_connection() con.execute("drop table if exists tbl CASCADE") @@ -13,7 +15,7 @@ def tbl_table(): con.execute("drop table tbl CASCADE") -@pytest.fixture() +@pytest.fixture def scoped_default(duckdb_cursor): default = duckdb.connect(":default:") duckdb.set_default_connection(duckdb_cursor) @@ -23,11 +25,11 @@ def scoped_default(duckdb_cursor): duckdb.set_default_connection(default) -class TestRAPIQuery(object): +class TestRAPIQuery: @pytest.mark.parametrize("steps", [1, 2, 3, 4]) def test_query_chain(self, steps): con = duckdb.default_connection() - amount = int(1000000) + amount = 1000000 rel = None for _ in range(steps): rel = con.query(f"select i from range({amount}::BIGINT) tbl(i)") diff --git a/tests/fast/relational_api/test_rapi_windows.py b/tests/fast/relational_api/test_rapi_windows.py index cc58b8f1..ce0196fc 100644 --- a/tests/fast/relational_api/test_rapi_windows.py +++ b/tests/fast/relational_api/test_rapi_windows.py @@ -1,6 +1,7 @@ -import duckdb import pytest +import duckdb + @pytest.fixture(autouse=True) def setup_and_teardown_of_table(duckdb_cursor): @@ -22,7 +23,7 @@ def setup_and_teardown_of_table(duckdb_cursor): duckdb_cursor.execute("drop table win") -@pytest.fixture() +@pytest.fixture def table(duckdb_cursor): return duckdb_cursor.table("win") diff --git a/tests/fast/relational_api/test_table_function.py b/tests/fast/relational_api/test_table_function.py index 5748f762..2a5271f9 100644 --- a/tests/fast/relational_api/test_table_function.py +++ b/tests/fast/relational_api/test_table_function.py @@ -1,11 +1,13 @@ -import duckdb -import pytest import os +import pytest + +import duckdb + script_path = os.path.dirname(__file__) -class TestTableFunction(object): +class TestTableFunction: def test_table_function(self, duckdb_cursor): path = os.path.join(script_path, "..", "data/integers.csv") rel = duckdb_cursor.table_function("read_csv", [path]) diff --git a/tests/fast/spark/test_replace_column_value.py b/tests/fast/spark/test_replace_column_value.py index 65ab85f1..17a2254e 100644 --- a/tests/fast/spark/test_replace_column_value.py +++ b/tests/fast/spark/test_replace_column_value.py @@ -4,7 +4,7 @@ from spark_namespace.sql.types import Row -class TestReplaceValue(object): +class TestReplaceValue: # https://sparkbyexamples.com/pyspark/pyspark-replace-column-values/?expand_article=1 def test_replace_value(self, spark): address = [(1, "14851 Jeffrey Rd", "DE"), (2, "43421 Margarita St", "NY"), (3, "13111 Siemon Ave", "CA")] diff --git a/tests/fast/spark/test_replace_empty_value.py b/tests/fast/spark/test_replace_empty_value.py index aad6a43e..615b15d8 100644 --- a/tests/fast/spark/test_replace_empty_value.py +++ b/tests/fast/spark/test_replace_empty_value.py @@ -2,12 +2,11 @@ _ = pytest.importorskip("duckdb.experimental.spark") -from spark_namespace import USE_ACTUAL_SPARK from spark_namespace.sql.types import Row # https://sparkbyexamples.com/pyspark/pyspark-replace-empty-value-with-none-on-dataframe-2/?expand_article=1 -class TestReplaceEmpty(object): +class TestReplaceEmpty: def test_replace_empty(self, spark): # Create the dataframe data = [("", "CA"), ("Julia", ""), ("Robert", ""), ("", "NJ")] diff --git a/tests/fast/spark/test_spark_arrow_table.py b/tests/fast/spark/test_spark_arrow_table.py index 57c81599..fc773562 100644 --- a/tests/fast/spark/test_spark_arrow_table.py +++ b/tests/fast/spark/test_spark_arrow_table.py @@ -2,8 +2,6 @@ _ = pytest.importorskip("duckdb.experimental.spark") pa = pytest.importorskip("pyarrow") -from spark_namespace import USE_ACTUAL_SPARK - from spark_namespace import USE_ACTUAL_SPARK from spark_namespace.sql.dataframe import DataFrame diff --git a/tests/fast/spark/test_spark_catalog.py b/tests/fast/spark/test_spark_catalog.py index 2ecaad24..c19ec83c 100644 --- a/tests/fast/spark/test_spark_catalog.py +++ b/tests/fast/spark/test_spark_catalog.py @@ -3,10 +3,10 @@ _ = pytest.importorskip("duckdb.experimental.spark") from spark_namespace import USE_ACTUAL_SPARK -from spark_namespace.sql.catalog import Table, Database, Column +from spark_namespace.sql.catalog import Column, Database, Table -class TestSparkCatalog(object): +class TestSparkCatalog: def test_list_databases(self, spark): dbs = spark.catalog.listDatabases() if USE_ACTUAL_SPARK: diff --git a/tests/fast/spark/test_spark_column.py b/tests/fast/spark/test_spark_column.py index 9ef17d95..e8da1333 100644 --- a/tests/fast/spark/test_spark_column.py +++ b/tests/fast/spark/test_spark_column.py @@ -2,17 +2,15 @@ _ = pytest.importorskip("duckdb.experimental.spark") +import re + from spark_namespace import USE_ACTUAL_SPARK -from spark_namespace.sql.column import Column -from spark_namespace.sql.functions import struct, array, col -from spark_namespace.sql.types import Row from spark_namespace.errors import PySparkTypeError - -import duckdb -import re +from spark_namespace.sql.functions import array, col, struct +from spark_namespace.sql.types import Row -class TestSparkColumn(object): +class TestSparkColumn: def test_struct_column(self, spark): df = spark.createDataFrame([Row(a=1, b=2, c=3, d=4)]) diff --git a/tests/fast/spark/test_spark_dataframe.py b/tests/fast/spark/test_spark_dataframe.py index e86995ec..26006952 100644 --- a/tests/fast/spark/test_spark_dataframe.py +++ b/tests/fast/spark/test_spark_dataframe.py @@ -2,25 +2,22 @@ _ = pytest.importorskip("duckdb.experimental.spark") + from spark_namespace import USE_ACTUAL_SPARK +from spark_namespace.errors import PySparkTypeError, PySparkValueError +from spark_namespace.sql.column import Column +from spark_namespace.sql.functions import col, struct, when from spark_namespace.sql.types import ( - LongType, - StructType, + ArrayType, BooleanType, - StructField, - StringType, IntegerType, LongType, - Row, - ArrayType, MapType, + Row, + StringType, + StructField, + StructType, ) -from spark_namespace.sql.functions import col, struct, when -from spark_namespace.sql.column import Column -import duckdb -import re - -from spark_namespace.errors import PySparkValueError, PySparkTypeError def assert_column_objects_equal(col1: Column, col2: Column): @@ -29,7 +26,7 @@ def assert_column_objects_equal(col1: Column, col2: Column): assert col1.expr == col2.expr -class TestDataFrame(object): +class TestDataFrame: def test_dataframe_from_list_of_tuples(self, spark): # Valid address = [(1, "14851 Jeffrey Rd", "DE"), (2, "43421 Margarita St", "NY"), (3, "13111 Siemon Ave", "CA")] @@ -194,7 +191,7 @@ def test_df_from_name_list(self, spark): assert res == [Row(a=42, b=True), Row(a=21, b=False)] def test_df_creation_coverage(self, spark): - from spark_namespace.sql.types import StructType, StructField, StringType, IntegerType + from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType data2 = [ ("James", "", "Smith", "36636", "M", 3000), @@ -298,7 +295,7 @@ def test_df_nested_struct(self, spark): ) def test_df_columns(self, spark): - from spark_namespace.sql.functions import col, struct, when + from spark_namespace.sql.functions import col structureData = [ (("James", "", "Smith"), "36636", "M", 3100), @@ -343,7 +340,6 @@ def test_df_columns(self, spark): def test_array_and_map_type(self, spark): """Array & Map""" - arrayStructureSchema = StructType( [ StructField( diff --git a/tests/fast/spark/test_spark_dataframe_sort.py b/tests/fast/spark/test_spark_dataframe_sort.py index db7dce4b..49631d4d 100644 --- a/tests/fast/spark/test_spark_dataframe_sort.py +++ b/tests/fast/spark/test_spark_dataframe_sort.py @@ -3,13 +3,13 @@ _ = pytest.importorskip("duckdb.experimental.spark") import spark_namespace.errors -from spark_namespace.sql.types import Row -from spark_namespace.sql.functions import desc, asc -from spark_namespace.errors import PySparkTypeError, PySparkValueError from spark_namespace import USE_ACTUAL_SPARK +from spark_namespace.errors import PySparkTypeError, PySparkValueError +from spark_namespace.sql.functions import asc, desc +from spark_namespace.sql.types import Row -class TestDataFrameSort(object): +class TestDataFrameSort: data = [(56, "Carol"), (20, "Alice"), (3, "Dave"), (3, "Anna"), (1, "Ben")] def test_sort_ascending(self, spark): diff --git a/tests/fast/spark/test_spark_drop_duplicates.py b/tests/fast/spark/test_spark_drop_duplicates.py index 563a5e76..cd658c77 100644 --- a/tests/fast/spark/test_spark_drop_duplicates.py +++ b/tests/fast/spark/test_spark_drop_duplicates.py @@ -1,6 +1,4 @@ import pytest - - from spark_namespace.sql.types import ( Row, ) @@ -8,7 +6,7 @@ _ = pytest.importorskip("duckdb.experimental.spark") -class TestDataFrameDropDuplicates(object): +class TestDataFrameDropDuplicates: @pytest.mark.parametrize("method", ["dropDuplicates", "drop_duplicates"]) def test_spark_drop_duplicates(self, method, spark): # Prepare Data diff --git a/tests/fast/spark/test_spark_except.py b/tests/fast/spark/test_spark_except.py index 7c28cc29..dd6c802d 100644 --- a/tests/fast/spark/test_spark_except.py +++ b/tests/fast/spark/test_spark_except.py @@ -1,10 +1,8 @@ -import platform import pytest _ = pytest.importorskip("duckdb.experimental.spark") from duckdb.experimental.spark.sql.types import Row -from duckdb.experimental.spark.sql.functions import col @pytest.fixture diff --git a/tests/fast/spark/test_spark_filter.py b/tests/fast/spark/test_spark_filter.py index a4733a44..9dbb8c94 100644 --- a/tests/fast/spark/test_spark_filter.py +++ b/tests/fast/spark/test_spark_filter.py @@ -2,26 +2,20 @@ _ = pytest.importorskip("duckdb.experimental.spark") + from spark_namespace import USE_ACTUAL_SPARK +from spark_namespace.errors import PySparkTypeError +from spark_namespace.sql.functions import array_contains, col from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, - Row, ArrayType, - MapType, + Row, + StringType, + StructField, + StructType, ) -from spark_namespace.sql.functions import col, struct, when, lit, array_contains -from spark_namespace.errors import PySparkTypeError -import duckdb -import re -class TestDataFrameFilter(object): +class TestDataFrameFilter: def test_dataframe_filter(self, spark): data = [ (("James", "", "Smith"), ["Java", "Scala", "C++"], "OH", "M"), diff --git a/tests/fast/spark/test_spark_function_concat_ws.py b/tests/fast/spark/test_spark_function_concat_ws.py index 82f19cd1..b4268d0f 100644 --- a/tests/fast/spark/test_spark_function_concat_ws.py +++ b/tests/fast/spark/test_spark_function_concat_ws.py @@ -1,11 +1,11 @@ import pytest _ = pytest.importorskip("duckdb.experimental.spark") +from spark_namespace.sql.functions import col, concat_ws from spark_namespace.sql.types import Row -from spark_namespace.sql.functions import concat_ws, col -class TestReplaceEmpty(object): +class TestReplaceEmpty: def test_replace_empty(self, spark): data = [ ("firstRowFirstColumn", "firstRowSecondColumn"), diff --git a/tests/fast/spark/test_spark_functions_array.py b/tests/fast/spark/test_spark_functions_array.py index 5ecba132..36afed54 100644 --- a/tests/fast/spark/test_spark_functions_array.py +++ b/tests/fast/spark/test_spark_functions_array.py @@ -1,10 +1,11 @@ -import pytest import platform +import pytest + _ = pytest.importorskip("duckdb.experimental.spark") +from spark_namespace import USE_ACTUAL_SPARK from spark_namespace.sql import functions as sf from spark_namespace.sql.types import Row -from spark_namespace import USE_ACTUAL_SPARK pytestmark = pytest.mark.skipif( platform.system() == "Emscripten", diff --git a/tests/fast/spark/test_spark_functions_base64.py b/tests/fast/spark/test_spark_functions_base64.py index 5a179481..44e4a7cd 100644 --- a/tests/fast/spark/test_spark_functions_base64.py +++ b/tests/fast/spark/test_spark_functions_base64.py @@ -5,7 +5,7 @@ from spark_namespace.sql import functions as F -class TestSparkFunctionsBase64(object): +class TestSparkFunctionsBase64: def test_base64(self, spark): data = [ ("quack",), diff --git a/tests/fast/spark/test_spark_functions_date.py b/tests/fast/spark/test_spark_functions_date.py index a298c0ff..914d33f6 100644 --- a/tests/fast/spark/test_spark_functions_date.py +++ b/tests/fast/spark/test_spark_functions_date.py @@ -1,4 +1,5 @@ import warnings + import pytest _ = pytest.importorskip("duckdb.experimental.spark") @@ -6,11 +7,11 @@ from spark_namespace import USE_ACTUAL_SPARK from spark_namespace.sql import functions as F -from spark_namespace.sql.types import Row from spark_namespace.sql.functions import col +from spark_namespace.sql.types import Row -class TestsSparkFunctionsDate(object): +class TestsSparkFunctionsDate: def test_date_trunc(self, spark): df = spark.createDataFrame( [(datetime(2019, 1, 23, 14, 34, 9, 87539),)], diff --git a/tests/fast/spark/test_spark_functions_expr.py b/tests/fast/spark/test_spark_functions_expr.py index 7cc47735..f14dbcce 100644 --- a/tests/fast/spark/test_spark_functions_expr.py +++ b/tests/fast/spark/test_spark_functions_expr.py @@ -5,7 +5,7 @@ _ = pytest.importorskip("duckdb.experimental.spark") -class TestSparkFunctionsExpr(object): +class TestSparkFunctionsExpr: def test_expr(self, spark): df = spark.createDataFrame([["Alice"], ["Bob"]], ["name"]) res = df.select("name", F.expr("length(name)").alias("str_len")).collect() diff --git a/tests/fast/spark/test_spark_functions_hash.py b/tests/fast/spark/test_spark_functions_hash.py index 7b14f29e..d1890990 100644 --- a/tests/fast/spark/test_spark_functions_hash.py +++ b/tests/fast/spark/test_spark_functions_hash.py @@ -4,7 +4,7 @@ from spark_namespace.sql import functions as F -class TestSparkFunctionsHash(object): +class TestSparkFunctionsHash: def test_md5(self, spark): data = [ ("quack",), diff --git a/tests/fast/spark/test_spark_functions_hex.py b/tests/fast/spark/test_spark_functions_hex.py index 7d5f3c6a..c58c6d90 100644 --- a/tests/fast/spark/test_spark_functions_hex.py +++ b/tests/fast/spark/test_spark_functions_hex.py @@ -1,11 +1,11 @@ + import pytest -import sys _ = pytest.importorskip("duckdb.experimental.spark") from spark_namespace.sql import functions as F -class TestSparkFunctionsHex(object): +class TestSparkFunctionsHex: def test_hex_string_col(self, spark): data = [ ("quack",), @@ -32,7 +32,7 @@ def test_hex_binary_col(self, spark): def test_hex_integer_col(self, spark): data = [ - (int(42),), + (42,), ] res = ( spark.createDataFrame(data, ["firstColumn"]) diff --git a/tests/fast/spark/test_spark_functions_null.py b/tests/fast/spark/test_spark_functions_null.py index 230634dc..2bcfd94a 100644 --- a/tests/fast/spark/test_spark_functions_null.py +++ b/tests/fast/spark/test_spark_functions_null.py @@ -7,7 +7,7 @@ from spark_namespace.sql.types import Row -class TestsSparkFunctionsNull(object): +class TestsSparkFunctionsNull: def test_coalesce(self, spark): data = [ (None, 2), diff --git a/tests/fast/spark/test_spark_functions_numeric.py b/tests/fast/spark/test_spark_functions_numeric.py index 3548d439..30224735 100644 --- a/tests/fast/spark/test_spark_functions_numeric.py +++ b/tests/fast/spark/test_spark_functions_numeric.py @@ -3,13 +3,14 @@ _ = pytest.importorskip("duckdb.experimental.spark") import math + import numpy as np from spark_namespace import USE_ACTUAL_SPARK from spark_namespace.sql import functions as sf from spark_namespace.sql.types import Row -class TestSparkFunctionsNumeric(object): +class TestSparkFunctionsNumeric: def test_greatest(self, spark): data = [ (1, 2), diff --git a/tests/fast/spark/test_spark_functions_string.py b/tests/fast/spark/test_spark_functions_string.py index b8d7f483..0001a167 100644 --- a/tests/fast/spark/test_spark_functions_string.py +++ b/tests/fast/spark/test_spark_functions_string.py @@ -7,7 +7,7 @@ from spark_namespace.sql.types import Row -class TestSparkFunctionsString(object): +class TestSparkFunctionsString: def test_length(self, spark): data = [ ("firstRowFirstColumn",), diff --git a/tests/fast/spark/test_spark_group_by.py b/tests/fast/spark/test_spark_group_by.py index 9e8a8ea0..f3748f1d 100644 --- a/tests/fast/spark/test_spark_group_by.py +++ b/tests/fast/spark/test_spark_group_by.py @@ -3,47 +3,35 @@ _ = pytest.importorskip("duckdb.experimental.spark") from spark_namespace import USE_ACTUAL_SPARK -from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, - Row, - ArrayType, - MapType, -) -from spark_namespace.sql.functions import col, struct, when, lit, array_contains from spark_namespace.sql.functions import ( - sum, + any_value, + approx_count_distinct, avg, + col, + covar_pop, + covar_samp, + first, + last, max, - min, - stddev_samp, - stddev, + median, + mode, + product, + skewness, std, + stddev, stddev_pop, + stddev_samp, + sum, var_pop, var_samp, variance, - mean, - mode, - median, - product, - count, - skewness, - any_value, - approx_count_distinct, - covar_pop, - covar_samp, - first, - last, +) +from spark_namespace.sql.types import ( + Row, ) -class TestDataFrameGroupBy(object): +class TestDataFrameGroupBy: def test_group_by(self, spark): simpleData = [ ("James", "Sales", "NY", 90000, 34, 10000), diff --git a/tests/fast/spark/test_spark_intersect.py b/tests/fast/spark/test_spark_intersect.py index ba0afbdd..8ec67dd0 100644 --- a/tests/fast/spark/test_spark_intersect.py +++ b/tests/fast/spark/test_spark_intersect.py @@ -1,10 +1,8 @@ -import platform import pytest _ = pytest.importorskip("duckdb.experimental.spark") from duckdb.experimental.spark.sql.types import Row -from duckdb.experimental.spark.sql.functions import col @pytest.fixture diff --git a/tests/fast/spark/test_spark_join.py b/tests/fast/spark/test_spark_join.py index f67c54cb..842dfbc5 100644 --- a/tests/fast/spark/test_spark_join.py +++ b/tests/fast/spark/test_spark_join.py @@ -2,20 +2,10 @@ _ = pytest.importorskip("duckdb.experimental.spark") +from spark_namespace.sql.functions import col from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, Row, - ArrayType, - MapType, ) -from spark_namespace.sql.functions import col, struct, when, lit, array_contains -from spark_namespace.sql.functions import sum, avg, max, min, mean, count @pytest.fixture @@ -30,7 +20,7 @@ def dataframe_a(spark): ] empColumns = ["emp_id", "name", "superior_emp_id", "year_joined", "emp_dept_id", "gender", "salary"] dataframe = spark.createDataFrame(data=emp, schema=empColumns) - yield dataframe + return dataframe @pytest.fixture @@ -38,10 +28,10 @@ def dataframe_b(spark): dept = [("Finance", 10), ("Marketing", 20), ("Sales", 30), ("IT", 40)] deptColumns = ["dept_name", "dept_id"] dataframe = spark.createDataFrame(data=dept, schema=deptColumns) - yield dataframe + return dataframe -class TestDataFrameJoin(object): +class TestDataFrameJoin: def test_inner_join(self, dataframe_a, dataframe_b): df = dataframe_a.join(dataframe_b, dataframe_a.emp_dept_id == dataframe_b.dept_id, "inner") df = df.sort(*df.columns) diff --git a/tests/fast/spark/test_spark_limit.py b/tests/fast/spark/test_spark_limit.py index c00496a0..eb88fc6a 100644 --- a/tests/fast/spark/test_spark_limit.py +++ b/tests/fast/spark/test_spark_limit.py @@ -7,7 +7,7 @@ ) -class TestDataFrameLimit(object): +class TestDataFrameLimit: def test_dataframe_limit(self, spark): df = spark.sql("select * from range(100000)") df2 = df.limit(10) diff --git a/tests/fast/spark/test_spark_order_by.py b/tests/fast/spark/test_spark_order_by.py index cc08dd7c..030db4b8 100644 --- a/tests/fast/spark/test_spark_order_by.py +++ b/tests/fast/spark/test_spark_order_by.py @@ -2,24 +2,13 @@ _ = pytest.importorskip("duckdb.experimental.spark") +from spark_namespace.sql.functions import col from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, Row, - ArrayType, - MapType, ) -from spark_namespace.sql.functions import col, struct, when, lit, array_contains -import duckdb -import re -class TestDataFrameOrderBy(object): +class TestDataFrameOrderBy: def test_order_by(self, spark): simpleData = [ ("James", "Sales", "NY", 90000, 34, 10000), diff --git a/tests/fast/spark/test_spark_pandas_dataframe.py b/tests/fast/spark/test_spark_pandas_dataframe.py index 6491b7a6..ab069156 100644 --- a/tests/fast/spark/test_spark_pandas_dataframe.py +++ b/tests/fast/spark/test_spark_pandas_dataframe.py @@ -3,22 +3,14 @@ _ = pytest.importorskip("duckdb.experimental.spark") pd = pytest.importorskip("pandas") +from pandas.testing import assert_frame_equal from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, IntegerType, - LongType, Row, - ArrayType, - MapType, + StringType, + StructField, + StructType, ) -from spark_namespace.sql.functions import col, struct, when -import duckdb -import re -from pandas.testing import assert_frame_equal @pytest.fixture @@ -26,10 +18,10 @@ def pandasDF(spark): data = [["Scott", 50], ["Jeff", 45], ["Thomas", 54], ["Ann", 34]] # Create the pandas DataFrame df = pd.DataFrame(data, columns=["Name", "Age"]) - yield df + return df -class TestPandasDataFrame(object): +class TestPandasDataFrame: def test_pd_conversion_basic(self, spark, pandasDF): sparkDF = spark.createDataFrame(pandasDF) res = sparkDF.collect() diff --git a/tests/fast/spark/test_spark_readcsv.py b/tests/fast/spark/test_spark_readcsv.py index 5ba3d199..10d1a17c 100644 --- a/tests/fast/spark/test_spark_readcsv.py +++ b/tests/fast/spark/test_spark_readcsv.py @@ -2,12 +2,13 @@ _ = pytest.importorskip("duckdb.experimental.spark") -from spark_namespace.sql.types import Row -from spark_namespace import USE_ACTUAL_SPARK import textwrap +from spark_namespace import USE_ACTUAL_SPARK +from spark_namespace.sql.types import Row + -class TestSparkReadCSV(object): +class TestSparkReadCSV: def test_read_csv(self, spark, tmp_path): file_path = tmp_path / "basic.csv" with open(file_path, "w+") as f: diff --git a/tests/fast/spark/test_spark_readjson.py b/tests/fast/spark/test_spark_readjson.py index 638bee2d..aa8d8ec5 100644 --- a/tests/fast/spark/test_spark_readjson.py +++ b/tests/fast/spark/test_spark_readjson.py @@ -2,12 +2,11 @@ _ = pytest.importorskip("duckdb.experimental.spark") + from spark_namespace.sql.types import Row -import textwrap -import duckdb -class TestSparkReadJson(object): +class TestSparkReadJson: def test_read_json(self, duckdb_cursor, spark, tmp_path): file_path = tmp_path / "basic.parquet" file_path = file_path.as_posix() diff --git a/tests/fast/spark/test_spark_readparquet.py b/tests/fast/spark/test_spark_readparquet.py index 1b3ddd74..2f182650 100644 --- a/tests/fast/spark/test_spark_readparquet.py +++ b/tests/fast/spark/test_spark_readparquet.py @@ -2,12 +2,11 @@ _ = pytest.importorskip("duckdb.experimental.spark") + from spark_namespace.sql.types import Row -import textwrap -import duckdb -class TestSparkReadParquet(object): +class TestSparkReadParquet: def test_read_parquet(self, duckdb_cursor, spark, tmp_path): file_path = tmp_path / "basic.parquet" file_path = file_path.as_posix() diff --git a/tests/fast/spark/test_spark_runtime_config.py b/tests/fast/spark/test_spark_runtime_config.py index 5e93ed63..b9053899 100644 --- a/tests/fast/spark/test_spark_runtime_config.py +++ b/tests/fast/spark/test_spark_runtime_config.py @@ -5,7 +5,7 @@ from spark_namespace import USE_ACTUAL_SPARK -class TestSparkRuntimeConfig(object): +class TestSparkRuntimeConfig: def test_spark_runtime_config(self, spark): # This fetches the internal runtime config from the session spark.conf diff --git a/tests/fast/spark/test_spark_session.py b/tests/fast/spark/test_spark_session.py index 06c9dbcb..604c85f1 100644 --- a/tests/fast/spark/test_spark_session.py +++ b/tests/fast/spark/test_spark_session.py @@ -1,15 +1,16 @@ import pytest +from spark_namespace import USE_ACTUAL_SPARK +from spark_namespace.sql.types import Row + from duckdb.experimental.spark.exception import ( ContributionsAcceptedError, ) -from spark_namespace.sql.types import Row -from spark_namespace import USE_ACTUAL_SPARK _ = pytest.importorskip("duckdb.experimental.spark") from spark_namespace.sql import SparkSession -class TestSparkSession(object): +class TestSparkSession: def test_spark_session_default(self): session = SparkSession.builder.getOrCreate() diff --git a/tests/fast/spark/test_spark_to_csv.py b/tests/fast/spark/test_spark_to_csv.py index e5387a6c..122f3223 100644 --- a/tests/fast/spark/test_spark_to_csv.py +++ b/tests/fast/spark/test_spark_to_csv.py @@ -1,8 +1,7 @@ -import pytest -import tempfile - import os +import pytest + _ = pytest.importorskip("duckdb.experimental.spark") from spark_namespace import USE_ACTUAL_SPARK @@ -15,12 +14,13 @@ allow_module_level=True, ) -from duckdb import connect, InvalidInputException, read_csv -from conftest import NumpyPandas, ArrowPandas, getTimeSeriesData -from spark_namespace import USE_ACTUAL_SPARK -import pandas._testing as tm -import datetime import csv +import datetime + +from conftest import ArrowPandas, NumpyPandas, getTimeSeriesData +from spark_namespace import USE_ACTUAL_SPARK + +from duckdb import InvalidInputException, read_csv @pytest.fixture @@ -34,24 +34,24 @@ def df(spark): ) columns = ["CourseName", "fee", "discount"] dataframe = spark.createDataFrame(data=simpleData, schema=columns) - yield dataframe + return dataframe @pytest.fixture(params=[NumpyPandas(), ArrowPandas()]) def pandas_df_ints(request, spark): pandas = request.param dataframe = pandas.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) - yield dataframe + return dataframe @pytest.fixture(params=[NumpyPandas(), ArrowPandas()]) def pandas_df_strings(request, spark): pandas = request.param dataframe = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) - yield dataframe + return dataframe -class TestSparkToCSV(object): +class TestSparkToCSV: def test_basic_to_csv(self, pandas_df_ints, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") diff --git a/tests/fast/spark/test_spark_to_parquet.py b/tests/fast/spark/test_spark_to_parquet.py index 68a10f65..8dc2d386 100644 --- a/tests/fast/spark/test_spark_to_parquet.py +++ b/tests/fast/spark/test_spark_to_parquet.py @@ -1,8 +1,7 @@ -import pytest -import tempfile - import os +import pytest + _ = pytest.importorskip("duckdb.experimental.spark") @@ -17,10 +16,10 @@ def df(spark): ) columns = ["CourseName", "fee", "discount"] dataframe = spark.createDataFrame(data=simpleData, schema=columns) - yield dataframe + return dataframe -class TestSparkToParquet(object): +class TestSparkToParquet: def test_basic_to_parquet(self, df, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.parquet") diff --git a/tests/fast/spark/test_spark_transform.py b/tests/fast/spark/test_spark_transform.py index 1f1186c5..bf1c7b01 100644 --- a/tests/fast/spark/test_spark_transform.py +++ b/tests/fast/spark/test_spark_transform.py @@ -3,19 +3,8 @@ _ = pytest.importorskip("duckdb.experimental.spark") from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, Row, - ArrayType, - MapType, ) -from spark_namespace.sql.functions import col, struct, when, lit, array_contains -from spark_namespace.sql.functions import sum, avg, max, min, mean, count @pytest.fixture @@ -26,7 +15,7 @@ def array_df(spark): ("Robert,,Williams", ["CSharp", "VB"], ["Spark", "Python"]), ] dataframe = spark.createDataFrame(data=data, schema=["Name", "Languages1", "Languages2"]) - yield dataframe + return dataframe @pytest.fixture @@ -40,10 +29,10 @@ def df(spark): ) columns = ["CourseName", "fee", "discount"] dataframe = spark.createDataFrame(data=simpleData, schema=columns) - yield dataframe + return dataframe -class TestDataFrameUnion(object): +class TestDataFrameUnion: def test_transform(self, spark, df): # Custom transformation 1 from spark_namespace.sql.functions import upper @@ -72,6 +61,6 @@ def apply_discount(df): # https://sparkbyexamples.com/pyspark/pyspark-transform-function/ @pytest.mark.skip(reason="LambdaExpressions are currently under development, waiting til that is finished") def test_transform_function(self, spark, array_df): - from spark_namespace.sql.functions import upper, transform + from spark_namespace.sql.functions import transform, upper df.select(transform("Languages1", lambda x: upper(x)).alias("languages1")).show() diff --git a/tests/fast/spark/test_spark_types.py b/tests/fast/spark/test_spark_types.py index 6c97c2d9..d19b3833 100644 --- a/tests/fast/spark/test_spark_types.py +++ b/tests/fast/spark/test_spark_types.py @@ -9,43 +9,42 @@ "Skipping these tests as they use test_all_types() which is specific to DuckDB", allow_module_level=True ) -from spark_namespace.sql.types import Row from spark_namespace.sql.types import ( - StringType, + ArrayType, BinaryType, BitstringType, - UUIDType, BooleanType, + ByteType, DateType, - TimestampType, - TimestampNTZType, - TimeType, - TimeNTZType, - TimestampNanosecondNTZType, - TimestampMilisecondNTZType, - TimestampSecondNTZType, + DayTimeIntervalType, DecimalType, DoubleType, FloatType, - ByteType, - UnsignedByteType, - ShortType, - UnsignedShortType, + HugeIntegerType, IntegerType, - UnsignedIntegerType, LongType, - UnsignedLongType, - HugeIntegerType, - UnsignedHugeIntegerType, - DayTimeIntervalType, - ArrayType, MapType, + ShortType, + StringType, StructField, StructType, + TimeNTZType, + TimestampMilisecondNTZType, + TimestampNanosecondNTZType, + TimestampNTZType, + TimestampSecondNTZType, + TimestampType, + TimeType, + UnsignedByteType, + UnsignedHugeIntegerType, + UnsignedIntegerType, + UnsignedLongType, + UnsignedShortType, + UUIDType, ) -class TestTypes(object): +class TestTypes: def test_all_types_schema(self, spark): # Create DataFrame df = spark.sql( diff --git a/tests/fast/spark/test_spark_udf.py b/tests/fast/spark/test_spark_udf.py index eebabbb3..cee0f256 100644 --- a/tests/fast/spark/test_spark_udf.py +++ b/tests/fast/spark/test_spark_udf.py @@ -3,7 +3,7 @@ _ = pytest.importorskip("duckdb.experimental.spark") -class TestSparkUDF(object): +class TestSparkUDF: def test_udf_register(self, spark): def to_upper_fn(s: str) -> str: return s.upper() diff --git a/tests/fast/spark/test_spark_union.py b/tests/fast/spark/test_spark_union.py index 8a3ff9ce..588c7ecd 100644 --- a/tests/fast/spark/test_spark_union.py +++ b/tests/fast/spark/test_spark_union.py @@ -1,10 +1,11 @@ import platform + import pytest _ = pytest.importorskip("duckdb.experimental.spark") -from spark_namespace.sql.types import Row from spark_namespace.sql.functions import col +from spark_namespace.sql.types import Row @pytest.fixture @@ -18,7 +19,7 @@ def df(spark): columns = ["employee_name", "department", "state", "salary", "age", "bonus"] dataframe = spark.createDataFrame(data=simpleData, schema=columns) - yield dataframe + return dataframe @pytest.fixture @@ -32,10 +33,10 @@ def df2(spark): ] columns2 = ["employee_name", "department", "state", "salary", "age", "bonus"] dataframe = spark.createDataFrame(data=simpleData2, schema=columns2) - yield dataframe + return dataframe -class TestDataFrameUnion(object): +class TestDataFrameUnion: def test_merge_with_union(self, df, df2): unionDF = df.union(df2) res = unionDF.collect() diff --git a/tests/fast/spark/test_spark_union_by_name.py b/tests/fast/spark/test_spark_union_by_name.py index 4739f0d8..bec539a2 100644 --- a/tests/fast/spark/test_spark_union_by_name.py +++ b/tests/fast/spark/test_spark_union_by_name.py @@ -4,36 +4,25 @@ from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, Row, - ArrayType, - MapType, ) -from spark_namespace.sql.functions import col, struct, when, lit, array_contains -from spark_namespace.sql.functions import sum, avg, max, min, mean, count @pytest.fixture def df1(spark): data = [("James", 34), ("Michael", 56), ("Robert", 30), ("Maria", 24)] dataframe = spark.createDataFrame(data=data, schema=["name", "id"]) - yield dataframe + return dataframe @pytest.fixture def df2(spark): data2 = [(34, "James"), (45, "Maria"), (45, "Jen"), (34, "Jeff")] dataframe = spark.createDataFrame(data=data2, schema=["id", "name"]) - yield dataframe + return dataframe -class TestDataFrameUnion(object): +class TestDataFrameUnion: def test_union_by_name(self, df1, df2): rel = df1.unionByName(df2) res = rel.collect() diff --git a/tests/fast/spark/test_spark_with_column.py b/tests/fast/spark/test_spark_with_column.py index 2980e7fe..4ea62fe1 100644 --- a/tests/fast/spark/test_spark_with_column.py +++ b/tests/fast/spark/test_spark_with_column.py @@ -2,25 +2,11 @@ _ = pytest.importorskip("duckdb.experimental.spark") -from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, - IntegerType, - LongType, - Row, - ArrayType, - MapType, -) -from spark_namespace.sql.functions import col, struct, when, lit from spark_namespace import USE_ACTUAL_SPARK -import duckdb -import re +from spark_namespace.sql.functions import col, lit -class TestWithColumn(object): +class TestWithColumn: def test_with_column(self, spark): data = [ ("James", "", "Smith", "1991-04-01", "M", 3000), diff --git a/tests/fast/spark/test_spark_with_column_renamed.py b/tests/fast/spark/test_spark_with_column_renamed.py index 8534ab0b..789bf2c1 100644 --- a/tests/fast/spark/test_spark_with_column_renamed.py +++ b/tests/fast/spark/test_spark_with_column_renamed.py @@ -2,24 +2,17 @@ _ = pytest.importorskip("duckdb.experimental.spark") + +from spark_namespace.sql.functions import col from spark_namespace.sql.types import ( - LongType, - StructType, - BooleanType, - StructField, - StringType, IntegerType, - LongType, - Row, - ArrayType, - MapType, + StringType, + StructField, + StructType, ) -from spark_namespace.sql.functions import col, struct, when, lit -import duckdb -import re -class TestWithColumnRenamed(object): +class TestWithColumnRenamed: def test_with_column_renamed(self, spark): dataDF = [ (("James", "", "Smith"), "1991-04-01", "M", 3000), @@ -28,7 +21,6 @@ def test_with_column_renamed(self, spark): (("Maria", "Anne", "Jones"), "1967-12-01", "F", 4000), (("Jen", "Mary", "Brown"), "1980-02-17", "F", -1), ] - from spark_namespace.sql.types import StructType, StructField, StringType, IntegerType schema = StructType( [ diff --git a/tests/fast/spark/test_spark_with_columns.py b/tests/fast/spark/test_spark_with_columns.py index 535f357d..244d40a3 100644 --- a/tests/fast/spark/test_spark_with_columns.py +++ b/tests/fast/spark/test_spark_with_columns.py @@ -3,8 +3,8 @@ _ = pytest.importorskip("duckdb.experimental.spark") -from spark_namespace.sql.functions import col, lit from spark_namespace import USE_ACTUAL_SPARK +from spark_namespace.sql.functions import col, lit class TestWithColumns: diff --git a/tests/fast/spark/test_spark_with_columns_renamed.py b/tests/fast/spark/test_spark_with_columns_renamed.py index 80b8b9e0..8c24062b 100644 --- a/tests/fast/spark/test_spark_with_columns_renamed.py +++ b/tests/fast/spark/test_spark_with_columns_renamed.py @@ -1,4 +1,5 @@ import re + import pytest _ = pytest.importorskip("duckdb.experimental.spark") @@ -6,7 +7,7 @@ from spark_namespace import USE_ACTUAL_SPARK -class TestWithColumnsRenamed(object): +class TestWithColumnsRenamed: def test_with_columns_renamed(self, spark): dataDF = [ (("James", "", "Smith"), "1991-04-01", "M", 3000), @@ -15,7 +16,7 @@ def test_with_columns_renamed(self, spark): (("Maria", "Anne", "Jones"), "1967-12-01", "F", 4000), (("Jen", "Mary", "Brown"), "1980-02-17", "F", -1), ] - from spark_namespace.sql.types import StructType, StructField, StringType, IntegerType + from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType schema = StructType( [ diff --git a/tests/fast/sqlite/test_types.py b/tests/fast/sqlite/test_types.py index 3ffdceae..47c4b7e1 100644 --- a/tests/fast/sqlite/test_types.py +++ b/tests/fast/sqlite/test_types.py @@ -27,8 +27,8 @@ import datetime import decimal import unittest + import duckdb -import pytest class DuckDBTypeTests(unittest.TestCase): diff --git a/tests/fast/test_alex_multithread.py b/tests/fast/test_alex_multithread.py index bcb0181b..7e25b5bb 100644 --- a/tests/fast/test_alex_multithread.py +++ b/tests/fast/test_alex_multithread.py @@ -1,8 +1,9 @@ import platform -import duckdb from threading import Thread, current_thread + import pytest +import duckdb pytestmark = pytest.mark.xfail( condition=platform.system() == "Emscripten", @@ -30,7 +31,7 @@ def insert_from_same_connection(duckdb_cursor): duckdb_cursor.execute("""INSERT INTO my_inserts VALUES (?)""", (thread_name,)) -class TestPythonMultithreading(object): +class TestPythonMultithreading: def test_multiple_cursors(self, duckdb_cursor): duckdb_con = duckdb.connect() # In Memory DuckDB duckdb_con.execute("""CREATE OR REPLACE TABLE my_inserts (thread_name varchar)""") diff --git a/tests/fast/test_all_types.py b/tests/fast/test_all_types.py index 3e701ced..e74cca30 100644 --- a/tests/fast/test_all_types.py +++ b/tests/fast/test_all_types.py @@ -1,14 +1,16 @@ -import duckdb -import pandas as pd -import numpy as np import datetime import math +import warnings +from contextlib import suppress from decimal import Decimal from uuid import UUID -import pytz + +import numpy as np +import pandas as pd import pytest -import warnings -from contextlib import suppress +import pytz + +import duckdb def replace_with_ndarray(obj): @@ -25,7 +27,6 @@ def replace_with_ndarray(obj): # we need to write our own equality function that considers nan==nan for testing purposes def recursive_equality(o1, o2): - import math if type(o1) != type(o2): return False @@ -114,7 +115,7 @@ def recursive_equality(o1, o2): ] -class TestAllTypes(object): +class TestAllTypes: @pytest.mark.parametrize("cur_type", all_types) def test_fetchall(self, cur_type): conn = duckdb.connect() @@ -538,7 +539,7 @@ def test_fetchnumpy(self, cur_type): @pytest.mark.parametrize("cur_type", all_types) def test_arrow(self, cur_type): try: - import pyarrow as pa + pass except: return # We skip those since the extreme ranges are not supported in arrow. diff --git a/tests/fast/test_ambiguous_prepare.py b/tests/fast/test_ambiguous_prepare.py index 998367ec..0865b007 100644 --- a/tests/fast/test_ambiguous_prepare.py +++ b/tests/fast/test_ambiguous_prepare.py @@ -1,9 +1,8 @@ + import duckdb -import pandas as pd -import pytest -class TestAmbiguousPrepare(object): +class TestAmbiguousPrepare: def test_bool(self, duckdb_cursor): conn = duckdb.connect() res = conn.execute("select ?, ?, ?", (True, 42, [1, 2, 3])).fetchall() diff --git a/tests/fast/test_case_alias.py b/tests/fast/test_case_alias.py index 2e42f0ed..5092f099 100644 --- a/tests/fast/test_case_alias.py +++ b/tests/fast/test_case_alias.py @@ -1,17 +1,12 @@ -import pandas -import numpy as np -import datetime -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestCaseAlias(object): +class TestCaseAlias: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_case_alias(self, duckdb_cursor, pandas): - import numpy as np - import datetime - import duckdb con = duckdb.connect(":memory:") diff --git a/tests/fast/test_context_manager.py b/tests/fast/test_context_manager.py index 65ec1d33..b6a9ebb2 100644 --- a/tests/fast/test_context_manager.py +++ b/tests/fast/test_context_manager.py @@ -1,7 +1,7 @@ import duckdb -class TestContextManager(object): +class TestContextManager: def test_context_manager(self, duckdb_cursor): with duckdb.connect(database=":memory:", read_only=False) as con: assert con.execute("select 1").fetchall() == [(1,)] diff --git a/tests/fast/test_duckdb_api.py b/tests/fast/test_duckdb_api.py index ea847d50..d779a368 100644 --- a/tests/fast/test_duckdb_api.py +++ b/tests/fast/test_duckdb_api.py @@ -1,6 +1,7 @@ -import duckdb import sys +import duckdb + def test_duckdb_api(): res = duckdb.execute("SELECT name, value FROM duckdb_settings() WHERE name == 'duckdb_api'") diff --git a/tests/fast/test_expression.py b/tests/fast/test_expression.py index 82753382..049a2a5c 100644 --- a/tests/fast/test_expression.py +++ b/tests/fast/test_expression.py @@ -1,19 +1,20 @@ +import datetime import platform -import duckdb + import pytest -from duckdb.typing import INTEGER, VARCHAR, TIMESTAMP + +import duckdb from duckdb import ( - Expression, - ConstantExpression, + CaseExpression, + CoalesceOperator, ColumnExpression, + ConstantExpression, + FunctionExpression, LambdaExpression, - CoalesceOperator, StarExpression, - FunctionExpression, - CaseExpression, ) -from duckdb.value.constant import Value, IntegerValue -import datetime +from duckdb.typing import INTEGER, TIMESTAMP, VARCHAR +from duckdb.value.constant import IntegerValue, Value pytestmark = pytest.mark.skipif( platform.system() == "Emscripten", @@ -35,10 +36,10 @@ def filter_rel(): ) tbl(a, b) """ ) - yield rel + return rel -class TestExpression(object): +class TestExpression: def test_constant_expression(self): con = duckdb.connect() @@ -839,7 +840,7 @@ def test_filter_and(self, filter_rel): expr = ~expr # AND operator - expr = expr & ("b" != ConstantExpression("b")) + expr = expr & (ConstantExpression("b") != "b") rel2 = filter_rel.filter(expr) res = rel2.fetchall() assert len(res) == 2 diff --git a/tests/fast/test_filesystem.py b/tests/fast/test_filesystem.py index 7b8fbb05..3fd6d60d 100644 --- a/tests/fast/test_filesystem.py +++ b/tests/fast/test_filesystem.py @@ -1,19 +1,19 @@ import logging import sys -from pathlib import Path -from shutil import copyfileobj -from typing import Callable, List from os.path import exists -from pathlib import PurePosixPath +from pathlib import Path, PurePosixPath +from shutil import copyfileobj +from typing import Callable + +from pytest import MonkeyPatch, fixture, importorskip, mark, raises import duckdb from duckdb import DuckDBPyConnection, InvalidInputException -from pytest import raises, importorskip, fixture, MonkeyPatch, mark importorskip("fsspec", "2022.11.0") -from fsspec import filesystem, AbstractFileSystem -from fsspec.implementations.memory import MemoryFileSystem +from fsspec import AbstractFileSystem, filesystem from fsspec.implementations.local import LocalFileOpener, LocalFileSystem +from fsspec.implementations.memory import MemoryFileSystem FILENAME = "integers.csv" @@ -35,13 +35,13 @@ def ceptor(*args, **kwargs): return error_occurred -@fixture() +@fixture def duckdb_cursor(): with duckdb.connect() as conn: yield conn -@fixture() +@fixture def memory(): fs = filesystem("memory", skip_instance_cache=True) diff --git a/tests/fast/test_get_table_names.py b/tests/fast/test_get_table_names.py index 1f90e444..92fa1c39 100644 --- a/tests/fast/test_get_table_names.py +++ b/tests/fast/test_get_table_names.py @@ -1,8 +1,9 @@ -import duckdb import pytest +import duckdb + -class TestGetTableNames(object): +class TestGetTableNames: def test_table_success(self, duckdb_cursor): conn = duckdb.connect() table_names = conn.get_table_names("SELECT * FROM my_table1, my_table2, my_table3") diff --git a/tests/fast/test_import_export.py b/tests/fast/test_import_export.py index d98a2d73..09b8cbda 100644 --- a/tests/fast/test_import_export.py +++ b/tests/fast/test_import_export.py @@ -1,10 +1,11 @@ -import duckdb -import pytest -from os import path import shutil -import os +from os import path from pathlib import Path +import pytest + +import duckdb + def export_database(export_location): # Create the db diff --git a/tests/fast/test_insert.py b/tests/fast/test_insert.py index baae75b4..34489b44 100644 --- a/tests/fast/test_insert.py +++ b/tests/fast/test_insert.py @@ -1,11 +1,11 @@ -import duckdb -import tempfile -import os + import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestInsert(object): +class TestInsert: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_insert(self, pandas): test_df = pandas.DataFrame({"i": [1, 2, 3], "j": ["one", "two", "three"]}) diff --git a/tests/fast/test_json_logging.py b/tests/fast/test_json_logging.py index a7f305f3..b29ea7bf 100644 --- a/tests/fast/test_json_logging.py +++ b/tests/fast/test_json_logging.py @@ -1,8 +1,9 @@ import json -import duckdb import pytest +import duckdb + def _parse_json_func(error_prefix: str): """Helper to check that the error message is indeed parsable json""" diff --git a/tests/fast/test_many_con_same_file.py b/tests/fast/test_many_con_same_file.py index 3cef2494..79b5db68 100644 --- a/tests/fast/test_many_con_same_file.py +++ b/tests/fast/test_many_con_same_file.py @@ -1,7 +1,9 @@ -import duckdb import os + import pytest +import duckdb + def get_tables(con): tbls = con.execute("SHOW TABLES").fetchall() diff --git a/tests/fast/test_map.py b/tests/fast/test_map.py index f86dd60b..1ce63110 100644 --- a/tests/fast/test_map.py +++ b/tests/fast/test_map.py @@ -1,9 +1,10 @@ -import duckdb -import numpy -import pytest -from datetime import date, timedelta import re -from conftest import NumpyPandas, ArrowPandas +from datetime import date, timedelta + +import pytest +from conftest import ArrowPandas, NumpyPandas + +import duckdb # column count differs from bind @@ -14,7 +15,7 @@ def evil1(df): return df -class TestMap(object): +class TestMap: @pytest.mark.parametrize("pandas", [NumpyPandas()]) def test_evil_map(self, duckdb_cursor, pandas): testrel = duckdb.values([1, 2]) diff --git a/tests/fast/test_metatransaction.py b/tests/fast/test_metatransaction.py index f617cba2..35d7c239 100644 --- a/tests/fast/test_metatransaction.py +++ b/tests/fast/test_metatransaction.py @@ -7,7 +7,7 @@ NUMBER_OF_COLUMNS = 1 -class TestMetaTransaction(object): +class TestMetaTransaction: def test_fetchmany(self, duckdb_cursor): duckdb_cursor.execute("CREATE SEQUENCE id_seq") column_names = ",\n".join([f"column_{i} FLOAT" for i in range(1, NUMBER_OF_COLUMNS + 1)]) diff --git a/tests/fast/test_multi_statement.py b/tests/fast/test_multi_statement.py index 722ab31a..cd3111e6 100644 --- a/tests/fast/test_multi_statement.py +++ b/tests/fast/test_multi_statement.py @@ -1,11 +1,11 @@ -import duckdb import os import shutil +import duckdb + -class TestMultiStatement(object): +class TestMultiStatement: def test_multi_statement(self, duckdb_cursor): - import duckdb con = duckdb.connect(":memory:") diff --git a/tests/fast/test_multithread.py b/tests/fast/test_multithread.py index 628aacd8..aeeeb412 100644 --- a/tests/fast/test_multithread.py +++ b/tests/fast/test_multithread.py @@ -1,13 +1,13 @@ +import os import platform -import duckdb -import pytest -import threading import queue as Queue +import threading + import numpy as np -from conftest import NumpyPandas, ArrowPandas -import os -from typing import List +import pytest +from conftest import ArrowPandas, NumpyPandas +import duckdb pytestmark = pytest.mark.xfail( condition=platform.system() == "Emscripten", @@ -36,7 +36,7 @@ def multithread_test(self, result_verification=everything_succeeded): queue = Queue.Queue() # Create all threads - for i in range(0, self.duckdb_insert_thread_count): + for i in range(self.duckdb_insert_thread_count): self.threads.append( threading.Thread( target=self.thread_function, args=(duckdb_conn, queue, self.pandas), name="duckdb_thread_" + str(i) @@ -45,13 +45,13 @@ def multithread_test(self, result_verification=everything_succeeded): # Record for every thread if they succeeded or not thread_results = [] - for i in range(0, len(self.threads)): + for i in range(len(self.threads)): self.threads[i].start() thread_result: bool = queue.get(timeout=60) thread_results.append(thread_result) # Finish all threads - for i in range(0, len(self.threads)): + for i in range(len(self.threads)): self.threads[i].join() # Assert that the results are what we expected @@ -374,7 +374,7 @@ def cursor(duckdb_conn, queue, pandas): queue.put(True) -class TestDuckMultithread(object): +class TestDuckMultithread: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_execute(self, duckdb_cursor, pandas): duck_threads = DuckDBThreaded(10, execute_query, pandas) diff --git a/tests/fast/test_non_default_conn.py b/tests/fast/test_non_default_conn.py index cb0218e3..06cd5fe5 100644 --- a/tests/fast/test_non_default_conn.py +++ b/tests/fast/test_non_default_conn.py @@ -1,11 +1,12 @@ -import pandas as pd -import numpy as np -import duckdb import os import tempfile +import pandas as pd + +import duckdb + -class TestNonDefaultConn(object): +class TestNonDefaultConn: def test_values(self, duckdb_cursor): duckdb_cursor.execute("create table t (a integer)") duckdb.values([1], connection=duckdb_cursor).insert_into("t") diff --git a/tests/fast/test_parameter_list.py b/tests/fast/test_parameter_list.py index 5a85ac2f..a28838ba 100644 --- a/tests/fast/test_parameter_list.py +++ b/tests/fast/test_parameter_list.py @@ -1,9 +1,10 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb -class TestParameterList(object): +class TestParameterList: def test_bool(self, duckdb_cursor): conn = duckdb.connect() conn.execute("create table bool_table (a bool)") diff --git a/tests/fast/test_parquet.py b/tests/fast/test_parquet.py index 61d74023..fd506da2 100644 --- a/tests/fast/test_parquet.py +++ b/tests/fast/test_parquet.py @@ -1,8 +1,8 @@ -import duckdb -import pytest import os -import tempfile -import pandas as pd + +import pytest + +import duckdb VARCHAR = duckdb.typing.VARCHAR BIGINT = duckdb.typing.BIGINT @@ -17,7 +17,7 @@ def tmp_parquets(tmp_path_factory): return tmp_parquets -class TestParquet(object): +class TestParquet: def test_scan_binary(self, duckdb_cursor): conn = duckdb.connect() res = conn.execute("SELECT typeof(#1) FROM parquet_scan('" + filename + "') limit 1").fetchall() diff --git a/tests/fast/test_pypi_cleanup.py b/tests/fast/test_pypi_cleanup.py index 84d4c9ff..0e0439ce 100644 --- a/tests/fast/test_pypi_cleanup.py +++ b/tests/fast/test_pypi_cleanup.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -""" -Unit tests for pypi_cleanup.py +"""Unit tests for pypi_cleanup.py Run with: python -m pytest test_pypi_cleanup.py -v """ @@ -15,18 +14,18 @@ duckdb_packaging = pytest.importorskip("duckdb_packaging") from duckdb_packaging.pypi_cleanup import ( - PyPICleanup, + AuthenticationError, CsrfParser, + PyPICleanup, PyPICleanupError, - AuthenticationError, ValidationError, - setup_logging, - validate_username, create_argument_parser, - session_with_retries, load_credentials, - validate_arguments, main, + session_with_retries, + setup_logging, + validate_arguments, + validate_username, ) @@ -116,7 +115,7 @@ def test_create_session_with_retries(self): # Verify retry adapter is mounted adapter = session.get_adapter("https://example.com") assert hasattr(adapter, "max_retries") - retries = getattr(adapter, "max_retries") + retries = adapter.max_retries assert isinstance(retries, Retry) @patch("duckdb_packaging.pypi_cleanup.logging.basicConfig") diff --git a/tests/fast/test_pytorch.py b/tests/fast/test_pytorch.py index c5b9b4d6..c0b9392d 100644 --- a/tests/fast/test_pytorch.py +++ b/tests/fast/test_pytorch.py @@ -1,6 +1,6 @@ -import duckdb import pytest +import duckdb torch = pytest.importorskip("torch") diff --git a/tests/fast/test_relation.py b/tests/fast/test_relation.py index 31ca393c..6628198f 100644 --- a/tests/fast/test_relation.py +++ b/tests/fast/test_relation.py @@ -1,16 +1,17 @@ -import duckdb -import numpy as np +import datetime +import gc +import os import platform import tempfile -import os + +import numpy as np import pandas as pd import pytest from conftest import ArrowPandas, NumpyPandas -import datetime -import gc -from duckdb import ColumnExpression -from duckdb.typing import BIGINT, VARCHAR, TINYINT, BOOLEAN +import duckdb +from duckdb import ColumnExpression +from duckdb.typing import BIGINT, BOOLEAN, TINYINT, VARCHAR @pytest.fixture(scope="session") @@ -25,7 +26,7 @@ def get_relation(conn): return conn.from_df(test_df) -class TestRelation(object): +class TestRelation: def test_csv_auto(self): conn = duckdb.connect() df_rel = get_relation(conn) diff --git a/tests/fast/test_relation_dependency_leak.py b/tests/fast/test_relation_dependency_leak.py index ee98e30a..73ea7df7 100644 --- a/tests/fast/test_relation_dependency_leak.py +++ b/tests/fast/test_relation_dependency_leak.py @@ -1,5 +1,6 @@ -import numpy as np import os + +import numpy as np import pytest try: @@ -8,8 +9,7 @@ can_run = True except ImportError: can_run = False -from conftest import NumpyPandas, ArrowPandas - +from conftest import ArrowPandas, NumpyPandas psutil = pytest.importorskip("psutil") @@ -46,7 +46,7 @@ def pandas_replacement(pandas, duckdb_cursor): duckdb_cursor.query("select sum(x) from df").fetchall() -class TestRelationDependencyMemoryLeak(object): +class TestRelationDependencyMemoryLeak: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_from_arrow_leak(self, pandas, duckdb_cursor): if not can_run: diff --git a/tests/fast/test_replacement_scan.py b/tests/fast/test_replacement_scan.py index 555773dc..c9d9ae3a 100644 --- a/tests/fast/test_replacement_scan.py +++ b/tests/fast/test_replacement_scan.py @@ -1,7 +1,9 @@ -import duckdb import os + import pytest +import duckdb + pa = pytest.importorskip("pyarrow") pl = pytest.importorskip("polars") pd = pytest.importorskip("pandas") @@ -9,7 +11,7 @@ def using_table(con, to_scan, object_name): local_scope = {"con": con, object_name: to_scan, "object_name": object_name} - exec(f"result = con.table(object_name)", globals(), local_scope) + exec("result = con.table(object_name)", globals(), local_scope) return local_scope["result"] @@ -75,7 +77,7 @@ def create_relation(conn, query: str) -> duckdb.DuckDBPyRelation: return conn.sql(query) -class TestReplacementScan(object): +class TestReplacementScan: def test_csv_replacement(self): con = duckdb.connect() filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "integers.csv") diff --git a/tests/fast/test_result.py b/tests/fast/test_result.py index 906b1198..38ae1de6 100644 --- a/tests/fast/test_result.py +++ b/tests/fast/test_result.py @@ -1,9 +1,11 @@ -import duckdb -import pytest import datetime +import pytest + +import duckdb + -class TestPythonResult(object): +class TestPythonResult: def test_result_closed(self, duckdb_cursor): connection = duckdb.connect("") cursor = connection.cursor() diff --git a/tests/fast/test_runtime_error.py b/tests/fast/test_runtime_error.py index 327be004..7ab160bb 100644 --- a/tests/fast/test_runtime_error.py +++ b/tests/fast/test_runtime_error.py @@ -1,12 +1,13 @@ -import duckdb import pytest -from conftest import NumpyPandas, ArrowPandas +from conftest import ArrowPandas, NumpyPandas + +import duckdb closed = lambda: pytest.raises(duckdb.ConnectionException, match="Connection already closed") no_result_set = lambda: pytest.raises(duckdb.InvalidInputException, match="No open result set") -class TestRuntimeError(object): +class TestRuntimeError: def test_fetch_error(self): con = duckdb.connect() con.execute("create table tbl as select 'hello' i") diff --git a/tests/fast/test_sql_expression.py b/tests/fast/test_sql_expression.py index 4dc4cab5..f3cf41ca 100644 --- a/tests/fast/test_sql_expression.py +++ b/tests/fast/test_sql_expression.py @@ -1,5 +1,6 @@ -import duckdb import pytest + +import duckdb from duckdb import ( ColumnExpression, ConstantExpression, @@ -7,7 +8,7 @@ ) -class TestSQLExpression(object): +class TestSQLExpression: def test_sql_expression_basic(self, duckdb_cursor): # Test simple constant expressions expr = SQLExpression("42") diff --git a/tests/fast/test_string_annotation.py b/tests/fast/test_string_annotation.py index 83685bed..17c22844 100644 --- a/tests/fast/test_string_annotation.py +++ b/tests/fast/test_string_annotation.py @@ -1,7 +1,6 @@ -import duckdb -import pytest import sys -from typing import Union + +import pytest def make_annotated_function(type: str): @@ -19,7 +18,6 @@ def test_base(): def python_version_lower_than_3_10(): - import sys if sys.version_info[0] < 3: return True @@ -28,7 +26,7 @@ def python_version_lower_than_3_10(): return False -class TestStringAnnotation(object): +class TestStringAnnotation: @pytest.mark.skipif( python_version_lower_than_3_10(), reason="inspect.signature(eval_str=True) only supported since 3.10 and higher" ) diff --git a/tests/fast/test_tf.py b/tests/fast/test_tf.py index db93d0de..ceec2ee0 100644 --- a/tests/fast/test_tf.py +++ b/tests/fast/test_tf.py @@ -1,6 +1,6 @@ -import duckdb import pytest +import duckdb tf = pytest.importorskip("tensorflow") diff --git a/tests/fast/test_transaction.py b/tests/fast/test_transaction.py index ff0ba1a7..4a06c9e7 100644 --- a/tests/fast/test_transaction.py +++ b/tests/fast/test_transaction.py @@ -1,8 +1,8 @@ + import duckdb -import pandas as pd -class TestConnectionTransaction(object): +class TestConnectionTransaction: def test_transaction(self, duckdb_cursor): con = duckdb.connect() con.execute("create table t (i integer)") diff --git a/tests/fast/test_type.py b/tests/fast/test_type.py index 1e8ebc25..768b7782 100644 --- a/tests/fast/test_type.py +++ b/tests/fast/test_type.py @@ -1,47 +1,46 @@ -import duckdb -import os -import pandas as pd -import pytest -from typing import Union, Optional import sys +from typing import Optional, Union + +import pytest +import duckdb +import duckdb.typing from duckdb.typing import ( - SQLNULL, - BOOLEAN, - TINYINT, - UTINYINT, - SMALLINT, - USMALLINT, - INTEGER, - UINTEGER, BIGINT, - UBIGINT, - HUGEINT, - UHUGEINT, - UUID, - FLOAT, - DOUBLE, + BIT, + BLOB, + BOOLEAN, DATE, + DOUBLE, + FLOAT, + HUGEINT, + INTEGER, + INTERVAL, + SMALLINT, + SQLNULL, + TIME, + TIME_TZ, TIMESTAMP, TIMESTAMP_MS, TIMESTAMP_NS, TIMESTAMP_S, - DuckDBPyType, - TIME, - TIME_TZ, TIMESTAMP_TZ, + TINYINT, + UBIGINT, + UHUGEINT, + UINTEGER, + USMALLINT, + UTINYINT, + UUID, VARCHAR, - BLOB, - BIT, - INTERVAL, + DuckDBPyType, ) -import duckdb.typing -class TestType(object): +class TestType: def test_sqltype(self): assert str(duckdb.sqltype("struct(a VARCHAR, b BIGINT)")) == "STRUCT(a VARCHAR, b BIGINT)" - # todo: add tests with invalid type_str + # TODO: add tests with invalid type_str def test_primitive_types(self): assert str(SQLNULL) == '"NULL"' @@ -118,7 +117,6 @@ def test_union_type(self): type = duckdb.union_type({"a": BIGINT, "b": VARCHAR, "c": TINYINT}) assert str(type) == "UNION(a BIGINT, b VARCHAR, c TINYINT)" - import sys @pytest.mark.skipif(sys.version_info < (3, 9), reason="requires >= python3.9") def test_implicit_convert_from_builtin_type(self): diff --git a/tests/fast/test_type_explicit.py b/tests/fast/test_type_explicit.py index 7b0797e6..3b9fe334 100644 --- a/tests/fast/test_type_explicit.py +++ b/tests/fast/test_type_explicit.py @@ -1,7 +1,7 @@ import duckdb -class TestMap(object): +class TestMap: def test_array_list_tuple_ambiguity(self): con = duckdb.connect() res = con.sql("SELECT $arg", params={"arg": (1, 2)}).fetchall()[0][0] diff --git a/tests/fast/test_unicode.py b/tests/fast/test_unicode.py index 7d08ac88..f1ed8501 100644 --- a/tests/fast/test_unicode.py +++ b/tests/fast/test_unicode.py @@ -1,11 +1,11 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -import duckdb import pandas as pd +import duckdb + -class TestUnicode(object): +class TestUnicode: def test_unicode_pandas_scan(self, duckdb_cursor): con = duckdb.connect(database=":memory:", read_only=False) test_df = pd.DataFrame.from_dict({"i": [1, 2, 3], "j": ["a", "c", "ë"]}) diff --git a/tests/fast/test_union.py b/tests/fast/test_union.py index 912caff9..d47a8192 100644 --- a/tests/fast/test_union.py +++ b/tests/fast/test_union.py @@ -1,8 +1,8 @@ + import duckdb -import pandas as pd -class TestUnion(object): +class TestUnion: def test_union_by_all(self): connection = duckdb.connect() diff --git a/tests/fast/test_value.py b/tests/fast/test_value.py index c17264fd..9e446fc3 100644 --- a/tests/fast/test_value.py +++ b/tests/fast/test_value.py @@ -1,74 +1,68 @@ -import duckdb -from pytest import raises -from duckdb import NotImplementedException, InvalidInputException -from duckdb.value.constant import ( - Value, - NullValue, - BooleanValue, - UnsignedBinaryValue, - UnsignedShortValue, - UnsignedIntegerValue, - UnsignedLongValue, - BinaryValue, - ShortValue, - IntegerValue, - LongValue, - HugeIntegerValue, - UnsignedHugeIntegerValue, - FloatValue, - DoubleValue, - DecimalValue, - StringValue, - UUIDValue, - BitValue, - BlobValue, - DateValue, - IntervalValue, - TimestampValue, - TimestampSecondValue, - TimestampMilisecondValue, - TimestampNanosecondValue, - TimestampTimeZoneValue, - TimeValue, - TimeTimeZoneValue, -) -import uuid import datetime -import pytest import decimal +import uuid + +import pytest +from pytest import raises +import duckdb +from duckdb import InvalidInputException, NotImplementedException from duckdb.typing import ( - SQLNULL, + BIGINT, + BIT, + BLOB, BOOLEAN, - TINYINT, - UTINYINT, - SMALLINT, - USMALLINT, + DATE, + DOUBLE, + FLOAT, + HUGEINT, INTEGER, - UINTEGER, - BIGINT, + INTERVAL, + SMALLINT, + SQLNULL, + TIME, + TIMESTAMP, + TINYINT, UBIGINT, - HUGEINT, UHUGEINT, + UINTEGER, + USMALLINT, + UTINYINT, UUID, - FLOAT, - DOUBLE, - DATE, - TIMESTAMP, - TIMESTAMP_MS, - TIMESTAMP_NS, - TIMESTAMP_S, - TIME, - TIME_TZ, - TIMESTAMP_TZ, VARCHAR, - BLOB, - BIT, - INTERVAL, +) +from duckdb.value.constant import ( + BinaryValue, + BitValue, + BlobValue, + BooleanValue, + DateValue, + DecimalValue, + DoubleValue, + FloatValue, + HugeIntegerValue, + IntegerValue, + IntervalValue, + LongValue, + NullValue, + ShortValue, + StringValue, + TimestampMilisecondValue, + TimestampNanosecondValue, + TimestampSecondValue, + TimestampValue, + TimeValue, + UnsignedBinaryValue, + UnsignedHugeIntegerValue, + UnsignedIntegerValue, + UnsignedLongValue, + UnsignedShortValue, + UUIDValue, + Value, ) -class TestValue(object): +class TestValue: # This excludes timezone aware values, as those are a pain to test @pytest.mark.parametrize( "item", diff --git a/tests/fast/test_version.py b/tests/fast/test_version.py index cdeb42b0..81f72855 100644 --- a/tests/fast/test_version.py +++ b/tests/fast/test_version.py @@ -1,6 +1,7 @@ -import duckdb import sys +import duckdb + def test_version(): assert duckdb.__version__ != "0.0.0" diff --git a/tests/fast/test_versioning.py b/tests/fast/test_versioning.py index 2ec3f784..207b24fe 100644 --- a/tests/fast/test_versioning.py +++ b/tests/fast/test_versioning.py @@ -1,25 +1,24 @@ -""" -Tests for duckdb_pytooling versioning functionality. +"""Tests for duckdb_pytooling versioning functionality. """ import os +import subprocess import unittest +from unittest.mock import MagicMock, patch import pytest -import subprocess -from unittest.mock import patch, MagicMock duckdb_packaging = pytest.importorskip("duckdb_packaging") from duckdb_packaging._versioning import ( - parse_version, format_version, - git_tag_to_pep440, - pep440_to_git_tag, get_current_version, get_git_describe, + git_tag_to_pep440, + parse_version, + pep440_to_git_tag, ) -from duckdb_packaging.setuptools_scm_version import _bump_version, version_scheme, forced_version_from_env +from duckdb_packaging.setuptools_scm_version import _bump_version, forced_version_from_env, version_scheme class TestVersionParsing(unittest.TestCase): diff --git a/tests/fast/test_windows_abs_path.py b/tests/fast/test_windows_abs_path.py index 4ce8311b..7cc31d0b 100644 --- a/tests/fast/test_windows_abs_path.py +++ b/tests/fast/test_windows_abs_path.py @@ -1,10 +1,10 @@ -import duckdb -import pytest import os import shutil +import duckdb + -class TestWindowsAbsPath(object): +class TestWindowsAbsPath: def test_windows_path_accent(self): if os.name != "nt": return diff --git a/tests/fast/types/test_blob.py b/tests/fast/types/test_blob.py index 0d331f7f..74f7f0b8 100644 --- a/tests/fast/types/test_blob.py +++ b/tests/fast/types/test_blob.py @@ -1,8 +1,7 @@ -import duckdb import numpy -class TestBlob(object): +class TestBlob: def test_blob(self, duckdb_cursor): duckdb_cursor.execute("SELECT BLOB 'hello'") results = duckdb_cursor.fetchall() diff --git a/tests/fast/types/test_boolean.py b/tests/fast/types/test_boolean.py index 8e8d2147..5a519e51 100644 --- a/tests/fast/types/test_boolean.py +++ b/tests/fast/types/test_boolean.py @@ -1,8 +1,6 @@ -import duckdb -import numpy -class TestBoolean(object): +class TestBoolean: def test_bool(self, duckdb_cursor): duckdb_cursor.execute("SELECT TRUE") results = duckdb_cursor.fetchall() diff --git a/tests/fast/types/test_datetime_date.py b/tests/fast/types/test_datetime_date.py index 9efb6bd1..d1c3d30b 100644 --- a/tests/fast/types/test_datetime_date.py +++ b/tests/fast/types/test_datetime_date.py @@ -1,8 +1,9 @@ -import duckdb import datetime +import duckdb + -class TestDateTimeDate(object): +class TestDateTimeDate: def test_date_infinity(self): con = duckdb.connect() # Positive infinity diff --git a/tests/fast/types/test_datetime_datetime.py b/tests/fast/types/test_datetime_datetime.py index 2df14b18..c486f9c9 100644 --- a/tests/fast/types/test_datetime_datetime.py +++ b/tests/fast/types/test_datetime_datetime.py @@ -1,7 +1,9 @@ -import duckdb import datetime + import pytest +import duckdb + def create_query(positive, type): inf = "infinity" if positive else "-infinity" @@ -10,7 +12,7 @@ def create_query(positive, type): """ -class TestDateTimeDateTime(object): +class TestDateTimeDateTime: @pytest.mark.parametrize("positive", [True, False]) @pytest.mark.parametrize( "type", diff --git a/tests/fast/types/test_decimal.py b/tests/fast/types/test_decimal.py index b068056d..8be55e44 100644 --- a/tests/fast/types/test_decimal.py +++ b/tests/fast/types/test_decimal.py @@ -1,9 +1,9 @@ -import numpy -import pandas from decimal import * +import numpy + -class TestDecimal(object): +class TestDecimal: def test_decimal(self, duckdb_cursor): duckdb_cursor.execute( "SELECT 1.2::DECIMAL(4,1), 100.3::DECIMAL(9,1), 320938.4298::DECIMAL(18,4), 49082094824.904820482094::DECIMAL(30,12), NULL::DECIMAL" diff --git a/tests/fast/types/test_hugeint.py b/tests/fast/types/test_hugeint.py index e9b5016a..aa8c900d 100644 --- a/tests/fast/types/test_hugeint.py +++ b/tests/fast/types/test_hugeint.py @@ -1,8 +1,7 @@ import numpy -import pandas -class TestHugeint(object): +class TestHugeint: def test_hugeint(self, duckdb_cursor): duckdb_cursor.execute("SELECT 437894723897234238947043214") result = duckdb_cursor.fetchall() diff --git a/tests/fast/types/test_nan.py b/tests/fast/types/test_nan.py index fe99a990..8ffbe1bc 100644 --- a/tests/fast/types/test_nan.py +++ b/tests/fast/types/test_nan.py @@ -1,12 +1,14 @@ -import numpy as np import datetime -import duckdb + +import numpy as np import pytest +import duckdb + pandas = pytest.importorskip("pandas") -class TestPandasNaN(object): +class TestPandasNaN: def test_pandas_nan(self, duckdb_cursor): # create a DataFrame with some basic values df = pandas.DataFrame([{"col1": "val1", "col2": 1.05}, {"col1": "val3", "col2": np.nan}]) diff --git a/tests/fast/types/test_nested.py b/tests/fast/types/test_nested.py index 7f777384..e82673c7 100644 --- a/tests/fast/types/test_nested.py +++ b/tests/fast/types/test_nested.py @@ -1,7 +1,6 @@ -import duckdb -class TestNested(object): +class TestNested: def test_lists(self, duckdb_cursor): result = duckdb_cursor.execute("SELECT LIST_VALUE(1, 2, 3, 4) ").fetchall() assert result == [([1, 2, 3, 4],)] diff --git a/tests/fast/types/test_null.py b/tests/fast/types/test_null.py index fa4105b6..e5fe2e3d 100644 --- a/tests/fast/types/test_null.py +++ b/tests/fast/types/test_null.py @@ -1,7 +1,6 @@ -import traceback -class TestNull(object): +class TestNull: def test_fetchone_null(self, duckdb_cursor): duckdb_cursor.execute("CREATE TABLE atable (Value int)") duckdb_cursor.execute("INSERT INTO atable VALUES (1)") diff --git a/tests/fast/types/test_numeric.py b/tests/fast/types/test_numeric.py index f25b72b1..174700aa 100644 --- a/tests/fast/types/test_numeric.py +++ b/tests/fast/types/test_numeric.py @@ -1,5 +1,3 @@ -import duckdb -import numpy def check_result(duckdb_cursor, value, type): @@ -8,7 +6,7 @@ def check_result(duckdb_cursor, value, type): assert results[0][0] == value -class TestNumeric(object): +class TestNumeric: def test_numeric_results(self, duckdb_cursor): check_result(duckdb_cursor, 1, "TINYINT") check_result(duckdb_cursor, 1, "SMALLINT") diff --git a/tests/fast/types/test_numpy.py b/tests/fast/types/test_numpy.py index 40b1a5de..b5fe6b3c 100644 --- a/tests/fast/types/test_numpy.py +++ b/tests/fast/types/test_numpy.py @@ -1,10 +1,11 @@ -import duckdb -import numpy as np import datetime -import pytest + +import numpy as np + +import duckdb -class TestNumpyDatetime64(object): +class TestNumpyDatetime64: def test_numpy_datetime64(self, duckdb_cursor): duckdb_con = duckdb.connect() diff --git a/tests/fast/types/test_object_int.py b/tests/fast/types/test_object_int.py index ed3a8d14..f0665535 100644 --- a/tests/fast/types/test_object_int.py +++ b/tests/fast/types/test_object_int.py @@ -1,12 +1,13 @@ -import numpy as np -import datetime -import duckdb -import pytest import warnings from contextlib import suppress +import numpy as np +import pytest + +import duckdb + -class TestPandasObjectInteger(object): +class TestPandasObjectInteger: # Signed Masked Integer types def test_object_integer(self, duckdb_cursor): pd = pytest.importorskip("pandas") diff --git a/tests/fast/types/test_time_tz.py b/tests/fast/types/test_time_tz.py index eceed79a..2215a046 100644 --- a/tests/fast/types/test_time_tz.py +++ b/tests/fast/types/test_time_tz.py @@ -1,17 +1,16 @@ -import numpy as np +import datetime from datetime import time, timezone -import duckdb + import pytest -import datetime pandas = pytest.importorskip("pandas") -class TestTimeTz(object): +class TestTimeTz: def test_time_tz(self, duckdb_cursor): df = pandas.DataFrame({"col1": [time(1, 2, 3, tzinfo=timezone.utc)]}) - sql = f"SELECT * FROM df" + sql = "SELECT * FROM df" duckdb_cursor.execute(sql) diff --git a/tests/fast/types/test_unsigned.py b/tests/fast/types/test_unsigned.py index a35a2216..5639d33b 100644 --- a/tests/fast/types/test_unsigned.py +++ b/tests/fast/types/test_unsigned.py @@ -1,4 +1,4 @@ -class TestUnsigned(object): +class TestUnsigned: def test_unsigned(self, duckdb_cursor): duckdb_cursor.execute("create table unsigned (a utinyint, b usmallint, c uinteger, d ubigint)") duckdb_cursor.execute("insert into unsigned values (1,1,1,1), (null,null,null,null)") diff --git a/tests/fast/udf/test_null_filtering.py b/tests/fast/udf/test_null_filtering.py index fd5b45d0..db86168c 100644 --- a/tests/fast/udf/test_null_filtering.py +++ b/tests/fast/udf/test_null_filtering.py @@ -1,15 +1,12 @@ -import duckdb import pytest +import duckdb + pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow", "18.0.0") -from typing import Union -import pyarrow.compute as pc -import uuid import datetime -import numpy as np -import cmath -from typing import NamedTuple, Any, List +import uuid +from typing import Any, NamedTuple from duckdb.typing import * @@ -152,7 +149,7 @@ def construct_parameters(tuples, dbtype): return parameters -class TestUDFNullFiltering(object): +class TestUDFNullFiltering: @pytest.mark.parametrize( "table_data", get_table_data(), diff --git a/tests/fast/udf/test_remove_function.py b/tests/fast/udf/test_remove_function.py index d03fd7e6..c909c61d 100644 --- a/tests/fast/udf/test_remove_function.py +++ b/tests/fast/udf/test_remove_function.py @@ -1,20 +1,15 @@ -import duckdb -import os + import pytest +import duckdb + pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow") -from typing import Union -import pyarrow.compute as pc -import uuid -import datetime -import numpy as np -import cmath from duckdb.typing import * -class TestRemoveFunction(object): +class TestRemoveFunction: def test_not_created(self): con = duckdb.connect() with pytest.raises( diff --git a/tests/fast/udf/test_scalar.py b/tests/fast/udf/test_scalar.py index c156f94b..e8b1e6d9 100644 --- a/tests/fast/udf/test_scalar.py +++ b/tests/fast/udf/test_scalar.py @@ -1,15 +1,16 @@ -import duckdb -import os + import pytest +import duckdb + pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow", "18.0.0") -from typing import Union, Any -import pyarrow.compute as pc -import uuid +import cmath import datetime +import uuid +from typing import Any + import numpy as np -import cmath from duckdb.typing import * @@ -29,7 +30,7 @@ def test_base(x): return test_function -class TestScalarUDF(object): +class TestScalarUDF: @pytest.mark.parametrize("function_type", ["native", "arrow"]) @pytest.mark.parametrize( "test_type", @@ -69,16 +70,16 @@ def test_type_coverage(self, test_type, function_type): con = duckdb.connect() con.create_function("test", test_function, type=function_type) # Single value - res = con.execute(f"select test(?::{str(type)})", [value]).fetchall() + res = con.execute(f"select test(?::{type!s})", [value]).fetchall() assert res[0][0] == value # NULLs - res = con.execute(f"select res from (select ?, test(NULL::{str(type)}) as res)", [value]).fetchall() + res = con.execute(f"select res from (select ?, test(NULL::{type!s}) as res)", [value]).fetchall() assert res[0][0] == None # Multiple chunks size = duckdb.__standard_vector_size__ * 3 - res = con.execute(f"select test(x) from repeat(?::{str(type)}, {size}) as tbl(x)", [value]).fetchall() + res = con.execute(f"select test(x) from repeat(?::{type!s}, {size}) as tbl(x)", [value]).fetchall() assert len(res) == size # Mixed NULL/NON-NULL @@ -88,7 +89,7 @@ def test_type_coverage(self, test_type, function_type): f""" select test( case when (x > 0.5) then - ?::{str(type)} + ?::{type!s} else NULL end @@ -102,7 +103,7 @@ def test_type_coverage(self, test_type, function_type): f""" select case when (x > 0.5) then - ?::{str(type)} + ?::{type!s} else NULL end @@ -113,7 +114,7 @@ def test_type_coverage(self, test_type, function_type): assert expected == actual # Using 'relation.project' - con.execute(f"create table tbl as select ?::{str(type)} as x", [value]) + con.execute(f"create table tbl as select ?::{type!s} as x", [value]) table_rel = con.table("tbl") res = table_rel.project("test(x)").fetchall() assert res[0][0] == value @@ -221,7 +222,6 @@ def return_np_nan(): @pytest.mark.parametrize("duckdb_type", [FLOAT, DOUBLE]) def test_math_nan(self, duckdb_type, udf_type): def return_math_nan(): - import cmath if udf_type == "native": return cmath.nan diff --git a/tests/fast/udf/test_scalar_arrow.py b/tests/fast/udf/test_scalar_arrow.py index 794ebc35..856c760d 100644 --- a/tests/fast/udf/test_scalar_arrow.py +++ b/tests/fast/udf/test_scalar_arrow.py @@ -1,18 +1,15 @@ -import duckdb -import os + import pytest +import duckdb + pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow") -from typing import Union -import pyarrow.compute as pc -import uuid -import datetime from duckdb.typing import * -class TestPyArrowUDF(object): +class TestPyArrowUDF: def test_basic_use(self): def plus_one(x): table = pa.lib.Table.from_arrays([x], names=["c0"]) @@ -24,7 +21,7 @@ def plus_one(x): con = duckdb.connect() con.create_function("plus_one", plus_one, [BIGINT], BIGINT, type="arrow") - assert [(6,)] == con.sql("select plus_one(5)").fetchall() + assert con.sql("select plus_one(5)").fetchall() == [(6,)] range_table = con.table_function("range", [5000]) res = con.sql("select plus_one(i) from range_table tbl(i)").fetchall() @@ -125,7 +122,6 @@ def return_too_many(col): res = con.sql("""select too_many_tuples(5)""").fetchall() def test_arrow_side_effects(self, duckdb_cursor): - import random as r def random_arrow(x): if not hasattr(random_arrow, "data"): diff --git a/tests/fast/udf/test_scalar_native.py b/tests/fast/udf/test_scalar_native.py index 0c5cf927..94b2949e 100644 --- a/tests/fast/udf/test_scalar_native.py +++ b/tests/fast/udf/test_scalar_native.py @@ -1,12 +1,11 @@ -import duckdb -import os -import pandas as pd + import pytest +import duckdb from duckdb.typing import * -class TestNativeUDF(object): +class TestNativeUDF: def test_default_conn(self): def passthrough(x): return x @@ -23,7 +22,7 @@ def plus_one(x): con = duckdb.connect() con.create_function("plus_one", plus_one, [BIGINT], BIGINT) - assert [(6,)] == con.sql("select plus_one(5)").fetchall() + assert con.sql("select plus_one(5)").fetchall() == [(6,)] range_table = con.table_function("range", [5000]) res = con.sql("select plus_one(i) from range_table tbl(i)").fetchall() diff --git a/tests/fast/udf/test_transactionality.py b/tests/fast/udf/test_transactionality.py index 134df663..acad21ef 100644 --- a/tests/fast/udf/test_transactionality.py +++ b/tests/fast/udf/test_transactionality.py @@ -1,8 +1,9 @@ -import duckdb import pytest +import duckdb + -class TestUDFTransactionality(object): +class TestUDFTransactionality: @pytest.mark.xfail(reason="fetchone() does not realize the stream result was closed before completion") def test_type_coverage(self, duckdb_cursor): rel = duckdb_cursor.sql("select * from range(4096)") diff --git a/tests/slow/test_h2oai_arrow.py b/tests/slow/test_h2oai_arrow.py index b0901ab8..d0dbc2fe 100644 --- a/tests/slow/test_h2oai_arrow.py +++ b/tests/slow/test_h2oai_arrow.py @@ -1,7 +1,9 @@ -import duckdb -import os import math -from pytest import mark, fixture, importorskip +import os + +from pytest import fixture, importorskip, mark + +import duckdb read_csv = importorskip("pyarrow.csv").read_csv requests = importorskip("requests") @@ -153,7 +155,7 @@ def join_by_q5(con): con.execute("DROP TABLE ans") -class TestH2OAIArrow(object): +class TestH2OAIArrow: @mark.parametrize( "function", [ From ac6ecfde6f325251cb9e73a150c3335b7bdc15ac Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:29:34 +0200 Subject: [PATCH 214/472] Ruff format fixes --- adbc_driver_duckdb/dbapi.py | 3 +- duckdb/experimental/spark/errors/__init__.py | 3 +- .../spark/errors/exceptions/base.py | 72 +++++++------------ duckdb/experimental/spark/errors/utils.py | 6 +- duckdb/experimental/spark/sql/dataframe.py | 3 +- duckdb/experimental/spark/sql/streaming.py | 1 - duckdb/experimental/spark/sql/types.py | 21 ++---- duckdb/polars_io.py | 3 +- duckdb/udf.py | 3 +- duckdb_packaging/pypi_cleanup.py | 3 - sqllogic/conftest.py | 3 +- tests/conftest.py | 1 - tests/fast/api/test_attribute_getter.py | 1 - tests/fast/api/test_dbapi12.py | 1 - tests/fast/arrow/test_2426.py | 2 - tests/fast/arrow/test_arrow_fetch.py | 2 - tests/fast/arrow/test_arrow_pycapsule.py | 1 - tests/fast/arrow/test_binary_type.py | 1 - tests/fast/arrow/test_date.py | 2 - tests/fast/arrow/test_dictionary_arrow.py | 1 - tests/fast/arrow/test_interval.py | 1 - tests/fast/arrow/test_large_string.py | 1 - tests/fast/arrow/test_time.py | 2 - tests/fast/pandas/test_df_analyze.py | 1 - tests/fast/pandas/test_df_recursive_nested.py | 1 - tests/fast/pandas/test_fetch_nested.py | 1 - tests/fast/pandas/test_pandas_limit.py | 1 - .../pandas/test_partitioned_pandas_scan.py | 1 - tests/fast/pandas/test_progress_bar.py | 1 - .../test_pyarrow_projection_pushdown.py | 1 - .../relational_api/test_rapi_aggregations.py | 1 - tests/fast/spark/test_spark_functions_hex.py | 1 - tests/fast/test_all_types.py | 1 - tests/fast/test_ambiguous_prepare.py | 1 - tests/fast/test_case_alias.py | 1 - tests/fast/test_insert.py | 1 - tests/fast/test_multi_statement.py | 1 - tests/fast/test_string_annotation.py | 1 - tests/fast/test_transaction.py | 1 - tests/fast/test_type.py | 1 - tests/fast/test_union.py | 1 - tests/fast/test_versioning.py | 3 +- tests/fast/types/test_boolean.py | 2 - tests/fast/types/test_nested.py | 2 - tests/fast/types/test_null.py | 2 - tests/fast/types/test_numeric.py | 2 - tests/fast/udf/test_remove_function.py | 1 - tests/fast/udf/test_scalar.py | 2 - tests/fast/udf/test_scalar_arrow.py | 2 - tests/fast/udf/test_scalar_native.py | 1 - 50 files changed, 40 insertions(+), 132 deletions(-) diff --git a/adbc_driver_duckdb/dbapi.py b/adbc_driver_duckdb/dbapi.py index 7d703713..5d0a8702 100644 --- a/adbc_driver_duckdb/dbapi.py +++ b/adbc_driver_duckdb/dbapi.py @@ -15,8 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""DBAPI 2.0-compatible facade for the ADBC DuckDB driver. -""" +"""DBAPI 2.0-compatible facade for the ADBC DuckDB driver.""" import typing diff --git a/duckdb/experimental/spark/errors/__init__.py b/duckdb/experimental/spark/errors/__init__.py index 2f265d97..ee7688ea 100644 --- a/duckdb/experimental/spark/errors/__init__.py +++ b/duckdb/experimental/spark/errors/__init__.py @@ -15,8 +15,7 @@ # limitations under the License. # -"""PySpark exceptions. -""" +"""PySpark exceptions.""" from .exceptions.base import ( AnalysisException, diff --git a/duckdb/experimental/spark/errors/exceptions/base.py b/duckdb/experimental/spark/errors/exceptions/base.py index a6f1f940..0b2c6a43 100644 --- a/duckdb/experimental/spark/errors/exceptions/base.py +++ b/duckdb/experimental/spark/errors/exceptions/base.py @@ -4,8 +4,7 @@ class PySparkException(Exception): - """Base Exception for handling errors generated from PySpark. - """ + """Base Exception for handling errors generated from PySpark.""" def __init__( self, @@ -78,115 +77,92 @@ def __str__(self) -> str: class AnalysisException(PySparkException): - """Failed to analyze a SQL query plan. - """ + """Failed to analyze a SQL query plan.""" class SessionNotSameException(PySparkException): - """Performed the same operation on different SparkSession. - """ + """Performed the same operation on different SparkSession.""" class TempTableAlreadyExistsException(AnalysisException): - """Failed to create temp view since it is already exists. - """ + """Failed to create temp view since it is already exists.""" class ParseException(AnalysisException): - """Failed to parse a SQL command. - """ + """Failed to parse a SQL command.""" class IllegalArgumentException(PySparkException): - """Passed an illegal or inappropriate argument. - """ + """Passed an illegal or inappropriate argument.""" class ArithmeticException(PySparkException): - """Arithmetic exception thrown from Spark with an error class. - """ + """Arithmetic exception thrown from Spark with an error class.""" class UnsupportedOperationException(PySparkException): - """Unsupported operation exception thrown from Spark with an error class. - """ + """Unsupported operation exception thrown from Spark with an error class.""" class ArrayIndexOutOfBoundsException(PySparkException): - """Array index out of bounds exception thrown from Spark with an error class. - """ + """Array index out of bounds exception thrown from Spark with an error class.""" class DateTimeException(PySparkException): - """Datetime exception thrown from Spark with an error class. - """ + """Datetime exception thrown from Spark with an error class.""" class NumberFormatException(IllegalArgumentException): - """Number format exception thrown from Spark with an error class. - """ + """Number format exception thrown from Spark with an error class.""" class StreamingQueryException(PySparkException): - """Exception that stopped a :class:`StreamingQuery`. - """ + """Exception that stopped a :class:`StreamingQuery`.""" class QueryExecutionException(PySparkException): - """Failed to execute a query. - """ + """Failed to execute a query.""" class PythonException(PySparkException): - """Exceptions thrown from Python workers. - """ + """Exceptions thrown from Python workers.""" class SparkRuntimeException(PySparkException): - """Runtime exception thrown from Spark with an error class. - """ + """Runtime exception thrown from Spark with an error class.""" class SparkUpgradeException(PySparkException): - """Exception thrown because of Spark upgrade. - """ + """Exception thrown because of Spark upgrade.""" class UnknownException(PySparkException): - """None of the above exceptions. - """ + """None of the above exceptions.""" class PySparkValueError(PySparkException, ValueError): - """Wrapper class for ValueError to support error classes. - """ + """Wrapper class for ValueError to support error classes.""" class PySparkIndexError(PySparkException, IndexError): - """Wrapper class for IndexError to support error classes. - """ + """Wrapper class for IndexError to support error classes.""" class PySparkTypeError(PySparkException, TypeError): - """Wrapper class for TypeError to support error classes. - """ + """Wrapper class for TypeError to support error classes.""" class PySparkAttributeError(PySparkException, AttributeError): - """Wrapper class for AttributeError to support error classes. - """ + """Wrapper class for AttributeError to support error classes.""" class PySparkRuntimeError(PySparkException, RuntimeError): - """Wrapper class for RuntimeError to support error classes. - """ + """Wrapper class for RuntimeError to support error classes.""" class PySparkAssertionError(PySparkException, AssertionError): - """Wrapper class for AssertionError to support error classes. - """ + """Wrapper class for AssertionError to support error classes.""" class PySparkNotImplementedError(PySparkException, NotImplementedError): - """Wrapper class for NotImplementedError to support error classes. - """ + """Wrapper class for NotImplementedError to support error classes.""" diff --git a/duckdb/experimental/spark/errors/utils.py b/duckdb/experimental/spark/errors/utils.py index c8c66896..8b737dde 100644 --- a/duckdb/experimental/spark/errors/utils.py +++ b/duckdb/experimental/spark/errors/utils.py @@ -21,15 +21,13 @@ class ErrorClassesReader: - """A reader to load error information from error_classes.py. - """ + """A reader to load error information from error_classes.py.""" def __init__(self) -> None: self.error_info_map = ERROR_CLASSES_MAP def get_error_message(self, error_class: str, message_parameters: dict[str, str]) -> str: - """Returns the completed error message by applying message parameters to the message template. - """ + """Returns the completed error message by applying message parameters to the message template.""" message_template = self.get_message_template(error_class) # Verify message parameters. message_parameters_from_template = re.findall("<([a-zA-Z0-9_-]+)>", message_template) diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 3f32aa32..d0d4835d 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -843,8 +843,7 @@ def limit(self, num: int) -> "DataFrame": return DataFrame(rel, self.session) def __contains__(self, item: str) -> bool: - """Check if the :class:`DataFrame` contains a column by the name of `item` - """ + """Check if the :class:`DataFrame` contains a column by the name of `item`""" return item in self.relation @property diff --git a/duckdb/experimental/spark/sql/streaming.py b/duckdb/experimental/spark/sql/streaming.py index ba54db60..201b889b 100644 --- a/duckdb/experimental/spark/sql/streaming.py +++ b/duckdb/experimental/spark/sql/streaming.py @@ -30,7 +30,6 @@ def load( schema: Union[StructType, str, None] = None, **options: OptionalPrimitiveType, ) -> "DataFrame": - raise NotImplementedError diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index d8a04b8e..9d2b4b7d 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -102,13 +102,11 @@ def needConversion(self) -> bool: return False def toInternal(self, obj: Any) -> Any: - """Converts a Python object into an internal SQL object. - """ + """Converts a Python object into an internal SQL object.""" return obj def fromInternal(self, obj: Any) -> Any: - """Converts an internal SQL object into a native Python object. - """ + """Converts an internal SQL object into a native Python object.""" return obj @@ -979,14 +977,12 @@ def typeName(cls) -> str: @classmethod def sqlType(cls) -> DataType: - """Underlying SQL storage type for this UDT. - """ + """Underlying SQL storage type for this UDT.""" raise NotImplementedError("UDT must implement sqlType().") @classmethod def module(cls) -> str: - """The Python module of the UDT. - """ + """The Python module of the UDT.""" raise NotImplementedError("UDT must implement module().") @classmethod @@ -1001,8 +997,7 @@ def needConversion(self) -> bool: @classmethod def _cachedSqlType(cls) -> DataType: - """Cache the sqlType() into class, because it's heavily used in `toInternal`. - """ + """Cache the sqlType() into class, because it's heavily used in `toInternal`.""" if not hasattr(cls, "_cached_sql_type"): cls._cached_sql_type = cls.sqlType() # type: ignore[attr-defined] return cls._cached_sql_type # type: ignore[attr-defined] @@ -1017,13 +1012,11 @@ def fromInternal(self, obj: Any) -> Any: return self.deserialize(v) def serialize(self, obj: Any) -> Any: - """Converts a user-type object into a SQL datum. - """ + """Converts a user-type object into a SQL datum.""" raise NotImplementedError("UDT must implement toInternal().") def deserialize(self, datum: Any) -> Any: - """Converts a SQL datum into a user-type object. - """ + """Converts a SQL datum into a user-type object.""" raise NotImplementedError("UDT must implement fromInternal().") def simpleString(self) -> str: diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index b1fc244c..59758f19 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -206,8 +206,7 @@ def _pl_tree_to_sql(tree: dict) -> str: def duckdb_source(relation: duckdb.DuckDBPyRelation, schema: pl.schema.Schema) -> pl.LazyFrame: - """A polars IO plugin for DuckDB. - """ + """A polars IO plugin for DuckDB.""" def source_generator( with_columns: Optional[list[str]], diff --git a/duckdb/udf.py b/duckdb/udf.py index 0eb59ba9..21d6d53f 100644 --- a/duckdb/udf.py +++ b/duckdb/udf.py @@ -1,6 +1,5 @@ def vectorized(func): - """Decorate a function with annotated function parameters, so DuckDB can infer that the function should be provided with pyarrow arrays and should expect pyarrow array(s) as output - """ + """Decorate a function with annotated function parameters, so DuckDB can infer that the function should be provided with pyarrow arrays and should expect pyarrow array(s) as output""" import types from inspect import signature diff --git a/duckdb_packaging/pypi_cleanup.py b/duckdb_packaging/pypi_cleanup.py index 8e91b34f..b45cf1a1 100644 --- a/duckdb_packaging/pypi_cleanup.py +++ b/duckdb_packaging/pypi_cleanup.py @@ -78,17 +78,14 @@ class PyPICleanupError(Exception): """Base exception for PyPI cleanup operations.""" - class AuthenticationError(PyPICleanupError): """Raised when authentication fails.""" - class ValidationError(PyPICleanupError): """Raised when input validation fails.""" - def setup_logging(verbose: bool = False) -> None: """Configure logging with appropriate level and format.""" level = logging.DEBUG if verbose else logging.INFO diff --git a/sqllogic/conftest.py b/sqllogic/conftest.py index 8d772111..77281d54 100644 --- a/sqllogic/conftest.py +++ b/sqllogic/conftest.py @@ -268,8 +268,7 @@ def pytest_collection_modifyitems(session: pytest.Session, config: pytest.Config def pytest_runtest_setup(item: pytest.Item): - """Show the test index after the test name - """ + """Show the test index after the test name""" def get_from_tuple_list(tuples, key): for t in tuples: diff --git a/tests/conftest.py b/tests/conftest.py index 83c10f3a..cc385c31 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -228,7 +228,6 @@ def _require(extension_name, db_name=""): # By making the scope 'function' we ensure that a new connection gets created for every function that uses the fixture @pytest.fixture(scope="function") def spark(): - if not hasattr(spark, "session"): # Cache the import from spark_namespace.sql import SparkSession as session diff --git a/tests/fast/api/test_attribute_getter.py b/tests/fast/api/test_attribute_getter.py index 3b1513d1..208ccc40 100644 --- a/tests/fast/api/test_attribute_getter.py +++ b/tests/fast/api/test_attribute_getter.py @@ -1,4 +1,3 @@ - import pytest import duckdb diff --git a/tests/fast/api/test_dbapi12.py b/tests/fast/api/test_dbapi12.py index 96b1deac..f8dcdbe6 100644 --- a/tests/fast/api/test_dbapi12.py +++ b/tests/fast/api/test_dbapi12.py @@ -1,4 +1,3 @@ - import pandas as pd import duckdb diff --git a/tests/fast/arrow/test_2426.py b/tests/fast/arrow/test_2426.py index 5e6d42ef..a4bdeff7 100644 --- a/tests/fast/arrow/test_2426.py +++ b/tests/fast/arrow/test_2426.py @@ -1,8 +1,6 @@ - import duckdb try: - can_run = True except: can_run = False diff --git a/tests/fast/arrow/test_arrow_fetch.py b/tests/fast/arrow/test_arrow_fetch.py index 62460912..11deab23 100644 --- a/tests/fast/arrow/test_arrow_fetch.py +++ b/tests/fast/arrow/test_arrow_fetch.py @@ -1,8 +1,6 @@ - import duckdb try: - can_run = True except: can_run = False diff --git a/tests/fast/arrow/test_arrow_pycapsule.py b/tests/fast/arrow/test_arrow_pycapsule.py index 295f0292..0799c206 100644 --- a/tests/fast/arrow/test_arrow_pycapsule.py +++ b/tests/fast/arrow/test_arrow_pycapsule.py @@ -1,4 +1,3 @@ - import pytest import duckdb diff --git a/tests/fast/arrow/test_binary_type.py b/tests/fast/arrow/test_binary_type.py index 5932fba8..0a0062f5 100644 --- a/tests/fast/arrow/test_binary_type.py +++ b/tests/fast/arrow/test_binary_type.py @@ -1,4 +1,3 @@ - import duckdb try: diff --git a/tests/fast/arrow/test_date.py b/tests/fast/arrow/test_date.py index 83c14932..bebb55a0 100644 --- a/tests/fast/arrow/test_date.py +++ b/tests/fast/arrow/test_date.py @@ -1,5 +1,3 @@ - - import duckdb try: diff --git a/tests/fast/arrow/test_dictionary_arrow.py b/tests/fast/arrow/test_dictionary_arrow.py index 5cb2d38d..1b24c2b9 100644 --- a/tests/fast/arrow/test_dictionary_arrow.py +++ b/tests/fast/arrow/test_dictionary_arrow.py @@ -1,4 +1,3 @@ - import pytest pa = pytest.importorskip("pyarrow") diff --git a/tests/fast/arrow/test_interval.py b/tests/fast/arrow/test_interval.py index 5cdb04bd..7d3ec128 100644 --- a/tests/fast/arrow/test_interval.py +++ b/tests/fast/arrow/test_interval.py @@ -1,4 +1,3 @@ - import pytest import duckdb diff --git a/tests/fast/arrow/test_large_string.py b/tests/fast/arrow/test_large_string.py index bb9d1b5b..d6a4c76a 100644 --- a/tests/fast/arrow/test_large_string.py +++ b/tests/fast/arrow/test_large_string.py @@ -1,4 +1,3 @@ - import duckdb try: diff --git a/tests/fast/arrow/test_time.py b/tests/fast/arrow/test_time.py index b3bab360..ff16002c 100644 --- a/tests/fast/arrow/test_time.py +++ b/tests/fast/arrow/test_time.py @@ -1,5 +1,3 @@ - - import duckdb try: diff --git a/tests/fast/pandas/test_df_analyze.py b/tests/fast/pandas/test_df_analyze.py index 92318085..e1e0a2a7 100644 --- a/tests/fast/pandas/test_df_analyze.py +++ b/tests/fast/pandas/test_df_analyze.py @@ -1,4 +1,3 @@ - import numpy as np import pytest from conftest import ArrowPandas, NumpyPandas diff --git a/tests/fast/pandas/test_df_recursive_nested.py b/tests/fast/pandas/test_df_recursive_nested.py index 4eacf777..4ef84c84 100644 --- a/tests/fast/pandas/test_df_recursive_nested.py +++ b/tests/fast/pandas/test_df_recursive_nested.py @@ -1,4 +1,3 @@ - import pytest from conftest import ArrowPandas, NumpyPandas diff --git a/tests/fast/pandas/test_fetch_nested.py b/tests/fast/pandas/test_fetch_nested.py index 6e878643..5b8cfe50 100644 --- a/tests/fast/pandas/test_fetch_nested.py +++ b/tests/fast/pandas/test_fetch_nested.py @@ -1,4 +1,3 @@ - import pytest import duckdb diff --git a/tests/fast/pandas/test_pandas_limit.py b/tests/fast/pandas/test_pandas_limit.py index 89fe1583..9c63cfdc 100644 --- a/tests/fast/pandas/test_pandas_limit.py +++ b/tests/fast/pandas/test_pandas_limit.py @@ -1,4 +1,3 @@ - import duckdb diff --git a/tests/fast/pandas/test_partitioned_pandas_scan.py b/tests/fast/pandas/test_partitioned_pandas_scan.py index 9f580659..c1ab7b34 100644 --- a/tests/fast/pandas/test_partitioned_pandas_scan.py +++ b/tests/fast/pandas/test_partitioned_pandas_scan.py @@ -1,4 +1,3 @@ - import numpy import pandas as pd diff --git a/tests/fast/pandas/test_progress_bar.py b/tests/fast/pandas/test_progress_bar.py index c8cfb2e0..5635edae 100644 --- a/tests/fast/pandas/test_progress_bar.py +++ b/tests/fast/pandas/test_progress_bar.py @@ -1,4 +1,3 @@ - import numpy import pandas as pd diff --git a/tests/fast/pandas/test_pyarrow_projection_pushdown.py b/tests/fast/pandas/test_pyarrow_projection_pushdown.py index 4191a96e..87f49f04 100644 --- a/tests/fast/pandas/test_pyarrow_projection_pushdown.py +++ b/tests/fast/pandas/test_pyarrow_projection_pushdown.py @@ -1,4 +1,3 @@ - import pytest from conftest import pandas_supports_arrow_backend diff --git a/tests/fast/relational_api/test_rapi_aggregations.py b/tests/fast/relational_api/test_rapi_aggregations.py index 31cb21c9..9cc0492b 100644 --- a/tests/fast/relational_api/test_rapi_aggregations.py +++ b/tests/fast/relational_api/test_rapi_aggregations.py @@ -1,4 +1,3 @@ - import pytest import duckdb diff --git a/tests/fast/spark/test_spark_functions_hex.py b/tests/fast/spark/test_spark_functions_hex.py index c58c6d90..54caaf28 100644 --- a/tests/fast/spark/test_spark_functions_hex.py +++ b/tests/fast/spark/test_spark_functions_hex.py @@ -1,4 +1,3 @@ - import pytest _ = pytest.importorskip("duckdb.experimental.spark") diff --git a/tests/fast/test_all_types.py b/tests/fast/test_all_types.py index e74cca30..be920cf8 100644 --- a/tests/fast/test_all_types.py +++ b/tests/fast/test_all_types.py @@ -27,7 +27,6 @@ def replace_with_ndarray(obj): # we need to write our own equality function that considers nan==nan for testing purposes def recursive_equality(o1, o2): - if type(o1) != type(o2): return False if type(o1) == float and math.isnan(o1) and math.isnan(o2): diff --git a/tests/fast/test_ambiguous_prepare.py b/tests/fast/test_ambiguous_prepare.py index 0865b007..48f217cd 100644 --- a/tests/fast/test_ambiguous_prepare.py +++ b/tests/fast/test_ambiguous_prepare.py @@ -1,4 +1,3 @@ - import duckdb diff --git a/tests/fast/test_case_alias.py b/tests/fast/test_case_alias.py index 5092f099..d1afb4d8 100644 --- a/tests/fast/test_case_alias.py +++ b/tests/fast/test_case_alias.py @@ -7,7 +7,6 @@ class TestCaseAlias: @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) def test_case_alias(self, duckdb_cursor, pandas): - con = duckdb.connect(":memory:") df = pandas.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val3", "CoL2": 17}]) diff --git a/tests/fast/test_insert.py b/tests/fast/test_insert.py index 34489b44..030d255a 100644 --- a/tests/fast/test_insert.py +++ b/tests/fast/test_insert.py @@ -1,4 +1,3 @@ - import pytest from conftest import ArrowPandas, NumpyPandas diff --git a/tests/fast/test_multi_statement.py b/tests/fast/test_multi_statement.py index cd3111e6..2b255375 100644 --- a/tests/fast/test_multi_statement.py +++ b/tests/fast/test_multi_statement.py @@ -6,7 +6,6 @@ class TestMultiStatement: def test_multi_statement(self, duckdb_cursor): - con = duckdb.connect(":memory:") # test empty statement diff --git a/tests/fast/test_string_annotation.py b/tests/fast/test_string_annotation.py index 17c22844..b8014740 100644 --- a/tests/fast/test_string_annotation.py +++ b/tests/fast/test_string_annotation.py @@ -18,7 +18,6 @@ def test_base(): def python_version_lower_than_3_10(): - if sys.version_info[0] < 3: return True if sys.version_info[1] < 10: diff --git a/tests/fast/test_transaction.py b/tests/fast/test_transaction.py index 4a06c9e7..0dfabafa 100644 --- a/tests/fast/test_transaction.py +++ b/tests/fast/test_transaction.py @@ -1,4 +1,3 @@ - import duckdb diff --git a/tests/fast/test_type.py b/tests/fast/test_type.py index 768b7782..4824ce7c 100644 --- a/tests/fast/test_type.py +++ b/tests/fast/test_type.py @@ -117,7 +117,6 @@ def test_union_type(self): type = duckdb.union_type({"a": BIGINT, "b": VARCHAR, "c": TINYINT}) assert str(type) == "UNION(a BIGINT, b VARCHAR, c TINYINT)" - @pytest.mark.skipif(sys.version_info < (3, 9), reason="requires >= python3.9") def test_implicit_convert_from_builtin_type(self): type = duckdb.list_type(list[str]) diff --git a/tests/fast/test_union.py b/tests/fast/test_union.py index d47a8192..8df17238 100644 --- a/tests/fast/test_union.py +++ b/tests/fast/test_union.py @@ -1,4 +1,3 @@ - import duckdb diff --git a/tests/fast/test_versioning.py b/tests/fast/test_versioning.py index 207b24fe..5f48c3cb 100644 --- a/tests/fast/test_versioning.py +++ b/tests/fast/test_versioning.py @@ -1,5 +1,4 @@ -"""Tests for duckdb_pytooling versioning functionality. -""" +"""Tests for duckdb_pytooling versioning functionality.""" import os import subprocess diff --git a/tests/fast/types/test_boolean.py b/tests/fast/types/test_boolean.py index 5a519e51..dfa67aaa 100644 --- a/tests/fast/types/test_boolean.py +++ b/tests/fast/types/test_boolean.py @@ -1,5 +1,3 @@ - - class TestBoolean: def test_bool(self, duckdb_cursor): duckdb_cursor.execute("SELECT TRUE") diff --git a/tests/fast/types/test_nested.py b/tests/fast/types/test_nested.py index e82673c7..824b2825 100644 --- a/tests/fast/types/test_nested.py +++ b/tests/fast/types/test_nested.py @@ -1,5 +1,3 @@ - - class TestNested: def test_lists(self, duckdb_cursor): result = duckdb_cursor.execute("SELECT LIST_VALUE(1, 2, 3, 4) ").fetchall() diff --git a/tests/fast/types/test_null.py b/tests/fast/types/test_null.py index e5fe2e3d..27f287c8 100644 --- a/tests/fast/types/test_null.py +++ b/tests/fast/types/test_null.py @@ -1,5 +1,3 @@ - - class TestNull: def test_fetchone_null(self, duckdb_cursor): duckdb_cursor.execute("CREATE TABLE atable (Value int)") diff --git a/tests/fast/types/test_numeric.py b/tests/fast/types/test_numeric.py index 174700aa..6540735d 100644 --- a/tests/fast/types/test_numeric.py +++ b/tests/fast/types/test_numeric.py @@ -1,5 +1,3 @@ - - def check_result(duckdb_cursor, value, type): duckdb_cursor.execute("SELECT " + str(value) + "::" + type) results = duckdb_cursor.fetchall() diff --git a/tests/fast/udf/test_remove_function.py b/tests/fast/udf/test_remove_function.py index c909c61d..2e7cc670 100644 --- a/tests/fast/udf/test_remove_function.py +++ b/tests/fast/udf/test_remove_function.py @@ -1,4 +1,3 @@ - import pytest import duckdb diff --git a/tests/fast/udf/test_scalar.py b/tests/fast/udf/test_scalar.py index e8b1e6d9..b7f4e343 100644 --- a/tests/fast/udf/test_scalar.py +++ b/tests/fast/udf/test_scalar.py @@ -1,4 +1,3 @@ - import pytest import duckdb @@ -222,7 +221,6 @@ def return_np_nan(): @pytest.mark.parametrize("duckdb_type", [FLOAT, DOUBLE]) def test_math_nan(self, duckdb_type, udf_type): def return_math_nan(): - if udf_type == "native": return cmath.nan else: diff --git a/tests/fast/udf/test_scalar_arrow.py b/tests/fast/udf/test_scalar_arrow.py index 856c760d..984a1f8c 100644 --- a/tests/fast/udf/test_scalar_arrow.py +++ b/tests/fast/udf/test_scalar_arrow.py @@ -1,4 +1,3 @@ - import pytest import duckdb @@ -122,7 +121,6 @@ def return_too_many(col): res = con.sql("""select too_many_tuples(5)""").fetchall() def test_arrow_side_effects(self, duckdb_cursor): - def random_arrow(x): if not hasattr(random_arrow, "data"): random_arrow.data = 0 diff --git a/tests/fast/udf/test_scalar_native.py b/tests/fast/udf/test_scalar_native.py index 94b2949e..76295060 100644 --- a/tests/fast/udf/test_scalar_native.py +++ b/tests/fast/udf/test_scalar_native.py @@ -1,4 +1,3 @@ - import pytest import duckdb From 8a6c28fe5c4382238bed483f0e1a98a385862566 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:32:18 +0200 Subject: [PATCH 215/472] Ruff config: dont add future annotations --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a53f9eb5..03570028 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -329,7 +329,6 @@ select = [ "E", # pycodestyle "EM", # flake8-errmsg "F", # pyflakes - "FA", # flake8-future-annotations "FBT001", # flake8-boolean-trap "I", # isort "ICN", # flake8-import-conventions From db800b14f352a949892873f31ccfb7127068b36b Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:37:19 +0200 Subject: [PATCH 216/472] Ruff config: temporarily skip import checks --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 03570028..811d9c1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -331,7 +331,7 @@ select = [ "F", # pyflakes "FBT001", # flake8-boolean-trap "I", # isort - "ICN", # flake8-import-conventions + #"ICN", # flake8-import-conventions "INT", # flake8-gettext "PERF", # perflint "PIE", # flake8-pie @@ -342,7 +342,7 @@ select = [ "SIM", # flake8-simplify "TCH", # flake8-type-checking "TD", # flake8-todos - "TID", # flake8-tidy-imports + #"TID", # flake8-tidy-imports "TRY", # tryceratops "UP", # pyupgrade "W", # pycodestyle From 9044133024fa8afcd2a539d8053103bfe77e71c7 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:45:55 +0200 Subject: [PATCH 217/472] Ruff EM: Error messages assigned to a var first --- duckdb/experimental/spark/_globals.py | 3 +- duckdb/experimental/spark/errors/utils.py | 6 +- duckdb/experimental/spark/sql/column.py | 6 +- duckdb/experimental/spark/sql/dataframe.py | 9 ++- duckdb/experimental/spark/sql/functions.py | 30 ++++--- duckdb/experimental/spark/sql/readwriter.py | 78 ++++++++++++------- duckdb/experimental/spark/sql/types.py | 36 ++++++--- duckdb/filesystem.py | 3 +- duckdb/polars_io.py | 12 ++- duckdb_packaging/_versioning.py | 9 ++- duckdb_packaging/build_backend.py | 26 ++++--- duckdb_packaging/pypi_cleanup.py | 65 ++++++++++------ duckdb_packaging/setuptools_scm_version.py | 12 ++- scripts/generate_connection_methods.py | 12 ++- scripts/generate_connection_stubs.py | 12 ++- .../generate_connection_wrapper_methods.py | 9 ++- scripts/generate_connection_wrapper_stubs.py | 12 ++- scripts/generate_import_cache_json.py | 3 +- sqllogic/conftest.py | 17 ++-- tests/fast/adbc/test_statement_bind.py | 3 +- tests/fast/udf/test_scalar.py | 3 +- tests/fast/udf/test_scalar_arrow.py | 3 +- 22 files changed, 242 insertions(+), 127 deletions(-) diff --git a/duckdb/experimental/spark/_globals.py b/duckdb/experimental/spark/_globals.py index 4bc325f7..771daceb 100644 --- a/duckdb/experimental/spark/_globals.py +++ b/duckdb/experimental/spark/_globals.py @@ -39,7 +39,8 @@ def foo(arg=pyducdkb.spark._NoValue): # Disallow reloading this module so as to preserve the identities of the # classes defined here. if "_is_loaded" in globals(): - raise RuntimeError("Reloading duckdb.experimental.spark._globals is not allowed") + msg = "Reloading duckdb.experimental.spark._globals is not allowed" + raise RuntimeError(msg) _is_loaded = True diff --git a/duckdb/experimental/spark/errors/utils.py b/duckdb/experimental/spark/errors/utils.py index 8b737dde..984504a4 100644 --- a/duckdb/experimental/spark/errors/utils.py +++ b/duckdb/experimental/spark/errors/utils.py @@ -86,7 +86,8 @@ def get_message_template(self, error_class: str) -> str: if main_error_class in self.error_info_map: main_error_class_info_map = self.error_info_map[main_error_class] else: - raise ValueError(f"Cannot find main error class '{main_error_class}'") + msg = f"Cannot find main error class '{main_error_class}'" + raise ValueError(msg) main_message_template = "\n".join(main_error_class_info_map["message"]) @@ -101,7 +102,8 @@ def get_message_template(self, error_class: str) -> str: if sub_error_class in main_error_class_subclass_info_map: sub_error_class_info_map = main_error_class_subclass_info_map[sub_error_class] else: - raise ValueError(f"Cannot find sub error class '{sub_error_class}'") + msg = f"Cannot find sub error class '{sub_error_class}'" + raise ValueError(msg) sub_message_template = "\n".join(sub_error_class_info_map["message"]) message_template = main_message_template + " " + sub_message_template diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index 3a6f6cea..6cc92523 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -201,7 +201,8 @@ def __getattr__(self, item: Any) -> "Column": +------+ """ if item.startswith("__"): - raise AttributeError("Can not access __ (dunder) method") + msg = "Can not access __ (dunder) method" + raise AttributeError(msg) return self[item] def alias(self, alias: str): @@ -209,7 +210,8 @@ def alias(self, alias: str): def when(self, condition: "Column", value: Any): if not isinstance(condition, Column): - raise TypeError("condition should be a Column") + msg = "condition should be a Column" + raise TypeError(msg) v = _get_expr(value) expr = self.expr.when(condition.expr, v) return Column(expr) diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index d0d4835d..57c8cd03 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -108,7 +108,8 @@ def createGlobalTempView(self, name: str) -> None: def withColumnRenamed(self, columnName: str, newName: str) -> "DataFrame": if columnName not in self.relation: - raise ValueError(f"DataFrame does not contain a column named {columnName}") + msg = f"DataFrame does not contain a column named {columnName}" + raise ValueError(msg) cols = [] for x in self.relation.columns: col = ColumnExpression(x) @@ -258,7 +259,8 @@ def withColumnsRenamed(self, colsMap: dict[str, str]) -> "DataFrame": unknown_columns = set(colsMap.keys()) - set(self.relation.columns) if unknown_columns: - raise ValueError(f"DataFrame does not contain column(s): {', '.join(unknown_columns)}") + msg = f"DataFrame does not contain column(s): {', '.join(unknown_columns)}" + raise ValueError(msg) # Compute this only once old_column_names = list(colsMap.keys()) @@ -887,7 +889,8 @@ def __getitem__(self, item: Union[int, str, Column, list, tuple]) -> Union[Colum elif isinstance(item, int): return col(self._schema[item].name) else: - raise TypeError(f"Unexpected item type: {type(item)}") + msg = f"Unexpected item type: {type(item)}" + raise TypeError(msg) def __getattr__(self, name: str) -> Column: """Returns the :class:`Column` denoted by ``name``. diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index 501c9503..fddcd4c5 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -92,7 +92,8 @@ def ucase(str: "ColumnOrName") -> Column: def when(condition: "Column", value: Any) -> Column: if not isinstance(condition, Column): - raise TypeError("condition should be a Column") + msg = "condition should be a Column" + raise TypeError(msg) v = _get_expr(value) expr = CaseExpression(condition.expr, v) return Column(expr) @@ -1480,7 +1481,8 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C +---------------+ """ if rsd is not None: - raise ValueError("rsd is not supported by DuckDB") + msg = "rsd is not supported by DuckDB" + raise ValueError(msg) return _invoke_function_over_columns("approx_count_distinct", col) @@ -2365,7 +2367,8 @@ def rand(seed: Optional[int] = None) -> Column: """ if seed is not None: # Maybe call setseed just before but how do we know when it is executed? - raise ContributionsAcceptedError("Seed is not yet implemented") + msg = "Seed is not yet implemented" + raise ContributionsAcceptedError(msg) return _invoke_function("random") @@ -2842,7 +2845,8 @@ def encode(col: "ColumnOrName", charset: str) -> Column: +----------------+ """ if charset != "UTF-8": - raise ContributionsAcceptedError("Only UTF-8 charset is supported right now") + msg = "Only UTF-8 charset is supported right now" + raise ContributionsAcceptedError(msg) return _invoke_function("encode", _to_column_expr(col)) @@ -3017,7 +3021,8 @@ def greatest(*cols: "ColumnOrName") -> Column: [Row(greatest=4)] """ if len(cols) < 2: - raise ValueError("greatest should take at least 2 columns") + msg = "greatest should take at least 2 columns" + raise ValueError(msg) cols = [_to_column_expr(expr) for expr in cols] return _invoke_function("greatest", *cols) @@ -3049,7 +3054,8 @@ def least(*cols: "ColumnOrName") -> Column: [Row(least=1)] """ if len(cols) < 2: - raise ValueError("least should take at least 2 columns") + msg = "least should take at least 2 columns" + raise ValueError(msg) cols = [_to_column_expr(expr) for expr in cols] return _invoke_function("least", *cols) @@ -3550,12 +3556,14 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column: +-----+----------------------------------------------------------------+ """ if numBits not in {224, 256, 384, 512, 0}: - raise ValueError("numBits should be one of {224, 256, 384, 512, 0}") + msg = "numBits should be one of {224, 256, 384, 512, 0}" + raise ValueError(msg) if numBits == 256: return _invoke_function_over_columns("sha256", col) - raise ContributionsAcceptedError("SHA-224, SHA-384, and SHA-512 are not supported yet.") + msg = "SHA-224, SHA-384, and SHA-512 are not supported yet." + raise ContributionsAcceptedError(msg) def curdate() -> Column: @@ -5241,7 +5249,8 @@ def array_sort(col: "ColumnOrName", comparator: Optional[Callable[[Column, Colum [Row(r=['foobar', 'foo', None, 'bar']), Row(r=['foo']), Row(r=[])] """ if comparator is not None: - raise ContributionsAcceptedError("comparator is not yet supported") + msg = "comparator is not yet supported" + raise ContributionsAcceptedError(msg) else: return _invoke_function_over_columns("list_sort", col, lit("ASC"), lit("NULLS LAST")) @@ -5335,7 +5344,8 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: if limit > 0: # Unclear how to implement this in DuckDB as we'd need to map back from the split array # to the original array which is tricky with regular expressions. - raise ContributionsAcceptedError("limit is not yet supported") + msg = "limit is not yet supported" + raise ContributionsAcceptedError(msg) return _invoke_function_over_columns("regexp_split_to_array", str, lit(pattern)) diff --git a/duckdb/experimental/spark/sql/readwriter.py b/duckdb/experimental/spark/sql/readwriter.py index 607e9d36..714ed797 100644 --- a/duckdb/experimental/spark/sql/readwriter.py +++ b/duckdb/experimental/spark/sql/readwriter.py @@ -248,10 +248,12 @@ def csv( def parquet(self, *paths: str, **options: "OptionalPrimitiveType") -> "DataFrame": input = list(paths) if len(input) != 1: - raise NotImplementedError("Only single paths are supported for now") + msg = "Only single paths are supported for now" + raise NotImplementedError(msg) option_amount = len(options.keys()) if option_amount != 0: - raise ContributionsAcceptedError("Options are not supported") + msg = "Options are not supported" + raise ContributionsAcceptedError(msg) path = input[0] rel = self.session.conn.read_parquet(path) from ..sql.dataframe import DataFrame @@ -338,53 +340,77 @@ def json( +---+------------+ """ if schema is not None: - raise ContributionsAcceptedError("The 'schema' option is not supported") + msg = "The 'schema' option is not supported" + raise ContributionsAcceptedError(msg) if primitivesAsString is not None: - raise ContributionsAcceptedError("The 'primitivesAsString' option is not supported") + msg = "The 'primitivesAsString' option is not supported" + raise ContributionsAcceptedError(msg) if prefersDecimal is not None: - raise ContributionsAcceptedError("The 'prefersDecimal' option is not supported") + msg = "The 'prefersDecimal' option is not supported" + raise ContributionsAcceptedError(msg) if allowComments is not None: - raise ContributionsAcceptedError("The 'allowComments' option is not supported") + msg = "The 'allowComments' option is not supported" + raise ContributionsAcceptedError(msg) if allowUnquotedFieldNames is not None: - raise ContributionsAcceptedError("The 'allowUnquotedFieldNames' option is not supported") + msg = "The 'allowUnquotedFieldNames' option is not supported" + raise ContributionsAcceptedError(msg) if allowSingleQuotes is not None: - raise ContributionsAcceptedError("The 'allowSingleQuotes' option is not supported") + msg = "The 'allowSingleQuotes' option is not supported" + raise ContributionsAcceptedError(msg) if allowNumericLeadingZero is not None: - raise ContributionsAcceptedError("The 'allowNumericLeadingZero' option is not supported") + msg = "The 'allowNumericLeadingZero' option is not supported" + raise ContributionsAcceptedError(msg) if allowBackslashEscapingAnyCharacter is not None: - raise ContributionsAcceptedError("The 'allowBackslashEscapingAnyCharacter' option is not supported") + msg = "The 'allowBackslashEscapingAnyCharacter' option is not supported" + raise ContributionsAcceptedError(msg) if mode is not None: - raise ContributionsAcceptedError("The 'mode' option is not supported") + msg = "The 'mode' option is not supported" + raise ContributionsAcceptedError(msg) if columnNameOfCorruptRecord is not None: - raise ContributionsAcceptedError("The 'columnNameOfCorruptRecord' option is not supported") + msg = "The 'columnNameOfCorruptRecord' option is not supported" + raise ContributionsAcceptedError(msg) if dateFormat is not None: - raise ContributionsAcceptedError("The 'dateFormat' option is not supported") + msg = "The 'dateFormat' option is not supported" + raise ContributionsAcceptedError(msg) if timestampFormat is not None: - raise ContributionsAcceptedError("The 'timestampFormat' option is not supported") + msg = "The 'timestampFormat' option is not supported" + raise ContributionsAcceptedError(msg) if multiLine is not None: - raise ContributionsAcceptedError("The 'multiLine' option is not supported") + msg = "The 'multiLine' option is not supported" + raise ContributionsAcceptedError(msg) if allowUnquotedControlChars is not None: - raise ContributionsAcceptedError("The 'allowUnquotedControlChars' option is not supported") + msg = "The 'allowUnquotedControlChars' option is not supported" + raise ContributionsAcceptedError(msg) if lineSep is not None: - raise ContributionsAcceptedError("The 'lineSep' option is not supported") + msg = "The 'lineSep' option is not supported" + raise ContributionsAcceptedError(msg) if samplingRatio is not None: - raise ContributionsAcceptedError("The 'samplingRatio' option is not supported") + msg = "The 'samplingRatio' option is not supported" + raise ContributionsAcceptedError(msg) if dropFieldIfAllNull is not None: - raise ContributionsAcceptedError("The 'dropFieldIfAllNull' option is not supported") + msg = "The 'dropFieldIfAllNull' option is not supported" + raise ContributionsAcceptedError(msg) if encoding is not None: - raise ContributionsAcceptedError("The 'encoding' option is not supported") + msg = "The 'encoding' option is not supported" + raise ContributionsAcceptedError(msg) if locale is not None: - raise ContributionsAcceptedError("The 'locale' option is not supported") + msg = "The 'locale' option is not supported" + raise ContributionsAcceptedError(msg) if pathGlobFilter is not None: - raise ContributionsAcceptedError("The 'pathGlobFilter' option is not supported") + msg = "The 'pathGlobFilter' option is not supported" + raise ContributionsAcceptedError(msg) if recursiveFileLookup is not None: - raise ContributionsAcceptedError("The 'recursiveFileLookup' option is not supported") + msg = "The 'recursiveFileLookup' option is not supported" + raise ContributionsAcceptedError(msg) if modifiedBefore is not None: - raise ContributionsAcceptedError("The 'modifiedBefore' option is not supported") + msg = "The 'modifiedBefore' option is not supported" + raise ContributionsAcceptedError(msg) if modifiedAfter is not None: - raise ContributionsAcceptedError("The 'modifiedAfter' option is not supported") + msg = "The 'modifiedAfter' option is not supported" + raise ContributionsAcceptedError(msg) if allowNonNumericNumbers is not None: - raise ContributionsAcceptedError("The 'allowNonNumericNumbers' option is not supported") + msg = "The 'allowNonNumericNumbers' option is not supported" + raise ContributionsAcceptedError(msg) if isinstance(path, str): path = [path] diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 9d2b4b7d..55eb9855 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -731,7 +731,8 @@ def fromInternal(self, obj: T) -> T: return self.dataType.fromInternal(obj) def typeName(self) -> str: # type: ignore[override] - raise TypeError("StructField does not have typeName. Use typeName on its type explicitly instead.") + msg = "StructField does not have typeName. Use typeName on its type explicitly instead." + raise TypeError(msg) class StructType(DataType): @@ -841,7 +842,8 @@ def add( self.names.append(field.name) else: if isinstance(field, str) and data_type is None: - raise ValueError("Must specify DataType if passing name of struct_field to create.") + msg = "Must specify DataType if passing name of struct_field to create." + raise ValueError(msg) else: data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) @@ -866,16 +868,19 @@ def __getitem__(self, key: Union[str, int]) -> StructField: for field in self: if field.name == key: return field - raise KeyError(f"No StructField named {key}") + msg = f"No StructField named {key}" + raise KeyError(msg) elif isinstance(key, int): try: return self.fields[key] except IndexError: - raise IndexError("StructType index out of range") + msg = "StructType index out of range" + raise IndexError(msg) elif isinstance(key, slice): return StructType(self.fields[key]) else: - raise TypeError("StructType keys should be strings, integers or slices") + msg = "StructType keys should be strings, integers or slices" + raise TypeError(msg) def simpleString(self) -> str: return "struct<%s>" % (",".join(f.simpleString() for f in self)) @@ -978,12 +983,14 @@ def typeName(cls) -> str: @classmethod def sqlType(cls) -> DataType: """Underlying SQL storage type for this UDT.""" - raise NotImplementedError("UDT must implement sqlType().") + msg = "UDT must implement sqlType()." + raise NotImplementedError(msg) @classmethod def module(cls) -> str: """The Python module of the UDT.""" - raise NotImplementedError("UDT must implement module().") + msg = "UDT must implement module()." + raise NotImplementedError(msg) @classmethod def scalaUDT(cls) -> str: @@ -1013,11 +1020,13 @@ def fromInternal(self, obj: Any) -> Any: def serialize(self, obj: Any) -> Any: """Converts a user-type object into a SQL datum.""" - raise NotImplementedError("UDT must implement toInternal().") + msg = "UDT must implement toInternal()." + raise NotImplementedError(msg) def deserialize(self, datum: Any) -> Any: """Converts a SQL datum into a user-type object.""" - raise NotImplementedError("UDT must implement fromInternal().") + msg = "UDT must implement fromInternal()." + raise NotImplementedError(msg) def simpleString(self) -> str: return "udt" @@ -1126,7 +1135,8 @@ def __new__(cls, **kwargs: Any) -> "Row": ... def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": if args and kwargs: - raise ValueError("Can not use both args and kwargs to create Row") + msg = "Can not use both args and kwargs to create Row" + raise ValueError(msg) if kwargs: # create row objects row = tuple.__new__(cls, list(kwargs.values())) @@ -1163,7 +1173,8 @@ def asDict(self, recursive: bool = False) -> dict[str, Any]: True """ if not hasattr(self, "__fields__"): - raise TypeError("Cannot convert a Row class into dict") + msg = "Cannot convert a Row class into dict" + raise TypeError(msg) if recursive: @@ -1224,7 +1235,8 @@ def __getattr__(self, item: str) -> Any: def __setattr__(self, key: Any, value: Any) -> None: if key != "__fields__": - raise RuntimeError("Row is read-only") + msg = "Row is read-only" + raise RuntimeError(msg) self.__dict__[key] = value def __reduce__( diff --git a/duckdb/filesystem.py b/duckdb/filesystem.py index 885c797f..77838103 100644 --- a/duckdb/filesystem.py +++ b/duckdb/filesystem.py @@ -18,7 +18,8 @@ class ModifiedMemoryFileSystem(MemoryFileSystem): def add_file(self, object, path): if not is_file_like(object): - raise ValueError("Can not read from a non file-like object") + msg = "Can not read from a non file-like object" + raise ValueError(msg) path = self._strip_protocol(path) if isinstance(object, TextIOBase): # Wrap this so that we can return a bytes object from 'read' diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index 59758f19..69e1e7ea 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -132,9 +132,11 @@ def _pl_tree_to_sql(tree: dict) -> str: return f"({arg_sql} IS NULL)" if func == "IsNotNull": return f"({arg_sql} IS NOT NULL)" - raise NotImplementedError(f"Boolean function not supported: {func}") + msg = f"Boolean function not supported: {func}" + raise NotImplementedError(msg) - raise NotImplementedError(f"Unsupported function type: {func_dict}") + msg = f"Unsupported function type: {func_dict}" + raise NotImplementedError(msg) if node_type == "Scalar": # Detect format: old style (dtype/value) or new style (direct type key) @@ -200,9 +202,11 @@ def _pl_tree_to_sql(tree: dict) -> str: string_val = value.get("StringOwned", value.get("String", None)) return f"'{string_val}'" - raise NotImplementedError(f"Unsupported scalar type {dtype!s}, with value {value}") + msg = f"Unsupported scalar type {dtype!s}, with value {value}" + raise NotImplementedError(msg) - raise NotImplementedError(f"Node type: {node_type} is not implemented. {subtree}") + msg = f"Node type: {node_type} is not implemented. {subtree}" + raise NotImplementedError(msg) def duckdb_source(relation: duckdb.DuckDBPyRelation, schema: pl.schema.Schema) -> pl.LazyFrame: diff --git a/duckdb_packaging/_versioning.py b/duckdb_packaging/_versioning.py index 57008fa3..b338ef6b 100644 --- a/duckdb_packaging/_versioning.py +++ b/duckdb_packaging/_versioning.py @@ -30,7 +30,8 @@ def parse_version(version: str) -> tuple[int, int, int, int, int]: """ match = VERSION_RE.match(version) if not match: - raise ValueError(f"Invalid version format: {version} (expected X.Y.Z, X.Y.Z.rcM or X.Y.Z.postN)") + msg = f"Invalid version format: {version} (expected X.Y.Z, X.Y.Z.rcM or X.Y.Z.postN)" + raise ValueError(msg) major, minor, patch, rc, post = match.groups() return int(major), int(minor), int(patch), int(post or 0), int(rc or 0) @@ -51,7 +52,8 @@ def format_version(major: int, minor: int, patch: int, post: int = 0, rc: int = """ version = f"{major}.{minor}.{patch}" if post != 0 and rc != 0: - raise ValueError("post and rc are mutually exclusive") + msg = "post and rc are mutually exclusive" + raise ValueError(msg) if post != 0: version += f".post{post}" if rc != 0: @@ -168,4 +170,5 @@ def get_git_describe(repo_path: Optional[pathlib.Path] = None, since_major=False result.check_returncode() return result.stdout.strip() except FileNotFoundError: - raise RuntimeError("git executable can't be found") + msg = "git executable can't be found" + raise RuntimeError(msg) diff --git a/duckdb_packaging/build_backend.py b/duckdb_packaging/build_backend.py index dc94eeaa..aa5e4515 100644 --- a/duckdb_packaging/build_backend.py +++ b/duckdb_packaging/build_backend.py @@ -75,7 +75,8 @@ def _in_sdist() -> bool: def _duckdb_submodule_path() -> Path: """Verify that the duckdb submodule is checked out and usable and return its path.""" if not _in_git_repository(): - raise RuntimeError("Not in a git repository, no duckdb submodule present") + msg = "Not in a git repository, no duckdb submodule present" + raise RuntimeError(msg) # search the duckdb submodule gitmodules_path = Path(".gitmodules") modules = dict() @@ -97,7 +98,8 @@ def _duckdb_submodule_path() -> Path: modules[cur_module_reponame] = cur_module_path if "duckdb" not in modules: - raise RuntimeError("DuckDB submodule missing") + msg = "DuckDB submodule missing" + raise RuntimeError(msg) duckdb_path = modules["duckdb"] # now check that the submodule is usable @@ -106,9 +108,11 @@ def _duckdb_submodule_path() -> Path: status = status.decode("ascii", "replace") for line in status.splitlines(): if line.startswith("-"): - raise RuntimeError(f"Duckdb submodule not initialized: {line}") + msg = f"Duckdb submodule not initialized: {line}" + raise RuntimeError(msg) if line.startswith("U"): - raise RuntimeError(f"Duckdb submodule has merge conflicts: {line}") + msg = f"Duckdb submodule has merge conflicts: {line}" + raise RuntimeError(msg) if line.startswith("+"): _log(f"WARNING: Duckdb submodule not clean: {line}") # all good @@ -169,7 +173,8 @@ def _skbuild_config_add( if not key_exists: config_settings[store_key] = value elif fail_if_exists: - raise RuntimeError(f"{key} already present in config and may not be overridden") + msg = f"{key} already present in config and may not be overridden" + raise RuntimeError(msg) elif key_exists_as_list and val_is_list: config_settings[store_key].extend(value) elif key_exists_as_list and val_is_str: @@ -178,9 +183,8 @@ def _skbuild_config_add( _log(f"WARNING: overriding existing value in {store_key}") config_settings[store_key] = value else: - raise RuntimeError( - f"Type mismatch: cannot set {store_key} ({type(config_settings[store_key])}) to `{value}` ({type(value)})" - ) + msg = f"Type mismatch: cannot set {store_key} ({type(config_settings[store_key])}) to `{value}` ({type(value)})" + raise RuntimeError(msg) def build_sdist(sdist_directory: str, config_settings: Optional[dict[str, Union[list[str], str]]] = None) -> str: @@ -201,7 +205,8 @@ def build_sdist(sdist_directory: str, config_settings: Optional[dict[str, Union[ RuntimeError: If not in a git repository or DuckDB submodule issues. """ if not _in_git_repository(): - raise RuntimeError("Not in a git repository, can't create an sdist") + msg = "Not in a git repository, can't create an sdist" + raise RuntimeError(msg) submodule_path = _duckdb_submodule_path() if _FORCED_PEP440_VERSION is not None: duckdb_version = pep440_to_git_tag(strip_post_from_version(_FORCED_PEP440_VERSION)) @@ -237,7 +242,8 @@ def build_wheel( duckdb_version = None if not _in_git_repository(): if not _in_sdist(): - raise RuntimeError("Not in a git repository nor in an sdist, can't build a wheel") + msg = "Not in a git repository nor in an sdist, can't build a wheel" + raise RuntimeError(msg) _log("Building duckdb wheel from sdist. Reading duckdb version from file.") config_settings = config_settings or {} duckdb_version = _read_duckdb_long_version() diff --git a/duckdb_packaging/pypi_cleanup.py b/duckdb_packaging/pypi_cleanup.py index b45cf1a1..428e07dd 100644 --- a/duckdb_packaging/pypi_cleanup.py +++ b/duckdb_packaging/pypi_cleanup.py @@ -95,15 +95,18 @@ def setup_logging(verbose: bool = False) -> None: def validate_username(value: str) -> str: """Validate and sanitize username input.""" if not value or not value.strip(): - raise argparse.ArgumentTypeError("Username cannot be empty") + msg = "Username cannot be empty" + raise argparse.ArgumentTypeError(msg) username = value.strip() if len(username) > 100: # Reasonable limit - raise argparse.ArgumentTypeError("Username too long (max 100 characters)") + msg = "Username too long (max 100 characters)" + raise argparse.ArgumentTypeError(msg) # Basic validation - PyPI usernames are alphanumeric with limited special chars if not re.match(r"^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$|^[a-zA-Z0-9]$", username): - raise argparse.ArgumentTypeError("Invalid username format") + msg = "Invalid username format" + raise argparse.ArgumentTypeError(msg) return username @@ -140,9 +143,11 @@ def load_credentials(dry_run: bool) -> tuple[Optional[str], Optional[str]]: otp = os.getenv("PYPI_CLEANUP_OTP") if not password: - raise ValidationError("PYPI_CLEANUP_PASSWORD environment variable is required when not in dry-run mode") + msg = "PYPI_CLEANUP_PASSWORD environment variable is required when not in dry-run mode" + raise ValidationError(msg) if not otp: - raise ValidationError("PYPI_CLEANUP_OTP environment variable is required when not in dry-run mode") + msg = "PYPI_CLEANUP_OTP environment variable is required when not in dry-run mode" + raise ValidationError(msg) return password, otp @@ -150,10 +155,12 @@ def load_credentials(dry_run: bool) -> tuple[Optional[str], Optional[str]]: def validate_arguments(args: argparse.Namespace) -> None: """Validate parsed arguments.""" if not args.dry_run and not args.username: - raise ValidationError("--username is required when not in dry-run mode") + msg = "--username is required when not in dry-run mode" + raise ValidationError(msg) if args.max_nightlies < 0: - raise ValidationError("--max-nightlies must be non-negative") + msg = "--max-nightlies must be non-negative" + raise ValidationError(msg) class CsrfParser(HTMLParser): @@ -287,7 +294,8 @@ def _fetch_released_versions(self, http_session: Session) -> set[str]: logging.debug(f"Found {len(versions)} releases with files") return versions except RequestException as e: - raise PyPICleanupError(f"Failed to fetch package information for '{self._package}': {e}") from e + msg = f"Failed to fetch package information for '{self._package}': {e}" + raise PyPICleanupError(msg) from e def _is_stable_release_version(self, version: str) -> bool: """Determine whether a version string denotes a stable release.""" @@ -305,14 +313,16 @@ def _parse_rc_version(self, version: str) -> str: """Parse a rc version string to determine the base version.""" match = self._rc_version_pattern.match(version) if not match: - raise PyPICleanupError(f"Invalid rc version '{version}'") + msg = f"Invalid rc version '{version}'" + raise PyPICleanupError(msg) return match.group("version") if match else None def _parse_dev_version(self, version: str) -> tuple[str, int]: """Parse a dev version string to determine the base version and dev version id.""" match = self._dev_version_pattern.match(version) if not match: - raise PyPICleanupError(f"Invalid dev version '{version}'") + msg = f"Invalid dev version '{version}'" + raise PyPICleanupError(msg) return match.group("version"), int(match.group("dev_id")) def _determine_versions_to_delete(self, versions: set[str]) -> set[str]: @@ -363,15 +373,17 @@ def _determine_versions_to_delete(self, versions: set[str]) -> set[str]: # Final safety checks if versions_to_delete == versions: - raise PyPICleanupError( + msg = ( f"Safety check failed: cleanup would delete ALL versions of '{self._package}'. " "This would make the package permanently inaccessible. Aborting." ) + raise PyPICleanupError(msg) if len(versions_to_delete.intersection(stable_versions)) > 0: - raise PyPICleanupError( + msg = ( f"Safety check failed: cleanup would delete one or more stable versions of '{self._package}'. " f"A regexp might be broken? (would delete {versions_to_delete.intersection(stable_versions)})" ) + raise PyPICleanupError(msg) unknown_versions = versions.difference(stable_versions).difference(rc_versions).difference(dev_versions) if unknown_versions: logging.warning(f"Found version string(s) in an unsupported format: {unknown_versions}") @@ -381,7 +393,8 @@ def _determine_versions_to_delete(self, versions: set[str]) -> set[str]: def _authenticate(self, http_session: Session) -> None: """Authenticate with PyPI.""" if not self._username or not self._password: - raise AuthenticationError("Username and password are required for authentication") + msg = "Username and password are required for authentication" + raise AuthenticationError(msg) logging.info(f"Authenticating user '{self._username}' with PyPI") @@ -397,7 +410,8 @@ def _authenticate(self, http_session: Session) -> None: logging.info("Authentication successful") except RequestException as e: - raise AuthenticationError(f"Network error during authentication: {e}") from e + msg = f"Network error during authentication: {e}" + raise AuthenticationError(msg) from e def _get_csrf_token(self, http_session: Session, form_action: str) -> str: """Extract CSRF token from a form page.""" @@ -406,7 +420,8 @@ def _get_csrf_token(self, http_session: Session, form_action: str) -> str: parser = CsrfParser(form_action) parser.feed(resp.text) if not parser.csrf: - raise AuthenticationError(f"No CSRF token found in {form_action}") + msg = f"No CSRF token found in {form_action}" + raise AuthenticationError(msg) return parser.csrf def _perform_login(self, http_session: Session) -> requests.Response: @@ -425,14 +440,16 @@ def _perform_login(self, http_session: Session) -> requests.Response: # Check if login failed (redirected back to login page) if response.url == f"{self._index_url}/account/login/": - raise AuthenticationError(f"Login failed for user '{self._username}' - check credentials") + msg = f"Login failed for user '{self._username}' - check credentials" + raise AuthenticationError(msg) return response def _handle_two_factor_auth(self, http_session: Session, response: requests.Response) -> None: """Handle two-factor authentication.""" if not self._otp: - raise AuthenticationError("Two-factor authentication required but no OTP secret provided") + msg = "Two-factor authentication required but no OTP secret provided" + raise AuthenticationError(msg) two_factor_url = response.url form_action = two_factor_url[len(self._index_url) :] @@ -462,11 +479,13 @@ def _handle_two_factor_auth(self, http_session: Session, response: requests.Resp except RequestException as e: if attempt == _LOGIN_RETRY_ATTEMPTS - 1: - raise AuthenticationError(f"Network error during 2FA: {e}") from e + msg = f"Network error during 2FA: {e}" + raise AuthenticationError(msg) from e logging.debug(f"Network error during 2FA attempt {attempt + 1}, retrying...") time.sleep(_LOGIN_RETRY_DELAY) - raise AuthenticationError("Two-factor authentication failed after all attempts") + msg = "Two-factor authentication failed after all attempts" + raise AuthenticationError(msg) def _delete_versions(self, http_session: Session, versions_to_delete: set[str]) -> None: """Delete the specified package versions.""" @@ -483,15 +502,15 @@ def _delete_versions(self, http_session: Session, versions_to_delete: set[str]) failed_deletions.append(version) if failed_deletions: - raise PyPICleanupError( - f"Failed to delete {len(failed_deletions)}/{len(versions_to_delete)} versions: {failed_deletions}" - ) + msg = f"Failed to delete {len(failed_deletions)}/{len(versions_to_delete)} versions: {failed_deletions}" + raise PyPICleanupError(msg) def _delete_single_version(self, http_session: Session, version: str) -> None: """Delete a single package version.""" # Safety check if not self._is_dev_version(version) or self._is_rc_version(version): - raise PyPICleanupError(f"Refusing to delete non-[dev|rc] version: {version}") + msg = f"Refusing to delete non-[dev|rc] version: {version}" + raise PyPICleanupError(msg) logging.debug(f"Deleting {self._package} version {version}") diff --git a/duckdb_packaging/setuptools_scm_version.py b/duckdb_packaging/setuptools_scm_version.py index 2ff79f80..5b0c5383 100644 --- a/duckdb_packaging/setuptools_scm_version.py +++ b/duckdb_packaging/setuptools_scm_version.py @@ -40,12 +40,14 @@ def version_scheme(version: Any) -> str: # Handle case where tag is None if version.tag is None: - raise ValueError("Need a valid version. Did you set a fallback_version in pyproject.toml?") + msg = "Need a valid version. Did you set a fallback_version in pyproject.toml?" + raise ValueError(msg) try: return _bump_version(str(version.tag), version.distance, version.dirty) except Exception as e: - raise RuntimeError(f"Failed to bump version: {e}") + msg = f"Failed to bump version: {e}" + raise RuntimeError(msg) def _bump_version(base_version: str, distance: int, dirty: bool = False) -> str: @@ -54,7 +56,8 @@ def _bump_version(base_version: str, distance: int, dirty: bool = False) -> str: try: major, minor, patch, post, rc = parse_version(base_version) except ValueError: - raise ValueError(f"Incorrect version format: {base_version} (expected X.Y.Z or X.Y.Z.postN)") + msg = f"Incorrect version format: {base_version} (expected X.Y.Z or X.Y.Z.postN)" + raise ValueError(msg) # If we're exactly on a tag (distance = 0, dirty=False) distance = int(distance or 0) @@ -110,7 +113,8 @@ def _git_describe_override_to_pep_440(override_value: str) -> str: match = describe_pattern.match(override_value) if not match: - raise ValueError(f"Invalid git describe override: {override_value}") + msg = f"Invalid git describe override: {override_value}" + raise ValueError(msg) version, distance, commit_hash = match.groups() diff --git a/scripts/generate_connection_methods.py b/scripts/generate_connection_methods.py index 51f667f6..a3bf36ad 100644 --- a/scripts/generate_connection_methods.py +++ b/scripts/generate_connection_methods.py @@ -37,15 +37,18 @@ def generate(): for i, line in enumerate(source_code): if line.startswith(INITIALIZE_METHOD): if start_index != -1: - raise ValueError("Encountered the INITIALIZE_METHOD a second time, quitting!") + msg = "Encountered the INITIALIZE_METHOD a second time, quitting!" + raise ValueError(msg) start_index = i elif line.startswith(END_MARKER): if end_index != -1: - raise ValueError("Encountered the END_MARKER a second time, quitting!") + msg = "Encountered the END_MARKER a second time, quitting!" + raise ValueError(msg) end_index = i if start_index == -1 or end_index == -1: - raise ValueError("Couldn't find start or end marker in source file") + msg = "Couldn't find start or end marker in source file" + raise ValueError(msg) start_section = source_code[: start_index + 1] end_section = source_code[end_index:] @@ -128,5 +131,6 @@ def create_definition(name, method) -> str: if __name__ == "__main__": - raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") + msg = "Please use 'generate_connection_code.py' instead of running the individual script(s)" + raise ValueError(msg) # generate() diff --git a/scripts/generate_connection_stubs.py b/scripts/generate_connection_stubs.py index 9b1be9aa..910e657a 100644 --- a/scripts/generate_connection_stubs.py +++ b/scripts/generate_connection_stubs.py @@ -20,15 +20,18 @@ def generate(): for i, line in enumerate(source_code): if line.startswith(START_MARKER): if start_index != -1: - raise ValueError("Encountered the START_MARKER a second time, quitting!") + msg = "Encountered the START_MARKER a second time, quitting!" + raise ValueError(msg) start_index = i elif line.startswith(END_MARKER): if end_index != -1: - raise ValueError("Encountered the END_MARKER a second time, quitting!") + msg = "Encountered the END_MARKER a second time, quitting!" + raise ValueError(msg) end_index = i if start_index == -1 or end_index == -1: - raise ValueError("Couldn't find start or end marker in source file") + msg = "Couldn't find start or end marker in source file" + raise ValueError(msg) start_section = source_code[: start_index + 1] end_section = source_code[end_index:] @@ -94,5 +97,6 @@ def create_definition(name, method, overloaded: bool) -> str: if __name__ == "__main__": - raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") + msg = "Please use 'generate_connection_code.py' instead of running the individual script(s)" + raise ValueError(msg) # generate() diff --git a/scripts/generate_connection_wrapper_methods.py b/scripts/generate_connection_wrapper_methods.py index d2ef0bba..743d0224 100644 --- a/scripts/generate_connection_wrapper_methods.py +++ b/scripts/generate_connection_wrapper_methods.py @@ -75,15 +75,18 @@ def remove_section(content, start_marker, end_marker) -> tuple[list[str], list[s for i, line in enumerate(content): if line.startswith(start_marker): if start_index != -1: - raise ValueError("Encountered the START_MARKER a second time, quitting!") + msg = "Encountered the START_MARKER a second time, quitting!" + raise ValueError(msg) start_index = i elif line.startswith(end_marker): if end_index != -1: - raise ValueError("Encountered the END_MARKER a second time, quitting!") + msg = "Encountered the END_MARKER a second time, quitting!" + raise ValueError(msg) end_index = i if start_index == -1 or end_index == -1: - raise ValueError("Couldn't find start or end marker in source file") + msg = "Couldn't find start or end marker in source file" + raise ValueError(msg) start_section = content[: start_index + 1] end_section = content[end_index:] diff --git a/scripts/generate_connection_wrapper_stubs.py b/scripts/generate_connection_wrapper_stubs.py index 3b3b8c93..4066d0ea 100644 --- a/scripts/generate_connection_wrapper_stubs.py +++ b/scripts/generate_connection_wrapper_stubs.py @@ -21,15 +21,18 @@ def generate(): for i, line in enumerate(source_code): if line.startswith(START_MARKER): if start_index != -1: - raise ValueError("Encountered the START_MARKER a second time, quitting!") + msg = "Encountered the START_MARKER a second time, quitting!" + raise ValueError(msg) start_index = i elif line.startswith(END_MARKER): if end_index != -1: - raise ValueError("Encountered the END_MARKER a second time, quitting!") + msg = "Encountered the END_MARKER a second time, quitting!" + raise ValueError(msg) end_index = i if start_index == -1 or end_index == -1: - raise ValueError("Couldn't find start or end marker in source file") + msg = "Couldn't find start or end marker in source file" + raise ValueError(msg) start_section = source_code[: start_index + 1] end_section = source_code[end_index:] @@ -118,5 +121,6 @@ def create_definition(name, method, overloaded: bool) -> str: if __name__ == "__main__": - raise ValueError("Please use 'generate_connection_code.py' instead of running the individual script(s)") + msg = "Please use 'generate_connection_code.py' instead of running the individual script(s)" + raise ValueError(msg) # generate() diff --git a/scripts/generate_import_cache_json.py b/scripts/generate_import_cache_json.py index 34cd84b6..7b43d175 100644 --- a/scripts/generate_import_cache_json.py +++ b/scripts/generate_import_cache_json.py @@ -107,7 +107,8 @@ def add_or_get_module(self, module_name: str) -> ImportCacheModule: def get_module(self, module_name: str) -> ImportCacheModule: if module_name not in self.modules: - raise ValueError("Import the module before registering its attributes!") + msg = "Import the module before registering its attributes!" + raise ValueError(msg) return self.modules[module_name] def get_item(self, item_name: str) -> Union[ImportCacheModule, ImportCacheAttribute]: diff --git a/sqllogic/conftest.py b/sqllogic/conftest.py index 77281d54..db875566 100644 --- a/sqllogic/conftest.py +++ b/sqllogic/conftest.py @@ -179,16 +179,20 @@ def determine_test_offsets(config: pytest.Config, num_tests: int) -> tuple[int, percentage_specified = start_offset_percentage is not None or end_offset_percentage is not None if index_specified and percentage_specified: - raise ValueError("You can only specify either start/end offsets or start/end offset percentages, not both") + msg = "You can only specify either start/end offsets or start/end offset percentages, not both" + raise ValueError(msg) if start_offset is not None and start_offset < 0: - raise ValueError("--start-offset must be a non-negative integer") + msg = "--start-offset must be a non-negative integer" + raise ValueError(msg) if start_offset_percentage is not None and (start_offset_percentage < 0 or start_offset_percentage > 100): - raise ValueError("--start-offset-percentage must be between 0 and 100") + msg = "--start-offset-percentage must be between 0 and 100" + raise ValueError(msg) if end_offset_percentage is not None and (end_offset_percentage < 0 or end_offset_percentage > 100): - raise ValueError("--end-offset-percentage must be between 0 and 100") + msg = "--end-offset-percentage must be between 0 and 100" + raise ValueError(msg) if start_offset is None: if start_offset_percentage is not None: @@ -197,9 +201,8 @@ def determine_test_offsets(config: pytest.Config, num_tests: int) -> tuple[int, start_offset = 0 if end_offset is not None and end_offset < start_offset: - raise ValueError( - f"--end-offset ({end_offset}) must be greater than or equal to the start offset ({start_offset})" - ) + msg = f"--end-offset ({end_offset}) must be greater than or equal to the start offset ({start_offset})" + raise ValueError(msg) if end_offset is None: if end_offset_percentage is not None: diff --git a/tests/fast/adbc/test_statement_bind.py b/tests/fast/adbc/test_statement_bind.py index dc5d1f59..c8b935cb 100644 --- a/tests/fast/adbc/test_statement_bind.py +++ b/tests/fast/adbc/test_statement_bind.py @@ -21,7 +21,8 @@ def _import(handle): return pa.RecordBatchReader._import_from_c(handle.address) elif isinstance(handle, adbc_driver_manager.ArrowSchemaHandle): return pa.Schema._import_from_c(handle.address) - raise NotImplementedError(f"Importing {handle!r}") + msg = f"Importing {handle!r}" + raise NotImplementedError(msg) def _bind(stmt, batch): diff --git a/tests/fast/udf/test_scalar.py b/tests/fast/udf/test_scalar.py index b7f4e343..57160d75 100644 --- a/tests/fast/udf/test_scalar.py +++ b/tests/fast/udf/test_scalar.py @@ -133,7 +133,8 @@ def no_op(x): @pytest.mark.parametrize("udf_type", ["arrow", "native"]) def test_exceptions(self, udf_type): def raises_exception(x): - raise AttributeError("error") + msg = "error" + raise AttributeError(msg) con = duckdb.connect() con.create_function("raises", raises_exception, [BIGINT], BIGINT, type=udf_type) diff --git a/tests/fast/udf/test_scalar_arrow.py b/tests/fast/udf/test_scalar_arrow.py index 984a1f8c..28d86455 100644 --- a/tests/fast/udf/test_scalar_arrow.py +++ b/tests/fast/udf/test_scalar_arrow.py @@ -47,7 +47,8 @@ def test_varargs(self): def variable_args(*args): # We return a chunked array here, but internally we convert this into a Table if len(args) == 0: - raise ValueError("Expected at least one argument") + msg = "Expected at least one argument" + raise ValueError(msg) for item in args: return item From 86d534b56aa93b0bae2fc3000735c48524ee534c Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:53:02 +0200 Subject: [PATCH 218/472] Ruff D: Docstring fixes --- duckdb/experimental/spark/exception.py | 2 +- duckdb/experimental/spark/sql/column.py | 10 ++++---- duckdb/experimental/spark/sql/dataframe.py | 2 +- duckdb/experimental/spark/sql/functions.py | 28 +++++++++++----------- duckdb/experimental/spark/sql/types.py | 8 +++---- duckdb/udf.py | 2 +- sqllogic/conftest.py | 2 +- tests/fast/spark/test_spark_dataframe.py | 2 +- tests/fast/test_json_logging.py | 2 +- tests/fast/test_pypi_cleanup.py | 2 +- tests/slow/test_h2oai_arrow.py | 2 +- 11 files changed, 31 insertions(+), 31 deletions(-) diff --git a/duckdb/experimental/spark/exception.py b/duckdb/experimental/spark/exception.py index 791f7090..1c2ad9a6 100644 --- a/duckdb/experimental/spark/exception.py +++ b/duckdb/experimental/spark/exception.py @@ -1,7 +1,7 @@ class ContributionsAcceptedError(NotImplementedError): """This method is not planned to be implemented, if you would like to implement this method or show your interest in this method to other members of the community, - feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb + feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb. """ def __init__(self, message=None) -> None: diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index 6cc92523..dd676846 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -29,7 +29,7 @@ def _unary_op( name: str, doc: str = "unary operator", ) -> Callable[["Column"], "Column"]: - """Create a method for given unary operator""" + """Create a method for given unary operator.""" def _(self: "Column") -> "Column": # Call the function identified by 'name' on the internal Expression object @@ -44,7 +44,7 @@ def _bin_op( name: str, doc: str = "binary operator", ) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]], "Column"]: - """Create a method for given binary operator""" + """Create a method for given binary operator.""" def _( self: "Column", @@ -62,7 +62,7 @@ def _bin_func( name: str, doc: str = "binary function", ) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]], "Column"]: - """Create a function expression for the given binary function""" + """Create a function expression for the given binary function.""" def _( self: "Column", @@ -245,14 +245,14 @@ def __eq__( # type: ignore[override] self, other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"], ) -> "Column": - """Binary function""" + """Binary function.""" return Column(self.expr == (_get_expr(other))) def __ne__( # type: ignore[override] self, other: object, ) -> "Column": - """Binary function""" + """Binary function.""" return Column(self.expr != (_get_expr(other))) __lt__ = _bin_op("__lt__") diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 57c8cd03..16d54f0b 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -845,7 +845,7 @@ def limit(self, num: int) -> "DataFrame": return DataFrame(rel, self.session) def __contains__(self, item: str) -> bool: - """Check if the :class:`DataFrame` contains a column by the name of `item`""" + """Check if the :class:`DataFrame` contains a column by the name of `item`.""" return item in self.relation @property diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index fddcd4c5..a319ec13 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -164,7 +164,7 @@ def _to_column_expr(col: ColumnOrName) -> Expression: def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Column: - r"""Replace all substrings of the specified string value that match regexp with rep. + """Replace all substrings of the specified string value that match regexp with rep. .. versionadded:: 1.5.0 @@ -1487,7 +1487,7 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C def approxCountDistinct(col: "ColumnOrName", rsd: Optional[float] = None) -> Column: - """.. versionadded:: 1.3.0 + """.. versionadded:: 1.3.0. .. versionchanged:: 3.4.0 Supports Spark Connect. @@ -2059,7 +2059,7 @@ def cbrt(col: "ColumnOrName") -> Column: def char(col: "ColumnOrName") -> Column: """Returns the ASCII character having the binary equivalent to `col`. If col is larger than 256 the - result is equivalent to char(col % 256) + result is equivalent to char(col % 256). .. versionadded:: 3.5.0 @@ -2373,7 +2373,7 @@ def rand(seed: Optional[int] = None) -> Column: def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: - r"""Returns true if `str` matches the Java regex `regexp`, or false otherwise. + """Returns true if `str` matches the Java regex `regexp`, or false otherwise. .. versionadded:: 3.5.0 @@ -2425,7 +2425,7 @@ def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: - r"""Returns a count of the number of times that the Java regex pattern `regexp` is matched + """Returns a count of the number of times that the Java regex pattern `regexp` is matched in the string `str`. .. versionadded:: 3.5.0 @@ -2456,7 +2456,7 @@ def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: - r"""Extract a specific group matched by the Java regex `regexp`, from the specified string column. + """Extract a specific group matched by the Java regex `regexp`, from the specified string column. If the regex did not match, or the specified group did not match, an empty string is returned. .. versionadded:: 1.5.0 @@ -2496,7 +2496,7 @@ def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: def regexp_extract_all(str: "ColumnOrName", regexp: "ColumnOrName", idx: Optional[Union[int, Column]] = None) -> Column: - r"""Extract all strings in the `str` that match the Java regex `regexp` + """Extract all strings in the `str` that match the Java regex `regexp` and corresponding to the regex group index. .. versionadded:: 3.5.0 @@ -2535,7 +2535,7 @@ def regexp_extract_all(str: "ColumnOrName", regexp: "ColumnOrName", idx: Optiona def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: - r"""Returns true if `str` matches the Java regex `regexp`, or false otherwise. + """Returns true if `str` matches the Java regex `regexp`, or false otherwise. .. versionadded:: 3.5.0 @@ -2587,7 +2587,7 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: def regexp_substr(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: - r"""Returns the substring that matches the Java regex `regexp` within the string `str`. + """Returns the substring that matches the Java regex `regexp` within the string `str`. If the regular expression is not found, the result is null. .. versionadded:: 3.5.0 @@ -3996,7 +3996,7 @@ def month(col: "ColumnOrName") -> Column: def dayofweek(col: "ColumnOrName") -> Column: """Extract the day of the week of a given date/timestamp as integer. - Ranges from 1 for a Sunday through to 7 for a Saturday + Ranges from 1 for a Sunday through to 7 for a Saturday. .. versionadded:: 2.3.0 @@ -4187,7 +4187,7 @@ def second(col: "ColumnOrName") -> Column: def weekofyear(col: "ColumnOrName") -> Column: """Extract the week number of a given date as integer. A week is considered to start on a Monday and week 1 is the first week with more than 3 days, - as defined by ISO 8601 + as defined by ISO 8601. .. versionadded:: 1.5.0 @@ -4609,7 +4609,7 @@ def atan(col: "ColumnOrName") -> Column: def atan2(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) -> Column: - """.. versionadded:: 1.4.0 + """.. versionadded:: 1.4.0. .. versionchanged:: 3.4.0 Supports Spark Connect. @@ -5577,7 +5577,7 @@ def var_samp(col: "ColumnOrName") -> Column: def variance(col: "ColumnOrName") -> Column: - """Aggregate function: alias for var_samp + """Aggregate function: alias for var_samp. .. versionadded:: 1.6.0 @@ -6242,7 +6242,7 @@ def instr(str: "ColumnOrName", substr: str) -> Column: def expr(str: str) -> Column: - """Parses the expression string into the column that it represents + """Parses the expression string into the column that it represents. .. versionadded:: 1.5.0 diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 55eb9855..4418f495 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -113,7 +113,7 @@ def fromInternal(self, obj: Any) -> Any: # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle class DataTypeSingleton(type): - """Metaclass for DataType""" + """Metaclass for DataType.""" _instances: ClassVar[dict[type["DataTypeSingleton"], "DataTypeSingleton"]] = {} @@ -855,7 +855,7 @@ def add( return self def __iter__(self) -> Iterator[StructField]: - """Iterate the fields""" + """Iterate the fields.""" return iter(self.fields) def __len__(self) -> int: @@ -1147,7 +1147,7 @@ def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": return tuple.__new__(cls, args) def asDict(self, recursive: bool = False) -> dict[str, Any]: - """Return as a dict + """Return as a dict. Parameters ---------- @@ -1200,7 +1200,7 @@ def __contains__(self, item: Any) -> bool: # let object acts like class def __call__(self, *args: Any) -> "Row": - """Create new Row object""" + """Create new Row object.""" if len(args) > len(self): raise ValueError( "Can not create Row with fields %s, expected %d values but got %s" % (self, len(self), args) diff --git a/duckdb/udf.py b/duckdb/udf.py index 21d6d53f..1357dee5 100644 --- a/duckdb/udf.py +++ b/duckdb/udf.py @@ -1,5 +1,5 @@ def vectorized(func): - """Decorate a function with annotated function parameters, so DuckDB can infer that the function should be provided with pyarrow arrays and should expect pyarrow array(s) as output""" + """Decorate a function with annotated function parameters, so DuckDB can infer that the function should be provided with pyarrow arrays and should expect pyarrow array(s) as output.""" import types from inspect import signature diff --git a/sqllogic/conftest.py b/sqllogic/conftest.py index db875566..40759e9c 100644 --- a/sqllogic/conftest.py +++ b/sqllogic/conftest.py @@ -271,7 +271,7 @@ def pytest_collection_modifyitems(session: pytest.Session, config: pytest.Config def pytest_runtest_setup(item: pytest.Item): - """Show the test index after the test name""" + """Show the test index after the test name.""" def get_from_tuple_list(tuples, key): for t in tuples: diff --git a/tests/fast/spark/test_spark_dataframe.py b/tests/fast/spark/test_spark_dataframe.py index 26006952..3fd78090 100644 --- a/tests/fast/spark/test_spark_dataframe.py +++ b/tests/fast/spark/test_spark_dataframe.py @@ -339,7 +339,7 @@ def test_df_columns(self, spark): assert "OtherInfo" in updatedDF.columns def test_array_and_map_type(self, spark): - """Array & Map""" + """Array & Map.""" arrayStructureSchema = StructType( [ StructField( diff --git a/tests/fast/test_json_logging.py b/tests/fast/test_json_logging.py index b29ea7bf..9e9908ea 100644 --- a/tests/fast/test_json_logging.py +++ b/tests/fast/test_json_logging.py @@ -6,7 +6,7 @@ def _parse_json_func(error_prefix: str): - """Helper to check that the error message is indeed parsable json""" + """Helper to check that the error message is indeed parsable json.""" def parse_func(exception): msg = exception.args[0] diff --git a/tests/fast/test_pypi_cleanup.py b/tests/fast/test_pypi_cleanup.py index 0e0439ce..74b1266f 100644 --- a/tests/fast/test_pypi_cleanup.py +++ b/tests/fast/test_pypi_cleanup.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Unit tests for pypi_cleanup.py +"""Unit tests for pypi_cleanup.py. Run with: python -m pytest test_pypi_cleanup.py -v """ diff --git a/tests/slow/test_h2oai_arrow.py b/tests/slow/test_h2oai_arrow.py index d0dbc2fe..35d8b1c7 100644 --- a/tests/slow/test_h2oai_arrow.py +++ b/tests/slow/test_h2oai_arrow.py @@ -197,7 +197,7 @@ def test_join(self, threads, function, large_data): @fixture(scope="module") def arrow_dataset_register(): - """Single fixture to download files and register them on the given connection""" + """Single fixture to download files and register them on the given connection.""" session = requests.Session() retries = urllib3_util.Retry( allowed_methods={"GET"}, # only retry on GETs (all we do) From afe4e69c574ade09d98c9f405183db06bbde28ba Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Tue, 9 Sep 2025 20:55:56 +0200 Subject: [PATCH 219/472] Ruff D301: Make docstring raw if they contain backslashes --- duckdb/experimental/spark/sql/functions.py | 26 +++++++++++----------- duckdb/experimental/spark/sql/types.py | 4 ++-- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index a319ec13..92631ee8 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -108,7 +108,7 @@ def struct(*cols: Column) -> Column: def array(*cols: Union["ColumnOrName", Union[list["ColumnOrName"], tuple["ColumnOrName", ...]]]) -> Column: - """Creates a new array column. + r"""Creates a new array column. .. versionadded:: 1.4.0 @@ -164,7 +164,7 @@ def _to_column_expr(col: ColumnOrName) -> Expression: def regexp_replace(str: "ColumnOrName", pattern: str, replacement: str) -> Column: - """Replace all substrings of the specified string value that match regexp with rep. + r"""Replace all substrings of the specified string value that match regexp with rep. .. versionadded:: 1.5.0 @@ -713,7 +713,7 @@ def asin(col: "ColumnOrName") -> Column: def like(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None) -> Column: - """Returns true if str matches `pattern` with `escape`, + r"""Returns true if str matches `pattern` with `escape`, null if any arguments are null, false otherwise. The default escape character is the '\'. @@ -750,7 +750,7 @@ def like(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Col def ilike(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Column"] = None) -> Column: - """Returns true if str matches `pattern` with `escape` case-insensitively, + r"""Returns true if str matches `pattern` with `escape` case-insensitively, null if any arguments are null, false otherwise. The default escape character is the '\'. @@ -2264,7 +2264,7 @@ def pow(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) def printf(format: "ColumnOrName", *cols: "ColumnOrName") -> Column: - """Formats the arguments in printf-style and returns the result as a string column. + r"""Formats the arguments in printf-style and returns the result as a string column. .. versionadded:: 3.5.0 @@ -2373,7 +2373,7 @@ def rand(seed: Optional[int] = None) -> Column: def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: - """Returns true if `str` matches the Java regex `regexp`, or false otherwise. + r"""Returns true if `str` matches the Java regex `regexp`, or false otherwise. .. versionadded:: 3.5.0 @@ -2425,7 +2425,7 @@ def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: - """Returns a count of the number of times that the Java regex pattern `regexp` is matched + r"""Returns a count of the number of times that the Java regex pattern `regexp` is matched in the string `str`. .. versionadded:: 3.5.0 @@ -2456,7 +2456,7 @@ def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: - """Extract a specific group matched by the Java regex `regexp`, from the specified string column. + r"""Extract a specific group matched by the Java regex `regexp`, from the specified string column. If the regex did not match, or the specified group did not match, an empty string is returned. .. versionadded:: 1.5.0 @@ -2496,7 +2496,7 @@ def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: def regexp_extract_all(str: "ColumnOrName", regexp: "ColumnOrName", idx: Optional[Union[int, Column]] = None) -> Column: - """Extract all strings in the `str` that match the Java regex `regexp` + r"""Extract all strings in the `str` that match the Java regex `regexp` and corresponding to the regex group index. .. versionadded:: 3.5.0 @@ -2535,7 +2535,7 @@ def regexp_extract_all(str: "ColumnOrName", regexp: "ColumnOrName", idx: Optiona def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: - """Returns true if `str` matches the Java regex `regexp`, or false otherwise. + r"""Returns true if `str` matches the Java regex `regexp`, or false otherwise. .. versionadded:: 3.5.0 @@ -2587,7 +2587,7 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: def regexp_substr(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: - """Returns the substring that matches the Java regex `regexp` within the string `str`. + r"""Returns the substring that matches the Java regex `regexp` within the string `str`. If the regular expression is not found, the result is null. .. versionadded:: 3.5.0 @@ -4274,7 +4274,7 @@ def acos(col: "ColumnOrName") -> Column: def call_function(funcName: str, *cols: "ColumnOrName") -> Column: - """Call a SQL function. + r"""Call a SQL function. .. versionadded:: 3.5.0 @@ -4851,7 +4851,7 @@ def initcap(col: "ColumnOrName") -> Column: def octet_length(col: "ColumnOrName") -> Column: - """Calculates the byte length for the specified string column. + r"""Calculates the byte length for the specified string column. .. versionadded:: 3.3.0 diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 4418f495..fa961eb1 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -736,7 +736,7 @@ def typeName(self) -> str: # type: ignore[override] class StructType(DataType): - """Struct type, consisting of a list of :class:`StructField`. + r"""Struct type, consisting of a list of :class:`StructField`. This is the data type representing a :class:`Row`. @@ -798,7 +798,7 @@ def add( nullable: bool = True, metadata: Optional[dict[str, Any]] = None, ) -> "StructType": - """Construct a :class:`StructType` by adding new elements to it, to define the schema. + r"""Construct a :class:`StructType` by adding new elements to it, to define the schema. The method accepts either: a) A single parameter which is a :class:`StructField` object. From 42ff0876bba2464138757913671ff8dc7839d9d0 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 10 Sep 2025 15:08:47 +0200 Subject: [PATCH 220/472] Fix testfixture yield --- tests/fast/test_expression.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/fast/test_expression.py b/tests/fast/test_expression.py index 049a2a5c..c7207338 100644 --- a/tests/fast/test_expression.py +++ b/tests/fast/test_expression.py @@ -36,7 +36,8 @@ def filter_rel(): ) tbl(a, b) """ ) - return rel + yield rel + con.close() class TestExpression: From 6396d4c131eb598675a3c3b871eef04d730c22ba Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 10 Sep 2025 15:27:07 +0200 Subject: [PATCH 221/472] Ruff config: disable docstring checks for tests --- pyproject.toml | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 811d9c1b..02e85f5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -349,9 +349,6 @@ select = [ ] ignore = [] -[tool.ruff.lint.pycodestyle] -max-doc-length = 88 - [tool.ruff.lint.pydocstyle] convention = "google" @@ -361,6 +358,12 @@ ban-relative-imports = "all" [tool.ruff.lint.flake8-type-checking] strict = true +[tool.ruff.lint.per-file-ignores] +"tests/**.py" = [ + # No need for package, module, class, function, init etc docstrings in tests + 'D100', 'D101', 'D102', 'D103', 'D104', 'D105', 'D107' +] + [tool.ruff.format] docstring-code-format = true docstring-code-line-length = 88 From b0c9d5afbb5e6eb168c1b45f456254d177f995a6 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 10 Sep 2025 15:28:29 +0200 Subject: [PATCH 222/472] Ruff noqa D205 on tests: no need to check newline between summary and description in tests --- tests/fast/numpy/test_numpy_new_path.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/fast/numpy/test_numpy_new_path.py b/tests/fast/numpy/test_numpy_new_path.py index d95c93d1..272d5e45 100644 --- a/tests/fast/numpy/test_numpy_new_path.py +++ b/tests/fast/numpy/test_numpy_new_path.py @@ -1,6 +1,6 @@ """The support for scaning over numpy arrays reuses many codes for pandas. Therefore, we only test the new codes and exec paths. -""" +""" # noqa: D205 from datetime import timedelta From b9af3d7b01d2020090fb17043d0d358ea08f6d37 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 10 Sep 2025 15:30:34 +0200 Subject: [PATCH 223/472] Ruff noqa D205: dont need the newline in docstrings for existing code --- duckdb/experimental/spark/exception.py | 2 +- duckdb/experimental/spark/sql/column.py | 4 +- duckdb/experimental/spark/sql/dataframe.py | 18 +-- duckdb/experimental/spark/sql/functions.py | 154 ++++++++++----------- duckdb/experimental/spark/sql/group.py | 2 +- duckdb/experimental/spark/sql/types.py | 8 +- sqllogic/conftest.py | 4 +- 7 files changed, 96 insertions(+), 96 deletions(-) diff --git a/duckdb/experimental/spark/exception.py b/duckdb/experimental/spark/exception.py index 1c2ad9a6..24c4f291 100644 --- a/duckdb/experimental/spark/exception.py +++ b/duckdb/experimental/spark/exception.py @@ -2,7 +2,7 @@ class ContributionsAcceptedError(NotImplementedError): """This method is not planned to be implemented, if you would like to implement this method or show your interest in this method to other members of the community, feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb. - """ + """ # noqa: D205 def __init__(self, message=None) -> None: doc = self.__class__.__doc__ diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index dd676846..ea6cd8a8 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -165,7 +165,7 @@ def __getitem__(self, k: Any) -> "Column": +------------------+------+ | abc| value| +------------------+------+ - """ + """ # noqa: D205 if isinstance(k, slice): raise ContributionsAcceptedError # if k.step is not None: @@ -199,7 +199,7 @@ def __getattr__(self, item: Any) -> "Column": +------+ | value| +------+ - """ + """ # noqa: D205 if item.startswith("__"): msg = "Can not access __ (dunder) method" raise AttributeError(msg) diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 16d54f0b..99da92ec 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -172,7 +172,7 @@ def withColumns(self, *colsMap: dict[str, Column]) -> "DataFrame": | 2|Alice| 4| 5| | 5| Bob| 7| 8| +---+-----+----+----+ - """ + """ # noqa: D205 # Below code is to help enable kwargs in future. assert len(colsMap) == 1 colsMap = colsMap[0] # type: ignore[assignment] @@ -250,7 +250,7 @@ def withColumnsRenamed(self, colsMap: dict[str, str]) -> "DataFrame": | 2|Alice| 4| 5| | 5| Bob| 7| 8| +---+-----+----+----+ - """ + """ # noqa: D205 if not isinstance(colsMap, dict): raise PySparkTypeError( error_class="NOT_DICT", @@ -974,7 +974,7 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] | Bob| 2| 2| | Bob| 5| 1| +-----+---+-----+ - """ + """ # noqa: D205 from .group import GroupedData, Grouping if len(cols) == 1 and isinstance(cols[0], list): @@ -1034,7 +1034,7 @@ def union(self, other: "DataFrame") -> "DataFrame": | 1| 2| 3| | 1| 2| 3| +----+----+----+ - """ + """ # noqa: D205 return DataFrame(self.relation.union(other.relation), self.session) unionAll = union @@ -1094,7 +1094,7 @@ def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) -> | 1| 2| 3|NULL| |NULL| 4| 5| 6| +----+----+----+----+ - """ + """ # noqa: D205 if allowMissingColumns: cols = [] for col in self.relation.columns: @@ -1144,7 +1144,7 @@ def intersect(self, other: "DataFrame") -> "DataFrame": | b| 3| | a| 1| +---+---+ - """ + """ # noqa: D205 return self.intersectAll(other).drop_duplicates() def intersectAll(self, other: "DataFrame") -> "DataFrame": @@ -1181,7 +1181,7 @@ def intersectAll(self, other: "DataFrame") -> "DataFrame": | a| 1| | b| 3| +---+---+ - """ + """ # noqa: D205 return DataFrame(self.relation.intersect(other.relation), self.session) def exceptAll(self, other: "DataFrame") -> "DataFrame": @@ -1221,7 +1221,7 @@ def exceptAll(self, other: "DataFrame") -> "DataFrame": | c| 4| +---+---+ - """ + """ # noqa: D205 return DataFrame(self.relation.except_(other.relation), self.session) def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": @@ -1275,7 +1275,7 @@ def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": +-----+---+------+ |Alice| 5| 80| +-----+---+------+ - """ + """ # noqa: D205 if subset: rn_col = f"tmp_col_{uuid.uuid1().hex}" subset_str = ", ".join([f'"{c}"' for c in subset]) diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index 92631ee8..7ae923f4 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -25,7 +25,7 @@ def _invoke_function_over_columns(name: str, *cols: "ColumnOrName") -> Column: """Invokes n-ary JVM function identified by name and wraps the result with :class:`~pyspark.sql.Column`. - """ + """ # noqa: D205 cols = [_to_column_expr(expr) for expr in cols] return _invoke_function(name, *cols) @@ -211,7 +211,7 @@ def slice(x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["C >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ["x"]) >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect() [Row(sliced=[2, 3]), Row(sliced=[5])] - """ + """ # noqa: D205 start = ConstantExpression(start) if isinstance(start, int) else _to_column_expr(start) length = ConstantExpression(length) if isinstance(length, int) else _to_column_expr(length) @@ -302,7 +302,7 @@ def asc_nulls_first(col: "ColumnOrName") -> Column: | 1| Bob| +---+-----+ - """ + """ # noqa: D205 return asc(col).nulls_first() @@ -337,7 +337,7 @@ def asc_nulls_last(col: "ColumnOrName") -> Column: | 0| NULL| +---+-----+ - """ + """ # noqa: D205 return asc(col).nulls_last() @@ -408,7 +408,7 @@ def desc_nulls_first(col: "ColumnOrName") -> Column: | 2|Alice| +---+-----+ - """ + """ # noqa: D205 return desc(col).nulls_first() @@ -443,7 +443,7 @@ def desc_nulls_last(col: "ColumnOrName") -> Column: | 0| NULL| +---+-----+ - """ + """ # noqa: D205 return desc(col).nulls_last() @@ -473,7 +473,7 @@ def left(str: "ColumnOrName", len: "ColumnOrName") -> Column: ... ) >>> df.select(left(df.a, df.b).alias("r")).collect() [Row(r='Spa')] - """ + """ # noqa: D205 len = _to_column_expr(len) return Column( CaseExpression(len <= ConstantExpression(0), ConstantExpression("")).otherwise( @@ -508,7 +508,7 @@ def right(str: "ColumnOrName", len: "ColumnOrName") -> Column: ... ) >>> df.select(right(df.a, df.b).alias("r")).collect() [Row(r='SQL')] - """ + """ # noqa: D205 len = _to_column_expr(len) return Column( CaseExpression(len <= ConstantExpression(0), ConstantExpression("")).otherwise( @@ -741,7 +741,7 @@ def like(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Col ... ) >>> df.select(like(df.a, df.b, lit("/")).alias("r")).collect() [Row(r=True)] - """ + """ # noqa: D205 if escapeChar is None: escapeChar = ConstantExpression("\\") else: @@ -778,7 +778,7 @@ def ilike(str: "ColumnOrName", pattern: "ColumnOrName", escapeChar: Optional["Co ... ) >>> df.select(ilike(df.a, df.b, lit("/")).alias("r")).collect() [Row(r=True)] - """ + """ # noqa: D205 if escapeChar is None: escapeChar = ConstantExpression("\\") else: @@ -872,7 +872,7 @@ def array_append(col: "ColumnOrName", value: Any) -> Column: [Row(array_append(c1, c2)=['b', 'a', 'c', 'c'])] >>> df.select(array_append(df.c1, "x")).collect() [Row(array_append(c1, x)=['b', 'a', 'c', 'x'])] - """ + """ # noqa: D205 return _invoke_function("list_append", _to_column_expr(col), _get_expr(value)) @@ -912,7 +912,7 @@ def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: An [Row(data=['a', 'd', 'b', 'c']), Row(data=['c', 'b', 'd', 'a'])] >>> df.select(array_insert(df.data, 5, "hello").alias("data")).collect() [Row(data=['a', 'b', 'c', None, 'hello']), Row(data=['c', 'b', 'a', None, 'hello'])] - """ + """ # noqa: D205 pos = _get_expr(pos) arr = _to_column_expr(arr) # Depending on if the position is positive or not, we need to interpret it differently. @@ -992,7 +992,7 @@ def array_contains(col: "ColumnOrName", value: Any) -> Column: [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)] >>> df.select(array_contains(df.data, lit("a"))).collect() [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)] - """ + """ # noqa: D205 value = _get_expr(value) return _invoke_function("array_contains", _to_column_expr(col), value) @@ -1051,7 +1051,7 @@ def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) >>> df.select(array_intersect(df.c1, df.c2)).collect() [Row(array_intersect(c1, c2)=['a', 'c'])] - """ + """ # noqa: D205 return _invoke_function_over_columns("array_intersect", col1, col2) @@ -1082,7 +1082,7 @@ def array_union(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) >>> df.select(array_union(df.c1, df.c2)).collect() [Row(array_union(c1, c2)=['b', 'a', 'c', 'd', 'f'])] - """ + """ # noqa: D205 return _invoke_function_over_columns("array_distinct", _invoke_function_over_columns("array_concat", col1, col2)) @@ -1265,7 +1265,7 @@ def mean(col: "ColumnOrName") -> Column: +-------+ | 4.5| +-------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("mean", col) @@ -1479,7 +1479,7 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C +---------------+ | 3| +---------------+ - """ + """ # noqa: D205 if rsd is not None: msg = "rsd is not supported by DuckDB" raise ValueError(msg) @@ -1588,7 +1588,7 @@ def concat_ws(sep: str, *cols: "ColumnOrName") -> "Column": >>> df = spark.createDataFrame([("abcd", "123")], ["s", "d"]) >>> df.select(concat_ws("-", df.s, df.d).alias("s")).collect() [Row(s='abcd-123')] - """ + """ # noqa: D205 cols = [_to_column_expr(expr) for expr in cols] return _invoke_function("concat_ws", ConstantExpression(sep), *cols) @@ -1854,7 +1854,7 @@ def equal_null(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: ... ) >>> df.select(equal_null(df.a, df.b).alias("r")).collect() [Row(r=True), Row(r=False)] - """ + """ # noqa: D205 if isinstance(col1, str): col1 = col(col1) @@ -1901,7 +1901,7 @@ def flatten(col: "ColumnOrName") -> Column: |[1, 2, 3, 4, 5, 6]| | NULL| +------------------+ - """ + """ # noqa: D205 col = _to_column_expr(col) contains_null = _list_contains_null(col) return Column(CaseExpression(contains_null, None).otherwise(FunctionExpression("flatten", col))) @@ -2077,7 +2077,7 @@ def char(col: "ColumnOrName") -> Column: +--------+ | A| +--------+ - """ + """ # noqa: D205 col = _to_column_expr(col) return Column(FunctionExpression("chr", CaseExpression(col > 256, col % 256).otherwise(col))) @@ -2110,7 +2110,7 @@ def corr(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) >>> df.agg(corr("a", "b").alias("c")).collect() [Row(c=1.0)] - """ + """ # noqa: D205 return _invoke_function_over_columns("corr", col1, col2) @@ -2183,7 +2183,7 @@ def negative(col: "ColumnOrName") -> Column: | -1| | -2| +------------+ - """ + """ # noqa: D205 return abs(col) * -1 @@ -2364,7 +2364,7 @@ def rand(seed: Optional[int] = None) -> Column: | 0|1.8575681106759028| | 1|1.5288056527339444| +---+------------------+ - """ + """ # noqa: D205 if seed is not None: # Maybe call setseed just before but how do we know when it is executed? msg = "Seed is not yet implemented" @@ -2451,7 +2451,7 @@ def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: [Row(d=0)] >>> df.select(regexp_count("str", col("regexp")).alias("d")).collect() [Row(d=3)] - """ + """ # noqa: D205 return _invoke_function_over_columns("len", _invoke_function_over_columns("regexp_extract_all", str, regexp)) @@ -2489,7 +2489,7 @@ def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: >>> df = spark.createDataFrame([("aaaac",)], ["str"]) >>> df.select(regexp_extract("str", "(a+)(b)?(c)", 2).alias("d")).collect() [Row(d='')] - """ + """ # noqa: D205 return _invoke_function( "regexp_extract", _to_column_expr(str), ConstantExpression(pattern), ConstantExpression(idx) ) @@ -2526,7 +2526,7 @@ def regexp_extract_all(str: "ColumnOrName", regexp: "ColumnOrName", idx: Optiona [Row(d=['200', '400'])] >>> df.select(regexp_extract_all("str", col("regexp")).alias("d")).collect() [Row(d=['100', '300'])] - """ + """ # noqa: D205 if idx is None: idx = 1 return _invoke_function( @@ -2613,7 +2613,7 @@ def regexp_substr(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: [Row(d=None)] >>> df.select(regexp_substr("str", col("regexp")).alias("d")).collect() [Row(d='1')] - """ + """ # noqa: D205 return Column( FunctionExpression( "nullif", @@ -2689,7 +2689,7 @@ def sequence(start: "ColumnOrName", stop: "ColumnOrName", step: Optional["Column >>> df2 = spark.createDataFrame([(4, -4, -2)], ("C1", "C2", "C3")) >>> df2.select(sequence("C1", "C2", "C3").alias("r")).collect() [Row(r=[4, 2, 0, -2, -4])] - """ + """ # noqa: D205 if step is None: return _invoke_function_over_columns("generate_series", start, stop) else: @@ -2843,7 +2843,7 @@ def encode(col: "ColumnOrName", charset: str) -> Column: +----------------+ | [61 62 63 64]| +----------------+ - """ + """ # noqa: D205 if charset != "UTF-8": msg = "Only UTF-8 charset is supported right now" raise ContributionsAcceptedError(msg) @@ -2869,7 +2869,7 @@ def find_in_set(str: "ColumnOrName", str_array: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([("ab", "abc,b,ab,c,def")], ["a", "b"]) >>> df.select(find_in_set(df.a, df.b).alias("r")).collect() [Row(r=3)] - """ + """ # noqa: D205 str_array = _to_column_expr(str_array) str = _to_column_expr(str) return Column( @@ -3019,7 +3019,7 @@ def greatest(*cols: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"]) >>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect() [Row(greatest=4)] - """ + """ # noqa: D205 if len(cols) < 2: msg = "greatest should take at least 2 columns" raise ValueError(msg) @@ -3052,7 +3052,7 @@ def least(*cols: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"]) >>> df.select(least(df.a, df.b, df.c).alias("least")).collect() [Row(least=1)] - """ + """ # noqa: D205 if len(cols) < 2: msg = "least should take at least 2 columns" raise ValueError(msg) @@ -3244,7 +3244,7 @@ def endswith(str: "ColumnOrName", suffix: "ColumnOrName") -> Column: +--------------+--------------+ | true| false| +--------------+--------------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("suffix", str, suffix) @@ -3296,7 +3296,7 @@ def startswith(str: "ColumnOrName", prefix: "ColumnOrName") -> Column: +----------------+----------------+ | true| false| +----------------+----------------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("starts_with", str, prefix) @@ -3324,7 +3324,7 @@ def length(col: "ColumnOrName") -> Column: -------- >>> spark.createDataFrame([("ABC ",)], ["a"]).select(length("a").alias("length")).collect() [Row(length=4)] - """ + """ # noqa: D205 return _invoke_function_over_columns("length", col) @@ -3370,7 +3370,7 @@ def coalesce(*cols: "ColumnOrName") -> Column: | 1|NULL| 1.0| |NULL| 2| 0.0| +----+----+----------------+ - """ + """ # noqa: D205 cols = [_to_column_expr(expr) for expr in cols] return Column(CoalesceOperator(*cols)) @@ -3400,7 +3400,7 @@ def nvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: ... ) >>> df.select(nvl(df.a, df.b).alias("r")).collect() [Row(r=8), Row(r=1)] - """ + """ # noqa: D205 return coalesce(col1, col2) @@ -3460,7 +3460,7 @@ def ifnull(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: | 8| | 1| +------------+ - """ + """ # noqa: D205 return coalesce(col1, col2) @@ -3554,7 +3554,7 @@ def sha2(col: "ColumnOrName", numBits: int) -> Column: |Alice|3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043| |Bob |cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961| +-----+----------------------------------------------------------------+ - """ + """ # noqa: D205 if numBits not in {224, 256, 384, 512, 0}: msg = "numBits should be one of {224, 256, 384, 512, 0}" raise ValueError(msg) @@ -3586,7 +3586,7 @@ def curdate() -> Column: +--------------+ | 2022-08-26| +--------------+ - """ + """ # noqa: D205 return _invoke_function("today") @@ -3613,7 +3613,7 @@ def current_date() -> Column: +--------------+ | 2022-08-26| +--------------+ - """ + """ # noqa: D205 return curdate() @@ -4018,7 +4018,7 @@ def dayofweek(col: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(dayofweek("dt").alias("day")).collect() [Row(day=4)] - """ + """ # noqa: D205 return _invoke_function_over_columns("dayofweek", col) + lit(1) @@ -4209,7 +4209,7 @@ def weekofyear(col: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([("2015-04-08",)], ["dt"]) >>> df.select(weekofyear(df.dt).alias("week")).collect() [Row(week=15)] - """ + """ # noqa: D205 return _invoke_function_over_columns("weekofyear", col) @@ -4367,7 +4367,7 @@ def covar_pop(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) >>> df.agg(covar_pop("a", "b").alias("c")).collect() [Row(c=0.0)] - """ + """ # noqa: D205 return _invoke_function_over_columns("covar_pop", col1, col2) @@ -4399,7 +4399,7 @@ def covar_samp(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) >>> df.agg(covar_samp("a", "b").alias("c")).collect() [Row(c=0.0)] - """ + """ # noqa: D205 return _invoke_function_over_columns("covar_samp", col1, col2) @@ -4545,7 +4545,7 @@ def degrees(col: "ColumnOrName") -> Column: >>> df = spark.range(1) >>> df.select(degrees(lit(math.pi))).first() Row(DEGREES(3.14159...)=180.0) - """ + """ # noqa: D205 return _invoke_function_over_columns("degrees", col) @@ -4573,7 +4573,7 @@ def radians(col: "ColumnOrName") -> Column: >>> df = spark.range(1) >>> df.select(radians(lit(180))).first() Row(RADIANS(180)=3.14159...) - """ + """ # noqa: D205 return _invoke_function_over_columns("radians", col) @@ -4698,7 +4698,7 @@ def round(col: "ColumnOrName", scale: int = 0) -> Column: -------- >>> spark.createDataFrame([(2.5,)], ["a"]).select(round("a", 0).alias("r")).collect() [Row(r=3.0)] - """ + """ # noqa: D205 return _invoke_function_over_columns("round", col, lit(scale)) @@ -4727,7 +4727,7 @@ def bround(col: "ColumnOrName", scale: int = 0) -> Column: -------- >>> spark.createDataFrame([(2.5,)], ["a"]).select(bround("a", 0).alias("r")).collect() [Row(r=2.0)] - """ + """ # noqa: D205 return _invoke_function_over_columns("round_even", col, lit(scale)) @@ -4796,7 +4796,7 @@ def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: +----------------------+ | a| +----------------------+ - """ + """ # noqa: D205 index = ConstantExpression(index) if isinstance(index, int) else _to_column_expr(index) # Spark uses 0-indexing, DuckDB 1-indexing index = index + 1 @@ -5029,7 +5029,7 @@ def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Col [Row(next_month=datetime.date(2015, 6, 8))] >>> df.select(add_months("dt", -2).alias("prev_month")).collect() [Row(prev_month=datetime.date(2015, 2, 8))] - """ + """ # noqa: D205 months = ConstantExpression(months) if isinstance(months, int) else _to_column_expr(months) return _invoke_function("date_add", _to_column_expr(start), FunctionExpression("to_months", months)).cast("date") @@ -5064,7 +5064,7 @@ def array_join(col: "ColumnOrName", delimiter: str, null_replacement: Optional[s [Row(joined='a,b,c'), Row(joined='a')] >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect() [Row(joined='a,b,c'), Row(joined='a,NULL')] - """ + """ # noqa: D205 col = _to_column_expr(col) if null_replacement is not None: col = FunctionExpression( @@ -5111,7 +5111,7 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ["data"]) >>> df.select(array_position(df.data, "a")).collect() [Row(array_position(data, a)=3), Row(array_position(data, a)=0)] - """ + """ # noqa: D205 return Column( CoalesceOperator( _to_column_expr(_invoke_function_over_columns("list_position", col, lit(value))), ConstantExpression(0) @@ -5143,7 +5143,7 @@ def array_prepend(col: "ColumnOrName", value: Any) -> Column: >>> df = spark.createDataFrame([([2, 3, 4],), ([],)], ["data"]) >>> df.select(array_prepend(df.data, 1)).collect() [Row(array_prepend(data, 1)=[1, 2, 3, 4]), Row(array_prepend(data, 1)=[1])] - """ + """ # noqa: D205 return _invoke_function_over_columns("list_prepend", lit(value), col) @@ -5247,7 +5247,7 @@ def array_sort(col: "ColumnOrName", comparator: Optional[Callable[[Column, Colum ... ).alias("r") ... ).collect() [Row(r=['foobar', 'foo', None, 'bar']), Row(r=['foo']), Row(r=[])] - """ + """ # noqa: D205 if comparator is not None: msg = "comparator is not yet supported" raise ContributionsAcceptedError(msg) @@ -5286,7 +5286,7 @@ def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])] >>> df.select(sort_array(df.data, asc=False).alias("r")).collect() [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] - """ + """ # noqa: D205 if asc: order = "ASC" null_order = "NULLS FIRST" @@ -5381,7 +5381,7 @@ def split_part(src: "ColumnOrName", delimiter: "ColumnOrName", partNum: "ColumnO ... ) >>> df.select(split_part(df.a, df.b, df.c).alias("r")).collect() [Row(r='13')] - """ + """ # noqa: D205 src = _to_column_expr(src) delimiter = _to_column_expr(delimiter) partNum = _to_column_expr(partNum) @@ -5422,7 +5422,7 @@ def stddev_samp(col: "ColumnOrName") -> Column: +------------------+ |1.8708286933869...| +------------------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("stddev_samp", col) @@ -5513,7 +5513,7 @@ def stddev_pop(col: "ColumnOrName") -> Column: +-----------------+ |1.707825127659...| +-----------------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("stddev_pop", col) @@ -5572,7 +5572,7 @@ def var_samp(col: "ColumnOrName") -> Column: +------------+ | 3.5| +------------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("var_samp", col) @@ -5701,7 +5701,7 @@ def to_date(col: "ColumnOrName", format: Optional[str] = None) -> Column: >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) >>> df.select(to_date(df.t, "yyyy-MM-dd HH:mm:ss").alias("date")).collect() [Row(date=datetime.date(1997, 2, 28))] - """ + """ # noqa: D205 return _to_date_or_timestamp(col, _types.DateType(), format) @@ -5739,7 +5739,7 @@ def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: >>> df = spark.createDataFrame([("1997-02-28 10:30:00",)], ["t"]) >>> df.select(to_timestamp(df.t, "yyyy-MM-dd HH:mm:ss").alias("dt")).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] - """ + """ # noqa: D205 return _to_date_or_timestamp(col, _types.TimestampNTZType(), format) @@ -5770,7 +5770,7 @@ def to_timestamp_ltz( >>> df.select(to_timestamp_ltz(df.e).alias("r")).collect() ... # doctest: +SKIP [Row(r=datetime.datetime(2016, 12, 31, 0, 0))] - """ + """ # noqa: D205 return _to_date_or_timestamp(timestamp, _types.TimestampNTZType(), format) @@ -5801,7 +5801,7 @@ def to_timestamp_ntz( >>> df.select(to_timestamp_ntz(df.e).alias("r")).collect() ... # doctest: +SKIP [Row(r=datetime.datetime(2016, 4, 8, 0, 0))] - """ + """ # noqa: D205 return _to_date_or_timestamp(timestamp, _types.TimestampNTZType(), format) @@ -5824,7 +5824,7 @@ def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = Non [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] >>> df.select(try_to_timestamp(df.t, lit("yyyy-MM-dd HH:mm:ss")).alias("dt")).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] - """ + """ # noqa: D205 if format is None: format = lit(["%Y-%m-%d", "%Y-%m-%d %H:%M:%S"]) @@ -5881,7 +5881,7 @@ def substr(str: "ColumnOrName", pos: "ColumnOrName", len: Optional["ColumnOrName +------------------------+ | k SQL| +------------------------+ - """ + """ # noqa: D205 if len is not None: return _invoke_function_over_columns("substring", str, pos, len) else: @@ -5939,7 +5939,7 @@ def unix_millis(col: "ColumnOrName") -> Column: >>> df.select(unix_millis(to_timestamp(df.t)).alias("n")).collect() [Row(n=1437584400000)] >>> spark.conf.unset("spark.sql.session.timeZone") - """ + """ # noqa: D205 return _unix_diff(col, "milliseconds") @@ -5956,7 +5956,7 @@ def unix_seconds(col: "ColumnOrName") -> Column: >>> df.select(unix_seconds(to_timestamp(df.t)).alias("n")).collect() [Row(n=1437584400)] >>> spark.conf.unset("spark.sql.session.timeZone") - """ + """ # noqa: D205 return _unix_diff(col, "seconds") @@ -5980,7 +5980,7 @@ def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ["x", "y"]) >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() [Row(overlap=True), Row(overlap=False)] - """ + """ # noqa: D205 a1 = _to_column_expr(a1) a2 = _to_column_expr(a2) @@ -6045,7 +6045,7 @@ def arrays_zip(*cols: "ColumnOrName") -> Column: | | |-- vals1: long (nullable = true) | | |-- vals2: long (nullable = true) | | |-- vals3: long (nullable = true) - """ + """ # noqa: D205 return _invoke_function_over_columns("list_zip", *cols) @@ -6084,7 +6084,7 @@ def substring(str: "ColumnOrName", pos: int, len: int) -> Column: ... ) >>> df.select(substring(df.s, 1, 2).alias("s")).collect() [Row(s='ab')] - """ + """ # noqa: D205 return _invoke_function( "substring", _to_column_expr(str), @@ -6130,7 +6130,7 @@ def contains(left: "ColumnOrName", right: "ColumnOrName") -> Column: +--------------+--------------+ | true| false| +--------------+--------------+ - """ + """ # noqa: D205 return _invoke_function_over_columns("contains", left, right) @@ -6157,7 +6157,7 @@ def reverse(col: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([([2, 1, 3],), ([1],), ([],)], ["data"]) >>> df.select(reverse(df.data).alias("r")).collect() [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])] - """ + """ # noqa: D205 return _invoke_function("reverse", _to_column_expr(col)) @@ -6197,7 +6197,7 @@ def concat(*cols: "ColumnOrName") -> Column: [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] >>> df DataFrame[arr: array] - """ + """ # noqa: D205 return _invoke_function_over_columns("concat", *cols) @@ -6237,7 +6237,7 @@ def instr(str: "ColumnOrName", substr: str) -> Column: ... ) >>> df.select(instr(df.s, "b").alias("s")).collect() [Row(s=2)] - """ + """ # noqa: D205 return _invoke_function("instr", _to_column_expr(str), ConstantExpression(substr)) @@ -6278,5 +6278,5 @@ def broadcast(df: "DataFrame") -> "DataFrame": dataset to all the worker nodes. However, DuckDB operates on a single-node architecture . As a result, the function simply returns the input DataFrame without applying any modifications or optimizations, since broadcasting is not applicable in the DuckDB context. - """ + """ # noqa: D205 return df diff --git a/duckdb/experimental/spark/sql/group.py b/duckdb/experimental/spark/sql/group.py index 7aa9eb11..ab8e89cf 100644 --- a/duckdb/experimental/spark/sql/group.py +++ b/duckdb/experimental/spark/sql/group.py @@ -76,7 +76,7 @@ class GroupedData: """A set of methods for aggregations on a :class:`DataFrame`, created by :func:`DataFrame.groupBy`. - """ + """ # noqa: D205 def __init__(self, grouping: Grouping, df: DataFrame) -> None: self._grouping = grouping diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index fa961eb1..606f792c 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -140,7 +140,7 @@ def typeName(cls) -> str: class AtomicType(DataType): """An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. - """ + """ # noqa: D205 class NumericType(AtomicType): @@ -836,7 +836,7 @@ def add( >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 True - """ + """ # noqa: D205 if isinstance(field, StructField): self.fields.append(field) self.names.append(field.name) @@ -996,7 +996,7 @@ def module(cls) -> str: def scalaUDT(cls) -> str: """The class name of the paired Scala UDT (could be '', if there is no corresponding one). - """ + """ # noqa: D205 return "" def needConversion(self) -> bool: @@ -1125,7 +1125,7 @@ class Row(tuple): >>> row2 = Row(name="Alice", age=11) >>> row1 == row2 True - """ + """ # noqa: D205 @overload def __new__(cls, *args: str) -> "Row": ... diff --git a/sqllogic/conftest.py b/sqllogic/conftest.py index 40759e9c..48315109 100644 --- a/sqllogic/conftest.py +++ b/sqllogic/conftest.py @@ -129,7 +129,7 @@ def create_parameters_from_paths(paths, root_dir: pathlib.Path, config: pytest.C def scan_for_test_scripts(root_dir: pathlib.Path, config: pytest.Config) -> typing.Iterator[typing.Any]: """Scans for .test files in the given directory and its subdirectories. Returns an iterator of pytest parameters (argument, id and marks). - """ + """ # noqa: D205 # TODO: Add tests from extensions test_script_extensions = [".test", ".test_slow", ".test_coverage"] it = itertools.chain.from_iterable(root_dir.rglob(f"*{ext}") for ext in test_script_extensions) @@ -169,7 +169,7 @@ def determine_test_offsets(config: pytest.Config, num_tests: int) -> tuple[int, start_offset defaults to 0. end_offset defaults to and is capped to the last test index. start_offset_percentage and end_offset_percentage are used to calculate the start and end offsets based on the total number of tests. This is done in a way that a test run to 25% and another test run starting at 25% do not overlap by excluding the 25th percent test. - """ + """ # noqa: D205 start_offset = config.getoption("start_offset") end_offset = config.getoption("end_offset") start_offset_percentage = config.getoption("start_offset_percentage") From 353ddd51241601ab7eafacc6578ad11bb9a15431 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 10 Sep 2025 15:31:52 +0200 Subject: [PATCH 224/472] Ruff config: will not check sqlogic --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 02e85f5e..b596bf60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -316,7 +316,7 @@ line-length = 120 indent-width = 4 target-version = "py39" fix = true -exclude = ['external/duckdb'] +exclude = ['external/duckdb', 'sqllogic'] [tool.ruff.lint] fixable = ["ALL"] From 40fcb7547582f868b28181b2f8bb459e4f202e41 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 10 Sep 2025 15:33:29 +0200 Subject: [PATCH 225/472] Ruff config: will not check scripts for D --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index b596bf60..525cfd2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -363,6 +363,10 @@ strict = true # No need for package, module, class, function, init etc docstrings in tests 'D100', 'D101', 'D102', 'D103', 'D104', 'D105', 'D107' ] +"scripts/**.py" = [ + # No need for package, module, class, function, init etc docstrings in scripts + 'D100', 'D101', 'D102', 'D103', 'D104', 'D105', 'D107', 'D205' +] [tool.ruff.format] docstring-code-format = true From a3c5dde89305123cb4f99562c8584a2509733c98 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 10 Sep 2025 15:34:42 +0200 Subject: [PATCH 226/472] Ruff noqa D10x: ignore docstring issues in existing code --- duckdb/__init__.py | 2 +- duckdb/bytes_io_wrapper.py | 10 +- duckdb/experimental/__init__.py | 2 +- duckdb/experimental/spark/__init__.py | 2 +- duckdb/experimental/spark/conf.py | 28 +-- duckdb/experimental/spark/context.py | 60 ++--- .../spark/errors/error_classes.py | 2 +- .../spark/errors/exceptions/__init__.py | 2 +- .../spark/errors/exceptions/base.py | 6 +- duckdb/experimental/spark/errors/utils.py | 4 +- duckdb/experimental/spark/exception.py | 4 +- duckdb/experimental/spark/sql/__init__.py | 2 +- duckdb/experimental/spark/sql/catalog.py | 24 +- duckdb/experimental/spark/sql/column.py | 28 +-- duckdb/experimental/spark/sql/conf.py | 14 +- duckdb/experimental/spark/sql/dataframe.py | 36 +-- duckdb/experimental/spark/sql/functions.py | 30 +-- duckdb/experimental/spark/sql/group.py | 14 +- duckdb/experimental/spark/sql/readwriter.py | 22 +- duckdb/experimental/spark/sql/session.py | 52 ++-- duckdb/experimental/spark/sql/streaming.py | 14 +- duckdb/experimental/spark/sql/type_utils.py | 8 +- duckdb/experimental/spark/sql/types.py | 228 +++++++++--------- duckdb/experimental/spark/sql/udf.py | 12 +- duckdb/filesystem.py | 8 +- duckdb/functional/__init__.py | 2 +- duckdb/polars_io.py | 2 +- duckdb/query_graph/__main__.py | 48 ++-- duckdb/typing/__init__.py | 2 +- duckdb/udf.py | 2 +- duckdb/value/__init__.py | 1 + duckdb/value/constant/__init__.py | 120 ++++----- duckdb_packaging/pypi_cleanup.py | 8 +- 33 files changed, 400 insertions(+), 399 deletions(-) diff --git a/duckdb/__init__.py b/duckdb/__init__.py index 73fcbbd2..8d6d68aa 100644 --- a/duckdb/__init__.py +++ b/duckdb/__init__.py @@ -1,4 +1,4 @@ -# Modules +# Modules # noqa: D104 from importlib.metadata import version from _duckdb import __version__ as duckdb_version diff --git a/duckdb/bytes_io_wrapper.py b/duckdb/bytes_io_wrapper.py index 763fd8b7..9851ad65 100644 --- a/duckdb/bytes_io_wrapper.py +++ b/duckdb/bytes_io_wrapper.py @@ -1,4 +1,4 @@ -from io import StringIO, TextIOBase +from io import StringIO, TextIOBase # noqa: D100 from typing import Any, Union """ @@ -36,10 +36,10 @@ """ -class BytesIOWrapper: +class BytesIOWrapper: # noqa: D101 # Wrapper that wraps a StringIO buffer and reads bytes from it # Created for compat with pyarrow read_csv - def __init__(self, buffer: Union[StringIO, TextIOBase], encoding: str = "utf-8") -> None: + def __init__(self, buffer: Union[StringIO, TextIOBase], encoding: str = "utf-8") -> None: # noqa: D107 self.buffer = buffer self.encoding = encoding # Because a character can be represented by more than 1 byte, @@ -48,10 +48,10 @@ def __init__(self, buffer: Union[StringIO, TextIOBase], encoding: str = "utf-8") # overflow to the front of the bytestring the next time reading is performed self.overflow = b"" - def __getattr__(self, attr: str) -> Any: + def __getattr__(self, attr: str) -> Any: # noqa: D105 return getattr(self.buffer, attr) - def read(self, n: Union[int, None] = -1) -> bytes: + def read(self, n: Union[int, None] = -1) -> bytes: # noqa: D102 assert self.buffer is not None bytestring = self.buffer.read(n).encode(self.encoding) # When n=-1/n greater than remaining bytes: Read entire file/rest of file diff --git a/duckdb/experimental/__init__.py b/duckdb/experimental/__init__.py index a88a6170..1b5ee51b 100644 --- a/duckdb/experimental/__init__.py +++ b/duckdb/experimental/__init__.py @@ -1,3 +1,3 @@ -from . import spark +from . import spark # noqa: D104 __all__ = spark.__all__ diff --git a/duckdb/experimental/spark/__init__.py b/duckdb/experimental/spark/__init__.py index bdde2ef8..7e56d4b1 100644 --- a/duckdb/experimental/spark/__init__.py +++ b/duckdb/experimental/spark/__init__.py @@ -1,4 +1,4 @@ -from ._globals import _NoValue +from ._globals import _NoValue # noqa: D104 from .conf import SparkConf from .context import SparkContext from .exception import ContributionsAcceptedError diff --git a/duckdb/experimental/spark/conf.py b/duckdb/experimental/spark/conf.py index ea1153b4..974115d6 100644 --- a/duckdb/experimental/spark/conf.py +++ b/duckdb/experimental/spark/conf.py @@ -1,45 +1,45 @@ -from typing import Optional +from typing import Optional # noqa: D100 from duckdb.experimental.spark.exception import ContributionsAcceptedError -class SparkConf: - def __init__(self) -> None: +class SparkConf: # noqa: D101 + def __init__(self) -> None: # noqa: D107 raise NotImplementedError - def contains(self, key: str) -> bool: + def contains(self, key: str) -> bool: # noqa: D102 raise ContributionsAcceptedError - def get(self, key: str, defaultValue: Optional[str] = None) -> Optional[str]: + def get(self, key: str, defaultValue: Optional[str] = None) -> Optional[str]: # noqa: D102 raise ContributionsAcceptedError - def getAll(self) -> list[tuple[str, str]]: + def getAll(self) -> list[tuple[str, str]]: # noqa: D102 raise ContributionsAcceptedError - def set(self, key: str, value: str) -> "SparkConf": + def set(self, key: str, value: str) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError - def setAll(self, pairs: list[tuple[str, str]]) -> "SparkConf": + def setAll(self, pairs: list[tuple[str, str]]) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError - def setAppName(self, value: str) -> "SparkConf": + def setAppName(self, value: str) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError - def setExecutorEnv( + def setExecutorEnv( # noqa: D102 self, key: Optional[str] = None, value: Optional[str] = None, pairs: Optional[list[tuple[str, str]]] = None ) -> "SparkConf": raise ContributionsAcceptedError - def setIfMissing(self, key: str, value: str) -> "SparkConf": + def setIfMissing(self, key: str, value: str) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError - def setMaster(self, value: str) -> "SparkConf": + def setMaster(self, value: str) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError - def setSparkHome(self, value: str) -> "SparkConf": + def setSparkHome(self, value: str) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError - def toDebugString(self) -> str: + def toDebugString(self) -> str: # noqa: D102 raise ContributionsAcceptedError diff --git a/duckdb/experimental/spark/context.py b/duckdb/experimental/spark/context.py index 9f1b4155..9835fcea 100644 --- a/duckdb/experimental/spark/context.py +++ b/duckdb/experimental/spark/context.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional # noqa: D100 import duckdb from duckdb import DuckDBPyConnection @@ -6,37 +6,37 @@ from duckdb.experimental.spark.exception import ContributionsAcceptedError -class SparkContext: - def __init__(self, master: str) -> None: +class SparkContext: # noqa: D101 + def __init__(self, master: str) -> None: # noqa: D107 self._connection = duckdb.connect(":memory:") # This aligns the null ordering with Spark. self._connection.execute("set default_null_order='nulls_first_on_asc_last_on_desc'") @property - def connection(self) -> DuckDBPyConnection: + def connection(self) -> DuckDBPyConnection: # noqa: D102 return self._connection - def stop(self) -> None: + def stop(self) -> None: # noqa: D102 self._connection.close() @classmethod - def getOrCreate(cls, conf: Optional[SparkConf] = None) -> "SparkContext": + def getOrCreate(cls, conf: Optional[SparkConf] = None) -> "SparkContext": # noqa: D102 raise ContributionsAcceptedError @classmethod - def setSystemProperty(cls, key: str, value: str) -> None: + def setSystemProperty(cls, key: str, value: str) -> None: # noqa: D102 raise ContributionsAcceptedError @property - def applicationId(self) -> str: + def applicationId(self) -> str: # noqa: D102 raise ContributionsAcceptedError @property - def defaultMinPartitions(self) -> int: + def defaultMinPartitions(self) -> int: # noqa: D102 raise ContributionsAcceptedError @property - def defaultParallelism(self) -> int: + def defaultParallelism(self) -> int: # noqa: D102 raise ContributionsAcceptedError # @property @@ -44,30 +44,30 @@ def defaultParallelism(self) -> int: # raise ContributionsAcceptedError @property - def startTime(self) -> str: + def startTime(self) -> str: # noqa: D102 raise ContributionsAcceptedError @property - def uiWebUrl(self) -> str: + def uiWebUrl(self) -> str: # noqa: D102 raise ContributionsAcceptedError @property - def version(self) -> str: + def version(self) -> str: # noqa: D102 raise ContributionsAcceptedError - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 raise ContributionsAcceptedError # def accumulator(self, value: ~T, accum_param: Optional[ForwardRef('AccumulatorParam[T]')] = None) -> 'Accumulator[T]': # pass - def addArchive(self, path: str) -> None: + def addArchive(self, path: str) -> None: # noqa: D102 raise ContributionsAcceptedError - def addFile(self, path: str, recursive: bool = False) -> None: + def addFile(self, path: str, recursive: bool = False) -> None: # noqa: D102 raise ContributionsAcceptedError - def addPyFile(self, path: str) -> None: + def addPyFile(self, path: str) -> None: # noqa: D102 raise ContributionsAcceptedError # def binaryFiles(self, path: str, minPartitions: Optional[int] = None) -> duckdb.experimental.spark.rdd.RDD[typing.Tuple[str, bytes]]: @@ -79,25 +79,25 @@ def addPyFile(self, path: str) -> None: # def broadcast(self, value: ~T) -> 'Broadcast[T]': # pass - def cancelAllJobs(self) -> None: + def cancelAllJobs(self) -> None: # noqa: D102 raise ContributionsAcceptedError - def cancelJobGroup(self, groupId: str) -> None: + def cancelJobGroup(self, groupId: str) -> None: # noqa: D102 raise ContributionsAcceptedError - def dump_profiles(self, path: str) -> None: + def dump_profiles(self, path: str) -> None: # noqa: D102 raise ContributionsAcceptedError # def emptyRDD(self) -> duckdb.experimental.spark.rdd.RDD[typing.Any]: # pass - def getCheckpointDir(self) -> Optional[str]: + def getCheckpointDir(self) -> Optional[str]: # noqa: D102 raise ContributionsAcceptedError - def getConf(self) -> SparkConf: + def getConf(self) -> SparkConf: # noqa: D102 raise ContributionsAcceptedError - def getLocalProperty(self, key: str) -> Optional[str]: + def getLocalProperty(self, key: str) -> Optional[str]: # noqa: D102 raise ContributionsAcceptedError # def hadoopFile(self, path: str, inputFormatClass: str, keyClass: str, valueClass: str, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, conf: Optional[Dict[str, str]] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: @@ -127,25 +127,25 @@ def getLocalProperty(self, key: str) -> Optional[str]: # def sequenceFile(self, path: str, keyClass: Optional[str] = None, valueClass: Optional[str] = None, keyConverter: Optional[str] = None, valueConverter: Optional[str] = None, minSplits: Optional[int] = None, batchSize: int = 0) -> pyspark.rdd.RDD[typing.Tuple[~T, ~U]]: # pass - def setCheckpointDir(self, dirName: str) -> None: + def setCheckpointDir(self, dirName: str) -> None: # noqa: D102 raise ContributionsAcceptedError - def setJobDescription(self, value: str) -> None: + def setJobDescription(self, value: str) -> None: # noqa: D102 raise ContributionsAcceptedError - def setJobGroup(self, groupId: str, description: str, interruptOnCancel: bool = False) -> None: + def setJobGroup(self, groupId: str, description: str, interruptOnCancel: bool = False) -> None: # noqa: D102 raise ContributionsAcceptedError - def setLocalProperty(self, key: str, value: str) -> None: + def setLocalProperty(self, key: str, value: str) -> None: # noqa: D102 raise ContributionsAcceptedError - def setLogLevel(self, logLevel: str) -> None: + def setLogLevel(self, logLevel: str) -> None: # noqa: D102 raise ContributionsAcceptedError - def show_profiles(self) -> None: + def show_profiles(self) -> None: # noqa: D102 raise ContributionsAcceptedError - def sparkUser(self) -> str: + def sparkUser(self) -> str: # noqa: D102 raise ContributionsAcceptedError # def statusTracker(self) -> duckdb.experimental.spark.status.StatusTracker: diff --git a/duckdb/experimental/spark/errors/error_classes.py b/duckdb/experimental/spark/errors/error_classes.py index 256fb644..55cea14d 100644 --- a/duckdb/experimental/spark/errors/error_classes.py +++ b/duckdb/experimental/spark/errors/error_classes.py @@ -1,4 +1,4 @@ -# +# # noqa: D100 # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. diff --git a/duckdb/experimental/spark/errors/exceptions/__init__.py b/duckdb/experimental/spark/errors/exceptions/__init__.py index cce3acad..edd0e7e1 100644 --- a/duckdb/experimental/spark/errors/exceptions/__init__.py +++ b/duckdb/experimental/spark/errors/exceptions/__init__.py @@ -1,4 +1,4 @@ -# +# # noqa: D104 # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. diff --git a/duckdb/experimental/spark/errors/exceptions/base.py b/duckdb/experimental/spark/errors/exceptions/base.py index 0b2c6a43..2eae2a19 100644 --- a/duckdb/experimental/spark/errors/exceptions/base.py +++ b/duckdb/experimental/spark/errors/exceptions/base.py @@ -1,4 +1,4 @@ -from typing import Optional, cast +from typing import Optional, cast # noqa: D100 from ..utils import ErrorClassesReader @@ -6,7 +6,7 @@ class PySparkException(Exception): """Base Exception for handling errors generated from PySpark.""" - def __init__( + def __init__( # noqa: D107 self, message: Optional[str] = None, # The error class, decides the message format, must be one of the valid options listed in 'error_classes.py' @@ -69,7 +69,7 @@ def getSqlState(self) -> None: """ return None - def __str__(self) -> str: + def __str__(self) -> str: # noqa: D105 if self.getErrorClass() is not None: return f"[{self.getErrorClass()}] {self.message}" else: diff --git a/duckdb/experimental/spark/errors/utils.py b/duckdb/experimental/spark/errors/utils.py index 984504a4..8a71f3b0 100644 --- a/duckdb/experimental/spark/errors/utils.py +++ b/duckdb/experimental/spark/errors/utils.py @@ -1,4 +1,4 @@ -# +# # noqa: D100 # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. @@ -23,7 +23,7 @@ class ErrorClassesReader: """A reader to load error information from error_classes.py.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 self.error_info_map = ERROR_CLASSES_MAP def get_error_message(self, error_class: str, message_parameters: dict[str, str]) -> str: diff --git a/duckdb/experimental/spark/exception.py b/duckdb/experimental/spark/exception.py index 24c4f291..3973d9c4 100644 --- a/duckdb/experimental/spark/exception.py +++ b/duckdb/experimental/spark/exception.py @@ -1,10 +1,10 @@ -class ContributionsAcceptedError(NotImplementedError): +class ContributionsAcceptedError(NotImplementedError): # noqa: D100 """This method is not planned to be implemented, if you would like to implement this method or show your interest in this method to other members of the community, feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb. """ # noqa: D205 - def __init__(self, message=None) -> None: + def __init__(self, message=None) -> None: # noqa: D107 doc = self.__class__.__doc__ if message: doc = message + "\n" + doc diff --git a/duckdb/experimental/spark/sql/__init__.py b/duckdb/experimental/spark/sql/__init__.py index 9ae09308..418273f0 100644 --- a/duckdb/experimental/spark/sql/__init__.py +++ b/duckdb/experimental/spark/sql/__init__.py @@ -1,4 +1,4 @@ -from .catalog import Catalog +from .catalog import Catalog # noqa: D104 from .conf import RuntimeConfig from .dataframe import DataFrame from .readwriter import DataFrameWriter diff --git a/duckdb/experimental/spark/sql/catalog.py b/duckdb/experimental/spark/sql/catalog.py index 8e510fdf..27e6fbb0 100644 --- a/duckdb/experimental/spark/sql/catalog.py +++ b/duckdb/experimental/spark/sql/catalog.py @@ -1,15 +1,15 @@ -from typing import NamedTuple, Optional +from typing import NamedTuple, Optional # noqa: D100 from .session import SparkSession -class Database(NamedTuple): +class Database(NamedTuple): # noqa: D101 name: str description: Optional[str] locationUri: str -class Table(NamedTuple): +class Table(NamedTuple): # noqa: D101 name: str database: Optional[str] description: Optional[str] @@ -17,7 +17,7 @@ class Table(NamedTuple): isTemporary: bool -class Column(NamedTuple): +class Column(NamedTuple): # noqa: D101 name: str description: Optional[str] dataType: str @@ -26,18 +26,18 @@ class Column(NamedTuple): isBucket: bool -class Function(NamedTuple): +class Function(NamedTuple): # noqa: D101 name: str description: Optional[str] className: str isTemporary: bool -class Catalog: - def __init__(self, session: SparkSession) -> None: +class Catalog: # noqa: D101 + def __init__(self, session: SparkSession) -> None: # noqa: D107 self._session = session - def listDatabases(self) -> list[Database]: + def listDatabases(self) -> list[Database]: # noqa: D102 res = self._session.conn.sql("select database_name from duckdb_databases()").fetchall() def transform_to_database(x) -> Database: @@ -46,7 +46,7 @@ def transform_to_database(x) -> Database: databases = [transform_to_database(x) for x in res] return databases - def listTables(self) -> list[Table]: + def listTables(self) -> list[Table]: # noqa: D102 res = self._session.conn.sql("select table_name, database_name, sql, temporary from duckdb_tables()").fetchall() def transform_to_table(x) -> Table: @@ -55,7 +55,7 @@ def transform_to_table(x) -> Table: tables = [transform_to_table(x) for x in res] return tables - def listColumns(self, tableName: str, dbName: Optional[str] = None) -> list[Column]: + def listColumns(self, tableName: str, dbName: Optional[str] = None) -> list[Column]: # noqa: D102 query = f""" select column_name, data_type, is_nullable from duckdb_columns() where table_name = '{tableName}' """ @@ -69,10 +69,10 @@ def transform_to_column(x) -> Column: columns = [transform_to_column(x) for x in res] return columns - def listFunctions(self, dbName: Optional[str] = None) -> list[Function]: + def listFunctions(self, dbName: Optional[str] = None) -> list[Function]: # noqa: D102 raise NotImplementedError - def setCurrentDatabase(self, dbName: str) -> None: + def setCurrentDatabase(self, dbName: str) -> None: # noqa: D102 raise NotImplementedError diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index ea6cd8a8..bc84365a 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Callable, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Union, cast # noqa: D100 from ..exception import ContributionsAcceptedError from .types import DataType @@ -93,11 +93,11 @@ class Column: .. versionadded:: 1.3.0 """ - def __init__(self, expr: Expression) -> None: + def __init__(self, expr: Expression) -> None: # noqa: D107 self.expr = expr # arithmetic operators - def __neg__(self) -> "Column": + def __neg__(self) -> "Column": # noqa: D105 return Column(-self.expr) # `and`, `or`, `not` cannot be overloaded in Python, @@ -205,10 +205,10 @@ def __getattr__(self, item: Any) -> "Column": raise AttributeError(msg) return self[item] - def alias(self, alias: str): + def alias(self, alias: str): # noqa: D102 return Column(self.expr.alias(alias)) - def when(self, condition: "Column", value: Any): + def when(self, condition: "Column", value: Any): # noqa: D102 if not isinstance(condition, Column): msg = "condition should be a Column" raise TypeError(msg) @@ -216,12 +216,12 @@ def when(self, condition: "Column", value: Any): expr = self.expr.when(condition.expr, v) return Column(expr) - def otherwise(self, value: Any): + def otherwise(self, value: Any): # noqa: D102 v = _get_expr(value) expr = self.expr.otherwise(v) return Column(expr) - def cast(self, dataType: Union[DataType, str]) -> "Column": + def cast(self, dataType: Union[DataType, str]) -> "Column": # noqa: D102 if isinstance(dataType, str): # Try to construct a default DuckDBPyType from it internal_type = DuckDBPyType(dataType) @@ -229,7 +229,7 @@ def cast(self, dataType: Union[DataType, str]) -> "Column": internal_type = dataType.duckdb_type return Column(self.expr.cast(internal_type)) - def isin(self, *cols: Any) -> "Column": + def isin(self, *cols: Any) -> "Column": # noqa: D102 if len(cols) == 1 and isinstance(cols[0], (list, set)): # Only one argument supplied, it's a list cols = cast("tuple", cols[0]) @@ -345,20 +345,20 @@ def __ne__( # type: ignore[override] nulls_first = _unary_op("nulls_first") nulls_last = _unary_op("nulls_last") - def asc_nulls_first(self) -> "Column": + def asc_nulls_first(self) -> "Column": # noqa: D102 return self.asc().nulls_first() - def asc_nulls_last(self) -> "Column": + def asc_nulls_last(self) -> "Column": # noqa: D102 return self.asc().nulls_last() - def desc_nulls_first(self) -> "Column": + def desc_nulls_first(self) -> "Column": # noqa: D102 return self.desc().nulls_first() - def desc_nulls_last(self) -> "Column": + def desc_nulls_last(self) -> "Column": # noqa: D102 return self.desc().nulls_last() - def isNull(self) -> "Column": + def isNull(self) -> "Column": # noqa: D102 return Column(self.expr.isnull()) - def isNotNull(self) -> "Column": + def isNotNull(self) -> "Column": # noqa: D102 return Column(self.expr.isnotnull()) diff --git a/duckdb/experimental/spark/sql/conf.py b/duckdb/experimental/spark/sql/conf.py index 8ab9fa38..e44f2566 100644 --- a/duckdb/experimental/spark/sql/conf.py +++ b/duckdb/experimental/spark/sql/conf.py @@ -1,23 +1,23 @@ -from typing import Optional, Union +from typing import Optional, Union # noqa: D100 from duckdb import DuckDBPyConnection from duckdb.experimental.spark._globals import _NoValue, _NoValueType -class RuntimeConfig: - def __init__(self, connection: DuckDBPyConnection) -> None: +class RuntimeConfig: # noqa: D101 + def __init__(self, connection: DuckDBPyConnection) -> None: # noqa: D107 self._connection = connection - def set(self, key: str, value: str) -> None: + def set(self, key: str, value: str) -> None: # noqa: D102 raise NotImplementedError - def isModifiable(self, key: str) -> bool: + def isModifiable(self, key: str) -> bool: # noqa: D102 raise NotImplementedError - def unset(self, key: str) -> None: + def unset(self, key: str) -> None: # noqa: D102 raise NotImplementedError - def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) -> str: + def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) -> str: # noqa: D102 raise NotImplementedError diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 99da92ec..8e83822b 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -1,4 +1,4 @@ -import uuid +import uuid # noqa: D100 from functools import reduce from keyword import iskeyword from typing import ( @@ -32,18 +32,18 @@ from .functions import _to_column_expr, col, lit -class DataFrame: - def __init__(self, relation: duckdb.DuckDBPyRelation, session: "SparkSession") -> None: +class DataFrame: # noqa: D101 + def __init__(self, relation: duckdb.DuckDBPyRelation, session: "SparkSession") -> None: # noqa: D107 self.relation = relation self.session = session self._schema = None if self.relation is not None: self._schema = duckdb_to_spark_schema(self.relation.columns, self.relation.types) - def show(self, **kwargs) -> None: + def show(self, **kwargs) -> None: # noqa: D102 self.relation.show() - def toPandas(self) -> "PandasDataFrame": + def toPandas(self) -> "PandasDataFrame": # noqa: D102 return self.relation.df() def toArrow(self) -> "pa.Table": @@ -103,10 +103,10 @@ def createOrReplaceTempView(self, name: str) -> None: """ self.relation.create_view(name, True) - def createGlobalTempView(self, name: str) -> None: + def createGlobalTempView(self, name: str) -> None: # noqa: D102 raise NotImplementedError - def withColumnRenamed(self, columnName: str, newName: str) -> "DataFrame": + def withColumnRenamed(self, columnName: str, newName: str) -> "DataFrame": # noqa: D102 if columnName not in self.relation: msg = f"DataFrame does not contain a column named {columnName}" raise ValueError(msg) @@ -119,7 +119,7 @@ def withColumnRenamed(self, columnName: str, newName: str) -> "DataFrame": rel = self.relation.select(*cols) return DataFrame(rel, self.session) - def withColumn(self, columnName: str, col: Column) -> "DataFrame": + def withColumn(self, columnName: str, col: Column) -> "DataFrame": # noqa: D102 if not isinstance(col, Column): raise PySparkTypeError( error_class="NOT_COLUMN", @@ -472,7 +472,7 @@ def sort(self, *cols: Union[str, Column, list[Union[str, Column]]], **kwargs: An orderBy = sort - def head(self, n: Optional[int] = None) -> Union[Optional[Row], list[Row]]: + def head(self, n: Optional[int] = None) -> Union[Optional[Row], list[Row]]: # noqa: D102 if n is None: rs = self.head(1) return rs[0] if rs else None @@ -480,7 +480,7 @@ def head(self, n: Optional[int] = None) -> Union[Optional[Row], list[Row]]: first = head - def take(self, num: int) -> list[Row]: + def take(self, num: int) -> list[Row]: # noqa: D102 return self.limit(num).collect() def filter(self, condition: "ColumnOrName") -> "DataFrame": @@ -547,7 +547,7 @@ def filter(self, condition: "ColumnOrName") -> "DataFrame": where = filter - def select(self, *cols) -> "DataFrame": + def select(self, *cols) -> "DataFrame": # noqa: D102 cols = list(cols) if len(cols) == 1: cols = cols[0] @@ -574,7 +574,7 @@ def _ipython_key_completions_(self) -> list[str]: # when accessed in bracket notation, e.g. df['] return self.columns - def __dir__(self) -> list[str]: + def __dir__(self) -> list[str]: # noqa: D105 out = set(super().__dir__()) out.update(c for c in self.columns if c.isidentifier() and not iskeyword(c)) return sorted(out) @@ -792,7 +792,7 @@ def alias(self, alias: str) -> "DataFrame": assert isinstance(alias, str), "alias should be a string" return DataFrame(self.relation.set_alias(alias), self.session) - def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] + def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] # noqa: D102 exclude = [] for col in cols: if isinstance(col, str): @@ -809,7 +809,7 @@ def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] expr = StarExpression(exclude=exclude) return DataFrame(self.relation.select(expr), self.session) - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return str(self.relation) def limit(self, num: int) -> "DataFrame": @@ -986,10 +986,10 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] groupby = groupBy @property - def write(self) -> DataFrameWriter: + def write(self) -> DataFrameWriter: # noqa: D102 return DataFrameWriter(self) - def printSchema(self): + def printSchema(self): # noqa: D102 raise ContributionsAcceptedError def union(self, other: "DataFrame") -> "DataFrame": @@ -1339,7 +1339,7 @@ def _cast_types(self, *types) -> "DataFrame": new_rel = self.relation.project(cast_expressions) return DataFrame(new_rel, self.session) - def toDF(self, *cols) -> "DataFrame": + def toDF(self, *cols) -> "DataFrame": # noqa: D102 existing_columns = self.relation.columns column_count = len(cols) if column_count != len(existing_columns): @@ -1350,7 +1350,7 @@ def toDF(self, *cols) -> "DataFrame": new_rel = self.relation.project(*projections) return DataFrame(new_rel, self.session) - def collect(self) -> list[Row]: + def collect(self) -> list[Row]: # noqa: D102 columns = self.relation.columns result = self.relation.fetchall() diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index 7ae923f4..30764fe1 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -1,4 +1,4 @@ -import warnings +import warnings # noqa: D100 from typing import TYPE_CHECKING, Any, Callable, Optional, Union, overload from duckdb import ( @@ -30,7 +30,7 @@ def _invoke_function_over_columns(name: str, *cols: "ColumnOrName") -> Column: return _invoke_function(name, *cols) -def col(column: str): +def col(column: str): # noqa: D103 return Column(ColumnExpression(column)) @@ -90,7 +90,7 @@ def ucase(str: "ColumnOrName") -> Column: return upper(str) -def when(condition: "Column", value: Any) -> Column: +def when(condition: "Column", value: Any) -> Column: # noqa: D103 if not isinstance(condition, Column): msg = "condition should be a Column" raise TypeError(msg) @@ -103,7 +103,7 @@ def _inner_expr_or_val(val): return val.expr if isinstance(val, Column) else val -def struct(*cols: Column) -> Column: +def struct(*cols: Column) -> Column: # noqa: D103 return Column(FunctionExpression("struct_pack", *[_inner_expr_or_val(x) for x in cols])) @@ -143,7 +143,7 @@ def array(*cols: Union["ColumnOrName", Union[list["ColumnOrName"], tuple["Column return _invoke_function_over_columns("list_value", *cols) -def lit(col: Any) -> Column: +def lit(col: Any) -> Column: # noqa: D103 return col if isinstance(col, Column) else Column(ConstantExpression(col)) @@ -1680,7 +1680,7 @@ def ceil(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("ceil", col) -def ceiling(col: "ColumnOrName") -> Column: +def ceiling(col: "ColumnOrName") -> Column: # noqa: D103 return ceil(col) @@ -1854,7 +1854,7 @@ def equal_null(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: ... ) >>> df.select(equal_null(df.a, df.b).alias("r")).collect() [Row(r=True), Row(r=False)] - """ # noqa: D205 + """ # noqa: D205, D415 if isinstance(col1, str): col1 = col(col1) @@ -2183,7 +2183,7 @@ def negative(col: "ColumnOrName") -> Column: | -1| | -2| +------------+ - """ # noqa: D205 + """ # noqa: D205, D415 return abs(col) * -1 @@ -3370,7 +3370,7 @@ def coalesce(*cols: "ColumnOrName") -> Column: | 1|NULL| 1.0| |NULL| 2| 0.0| +----+----+----------------+ - """ # noqa: D205 + """ # noqa: D205, D415 cols = [_to_column_expr(expr) for expr in cols] return Column(CoalesceOperator(*cols)) @@ -3400,7 +3400,7 @@ def nvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: ... ) >>> df.select(nvl(df.a, df.b).alias("r")).collect() [Row(r=8), Row(r=1)] - """ # noqa: D205 + """ # noqa: D205, D415 return coalesce(col1, col2) @@ -3460,7 +3460,7 @@ def ifnull(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: | 8| | 1| +------------+ - """ # noqa: D205 + """ # noqa: D205, D415 return coalesce(col1, col2) @@ -5824,7 +5824,7 @@ def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = Non [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] >>> df.select(try_to_timestamp(df.t, lit("yyyy-MM-dd HH:mm:ss")).alias("dt")).collect() [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] - """ # noqa: D205 + """ # noqa: D205, D415 if format is None: format = lit(["%Y-%m-%d", "%Y-%m-%d %H:%M:%S"]) @@ -6130,7 +6130,7 @@ def contains(left: "ColumnOrName", right: "ColumnOrName") -> Column: +--------------+--------------+ | true| false| +--------------+--------------+ - """ # noqa: D205 + """ # noqa: D205, D415 return _invoke_function_over_columns("contains", left, right) @@ -6157,7 +6157,7 @@ def reverse(col: "ColumnOrName") -> Column: >>> df = spark.createDataFrame([([2, 1, 3],), ([1],), ([],)], ["data"]) >>> df.select(reverse(df.data).alias("r")).collect() [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])] - """ # noqa: D205 + """ # noqa: D205, D415 return _invoke_function("reverse", _to_column_expr(col)) @@ -6197,7 +6197,7 @@ def concat(*cols: "ColumnOrName") -> Column: [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] >>> df DataFrame[arr: array] - """ # noqa: D205 + """ # noqa: D205, D415 return _invoke_function_over_columns("concat", *cols) diff --git a/duckdb/experimental/spark/sql/group.py b/duckdb/experimental/spark/sql/group.py index ab8e89cf..c4222749 100644 --- a/duckdb/experimental/spark/sql/group.py +++ b/duckdb/experimental/spark/sql/group.py @@ -1,4 +1,4 @@ -# +# # noqa: D100 # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. @@ -51,8 +51,8 @@ def _api(self: "GroupedData", *cols: str) -> DataFrame: return _api -class Grouping: - def __init__(self, *cols: "ColumnOrName", **kwargs) -> None: +class Grouping: # noqa: D101 + def __init__(self, *cols: "ColumnOrName", **kwargs) -> None: # noqa: D107 self._type = "" self._cols = [_to_column_expr(x) for x in cols] if "special" in kwargs: @@ -61,11 +61,11 @@ def __init__(self, *cols: "ColumnOrName", **kwargs) -> None: assert special in accepted_special self._type = special - def get_columns(self) -> str: + def get_columns(self) -> str: # noqa: D102 columns = ",".join([str(x) for x in self._cols]) return columns - def __str__(self) -> str: + def __str__(self) -> str: # noqa: D105 columns = self.get_columns() if self._type: return self._type + "(" + columns + ")" @@ -78,12 +78,12 @@ class GroupedData: """ # noqa: D205 - def __init__(self, grouping: Grouping, df: DataFrame) -> None: + def __init__(self, grouping: Grouping, df: DataFrame) -> None: # noqa: D107 self._grouping = grouping self._df = df self.session: SparkSession = df.session - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return str(self._df) def count(self) -> DataFrame: diff --git a/duckdb/experimental/spark/sql/readwriter.py b/duckdb/experimental/spark/sql/readwriter.py index 714ed797..eb714833 100644 --- a/duckdb/experimental/spark/sql/readwriter.py +++ b/duckdb/experimental/spark/sql/readwriter.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Union, cast +from typing import TYPE_CHECKING, Optional, Union, cast # noqa: D100 from ..errors import PySparkNotImplementedError, PySparkTypeError from ..exception import ContributionsAcceptedError @@ -12,15 +12,15 @@ from duckdb.experimental.spark.sql.session import SparkSession -class DataFrameWriter: - def __init__(self, dataframe: "DataFrame") -> None: +class DataFrameWriter: # noqa: D101 + def __init__(self, dataframe: "DataFrame") -> None: # noqa: D107 self.dataframe = dataframe - def saveAsTable(self, table_name: str) -> None: + def saveAsTable(self, table_name: str) -> None: # noqa: D102 relation = self.dataframe.relation relation.create(table_name) - def parquet( + def parquet( # noqa: D102 self, path: str, mode: Optional[str] = None, @@ -35,7 +35,7 @@ def parquet( relation.write_parquet(path, compression=compression) - def csv( + def csv( # noqa: D102 self, path: str, mode: Optional[str] = None, @@ -86,11 +86,11 @@ def csv( ) -class DataFrameReader: - def __init__(self, session: "SparkSession") -> None: +class DataFrameReader: # noqa: D101 + def __init__(self, session: "SparkSession") -> None: # noqa: D107 self.session = session - def load( + def load( # noqa: D102 self, path: Optional[Union[str, list[str]]] = None, format: Optional[str] = None, @@ -127,7 +127,7 @@ def load( df = df.toDF(names) raise NotImplementedError - def csv( + def csv( # noqa: D102 self, path: Union[str, list[str]], schema: Optional[Union[StructType, str]] = None, @@ -245,7 +245,7 @@ def csv( df = df.toDF(*names) return df - def parquet(self, *paths: str, **options: "OptionalPrimitiveType") -> "DataFrame": + def parquet(self, *paths: str, **options: "OptionalPrimitiveType") -> "DataFrame": # noqa: D102 input = list(paths) if len(input) != 1: msg = "Only single paths are supported for now" diff --git a/duckdb/experimental/spark/sql/session.py b/duckdb/experimental/spark/sql/session.py index 4b919446..8bb6e910 100644 --- a/duckdb/experimental/spark/sql/session.py +++ b/duckdb/experimental/spark/sql/session.py @@ -1,4 +1,4 @@ -import uuid +import uuid # noqa: D100 from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Optional, Union @@ -41,8 +41,8 @@ def _combine_data_and_schema(data: Iterable[Any], schema: StructType): return new_data -class SparkSession: - def __init__(self, context: SparkContext) -> None: +class SparkSession: # noqa: D101 + def __init__(self, context: SparkContext) -> None: # noqa: D107 self.conn = context.connection self._context = context self._conf = RuntimeConfig(self.conn) @@ -121,7 +121,7 @@ def _createDataFrameFromPandas(self, data: "PandasDataFrame", types, names) -> D df = df.toDF(*names) return df - def createDataFrame( + def createDataFrame( # noqa: D102 self, data: Union["PandasDataFrame", Iterable[Any]], schema: Optional[Union[StructType, list[str]]] = None, @@ -184,10 +184,10 @@ def createDataFrame( df = df.toDF(*names) return df - def newSession(self) -> "SparkSession": + def newSession(self) -> "SparkSession": # noqa: D102 return SparkSession(self._context) - def range( + def range( # noqa: D102 self, start: int, end: Optional[int] = None, @@ -203,24 +203,24 @@ def range( return DataFrame(self.conn.table_function("range", parameters=[start, end, step]), self) - def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame: + def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame: # noqa: D102 if kwargs: raise NotImplementedError relation = self.conn.sql(sqlQuery) return DataFrame(relation, self) - def stop(self) -> None: + def stop(self) -> None: # noqa: D102 self._context.stop() - def table(self, tableName: str) -> DataFrame: + def table(self, tableName: str) -> DataFrame: # noqa: D102 relation = self.conn.table(tableName) return DataFrame(relation, self) - def getActiveSession(self) -> "SparkSession": + def getActiveSession(self) -> "SparkSession": # noqa: D102 return self @property - def catalog(self) -> "Catalog": + def catalog(self) -> "Catalog": # noqa: D102 if not hasattr(self, "_catalog"): from duckdb.experimental.spark.sql.catalog import Catalog @@ -228,59 +228,59 @@ def catalog(self) -> "Catalog": return self._catalog @property - def conf(self) -> RuntimeConfig: + def conf(self) -> RuntimeConfig: # noqa: D102 return self._conf @property - def read(self) -> DataFrameReader: + def read(self) -> DataFrameReader: # noqa: D102 return DataFrameReader(self) @property - def readStream(self) -> DataStreamReader: + def readStream(self) -> DataStreamReader: # noqa: D102 return DataStreamReader(self) @property - def sparkContext(self) -> SparkContext: + def sparkContext(self) -> SparkContext: # noqa: D102 return self._context @property - def streams(self) -> Any: + def streams(self) -> Any: # noqa: D102 raise ContributionsAcceptedError @property - def udf(self) -> UDFRegistration: + def udf(self) -> UDFRegistration: # noqa: D102 return UDFRegistration(self) @property - def version(self) -> str: + def version(self) -> str: # noqa: D102 return "1.0.0" - class Builder: - def __init__(self) -> None: + class Builder: # noqa: D106 + def __init__(self) -> None: # noqa: D107 pass - def master(self, name: str) -> "SparkSession.Builder": + def master(self, name: str) -> "SparkSession.Builder": # noqa: D102 # no-op return self - def appName(self, name: str) -> "SparkSession.Builder": + def appName(self, name: str) -> "SparkSession.Builder": # noqa: D102 # no-op return self - def remote(self, url: str) -> "SparkSession.Builder": + def remote(self, url: str) -> "SparkSession.Builder": # noqa: D102 # no-op return self - def getOrCreate(self) -> "SparkSession": + def getOrCreate(self) -> "SparkSession": # noqa: D102 context = SparkContext("__ignored__") return SparkSession(context) - def config( + def config( # noqa: D102 self, key: Optional[str] = None, value: Optional[Any] = None, conf: Optional[SparkConf] = None ) -> "SparkSession.Builder": return self - def enableHiveSupport(self) -> "SparkSession.Builder": + def enableHiveSupport(self) -> "SparkSession.Builder": # noqa: D102 # no-op return self diff --git a/duckdb/experimental/spark/sql/streaming.py b/duckdb/experimental/spark/sql/streaming.py index 201b889b..08b7cc30 100644 --- a/duckdb/experimental/spark/sql/streaming.py +++ b/duckdb/experimental/spark/sql/streaming.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional, Union # noqa: D100 from .types import StructType @@ -10,20 +10,20 @@ OptionalPrimitiveType = Optional[PrimitiveType] -class DataStreamWriter: - def __init__(self, dataframe: "DataFrame") -> None: +class DataStreamWriter: # noqa: D101 + def __init__(self, dataframe: "DataFrame") -> None: # noqa: D107 self.dataframe = dataframe - def toTable(self, table_name: str) -> None: + def toTable(self, table_name: str) -> None: # noqa: D102 # Should we register the dataframe or create a table from the contents? raise NotImplementedError -class DataStreamReader: - def __init__(self, session: "SparkSession") -> None: +class DataStreamReader: # noqa: D101 + def __init__(self, session: "SparkSession") -> None: # noqa: D107 self.session = session - def load( + def load( # noqa: D102 self, path: Optional[str] = None, format: Optional[str] = None, diff --git a/duckdb/experimental/spark/sql/type_utils.py b/duckdb/experimental/spark/sql/type_utils.py index 446eac97..1773eb9e 100644 --- a/duckdb/experimental/spark/sql/type_utils.py +++ b/duckdb/experimental/spark/sql/type_utils.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import cast # noqa: D100 from duckdb.typing import DuckDBPyType @@ -74,7 +74,7 @@ } -def convert_nested_type(dtype: DuckDBPyType) -> DataType: +def convert_nested_type(dtype: DuckDBPyType) -> DataType: # noqa: D103 id = dtype.id if id == "list" or id == "array": children = dtype.children @@ -89,7 +89,7 @@ def convert_nested_type(dtype: DuckDBPyType) -> DataType: raise NotImplementedError -def convert_type(dtype: DuckDBPyType) -> DataType: +def convert_type(dtype: DuckDBPyType) -> DataType: # noqa: D103 id = dtype.id if id in ["list", "struct", "map", "array"]: return convert_nested_type(dtype) @@ -102,6 +102,6 @@ def convert_type(dtype: DuckDBPyType) -> DataType: return spark_type() -def duckdb_to_spark_schema(names: list[str], types: list[DuckDBPyType]) -> StructType: +def duckdb_to_spark_schema(names: list[str], types: list[DuckDBPyType]) -> StructType: # noqa: D103 fields = [StructField(name, dtype) for name, dtype in zip(names, [convert_type(x) for x in types])] return StructType(fields) diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 606f792c..ad74cd98 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -1,4 +1,4 @@ -# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'spark' folder. +# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'spark' folder. # noqa: D100 import calendar import datetime @@ -66,32 +66,32 @@ class DataType: """Base class for data types.""" - def __init__(self, duckdb_type) -> None: + def __init__(self, duckdb_type) -> None: # noqa: D107 self.duckdb_type = duckdb_type - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return self.__class__.__name__ + "()" - def __hash__(self) -> int: + def __hash__(self) -> int: # noqa: D105 return hash(str(self)) - def __eq__(self, other: object) -> bool: + def __eq__(self, other: object) -> bool: # noqa: D105 return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - def __ne__(self, other: object) -> bool: + def __ne__(self, other: object) -> bool: # noqa: D105 return not self.__eq__(other) @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return cls.__name__[:-4].lower() - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return self.typeName() - def jsonValue(self) -> Union[str, dict[str, Any]]: + def jsonValue(self) -> Union[str, dict[str, Any]]: # noqa: D102 raise ContributionsAcceptedError - def json(self) -> str: + def json(self) -> str: # noqa: D102 raise ContributionsAcceptedError def needConversion(self) -> bool: @@ -129,11 +129,11 @@ class NullType(DataType, metaclass=DataTypeSingleton): The data type representing None, used for the types that cannot be inferred. """ - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("NULL")) @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return "void" @@ -158,54 +158,54 @@ class FractionalType(NumericType): class StringType(AtomicType, metaclass=DataTypeSingleton): """String data type.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("VARCHAR")) class BitstringType(AtomicType, metaclass=DataTypeSingleton): """Bitstring data type.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("BIT")) class UUIDType(AtomicType, metaclass=DataTypeSingleton): """UUID data type.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("UUID")) class BinaryType(AtomicType, metaclass=DataTypeSingleton): """Binary (byte array) data type.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("BLOB")) class BooleanType(AtomicType, metaclass=DataTypeSingleton): """Boolean data type.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("BOOLEAN")) class DateType(AtomicType, metaclass=DataTypeSingleton): """Date (datetime.date) data type.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("DATE")) EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True - def toInternal(self, d: datetime.date) -> int: + def toInternal(self, d: datetime.date) -> int: # noqa: D102 if d is not None: return d.toordinal() - self.EPOCH_ORDINAL - def fromInternal(self, v: int) -> datetime.date: + def fromInternal(self, v: int) -> datetime.date: # noqa: D102 if v is not None: return datetime.date.fromordinal(v + self.EPOCH_ORDINAL) @@ -213,22 +213,22 @@ def fromInternal(self, v: int) -> datetime.date: class TimestampType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIMESTAMPTZ")) @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return "timestamptz" - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True - def toInternal(self, dt: datetime.datetime) -> int: + def toInternal(self, dt: datetime.datetime) -> int: # noqa: D102 if dt is not None: seconds = calendar.timegm(dt.utctimetuple()) if dt.tzinfo else time.mktime(dt.timetuple()) return int(seconds) * 1000000 + dt.microsecond - def fromInternal(self, ts: int) -> datetime.datetime: + def fromInternal(self, ts: int) -> datetime.datetime: # noqa: D102 if ts is not None: # using int to avoid precision loss in float return datetime.datetime.fromtimestamp(ts // 1000000).replace(microsecond=ts % 1000000) @@ -237,22 +237,22 @@ def fromInternal(self, ts: int) -> datetime.datetime: class TimestampNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with microsecond precision.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIMESTAMP")) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return "timestamp" - def toInternal(self, dt: datetime.datetime) -> int: + def toInternal(self, dt: datetime.datetime) -> int: # noqa: D102 if dt is not None: seconds = calendar.timegm(dt.timetuple()) return int(seconds) * 1000000 + dt.microsecond - def fromInternal(self, ts: int) -> datetime.datetime: + def fromInternal(self, ts: int) -> datetime.datetime: # noqa: D102 if ts is not None: # using int to avoid precision loss in float return datetime.datetime.utcfromtimestamp(ts // 1000000).replace(microsecond=ts % 1000000) @@ -261,60 +261,60 @@ def fromInternal(self, ts: int) -> datetime.datetime: class TimestampSecondNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with second precision.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIMESTAMP_S")) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return "timestamp_s" - def toInternal(self, dt: datetime.datetime) -> int: + def toInternal(self, dt: datetime.datetime) -> int: # noqa: D102 raise ContributionsAcceptedError - def fromInternal(self, ts: int) -> datetime.datetime: + def fromInternal(self, ts: int) -> datetime.datetime: # noqa: D102 raise ContributionsAcceptedError class TimestampMilisecondNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with milisecond precision.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIMESTAMP_MS")) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return "timestamp_ms" - def toInternal(self, dt: datetime.datetime) -> int: + def toInternal(self, dt: datetime.datetime) -> int: # noqa: D102 raise ContributionsAcceptedError - def fromInternal(self, ts: int) -> datetime.datetime: + def fromInternal(self, ts: int) -> datetime.datetime: # noqa: D102 raise ContributionsAcceptedError class TimestampNanosecondNTZType(AtomicType, metaclass=DataTypeSingleton): """Timestamp (datetime.datetime) data type without timezone information with nanosecond precision.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIMESTAMP_NS")) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True @classmethod - def typeName(cls) -> str: + def typeName(cls) -> str: # noqa: D102 return "timestamp_ns" - def toInternal(self, dt: datetime.datetime) -> int: + def toInternal(self, dt: datetime.datetime) -> int: # noqa: D102 raise ContributionsAcceptedError - def fromInternal(self, ts: int) -> datetime.datetime: + def fromInternal(self, ts: int) -> datetime.datetime: # noqa: D102 raise ContributionsAcceptedError @@ -338,90 +338,90 @@ class DecimalType(FractionalType): the number of digits on right side of dot. (default: 0) """ - def __init__(self, precision: int = 10, scale: int = 0) -> None: + def __init__(self, precision: int = 10, scale: int = 0) -> None: # noqa: D107 super().__init__(duckdb.decimal_type(precision, scale)) self.precision = precision self.scale = scale self.hasPrecisionInfo = True # this is a public API - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "decimal(%d,%d)" % (self.precision, self.scale) - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return "DecimalType(%d,%d)" % (self.precision, self.scale) class DoubleType(FractionalType, metaclass=DataTypeSingleton): """Double data type, representing double precision floats.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("DOUBLE")) class FloatType(FractionalType, metaclass=DataTypeSingleton): """Float data type, representing single precision floats.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("FLOAT")) class ByteType(IntegralType): """Byte data type, i.e. a signed integer in a single byte.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TINYINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "tinyint" class UnsignedByteType(IntegralType): """Unsigned byte data type, i.e. a unsigned integer in a single byte.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("UTINYINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "utinyint" class ShortType(IntegralType): """Short data type, i.e. a signed 16-bit integer.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("SMALLINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "smallint" class UnsignedShortType(IntegralType): """Unsigned short data type, i.e. a unsigned 16-bit integer.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("USMALLINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "usmallint" class IntegerType(IntegralType): """Int data type, i.e. a signed 32-bit integer.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("INTEGER")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "integer" class UnsignedIntegerType(IntegralType): """Unsigned int data type, i.e. a unsigned 32-bit integer.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("UINTEGER")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "uinteger" @@ -432,10 +432,10 @@ class LongType(IntegralType): please use :class:`DecimalType`. """ - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("BIGINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "bigint" @@ -446,10 +446,10 @@ class UnsignedLongType(IntegralType): please use :class:`HugeIntegerType`. """ - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("UBIGINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "ubigint" @@ -460,10 +460,10 @@ class HugeIntegerType(IntegralType): please use :class:`DecimalType`. """ - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("HUGEINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "hugeint" @@ -474,30 +474,30 @@ class UnsignedHugeIntegerType(IntegralType): please use :class:`DecimalType`. """ - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("UHUGEINT")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "uhugeint" class TimeType(IntegralType): """Time (datetime.time) data type.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIMETZ")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "timetz" class TimeNTZType(IntegralType): """Time (datetime.time) data type without timezone information.""" - def __init__(self) -> None: + def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("TIME")) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "time" @@ -518,7 +518,7 @@ class DayTimeIntervalType(AtomicType): _inverted_fields = dict(zip(_fields.values(), _fields.keys())) - def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None) -> None: + def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None) -> None: # noqa: D107 super().__init__(DuckDBPyType("INTERVAL")) if startField is None and endField is None: # Default matched to scala side. @@ -544,17 +544,17 @@ def _str_repr(self) -> str: simpleString = _str_repr - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return "%s(%d, %d)" % (type(self).__name__, self.startField, self.endField) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return True - def toInternal(self, dt: datetime.timedelta) -> Optional[int]: + def toInternal(self, dt: datetime.timedelta) -> Optional[int]: # noqa: D102 if dt is not None: return (math.floor(dt.total_seconds()) * 1000000) + dt.microseconds - def fromInternal(self, micros: int) -> Optional[datetime.timedelta]: + def fromInternal(self, micros: int) -> Optional[datetime.timedelta]: # noqa: D102 if micros is not None: return datetime.timedelta(microseconds=micros) @@ -577,7 +577,7 @@ class ArrayType(DataType): False """ - def __init__(self, elementType: DataType, containsNull: bool = True) -> None: + def __init__(self, elementType: DataType, containsNull: bool = True) -> None: # noqa: D107 super().__init__(duckdb.list_type(elementType.duckdb_type)) assert isinstance(elementType, DataType), "elementType %s should be an instance of %s" % ( elementType, @@ -586,21 +586,21 @@ def __init__(self, elementType: DataType, containsNull: bool = True) -> None: self.elementType = elementType self.containsNull = containsNull - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "array<%s>" % self.elementType.simpleString() - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return "ArrayType(%s, %s)" % (self.elementType, str(self.containsNull)) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return self.elementType.needConversion() - def toInternal(self, obj: list[Optional[T]]) -> list[Optional[T]]: + def toInternal(self, obj: list[Optional[T]]) -> list[Optional[T]]: # noqa: D102 if not self.needConversion(): return obj return obj and [self.elementType.toInternal(v) for v in obj] - def fromInternal(self, obj: list[Optional[T]]) -> list[Optional[T]]: + def fromInternal(self, obj: list[Optional[T]]) -> list[Optional[T]]: # noqa: D102 if not self.needConversion(): return obj return obj and [self.elementType.fromInternal(v) for v in obj] @@ -630,7 +630,7 @@ class MapType(DataType): False """ - def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True) -> None: + def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True) -> None: # noqa: D107 super().__init__(duckdb.map_type(keyType.duckdb_type, valueType.duckdb_type)) assert isinstance(keyType, DataType), "keyType %s should be an instance of %s" % ( keyType, @@ -644,28 +644,28 @@ def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bo self.valueType = valueType self.valueContainsNull = valueContainsNull - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "map<%s,%s>" % ( self.keyType.simpleString(), self.valueType.simpleString(), ) - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return "MapType(%s, %s, %s)" % ( self.keyType, self.valueType, str(self.valueContainsNull), ) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return self.keyType.needConversion() or self.valueType.needConversion() - def toInternal(self, obj: dict[T, Optional[U]]) -> dict[T, Optional[U]]: + def toInternal(self, obj: dict[T, Optional[U]]) -> dict[T, Optional[U]]: # noqa: D102 if not self.needConversion(): return obj return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v)) for k, v in obj.items()) - def fromInternal(self, obj: dict[T, Optional[U]]) -> dict[T, Optional[U]]: + def fromInternal(self, obj: dict[T, Optional[U]]) -> dict[T, Optional[U]]: # noqa: D102 if not self.needConversion(): return obj return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v)) for k, v in obj.items()) @@ -693,7 +693,7 @@ class StructField(DataType): False """ - def __init__( + def __init__( # noqa: D107 self, name: str, dataType: DataType, @@ -711,26 +711,26 @@ def __init__( self.nullable = nullable self.metadata = metadata or {} - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "%s:%s" % (self.name, self.dataType.simpleString()) - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return "StructField('%s', %s, %s)" % ( self.name, self.dataType, str(self.nullable), ) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 return self.dataType.needConversion() - def toInternal(self, obj: T) -> T: + def toInternal(self, obj: T) -> T: # noqa: D102 return self.dataType.toInternal(obj) - def fromInternal(self, obj: T) -> T: + def fromInternal(self, obj: T) -> T: # noqa: D102 return self.dataType.fromInternal(obj) - def typeName(self) -> str: # type: ignore[override] + def typeName(self) -> str: # type: ignore[override] # noqa: D102 msg = "StructField does not have typeName. Use typeName on its type explicitly instead." raise TypeError(msg) @@ -766,7 +766,7 @@ class StructType(DataType): def _update_internal_duckdb_type(self): self.duckdb_type = duckdb.struct_type(dict(zip(self.names, [x.duckdb_type for x in self.fields]))) - def __init__(self, fields: Optional[list[StructField]] = None) -> None: + def __init__(self, fields: Optional[list[StructField]] = None) -> None: # noqa: D107 if not fields: self.fields = [] self.names = [] @@ -836,7 +836,7 @@ def add( >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 True - """ # noqa: D205 + """ # noqa: D205, D415 if isinstance(field, StructField): self.fields.append(field) self.names.append(field.name) @@ -882,16 +882,16 @@ def __getitem__(self, key: Union[str, int]) -> StructField: msg = "StructType keys should be strings, integers or slices" raise TypeError(msg) - def simpleString(self) -> str: + def simpleString(self) -> str: # noqa: D102 return "struct<%s>" % (",".join(f.simpleString() for f in self)) - def __repr__(self) -> str: + def __repr__(self) -> str: # noqa: D105 return "StructType([%s])" % ", ".join(str(field) for field in self) - def __contains__(self, item: Any) -> bool: + def __contains__(self, item: Any) -> bool: # noqa: D105 return item in self.names - def extract_types_and_names(self) -> tuple[list[str], list[str]]: + def extract_types_and_names(self) -> tuple[list[str], list[str]]: # noqa: D102 names = [] types = [] for f in self.fields: @@ -910,11 +910,11 @@ def fieldNames(self) -> list[str]: """ return list(self.names) - def needConversion(self) -> bool: + def needConversion(self) -> bool: # noqa: D102 # We need convert Row()/namedtuple into tuple() return True - def toInternal(self, obj: tuple) -> tuple: + def toInternal(self, obj: tuple) -> tuple: # noqa: D102 if obj is None: return @@ -946,7 +946,7 @@ def toInternal(self, obj: tuple) -> tuple: else: raise ValueError("Unexpected tuple %r with StructType" % obj) - def fromInternal(self, obj: tuple) -> "Row": + def fromInternal(self, obj: tuple) -> "Row": # noqa: D102 if obj is None: return if isinstance(obj, Row): @@ -1125,7 +1125,7 @@ class Row(tuple): >>> row2 = Row(name="Alice", age=11) >>> row1 == row2 True - """ # noqa: D205 + """ # noqa: D205, D415 @overload def __new__(cls, *args: str) -> "Row": ... @@ -1133,7 +1133,7 @@ def __new__(cls, *args: str) -> "Row": ... @overload def __new__(cls, **kwargs: Any) -> "Row": ... - def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": + def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": # noqa: D102 if args and kwargs: msg = "Can not use both args and kwargs to create Row" raise ValueError(msg) @@ -1192,7 +1192,7 @@ def conv(obj: Any) -> Any: else: return dict(zip(self.__fields__, self)) - def __contains__(self, item: Any) -> bool: + def __contains__(self, item: Any) -> bool: # noqa: D105 if hasattr(self, "__fields__"): return item in self.__fields__ else: @@ -1207,7 +1207,7 @@ def __call__(self, *args: Any) -> "Row": ) return _create_row(self, args) - def __getitem__(self, item: Any) -> Any: + def __getitem__(self, item: Any) -> Any: # noqa: D105 if isinstance(item, (int, slice)): return super(Row, self).__getitem__(item) try: @@ -1220,7 +1220,7 @@ def __getitem__(self, item: Any) -> Any: except ValueError: raise ValueError(item) - def __getattr__(self, item: str) -> Any: + def __getattr__(self, item: str) -> Any: # noqa: D105 if item.startswith("__"): raise AttributeError(item) try: @@ -1233,7 +1233,7 @@ def __getattr__(self, item: str) -> Any: except ValueError: raise AttributeError(item) - def __setattr__(self, key: Any, value: Any) -> None: + def __setattr__(self, key: Any, value: Any) -> None: # noqa: D105 if key != "__fields__": msg = "Row is read-only" raise RuntimeError(msg) diff --git a/duckdb/experimental/spark/sql/udf.py b/duckdb/experimental/spark/sql/udf.py index 389d43ab..7437ed6b 100644 --- a/duckdb/experimental/spark/sql/udf.py +++ b/duckdb/experimental/spark/sql/udf.py @@ -1,4 +1,4 @@ -# https://sparkbyexamples.com/pyspark/pyspark-udf-user-defined-function/ +# https://sparkbyexamples.com/pyspark/pyspark-udf-user-defined-function/ # noqa: D100 from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union from .types import DataType @@ -10,11 +10,11 @@ UserDefinedFunctionLike = TypeVar("UserDefinedFunctionLike") -class UDFRegistration: - def __init__(self, sparkSession: "SparkSession") -> None: +class UDFRegistration: # noqa: D101 + def __init__(self, sparkSession: "SparkSession") -> None: # noqa: D107 self.sparkSession = sparkSession - def register( + def register( # noqa: D102 self, name: str, f: Union[Callable[..., Any], "UserDefinedFunctionLike"], @@ -22,7 +22,7 @@ def register( ) -> "UserDefinedFunctionLike": self.sparkSession.conn.create_function(name, f, return_type=returnType) - def registerJavaFunction( + def registerJavaFunction( # noqa: D102 self, name: str, javaClassName: str, @@ -30,7 +30,7 @@ def registerJavaFunction( ) -> None: raise NotImplementedError - def registerJavaUDAF(self, name: str, javaClassName: str) -> None: + def registerJavaUDAF(self, name: str, javaClassName: str) -> None: # noqa: D102 raise NotImplementedError diff --git a/duckdb/filesystem.py b/duckdb/filesystem.py index 77838103..1775a9cf 100644 --- a/duckdb/filesystem.py +++ b/duckdb/filesystem.py @@ -1,4 +1,4 @@ -from io import TextIOBase +from io import TextIOBase # noqa: D100 from fsspec import AbstractFileSystem from fsspec.implementations.memory import MemoryFile, MemoryFileSystem @@ -6,17 +6,17 @@ from .bytes_io_wrapper import BytesIOWrapper -def is_file_like(obj): +def is_file_like(obj): # noqa: D103 # We only care that we can read from the file return hasattr(obj, "read") and hasattr(obj, "seek") -class ModifiedMemoryFileSystem(MemoryFileSystem): +class ModifiedMemoryFileSystem(MemoryFileSystem): # noqa: D101 protocol = ("DUCKDB_INTERNAL_OBJECTSTORE",) # defer to the original implementation that doesn't hardcode the protocol _strip_protocol = classmethod(AbstractFileSystem._strip_protocol.__func__) - def add_file(self, object, path): + def add_file(self, object, path): # noqa: D102 if not is_file_like(object): msg = "Can not read from a non file-like object" raise ValueError(msg) diff --git a/duckdb/functional/__init__.py b/duckdb/functional/__init__.py index b1ddab19..a1d69d39 100644 --- a/duckdb/functional/__init__.py +++ b/duckdb/functional/__init__.py @@ -1,3 +1,3 @@ -from _duckdb.functional import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonUDFType +from _duckdb.functional import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonUDFType # noqa: D104 __all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"] diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index 69e1e7ea..a11339bb 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -1,4 +1,4 @@ -import datetime +import datetime # noqa: D100 import json from collections.abc import Iterator from decimal import Decimal diff --git a/duckdb/query_graph/__main__.py b/duckdb/query_graph/__main__.py index 88d96350..dedb30a3 100644 --- a/duckdb/query_graph/__main__.py +++ b/duckdb/query_graph/__main__.py @@ -1,4 +1,4 @@ -import argparse +import argparse # noqa: D100 import json import os import re @@ -76,63 +76,63 @@ """ -class NodeTiming: - def __init__(self, phase: str, time: float) -> object: +class NodeTiming: # noqa: D101 + def __init__(self, phase: str, time: float) -> object: # noqa: D107 self.phase = phase self.time = time # percentage is determined later. self.percentage = 0 - def calculate_percentage(self, total_time: float) -> None: + def calculate_percentage(self, total_time: float) -> None: # noqa: D102 self.percentage = self.time / total_time - def combine_timing(l: object, r: object) -> object: + def combine_timing(l: object, r: object) -> object: # noqa: D102 # TODO: can only add timings for same-phase nodes total_time = l.time + r.time return NodeTiming(l.phase, total_time) -class AllTimings: - def __init__(self) -> None: +class AllTimings: # noqa: D101 + def __init__(self) -> None: # noqa: D107 self.phase_to_timings = {} - def add_node_timing(self, node_timing: NodeTiming): + def add_node_timing(self, node_timing: NodeTiming): # noqa: D102 if node_timing.phase in self.phase_to_timings: self.phase_to_timings[node_timing.phase].append(node_timing) return self.phase_to_timings[node_timing.phase] = [node_timing] - def get_phase_timings(self, phase: str): + def get_phase_timings(self, phase: str): # noqa: D102 return self.phase_to_timings[phase] - def get_summary_phase_timings(self, phase: str): + def get_summary_phase_timings(self, phase: str): # noqa: D102 return reduce(NodeTiming.combine_timing, self.phase_to_timings[phase]) - def get_phases(self): + def get_phases(self): # noqa: D102 phases = list(self.phase_to_timings.keys()) phases.sort(key=lambda x: (self.get_summary_phase_timings(x)).time) phases.reverse() return phases - def get_sum_of_all_timings(self): + def get_sum_of_all_timings(self): # noqa: D102 total_timing_sum = 0 for phase in self.phase_to_timings.keys(): total_timing_sum += self.get_summary_phase_timings(phase).time return total_timing_sum -def open_utf8(fpath: str, flags: str) -> object: +def open_utf8(fpath: str, flags: str) -> object: # noqa: D103 return open(fpath, flags, encoding="utf8") -def get_child_timings(top_node: object, query_timings: object) -> str: +def get_child_timings(top_node: object, query_timings: object) -> str: # noqa: D103 node_timing = NodeTiming(top_node["operator_type"], float(top_node["operator_timing"])) query_timings.add_node_timing(node_timing) for child in top_node["children"]: get_child_timings(child, query_timings) -def get_pink_shade_hex(fraction: float): +def get_pink_shade_hex(fraction: float): # noqa: D103 fraction = max(0, min(1, fraction)) # Define the RGB values for very light pink (almost white) and dark pink @@ -148,7 +148,7 @@ def get_pink_shade_hex(fraction: float): return f"#{r:02x}{g:02x}{b:02x}" -def get_node_body(name: str, result: str, cpu_time: float, card: int, est: int, width: int, extra_info: str) -> str: +def get_node_body(name: str, result: str, cpu_time: float, card: int, est: int, width: int, extra_info: str) -> str: # noqa: D103 node_style = f"background-color: {get_pink_shade_hex(float(result) / cpu_time)};" body = f'' @@ -167,7 +167,7 @@ def get_node_body(name: str, result: str, cpu_time: float, card: int, est: int, return body -def generate_tree_recursive(json_graph: object, cpu_time: float) -> str: +def generate_tree_recursive(json_graph: object, cpu_time: float) -> str: # noqa: D103 node_prefix_html = "
    • " node_suffix_html = "
    • " @@ -206,7 +206,7 @@ def generate_tree_recursive(json_graph: object, cpu_time: float) -> str: # For generating the table in the top left. -def generate_timing_html(graph_json: object, query_timings: object) -> object: +def generate_timing_html(graph_json: object, query_timings: object) -> object: # noqa: D103 json_graph = json.loads(graph_json) gather_timing_information(json_graph, query_timings) total_time = float(json_graph.get("operator_timing") or json_graph.get("latency")) @@ -244,7 +244,7 @@ def generate_timing_html(graph_json: object, query_timings: object) -> object: return table_head + table_body -def generate_tree_html(graph_json: object) -> str: +def generate_tree_html(graph_json: object) -> str: # noqa: D103 json_graph = json.loads(graph_json) cpu_time = float(json_graph["cpu_time"]) tree_prefix = '
      \n
        ' @@ -255,7 +255,7 @@ def generate_tree_html(graph_json: object) -> str: return tree_prefix + tree_body + tree_suffix -def generate_ipython(json_input: str) -> str: +def generate_ipython(json_input: str) -> str: # noqa: D103 from IPython.core.display import HTML html_output = generate_html(json_input, False) @@ -268,7 +268,7 @@ def generate_ipython(json_input: str) -> str: ) -def generate_style_html(graph_json: str, include_meta_info: bool) -> None: +def generate_style_html(graph_json: str, include_meta_info: bool) -> None: # noqa: D103 treeflex_css = '\n' css = "