diff --git a/Lib/immutable.py b/Lib/immutable.py index 3d1c34e4a43c5a..001697b97d4c5e 100644 --- a/Lib/immutable.py +++ b/Lib/immutable.py @@ -18,6 +18,7 @@ FREEZABLE_NO = _c.FREEZABLE_NO FREEZABLE_EXPLICIT = _c.FREEZABLE_EXPLICIT FREEZABLE_PROXY = _c.FREEZABLE_PROXY +InterpreterLocal = _c.InterpreterLocal def freezable(cls): @@ -56,6 +57,7 @@ def frozen(cls): "FREEZABLE_NO", "FREEZABLE_EXPLICIT", "FREEZABLE_PROXY", + "InterpreterLocal", "freezable", "unfreezable", "explicitlyFreezable", diff --git a/Lib/test/test_freeze/test_interpreterlocal.py b/Lib/test/test_freeze/test_interpreterlocal.py new file mode 100644 index 00000000000000..ac975bc3711970 --- /dev/null +++ b/Lib/test/test_freeze/test_interpreterlocal.py @@ -0,0 +1,232 @@ +import os +import unittest +from immutable import freeze, isfrozen, InterpreterLocal +from test.support import import_helper + + +class TestInterpreterLocalBasic(unittest.TestCase): + """Test basic InterpreterLocal with immutable default value.""" + + def test_get_returns_default(self): + field = InterpreterLocal(42) + self.assertEqual(field.get(), 42) + + def test_set(self): + field = InterpreterLocal(42) + field.set(99) + self.assertEqual(field.get(), 99) + + def test_get_consistent(self): + field = InterpreterLocal("hello") + self.assertIs(field.get(), field.get()) + + def test_none_default(self): + field = InterpreterLocal(None) + self.assertIsNone(field.get()) + + def test_tuple_default(self): + t = (1, 2, 3) + field = InterpreterLocal(t) + self.assertEqual(field.get(), (1, 2, 3)) + + +class TestInterpreterLocalFactory(unittest.TestCase): + """Test InterpreterLocal with factory callable.""" + + def test_factory_returns_new_value(self): + field = InterpreterLocal(lambda: []) + result = field.get() + self.assertIsInstance(result, list) + self.assertEqual(result, []) + + def test_factory_called_once(self): + field = InterpreterLocal(lambda: []) + first = field.get() + second = field.get() + self.assertIs(first, second) + + def test_factory_value_is_mutable(self): + field = InterpreterLocal(lambda: []) + field.get().append(1) + self.assertEqual(field.get(), [1]) + + def test_factory_set_overrides(self): + field = InterpreterLocal(lambda: []) + field.get() # initialise + field.set([1, 2, 3]) + self.assertEqual(field.get(), [1, 2, 3]) + + +class TestInterpreterLocalFreeze(unittest.TestCase): + """Test that InterpreterLocal works within frozen object graphs.""" + + def test_freeze_object_with_interpreterlocal(self): + class Container: + pass + + c = Container() + c.field = InterpreterLocal(42) + freeze(c) + self.assertTrue(isfrozen(c)) + + def test_value_accessible_after_freeze(self): + class Container: + pass + + c = Container() + c.field = InterpreterLocal(42) + freeze(c) + self.assertEqual(c.field.get(), 42) + + def test_value_mutable_after_freeze(self): + class Container: + pass + + c = Container() + c.field = InterpreterLocal(42) + freeze(c) + c.field.set(99) + self.assertEqual(c.field.get(), 99) + + def test_factory_works_after_freeze(self): + class Container: + pass + + c = Container() + c.field = InterpreterLocal(lambda: {}) + freeze(c) + result = c.field.get() + self.assertIsInstance(result, dict) + + def test_interpreterlocal_itself_frozen(self): + field = InterpreterLocal(42) + freeze(field) + self.assertTrue(isfrozen(field)) + + def test_factory_result_mutable_after_freeze(self): + class Container: + pass + + c = Container() + c.field = InterpreterLocal(lambda: []) + freeze(c) + c.field.get().append("item") + self.assertEqual(c.field.get(), ["item"]) + + +class TestInterpreterLocalErrors(unittest.TestCase): + """Test error cases.""" + + def test_no_args(self): + with self.assertRaises(TypeError): + InterpreterLocal() + + def test_multiple_independent_fields(self): + f1 = InterpreterLocal(1) + f2 = InterpreterLocal(2) + self.assertEqual(f1.get(), 1) + self.assertEqual(f2.get(), 2) + f1.set(10) + self.assertEqual(f1.get(), 10) + self.assertEqual(f2.get(), 2) + + def test_non_freezable_default(self): + """Non-freezable default should raise at construction.""" + from immutable import NotFreezable + class NF(NotFreezable): + pass + with self.assertRaises(TypeError): + InterpreterLocal(NF()) + + def test_non_freezable_factory(self): + """Non-freezable factory should raise at construction.""" + from immutable import set_freezable, FREEZABLE_NO + def factory(): + return [] + set_freezable(factory, FREEZABLE_NO) + with self.assertRaises(TypeError): + InterpreterLocal(factory) + + +class TestInterpreterLocalSubinterpreters(unittest.TestCase): + """Test that InterpreterLocal provides per-interpreter isolation.""" + + def setUp(self): + self._interpreters = import_helper.import_module('_interpreters') + + def _run_in_subinterp(self, code, shared=None): + r, w = os.pipe() + wrapped = ( + "import contextlib, os\n" + f"with open({w}, 'w', encoding='utf-8') as spipe:\n" + " with contextlib.redirect_stdout(spipe):\n" + ) + for line in code.splitlines(): + wrapped += " " + line + "\n" + + interp = self._interpreters.create() + try: + with os.fdopen(r, encoding='utf-8') as rpipe: + self._interpreters.run_string( + interp, wrapped, shared=shared or {}) + return rpipe.read() + finally: + self._interpreters.destroy(interp) + + def test_shared_frozen_object_gets_default_in_subinterp(self): + """A frozen InterpreterLocal shared to a sub-interpreter + should return the default value there, not main's value.""" + field = InterpreterLocal(42) + field.set(999) + self.assertEqual(field.get(), 999) + + # Freeze so it can be shared directly (immutable sharing) + freeze(field) + + output = self._run_in_subinterp( + "print(field.get())\n", + shared={"field": field}, + ) + self.assertEqual(output.strip(), "42") + + def test_shared_frozen_object_set_independent(self): + """Setting a value in the sub-interpreter should not affect main.""" + field = InterpreterLocal(0) + freeze(field) + + self._run_in_subinterp( + "field.set(123)\n" + "print(field.get())\n", + shared={"field": field}, + ) + # Main interpreter's value should still be 0 + self.assertEqual(field.get(), 0) + + def test_shared_frozen_container_with_interpreterlocal(self): + """A frozen container with an InterpreterLocal field should + provide per-interpreter isolation when shared.""" + class Container: + pass + + c = Container() + c.counter = InterpreterLocal(lambda: []) + freeze(c) + + # Main interpreter uses the field + c.counter.get().append("main") + self.assertEqual(c.counter.get(), ["main"]) + + # Sub-interpreter gets its own fresh list from the factory + output = self._run_in_subinterp( + "c.counter.get().append('sub')\n" + "print(c.counter.get())\n", + shared={"c": c}, + ) + self.assertEqual(output.strip(), "['sub']") + + # Main should be unaffected + self.assertEqual(c.counter.get(), ["main"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/Modules/_immutablemodule.c b/Modules/_immutablemodule.c index 3578442aebe6ff..7b3869286a09dd 100644 --- a/Modules/_immutablemodule.c +++ b/Modules/_immutablemodule.c @@ -10,6 +10,7 @@ #include #include "pycore_object.h" #include "pycore_immutability.h" +#include "pycore_critical_section.h" /*[clinic input] module _immutable @@ -20,6 +21,8 @@ module _immutable typedef struct { PyObject *not_freezable_error_obj; + PyObject *interpreter_locals; // dict: InterpreterLocal -> value + PyObject *interpreterlocal_type; // heap type object } immutable_state; static struct PyModuleDef _immutablemodule; @@ -37,6 +40,8 @@ immutable_clear(PyObject *module) { immutable_state *module_state = PyModule_GetState(module); Py_CLEAR(module_state->not_freezable_error_obj); + Py_CLEAR(module_state->interpreter_locals); + Py_CLEAR(module_state->interpreterlocal_type); return 0; } @@ -45,6 +50,8 @@ immutable_traverse(PyObject *module, visitproc visit, void *arg) { immutable_state *module_state = PyModule_GetState(module); Py_VISIT(module_state->not_freezable_error_obj); + Py_VISIT(module_state->interpreter_locals); + Py_VISIT(module_state->interpreterlocal_type); return 0; } @@ -152,6 +159,176 @@ _immutable_set_freezable_impl(PyObject *module, PyObject *obj, int status) Py_RETURN_NONE; } +/* + * InterpreterLocal type + * + * An immutable indirection to per-interpreter mutable state. + * tp_reachable hides per-interpreter values from the freeze walk. + */ + +typedef struct { + PyObject_HEAD + PyObject *default_value; // Immutable default, or NULL if factory form + PyObject *factory; // Frozen callable, or NULL if value form +} PyInterpreterLocalObject; + +static PyObject * +interpreterlocal_get_locals(PyObject *self) +{ + PyObject *module = PyType_GetModuleByDef(Py_TYPE(self), &_immutablemodule); + if (module == NULL) { + return NULL; + } + return get_immutable_state(module)->interpreter_locals; +} + +static PyObject * +interpreterlocal_lookup(PyInterpreterLocalObject *self) +{ + PyObject *locals = interpreterlocal_get_locals((PyObject *)self); + if (locals == NULL) { + return NULL; + } + + PyObject *val = NULL; + // Under free-threading (--disable-gil), multiple threads in the same + // interpreter share this dict. The critical section makes the + // get-or-init compound operation atomic so the factory is called + // at most once per interpreter. On GIL builds this is a no-op. + Py_BEGIN_CRITICAL_SECTION(locals); + int ret = PyDict_GetItemRef(locals, (PyObject *)self, &val); + if (ret == 0) { + // Not found — initialise + if (self->factory != NULL) { + val = PyObject_CallNoArgs(self->factory); + } + else { + val = Py_NewRef(self->default_value); + } + if (val != NULL) { + if (PyDict_SetItem(locals, (PyObject *)self, val) < 0) { + Py_CLEAR(val); + } + } + } + else if (ret < 0) { + val = NULL; + } + Py_END_CRITICAL_SECTION(); + return val; +} + +static PyObject * +interpreterlocal_get(PyObject *self, PyObject *Py_UNUSED(ignored)) +{ + return interpreterlocal_lookup((PyInterpreterLocalObject *)self); +} + +static PyObject * +interpreterlocal_set(PyObject *self, PyObject *value) +{ + PyObject *locals = interpreterlocal_get_locals(self); + if (locals == NULL) { + return NULL; + } + if (PyDict_SetItem(locals, self, value) < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +static int +interpreterlocal_init(PyObject *self, PyObject *args, PyObject *kwds) +{ + static char *kwlist[] = {"default", NULL}; + PyObject *default_or_factory = NULL; + PyInterpreterLocalObject *il = (PyInterpreterLocalObject *)self; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, + &default_or_factory)) { + return -1; + } + + if (PyCallable_Check(default_or_factory)) { + if (_PyImmutability_Freeze(default_or_factory) < 0) { + return -1; + } + il->factory = Py_NewRef(default_or_factory); + il->default_value = NULL; + } + else { + if (_PyImmutability_Freeze(default_or_factory) < 0) { + return -1; + } + il->default_value = Py_NewRef(default_or_factory); + il->factory = NULL; + } + return 0; +} + +static void +interpreterlocal_dealloc(PyObject *self) +{ + PyInterpreterLocalObject *il = (PyInterpreterLocalObject *)self; + PyObject_GC_UnTrack(self); + Py_CLEAR(il->default_value); + Py_CLEAR(il->factory); + PyTypeObject *tp = Py_TYPE(self); + tp->tp_free(self); + Py_DECREF(tp); +} + +static int +interpreterlocal_traverse(PyObject *self, visitproc visit, void *arg) +{ + PyInterpreterLocalObject *il = (PyInterpreterLocalObject *)self; + Py_VISIT(Py_TYPE(self)); + Py_VISIT(il->default_value); + Py_VISIT(il->factory); + return 0; +} + +static int +interpreterlocal_reachable(PyObject *self, visitproc visit, void *arg) +{ + // Visit the type and the frozen fields. + // Do NOT visit per-interpreter stored values — that's the escape hatch. + PyInterpreterLocalObject *il = (PyInterpreterLocalObject *)self; + Py_VISIT(Py_TYPE(self)); + Py_VISIT(il->default_value); + Py_VISIT(il->factory); + return 0; +} + +static PyMethodDef interpreterlocal_methods[] = { + {"get", interpreterlocal_get, METH_NOARGS, + "Return the value for the current interpreter."}, + {"set", interpreterlocal_set, METH_O, + "Set the value for the current interpreter."}, + {NULL, NULL} +}; + +static PyType_Slot interpreterlocal_slots[] = { + {Py_tp_dealloc, interpreterlocal_dealloc}, + {Py_tp_init, interpreterlocal_init}, + {Py_tp_methods, interpreterlocal_methods}, + {Py_tp_traverse, interpreterlocal_traverse}, + {Py_tp_reachable, interpreterlocal_reachable}, + {Py_tp_new, PyType_GenericNew}, + {Py_tp_alloc, PyType_GenericAlloc}, + {Py_tp_free, PyObject_GC_Del}, + {0, NULL}, +}; + +static PyType_Spec interpreterlocal_spec = { + .name = "_immutable.InterpreterLocal", + .basicsize = sizeof(PyInterpreterLocalObject), + .flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE | + Py_TPFLAGS_HAVE_GC), + .slots = interpreterlocal_slots, +}; + + static PyType_Slot not_freezable_error_slots[] = { {0, NULL}, }; @@ -218,6 +395,27 @@ immutable_exec(PyObject *module) { return -1; } + /* Create InterpreterLocal heap type */ + module_state->interpreterlocal_type = PyType_FromModuleAndSpec( + module, &interpreterlocal_spec, NULL); + if (module_state->interpreterlocal_type == NULL) { + return -1; + } + if (PyModule_AddType(module, + (PyTypeObject *)module_state->interpreterlocal_type) != 0) { + return -1; + } + if (_PyImmutability_RegisterFreezable( + (PyTypeObject *)module_state->interpreterlocal_type) < 0) { + return -1; + } + + /* Create per-interpreter locals dict */ + module_state->interpreter_locals = PyDict_New(); + if (module_state->interpreter_locals == NULL) { + return -1; + } + if (PyModule_AddIntConstant(module, "FREEZABLE_YES", _Py_FREEZABLE_YES) != 0) { return -1;