On Tue, Nov 29, 2005 at 12:29:55PM -0000, Simon Marlow wrote:
> > Alternatively, it would be nice to have a new STM primitive:
> > 
> >     wailUntil :: ClockTime -> STM ()
> > 
> > so you would wait until some time-point passes, not for a number of
> > time-units (waiting for a number of time-units wouldn't work because
> > of retries). I think it could be efficiently implemented, wouldn't it?
> But you could also implement this using registerTimeout, albeit with
> some more code and an extra thread, and waitUntil requires an
> implementation in the runtime which is not entirely trivial.

It is trivial to create a very inefficient implementation, in which all
STM transactions waiting on waitUntil will be retried on (almost) every
tick of the clock, let's say every second. You just create a TVar that
is updated with current time every second. But as I say, the efficiency
would be unacceptable.

But this can be improved. I found a simple solution that reduces the
number of transaction retries to O(log (t - t0)), where t0 is
transaction start time, and t is the parameter to waitUntil. I attached
a proof-of-concept implementation to this message.

I simply use the binary representation of time and wait only on the part
of bits, starting from most significant ones, that are enough to tell
that the waitUntil time has not come. To make it simple I used unix
epoch time in seconds, but the thing could be made more precise. I the
thread updating the time variable I make sure that I don't update the
bits that didn't change.

You can compile the Test module as a program. There are two kinds of
tests, some that show how many tries are made, and one that shows how
the thing works for many threads.

BTW, I tried to make the library interface simpler creating a default
top-level time variable

    {-# NOINLINE timeVar #-}
    timeVar :: TimeVar
    timeVar = unsafePerformIO initTimeVar

so I could export a waitUntil function with type

    waitUntil :: Time -> STM ()

but I tripped on something that was reported before, namely that STM
transactions can't be nested (as a result of unsafePerformIO or
unsafeInterleaveIO). Is there a plan to support such scenario?

Best regards

module Main where

import Debug.Trace
import Control.Concurrent.STM
import Control.Concurrent
import Control.Monad
import TimeVar
import Random
import Control.Exception (finally)

main = do
        [ test 1
        , testMany
        , test 5
        , test 10

runTests fs = do
    waitUntil <- liftM waitOnTimeVar initTimeVar
    mapM_ (\t -> t waitUntil >> putStrLn "") fs

-- waits for a given number of seconds
-- the STM transaction has a trace instruction that prints "try"
-- on every (re)try
test secs waitUntil = do
    putStrLn ("Testing waiting for " ++ show secs ++ " seconds")
    time1 <- getTime
    atomically $ do
        stmTrace "try"
        waitUntil (time1 + secs)
    time2 <- getTime
    putStrLn $ concat $
        [ show time2
        , " - "
        , show time1
        , " = "
        , show (time2 - time1)

-- spawn many threads, each of which
--    - takes the current time
--    - waits for a random number of seconds
--    - takes the current time again and checks that it slept for the
--      correct number of seconds
testMany waitUntil = do
    putStrLn "Testing many threads"
    messages <- atomically newTChan
    threadCount <- atomically (newTVar 0)
    let forkIO' t = do
            atomically (modifyTVar threadCount succ)
            forkIO (t `finally` atomically (modifyTVar threadCount pred))
    replicateM 100 $ forkIO' $ do
        secs <- liftM fromIntegral (randomRIO (1 :: Int, 10))
        time1 <- getTime
        atomically $ do
            waitUntil (time1 + secs)
        time2 <- getTime
        atomically $ writeTChan messages $ concat $
            [ "waited for "
            , show (time2 - time1)
            , " seconds, difference from requested: "
            , show ((time2 - time1) - secs)
    let loop = do
            join $ atomically $
                (do msg <- readTChan messages
                    return (putStrLn msg >> loop))
                (do n <- readTVar threadCount
                    guard (n == 0)
                    return (return ()))


stmTrace s = do
    v <- newTVar ()
    trace s v `seq` return ()

modifyTVar v f = readTVar v >>= writeTVar v . f

module TimeVar
    ( TimeVar
    , Time
    , getTime
    , initTimeVar
    , waitOnTimeVar
    ) where

import Control.Concurrent.STM
import Control.Concurrent
import Control.Monad
import System.Time (getClockTime, ClockTime(..))
import Data.Bits
import Data.Int

newtype TimeVar = TimeVar [TVar Bool]

type Time = Int32

getTime :: IO Time
getTime = do
    TOD secs _ <- getClockTime
    return (fromIntegral secs)

initTimeVar :: IO TimeVar
initTimeVar = do
    t0 <- liftM toBitsFromMSB getTime
    -- create the "bit vars"
    vars <- atomically $ do
            [ do
                v <- newTVar b
                return v
            | b <- t0 ]
    -- fork a bit vars update thread
    forkIO $ do
        sequence_ $ repeat $ do
            -- update every 0.1 sec
            threadDelay 100000
            t <- liftM toBitsFromMSB getTime
            atomically $ do
                    [ do
                        old <- readTVar var
                        -- don't update the bits that haven't changed
                        when (old /= new) (writeTVar var new)
                    | (var, new) <- zip vars t ]
    return (TimeVar vars)

waitOnTimeVar :: TimeVar -> Time -> STM ()
waitOnTimeVar (TimeVar vars) t = do
    let tBits = toBitsFromMSB t
    cmp vars tBits
    cmp (v:vs) (b:bs) = do
        vVal <- readTVar v
        case compare vVal b of
            LT -> retry
            EQ -> cmp vs bs
            GT -> return ()
    cmp [] [] = return ()

toBitsFromMSB :: Bits b => b -> [Bool]
toBitsFromMSB x = [ testBit x i | i <- [nBits-1, nBits-2 .. 0] ]
    nBits = bitSize x

