Hello,
I'm writing some matrix multiplication and inversion functions for
small matrices (3x3 and 4x4 mostly, for 3d graphics, modeling,
simulation, etc.) I noticed that the matrix multiplication was a
bottleneck so I set out to optimize and found that using unsafeRead
instead of (!) (or readArray in stateful code) helped a lot. So then I
went to optimize my gaussian elimination function and found just the
opposite. unsafeRead is slower than readArray. This struck me as very
odd considering that readArray calls unsafeRead.
If there is a "good" reason why the compiler optimized readArray
better than unsafeRead, I'd like to know what it is so that I can make
all my array code safe as well as fast. (By "good" reason I mean
something deterministic and repeatable, not just luck.)
On the otherhand, if this is a fluke, I'm inclined to think that it's
not the safe code which is freakishly fast, but the unsafe code which
is needlessly slow. That is, something about my program is hindering
optimization of the unsafe code. What is it?
Attached is the profiling results and a test program with a handful of
matrix multiplication and gaussian elimination functions to illustrate
what I've seen. This happens both on amd64 and intel core
architectures.
Thanks for any insight,
Scott
{-
total time = 1196.70 secs (23934 ticks @ 50 ms)
total alloc = 419,350,893,280 bytes (excludes profiling overheads)
COST CENTRE MODULE %time %alloc ticks bytes
matrixMultSafe Main 26.9 31.4 6449 16450164500
gaussElimUnsafe' Main 21.4 20.4 5126 10687510443
gaussElim2Unsafe Main 17.4 17.9 4173 9366179268
gaussElimSafe' Main 16.2 16.0 3881 8377362484
gaussElim2Safe Main 13.4 13.1 3207 6869088980
matrixMultUnsafe Main 3.5 0.8 829 400004000
-}
{-# OPTIONS_GHC -O2 -optc-O2 -fglasgow-exts -fbang-patterns #-}
import Control.Monad
import Data.List
import Data.Array.IO
import Data.Array.Unboxed
import System
import System.IO.Unsafe
import System.Random
import Data.Array.Base
matrixMultUnsafe n a b = unsafePerformIO $
do
c <- newArray_ ((1,1),(n,n)) :: IO (IOUArray (Int,Int) Double)
let
f !i !j !k !s | (k==n) = s
f !i !j !k !s = f i j (k+1) $ s + (a`unsafeAt`(n*i+k))*(b`unsafeAt`(n*k+j))
jloop !i !j | (j==n) = return()
jloop !i !j = do unsafeWrite c (i*n+j) (f i j 0 0) ; jloop i (j+1)
iloop !i | (i==n) = return()
iloop !i = do jloop i 0; iloop (i+1)
iloop 0
unsafeFreeze c
matrixMultSafe n a b = unsafePerformIO $
do
c <- newArray_ ((1,1),(n,n)) :: IO (IOUArray (Int,Int) Double)
let
f !i !j !k !s | (k>n) = s
f !i !j !k !s = f i j (k+1) $ s + (a!(i,k))*(b!(k,j))
jloop !i !j | (j>n) = return()
jloop !i !j = do writeArray c (i,j) (f i j 1 0) ; jloop i (j+1)
iloop !i | (i>n) = return()
iloop !i = do jloop i 1; iloop (i+1)
iloop 1
unsafeFreeze c
gaussElimUnsafe matrix =
do
((i1,j1),(m,n)) <- getBounds matrix
gaussElimUnsafe' matrix (m-i1+1) (n-j1+1)
gaussElimUnsafe' matrix m n = doColumn 0 0
where
doColumn !i !j | (i==m||j==n) = return()
doColumn !i !j =
do
(pivotRow,pivotVal) <- findPivot i j
if nearZero pivotVal
then doColumn i (j+1)
else
do
swapRowsAndDivideByPivot i pivotRow pivotVal
subtractRows i j
doColumn (i+1) (j+1)
findPivot !i !j = f i (i,0)
where
f !i (!maxi,!maxe) | (i==m) = return (maxi,maxe)
f !i (!maxi,!maxe) =
do
e <- unsafeRead matrix (i*n+j)
f (i+1) $ if abs e > abs maxe then (i,e) else (maxi,maxe)
swapRowsAndDivideByPivot !i !pr !pv = f 0
where
f !j | (j==n) = return ()
f !j =
do
ei <- unsafeRead matrix (i *n+j)
ep <- unsafeRead matrix (pr*n+j)
unsafeWrite matrix (i *n+j) ep
unsafeWrite matrix (pr*n+j) (ei/pv)
f (j+1)
subtractRows !i !j = f 0
where
f !u | (u==m) = return ()
f !u | (u==i) = f (u+1)
f !u =
do
s <- unsafeRead matrix (u*n+j)
g s u 0
f (u+1)
g _ _ !j | (j==n) = return ()
g !s !u !j =
do
ei <- unsafeRead matrix (i*n+j)
eu <- unsafeRead matrix (u*n+j)
unsafeWrite matrix (u*n+j) (eu - s*ei)
g s u (j+1)
--------------------------------------------------
gaussElimSafe matrix =
do
bnds <- getBounds matrix
gaussElimSafe' matrix bnds
gaussElimSafe' matrix ((i1,j1),(m,n)) = doColumn i1 j1
where
doColumn !i !j | (i>m||j>n) = return()
doColumn !i !j =
do
(pivotRow,pivotVal) <- findPivot i j
if nearZero pivotVal
then doColumn i (j+1)
else
do
swapRowsAndDivideByPivot i pivotRow pivotVal
subtractRows i j
doColumn (i+1) (j+1)
findPivot !i !j = f i (i,0)
where
f !i (!maxi,!maxe) | i>m = return (maxi,maxe)
f !i (!maxi,!maxe) =
do
e <- readArray matrix (i,j)
f (i+1) $ if abs e > abs maxe then (i,e) else (maxi,maxe)
swapRowsAndDivideByPivot !i !pr !pv = f j1
where
f !j | j>n = return ()
f !j =
do
ei <- readArray matrix (i ,j)
ep <- readArray matrix (pr,j)
writeArray matrix (i ,j) (ep/pv)
writeArray matrix (pr,j) ei
f (j+1)
subtractRows !i !j = f i1
where
f !u | u>m = return ()
f !u | u==i = f (u+1)
f !u =
do
s <- readArray matrix (u,j)
g s u j1
f (u+1)
g _ _ !j | j>n = return ()
g !s !u !j =
do
ei <- readArray matrix (i,j)
eu <- readArray matrix (u,j)
writeArray matrix (u,j) (eu - s*ei)
g s u (j+1)
------------------------------------------------------
gaussElim2Unsafe m n matrix =
do
_ <- fold1M doColumn [0..n-1]
return () --matrix
where
doColumn i j | i==m = return i
doColumn i j =
do (pivotRow,pivotVal) <- findPivot i j
if nearZero pivotVal
then return i
else do swapRows i pivotRow
divideRow i pivotVal
mapM_ (\i' -> do e <- unsafeRead matrix (i'*n+j); subtractRow i (e,i')) [0..m-1]
return (i+1)
findPivot i j =
do pivotRow <- fold1M
(\ ia ib ->
do ea <- unsafeRead matrix (ia*n+j)
eb <- unsafeRead matrix (ib*n+j)
if abs ea > abs eb
then return ia
else return ib
) [i..m-1]
pivotVal <- unsafeRead matrix (pivotRow*n+j)
return (pivotRow,pivotVal)
swapRows ia ib = unless (ia == ib) $ mapM_ f [0..n-1]
where f j = do ea <- unsafeRead matrix (ia*n+j)
eb <- unsafeRead matrix (ib*n+j)
unsafeWrite matrix (ia*n+j) eb
unsafeWrite matrix (ib*n+j) ea
-- subtract s*row(ia) from row(ib)
subtractRow ia (s,ib) = unless (ia == ib) $ mapM_ f [0..n-1]
where f j = do ea <- unsafeRead matrix (ia*n+j)
eb <- unsafeRead matrix (ib*n+j)
unsafeWrite matrix (ib*n+j) (eb - s*ea)
--divide row(i) by s
divideRow i s = mapM_ f [0..n-1]
where f j = do e <- unsafeRead matrix (i*n+j)
unsafeWrite matrix (i*n+j) (e/s)
----------------------------------------------------------------------
gaussElim2Safe matrix ((i1,j1),(m,n)) =
do
_ <- fold1M doColumn [j1..n]
return () --matrix
where
doColumn i j | i > m = return i
doColumn i j =
do (pivotRow,pivotVal) <- findPivot i j
if nearZero pivotVal
then return i
else do swapRows i pivotRow
divideRow i pivotVal
mapM_ (\i' -> do e <- readArray matrix (i',j); subtractRow i (e,i')) [i1..m]
return (i+1)
findPivot i j =
do pivotRow <- fold1M
(\ ia ib ->
do ea <- readArray matrix (ia,j)
eb <- readArray matrix (ib,j)
if abs ea > abs eb
then return ia
else return ib
) [i..m]
pivotVal <- readArray matrix (pivotRow,j)
return (pivotRow,pivotVal)
swapRows ia ib = unless (ia == ib) $ mapM_ f [j1..n]
where f j = do ea <- readArray matrix (ia,j)
eb <- readArray matrix (ib,j)
writeArray matrix (ia,j) eb
writeArray matrix (ib,j) ea
-- subtract s*row(ia) from row(ib)
subtractRow ia (s,ib) = unless (ia == ib) $ mapM_ f [j1..n]
where f j = do ea <- readArray matrix (ia,j)
eb <- readArray matrix (ib,j)
writeArray matrix (ib,j) (eb - s*ea)
--divide row(i) by s
divideRow i s = mapM_ f [j1..n]
where f j = do e <- readArray matrix (i,j)
writeArray matrix (i,j) (e/s)
---------------------------------------------------------------------
fold1M f xs = foldM f (head xs) xs
fold1M_ f xs = fold1M f xs >> return ()
nearZero x = abs x < 1e-5
numItrs = 100
main =
do
rngs <- sequence (replicate numItrs newStdGen)
putStrLn "mulMatrixUnsafe=================="
forM_ rngs $ \rng ->
do
let
xs = randoms rng
a = listArray ((1,1),(4,4)) (take 16 xs) :: UArray (Int,Int) Double
b = foldl' (matrixMultUnsafe 4) a (replicate 100000 a)
print a
print $ (matrixMultUnsafe 4 a a :: UArray (Int,Int) Double)
print b
putStrLn "mulMatrixSafe=================="
forM_ rngs $ \rng ->
do
let
xs = randoms rng
a = listArray ((1,1),(4,4)) (take 16 xs) :: UArray (Int,Int) Double
b = foldl' (matrixMultSafe 4) a (replicate 100000 a)
print a
print $ (matrixMultSafe 4 a a :: UArray (Int,Int) Double)
print b
putStrLn "gaussElimSafe=================="
forM_ rngs $ \rng ->
do
a <- makeMatrix rng
forM_ [1..10000] $ \_ -> gaussElimSafe a
printMatrix a
putStrLn "gaussElimUnsafe=================="
forM_ rngs $ \rng ->
do
a <- makeMatrix rng
forM_ [1..10000] $ \_ -> gaussElimUnsafe a
printMatrix a
putStrLn "gaussElim2Safe=================="
forM_ rngs $ \rng ->
do
a <- makeMatrix rng
forM_ [1..10000] $ \_ -> gaussElim2Safe a ((1,1),(4,8))
printMatrix a
putStrLn "gaussElim2Unsafe=================="
forM_ rngs $ \rng ->
do
a <- makeMatrix rng
forM_ [1..10000] $ \_ -> gaussElim2Unsafe 4 8 a
printMatrix a
makeMatrix rng =
do
let xs = randoms rng
m <- newListArray ((1,1),(4,8)) (take 32 xs) :: IO (IOUArray (Int,Int) Double)
forM_ [1..4] $ \i ->
do
forM_ [5..8] $ \j -> writeArray m (i,j) (if j==i+4 then 1 else 0)
return m
printMatrix :: IOUArray (Int,Int) Double -> IO ()
printMatrix m =
do
((r1,c1),(rm,cn)) <- getBounds m
forM_ [ (i,j) | i<-[r1..rm],j<-[c1..cn] ] $ \(i,j) ->
do
a <- readArray m (i,j)
putStr $ (show a) ++ " " ++ (if j == cn then "\n" else "")
putStr "\n"
_______________________________________________
Glasgow-haskell-users mailing list
[email protected]
http://www.haskell.org/mailman/listinfo/glasgow-haskell-users