diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index bf375cbef6..2252af6199 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -457,6 +457,7 @@ def init_settings( "config": jupyter_app.config, "config_dir": jupyter_app.config_dir, "allow_password_change": jupyter_app.allow_password_change, + "accept_kernel_env_var": jupyter_app.accept_kernel_env_var, "server_root_dir": root_dir, "jinja2_env": env, "serverapp": jupyter_app, @@ -1427,6 +1428,12 @@ def _default_allow_remote(self) -> bool: """, ) + accept_kernel_env_vars = Bool( + False, + config=True, + help="""Allow a user to setup custom env variables while launching or selecting a kernel""", + ) + browser = Unicode( "", config=True, diff --git a/jupyter_server/services/kernels/kernelmanager.py b/jupyter_server/services/kernels/kernelmanager.py index cd8a9de71f..e0df82360e 100644 --- a/jupyter_server/services/kernels/kernelmanager.py +++ b/jupyter_server/services/kernels/kernelmanager.py @@ -224,6 +224,7 @@ async def _async_start_kernel( # type:ignore[override] The name identifying which kernel spec to launch. This is ignored if an existing kernel is returned, but it may be checked in the future. """ + # if kernel_id is None or kernel_id not in self: if path is not None: kwargs["cwd"] = self.cwd_for_path(path, env=kwargs.get("env", {})) diff --git a/jupyter_server/services/sessions/handlers.py b/jupyter_server/services/sessions/handlers.py index 3e013c0335..09fb67f9e0 100644 --- a/jupyter_server/services/sessions/handlers.py +++ b/jupyter_server/services/sessions/handlers.py @@ -49,6 +49,8 @@ async def post(self): # (unless a session already exists for the named session) sm = self.session_manager + accept_kernel_env_vars = self.settings.get("accept_kernel_env_vars", False) + model = self.get_json_body() if model is None: raise web.HTTPError(400, "No JSON data provided") @@ -77,6 +79,7 @@ async def post(self): kernel = model.get("kernel", {}) kernel_name = kernel.get("name", None) kernel_id = kernel.get("id", None) + custom_env_vars = kernel["env"] if "env" in kernel and accept_kernel_env_vars else {} if not kernel_id and not kernel_name: self.log.debug("No kernel specified, using default kernel") @@ -93,6 +96,7 @@ async def post(self): kernel_id=kernel_id, name=name, type=mtype, + custom_env_vars=custom_env_vars, ) except NoSuchKernel: msg = ( @@ -141,6 +145,8 @@ async def patch(self, session_id): # get the previous session model before = await sm.get_session(session_id=session_id) + accept_kernel_env_vars = self.settings.get("accept_kernel_env_vars", False) + changes = {} if "notebook" in model and "path" in model["notebook"]: self.log.warning("Sessions API changed, see updated swagger docs") @@ -152,6 +158,8 @@ async def patch(self, session_id): changes["name"] = model["name"] if "type" in model: changes["type"] = model["type"] + if "env" in model and accept_kernel_env_vars: + changes["custom_env_vars"] = model["env"] if "kernel" in model: # Kernel id takes precedence over name. if model["kernel"].get("id") is not None: @@ -160,6 +168,12 @@ async def patch(self, session_id): raise web.HTTPError(400, "No such kernel: %s" % kernel_id) changes["kernel_id"] = kernel_id elif model["kernel"].get("name") is not None: + custom_env_vars = ( + model["kernel"]["env"] + if "env" in model["kernel"] and accept_kernel_env_vars + else {} + ) + kernel_name = model["kernel"]["name"] kernel_id = await sm.start_kernel_for_session( session_id, @@ -167,6 +181,7 @@ async def patch(self, session_id): name=before["name"], path=before["path"], type=before["type"], + custom_env_vars=custom_env_vars, ) changes["kernel_id"] = kernel_id diff --git a/jupyter_server/services/sessions/sessionmanager.py b/jupyter_server/services/sessions/sessionmanager.py index 8b392b4e1b..f7bf015870 100644 --- a/jupyter_server/services/sessions/sessionmanager.py +++ b/jupyter_server/services/sessions/sessionmanager.py @@ -205,6 +205,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._pending_sessions = KernelSessionRecordList() + _custom_envs: Dict[str, Optional[Dict[str, Any]]] = {} # Session database initialized below _cursor = None _connection = None @@ -267,6 +268,7 @@ async def create_session( type: Optional[str] = None, kernel_name: Optional[KernelName] = None, kernel_id: Optional[str] = None, + custom_env_vars: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Creates a session and returns its model @@ -283,7 +285,7 @@ async def create_session( pass else: kernel_id = await self.start_kernel_for_session( - session_id, path, name, type, kernel_name + session_id, path, name, type, kernel_name, custom_env_vars ) record.kernel_id = kernel_id self._pending_sessions.update(record) @@ -319,6 +321,7 @@ async def start_kernel_for_session( name: Optional[ModelName], type: Optional[str], kernel_name: Optional[KernelName], + custom_env_vars: Optional[Dict[str, Any]] = None, ) -> str: """Start a new kernel for a given session. @@ -335,16 +338,25 @@ async def start_kernel_for_session( the type of the session kernel_name : str the name of the kernel specification to use. The default kernel name will be used if not provided. + custom_env_vars: dict + dictionary of custom env variables """ # allow contents manager to specify kernels cwd kernel_path = await ensure_async(self.contents_manager.get_kernel_path(path=path)) - kernel_env = self.get_kernel_env(path, name) + + # if we have custom env than we have to add them to available env variables + if custom_env_vars is not None and isinstance(custom_env_vars, dict): + for key, value in custom_env_vars.items(): + kernel_env[key] = value + kernel_id = await self.kernel_manager.start_kernel( path=kernel_path, kernel_name=kernel_name, env=kernel_env, ) + + self._custom_envs[kernel_id] = custom_env_vars return cast(str, kernel_id) async def save_session(self, session_id, path=None, name=None, type=None, kernel_id=None): @@ -445,8 +457,8 @@ async def update_session(self, session_id, **kwargs): and the value replaces the current value in the session with session_id. """ - await self.get_session(session_id=session_id) + await self.get_session(session_id=session_id) if not kwargs: # no changes return @@ -464,7 +476,15 @@ async def update_session(self, session_id, **kwargs): "SELECT path, name, kernel_id FROM session WHERE session_id=?", [session_id] ) path, name, kernel_id = self.cursor.fetchone() - self.kernel_manager.update_env(kernel_id=kernel_id, env=self.get_kernel_env(path, name)) + env = self.get_kernel_env(path, name) + + # if we have custom env than we have to add them to available env variables + if isinstance(self._custom_envs, dict): + custom_env = self._custom_envs.get(kernel_id) + if custom_env is not None and isinstance(custom_env, dict): + for key, value in custom_env.items(): + env[key] = value + self.kernel_manager.update_env(kernel_id=kernel_id, env=env) async def kernel_culled(self, kernel_id: str) -> bool: """Checks if the kernel is still considered alive and returns true if its not found.""" @@ -526,6 +546,9 @@ async def delete_session(self, session_id): record = KernelSessionRecord(session_id=session_id) self._pending_sessions.update(record) session = await self.get_session(session_id=session_id) - await ensure_async(self.kernel_manager.shutdown_kernel(session["kernel"]["id"])) + kernel_id = session["kernel"]["id"] + if kernel_id in self._custom_envs: + del self._custom_envs[kernel_id] + await ensure_async(self.kernel_manager.shutdown_kernel(kernel_id)) self.cursor.execute("DELETE FROM session WHERE session_id=?", (session_id,)) self._pending_sessions.remove(record) diff --git a/tests/services/sessions/test_manager.py b/tests/services/sessions/test_manager.py index bd092259e0..fc756b0850 100644 --- a/tests/services/sessions/test_manager.py +++ b/tests/services/sessions/test_manager.py @@ -319,6 +319,20 @@ async def test_update_session(session_manager): assert model == expected +async def test_update_session_with_custom_env_vars(session_manager): + custom_env_vars = {"test_env_name": "test_env_value"} + await session_manager.create_session( + path="/path/to/test.ipynb", + kernel_name="julia", + type="notebook", + custom_env_vars=custom_env_vars, + ) + kernel_id = "A" + custom_envs = session_manager._custom_envs[kernel_id] + expected = "test_env_value" + assert custom_envs["test_env_name"] == expected + + async def test_bad_update_session(session_manager): # try to update a session with a bad keyword ~ raise error session = await session_manager.create_session(