Skip to content

Commit 22c4fe0

Browse files
committed
Allow keyword-only arguments to follow py::args
This removes the constraint that py::args has to be last (or second-last, with py::kwargs) and instead makes py::args imply py::kw_only for any remaining arguments, allowing you to bind a function that works the same way as a Python function such as: def f(a, *args, b): return a * b + sum(args) f(10, 1, 2, 3, b=20) # == 206 With this change, you can bind such a function using: m.def("f", [](int a, py::args args, int b) { /* ... */ }, "a"_a, "b"_a); Or, to be more explicit about the keyword-only arguments: m.def("g", [](int a, py::args args, int b) { /* ... */ }, "a"_a, py::kw_only{}, "b"_a); (The only difference between the two is that the latter will fail at binding time if the `kw_only{}` doesn't match the `py::args` position). This doesn't affect backwards compatibility at all because, currently, you can't have a py::args anywhere except the end/2nd-last.
1 parent 44f1011 commit 22c4fe0

File tree

7 files changed

+178
-40
lines changed

7 files changed

+178
-40
lines changed

docs/advanced/functions.rst

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,9 @@ The class ``py::args`` derives from ``py::tuple`` and ``py::kwargs`` derives
306306
from ``py::dict``.
307307

308308
You may also use just one or the other, and may combine these with other
309-
arguments as long as the ``py::args`` and ``py::kwargs`` arguments are the last
310-
arguments accepted by the function.
309+
arguments. Note, however, that ``py::kwargs`` must always be the last argument
310+
of the function, and ``py::args`` implies that any further arguments are
311+
keyword-only (see :ref:`keyword_only_arguments`).
311312

312313
Please refer to the other examples for details on how to iterate over these,
313314
and on how to cast their entries into C++ objects. A demonstration is also
@@ -366,6 +367,8 @@ like so:
366367
py::class_<MyClass>("MyClass")
367368
.def("myFunction", py::arg("arg") = static_cast<SomeType *>(nullptr));
368369
370+
.. _keyword_only_arguments:
371+
369372
Keyword-only arguments
370373
======================
371374

@@ -397,6 +400,15 @@ feature does *not* require Python 3 to work.
397400

398401
.. versionadded:: 2.6
399402

403+
As of pybind11 2.9, a ``py::args`` argument implies that any following arguments
404+
are keyword-only, as if ``py::kw_only()`` had been specified in the same
405+
relative location of the argument list as the ``py::args`` argument. The
406+
``py::kw_only()`` may be included to be explicit about this, but is not
407+
required. (Prior to 2.9 ``py::args`` may only occur at the end of the argument
408+
list, or immediately before a ``py::kwargs`` argument at the end).
409+
410+
.. versionadded:: 2.9
411+
400412
Positional-only arguments
401413
=========================
402414

include/pybind11/attr.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ template <> struct process_attribute<is_new_style_constructor> : process_attribu
411411

412412
inline void check_kw_only_arg(const arg &a, function_record *r) {
413413
if (r->args.size() > r->nargs_pos && (!a.name || a.name[0] == '\0'))
414-
pybind11_fail("arg(): cannot specify an unnamed argument after an kw_only() annotation");
414+
pybind11_fail("arg(): cannot specify an unnamed argument after a kw_only() annotation or args() argument");
415415
}
416416

417417
/// Process a keyword argument attribute (*without* a default value)
@@ -461,6 +461,8 @@ template <> struct process_attribute<arg_v> : process_attribute_default<arg_v> {
461461
/// Process a keyword-only-arguments-follow pseudo argument
462462
template <> struct process_attribute<kw_only> : process_attribute_default<kw_only> {
463463
static void init(const kw_only &, function_record *r) {
464+
if (r->has_args && r->nargs_pos != static_cast<std::uint16_t>(r->args.size()))
465+
pybind11_fail("Mismatched args() and kw_only(): they must occur at the same relative argument location (or omit kw_only() entirely)");
464466
r->nargs_pos = static_cast<std::uint16_t>(r->args.size());
465467
}
466468
};
@@ -469,6 +471,9 @@ template <> struct process_attribute<kw_only> : process_attribute_default<kw_onl
469471
template <> struct process_attribute<pos_only> : process_attribute_default<pos_only> {
470472
static void init(const pos_only &, function_record *r) {
471473
r->nargs_pos_only = static_cast<std::uint16_t>(r->args.size());
474+
if (r->nargs_pos_only > r->nargs_pos)
475+
pybind11_fail("pos_only(): cannot follow a py::args() argument");
476+
// It also can't follow a kw_only, but a static_assert in pybind11.h checks that
472477
}
473478
};
474479

include/pybind11/cast.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,19 +1163,17 @@ template <typename... Args>
11631163
class argument_loader {
11641164
using indices = make_index_sequence<sizeof...(Args)>;
11651165

1166-
template <typename Arg> using argument_is_args = std::is_same<intrinsic_t<Arg>, args>;
1167-
template <typename Arg> using argument_is_kwargs = std::is_same<intrinsic_t<Arg>, kwargs>;
1168-
// Get args/kwargs argument positions relative to the end of the argument list:
1169-
static constexpr auto args_pos = constexpr_first<argument_is_args, Args...>() - (int) sizeof...(Args),
1170-
kwargs_pos = constexpr_first<argument_is_kwargs, Args...>() - (int) sizeof...(Args);
1166+
// Get kwargs argument position, or -1 if not present:
1167+
static constexpr auto kwargs_pos = constexpr_last(std::is_same<intrinsic_t<Args>, kwargs>::value...);
11711168

1172-
static constexpr bool args_kwargs_are_last = kwargs_pos >= - 1 && args_pos >= kwargs_pos - 1;
1173-
1174-
static_assert(args_kwargs_are_last, "py::args/py::kwargs are only permitted as the last argument(s) of a function");
1169+
static_assert(kwargs_pos == -1 || kwargs_pos == (int) sizeof...(Args) - 1, "py::kwargs is only permitted as the last argument of a function");
1170+
static_assert(constexpr_sum(std::is_same<intrinsic_t<Args>, args>::value...) <= 1, "py::args cannot be specified more than once");
11751171

11761172
public:
1177-
static constexpr bool has_kwargs = kwargs_pos < 0;
1178-
static constexpr bool has_args = args_pos < 0;
1173+
static constexpr bool has_kwargs = kwargs_pos != -1;
1174+
1175+
// py::args argument position; -1 if not present.
1176+
static constexpr int args_pos = constexpr_last(std::is_same<intrinsic_t<Args>, args>::value...);
11791177

11801178
static constexpr auto arg_names = concat(type_descr(make_caster<Args>::name)...);
11811179

@@ -1381,8 +1379,8 @@ template <return_value_policy policy, typename... Args,
13811379
unpacking_collector<policy> collect_arguments(Args &&...args) {
13821380
// Following argument order rules for generalized unpacking according to PEP 448
13831381
static_assert(
1384-
constexpr_last<is_positional, Args...>() < constexpr_first<is_keyword_or_ds, Args...>()
1385-
&& constexpr_last<is_s_unpacking, Args...>() < constexpr_first<is_ds_unpacking, Args...>(),
1382+
constexpr_last(is_positional<Args>::value...) < constexpr_first(is_keyword_or_ds<Args>::value...)
1383+
&& constexpr_last(is_s_unpacking<Args>::value...) < constexpr_first(is_ds_unpacking<Args>::value...),
13861384
"Invalid function call: positional args must precede keywords and ** unpacking; "
13871385
"* unpacking must precede ** unpacking"
13881386
);

include/pybind11/detail/common.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -684,14 +684,13 @@ template <typename T, typename... Ts>
684684
constexpr int last(int i, int result, T v, Ts... vs) { return last(i + 1, v ? i : result, vs...); }
685685
PYBIND11_NAMESPACE_END(constexpr_impl)
686686

687-
/// Return the index of the first type in Ts which satisfies Predicate<T>. Returns sizeof...(Ts) if
688-
/// none match.
689-
template <template<typename> class Predicate, typename... Ts>
690-
constexpr int constexpr_first() { return constexpr_impl::first(0, Predicate<Ts>::value...); }
687+
/// Returns the index of the first true argument. Returns sizeof...(Args) if none match.
688+
template <typename... Args>
689+
constexpr int constexpr_first(Args... args) { return constexpr_impl::first(0, args...); }
691690

692-
/// Return the index of the last type in Ts which satisfies Predicate<T>, or -1 if none match.
693-
template <template<typename> class Predicate, typename... Ts>
694-
constexpr int constexpr_last() { return constexpr_impl::last(0, -1, Predicate<Ts>::value...); }
691+
/// Returns the index of the last true argument. Returns -1 if none match.
692+
template <typename... Args>
693+
constexpr int constexpr_last(Args... args) { return constexpr_impl::last(0, -1, args...); }
695694

696695
/// Return the Nth element from the parameter pack
697696
template <size_t N, typename T, typename... Ts>
@@ -706,7 +705,7 @@ struct exactly_one {
706705
static constexpr auto found = constexpr_sum(Predicate<Ts>::value...);
707706
static_assert(found <= 1, "Found more than one type matching the predicate");
708707

709-
static constexpr auto index = found ? constexpr_first<Predicate, Ts...>() : 0;
708+
static constexpr auto index = found ? constexpr_first(Predicate<Ts>::value...) : 0;
710709
using type = conditional_t<found, typename pack_element<index, Ts...>::type, Default>;
711710
};
712711
template <template<typename> class P, typename Default>

include/pybind11/pybind11.h

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ class cpp_function : public function {
203203
conditional_t<std::is_void<Return>::value, void_type, Return>
204204
>;
205205

206-
static_assert(expected_num_args<Extra...>(sizeof...(Args), cast_in::has_args, cast_in::has_kwargs),
206+
static_assert(expected_num_args<Extra...>(sizeof...(Args), cast_in::args_pos >= 0, cast_in::has_kwargs),
207207
"The number of argument annotations does not match the number of function arguments");
208208

209209
/* Dispatch code which converts function arguments and performs the actual function call */
@@ -238,19 +238,27 @@ class cpp_function : public function {
238238
return result;
239239
};
240240

241-
rec->nargs_pos = sizeof...(Args) - cast_in::has_args - cast_in::has_kwargs; // Will get reduced more if we have a kw_only
241+
rec->nargs_pos = cast_in::args_pos >= 0
242+
? static_cast<std::uint16_t>(cast_in::args_pos)
243+
: sizeof...(Args) - cast_in::has_kwargs; // Will get reduced more if we have a kw_only
244+
if (cast_in::args_pos >= 0) rec->has_args = true;
245+
if (cast_in::has_kwargs) rec->has_kwargs = true;
242246

243247
/* Process any user-provided function attributes */
244248
process_attributes<Extra...>::init(extra..., rec);
245249

246250
{
247251
constexpr bool has_kw_only_args = any_of<std::is_same<kw_only, Extra>...>::value,
248252
has_pos_only_args = any_of<std::is_same<pos_only, Extra>...>::value,
249-
has_args = any_of<std::is_same<args, Args>...>::value,
250253
has_arg_annotations = any_of<is_keyword<Extra>...>::value;
251254
static_assert(has_arg_annotations || !has_kw_only_args, "py::kw_only requires the use of argument annotations");
252255
static_assert(has_arg_annotations || !has_pos_only_args, "py::pos_only requires the use of argument annotations (for docstrings and aligning the annotations to the argument)");
253-
static_assert(!(has_args && has_kw_only_args), "py::kw_only cannot be combined with a py::args argument");
256+
257+
static_assert(constexpr_sum(std::is_same<kw_only, Extra>::value...) <= 1, "py::kw_only may be specified only once");
258+
static_assert(constexpr_sum(std::is_same<pos_only, Extra>::value...) <= 1, "py::pos_only may be specified only once");
259+
constexpr auto kw_only_pos = constexpr_first(std::is_same<kw_only, Extra>::value...);
260+
constexpr auto pos_only_pos = constexpr_first(std::is_same<pos_only, Extra>::value...);
261+
static_assert(!(has_kw_only_args && has_pos_only_args) || pos_only_pos < kw_only_pos, "py::pos_only must come before py::kw_only");
254262
}
255263

256264
/* Generate a readable signature describing the function's arguments and return value types */
@@ -261,9 +269,6 @@ class cpp_function : public function {
261269
// Pass on the ownership over the `unique_rec` to `initialize_generic`. `rec` stays valid.
262270
initialize_generic(std::move(unique_rec), signature.text, types.data(), sizeof...(Args));
263271

264-
if (cast_in::has_args) rec->has_args = true;
265-
if (cast_in::has_kwargs) rec->has_kwargs = true;
266-
267272
/* Stash some additional information used by an important optimization in 'functional.h' */
268273
using FunctionType = Return (*)(Args...);
269274
constexpr bool is_function_ptr =
@@ -342,15 +347,17 @@ class cpp_function : public function {
342347
/* Generate a proper function signature */
343348
std::string signature;
344349
size_t type_index = 0, arg_index = 0;
350+
bool is_starred = false;
345351
for (auto *pc = text; *pc != '\0'; ++pc) {
346352
const auto c = *pc;
347353

348354
if (c == '{') {
349355
// Write arg name for everything except *args and **kwargs.
350-
if (*(pc + 1) == '*')
356+
is_starred = *(pc + 1) == '*';
357+
if (is_starred)
351358
continue;
352359
// Separator for keyword-only arguments, placed before the kw
353-
// arguments start
360+
// arguments start (unless we are already putting an *args)
354361
if (!rec->has_args && arg_index == rec->nargs_pos)
355362
signature += "*, ";
356363
if (arg_index < rec->args.size() && rec->args[arg_index].name) {
@@ -363,15 +370,16 @@ class cpp_function : public function {
363370
signature += ": ";
364371
} else if (c == '}') {
365372
// Write default value if available.
366-
if (arg_index < rec->args.size() && rec->args[arg_index].descr) {
373+
if (!is_starred && arg_index < rec->args.size() && rec->args[arg_index].descr) {
367374
signature += " = ";
368375
signature += rec->args[arg_index].descr;
369376
}
370377
// Separator for positional-only arguments (placed after the
371378
// argument, rather than before like *
372379
if (rec->nargs_pos_only > 0 && (arg_index + 1) == rec->nargs_pos_only)
373380
signature += ", /";
374-
arg_index++;
381+
if (!is_starred)
382+
arg_index++;
375383
} else if (c == '%') {
376384
const std::type_info *t = types[type_index++];
377385
if (!t)
@@ -397,7 +405,7 @@ class cpp_function : public function {
397405
}
398406
}
399407

400-
if (arg_index != args || types[type_index] != nullptr)
408+
if (arg_index != args - rec->has_args - rec->has_kwargs || types[type_index] != nullptr)
401409
pybind11_fail("Internal error while parsing type signature (2)");
402410

403411
#if PY_MAJOR_VERSION < 3
@@ -697,6 +705,10 @@ class cpp_function : public function {
697705
if (bad_arg)
698706
continue; // Maybe it was meant for another overload (issue #688)
699707

708+
// Keep track of how many position args we copied out in case we need to come back
709+
// to copy the rest into a py::args argument.
710+
size_t positional_args_copied = args_copied;
711+
700712
// We'll need to copy this if we steal some kwargs for defaults
701713
dict kwargs = reinterpret_borrow<dict>(kwargs_in);
702714

@@ -749,6 +761,10 @@ class cpp_function : public function {
749761
}
750762

751763
if (value) {
764+
// If we're at the py::args index then first insert a stub for it to be replaced later
765+
if (func.has_args && call.args.size() == func.nargs_pos)
766+
call.args.push_back(none());
767+
752768
call.args.push_back(value);
753769
call.args_convert.push_back(arg_rec.convert);
754770
}
@@ -771,16 +787,19 @@ class cpp_function : public function {
771787
// We didn't copy out any position arguments from the args_in tuple, so we
772788
// can reuse it directly without copying:
773789
extra_args = reinterpret_borrow<tuple>(args_in);
774-
} else if (args_copied >= n_args_in) {
790+
} else if (positional_args_copied >= n_args_in) {
775791
extra_args = tuple(0);
776792
} else {
777-
size_t args_size = n_args_in - args_copied;
793+
size_t args_size = n_args_in - positional_args_copied;
778794
extra_args = tuple(args_size);
779795
for (size_t i = 0; i < args_size; ++i) {
780-
extra_args[i] = PyTuple_GET_ITEM(args_in, args_copied + i);
796+
extra_args[i] = PyTuple_GET_ITEM(args_in, positional_args_copied + i);
781797
}
782798
}
783-
call.args.push_back(extra_args);
799+
if (call.args.size() <= func.nargs_pos)
800+
call.args.push_back(extra_args);
801+
else
802+
call.args[func.nargs_pos] = extra_args;
784803
call.args_convert.push_back(false);
785804
call.args_ref = std::move(extra_args);
786805
}

tests/test_kwargs_and_defaults.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,23 @@ TEST_SUBMODULE(kwargs_and_defaults, m) {
5656
m.def("mixed_plus_args_kwargs_defaults", mixed_plus_both,
5757
py::arg("i") = 1, py::arg("j") = 3.14159);
5858

59+
m.def("args_kwonly",
60+
[](int i, double j, py::args args, int z) { return py::make_tuple(i, j, args, z); },
61+
"i"_a, "j"_a, "z"_a);
62+
m.def("args_kwonly_kwargs",
63+
[](int i, double j, py::args args, int z, py::kwargs kwargs) {
64+
return py::make_tuple(i, j, args, z, kwargs); },
65+
"i"_a, "j"_a, py::kw_only{}, "z"_a);
66+
m.def("args_kwonly_kwargs_defaults",
67+
[](int i, double j, py::args args, int z, py::kwargs kwargs) {
68+
return py::make_tuple(i, j, args, z, kwargs); },
69+
"i"_a = 1, "j"_a = 3.14159, "z"_a = 42);
70+
m.def("args_kwonly_full_monty",
71+
[](int h, int i, double j, py::args args, int z, py::kwargs kwargs) {
72+
return py::make_tuple(h, i, j, args, z, kwargs); },
73+
py::arg() = 1, py::arg() = 2, py::pos_only{}, "j"_a = 3.14159, "z"_a = 42);
74+
75+
5976
// test_args_refcount
6077
// PyPy needs a garbage collection to get the reference count values to match CPython's behaviour
6178
#ifdef PYPY_VERSION

0 commit comments

Comments
 (0)