{-# LANGUAGE CPP, Trustworthy #-}

module Text.EditDistance.ArrayUtilities (
    unsafeReadArray, unsafeWriteArray,
    unsafeReadArray', unsafeWriteArray',
    stringToArray
  ) where

import Control.Monad (forM_)
import Control.Monad.ST

import Data.Array.ST
import Data.Array.Base (unsafeRead, unsafeWrite)

#ifdef __GLASGOW_HASKELL__
import GHC.Arr (unsafeIndex)
#else
import Data.Ix (index)

{-# INLINE unsafeIndex #-}
unsafeIndex :: Ix i => (i, i) -> i -> Int
unsafeIndex = index
#endif


{-# INLINE unsafeReadArray #-}
unsafeReadArray :: (MArray a e m, Ix i) => a i e -> i -> m e
unsafeReadArray :: forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
unsafeReadArray a i e
marr i
i = do
    f <- a i e -> m (i -> m e)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i -> m e)
unsafeReadArray' a i e
marr
    f i

{-# INLINE unsafeWriteArray #-}
unsafeWriteArray :: (MArray a e m, Ix i) => a i e -> i -> e -> m ()
unsafeWriteArray :: forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
unsafeWriteArray a i e
marr i
i e
e = do
  f <- a i e -> m (i -> e -> m ())
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i -> e -> m ())
unsafeWriteArray' a i e
marr
  f i e


{-# INLINE unsafeReadArray' #-}
unsafeReadArray' :: (MArray a e m, Ix i) => a i e -> m (i -> m e)
unsafeReadArray' :: forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i -> m e)
unsafeReadArray' a i e
marr = do
    (l,u) <- a i e -> m (i, i)
forall i. Ix i => a i e -> m (i, i)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i, i)
getBounds a i e
marr
    return $ \i
i -> a i e -> Int -> m e
forall i. Ix i => a i e -> Int -> m e
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> m e
unsafeRead a i e
marr ((i, i) -> i -> Int
forall a. Ix a => (a, a) -> a -> Int
unsafeIndex (i
l,i
u) i
i)

{-# INLINE unsafeWriteArray' #-}
unsafeWriteArray' :: (MArray a e m, Ix i) => a i e -> m (i -> e -> m ())
unsafeWriteArray' :: forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i -> e -> m ())
unsafeWriteArray' a i e
marr = do
  (l,u) <- a i e -> m (i, i)
forall i. Ix i => a i e -> m (i, i)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i, i)
getBounds a i e
marr
  return $ \i
i e
e -> a i e -> Int -> e -> m ()
forall i. Ix i => a i e -> Int -> e -> m ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> Int -> e -> m ()
unsafeWrite a i e
marr ((i, i) -> i -> Int
forall a. Ix a => (a, a) -> a -> Int
unsafeIndex (i
l,i
u) i
i) e
e

{-# INLINE stringToArray #-}
stringToArray :: String -> Int -> ST s (STUArray s Int Char)
stringToArray :: forall s. String -> Int -> ST s (STUArray s Int Char)
stringToArray String
str Int
str_len = do
    array <- (Int, Int) -> ST s (STUArray s Int Char)
forall i. Ix i => (i, i) -> ST s (STUArray s i Char)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> m (a i e)
newArray_ (Int
1, Int
str_len)
    write <- unsafeWriteArray' array
    forM_ (zip [1..] str) (uncurry write)
    return array

{-
showArray :: STUArray s (Int, Int) Int -> ST s String
showArray array = do
    ((il, jl), (iu, ju)) <- getBounds array
    flip (flip foldM "") [(i, j) | i <- [il..iu], j <- [jl.. ju]] $ \rest (i, j) -> do
        elt <- readArray array (i, j)
        return $ rest ++ show (i, j) ++ ": " ++ show elt ++ ", "
-}