Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
gh-132983: Refactor shared code in train_dict and finalize_dict (GH-1…
…34432)

Refactor shared code in train_dict and finalize_dict
(cherry picked from commit c64a214)

Co-authored-by: Emma Smith <emma@emmatyping.dev>
  • Loading branch information
emmatyping authored and miss-islington committed May 21, 2025
commit 24c6f25273246735879ca15e474340f1f8e5735e
123 changes: 55 additions & 68 deletions Modules/_zstd/_zstdmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,49 @@ get_zstd_state(PyObject *module)
return (_zstd_state *)state;
}

static Py_ssize_t
calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
size_t **chunk_sizes)
{
Py_ssize_t chunks_number;
Py_ssize_t sizes_sum;
Py_ssize_t i;

chunks_number = Py_SIZE(samples_sizes);
if ((size_t) chunks_number > UINT32_MAX) {
PyErr_Format(PyExc_ValueError,
"The number of samples should be <= %u.", UINT32_MAX);
return -1;
}

/* Prepare chunk_sizes */
*chunk_sizes = PyMem_New(size_t, chunks_number);
if (*chunk_sizes == NULL) {
PyErr_NoMemory();
return -1;
}

sizes_sum = 0;
for (i = 0; i < chunks_number; i++) {
PyObject *size = PyTuple_GetItem(samples_sizes, i);
(*chunk_sizes)[i] = PyLong_AsSize_t(size);
if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) {
PyErr_Format(PyExc_ValueError,
"Items in samples_sizes should be an int "
"object, with a value between 0 and %u.", SIZE_MAX);
return -1;
}
sizes_sum += (*chunk_sizes)[i];
}

if (sizes_sum != Py_SIZE(samples_bytes)) {
PyErr_SetString(PyExc_ValueError,
"The samples size tuple doesn't match the concatenation's size.");
return -1;
}
return chunks_number;
}


/*[clinic input]
_zstd.train_dict
Expand All @@ -192,54 +235,25 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
PyObject *samples_sizes, Py_ssize_t dict_size)
/*[clinic end generated code: output=8e87fe43935e8f77 input=d20dedb21c72cb62]*/
{
// TODO(emmatyping): The preamble and suffix to this function and _finalize_dict
// are pretty similar. We should see if we can refactor them to share that code.
Py_ssize_t chunks_number;
size_t *chunk_sizes = NULL;
PyObject *dst_dict_bytes = NULL;
size_t *chunk_sizes = NULL;
Py_ssize_t chunks_number;
size_t zstd_ret;
Py_ssize_t sizes_sum;
Py_ssize_t i;

/* Check arguments */
if (dict_size <= 0) {
PyErr_SetString(PyExc_ValueError, "dict_size argument should be positive number.");
return NULL;
}

chunks_number = Py_SIZE(samples_sizes);
if ((size_t) chunks_number > UINT32_MAX) {
PyErr_Format(PyExc_ValueError,
"The number of samples should be <= %u.", UINT32_MAX);
/* Check that the samples are valid and get their sizes */
chunks_number = calculate_samples_stats(samples_bytes, samples_sizes,
&chunk_sizes);
if (chunks_number < 0)
{
return NULL;
}

/* Prepare chunk_sizes */
chunk_sizes = PyMem_New(size_t, chunks_number);
if (chunk_sizes == NULL) {
PyErr_NoMemory();
goto error;
}

sizes_sum = 0;
for (i = 0; i < chunks_number; i++) {
PyObject *size = PyTuple_GetItem(samples_sizes, i);
chunk_sizes[i] = PyLong_AsSize_t(size);
if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
PyErr_Format(PyExc_ValueError,
"Items in samples_sizes should be an int "
"object, with a value between 0 and %u.", SIZE_MAX);
goto error;
}
sizes_sum += chunk_sizes[i];
}

if (sizes_sum != Py_SIZE(samples_bytes)) {
PyErr_SetString(PyExc_ValueError,
"The samples size tuple doesn't match the concatenation's size.");
goto error;
}

/* Allocate dict buffer */
dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size);
if (dst_dict_bytes == NULL) {
Expand Down Expand Up @@ -307,48 +321,21 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
PyObject *dst_dict_bytes = NULL;
size_t zstd_ret;
ZDICT_params_t params;
Py_ssize_t sizes_sum;
Py_ssize_t i;

/* Check arguments */
if (dict_size <= 0) {
PyErr_SetString(PyExc_ValueError, "dict_size argument should be positive number.");
return NULL;
}

chunks_number = Py_SIZE(samples_sizes);
if ((size_t) chunks_number > UINT32_MAX) {
PyErr_Format(PyExc_ValueError,
"The number of samples should be <= %u.", UINT32_MAX);
/* Check that the samples are valid and get their sizes */
chunks_number = calculate_samples_stats(samples_bytes, samples_sizes,
&chunk_sizes);
if (chunks_number < 0)
{
return NULL;
}

/* Prepare chunk_sizes */
chunk_sizes = PyMem_New(size_t, chunks_number);
if (chunk_sizes == NULL) {
PyErr_NoMemory();
goto error;
}

sizes_sum = 0;
for (i = 0; i < chunks_number; i++) {
PyObject *size = PyTuple_GetItem(samples_sizes, i);
chunk_sizes[i] = PyLong_AsSize_t(size);
if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
PyErr_Format(PyExc_ValueError,
"Items in samples_sizes should be an int "
"object, with a value between 0 and %u.", SIZE_MAX);
goto error;
}
sizes_sum += chunk_sizes[i];
}

if (sizes_sum != Py_SIZE(samples_bytes)) {
PyErr_SetString(PyExc_ValueError,
"The samples size tuple doesn't match the concatenation's size.");
goto error;
}

/* Allocate dict buffer */
dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size);
if (dst_dict_bytes == NULL) {
Expand Down
Loading