Tuesday, April 2, 2013

Bitonic Sort in Haskell

This year at GDC, one of the talks was about how to use compute shaders to do work (in a game). The example program that they described was a program that used a parallel sort algorithm, bitonic sort, which ran on a graphics card. I hadn't heard of this sorting algorithm before (the only parallel sorting algorithm I'm familiar with is bead sort), so I decided to try to implement it in Haskell (because if something's worth doing, it's worth doing in Haskell :D)

The first thing I did was to naively re-implement the recursive python code that's on the Wikipedia page for bitonic sort. Ultimately, however, I wanted to make the algorithm parallel, and the picture describing the sequence of conditional swaps turned out to be very useful for understanding the runtime of the algorithm. It turns out that the sequence of comparisons can be determined in closed form based on which phase the algorithm is in. The phases can be simply chained together, where each phase swaps the relevant values in the vector (which can be done in parallel), and then the next phase can start. Because the number of work items per phase is constant, and the work per work item is also constant, this makes the algorithm a very good candidate to be run on a GPU.

I elected to implement the different phases using a Chan. My initial attempts at creating n work items each phase, and then waiting for each of those items to finish ended up being too slow, so I elected to keep n work items alive throughout the lifetime of the algorithm, and perform synchronization between the elements. In particular, I elected to implement the synchronization using a Chan data structure, which lets one thread wait on the channel, which another thread can write to, thereby waking up the first thread. Chans also allow themselves to be duplicated, which means that when a value is written to a channel, all of the duplicated channels can read the value. I used this functionality to give each thread their own channel that it can wait on, but by linking them all together, one value can be written into the channel and wake all the threads up. I had each thread atomically increment a counter when it was done with its work, then when the counter (atomically) reaches the number of threads, I know that all the threads are done, and this last thread can write a value into the chan and wake everyone up for the next phase. I also didn't want to create a thread for each element in the input array (millions of threads is too much overhead), so I query for the number of "capabilities" which represents the number of processors, spin up that number of threads, and assign each thread a chunk of the input array (hopefully the chunks are relatively the same size).

I believe that IOVar's "atomic" functions actually lock a mutex (they have to, because the functions it evaluates can be arbitrarily complex, and no other thread is allowed to read the IOVar while the functions are being computed). This means that each thread locks and unlocks a mutex when it is done with its computation. However, because each thread gets a whole chunk of the input vector to work with, and there are only as many threads as there are processors in the system, the lock will likely not be contended.

I actually wanted to implement this in a pure way, using STVector instead of IOVector (using runST instead of unsafePerformIO and 'rpar' instead of 'forkIO'). However, it looks like STVectors have knowledge about their sequential context (the 's' variable in 'ST s'), so they can't perform any parallel operations. Shucks.

The results are pretty good: My algorithm performs, in general, about 4 times faster than the built-in Data.List.sort function. I've copied the code below.

{-# LANGUAGE DeriveDataTypeable #-}

module Bitonic (parallelBitonicSort, NPOTException) where

import           Control.Concurrent (forkIO)
import           Control.Concurrent.Chan (Chan, newChan, writeChan, readChan, dupChan)
import           Control.Exception (Exception, throw)
import           Control.Monad.Loops (iterateUntil)
import           Control.Monad.Primitive (PrimMonad, PrimState)
import           Data.Typeable (Typeable)
import           GHC.Conc (getNumCapabilities)
import           Data.Bits (shift, (.&.), Bits)
import           Data.IORef
import qualified Data.Vector.Unboxed.Mutable as MV
import           System.IO.Unsafe (unsafePerformIO)

data NPOTException = NPOTException
  deriving (Typeable, Show)

instance Exception NPOTException

intlog2 :: (Bits a, Integral b) => a -> b
intlog2 1 = 1
intlog2 x = 1 + intlog2 (x `shift` (-1))

parallelBitonicSort :: (Ord a, MV.Unbox a) => [a] -> [a]
parallelBitonicSort l
  | len .&. (len-1) == 0 = unsafePerformIO $ parallelBitonicSort' l
  | otherwise = throw NPOTException
  where len = length l

parallelBitonicSort' :: (Ord a, MV.Unbox a) => [a] -> IO [a]
parallelBitonicSort' l = do
  masterChannel <- newChan
  counter <- newIORef 0
  v <- MV.new len
  sequence_ [MV.write v i x | (i, x) <- zip [0..] l]
  caps <- getNumCapabilities
  channels <- sequence $ replicate (length $ segmentedindices caps) $ dupChan masterChannel
  mapM_ forkIO [aggregateLine v counter chan x | (chan, x) <- zip channels $ segmentedindices caps]
  _ <- iterateUntil (== (majorpasses-1, majorpasses-1)) (readChan masterChannel)
  sequence [MV.read v i | i <- [0..len-1]]
  where len = length l
        majorpasses = (intlog2 len) - 1
        workunitsize caps = len `div` caps + 1
        segmentedindices caps = segment (workunitsize caps) [0..len-1]

aggregateLine :: (Ord a, MV.Unbox a) => MV.IOVector a -> IORef Int -> Chan (Int, Int) -> [Int] -> IO ()
aggregateLine v counter chan indices = lineHelper 0 0
  where len = MV.length v
        swapper = swapIf v (>)
        majorpasses = (intlog2 len) - 1
        lineHelper major minor
          | major == majorpasses = return ()
          | otherwise = mapM_ dowork indices >> workdone
          where dowork index
                  | minor == 0 && index .&. (1 `shift` major) == 0 = swapper index triangle
                  | index .&. (1 `shift` (major - minor)) == 0 = swapper index roll
                  | otherwise = return ()
                  where base = index - (index .&. (blocksize-1))
                        triangle = base + blocksize - (index - base) - 1
                        roll = index + ((blocksize `div` (2^minor)) `div` 2)
                blocksize = 2^(major+1)
                workdone = atomicModifyIORef counter inc >>= passMightBeDone >> rec
                  where inc x = y `seq` (y, y)
                          where y = x + (length indices)
                passMightBeDone x
                  | x == len = atomicModifyIORef counter (\ _ -> (0, ())) >> writeChan chan (major, minor)
                  | otherwise = iterateUntil (== (major, minor)) (readChan chan) >> return ()
                rec
                  | minor == major = lineHelper (major + 1) 0
                  | otherwise = lineHelper major (minor + 1)

swapIf :: (PrimMonad m, MV.Unbox t) => MV.MVector (PrimState m) t -> (t -> t -> Bool) -> Int -> Int -> m ()
swapIf v f a b = do
  x <- MV.read v a
  y <- MV.read v b
  helper x y
  where helper x y
          | f x y = do
              MV.write v a y
              MV.write v b x
          | otherwise = return ()

segment :: Int -> [a] -> [[a]]
segment size = helper
  where helper [] = []
        helper l = before : helper after
          where (before, after) = splitAt size l


No comments:

Post a Comment