Skip to content
Prev Previous commit
Next Next commit
Refactor _zstd_load_c_dict and _zstd_load_d_dict.
  • Loading branch information
serhiy-storchaka committed May 31, 2025
commit cffe18b44b1491aaa18879b0be3e0930214a61eb
37 changes: 36 additions & 1 deletion Modules/_zstd/_zstdmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "Python.h"

#include "_zstdmodule.h"
#include "zstddict.h"

#include <zstd.h> // ZSTD_*()
#include <zdict.h> // ZDICT_*()
Expand All @@ -20,6 +19,42 @@ module _zstd
#include "clinic/_zstdmodule.c.h"


ZstdDict *
_Py_parse_zstd_dict(const _zstd_state *state, PyObject *dict, int *ptype)
{
if (state == NULL) {
return NULL;
}

/* Check ZstdDict */
if (PyObject_TypeCheck(dict, state->ZstdDict_type)) {
return (ZstdDict*)dict;
}

/* Check (ZstdDict, type) */
if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2
&& PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0), state->ZstdDict_type)
&& PyLong_Check(PyTuple_GET_ITEM(dict, 1)))
{
int type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1));
if (type == -1 && PyErr_Occurred()) {
return NULL;
}
if (type == DICT_TYPE_DIGESTED
|| type == DICT_TYPE_UNDIGESTED
|| type == DICT_TYPE_PREFIX)
{
*ptype = type;
return (ZstdDict*)PyTuple_GET_ITEM(dict, 0);
}
}

/* Wrong type */
PyErr_SetString(PyExc_TypeError,
"zstd_dict argument should be ZstdDict object.");
return NULL;
}

/* Format error message and set ZstdError. */
void
set_zstd_error(const _zstd_state* const state,
Expand Down
6 changes: 6 additions & 0 deletions Modules/_zstd/_zstdmodule.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#ifndef ZSTD_MODULE_H
#define ZSTD_MODULE_H

#include "zstddict.h"

/* Type specs */
extern PyType_Spec zstd_dict_type_spec;
extern PyType_Spec zstd_compressor_type_spec;
Expand Down Expand Up @@ -43,6 +45,10 @@ typedef enum {
DICT_TYPE_PREFIX = 2
} dictionary_type;

extern ZstdDict *
_Py_parse_zstd_dict(const _zstd_state *state,
PyObject *dict, int *type);

/* Format error message and set ZstdError. */
extern void
set_zstd_error(const _zstd_state* const state,
Expand Down
54 changes: 9 additions & 45 deletions Modules/_zstd/compressor.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class _zstd.ZstdCompressor "ZstdCompressor *" "&zstd_compressor_type_spec"

#include "_zstdmodule.h"
#include "buffer.h"
#include "zstddict.h"
#include "internal/pycore_lock.h" // PyMutex_IsLocked

#include <stddef.h> // offsetof()
Expand Down Expand Up @@ -262,52 +261,17 @@ static int
_zstd_load_c_dict(ZstdCompressor *self, PyObject *dict)
{
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state == NULL) {
/* When compressing, use undigested dictionary by default. */
int type = DICT_TYPE_UNDIGESTED;
ZstdDict *zd = _Py_parse_zstd_dict(mod_state, dict, &type);
if (zd == NULL) {
return -1;
}
ZstdDict *zd;
int type, ret;

/* Check ZstdDict */
if (PyObject_TypeCheck(dict, mod_state->ZstdDict_type)) {
/* When compressing, use undigested dictionary by default. */
zd = (ZstdDict*)dict;
type = DICT_TYPE_UNDIGESTED;
PyMutex_Lock(&zd->lock);
ret = _zstd_load_impl(self, zd, mod_state, type);
PyMutex_Unlock(&zd->lock);
return ret;
}

/* Check (ZstdDict, type) */
if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) {
/* Check ZstdDict */
if (PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0),
mod_state->ZstdDict_type)
&& PyLong_Check(PyTuple_GET_ITEM(dict, 1)))
{
type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1));
if (type == -1 && PyErr_Occurred()) {
return -1;
}
if (type == DICT_TYPE_DIGESTED
|| type == DICT_TYPE_UNDIGESTED
|| type == DICT_TYPE_PREFIX)
{
assert(type >= 0);
zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0);
PyMutex_Lock(&zd->lock);
ret = _zstd_load_impl(self, zd, mod_state, type);
PyMutex_Unlock(&zd->lock);
return ret;
}
}
}

/* Wrong type */
PyErr_SetString(PyExc_TypeError,
"zstd_dict argument should be ZstdDict object.");
return -1;
int ret;
PyMutex_Lock(&zd->lock);
ret = _zstd_load_impl(self, zd, mod_state, type);
PyMutex_Unlock(&zd->lock);
return ret;
}

/*[clinic input]
Expand Down
54 changes: 9 additions & 45 deletions Modules/_zstd/decompressor.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class _zstd.ZstdDecompressor "ZstdDecompressor *" "&zstd_decompressor_type_spec"

#include "_zstdmodule.h"
#include "buffer.h"
#include "zstddict.h"
#include "internal/pycore_lock.h" // PyMutex_IsLocked

#include <stdbool.h> // bool
Expand Down Expand Up @@ -177,52 +176,17 @@ static int
_zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict)
{
_zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self));
if (mod_state == NULL) {
/* When decompressing, use digested dictionary by default. */
int type = DICT_TYPE_DIGESTED;
ZstdDict *zd = _Py_parse_zstd_dict(mod_state, dict, &type);
if (zd == NULL) {
return -1;
}
ZstdDict *zd;
int type, ret;

/* Check ZstdDict */
if (PyObject_TypeCheck(dict, mod_state->ZstdDict_type)) {
/* When decompressing, use digested dictionary by default. */
zd = (ZstdDict*)dict;
type = DICT_TYPE_DIGESTED;
PyMutex_Lock(&zd->lock);
ret = _zstd_load_impl(self, zd, mod_state, type);
PyMutex_Unlock(&zd->lock);
return ret;
}

/* Check (ZstdDict, type) */
if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) {
/* Check ZstdDict */
if (PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0),
mod_state->ZstdDict_type)
&& PyLong_Check(PyTuple_GET_ITEM(dict, 1)))
{
type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1));
if (type == -1 && PyErr_Occurred()) {
return -1;
}
if (type == DICT_TYPE_DIGESTED
|| type == DICT_TYPE_UNDIGESTED
|| type == DICT_TYPE_PREFIX)
{
assert(type >= 0);
zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0);
PyMutex_Lock(&zd->lock);
ret = _zstd_load_impl(self, zd, mod_state, type);
PyMutex_Unlock(&zd->lock);
return ret;
}
}
}

/* Wrong type */
PyErr_SetString(PyExc_TypeError,
"zstd_dict argument should be ZstdDict object.");
return -1;
int ret;
PyMutex_Lock(&zd->lock);
ret = _zstd_load_impl(self, zd, mod_state, type);
PyMutex_Unlock(&zd->lock);
return ret;
}

/*
Expand Down
1 change: 0 additions & 1 deletion Modules/_zstd/zstddict.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class _zstd.ZstdDict "ZstdDict *" "&zstd_dict_type_spec"
#include "Python.h"

#include "_zstdmodule.h"
#include "zstddict.h"
#include "clinic/zstddict.c.h"
#include "internal/pycore_lock.h" // PyMutex_IsLocked

Expand Down