The Absolutely Simplest Consistent Hashing Example

July 07, 2012 at 12:01 PM | Code

Lately I've been studying Redis a lot. When using key/value databases like Redis, as well as caches like Memcached, if you want to scale keys across multiple nodes, you need a consistent hashing algorithm. Consistent hashing is what we use when we want to distribute a set of keys along a span of key/value servers in a...well consistent fashion.

If you Google around to learn what consistent hashing means, the article that most directly tells you "the answer" without a lot of handwringing is Consistent Hashing by Tom White. Not only does it explain the concept very clearly, it even has a plain and simple code example in Java.

The recipe in Tom's post is dependent on the capabilities of Java's TreeMap which we don't have in Python, but after some contemplation it became apparent that the functionality of circle.tailMap(hash) is something we already have using bisect, that is, we have a sorted array of integers, and a new number. Where in the array does the new number go? bisect.bisect() will give you that, with the same efficiency as TreeMap.

As a sanity check, I searched a bit more for Python implementations. I found a recipe by Amir Salihefendic, which seems to be based on the Java recipe and is pretty nice, but in the post he's searching the circle for hash values using a linear search, ouch! Turns out Amir is in fact using bisect in his Python Cheese Shop package hash_ring, but by then it was too late, I had already written my own recipe as well as tests (which hash_ring doesn't appear to have, at least in the downloaded distribution). There's also Continuum, taking a slightly more heavy-handed approach (three separate classes and an expensive IndexError being caught to detect keys beyond the circle). Both systems, Continuum more so, seem to encourage using hostnames directly as keys - as noted by Jeremy Zawodny, with a persistent system like Redis this is a bad idea as it means you can't move a particular key set to a new host.

So spending a bit of NIH capital, here's my recipe, which provides a dictionary interface so that you can store hostnames or even actual client instances, keyed to symbolic names:

import bisect
import md5

class ConsistentHashRing(object):
    """Implement a consistent hashing ring."""

    def __init__(self, replicas=100):
        """Create a new ConsistentHashRing.

        :param replicas: number of replicas.

        """
        self.replicas = replicas
        self._keys = []
        self._nodes = {}

    def _hash(self, key):
        """Given a string key, return a hash value."""

        return long(md5.md5(key).hexdigest(), 16)

    def _repl_iterator(self, nodename):
        """Given a node name, return an iterable of replica hashes."""

        return (self._hash("%s:%s" % (nodename, i))
                for i in xrange(self.replicas))

    def __setitem__(self, nodename, node):
        """Add a node, given its name.

        The given nodename is hashed
        among the number of replicas.

        """
        for hash_ in self._repl_iterator(nodename):
            if hash_ in self._nodes:
                raise ValueError("Node name %r is "
                            "already present" % nodename)
            self._nodes[hash_] = node
            bisect.insort(self._keys, hash_)

    def __delitem__(self, nodename):
        """Remove a node, given its name."""

        for hash_ in self._repl_iterator(nodename):
            # will raise KeyError for nonexistent node name
            del self._nodes[hash_]
            index = bisect.bisect_left(self._keys, hash_)
            del self._keys[index]

    def __getitem__(self, key):
        """Return a node, given a key.

        The node replica with a hash value nearest
        but not less than that of the given
        name is returned.   If the hash of the
        given name is greater than the greatest
        hash, returns the lowest hashed node.

        """
        hash_ = self._hash(key)
        start = bisect.bisect(self._keys, hash_)
        if start == len(self._keys):
            start = 0
        return self._nodes[self._keys[start]]

The map is used as a dictionary of node names to whatever you want, such as here we use Redis clients:

import redis
cr = = ConsistentHashRing(100)

cr["node1"] = redis.StrictRedis(host="host1")
cr["node2"] = redis.StrictRedis(host="host2")

client = cr["some key"]
data = client.get("some key")

I wanted to validate that the ring is in fact producing standard deviations like those mentioned in the Java article, so this is tested like the following:

import unittest
import collections
import random
import math

class ConsistentHashRingTest(unittest.TestCase):
    def test_get_distribution(self):
        ring = ConsistentHashRing(100)

        numnodes = 10
        numhits = 1000
        numvalues = 10000

        for i in range(1, 1 + numnodes):
            ring["node%d" % i] = "node_value%d" % i

        distributions = collections.defaultdict(int)
        for i in xrange(numhits):
            key = str(random.randint(1, numvalues))
            node = ring[key]
            distributions[node] += 1

        # count of hits matches what is observed
        self.assertEquals(sum(distributions.values()), numhits)

        # I've observed standard deviation for 10 nodes + 100
        # replicas to be between 10 and 15.   Play around with
        # the number of nodes / replicas to see how different
        # tunings work out.
        standard_dev = self._pop_std_dev(distributions.values())
        self.assertLessEqual(standard_dev, 20)

        # if the stddev is good, it's safe to assume
        # all nodes were used
        self.assertEquals(len(distributions), numnodes)

        # just to test getting keys, see that we got the values
        # back and not keys or indexes or whatever.
        self.assertEquals(
                set(distributions.keys()),
                set("node_value%d" % i for i in range(1, 1 + numnodes))
            )

    def _pop_std_dev(self, population):
        mean = sum(population) / len(population)
        return math.sqrt(
                sum(pow(n - mean, 2) for n in population)
                / len(population)
            )