diff --git a/pep-0646.rst b/pep-0646.rst index 965c6062d88..51e4139991f 100644 --- a/pep-0646.rst +++ b/pep-0646.rst @@ -610,6 +610,16 @@ manipulation mechanisms. We plan to introduce these in a future PEP.) Rationale and Rejected Ideas ============================ +Shape Arithmetic +---------------- + +Considering the use case of array shapes in particular, note that as of +this PEP, it is not yet possible to describe arithmetic transformations +of array dimensions - for example, +``def repeat_each_element(x: Array[N]) -> Array[2*N]``. We consider +this out-of-scope for the current PEP, but plan to propose additional +mechanisms that *will* enable this in a future PEP. + Supporting Variadicity Through Aliases -------------------------------------- @@ -743,11 +753,205 @@ is available in `cpython/23527`_. A preliminary version of the version using the star operator, based on an early implementation of PEP 637, is also available at `mrahtz/cpython/pep637+646`_. +Appendix A: Shape Typing Use Cases +================================== + +To give this PEP additional context for those particularly interested in the +array typing use case, in this appendix we expand on the different ways +this PEP can be used for specifying shape-based subtypes. + +Use Case 1: Specifying Shape Values +----------------------------------- + +The simplest way to parameterise array types is using ``Literal`` +type parameters - e.g. ``Array[Literal[64], Literal[64]]``. + +We can attach names to each parameter using normal type variables: + +:: + + K = TypeVar('K') + N = TypeVar('N') + + def matrix_vector_multiply(x: Array[K, N], Array[N]) -> Array[K]: ... + + a: Array[Literal[64], Literal[32]] + b: Array[Literal[32]] + matrix_vector_multiply(a, b) + # Result is Array[Literal[64]] + +Note that such names have a purely local scope. That is, the name +``K`` is bound to ``Literal[64]`` only within ``matrix_vector_multiply``. To put it another +way, there's no relationship between the value of ``K`` in different +signatures. This is important: it would be inconvenient if every axis named ``K`` +were constrained to have the same value throughout the entire program. + +The disadvantage of this approach is that we have no ability to enforce shape semantics across +different calls. For example, we can't address the problem mentioned in `Motivation`_: if +one function returns an array with leading dimensions 'Time × Batch', and another function +takes the same array assuming leading dimensions 'Batch × Time', we have no way of detecting this. + +The main advantage is that in some cases, axis sizes really are what we care about. This is true +for both simple linear algebra operations such as the matrix manipulations above, but also in more +complicated transformations such as convolutional layers in neural networks, where it would be of +great utility to the programmer to be able to inspect the array size after each layer using +static analysis. To aid this, in the future we would like to explore possibilities for additional +type operators that enable arithmetic on array shapes - for example: + +:: + + def repeat_each_element(x: Array[N]) -> Array[Mul[2, N]]: ... + +Such arithmetic type operators would only make sense if names such as ``N`` refer to axis size. + +Use Case 2: Specifying Shape Semantics +-------------------------------------- + +A second approach (the one that most of the examples in this PEP are based around) +is to forgo annotation with actual axis size, and instead annotate axis *type*. + +This would enable us to solve the problem of enforcing shape properties across calls. +For example: + +:: + + # lib.py + + class Batch: pass + class Time: pass + + def make_array() -> Array[Batch, Time]: ... + + # user.py + + from lib import Batch, Time + + # `Batch` and `Time` have the same identity as in `lib`, + # so must take array as produced by `lib.make_array` + def use_array(x: Array[Batch, Time]): ... + +Note that in this case, names are *global* (to the extent that we use the +same ``Batch`` type in different place). However, because names refer only +to axis *types*, this doesn't constrain the *value* of certain axes to be +the same through (that is, this doesn't constrain all axes named ``Height`` +to have a value of, say, 480 throughout). + +The argument *for* this approach is that in many cases, axis *type* is the more +important thing to verify; we care more about which axis is which than what the +specific size of each axis is. + +It also does not preclude cases where we wish to describe shape transformations +without knowing the type ahead of time. For example, we can still write: + +:: + + K = TypeVar('K') + N = TypeVar('N') + + def matrix_vector_multiply(x: Array[K, N], Array[N]) -> Array[K]: ... + +We can then use this with: + + class Batch: pass + class Values: pass + + batch_of_values: Array[Batch, Values] + value_weights: Array[Values] + matrix_vector_multiply(batch_of_values, value_weights) + # Result is Array[Batch] + +The disadvantages are the inverse of the advantages from use case 1. +In particular, this approach does not lend itself well to arithmetic +on axis types: ``Mul[2, Batch]`` would be as meaningless as ``2 * int``. + +Discussion +---------- + +Note that use cases 1 and 2 are mutually exclusive in user code. Users +can verify size or semantic type but not both. + +As of this PEP, we are agnostic about which approach will provide most benefit. +Since the features introduced in this PEP are compatible with both approaches, however, +we leave the door open. + +Why Not Both? +------------- + +Consider the following 'normal' code: + +:: + + def f(x: int): ... + +Note that we have symbols for both the value of the thing (``x``) and the type of +the thing (``int``). Why can't we do the same with axes? For example, with an imaginary +syntax, we could write: + +:: + + def f(array: Array[TimeValue: TimeType]): ... + +This would allow us to access the axis size (say, 32) through the symbol ``TimeValue`` +*and* the type through the symbol ``TypeType``. + +This might even be possible using existing syntax, through a second level of parameterisation: + +:: + + def f(array: array[TimeValue[TimeType]]): .. + +However, we leave exploration of this approach to the future. + +Appendix B: Shaped Types vs Named Axes +====================================== + +An issue related to those addressed by this PEP concerns +axis *selection*. For example, if we have an image stored in an array of shape 64×64x3, +we might wish to convert to black-and-white by computing the mean over the third +axis, ``mean(image, axis=2)``. Unfortunately, the simple typo ``axis=1`` is +difficult to spot and will produce a result that means something completely different +(all while likely allowing the program to keep on running, resulting in a bug +that is serious but silent). + +In response, some libraries have implemented so-called 'named tensors' (in this context, +'tensor' is synonymous with 'array'), in which axes are selected not by index but by +label - e.g. ``mean(image, axis='channels')``. + +A question we are often asked about this PEP is: why not just use named tensors? +The answer is that we consider the named tensors approach insufficient, for two main reasons: + +* **Static checking** of shape correctness is not possible. As mentioned in `Motivation`_, + this is a highly desireable feature in machine learning code where iteration times + are slow by default. +* **Interface documentation** is still not possible with this approach. If a function should + *only* be willing to take array arguments that have image-like shapes, this cannot be stipulated + with named tensors. + +Additionally, there's the issue of **poor uptake**. At the time of writing, named tensors +have only been implemented in a small number of numerical computing libraries. Possible explanations for this +include difficulty of implementation (the whole API must be modified to allow selection by axis name +instead of index), and lack of usefulness due to the fact that axis ordering conventions are often +strong enough that axis names provide little benefit (e.g. when working with images, 3D tensors are +basically *always* height × width × channels). However, ultimately we are still uncertain +why this is the case. + +Can the named tensors approach be combined with the approach we advocate for in +this PEP? We're not sure. One area of overlap is that in some contexts, we could do, say: + +:: + + Image: Array[Height, Width, Channels] + im: Image + mean(im, axis=Image.axes.index(Channels) + +Ideally, we might write something like ``im: Array[Height=64, Width=64, Channels=3]`` - +but this won't be possible in the short term, due to the rejection of PEP 637. +In any case, our attitude towards this is mostly "Wait and see what happens before +taking any further steps". Footnotes ========== - .. [#batch] 'Batch' is machine learning parlance for 'a number of'. .. [#array] We use the term 'array' to refer to a matrix with an arbitrary