Skip to content

Commit 9fd28d9

Browse files
anandoleecopybara-github
authored andcommitted
Check with fallback descriptorDB for FindExtensionByNumber()/FindAllExtensions in UPB python pool.
PiperOrigin-RevId: 740001906
1 parent c5b35fa commit 9fd28d9

File tree

4 files changed

+183
-10
lines changed

4 files changed

+183
-10
lines changed

python/descriptor_pool.c

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,61 @@ static bool PyUpb_DescriptorPool_TryLoadFilename(PyUpb_DescriptorPool* self,
146146
return ret;
147147
}
148148

149+
static bool PyUpb_DescriptorPool_TryLoadExtension(PyUpb_DescriptorPool* self,
150+
const upb_MessageDef* m,
151+
int32_t field_number) {
152+
if (!self->db) return false;
153+
const char* full_name = upb_MessageDef_FullName(m);
154+
PyObject* py_name = PyUnicode_FromStringAndSize(full_name, strlen(full_name));
155+
PyObject* py_descriptor = PyObject_CallMethod(
156+
self->db, "FindFileContainingExtension", "Oi", py_name, field_number);
157+
Py_DECREF(py_name);
158+
if (!py_descriptor) {
159+
PyErr_Clear();
160+
return false;
161+
}
162+
bool ret = PyUpb_DescriptorPool_TryLoadFileProto(self, py_descriptor);
163+
Py_DECREF(py_descriptor);
164+
return ret;
165+
}
166+
167+
static void PyUpb_DescriptorPool_TryLoadAllExtensions(
168+
PyUpb_DescriptorPool* self, const upb_MessageDef* m) {
169+
if (!self->db) return;
170+
const char* full_name = upb_MessageDef_FullName(m);
171+
PyObject* py_name = PyUnicode_FromStringAndSize(full_name, strlen(full_name));
172+
PyObject* py_list =
173+
PyObject_CallMethod(self->db, "FindAllExtensionNumbers", "O", py_name);
174+
Py_DECREF(py_name);
175+
if (!py_list) {
176+
PyErr_Clear();
177+
return;
178+
}
179+
Py_ssize_t size = PyList_Size(py_list);
180+
if (size == -1) {
181+
PyErr_Format(
182+
PyExc_RuntimeError,
183+
"FindAllExtensionNumbers() on fall back DB must return a list, not %S",
184+
py_list);
185+
PyErr_Print();
186+
Py_DECREF(py_list);
187+
return;
188+
}
189+
int64_t field_number;
190+
const upb_ExtensionRegistry* reg =
191+
upb_DefPool_ExtensionRegistry(self->symtab);
192+
const upb_MiniTable* t = upb_MessageDef_MiniTable(m);
193+
for (Py_ssize_t i = 0; i < size; ++i) {
194+
PyObject* item = PySequence_GetItem(py_list, i);
195+
field_number = PyLong_AsLong(item);
196+
Py_DECREF(item);
197+
if (!upb_ExtensionRegistry_Lookup(reg, t, field_number)) {
198+
PyUpb_DescriptorPool_TryLoadExtension(self, m, field_number);
199+
}
200+
}
201+
Py_DECREF(py_list);
202+
}
203+
149204
bool PyUpb_DescriptorPool_CheckNoDatabase(PyObject* _self) { return true; }
150205

151206
static bool PyUpb_DescriptorPool_LoadDependentFiles(
@@ -582,8 +637,15 @@ static PyObject* PyUpb_DescriptorPool_FindExtensionByNumber(PyObject* _self,
582637
return NULL;
583638
}
584639

585-
const upb_FieldDef* f = upb_DefPool_FindExtensionByNumber(
586-
self->symtab, PyUpb_Descriptor_GetDef(message_descriptor), number);
640+
const upb_MessageDef* message_def =
641+
PyUpb_Descriptor_GetDef(message_descriptor);
642+
const upb_FieldDef* f =
643+
upb_DefPool_FindExtensionByNumber(self->symtab, message_def, number);
644+
if (f == NULL && self->db) {
645+
if (PyUpb_DescriptorPool_TryLoadExtension(self, message_def, number)) {
646+
f = upb_DefPool_FindExtensionByNumber(self->symtab, message_def, number);
647+
}
648+
}
587649
if (f == NULL) {
588650
return PyErr_Format(PyExc_KeyError, "Couldn't find Extension %d", number);
589651
}
@@ -595,6 +657,9 @@ static PyObject* PyUpb_DescriptorPool_FindAllExtensions(PyObject* _self,
595657
PyObject* msg_desc) {
596658
PyUpb_DescriptorPool* self = (PyUpb_DescriptorPool*)_self;
597659
const upb_MessageDef* m = PyUpb_Descriptor_GetDef(msg_desc);
660+
if (self->db) {
661+
PyUpb_DescriptorPool_TryLoadAllExtensions(self, m);
662+
}
598663
size_t n;
599664
const upb_FieldDef** ext = upb_DefPool_GetAllExtensions(self->symtab, m, &n);
600665
PyObject* ret = PyList_New(n);

python/google/protobuf/descriptor_pool.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -580,11 +580,21 @@ def FindAllExtensions(self, message_descriptor):
580580
if self._descriptor_db and hasattr(
581581
self._descriptor_db, 'FindAllExtensionNumbers'):
582582
full_name = message_descriptor.full_name
583-
all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name)
584-
for number in all_numbers:
585-
if number in self._extensions_by_number[message_descriptor]:
586-
continue
587-
self._TryLoadExtensionFromDB(message_descriptor, number)
583+
try:
584+
all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name)
585+
except:
586+
pass
587+
else:
588+
if isinstance(all_numbers, list):
589+
for number in all_numbers:
590+
if number in self._extensions_by_number[message_descriptor]:
591+
continue
592+
self._TryLoadExtensionFromDB(message_descriptor, number)
593+
else:
594+
warnings.warn(
595+
'FindAllExtensionNumbers() on fall back DB must return a list,'
596+
' not {0}'.format(type(all_numbers))
597+
)
588598

589599
return list(self._extensions_by_number[message_descriptor].values())
590600

@@ -603,8 +613,13 @@ def _TryLoadExtensionFromDB(self, message_descriptor, number):
603613
return
604614

605615
full_name = message_descriptor.full_name
606-
file_proto = self._descriptor_db.FindFileContainingExtension(
607-
full_name, number)
616+
file_proto = None
617+
try:
618+
file_proto = self._descriptor_db.FindFileContainingExtension(
619+
full_name, number
620+
)
621+
except:
622+
return
608623

609624
if file_proto is None:
610625
return
@@ -668,7 +683,6 @@ def SetFeatureSetDefaults(self, defaults):
668683
if not isinstance(defaults, descriptor_pb2.FeatureSetDefaults):
669684
raise TypeError('SetFeatureSetDefaults called with invalid type')
670685

671-
672686
if defaults.minimum_edition > defaults.maximum_edition:
673687
raise ValueError(
674688
'Invalid edition range %s to %s'

python/google/protobuf/internal/descriptor_pool_test.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,5 +1531,91 @@ def testAboveMaximum(self):
15311531
'google/protobuf/internal/more_messages.proto'])
15321532

15331533

1534+
class LocalFakeDB(descriptor_database.DescriptorDatabase):
1535+
1536+
def FindFileContainingExtension(self, extendee_name, extension_number):
1537+
return descriptor_pb2.FileDescriptorProto.FromString(
1538+
factory_test2_pb2.DESCRIPTOR.serialized_pb
1539+
)
1540+
1541+
def FindAllExtensionNumbers(self, extendee_name):
1542+
return [1001, 1002]
1543+
1544+
1545+
class BadDB(descriptor_database.DescriptorDatabase):
1546+
1547+
def FindFileContainingExtension(self, extendee_name, extension_number):
1548+
raise RuntimeError('just ignore it')
1549+
1550+
def FindAllExtensionNumbers(self, extendee_name):
1551+
raise RuntimeError('just ignore it')
1552+
1553+
1554+
class BadDB2(descriptor_database.DescriptorDatabase):
1555+
1556+
# Returns a none list value.
1557+
def FindAllExtensionNumbers(self, extendee_name):
1558+
return 1.2
1559+
1560+
1561+
@testing_refleaks.TestCase
1562+
class FallBackDBTest(unittest.TestCase):
1563+
1564+
def setUp(self):
1565+
self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString(
1566+
factory_test1_pb2.DESCRIPTOR.serialized_pb
1567+
)
1568+
factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString(
1569+
factory_test2_pb2.DESCRIPTOR.serialized_pb
1570+
)
1571+
db = LocalFakeDB()
1572+
db.Add(self.factory_test1_fd)
1573+
db.Add(factory_test2_fd)
1574+
self.pool = descriptor_pool.DescriptorPool(db)
1575+
file_desc = self.pool.FindFileByName(
1576+
'google/protobuf/internal/factory_test1.proto'
1577+
)
1578+
self.message_desc = file_desc.message_types_by_name['Factory1Message']
1579+
1580+
bad_db = BadDB()
1581+
bad_db.Add(self.factory_test1_fd)
1582+
self.bad_pool = descriptor_pool.DescriptorPool(bad_db)
1583+
1584+
def testFindExtensionByNumber(self):
1585+
ext = self.pool.FindExtensionByNumber(self.message_desc, 1001)
1586+
self.assertEqual(ext.name, 'one_more_field')
1587+
1588+
def testFindAllExtensions(self):
1589+
extensions = self.pool.FindAllExtensions(self.message_desc)
1590+
self.assertEqual(len(extensions), 2)
1591+
1592+
def testIgnoreBadFindExtensionByNumber(self):
1593+
file_desc = self.bad_pool.FindFileByName(
1594+
'google/protobuf/internal/factory_test1.proto'
1595+
)
1596+
message_desc = file_desc.message_types_by_name['Factory1Message']
1597+
with self.assertRaises(KeyError):
1598+
ext = self.bad_pool.FindExtensionByNumber(message_desc, 1001)
1599+
1600+
def testIgnoreBadFindAllExtensions(self):
1601+
file_desc = self.bad_pool.FindFileByName(
1602+
'google/protobuf/internal/factory_test1.proto'
1603+
)
1604+
message_desc = file_desc.message_types_by_name['Factory1Message']
1605+
extensions = self.bad_pool.FindAllExtensions(message_desc)
1606+
self.assertEqual(len(extensions), 0)
1607+
1608+
def testFindAllExtensionsReturnsNoneList(self):
1609+
db = BadDB2()
1610+
db.Add(self.factory_test1_fd)
1611+
pool = descriptor_pool.DescriptorPool(db)
1612+
file_desc = pool.FindFileByName(
1613+
'google/protobuf/internal/factory_test1.proto'
1614+
)
1615+
message_desc = file_desc.message_types_by_name['Factory1Message']
1616+
extensions = self.bad_pool.FindAllExtensions(message_desc)
1617+
self.assertEqual(len(extensions), 0)
1618+
1619+
15341620
if __name__ == '__main__':
15351621
unittest.main()

python/google/protobuf/pyext/descriptor_database.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,14 @@ bool PyDescriptorDatabase::FindAllExtensionNumbers(
146146
return false;
147147
}
148148
Py_ssize_t size = PyList_Size(py_list.get());
149+
if (size == -1) {
150+
PyErr_Format(
151+
PyExc_RuntimeError,
152+
"FindAllExtensionNumbers() on fall back DB must return a list, not %S",
153+
py_list.get());
154+
PyErr_Print();
155+
return false;
156+
}
149157
int64_t item_value;
150158
for (Py_ssize_t i = 0 ; i < size; ++i) {
151159
ScopedPyObjectPtr item(PySequence_GetItem(py_list.get(), i));

0 commit comments

Comments
 (0)