## Tuesday, April 2, 2013

### Bitonic Sort in Haskell

This year at GDC, one of the talks was about how to use compute shaders to do work (in a game). The example program that they described was a program that used a parallel sort algorithm, bitonic sort, which ran on a graphics card. I hadn't heard of this sorting algorithm before (the only parallel sorting algorithm I'm familiar with is bead sort), so I decided to try to implement it in Haskell (because if something's worth doing, it's worth doing in Haskell :D)

The first thing I did was to naively re-implement the recursive python code that's on the Wikipedia page for bitonic sort. Ultimately, however, I wanted to make the algorithm parallel, and the picture describing the sequence of conditional swaps turned out to be very useful for understanding the runtime of the algorithm. It turns out that the sequence of comparisons can be determined in closed form based on which phase the algorithm is in. The phases can be simply chained together, where each phase swaps the relevant values in the vector (which can be done in parallel), and then the next phase can start. Because the number of work items per phase is constant, and the work per work item is also constant, this makes the algorithm a very good candidate to be run on a GPU.

I believe that IOVar's "atomic" functions actually lock a mutex (they have to, because the functions it evaluates can be arbitrarily complex, and no other thread is allowed to read the IOVar while the functions are being computed). This means that each thread locks and unlocks a mutex when it is done with its computation. However, because each thread gets a whole chunk of the input vector to work with, and there are only as many threads as there are processors in the system, the lock will likely not be contended.

I actually wanted to implement this in a pure way, using STVector instead of IOVector (using runST instead of unsafePerformIO and 'rpar' instead of 'forkIO'). However, it looks like STVectors have knowledge about their sequential context (the 's' variable in 'ST s'), so they can't perform any parallel operations. Shucks.

The results are pretty good: My algorithm performs, in general, about 4 times faster than the built-in Data.List.sort function. I've copied the code below.

{-# LANGUAGE DeriveDataTypeable #-}

module Bitonic (parallelBitonicSort, NPOTException) where

import           Control.Concurrent (forkIO)
import           Control.Concurrent.Chan (Chan, newChan, writeChan, readChan, dupChan)
import           Control.Exception (Exception, throw)
import           Data.Typeable (Typeable)
import           GHC.Conc (getNumCapabilities)
import           Data.Bits (shift, (.&.), Bits)
import           Data.IORef
import qualified Data.Vector.Unboxed.Mutable as MV
import           System.IO.Unsafe (unsafePerformIO)

data NPOTException = NPOTException
deriving (Typeable, Show)

instance Exception NPOTException

intlog2 :: (Bits a, Integral b) => a -> b
intlog2 1 = 1
intlog2 x = 1 + intlog2 (x `shift` (-1))

parallelBitonicSort :: (Ord a, MV.Unbox a) => [a] -> [a]
parallelBitonicSort l
| len .&. (len-1) == 0 = unsafePerformIO \$ parallelBitonicSort' l
| otherwise = throw NPOTException
where len = length l

parallelBitonicSort' :: (Ord a, MV.Unbox a) => [a] -> IO [a]
parallelBitonicSort' l = do
masterChannel <- newChan
counter <- newIORef 0
v <- MV.new len
sequence_ [MV.write v i x | (i, x) <- zip [0..] l]
caps <- getNumCapabilities
channels <- sequence \$ replicate (length \$ segmentedindices caps) \$ dupChan masterChannel
mapM_ forkIO [aggregateLine v counter chan x | (chan, x) <- zip channels \$ segmentedindices caps]
_ <- iterateUntil (== (majorpasses-1, majorpasses-1)) (readChan masterChannel)
sequence [MV.read v i | i <- [0..len-1]]
where len = length l
majorpasses = (intlog2 len) - 1
workunitsize caps = len `div` caps + 1
segmentedindices caps = segment (workunitsize caps) [0..len-1]

aggregateLine :: (Ord a, MV.Unbox a) => MV.IOVector a -> IORef Int -> Chan (Int, Int) -> [Int] -> IO ()
aggregateLine v counter chan indices = lineHelper 0 0
where len = MV.length v
swapper = swapIf v (>)
majorpasses = (intlog2 len) - 1
lineHelper major minor
| major == majorpasses = return ()
| otherwise = mapM_ dowork indices >> workdone
where dowork index
| minor == 0 && index .&. (1 `shift` major) == 0 = swapper index triangle
| index .&. (1 `shift` (major - minor)) == 0 = swapper index roll
| otherwise = return ()
where base = index - (index .&. (blocksize-1))
triangle = base + blocksize - (index - base) - 1
roll = index + ((blocksize `div` (2^minor)) `div` 2)
blocksize = 2^(major+1)
workdone = atomicModifyIORef counter inc >>= passMightBeDone >> rec
where inc x = y `seq` (y, y)
where y = x + (length indices)
passMightBeDone x
| x == len = atomicModifyIORef counter (\ _ -> (0, ())) >> writeChan chan (major, minor)
| otherwise = iterateUntil (== (major, minor)) (readChan chan) >> return ()
rec
| minor == major = lineHelper (major + 1) 0
| otherwise = lineHelper major (minor + 1)

swapIf :: (PrimMonad m, MV.Unbox t) => MV.MVector (PrimState m) t -> (t -> t -> Bool) -> Int -> Int -> m ()
swapIf v f a b = do
x <- MV.read v a
y <- MV.read v b
helper x y
where helper x y
| f x y = do
MV.write v a y
MV.write v b x
| otherwise = return ()

segment :: Int -> [a] -> [[a]]
segment size = helper
where helper [] = []
helper l = before : helper after
where (before, after) = splitAt size l