How can I improve the the following rolling sum implementation?
type Buffer = State BufferState (Maybe Double)
type BufferState = ( [Double] , Int, Int )
-- circular buffer
buff :: Double -> Buffer
buff newVal = do
( list, ptr, len) <- get
-- if the list is not full yet just accumulate the new value
if length list < len
put ( newVal : list , ptr, len)
let nptr = (ptr - 1) `mod` len
(as,(v:bs)) = splitAt ptr list
nlist = as ++ (newVal : bs)
put (nlist, nptr, len)
return $ Just v
-- create intial state for circular buffer
initBuff l = (  , l-1 , l)
-- use the circular buffer to calculate a rolling sum
rollSum :: Double -> State (Double,BufferState) (Maybe Double)
rollSum newVal = do
(acc,bState) <- get
let (lv , bState' ) = runState (buff newVal) bState
acc' = acc + newVal
-- subtract the old value if the circular buffer is full
case lv of
Just x -> put ( acc' - x , bState') >> (return $ Just (acc' - x))
Nothing -> put ( acc' , bState') >> return Nothing
test :: (Double,BufferState) -> [Double] -> [Maybe Double] -> [Maybe Double]
test state  acc = acc
test state (x:xs) acc =
let (a,s) = runState (rollSum x) state
in test s xs (a:acc)
main :: IO()
main = print $ test (0,initBuff 3) [1,1,1,2,2,0] 
Buffer uses the State monad to implement a circular buffer. rollSum uses the State monad again to keep track of the rolling sum value and the state of the circular buffer.
- How could I make this more elegant?
- I'd like to implement other functions like rolling average or a difference, what could I do to make this easy?
I forgot to mention I am using a circular buffer as I intend to use this code on-line and process updates as they arrive - hence the need to record state. Something like
newRollingSum = update rollingSum newValue
Best How To :
I haven't managed to decipher all of your code, but here is the plan I would take for solving this problem. First, an English description of the plan:
- We need windows into the list of length
n starting at each index.
- Make windows of arbitrary length.
- Truncate long windows to length
- Drop the last
n-1 of these, which will be too short.
- For each window, add up the entries.
This was the first idea I had; for windows of length three it's an okay approach because step
2 is cheap on such a short list. For longer windows, you may want an alternate approach, which I will discuss below; but this approach has the benefit that it generalizes smoothly to functions other than
sum. The code might look like this:
rollingSums n xs
= map sum -- add up the entries
. zipWith (flip const) (drop (n-1) xs) -- drop the last n-1
. map (take n) -- truncate long windows
. tails -- make arbitrarily long windows
If you're familiar with the "equational reasoning" approach to optimization, you might spot a first place we can improve the performance of this function: by swapping the first
zipWith, we can produce a function with the same behavior but with a
map f . map g subterm, which can be replaced by
map (f . g) to get slightly less allocation.
Unfortunately, for large
n, this adds
n numbers together in the inner loop; we would prefer to simply add the value at the "front" of the window and subtract the one at the "back". So we need to get trickier. Here's a new idea: we'll traverse the list twice in parallel,
n positions apart. Then we'll use a simple function for getting the rolling sum (of unbounded window length) of prefixes of a list, namely,
scanl (+), to convert this traversal into the actual sums we're interested in.
rollingSumsEfficient n xs = scanl (+) firstSum deltas where
firstSum = sum (take n xs)
deltas = zipWith (-) (drop n xs) xs -- front - back
There's one twist, which is that
scanl never returns an empty list. So if it's important that you be able to handle short lists, you'll want another equation that checks for these. Don't use
length, as that forces the entire input list into memory before starting the computation -- a potentially lethal performance mistake. Instead add a line like this above the previous definition:
rollingSumsEfficient n xs | null (drop (n-1) xs) = 
We can try these two out in ghci. You'll notice that they do not quite have the same behavior as yours:
*Main> rollingSums 3 [10^n | n <- [0..5]]
*Main> rollingSumsEfficient 3 [10^n | n <- [0..5]]
On the other hand, the implementations are considerably more concise and are fully lazy in the sense that they work on infinite lists:
*Main> take 5 . rollingSums 10 $ [1..]
*Main> take 5 . rollingSumsEfficient 10 $ [1..]