Skip to content

Commit 3b78409

Browse files
authored
gh-87138: convert SHA-3 object type to heap type (GH-127670)
1 parent 8fa5ece commit 3b78409

File tree

1 file changed

+29
-10
lines changed

1 file changed

+29
-10
lines changed

Modules/sha3module.c

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ typedef struct {
7171
static SHA3object *
7272
newSHA3object(PyTypeObject *type)
7373
{
74-
SHA3object *newobj;
75-
newobj = (SHA3object *)PyObject_New(SHA3object, type);
74+
SHA3object *newobj = PyObject_GC_New(SHA3object, type);
7675
if (newobj == NULL) {
7776
return NULL;
7877
}
7978
HASHLIB_INIT_MUTEX(newobj);
8079

80+
PyObject_GC_Track(newobj);
8181
return newobj;
8282
}
8383

@@ -166,15 +166,32 @@ py_sha3_new_impl(PyTypeObject *type, PyObject *data, int usedforsecurity)
166166

167167
/* Internal methods for a hash object */
168168

169+
static int
170+
SHA3_clear(SHA3object *self)
171+
{
172+
if (self->hash_state != NULL) {
173+
Hacl_Hash_SHA3_free(self->hash_state);
174+
self->hash_state = NULL;
175+
}
176+
return 0;
177+
}
178+
169179
static void
170180
SHA3_dealloc(SHA3object *self)
171181
{
172-
Hacl_Hash_SHA3_free(self->hash_state);
173182
PyTypeObject *tp = Py_TYPE(self);
174-
PyObject_Free(self);
183+
PyObject_GC_UnTrack(self);
184+
(void)SHA3_clear(self);
185+
tp->tp_free(self);
175186
Py_DECREF(tp);
176187
}
177188

189+
static int
190+
SHA3_traverse(PyObject *self, visitproc visit, void *arg)
191+
{
192+
Py_VISIT(Py_TYPE(self));
193+
return 0;
194+
}
178195

179196
/* External methods for a hash object */
180197

@@ -335,6 +352,7 @@ static PyObject *
335352
SHA3_get_capacity_bits(SHA3object *self, void *closure)
336353
{
337354
uint32_t rate = Hacl_Hash_SHA3_block_len(self->hash_state) * 8;
355+
assert(rate <= 1600);
338356
int capacity = 1600 - rate;
339357
return PyLong_FromLong(capacity);
340358
}
@@ -366,12 +384,14 @@ static PyGetSetDef SHA3_getseters[] = {
366384

367385
#define SHA3_TYPE_SLOTS(type_slots_obj, type_doc, type_methods, type_getseters) \
368386
static PyType_Slot type_slots_obj[] = { \
387+
{Py_tp_clear, SHA3_clear}, \
369388
{Py_tp_dealloc, SHA3_dealloc}, \
389+
{Py_tp_traverse, SHA3_traverse}, \
370390
{Py_tp_doc, (char*)type_doc}, \
371391
{Py_tp_methods, type_methods}, \
372392
{Py_tp_getset, type_getseters}, \
373393
{Py_tp_new, py_sha3_new}, \
374-
{0,0} \
394+
{0, NULL} \
375395
}
376396

377397
// Using _PyType_GetModuleState() on these types is safe since they
@@ -380,7 +400,8 @@ static PyGetSetDef SHA3_getseters[] = {
380400
static PyType_Spec type_spec_obj = { \
381401
.name = "_sha3." type_name, \
382402
.basicsize = sizeof(SHA3object), \
383-
.flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE, \
403+
.flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE \
404+
| Py_TPFLAGS_HAVE_GC, \
384405
.slots = type_slots \
385406
}
386407

@@ -444,9 +465,7 @@ _SHAKE_digest(SHA3object *self, unsigned long digestlen, int hex)
444465
result = PyBytes_FromStringAndSize((const char *)digest,
445466
digestlen);
446467
}
447-
if (digest != NULL) {
448-
PyMem_Free(digest);
449-
}
468+
PyMem_Free(digest);
450469
return result;
451470
}
452471

@@ -563,7 +582,7 @@ _sha3_clear(PyObject *module)
563582
static void
564583
_sha3_free(void *module)
565584
{
566-
_sha3_clear((PyObject *)module);
585+
(void)_sha3_clear((PyObject *)module);
567586
}
568587

569588
static int

0 commit comments

Comments
 (0)