diff --git a/app.py b/app.py index b895747..e325ccf 100644 --- a/app.py +++ b/app.py @@ -73,6 +73,7 @@ muted_users: set = set() banned_usernames: set = set() banned_ips: set = set() message_timestamps: dict = defaultdict(list) +pending_pm_invites: dict = {} # sid → set of room names they were invited to RATE_LIMIT = 6 RATE_WINDOW = 5 @@ -360,6 +361,7 @@ def on_disconnect(): sid = request.sid user = connected_users.pop(sid, None) message_timestamps.pop(sid, None) + pending_pm_invites.pop(sid, None) if user and user.get("username"): lower = user["username"].lower() username_to_sid.pop(lower, None) @@ -526,6 +528,8 @@ def on_pm_open(data): room = _pm_room(user["username"], target) join_room(room) + if target_sid: + pending_pm_invites.setdefault(target_sid, set()).add(room) socketio.emit("pm_invite", {"from": user["username"], "room": room}, to=target_sid) emit("pm_ready", {"with": target, "room": room}) @@ -533,7 +537,14 @@ def on_pm_open(data): @socketio.on("pm_accept") def on_pm_accept(data): - join_room(data.get("room")) + sid = request.sid + room = str(data.get("room", "")) + allowed = pending_pm_invites.get(sid, set()) + if room not in allowed: + emit("error", {"msg": "Invalid or expired PM invitation."}) + return + allowed.discard(room) + join_room(room) @socketio.on("pm_message")