Let's Build Real-time Session Invalidation

May 10, 2020

Image of spiderman pointing at his clone

Demo Repo

Some applications need to limit users to a single client or browser instance. This post covers how to build, improve, and scale this feature. We begin with a simple web app with two API endpoints:

  • Users log in by sending their user ID in the user HTTP request header to the /login route. Here's an example request/response:
curl -H "user:user123" localhost:9000/login
{"sessionId":"364rl8"}
  • The user adds sessionid=364rl8 as an HTTP header for the route /api. If the session ID is valid, the server returns "authenticated”, if not, the server returns an error:
curl -H "sessionid=364rl8" localhost:9000/api
authenticated

curl -H "sessionid=badSession" localhost:9000/api
error: invalid session

Our example returns the session ID in the HTTP response body, but it's more common in practice to store the session ID as a cookie, where the server returns the Set-Cookie: sessionid=364rl8; HTTP header. This causes the browser to automatically include the session ID in all subsequent requests to the same domain.

1. The Simplest Solution

The simplest solution is to use a server-side session cache that generates and stores a session ID for each user ID.

const { generateSessionId } = require("./utils");
const cors = require("cors");
const app = require("express")().use(cors());
 
const PORT = 9000;
// this will totally scale, trust me
const sessions = {};
 
app.get("/login", (req, res) => {
  const { user } = req.headers;
 
  if (!user) {
    res.status(400).send("error: request must include the 'user' HTTP header");
  } else {
    const sessionId = generateSessionId();
    sessions[user] = sessionId;
 
    res.send({ sessionId });
  }
});
 
app.get("/api", (req, res) => {
  const { sessionid } = req.headers;
 
  if (!sessionid) {
    res.status(401).send("error: no sessionId. Log in at /login");
  } else {
    if (Object.values(sessions).includes(sessionid)) {
      res.send("authenticated");
    } else {
      res.status(401).send("error: invalid session.");
    }
  }
});
 
app.listen(PORT, () => {
  console.log(`server started on http://localhost:${PORT}`);
});

Whenever a user successfully logs in, the session ID will be overridden. Requests that include outdated session IDs will fail validation, causing the server to return an error. However, if the client does not make an API request, the user will not know that the session was invalidated.

Ideally, we want a client-side function that can tell us when the session is no longer valid:

async function logIn(userId, onSessionInvalidated)

The logIn function takes a callback function (as the second argument) that will be invoked whenever we detect that the session is no longer valid. We can implement this API in two ways: polling and server-push.

2. Polling

If the client polls the /api endpoint, we can detect an invalid session without explicit user interaction.

async function logIn(userId, onSessionInvalidated) {
  const response = await fetch("http://localhost:9000/login", {
    headers: {
      user: userId,
    },
  });
  const { sessionId } = await response.json();
 
  const POLLING_INTERVAL = 200;
  const poll = setInterval(async () => {
    const response = await fetch("http://localhost:9000/api", {
      headers: {
        sessionId,
      },
    });
 
    if (response.status !== 200) {
      // non-200 status code means the token is invalid
      clearTimeout(poll);
      onSessionInvalidated();
    }
  }, POLLING_INTERVAL);
 
  return sessionId;
}

However, polling forces us to make a trade-off between latency and efficiency. The shorter the polling interval, the more quickly we can detect a bad session at the cost of more wasted polls.

3. Server-Push

If we encounter bottlenecks with the polling solution, then our final solution is to maintain a persistent, bi-directional channel on which the server can tell connected clients when their sessions are invalidated. For this demo, we'll use Web Sockets. To host a Web Socket server, we use the ws package.

const wss = new WebSocket.Server({ port: 9001 });

wss.on("connection", (ws) => {
  ws.on("message", (data) => {
    const request = JSON.parse(data);
    if (request.action === "subscribeToSessionInvalidation") {
      const { sessionId } = request.args;
      subscribeToSessionInvalidation(sessionId, () => {
        ws.send(
          JSON.stringify({
            event: "sessionInvalidated",
            args: {
              sessionId,
            },
          })
        );
      });
    }
  });
});

This code tells the server to listen for incoming Web Socket connections on port 9001. For each new connection, listen for messages and assume the following format:

{
  action: "action ID",
  args: {...}
}

If the action value is "subscribeToSessionInvalidation", notify that client whenever the specified session ID is invalidated. Note: this solution requires generating session IDs that are hard to guess.

We also need to update our logIn route handler to detect existing sessions and publish the invalidation event:

app.get("/login", (req, res) => {
  const { user } = req.headers;

  if (!user) {
    res.status(400).send("error: request must include the 'user' HTTP header");
  } else {
    const existingSession = sessions[user];
    if (existingSession) {
      publishSessionInvalidation(existingSession);
    }
    const sessionId = generateSessionId();
    sessions[user] = sessionId;

    res.send({ sessionId });
  }
});

Here's the code for subscribeToSessionInvalidation and publishSessionInvalidation:

const { EventEmitter } = require("events");
const sessionEvents = new EventEmitter();

const SESSION_INVALIDATED = "session_invalidated";

function publishSessionInvalidation(sessionId) {
  sessionEvents.emit(SESSION_INVALIDATED, sessionId);
}

function subscribeToSessionInvalidation(sessionId, callback) {
  const listener = (invalidatedSessionId) => {
    if (sessionId === invalidatedSessionId) {
      sessionEvents.removeListener(SESSION_INVALIDATED, listener);
      callback();
    }
  };

  sessionEvents.addListener(SESSION_INVALIDATED, listener);
}

module.exports = {
  publishSessionInvalidation,
  subscribeToSessionInvalidation,
};

Now we are ready to update the client to use the WebSocket DOM API to replace our polling logic:

async function logIn(userId, onSessionInvalidated) {
  const response = await fetch("http://localhost:9000/login", {
    headers: {
      user: userId,
    },
  });
  const { sessionId } = await response.json();

  const socket = new WebSocket("ws://localhost:9001");
  socket.addEventListener("open", () => {
    console.log("connected.");
    socket.addEventListener("message", ({ data }) => {
      const { event, args } = JSON.parse(data);
      if (event === "sessionInvalidated") {
        // args.sessionId should equal sessionId
        onSessionInvalidated();
      }
    });
    socket.send(
      JSON.stringify({
        action: "subscribeToSessionInvalidation",
        args: {
          sessionId,
        },
      })
    );
  });

  socket.addEventListener("error", (error) => {
    console.error(error);
  });

  return sessionId;
}

Load /push/index.html in your browser, and try it out. You should now see some real-time session invalidation action.

4. Scaling

Oreilly book on scaling

I bet you noticed that this solution doesn't scale. To create a more scalable version, we need to make the following changes:

  1. Move the session cache to a scalable distributed cache
  2. Move from event emitter to a scalable distributed pubsub system
  3. Update the client to add retry logic on disconnect

Redis satisfies requirements #1 and #2. If we need to scale Redis, we can deploy a Redis cluster or we can use a hosted version of Redis, such as Amazon ElastiCache.

Redis as a remote session cache

First, let's spin up a redis instance. Assuming you have a docker host available:

docker run -d -p 6739:6739 redis

Make sure that port 6739 is open if you're running this on a cloud VM. If you don't have a cloud VM, you can launch a t2.micro instances on EC2 as part of the AWS free tier. Once your VM is launched, you can install docker.

There are many articles that discuss session caching with Redis; here's my approach, using the redis npm package:

// remoteCache.js
const redis = require("redis");

const SessionCacheKey = "sessions";

client = redis.createClient({
  host: process.env.REDIS_HOST
});

async function getSession(userId) {
  return new Promise((resolve) => {
    return client.hmget(SessionCacheKey, userId, (err, res) => {
      resolve(res ? (Array.isArray(res) ? res[0] : res) : null);
    });
  });
}

async function putSession(userId, sessionId) {
  return new Promise((resolve) => {
    client.hmset(SessionCacheKey, userId, sessionId, (err, res) => {
      resolve(res ? (Array.isArray(res) ? res[0] : res) : null);
    });
  });
}

We use the Redis commands HMGET and HMSET (HM stands for "hash map") to respectively read and write the tuple [user ID, session ID]. That takes care of the session storage, we still need to replace event emitter with Redis. The redis npm docs state:

When a client issues a SUBSCRIBE or PSUBSCRIBE, that connection is put into a "subscriber" mode. At that point, the only valid commands are those that modify the subscription set, and quit (also ping on some redis versions). When the subscription set is empty, the connection is put back into regular mode.

So we need to create two Redis clients, one for general commands, the other for dedicated subscriber commands:

// remoteCache.js

const SessionInvalidationChannel = "sessionInvalidation";
const pendingCallbacks = {};

async function connect() {
  client = redis.createClient({
    host: process.env.REDIS_HOST
  });
  // the redis client we're using works in two modes "normal" and
  // "subscriber". So we duplicate a client here and use that
  // for our subscriptions.
  subscriber = client.duplicate();

  return Promise.all([
    new Promise((resolve) => {
      client.on("ready", () => resolve());
    }),
    new Promise((resolve) => {
      subscriber.on("ready", () => {
        subscriber.on("message", (channel, invalidatedSession) => {
          console.log(channel, invalidatedSession);
          if (Object.keys(pendingCallbacks).includes(invalidatedSession)) {
            pendingCallbacks[invalidatedSession]();
            delete pendingCallbacks[invalidatedSession];
          }
        });

        subscriber.subscribe(SessionInvalidationChannel, () => {
          resolve();
        });
      });
    }),
  ]);
}

function publishSessionInvalidation(sessionId) {
  client.publish(SessionInvalidationChannel, sessionId);
}

function subscribeToSessionInvalidation(sessionId, callback) {
  pendingCallbacks[sessionId] = callback;
}

In the connect function, we subscribe to the "sessionInvalidation" channel. We publish to this channel when another module calls publishSessionInvalidation.

You can run the demo like so:

git clone https://github.com/robzhu/logged-out 
cd logged-out/push-redis/server
npm i && node server.js

Next, open /push-redis/index.html in two browser tabs and you should be able to see the working demo.

5. Native Client

Let's take a moment to consider example applications that need real-time session invalidation. A few that come to mind for me: games, streaming media clients, advanced finance applications (e.g. bloomberg terminal). Since these sorts of applications are often built as native clients, let's see how a .net client looks:

using System;
using System.Net.Http;
using System.Threading.Tasks;
using Newtonsoft.Json;
using Websocket.Client;

static class Program
{
  const string LoginEndpoint = "http://localhost:9000/login";
  const string UserID = "1234";
  static Uri WebSocketEndpoint = new Uri("ws://localhost:9001");

  static async Task Main(string[] args)
  {
    HttpClient client = new HttpClient();

    client.DefaultRequestHeaders.Add("user", UserID);
    dynamic response = JsonConvert.DeserializeObject(await client.GetStringAsync(LoginEndpoint));
    string sessionId = response.sessionId;
    Console.WriteLine("Obtained session ID: " + sessionId);

    using (var socket = new WebsocketClient(WebSocketEndpoint))
    {
      await socket.Start();

      socket.MessageReceived.Subscribe(msg =>
      {
        dynamic payload = JsonConvert.DeserializeObject(msg.Text);
        if (payload["event"] == "sessionInvalidated")
        {
          Console.WriteLine("You have logged in elsewhere. Exiting.");
          Environment.Exit(0);
        }
      });

      socket.Send(JsonConvert.SerializeObject(new
      {
        action = "subscribeToSessionInvalidation",
        args = new
        {
          sessionId = sessionId
        }
      }));

      Console.WriteLine("Press ENTER to exit.");
      Console.ReadLine();
    }
  }
}
You can run the .net client and web client side by side and watch them invalidate one another.

Of the many rough edges in the demo, the lack of type safety around the API stands out to me. Specifically, the topic names and the schema for the subscription request and response. Scaling this solution beyond one developer would require comprehensive documentation or a client-server type system, like a GraphQL schema.

Over the course of building this demo, people have suggested several other solutions:

I hope this article gave you some ideas for building real-time session invalidation.