#!/usr/bin/env python from util import * class _SkipListNode: '''Holds the data and the forward pointers in the skip list nodes.''' def __init__(self, v, levels): self.data = v self.next = [None] * levels class SkipList: def __init__(self): '''Create a new skip list.''' self.end = _SkipListNode(INFINITY, 1) self.header = _SkipListNode(None, 1) self.nlevels = 1 self.header.next[0] = self.end def find(self, x): '''Return x if it is in the skip list, otherwise None.''' # start at the header p = self.header # for each level starting with the top for level in xrange(self.nlevels - 1, -1, -1): # until we pass the desired value while p.next[level].data < x: p = p.next[level] # p must point to the pred of x p = p.next[0] return p.data if p.data == x else None def insert(self, x): '''Insert x into the skip list.''' # find the predecessors of x p, pred = self.__find_pred(x) if p.data == x: raise DuplicateKey # pick the height of this new node h = self.__random_height() if h > self.nlevels: d = h - self.nlevels self.__add_levels(d) pred.extend([self.header] * d) self.nlevels = h # for every level, patch it in q = _SkipListNode(x, h) for level in xrange(0, h): q.next[level] = pred[level].next[level] pred[level].next[level] = q def delete(self, x): '''Remove x, if present, else raise KeyError.''' p, pred = self.__find_pred(x) if p.data != x: raise KeyError for i, q in enumerate(pred): if q.next[i] is p: q.next[i] = p.next[i] def __find_pred(self, x): '''Find x and return its predecessors on each level. If x is not present, return its successor on level 0, and the nodes that would preceed it at each level.''' p = self.header # find the predecessor of x on each level pred = [None] * self.nlevels for level in xrange(self.nlevels-1, -1, -1): while p.next[level].data < x: p = p.next[level] pred[level] = p return p.next[0], pred def __add_levels(self, d): '''Increase the height of the skip list by d levels.''' self.end.next.extend([INFINITY] * d) self.header.next.extend([self.end] * d) def __random_height(self): '''Choose a random height for a new node.''' level = 1 while random_bit() == 1: level += 1 return level #================================================================================ # Unit Test #================================================================================ import unittest class TestSkipList(unittest.TestCase): def setUp(self): self.SL = SkipList() self.seq = range(1000) def testSorted(self): '''Inserts, finds, and deletes on a sorted sequence.''' # insert and make sure it got in for i in self.seq: self.SL.insert(i) self.assertEqual(self.SL.find(i), i) # make sure they all got in for i in self.seq: self.assertEqual(self.SL.find(i), i) # delete and make sure they are gone for i in self.seq: self.SL.delete(i) self.assertEqual(self.SL.find(i), None) def testRandom(self): '''Inserts, finds, and deletes on a random sequence.''' # insert and make sure they got in random.shuffle(self.seq) for a in self.seq: self.SL.insert(a) self.assertEqual(self.SL.find(a), a) # try to insert twice, and make sure error raised random.shuffle(self.seq) for a in self.seq: self.assertRaises(DuplicateKey, lambda: self.SL.insert(a)) # find each, delete it, and make sure they are removed random.shuffle(self.seq) for a in self.seq: self.assertEqual(self.SL.find(a), a) self.SL.delete(a) self.assertEqual(self.SL.find(a), None) random.shuffle(self.seq) for a in self.seq: self.assertEqual(self.SL.find(a), None) self.assertRaises(KeyError, lambda: self.SL.delete(a)) if __name__ == '__main__': unittest.main()