Hello all,

 HaskellDB is a nice, combinator library but there are
two main disadvantages. 

    * HaskellDB uses 'Trex' module which is Hugs
specific. Both GHC and NHC doesn't support 'Trex'.
    * HaskellDB cann't execute stored SQL procedures
and doesn't allow to execute plain SQL statements.

 I use HSQL (see mail attachment) for data access.
HSQL works with ODBC but its user interface isn't ODBC
specific. The module can be rewriten to use native
drivers for specific databases
(MySQL,PostgresSQL,Oracle,Sybase,...). If there is
somebody interested in using of HSQL I will place it
on CVS.

Krasimir Angelov

__________________________________________________
Do You Yahoo!?
Yahoo! - Official partner of 2002 FIFA World Cup
http://fifaworldcup.yahoo.com
module HSQL
                ( SqlBind(..), SqlError(..), SqlType(..), Connection, Statement
                , catchSql                      -- :: IO a -> (SqlError -> IO a) -> IO 
a
                , connect                       -- :: String -> String -> String -> IO 
Connection
                , disconnect            -- :: Connection -> IO ()
                , execute                       -- :: Connection -> String -> IO 
Statement
                , closeStatement        -- :: Statement -> IO ()
                , fetch                         -- :: Statement -> IO Bool
                , inTransaction         -- :: Connection -> (Connection -> IO a) -> IO 
a
                , getFieldValue         -- :: SqlBind a => Statement -> String -> IO a
                , getFieldValueType     -- :: Statement -> String -> (SqlType, Bool)
                , getFieldsTypes        -- :: Statement -> (String, SqlType, Bool)
                , forEachRow            -- :: (Statement -> s -> IO s) -> Statement -> 
s -> IO s
                , forEachRow'           -- :: (Statement -> IO ()) -> Statement -> IO 
()
                , collectRows           -- :: (Statement -> IO s) -> Statement -> IO 
[s]
                )       where

import Word(Word32, Word16)
import Int(Int32, Int16)
import Foreign
import CString
import IORef
import Monad(when)
import Exception (throwDyn, catchDyn, Exception(..))
import Dynamic

#include <HSQLStructs.h>

type SQLHANDLE = Ptr ()
type HENV = SQLHANDLE
type HDBC = SQLHANDLE
type HSTMT = SQLHANDLE
type HENVRef = ForeignPtr ()

type SQLSMALLINT  = Int16
type SQLUSMALLINT = Word16
type SQLINTEGER   = Int32
type SQLUINTEGER  = Word32
type SQLRETURN    = SQLSMALLINT
type SQLLEN               = SQLINTEGER
type SQLULEN      = SQLINTEGER

foreign import stdcall "sqlext.h SQLAllocEnv" sqlAllocEnv :: Ptr HENV -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLFreeEnv" sqlFreeEnv :: HENV -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLAllocConnect" sqlAllocConnect :: HENV -> Ptr HDBC 
-> IO SQLRETURN
foreign import stdcall "sqlext.h SQLFreeConnect" sqlFreeConnect:: HDBC -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLConnect" sqlConnect :: HDBC -> CString -> Int -> 
CString -> Int -> CString -> Int -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLDisconnect" sqlDisconnect :: HDBC -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLAllocStmt" sqlAllocStmt :: HDBC -> Ptr HSTMT -> IO 
SQLRETURN
foreign import stdcall "sqlext.h SQLFreeStmt" sqlFreeStmt :: HSTMT -> SQLUSMALLINT -> 
IO SQLRETURN
foreign import stdcall "sqlext.h SQLNumResultCols" sqlNumResultCols :: HSTMT -> Ptr 
SQLUSMALLINT -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLDescribeCol" sqlDescribeCol :: HSTMT -> 
SQLUSMALLINT -> CString -> SQLSMALLINT -> Ptr SQLSMALLINT -> Ptr SQLSMALLINT -> Ptr 
SQLULEN -> Ptr SQLSMALLINT -> Ptr SQLSMALLINT -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLBindCol" sqlBindCol :: HSTMT -> SQLUSMALLINT -> 
SQLSMALLINT -> Ptr a -> SQLLEN -> Ptr SQLINTEGER -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLFetch" sqlFetch :: HSTMT -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLGetDiagRec" sqlGetDiagRec :: SQLSMALLINT -> 
SQLHANDLE -> SQLSMALLINT -> CString -> Ptr SQLINTEGER -> CString -> SQLSMALLINT -> Ptr 
SQLSMALLINT -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLExecDirect" sqlExecDirect :: HSTMT -> CString -> 
Int -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLSetConnectOption" sqlSetConnectOption :: HDBC -> 
SQLUSMALLINT -> SQLULEN -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLTransact" sqlTransact :: HENV -> HDBC -> 
SQLUSMALLINT -> IO SQLRETURN

data Connection
  =  Connection
                { hDBC :: HDBC
                , environment :: HENVRef
                }
                
type FieldDef = (String, SqlType, Bool, Int)

data Statement
  =  Statement
                { hSTMT :: HSTMT
                , connection :: Connection
                , fields :: [FieldDef]
                , fetchBuffer :: Ptr ()
                }

data SqlType    
        = SqlChar                       Int
        | SqlVarChar            Int
        | SqlLongVarChar        Int
        | SqlDecimal            Int Int
        | SqlNumeric            Int Int
        | SqlSmallInt
        | SqlInteger
        | SqlReal
        | SqlDouble
        | SqlBit
        | SqlTinyInt
        | SqlBigInt
        | SqlBinary                     Int
        | SqlVarBinary          Int
        | SqlLongVarBinary  Int
        | SqlDate
        | SqlTime
        | SqlTimeStamp
        deriving (Eq, Show)
        
data SqlError
   = SqlError
                { seState               :: String
                , seNativeError :: Int
                , seErrorMsg    :: String
                }
   | SqlNoData
   | SqlInvalidHandle
   | SqlStillExecuting
   | SqlNeedData
   deriving Show
   
-----------------------------------------------------------------------------------------
-- routines for handling exceptions
-----------------------------------------------------------------------------------------

{-# NOINLINE sqlErrorTy #-}
sqlErrorTy = mkAppTy (mkTyCon "SqlError") []

instance Typeable SqlError where
        typeOf x = sqlErrorTy

catchSql :: IO a -> (SqlError -> IO a) -> IO a
catchSql = catchDyn

sqlSuccess :: SQLRETURN -> Bool
sqlSuccess res = 
        (res == (#const SQL_SUCCESS)) || (res == (#const SQL_SUCCESS_WITH_INFO)) || 
(res == (#const SQL_NO_DATA))

handleSqlResult :: SQLSMALLINT -> SQLHANDLE -> SQLRETURN -> IO ()
handleSqlResult handleType handle res
        | sqlSuccess res = return ()
        | res == (#const SQL_INVALID_HANDLE) = throwDyn SqlInvalidHandle
        | res == (#const SQL_STILL_EXECUTING) = throwDyn SqlStillExecuting
        | res == (#const SQL_NEED_DATA) = throwDyn SqlNeedData
        | res == (#const SQL_ERROR) = do        
                pState  <- mallocBytes 256
                pNative <- malloc
                pMsg    <- mallocBytes 256
                pTextLen <- malloc
                sqlGetDiagRec handleType handle 1 pState pNative pMsg 256 pTextLen
                state <- peekCString pState
                free pState
                native <- peek pNative
                free pNative
                msg <- peekCString pMsg
                free pMsg
                free pTextLen   
                throwDyn (SqlError {seState=state, seNativeError=fromIntegral native, 
seErrorMsg=msg})
        | otherwise = error (show res)

-----------------------------------------------------------------------------------------
-- keeper of HENV
-----------------------------------------------------------------------------------------

{-# NOINLINE myEnvironment #-}
myEnvironment :: HENVRef
myEnvironment = unsafePerformIO $ do
        (phEnv :: Ptr HENV) <- malloc
        res <- sqlAllocEnv phEnv
        hEnv <- peek phEnv
        free phEnv
        handleSqlResult 0 nullPtr res
        newForeignPtr hEnv (closeEnvironment hEnv)
        where
                closeEnvironment :: HENV -> IO ()
                closeEnvironment hEnv = sqlFreeEnv hEnv >>= handleSqlResult (#const 
SQL_HANDLE_ENV) hEnv

-----------------------------------------------------------------------------------------
-- Connect/Disconnect
-----------------------------------------------------------------------------------------

connect :: String -> String -> String -> IO Connection
connect server user authentication = withForeignPtr myEnvironment $ \hEnv -> do
        (phDBC :: Ptr HDBC) <- malloc
        res <- sqlAllocConnect hEnv phDBC
        hDBC <- peek phDBC
        free phDBC
        handleSqlResult (#const SQL_HANDLE_ENV) hEnv res
        pServer <- newCString server
        pUser <- newCString user
        pAuthentication <- newCString authentication
        res <- sqlConnect hDBC pServer (length server) pUser (length user) 
pAuthentication (length authentication)
        free pServer
        free pUser
        free pAuthentication
        handleSqlResult (#const SQL_HANDLE_ENV) hEnv res
        return (Connection {hDBC=hDBC, environment=myEnvironment})
        
disconnect :: Connection -> IO ()
disconnect (Connection {hDBC=hDBC}) = do
        sqlDisconnect hDBC >>= handleSqlResult (#const SQL_HANDLE_DBC) hDBC
        sqlFreeConnect hDBC >>= handleSqlResult (#const SQL_HANDLE_DBC) hDBC
        return ()
        
-----------------------------------------------------------------------------------------
-- queries
-----------------------------------------------------------------------------------------

execute :: Connection -> String -> IO Statement
execute conn@(Connection {hDBC=hDBC}) query = do
        pFIELD <- mallocBytes (#const sizeof(FIELD))    
        res <- sqlAllocStmt hDBC ((#ptr FIELD, hSTMT) pFIELD)
        when (not (sqlSuccess res)) (free pFIELD)
        handleSqlResult (#const SQL_HANDLE_DBC) hDBC res
        hSTMT <- (#peek FIELD, hSTMT) pFIELD
        let handleResult res = do
                when (not (sqlSuccess res)) (free pFIELD)
                handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
        pQuery <- newCString query
        res <- sqlExecDirect hSTMT pQuery (length query)
        free pQuery
        handleResult res
        sqlNumResultCols hSTMT ((#ptr FIELD, fieldsCount) pFIELD) >>= handleResult
        count <- (#peek FIELD, fieldsCount) pFIELD
        (fields, offs) <- createBindState hSTMT pFIELD 0 1 count
        free pFIELD
        buffer <- mallocBytes offs
        let statement = Statement {hSTMT=hSTMT, connection=conn, fields=fields, 
fetchBuffer=buffer}
        catchSql (bindFields hSTMT buffer 1 fields) (errHandler statement)
        return statement
        where
                errHandler statement err = do
                        closeStatement statement
                        throwDyn err                    
                
                createBindState :: HSTMT -> Ptr a -> Int -> SQLUSMALLINT -> 
SQLUSMALLINT -> IO ([FieldDef], Int)
                createBindState hSTMT pFIELD offs n count
                        | n > count  = return ([], offs)
                        | otherwise = do
                                res <- sqlDescribeCol hSTMT n ((#ptr FIELD, fieldName) 
pFIELD) (#const FIELD_NAME_LENGTH) ((#ptr FIELD, NameLength) pFIELD) ((#ptr FIELD, 
DataType) pFIELD) ((#ptr FIELD, ColumnSize) pFIELD) ((#ptr FIELD, DecimalDigits) 
pFIELD) ((#ptr FIELD, Nullable) pFIELD)
                                handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                                name <- peekCString ((#ptr FIELD, fieldName) pFIELD)
                                dataType <- (#peek FIELD, DataType) pFIELD
                                columnSize <- (#peek FIELD, ColumnSize) pFIELD
                                decimalDigits <- (#peek FIELD, DecimalDigits) pFIELD
                                (nullable :: SQLSMALLINT) <- (#peek FIELD, Nullable) 
pFIELD
                                let (sqlType, offs') = mkSqlType dataType columnSize 
decimalDigits (offs+(#const sizeof(SQLINTEGER)))
                                (fields, offs'') <- createBindState hSTMT pFIELD offs' 
(n+1) count
                                return ((name,sqlType,toBool nullable,offs):fields, 
offs'')
                                
                bindFields :: HSTMT -> Ptr () -> SQLUSMALLINT -> [FieldDef] -> IO ()
                bindFields hSTMT fetchBuffer n [] = return ()
                bindFields hSTMT fetchBuffer n ((name,SqlChar   size,           
nullable,offs):fields) = do
                        let buffer = fetchBuffer `plusPtr` offs
                        res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr 
buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral (size+1) * (#const 
sizeof(SQLCHAR))) (castPtr buffer)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        bindFields hSTMT fetchBuffer (n+1) fields
                bindFields hSTMT fetchBuffer n ((name,SqlVarChar size,          
nullable,offs):fields) = do
                        let buffer = fetchBuffer `plusPtr` offs
                        res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr 
buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral (size+1) * (#const 
sizeof(SQLCHAR))) (castPtr buffer)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        bindFields hSTMT fetchBuffer (n+1) fields
                bindFields hSTMT fetchBuffer n ((name,SqlLongVarChar size,      
nullable,offs):fields) = do
                        let buffer = fetchBuffer `plusPtr` offs
                        res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr 
buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral (size+1) * (#const 
sizeof(SQLCHAR))) (castPtr buffer)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        bindFields hSTMT fetchBuffer (n+1) fields
                bindFields hSTMT fetchBuffer n ((name,SqlDecimal size 
prec,nullable,offs):fields) = do
                        let buffer = fetchBuffer `plusPtr` offs
                        res <- sqlBindCol hSTMT n (#const SQL_C_DOUBLE) ((castPtr 
buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLDOUBLE)) (castPtr 
buffer)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        bindFields hSTMT fetchBuffer (n+1) fields
                bindFields hSTMT fetchBuffer n ((name,SqlNumeric size 
prec,nullable,offs):fields) = do
                        let buffer = fetchBuffer `plusPtr` offs
                        res <- sqlBindCol hSTMT n (#const SQL_C_DOUBLE) ((castPtr 
buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLDOUBLE)) (castPtr 
buffer)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        bindFields hSTMT fetchBuffer (n+1) fields
                bindFields hSTMT fetchBuffer n ((name,SqlSmallInt,              
nullable,offs):fields) = do
                        let buffer = fetchBuffer `plusPtr` offs
                        res <- sqlBindCol hSTMT n (#const SQL_C_SHORT) ((castPtr 
buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLSMALLINT)) (castPtr 
buffer)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        bindFields hSTMT fetchBuffer (n+1) fields
                bindFields hSTMT fetchBuffer n ((name,SqlInteger,                      
 nullable,offs):fields) = do
                        let buffer = fetchBuffer `plusPtr` offs
                        res <- sqlBindCol hSTMT n (#const SQL_C_LONG) ((castPtr 
buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLINTEGER)) (castPtr 
buffer)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        bindFields hSTMT fetchBuffer (n+1) fields
                bindFields hSTMT fetchBuffer n ((name,SqlReal,                  
nullable,offs):fields) = do
                        let buffer = fetchBuffer `plusPtr` offs
                        res <- sqlBindCol hSTMT n (#const SQL_C_DOUBLE) ((castPtr 
buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLDOUBLE)) (castPtr 
buffer)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        bindFields hSTMT fetchBuffer (n+1) fields
                bindFields hSTMT fetchBuffer n ((name,SqlDouble,                       
 nullable,offs):fields) = do
                        let buffer = fetchBuffer `plusPtr` offs
                        res <- sqlBindCol hSTMT n (#const SQL_C_DOUBLE) ((castPtr 
buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLDOUBLE)) (castPtr 
buffer)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        bindFields hSTMT fetchBuffer (n+1) fields
                bindFields hSTMT fetchBuffer n ((name,SqlBit,                          
 nullable,offs):fields) = do
                        let buffer = fetchBuffer `plusPtr` offs
                        res <- sqlBindCol hSTMT n (#const SQL_C_LONG) ((castPtr 
buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLINTEGER)) (castPtr 
buffer)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        bindFields hSTMT fetchBuffer (n+1) fields
                bindFields hSTMT fetchBuffer n ((name,SqlTinyInt,                      
 nullable,offs):fields) = do
                        let buffer = fetchBuffer `plusPtr` offs
                        res <- sqlBindCol hSTMT n (#const SQL_C_SHORT) ((castPtr 
buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLSMALLINT)) (castPtr 
buffer)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        bindFields hSTMT fetchBuffer (n+1) fields
                bindFields hSTMT fetchBuffer n ((name,SqlBigInt,                       
 nullable,offs):fields) = do
                        let buffer = fetchBuffer `plusPtr` offs
                        res <- sqlBindCol hSTMT n (#const SQL_C_LONG) ((castPtr 
buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLINTEGER)) (castPtr 
buffer)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        bindFields hSTMT fetchBuffer (n+1) fields
                bindFields hSTMT fetchBuffer n ((name,SqlBinary size,           
nullable,offs):fields) = do
                        let buffer = fetchBuffer `plusPtr` offs
                        res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr 
buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral size * (#const 
sizeof(SQLCHAR))) (castPtr buffer)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        bindFields hSTMT fetchBuffer (n+1) fields
                bindFields hSTMT fetchBuffer n ((name,SqlVarBinary size,        
nullable,offs):fields) = do
                        let buffer = fetchBuffer `plusPtr` offs
                        res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr 
buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral size * (#const 
sizeof(SQLCHAR))) (castPtr buffer)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        bindFields hSTMT fetchBuffer (n+1) fields
                bindFields hSTMT fetchBuffer n ((name,SqlLongVarBinary 
size,nullable,offs):fields)= do
                        let buffer = fetchBuffer `plusPtr` offs
                        res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr 
buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral size * (#const 
sizeof(SQLCHAR))) (castPtr buffer)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        bindFields hSTMT fetchBuffer (n+1) fields
                bindFields hSTMT fetchBuffer n ((name,SqlDate,                  
nullable,offs):fields) = do
                        let buffer = fetchBuffer `plusPtr` offs
                        res <- sqlBindCol hSTMT n (#const SQL_C_TYPE_DATE) ((castPtr 
buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQL_DATE_STRUCT)) 
(castPtr buffer)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        bindFields hSTMT fetchBuffer (n+1) fields
                bindFields hSTMT fetchBuffer n ((name,SqlTime,                  
nullable,offs):fields) = do
                        let buffer = fetchBuffer `plusPtr` offs
                        res <- sqlBindCol hSTMT n (#const SQL_C_TYPE_TIME) ((castPtr 
buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQL_TIME_STRUCT)) 
(castPtr buffer)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        bindFields hSTMT fetchBuffer (n+1) fields
                bindFields hSTMT fetchBuffer n ((name,SqlTimeStamp,             
nullable,offs):fields) = do
                        let buffer = fetchBuffer `plusPtr` offs
                        res <- sqlBindCol hSTMT n (#const SQL_C_TYPE_TIMESTAMP) 
((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const 
sizeof(SQL_TIMESTAMP_STRUCT)) (castPtr buffer)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        bindFields hSTMT fetchBuffer (n+1) fields
                
                        
                                
                mkSqlType :: SQLSMALLINT -> SQLULEN -> SQLSMALLINT -> Int -> (SqlType, 
Int)
                mkSqlType (#const SQL_CHAR)                     size _          offs   
 = (SqlChar (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1))
                mkSqlType (#const SQL_VARCHAR)          size _          offs    = 
(SqlVarChar (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1))
                mkSqlType (#const SQL_LONGVARCHAR)      size _          offs    = 
(SqlLongVarChar (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral 
size+1))
                mkSqlType (#const SQL_DECIMAL)          size prec       offs    = 
(SqlDecimal (fromIntegral size) (fromIntegral prec), offs + (#const sizeof(SQLDOUBLE)))
                mkSqlType (#const SQL_NUMERIC)          size prec       offs    = 
(SqlNumeric (fromIntegral size) (fromIntegral prec), offs + (#const sizeof(SQLDOUBLE)))
                mkSqlType (#const SQL_SMALLINT)         _        _              offs   
 = (SqlSmallInt, offs + (#const sizeof(SQLSMALLINT)))
                mkSqlType (#const SQL_INTEGER)          _    _          offs    = 
(SqlInteger, offs + (#const sizeof(SQLINTEGER)))
                mkSqlType (#const SQL_REAL)                     _        _             
 offs    = (SqlReal, offs + (#const sizeof(SQLDOUBLE)))
                mkSqlType (#const SQL_DOUBLE)           _        _              offs   
 = (SqlDouble, offs + (#const sizeof(SQLDOUBLE)))
                mkSqlType (#const SQL_BIT)                      _        _             
 offs    = (SqlBit, offs + (#const sizeof(SQLINTEGER)))
                mkSqlType (#const SQL_TINYINT)          _        _              offs   
 = (SqlTinyInt, offs + (#const sizeof(SQLSMALLINT)))
                mkSqlType (#const SQL_BIGINT)           _        _              offs   
 = (SqlBigInt, offs + (#const sizeof(SQLINTEGER)))
                mkSqlType (#const SQL_BINARY)           size _          offs    = 
(SqlBinary (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1))
                mkSqlType (#const SQL_VARBINARY)        size _          offs    = 
(SqlVarBinary (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral 
size+1))
                mkSqlType (#const SQL_LONGVARBINARY)size _              offs    = 
(SqlLongVarBinary (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral 
size+1))
                mkSqlType (#const SQL_DATE)                     _        _             
 offs    = (SqlDate, offs + (#const sizeof(SQL_DATE_STRUCT)))
                mkSqlType (#const SQL_TIME)                     _        _             
 offs    = (SqlTime, offs + (#const sizeof(SQL_TIME_STRUCT)))
                mkSqlType (#const SQL_TIMESTAMP)        _        _              offs   
 = (SqlTimeStamp, offs + (#const sizeof(SQL_TIMESTAMP_STRUCT)))


{-# NOINLINE fetch #-}
fetch :: Statement -> IO Bool
fetch stmt = do
        res <- sqlFetch (hSTMT stmt)
        handleSqlResult (#const SQL_HANDLE_STMT) (hSTMT stmt) res
        return (res /= (#const SQL_NO_DATA))
        

closeStatement :: Statement -> IO ()
closeStatement stmt = do
        sqlFreeStmt (hSTMT stmt) 0 >>= handleSqlResult (#const SQL_HANDLE_STMT) (hSTMT 
stmt)
        free (fetchBuffer stmt)

-----------------------------------------------------------------------------------------
-- transactions
-----------------------------------------------------------------------------------------

inTransaction :: Connection -> (Connection -> IO a) -> IO a
inTransaction conn@(Connection {hDBC=hDBC, environment=envRef}) action = 
withForeignPtr envRef $ \hEnv -> do
        sqlSetConnectOption hDBC (#const SQL_AUTOCOMMIT) (#const SQL_AUTOCOMMIT_OFF)
        r <- catchSql (action conn) (\err -> do
                        sqlTransact hEnv hDBC (#const SQL_ROLLBACK)
                        sqlSetConnectOption hDBC (#const SQL_AUTOCOMMIT) (#const 
SQL_AUTOCOMMIT_ON)
                        throwDyn err)
        sqlTransact hEnv hDBC (#const SQL_COMMIT)
        sqlSetConnectOption hDBC (#const SQL_AUTOCOMMIT) (#const SQL_AUTOCOMMIT_ON)
        return r
-----------------------------------------------------------------------------------------
-- binding
-----------------------------------------------------------------------------------------

class SqlBind a where
        getValue :: SqlType -> Ptr () -> IO a

instance SqlBind Int where
        getValue SqlInteger ptr  = peek (castPtr (ptr `plusPtr` (#const 
sizeof(SQLINTEGER))))
        getValue SqlSmallInt ptr = do
                (n :: Int16) <- peek (castPtr (ptr `plusPtr` (#const 
sizeof(SQLINTEGER))))
                return (fromIntegral n)
        
instance SqlBind String where
        getValue (SqlChar size) ptr = do
                len <- peek (castPtr ptr)
                peekCStringLen (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))), 
len)
        getValue (SqlVarChar size) ptr = do
                len <- peek (castPtr ptr)
                peekCStringLen (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))), 
len)
        getValue (SqlLongVarChar size) ptr = do
                len <- peek (castPtr ptr)
                peekCStringLen (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))), 
len)
                
instance SqlBind Double where
        getValue (SqlDecimal size prec) ptr = peek (castPtr (ptr `plusPtr` (#const 
sizeof(SQLINTEGER))))
        getValue (SqlNumeric size prec) ptr = peek (castPtr (ptr `plusPtr` (#const 
sizeof(SQLINTEGER))))
        getValue SqlDouble ptr = peek (castPtr (ptr `plusPtr` (#const 
sizeof(SQLINTEGER))))
        getValue SqlReal ptr = peek (castPtr (ptr `plusPtr` (#const 
sizeof(SQLINTEGER))))
        

getFieldValue :: SqlBind a => Statement -> String -> IO a
getFieldValue stmt name = getValue sqlType ((fetchBuffer stmt) `plusPtr` offs)
        where
                (_,sqlType,nullable,offs) = findField name (fields stmt)

getFieldValueType :: Statement -> String -> (SqlType, Bool)
getFieldValueType stmt name = (sqlType, nullable)
        where
                (_,sqlType,nullable,offs) = findField name (fields stmt)
                
getFieldsTypes :: Statement -> [(String, SqlType, Bool)]
getFieldsTypes stmt = map (\(name,sqlType,nullable,_) -> (name,sqlType,nullable)) 
(fields stmt)

findField :: String -> [FieldDef] -> FieldDef
findField name [] = error (name ++ "??")
findField name (fieldDef@(name',_,_,_):fields)
        | name == name' = fieldDef
        | otherwise             = findField name fields

-----------------------------------------------------------------------------------------
-- helpers
-----------------------------------------------------------------------------------------

forEachRow :: (Statement -> s -> IO s) -> Statement -> s -> IO s
forEachRow f stmt s = do
        success <- fetch stmt
        if success then f stmt s >>= forEachRow f stmt else closeStatement stmt >> 
return s
        
forEachRow' :: (Statement -> IO ()) -> Statement -> IO ()
forEachRow' f stmt = do
        success <- fetch stmt
        if success then f stmt >> forEachRow' f stmt else closeStatement stmt
        
collectRows :: (Statement -> IO a) -> Statement -> IO [a]
collectRows f stmt = loop
        where 
                loop = do
                        success <- fetch stmt
                        if success
                                then do
                                        x <- f stmt
                                        xs <- loop
                                        return (x:xs)
                                else closeStatement stmt >> return []

Reply via email to