@@ -688,10 +688,14 @@ def visitModule(self, mod):
688
688
static int
689
689
ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
690
690
{
691
+ astmodulestate *state = get_global_ast_state();
692
+ if (state == NULL) {
693
+ return -1;
694
+ }
695
+
691
696
Py_ssize_t i, numfields = 0;
692
697
int res = -1;
693
698
PyObject *key, *value, *fields;
694
- astmodulestate *state = get_global_ast_state();
695
699
if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
696
700
goto cleanup;
697
701
}
@@ -761,6 +765,10 @@ def visitModule(self, mod):
761
765
ast_type_reduce(PyObject *self, PyObject *unused)
762
766
{
763
767
astmodulestate *state = get_global_ast_state();
768
+ if (state == NULL) {
769
+ return NULL;
770
+ }
771
+
764
772
PyObject *dict;
765
773
if (_PyObject_LookupAttr(self, state->__dict__, &dict) < 0) {
766
774
return NULL;
@@ -969,9 +977,8 @@ def visitModule(self, mod):
969
977
970
978
""" , 0 , reflow = False )
971
979
972
- self .emit ("static int init_types(void )" ,0 )
980
+ self .emit ("static int init_types(astmodulestate *state )" ,0 )
973
981
self .emit ("{" , 0 )
974
- self .emit ("astmodulestate *state = get_global_ast_state();" , 1 )
975
982
self .emit ("if (state->initialized) return 1;" , 1 )
976
983
self .emit ("if (init_identifiers(state) < 0) return 0;" , 1 )
977
984
self .emit ("state->AST_type = PyType_FromSpec(&AST_type_spec);" , 1 )
@@ -1046,40 +1053,55 @@ def emit_defaults(self, name, fields, depth):
1046
1053
class ASTModuleVisitor (PickleVisitor ):
1047
1054
1048
1055
def visitModule (self , mod ):
1049
- self .emit ("PyMODINIT_FUNC " , 0 )
1050
- self .emit ("PyInit__ast(void )" , 0 )
1056
+ self .emit ("static int " , 0 )
1057
+ self .emit ("astmodule_exec(PyObject *m )" , 0 )
1051
1058
self .emit ("{" , 0 )
1052
- self .emit ("PyObject *m = PyModule_Create(&_astmodule);" , 1 )
1053
- self .emit ("if (!m) {" , 1 )
1054
- self .emit ("return NULL;" , 2 )
1055
- self .emit ("}" , 1 )
1056
1059
self .emit ('astmodulestate *state = get_ast_state(m);' , 1 )
1057
- self .emit ('' , 1 )
1060
+ self .emit ("" , 0 )
1058
1061
1059
- self .emit ("if (!init_types()) {" , 1 )
1060
- self .emit ("goto error ;" , 2 )
1062
+ self .emit ("if (!init_types(state )) {" , 1 )
1063
+ self .emit ("return -1 ;" , 2 )
1061
1064
self .emit ("}" , 1 )
1062
1065
self .emit ('if (PyModule_AddObject(m, "AST", state->AST_type) < 0) {' , 1 )
1063
- self .emit ('goto error ;' , 2 )
1066
+ self .emit ('return -1 ;' , 2 )
1064
1067
self .emit ('}' , 1 )
1065
1068
self .emit ('Py_INCREF(state->AST_type);' , 1 )
1066
1069
self .emit ('if (PyModule_AddIntMacro(m, PyCF_ALLOW_TOP_LEVEL_AWAIT) < 0) {' , 1 )
1067
- self .emit ("goto error ;" , 2 )
1070
+ self .emit ("return -1 ;" , 2 )
1068
1071
self .emit ('}' , 1 )
1069
1072
self .emit ('if (PyModule_AddIntMacro(m, PyCF_ONLY_AST) < 0) {' , 1 )
1070
- self .emit ("goto error ;" , 2 )
1073
+ self .emit ("return -1 ;" , 2 )
1071
1074
self .emit ('}' , 1 )
1072
1075
self .emit ('if (PyModule_AddIntMacro(m, PyCF_TYPE_COMMENTS) < 0) {' , 1 )
1073
- self .emit ("goto error ;" , 2 )
1076
+ self .emit ("return -1 ;" , 2 )
1074
1077
self .emit ('}' , 1 )
1075
1078
for dfn in mod .dfns :
1076
1079
self .visit (dfn )
1077
- self .emit ("return m;" , 1 )
1078
- self .emit ("" , 0 )
1079
- self .emit ("error:" , 0 )
1080
- self .emit ("Py_DECREF(m);" , 1 )
1081
- self .emit ("return NULL;" , 1 )
1080
+ self .emit ("return 0;" , 1 )
1082
1081
self .emit ("}" , 0 )
1082
+ self .emit ("" , 0 )
1083
+ self .emit ("""
1084
+ static PyModuleDef_Slot astmodule_slots[] = {
1085
+ {Py_mod_exec, astmodule_exec},
1086
+ {0, NULL}
1087
+ };
1088
+
1089
+ static struct PyModuleDef _astmodule = {
1090
+ PyModuleDef_HEAD_INIT,
1091
+ .m_name = "_ast",
1092
+ .m_size = sizeof(astmodulestate),
1093
+ .m_slots = astmodule_slots,
1094
+ .m_traverse = astmodule_traverse,
1095
+ .m_clear = astmodule_clear,
1096
+ .m_free = astmodule_free,
1097
+ };
1098
+
1099
+ PyMODINIT_FUNC
1100
+ PyInit__ast(void)
1101
+ {
1102
+ return PyModuleDef_Init(&_astmodule);
1103
+ }
1104
+ """ .strip (), 0 , reflow = False )
1083
1105
1084
1106
def visitProduct (self , prod , name ):
1085
1107
self .addObj (name )
@@ -1095,7 +1117,7 @@ def visitConstructor(self, cons, name):
1095
1117
def addObj (self , name ):
1096
1118
self .emit ("if (PyModule_AddObject(m, \" %s\" , "
1097
1119
"state->%s_type) < 0) {" % (name , name ), 1 )
1098
- self .emit ("goto error ;" , 2 )
1120
+ self .emit ("return -1 ;" , 2 )
1099
1121
self .emit ('}' , 1 )
1100
1122
self .emit ("Py_INCREF(state->%s_type);" % name , 1 )
1101
1123
@@ -1255,11 +1277,10 @@ class PartingShots(StaticVisitor):
1255
1277
CODE = """
1256
1278
PyObject* PyAST_mod2obj(mod_ty t)
1257
1279
{
1258
- if (!init_types()) {
1280
+ astmodulestate *state = get_global_ast_state();
1281
+ if (state == NULL) {
1259
1282
return NULL;
1260
1283
}
1261
-
1262
- astmodulestate *state = get_global_ast_state();
1263
1284
return ast2obj_mod(state, t);
1264
1285
}
1265
1286
@@ -1281,10 +1302,6 @@ class PartingShots(StaticVisitor):
1281
1302
1282
1303
assert(0 <= mode && mode <= 2);
1283
1304
1284
- if (!init_types()) {
1285
- return NULL;
1286
- }
1287
-
1288
1305
isinstance = PyObject_IsInstance(ast, req_type[mode]);
1289
1306
if (isinstance == -1)
1290
1307
return NULL;
@@ -1303,11 +1320,10 @@ class PartingShots(StaticVisitor):
1303
1320
1304
1321
int PyAST_Check(PyObject* obj)
1305
1322
{
1306
- if (!init_types()) {
1323
+ astmodulestate *state = get_global_ast_state();
1324
+ if (state == NULL) {
1307
1325
return -1;
1308
1326
}
1309
-
1310
- astmodulestate *state = get_global_ast_state();
1311
1327
return PyObject_IsInstance(obj, state->AST_type);
1312
1328
}
1313
1329
"""
@@ -1358,12 +1374,35 @@ def generate_module_def(f, mod):
1358
1374
f .write (' PyObject *' + s + ';\n ' )
1359
1375
f .write ('} astmodulestate;\n \n ' )
1360
1376
f .write ("""
1361
- static astmodulestate global_ast_state;
1377
+ static astmodulestate*
1378
+ get_ast_state(PyObject *module)
1379
+ {
1380
+ void *state = PyModule_GetState(module);
1381
+ assert(state != NULL);
1382
+ return (astmodulestate*)state;
1383
+ }
1362
1384
1363
- static astmodulestate *
1364
- get_ast_state(PyObject *Py_UNUSED(module) )
1385
+ static astmodulestate*
1386
+ get_global_ast_state(void )
1365
1387
{
1366
- return &global_ast_state;
1388
+ _Py_IDENTIFIER(_ast);
1389
+ PyObject *name = _PyUnicode_FromId(&PyId__ast); // borrowed reference
1390
+ if (name == NULL) {
1391
+ return NULL;
1392
+ }
1393
+ PyObject *module = PyImport_GetModule(name);
1394
+ if (module == NULL) {
1395
+ if (PyErr_Occurred()) {
1396
+ return NULL;
1397
+ }
1398
+ module = PyImport_Import(name);
1399
+ if (module == NULL) {
1400
+ return NULL;
1401
+ }
1402
+ }
1403
+ astmodulestate *state = get_ast_state(module);
1404
+ Py_DECREF(module);
1405
+ return state;
1367
1406
}
1368
1407
1369
1408
static int astmodule_clear(PyObject *module)
@@ -1390,17 +1429,6 @@ def generate_module_def(f, mod):
1390
1429
astmodule_clear((PyObject*)module);
1391
1430
}
1392
1431
1393
- static struct PyModuleDef _astmodule = {
1394
- PyModuleDef_HEAD_INIT,
1395
- .m_name = "_ast",
1396
- .m_size = -1,
1397
- .m_traverse = astmodule_traverse,
1398
- .m_clear = astmodule_clear,
1399
- .m_free = astmodule_free,
1400
- };
1401
-
1402
- #define get_global_ast_state() (&global_ast_state)
1403
-
1404
1432
""" )
1405
1433
f .write ('static int init_identifiers(astmodulestate *state)\n ' )
1406
1434
f .write ('{\n ' )
0 commit comments