Skip to content

Commit e295d8c

Browse files
committed
Add custom_type_setup attribute
This allows for custom modifications to the PyHeapTypeObject prior to calling `PyType_Ready`. This may be used, for example, to define `tp_traverse` and `tp_clear` functions.
1 parent 614ca93 commit e295d8c

File tree

10 files changed

+243
-3
lines changed

10 files changed

+243
-3
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ set(PYBIND11_HEADERS
106106
include/pybind11/detail/class.h
107107
include/pybind11/detail/common.h
108108
include/pybind11/detail/descr.h
109+
include/pybind11/detail/function_view.h
109110
include/pybind11/detail/init.h
110111
include/pybind11/detail/internals.h
111112
include/pybind11/detail/type_caster_base.h

docs/advanced/classes.rst

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1247,7 +1247,7 @@ Accessing the type object
12471247

12481248
You can get the type object from a C++ class that has already been registered using:
12491249

1250-
.. code-block:: python
1250+
.. code-block:: cpp
12511251
12521252
py::type T_py = py::type::of<T>();
12531253
@@ -1259,3 +1259,40 @@ object, just like ``type(ob)`` in Python.
12591259
Other types, like ``py::type::of<int>()``, do not work, see :ref:`type-conversions`.
12601260

12611261
.. versionadded:: 2.6
1262+
1263+
Custom type setup
1264+
=================
1265+
1266+
For advanced use cases, such as setting certain type object slots, you may wish
1267+
to directly manipulate the `PyHeapTypeObject` corresponding to a ``py::class_``
1268+
definition.
1269+
1270+
You can do that using ``py::custom_type_setup``:
1271+
1272+
.. code-block:: cpp
1273+
1274+
struct OwnsPythonObjects {
1275+
py::object value = py::none();
1276+
};
1277+
py::class_<OwnsPythonObjects> cls(
1278+
m, "OwnsPythonObjects", py::custom_type_setup([](PyHeapTypeObject *heap_type) {
1279+
auto *type = &heap_type->ht_type;
1280+
type->tp_flags |= Py_TPFLAGS_HAVE_GC;
1281+
type->tp_traverse = +[](PyObject *self_base, visitproc visit, void *arg) {
1282+
auto &self = py::cast<OwnsPythonObjects&>(py::handle(self_base));
1283+
Py_VISIT(self.value.ptr());
1284+
return 0;
1285+
};
1286+
type->tp_clear = +[](PyObject *self_base) {
1287+
auto &self = py::cast<OwnsPythonObjects&>(py::handle(self_base));
1288+
// Use temporary object to ensure `self.value` remains valid
1289+
// while Python APIs are called.
1290+
py::object temp = py::none();
1291+
std::swap(temp, self.value);
1292+
return 0;
1293+
};
1294+
}));
1295+
cls.def(py::init([] { return OwnsPythonObjects{}; }));
1296+
cls.def_readwrite("value", &OwnsPythonObjects::value);
1297+
1298+
.. versionadded:: 2.8

include/pybind11/attr.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#pragma once
1212

1313
#include "cast.h"
14+
#include "detail/function_view.h"
1415

1516
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
1617

@@ -79,6 +80,23 @@ struct metaclass {
7980
explicit metaclass(handle value) : value(value) { }
8081
};
8182

83+
/// Specifies a custom callback with signature `void (PyHeapTypeObject*)` that
84+
/// may be used to customize the Python type.
85+
///
86+
/// The callback is invoked immediately before `PyType_Ready`.
87+
///
88+
/// Note: This is an advanced interface, and uses of it may require changes to
89+
/// work with later versions of pybind11. You may wish to consult the
90+
/// implementation of `make_new_python_type` in `detail/classes.h` to understand
91+
/// the context in which the callback will be run.
92+
struct custom_type_setup {
93+
using callback = detail::function_view<void(PyHeapTypeObject* heap_type)>;
94+
95+
explicit custom_type_setup(callback value) : value(value) {}
96+
97+
callback value;
98+
};
99+
82100
/// Annotation that marks a class as local to the module:
83101
struct module_local { const bool value;
84102
constexpr explicit module_local(bool v = true) : value(v) {}
@@ -272,6 +290,9 @@ struct type_record {
272290
/// Custom metaclass (optional)
273291
handle metaclass;
274292

293+
/// Custom type setup.
294+
::pybind11::custom_type_setup::callback custom_type_setup;
295+
275296
/// Multiple inheritance marker
276297
bool multiple_inheritance : 1;
277298

@@ -476,6 +497,11 @@ struct process_attribute<dynamic_attr> : process_attribute_default<dynamic_attr>
476497
static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; }
477498
};
478499

500+
template <>
501+
struct process_attribute<custom_type_setup> {
502+
static void init(const custom_type_setup &value, type_record *r) { r->custom_type_setup = value.value; }
503+
};
504+
479505
template <>
480506
struct process_attribute<is_final> : process_attribute_default<is_final> {
481507
static void init(const is_final &, type_record *r) { r->is_final = true; }

include/pybind11/detail/class.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,11 +683,13 @@ inline PyObject* make_new_python_type(const type_record &rec) {
683683
if (rec.buffer_protocol)
684684
enable_buffer_protocol(heap_type);
685685

686+
if (rec.custom_type_setup)
687+
rec.custom_type_setup(heap_type);
688+
686689
if (PyType_Ready(type) < 0)
687690
pybind11_fail(std::string(rec.name) + ": PyType_Ready failed (" + error_string() + ")!");
688691

689-
assert(rec.dynamic_attr ? PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)
690-
: !PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC));
692+
assert(!rec.dynamic_attr || PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC));
691693

692694
/* Register type with the parent scope */
693695
if (rec.scope)

include/pybind11/detail/common.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,13 @@ template <typename T> using remove_cv_t = typename std::remove_cv<T>::type;
570570
template <typename T> using remove_reference_t = typename std::remove_reference<T>::type;
571571
#endif
572572

573+
#ifdef __cpp_lib_remove_cvref
574+
using std::remove_cvref_t;
575+
#else
576+
template <typename T>
577+
using remove_cvref_t = remove_cv_t<remove_reference_t<T>>;
578+
#endif
579+
573580
/// Index sequences
574581
#if defined(PYBIND11_CPP14)
575582
using std::index_sequence;
@@ -771,6 +778,21 @@ using function_signature_t = conditional_t<
771778
template <typename T> using is_lambda = satisfies_none_of<remove_reference_t<T>,
772779
std::is_function, std::is_pointer, std::is_member_pointer>;
773780

781+
template <typename F, typename... Arg>
782+
using call_result_t = decltype(std::declval<F>()(std::declval<Arg>()...));
783+
784+
template <typename Expected>
785+
struct is_expected_return_type {
786+
template <typename Actual>
787+
using compatible_with = std::is_convertible<Actual, Expected>;
788+
};
789+
790+
template <>
791+
struct is_expected_return_type<void> {
792+
template <typename Actual>
793+
using compatible_with = std::true_type;
794+
};
795+
774796
// [workaround(intel)] Internal error on fold expression
775797
/// Apply a function over each element of a parameter pack
776798
#if defined(__cpp_fold_expressions) && !defined(__INTEL_COMPILER)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Copyright (c) 2021 The Pybind Development Team.
2+
// All rights reserved. Use of this source code is governed by a
3+
// BSD-style license that can be found in the LICENSE file.
4+
5+
#pragma once
6+
7+
#include "common.h"
8+
9+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
10+
PYBIND11_NAMESPACE_BEGIN(detail)
11+
12+
template <typename Signature>
13+
class function_view;
14+
15+
template <typename R, typename... Arg>
16+
class function_view<R(Arg...)> {
17+
public:
18+
function_view() = default;
19+
20+
/// Creates a view that references an existing function.
21+
///
22+
/// The constructed view is only valid for the lifetime of `f`.
23+
///
24+
/// Requires that `F` is a function object type that is callable with
25+
/// arguments `(Args...)` and has a return type convertible to `R`.
26+
template <typename F,
27+
typename = enable_if_t<
28+
// Prevent construction from the same `function_view` type, to avoid inefficient
29+
// double wrapping that also commonly leads to dangling references.
30+
(!std::is_same<remove_cvref_t<F>, function_view>::value
31+
// Ensure `F` is compatible with the signature.
32+
&& is_expected_return_type<R>::template compatible_with<
33+
call_result_t<F, Arg...>>::value)>>
34+
function_view(F &&f) noexcept // NOLINT(google-explicit-constructor)
35+
: erased_fn_(&wrapper<remove_reference_t<F>>), erased_obj_(&f) {}
36+
37+
/// Calls the referenced function.
38+
///
39+
/// Precondition: this is a non-null function view.
40+
R operator()(Arg... arg) const {
41+
return erased_fn_(const_cast<void *>(erased_obj_), static_cast<Arg &&>(arg)...);
42+
}
43+
44+
/// Returns `true` if this is a non-null function view.
45+
explicit operator bool() const { return erased_fn_ != nullptr; }
46+
47+
private:
48+
template <typename F>
49+
static R wrapper(void *obj, Arg... arg) {
50+
return (*static_cast<typename std::add_pointer<F>::type>(obj))(
51+
static_cast<Arg &&>(arg)...);
52+
}
53+
54+
R (*erased_fn_)(void *, Arg...) = nullptr;
55+
const void *erased_obj_;
56+
};
57+
58+
PYBIND11_NAMESPACE_END(detail)
59+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ set(PYBIND11_TEST_FILES
104104
test_constants_and_functions.cpp
105105
test_copy_move.cpp
106106
test_custom_type_casters.cpp
107+
test_custom_type_setup.cpp
107108
test_docstring_options.cpp
108109
test_eigen.cpp
109110
test_enum.cpp

tests/extra_python_package/test_files.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"include/pybind11/detail/class.h",
4141
"include/pybind11/detail/common.h",
4242
"include/pybind11/detail/descr.h",
43+
"include/pybind11/detail/function_view.h",
4344
"include/pybind11/detail/init.h",
4445
"include/pybind11/detail/internals.h",
4546
"include/pybind11/detail/type_caster_base.h",

tests/test_custom_type_setup.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
tests/test_custom_type_setup.cpp -- Tests `pybind11::custom_type_setup`
3+
4+
Copyright (c) Google LLC
5+
6+
All rights reserved. Use of this source code is governed by a
7+
BSD-style license that can be found in the LICENSE file.
8+
*/
9+
10+
#include <pybind11/pybind11.h>
11+
12+
#include "pybind11_tests.h"
13+
14+
namespace py = pybind11;
15+
16+
namespace {
17+
18+
struct OwnsPythonObjects {
19+
py::object value = py::none();
20+
};
21+
} // namespace
22+
23+
TEST_SUBMODULE(custom_type_setup, m) {
24+
py::class_<OwnsPythonObjects> cls(
25+
m, "OwnsPythonObjects", py::custom_type_setup([](PyHeapTypeObject *heap_type) {
26+
auto *type = &heap_type->ht_type;
27+
type->tp_flags |= Py_TPFLAGS_HAVE_GC;
28+
type->tp_traverse = [](PyObject *self_base, visitproc visit, void *arg) {
29+
auto &self = py::cast<OwnsPythonObjects &>(py::handle(self_base));
30+
Py_VISIT(self.value.ptr());
31+
return 0;
32+
};
33+
type->tp_clear = [](PyObject *self_base) {
34+
auto &self = py::cast<OwnsPythonObjects &>(py::handle(self_base));
35+
self.value = py::none();
36+
return 0;
37+
};
38+
}));
39+
cls.def(py::init([] { return OwnsPythonObjects{}; }));
40+
cls.def_readwrite("value", &OwnsPythonObjects::value);
41+
}

tests/test_custom_type_setup.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import gc
4+
import weakref
5+
6+
import pytest
7+
8+
import env # noqa: F401
9+
from pybind11_tests import custom_type_setup as m
10+
11+
12+
@pytest.fixture
13+
def gc_tester():
14+
"""Tests that an object is garbage collected."""
15+
16+
weak_refs = []
17+
18+
def add_ref(obj):
19+
# PyPy does not support `gc.is_tracked`.
20+
if hasattr(gc, "is_tracked"):
21+
assert gc.is_tracked(obj)
22+
weak_refs.append(weakref.ref(obj))
23+
24+
yield add_ref
25+
26+
gc.collect()
27+
if hasattr(gc, "collect_step"):
28+
# PyPy may require additional encouragement to actually collect the
29+
# garbage.
30+
for _ in range(100):
31+
gc.collect_step()
32+
for ref in weak_refs:
33+
assert ref() is None
34+
35+
36+
# PyPy does not seem to garbage collect.
37+
@pytest.mark.xfail("env.PYPY")
38+
def test_self_cycle(gc_tester):
39+
obj = m.OwnsPythonObjects()
40+
obj.value = obj
41+
gc_tester(obj)
42+
43+
44+
# PyPy does not seem to garbage collect.
45+
@pytest.mark.xfail("env.PYPY")
46+
def test_indirect_cycle(gc_tester):
47+
obj = m.OwnsPythonObjects()
48+
obj_list = [obj]
49+
obj.value = obj_list
50+
gc_tester(obj)

0 commit comments

Comments
 (0)