diff --git a/plugins/terminals/fps_terminals/routes.py b/plugins/terminals/fps_terminals/routes.py index 337c07fb..345948fb 100644 --- a/plugins/terminals/fps_terminals/routes.py +++ b/plugins/terminals/fps_terminals/routes.py @@ -37,7 +37,12 @@ async def create_terminal( self, user: User, ): - name = str(len(TERMINALS) + 1) + name_int = 1 + while True: + if str(name_int) not in TERMINALS: + break + name_int += 1 + name = str(name_int) terminal = Terminal( name=name, last_activity=datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z"), @@ -51,9 +56,10 @@ async def delete_terminal( name: str, user: User, ): - for websocket in TERMINALS[name]["server"].websockets: - TERMINALS[name]["server"].quit(websocket) - del TERMINALS[name] + if name in TERMINALS: + for websocket in TERMINALS[name]["server"].websockets: + TERMINALS[name]["server"].quit(websocket) + del TERMINALS[name] return Response(status_code=HTTPStatus.NO_CONTENT.value) async def terminal_websocket( @@ -65,6 +71,9 @@ async def terminal_websocket( return websocket, permissions = websocket_permissions await websocket.accept() + if name not in TERMINALS: + return + await TERMINALS[name]["server"].serve(websocket, permissions) if name in TERMINALS: TERMINALS[name]["server"].quit(websocket) diff --git a/plugins/terminals/fps_terminals/server.py b/plugins/terminals/fps_terminals/server.py index bbdb55b2..69b570a9 100644 --- a/plugins/terminals/fps_terminals/server.py +++ b/plugins/terminals/fps_terminals/server.py @@ -7,7 +7,7 @@ import termios from functools import partial -from anyio import create_memory_object_stream, create_task_group, from_thread, to_thread +from anyio import Lock, create_memory_object_stream, create_task_group, from_thread, to_thread from anyio.abc import ByteReceiveStream, ByteSendStream from fastapi import WebSocketDisconnect @@ -30,78 +30,86 @@ def __init__(self): self.fd = fd self.p_out = os.fdopen(self.fd, "w+b", 0) self.websockets = [] + self.task_group = None + self.lock = Lock() async def serve(self, websocket, permissions) -> None: - self.websocket = websocket self.permissions = permissions self.websockets.append(websocket) - async with create_task_group() as self.task_group: - self.recv_stream = ReceiveStream(self.p_out, self.task_group) - self.send_stream = SendStream(self.p_out) - self.task_group.start_soon(self.backend_to_frontend) - self.task_group.start_soon(self.frontend_to_backend) - - async def stop(self) -> None: - os.write(self.recv_stream.pipeout, b"0") - self.p_out.close() - try: - self.recv_stream.sel.unregister(self.p_out) - except Exception: - pass - self.task_group.cancel_scope.cancel() + await self.lock.acquire() + if self.task_group is None: + async with create_task_group() as self.task_group: + self.lock.release() + self.recv_stream_from_backend = ReceiveStream(self.p_out, self.task_group, self.quit) + async with create_task_group() as tg: + self.send_stream_to_backend = SendStream(self.p_out) + tg.start_soon(self.backend_to_frontend) + tg.start_soon(partial(self.frontend_to_backend, websocket)) + else: + self.lock.release() + async with create_task_group() as tg: + self.send_stream_to_backend = SendStream(self.p_out) + tg.start_soon(self.backend_to_frontend) + tg.start_soon(partial(self.frontend_to_backend, websocket)) async def backend_to_frontend(self): while True: - data = (await self.recv_stream.receive(65536)).decode() + data = (await self.recv_stream_from_backend.receive(65536)).decode() for websocket in self.websockets: await websocket.send_json(["stdout", data]) - async def frontend_to_backend(self): - await self.websocket.send_json(["setup", {}]) + async def frontend_to_backend(self, websocket): + await websocket.send_json(["setup", {}]) can_execute = self.permissions is None or "execute" in self.permissions.get("terminals", []) try: while True: - msg = await self.websocket.receive_json() + msg = await websocket.receive_json() if can_execute: if msg[0] == "stdin": - await self.send_stream.send(msg[1].encode()) + await self.send_stream_to_backend.send(msg[1].encode()) elif msg[0] == "set_size": winsize = struct.pack("HH", msg[1], msg[2]) fcntl.ioctl(self.fd, termios.TIOCSWINSZ, winsize) except WebSocketDisconnect: - self.quit(self.websocket) + self.quit(websocket) self.task_group.cancel_scope.cancel() - def quit(self, websocket): + def quit(self, websocket=None): try: - os.write(self.recv_stream.pipeout, b"0") - self.p_out.close() - self.recv_stream.sel.unregister(self.p_out) - self.websockets.remove(websocket) - if not self.websockets: - os.close(self.fd) + self.recv_stream_from_backend.sel.unregister(self.p_out) except Exception: pass + if websocket is None: + self.websockets.clear() + elif websocket in self.websockets: + self.websockets.remove(websocket) + if not self.websockets: + try: + os.close(self.fd) + except Exception: + pass + try: + self.p_out.close() + except Exception: + pass + self.task_group.cancel_scope.cancel() class ReceiveStream(ByteReceiveStream): - def __init__(self, p_out, task_group): + def __init__(self, p_out, task_group, quit): self.p_out = p_out + self.task_group = task_group + self.quit = quit self.sel = selectors.DefaultSelector() self.sel.register(self.p_out, selectors.EVENT_READ, self._read) - self.pipein, self.pipeout = os.pipe() - f = os.fdopen(self.pipein, "r+b", 0) - - def cb(): - return True - - self.sel.register(f, selectors.EVENT_READ, cb) self.send_stream, self.recv_stream = create_memory_object_stream[bytes]( max_buffer_size=65536 ) def reader(): + # this runs in a thread + # the callback is self._read, which returns True when the terminal was exited while True: events = self.sel.select() for key, mask in events: @@ -115,7 +123,7 @@ def _read(self) -> bool: try: data = self.p_out.read(65536) except OSError: - self.sel.unregister(self.p_out) + from_thread.run_sync(self.quit) return True else: from_thread.run_sync(self.send_stream.send_nowait, data)