Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,10 +673,10 @@ def _map_sql_type( # pylint: disable=too-many-arguments,too-many-positional-arg

if isinstance(param, datetime.time):
return (
ddbc_sql_const.SQL_TIME.value,
ddbc_sql_const.SQL_C_TYPE_TIME.value,
8,
0,
ddbc_sql_const.SQL_TYPE_TIME.value,
ddbc_sql_const.SQL_C_CHAR.value,
16,
6,
False,
)

Expand Down Expand Up @@ -941,6 +941,16 @@ def _create_parameter_types_list( # pylint: disable=too-many-arguments,too-many
parameter, parameters_list, i, min_val=min_val, max_val=max_val
)

# If TIME values are being bound via text C-types, normalize them to a
# textual representation expected by SQL_C_CHAR/SQL_C_WCHAR binding.
if isinstance(parameter, datetime.time) and c_type in (
ddbc_sql_const.SQL_C_CHAR.value,
ddbc_sql_const.SQL_C_WCHAR.value,
):
time_text = parameter.isoformat(timespec="microseconds")
parameters_list[i] = time_text
column_size = max(column_size, len(time_text))

paraminfo.paramCType = c_type
paraminfo.paramSQLType = sql_type
paraminfo.inputOutputType = ddbc_sql_const.SQL_PARAM_INPUT.value
Expand Down Expand Up @@ -2250,6 +2260,12 @@ def executemany( # pylint: disable=too-many-locals,too-many-branches,too-many-s
for i, val in enumerate(processed_row):
if val is None:
continue
if isinstance(val, datetime.time) and parameters_type[i].paramCType in (
ddbc_sql_const.SQL_C_CHAR.value,
ddbc_sql_const.SQL_C_WCHAR.value,
):
processed_row[i] = val.isoformat(timespec="microseconds")
continue
if (
isinstance(val, decimal.Decimal)
and parameters_type[i].paramSQLType == ddbc_sql_const.SQL_VARCHAR.value
Expand Down
136 changes: 105 additions & 31 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "connection/connection_pool.h"
#include "logger_bridge.hpp"

#include <cctype>
#include <cstdint>
#include <cstring> // For std::memcpy
#include <filesystem>
Expand All @@ -28,6 +29,7 @@
#define SQL_MAX_NUMERIC_LEN 16
#define SQL_SS_XML (-152)
#define SQL_SS_UDT (-151)
#define SQL_TIME_TEXT_MAX_LEN 32

#define STRINGIFY_FOR_CASE(x) \
case x: \
Expand All @@ -53,6 +55,79 @@ inline std::string GetEffectiveCharDecoding(const std::string& userEncoding) {
#endif
}

namespace PythonObjectCache {
py::object get_time_class();
}

inline py::object ParseSqlTimeTextToPythonObject(const char* timeText, SQLLEN timeTextLen) {
if (!timeText || (timeTextLen <= 0 && timeTextLen != SQL_NO_TOTAL)) {
return py::none();
}

size_t len;
if (timeTextLen == SQL_NO_TOTAL) {
// When the driver reports SQL_NO_TOTAL, the buffer may not be null-terminated.
// Bound the scan to the maximum expected TIME/TIME2 text length.
const void* nul = std::memchr(timeText, '\0', SQL_TIME_TEXT_MAX_LEN - 1);
len = nul ? static_cast<size_t>(static_cast<const char*>(nul) - timeText)
: static_cast<size_t>(SQL_TIME_TEXT_MAX_LEN - 1);
} else {
len = static_cast<size_t>(timeTextLen);
if (len > SQL_TIME_TEXT_MAX_LEN - 1) {
len = SQL_TIME_TEXT_MAX_LEN - 1;
}
}

std::string value(timeText, len);

size_t start = value.find_first_not_of(" \t\r\n");
if (start == std::string::npos) {
return py::none();
}
size_t end = value.find_last_not_of(" \t\r\n");
value = value.substr(start, end - start + 1);

size_t firstColon = value.find(':');
size_t secondColon =
(firstColon == std::string::npos) ? std::string::npos : value.find(':', firstColon + 1);
if (firstColon == std::string::npos || secondColon == std::string::npos) {
ThrowStdException("Failed to parse TIME/TIME2 value: missing ':' separators");
}

int hour = std::stoi(value.substr(0, firstColon));
int minute = std::stoi(value.substr(firstColon + 1, secondColon - firstColon - 1));

size_t dotPos = value.find('.', secondColon + 1);
int second = 0;
int microsecond = 0;

if (dotPos == std::string::npos) {
second = std::stoi(value.substr(secondColon + 1));
} else {
second = std::stoi(value.substr(secondColon + 1, dotPos - secondColon - 1));
std::string frac = value.substr(dotPos + 1);

size_t digitCount = 0;
while (digitCount < frac.size() &&
std::isdigit(static_cast<unsigned char>(frac[digitCount]))) {
++digitCount;
}
frac = frac.substr(0, digitCount);

if (frac.size() > 6) {
frac = frac.substr(0, 6);
}
while (frac.size() < 6) {
frac.push_back('0');
}
if (!frac.empty()) {
microsecond = std::stoi(frac);
}
}

return PythonObjectCache::get_time_class()(hour, minute, second, microsecond);
}

//-------------------------------------------------------------------------------------------------
//-------------------------------------------------------------------------------------------------
// Logging Infrastructure:
Expand Down Expand Up @@ -3244,17 +3319,20 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
}
break;
}
case SQL_TIME:
case SQL_TYPE_TIME:
case SQL_SS_TIME2: {
SQL_TIME_STRUCT timeValue;
ret =
SQLGetData_ptr(hStmt, i, SQL_C_TYPE_TIME, &timeValue, sizeof(timeValue), NULL);
char timeTextBuffer[SQL_TIME_TEXT_MAX_LEN] = {0};
SQLLEN timeDataLen = 0;
ret = SQLGetData_ptr(hStmt, i, SQL_C_CHAR, timeTextBuffer, sizeof(timeTextBuffer),
&timeDataLen);
if (SQL_SUCCEEDED(ret)) {
row.append(PythonObjectCache::get_time_class()(timeValue.hour, timeValue.minute,
timeValue.second));
if (timeDataLen == SQL_NULL_DATA) {
// Normal NULL value: append None without logging an error.
row.append(py::none());
} else {
row.append(ParseSqlTimeTextToPythonObject(timeTextBuffer, timeDataLen));
}
} else {
LOG("SQLGetData: Error retrieving SQL_TYPE_TIME for column "
LOG("SQLGetData: Error retrieving SQL_SS_TIME2 for column "
"%d - SQLRETURN=%d",
i, ret);
row.append(py::none());
Expand Down Expand Up @@ -3585,13 +3663,10 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column
SQLBindCol_ptr(hStmt, col, SQL_C_TYPE_DATE, buffers.dateBuffers[col - 1].data(),
sizeof(SQL_DATE_STRUCT), buffers.indicators[col - 1].data());
break;
case SQL_TIME:
case SQL_TYPE_TIME:
case SQL_SS_TIME2:
buffers.timeBuffers[col - 1].resize(fetchSize);
ret =
SQLBindCol_ptr(hStmt, col, SQL_C_TYPE_TIME, buffers.timeBuffers[col - 1].data(),
sizeof(SQL_TIME_STRUCT), buffers.indicators[col - 1].data());
buffers.charBuffers[col - 1].resize(fetchSize * SQL_TIME_TEXT_MAX_LEN);
ret = SQLBindCol_ptr(hStmt, col, SQL_C_CHAR, buffers.charBuffers[col - 1].data(),
SQL_TIME_TEXT_MAX_LEN, buffers.indicators[col - 1].data());
break;
case SQL_GUID:
buffers.guidBuffers[col - 1].resize(fetchSize);
Expand Down Expand Up @@ -3895,16 +3970,12 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
PyList_SET_ITEM(row, col - 1, dateObj);
break;
}
case SQL_TIME:
case SQL_TYPE_TIME:
case SQL_SS_TIME2: {
PyObject* timeObj =
PythonObjectCache::get_time_class()(buffers.timeBuffers[col - 1][i].hour,
buffers.timeBuffers[col - 1][i].minute,
buffers.timeBuffers[col - 1][i].second)
.release()
.ptr();
PyList_SET_ITEM(row, col - 1, timeObj);
const char* rawData = reinterpret_cast<const char*>(
&buffers.charBuffers[col - 1][i * SQL_TIME_TEXT_MAX_LEN]);
SQLLEN timeDataLen = buffers.indicators[col - 1][i];
py::object timeObj = ParseSqlTimeTextToPythonObject(rawData, timeDataLen);
PyList_SET_ITEM(row, col - 1, timeObj.release().ptr());
break;
}
case SQL_SS_TIMESTAMPOFFSET: {
Expand Down Expand Up @@ -4034,10 +4105,8 @@ size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols) {
case SQL_TYPE_DATE:
rowSize += sizeof(SQL_DATE_STRUCT);
break;
case SQL_TIME:
case SQL_TYPE_TIME:
case SQL_SS_TIME2:
rowSize += sizeof(SQL_TIME_STRUCT);
rowSize += SQL_TIME_TEXT_MAX_LEN;
break;
case SQL_GUID:
rowSize += sizeof(SQLGUID);
Expand All @@ -4048,7 +4117,8 @@ size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols) {
break;
case SQL_SS_UDT:
rowSize += (static_cast<SQLLEN>(columnSize) == SQL_NO_TOTAL || columnSize == 0)
? SQL_MAX_LOB_SIZE : columnSize;
? SQL_MAX_LOB_SIZE
: columnSize;
break;
case SQL_BINARY:
case SQL_VARBINARY:
Expand Down Expand Up @@ -4112,8 +4182,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch

if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || dataType == SQL_VARCHAR ||
dataType == SQL_LONGVARCHAR || dataType == SQL_VARBINARY ||
dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML ||
dataType == SQL_SS_UDT) &&
dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML || dataType == SQL_SS_UDT) &&
(columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) {
lobColumns.push_back(i + 1); // 1-based
}
Expand Down Expand Up @@ -4252,8 +4321,7 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows,

if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || dataType == SQL_VARCHAR ||
dataType == SQL_LONGVARCHAR || dataType == SQL_VARBINARY ||
dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML ||
dataType == SQL_SS_UDT) &&
dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML || dataType == SQL_SS_UDT) &&
(columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) {
lobColumns.push_back(i + 1); // 1-based
}
Expand Down Expand Up @@ -4432,6 +4500,12 @@ PYBIND11_MODULE(ddbc_bindings, m) {
// Expose architecture-specific constants
m.attr("ARCHITECTURE") = ARCHITECTURE;

// Test helper: expose time-text parser for unit testing edge cases
m.def("_test_parse_time_text", &ParseSqlTimeTextToPythonObject,
"Parse a SQL TIME/TIME2 text buffer into a Python datetime.time object (test helper)",
py::arg("timeText"), py::arg("timeTextLen"));
m.attr("SQL_NO_TOTAL") = static_cast<int>(SQL_NO_TOTAL);

// Expose the C++ functions to Python
m.def("ThrowStdException", &ThrowStdException);
m.def("GetDriverPathCpp", &GetDriverPathCpp, "Get the path to the ODBC driver");
Expand Down
Loading
Loading