diff --git a/anonstream/routes/websocket.py b/anonstream/routes/websocket.py index 5aa6b75..e97955e 100644 --- a/anonstream/routes/websocket.py +++ b/anonstream/routes/websocket.py @@ -4,26 +4,32 @@ import asyncio from quart import current_app, websocket -from anonstream.user import see +from anonstream.user import see, ensure_allowedness, AllowednessException from anonstream.websocket import websocket_outbound, websocket_inbound from anonstream.routes.wrappers import with_user_from @current_app.websocket('/live') -@with_user_from(websocket, fallback_to_token=True) +@with_user_from(websocket, fallback_to_token=True, ignore_allowedness=True) async def live(timestamp, user_or_token): match user_or_token: case str() | None: await websocket.send_json({'type': 'kick'}) await websocket.close(1001) case dict() as user: - queue = asyncio.Queue() - user['websockets'][queue] = timestamp - user['last']['reading'] = timestamp - - producer = websocket_outbound(queue, user) - consumer = websocket_inbound(queue, user) try: - await asyncio.gather(producer, consumer) - finally: - see(user) - user['websockets'].pop(queue) + ensure_allowedness(user, timestamp=timestamp) + except AllowednessException: + await websocket.send_json({'type': 'kick'}) + await websocket.close(1001) + else: + queue = asyncio.Queue() + user['websockets'][queue] = timestamp + user['last']['reading'] = timestamp + + producer = websocket_outbound(queue, user) + consumer = websocket_inbound(queue, user) + try: + await asyncio.gather(producer, consumer) + finally: + see(user) + user['websockets'].pop(queue) diff --git a/anonstream/tasks.py b/anonstream/tasks.py index fe92fee..65b0247 100644 --- a/anonstream/tasks.py +++ b/anonstream/tasks.py @@ -9,7 +9,7 @@ from quart import current_app, websocket from anonstream.broadcast import broadcast, broadcast_users_update from anonstream.stream import is_online, get_stream_title, get_stream_uptime_and_viewership -from anonstream.user import get_absent_users, get_sunsettable_users, deverify +from anonstream.user import get_absent_users, get_sunsettable_users, deverify, ensure_allowedness, AllowednessException from anonstream.wrappers import with_timestamp CONFIG = current_app.config @@ -116,6 +116,12 @@ async def t_close_websockets(timestamp, iteration): else: for user in USERS: for queue in user['websockets']: + # Check allowedness + try: + ensure_allowedness(user, timestamp=timestamp) + except AllowednessException: + queue.put_nowait({'type': 'kick'}) + # Check expiry last_pong = user['websockets'][queue] last_pong_ago = timestamp - last_pong if last_pong_ago > THRESHOLD: diff --git a/anonstream/websocket.py b/anonstream/websocket.py index 72d2875..a20d393 100644 --- a/anonstream/websocket.py +++ b/anonstream/websocket.py @@ -9,7 +9,7 @@ from quart import current_app, websocket from anonstream.stream import get_stream_title, get_stream_uptime_and_viewership from anonstream.captcha import get_random_captcha_digest_for from anonstream.chat import get_all_messages_for_websocket, add_chat_message, Rejected -from anonstream.user import get_all_users_for_websocket, see, reading, verify, deverify, BadCaptcha, try_change_appearance +from anonstream.user import get_all_users_for_websocket, see, reading, verify, deverify, BadCaptcha, try_change_appearance, ensure_allowedness, AllowednessException from anonstream.wrappers import with_timestamp, get_timestamp from anonstream.utils.chat import generate_nonce from anonstream.utils.user import identifying_string @@ -18,6 +18,9 @@ from anonstream.utils.websocket import parse_websocket_data, Malformed, WS CONFIG = current_app.config async def websocket_outbound(queue, user): + # This function does NOT check alllowedness at first, only later. + # Allowedness is assumed to be checked beforehand (by the route handler). + # These first two websocket messages are always sent. await websocket.send_json({'type': 'ping'}) await websocket.send_json({ 'type': 'init', @@ -36,14 +39,26 @@ async def websocket_outbound(queue, user): }) while True: payload = await queue.get() - if payload['type'] == 'close': + if payload['type'] == 'kick': + await websocket.send_json(payload) + await websocket.close(1001) + break + elif payload['type'] == 'close': await websocket.close(1011) break else: - await websocket.send_json(payload) + try: + ensure_allowedness(user) + except AllowednessException: + websocket.send_json({'type': 'kick'}) + await websocket.close(1001) + break + else: + await websocket.send_json(payload) async def websocket_inbound(queue, user): while True: + # Read from websocket try: receipt = await websocket.receive_json() except json.JSONDecodeError: @@ -51,26 +66,34 @@ async def websocket_inbound(queue, user): finally: timestamp = get_timestamp() see(user, timestamp=timestamp) - try: - receipt_type, parsed = parse_websocket_data(receipt) - except Malformed as e: - error , *_ = e.args - payload = { - 'type': 'error', - 'because': error, - } - else: - match receipt_type: - case WS.MESSAGE: - handle = handle_inbound_message - case WS.APPEARANCE: - handle = handle_inbound_appearance - case WS.CAPTCHA: - handle = handle_inbound_captcha - case WS.PONG: - handle = handle_inbound_pong - payload = handle(timestamp, queue, user, *parsed) + # Prepare response + try: + ensure_allowedness(user) + except AllowednessException: + payload = {'type': 'kick'} + else: + try: + receipt_type, parsed = parse_websocket_data(receipt) + except Malformed as e: + error , *_ = e.args + payload = { + 'type': 'error', + 'because': error, + } + else: + match receipt_type: + case WS.MESSAGE: + handle = handle_inbound_message + case WS.APPEARANCE: + handle = handle_inbound_appearance + case WS.CAPTCHA: + handle = handle_inbound_captcha + case WS.PONG: + handle = handle_inbound_pong + payload = handle(timestamp, queue, user, *parsed) + + # Write to websocket if payload is not None: queue.put_nowait(payload)