{-# LANGUAGE PatternGuards, ScopedTypeVariables, BangPatterns, FlexibleContexts #-}
module Text.EditDistance.STUArray (
levenshteinDistance, levenshteinDistanceWithLengths, restrictedDamerauLevenshteinDistance, restrictedDamerauLevenshteinDistanceWithLengths
) where
import Text.EditDistance.EditCosts
import Text.EditDistance.MonadUtilities
import Text.EditDistance.ArrayUtilities
import Control.Monad hiding (foldM)
import Control.Monad.ST
import Data.Array.ST
levenshteinDistance :: EditCosts -> String -> String -> Int
levenshteinDistance :: EditCosts -> String -> String -> Int
levenshteinDistance !EditCosts
costs String
str1 String
str2 = EditCosts -> Int -> Int -> String -> String -> Int
levenshteinDistanceWithLengths EditCosts
costs Int
str1_len Int
str2_len String
str1 String
str2
where
str1_len :: Int
str1_len = String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str1
str2_len :: Int
str2_len = String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str2
levenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
levenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
levenshteinDistanceWithLengths !EditCosts
costs !Int
str1_len !Int
str2_len String
str1 String
str2 = (forall s. ST s Int) -> Int
forall a. (forall s. ST s a) -> a
runST (EditCosts -> Int -> Int -> String -> String -> ST s Int
forall s. EditCosts -> Int -> Int -> String -> String -> ST s Int
levenshteinDistanceST EditCosts
costs Int
str1_len Int
str2_len String
str1 String
str2)
levenshteinDistanceST :: EditCosts -> Int -> Int -> String -> String -> ST s Int
levenshteinDistanceST :: forall s. EditCosts -> Int -> Int -> String -> String -> ST s Int
levenshteinDistanceST !EditCosts
costs !Int
str1_len !Int
str2_len String
str1 String
str2 = do
str1_array <- String -> Int -> ST s (STUArray s Int Char)
forall s. String -> Int -> ST s (STUArray s Int Char)
stringToArray String
str1 Int
str1_len
str2_array <- stringToArray str2 str2_len
start_cost_row <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)
start_cost_row' <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)
read_str1 <- unsafeReadArray' str1_array
read_str2 <- unsafeReadArray' str2_array
_ <- (\(Int, Int) -> Char -> ST s (Int, Int)
f -> ((Int, Int) -> Char -> ST s (Int, Int))
-> (Int, Int) -> String -> ST s (Int, Int)
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> [b] -> m a
foldM (Int, Int) -> Char -> ST s (Int, Int)
f (Int
1, Int
0) String
str1) $ \(Int
i, Int
deletion_cost) Char
col_char -> let deletion_cost' :: Int
deletion_cost' = Int
deletion_cost Int -> Int -> Int
forall a. Num a => a -> a -> a
+ EditCosts -> Char -> Int
deletionCost EditCosts
costs Char
col_char in STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
unsafeWriteArray STUArray s Int Int
start_cost_row Int
i Int
deletion_cost' ST s () -> ST s (Int, Int) -> ST s (Int, Int)
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (Int, Int) -> ST s (Int, Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
deletion_cost')
(_, final_row, _) <- (\(Int, STUArray s Int Int, STUArray s Int Int)
-> Int -> ST s (Int, STUArray s Int Int, STUArray s Int Int)
f -> ((Int, STUArray s Int Int, STUArray s Int Int)
-> Int -> ST s (Int, STUArray s Int Int, STUArray s Int Int))
-> (Int, STUArray s Int Int, STUArray s Int Int)
-> [Int]
-> ST s (Int, STUArray s Int Int, STUArray s Int Int)
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> [b] -> m a
foldM (Int, STUArray s Int Int, STUArray s Int Int)
-> Int -> ST s (Int, STUArray s Int Int, STUArray s Int Int)
f (Int
0, STUArray s Int Int
start_cost_row, STUArray s Int Int
start_cost_row') [Int
1..Int
str2_len]) $ \(!Int
insertion_cost, !STUArray s Int Int
cost_row, !STUArray s Int Int
cost_row') !Int
j -> do
row_char <- Int -> ST s Char
read_str2 Int
j
let insertion_cost' = Int
insertion_cost Int -> Int -> Int
forall a. Num a => a -> a -> a
+ EditCosts -> Char -> Int
insertionCost EditCosts
costs Char
row_char
unsafeWriteArray cost_row' 0 insertion_cost'
loopM_ 1 str1_len $ \(!Int
i) -> do
col_char <- Int -> ST s Char
read_str1 Int
i
left_up <- unsafeReadArray cost_row (i - 1)
left <- unsafeReadArray cost_row' (i - 1)
here_up <- unsafeReadArray cost_row i
let here = EditCosts -> Char -> Char -> Int -> Int -> Int -> Int
standardCosts EditCosts
costs Char
row_char Char
col_char Int
left Int
left_up Int
here_up
unsafeWriteArray cost_row' i here
return (insertion_cost', cost_row', cost_row)
unsafeReadArray final_row str1_len
restrictedDamerauLevenshteinDistance :: EditCosts -> String -> String -> Int
restrictedDamerauLevenshteinDistance :: EditCosts -> String -> String -> Int
restrictedDamerauLevenshteinDistance !EditCosts
costs String
str1 String
str2 = EditCosts -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistanceWithLengths EditCosts
costs Int
str1_len Int
str2_len String
str1 String
str2
where
str1_len :: Int
str1_len = String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str1
str2_len :: Int
str2_len = String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str2
restrictedDamerauLevenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistanceWithLengths !EditCosts
costs !Int
str1_len !Int
str2_len String
str1 String
str2 = (forall s. ST s Int) -> Int
forall a. (forall s. ST s a) -> a
runST (EditCosts -> Int -> Int -> String -> String -> ST s Int
forall s. EditCosts -> Int -> Int -> String -> String -> ST s Int
restrictedDamerauLevenshteinDistanceST EditCosts
costs Int
str1_len Int
str2_len String
str1 String
str2)
restrictedDamerauLevenshteinDistanceST :: EditCosts -> Int -> Int -> String -> String -> ST s Int
restrictedDamerauLevenshteinDistanceST :: forall s. EditCosts -> Int -> Int -> String -> String -> ST s Int
restrictedDamerauLevenshteinDistanceST !EditCosts
costs Int
str1_len Int
str2_len String
str1 String
str2 = do
str1_array <- String -> Int -> ST s (STUArray s Int Char)
forall s. String -> Int -> ST s (STUArray s Int Char)
stringToArray String
str1 Int
str1_len
str2_array <- stringToArray str2 str2_len
cost_row <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)
read_str1 <- unsafeReadArray' str1_array
read_str2 <- unsafeReadArray' str2_array
_ <- (\(Int, Int) -> Char -> ST s (Int, Int)
f -> ((Int, Int) -> Char -> ST s (Int, Int))
-> (Int, Int) -> String -> ST s (Int, Int)
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> [b] -> m a
foldM (Int, Int) -> Char -> ST s (Int, Int)
f (Int
1, Int
0) String
str1) $ \(Int
i, Int
deletion_cost) Char
col_char -> let deletion_cost' :: Int
deletion_cost' = Int
deletion_cost Int -> Int -> Int
forall a. Num a => a -> a -> a
+ EditCosts -> Char -> Int
deletionCost EditCosts
costs Char
col_char in STUArray s Int Int -> Int -> Int -> ST s ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
unsafeWriteArray STUArray s Int Int
cost_row Int
i Int
deletion_cost' ST s () -> ST s (Int, Int) -> ST s (Int, Int)
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (Int, Int) -> ST s (Int, Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
deletion_cost')
if (str2_len == 0)
then unsafeReadArray cost_row str1_len
else do
cost_row' <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)
cost_row'' <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)
row_char <- read_str2 1
let zero = EditCosts -> Char -> Int
insertionCost EditCosts
costs Char
row_char
unsafeWriteArray cost_row' 0 zero
loopM_ 1 str1_len (firstRowColWorker read_str1 row_char cost_row cost_row')
(_, _, final_row, _, _) <- foldM (restrictedDamerauLevenshteinDistanceSTRowWorker costs str1_len read_str1 read_str2) (zero, cost_row, cost_row', cost_row'', row_char) [2..str2_len]
unsafeReadArray final_row str1_len
where
{-# INLINE firstRowColWorker #-}
firstRowColWorker :: (i -> m Char) -> Char -> a i Int -> a i Int -> i -> m ()
firstRowColWorker i -> m Char
read_str1 !Char
row_char !a i Int
cost_row !a i Int
cost_row' !i
i = do
col_char <- i -> m Char
read_str1 i
i
left_up <- unsafeReadArray cost_row (i - 1)
left <- unsafeReadArray cost_row' (i - 1)
here_up <- unsafeReadArray cost_row i
let here = EditCosts -> Char -> Char -> Int -> Int -> Int -> Int
standardCosts EditCosts
costs Char
row_char Char
col_char Int
left Int
left_up Int
here_up
unsafeWriteArray cost_row' i here
{-# INLINE restrictedDamerauLevenshteinDistanceSTRowWorker #-}
restrictedDamerauLevenshteinDistanceSTRowWorker :: EditCosts -> Int
-> (Int -> ST s Char) -> (Int -> ST s Char)
-> (Int, STUArray s Int Int, STUArray s Int Int, STUArray s Int Int, Char) -> Int
-> ST s (Int, STUArray s Int Int, STUArray s Int Int, STUArray s Int Int, Char)
restrictedDamerauLevenshteinDistanceSTRowWorker :: forall s.
EditCosts
-> Int
-> (Int -> ST s Char)
-> (Int -> ST s Char)
-> (Int, STUArray s Int Int, STUArray s Int Int,
STUArray s Int Int, Char)
-> Int
-> ST
s
(Int, STUArray s Int Int, STUArray s Int Int, STUArray s Int Int,
Char)
restrictedDamerauLevenshteinDistanceSTRowWorker !EditCosts
costs !Int
str1_len Int -> ST s Char
read_str1 Int -> ST s Char
read_str2 (!Int
insertion_cost, !STUArray s Int Int
cost_row, !STUArray s Int Int
cost_row', !STUArray s Int Int
cost_row'', !Char
prev_row_char) !Int
j = do
row_char <- Int -> ST s Char
read_str2 Int
j
zero_up <- unsafeReadArray cost_row' 0
let insertion_cost' = Int
insertion_cost Int -> Int -> Int
forall a. Num a => a -> a -> a
+ EditCosts -> Char -> Int
insertionCost EditCosts
costs Char
row_char
unsafeWriteArray cost_row'' 0 insertion_cost'
when (str1_len > 0) $ do
col_char <- read_str1 1
one_up <- unsafeReadArray cost_row' 1
let one = EditCosts -> Char -> Char -> Int -> Int -> Int -> Int
standardCosts EditCosts
costs Char
row_char Char
col_char Int
insertion_cost' Int
zero_up Int
one_up
unsafeWriteArray cost_row'' 1 one
loopM_ 2 str1_len (colWorker row_char)
return (insertion_cost', cost_row', cost_row'', cost_row, row_char)
where
colWorker :: Char -> Int -> ST s ()
colWorker !Char
row_char !Int
i = do
prev_col_char <- Int -> ST s Char
read_str1 (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
col_char <- read_str1 i
left_left_up_up <- unsafeReadArray cost_row (i - 2)
left_up <- unsafeReadArray cost_row' (i - 1)
left <- unsafeReadArray cost_row'' (i - 1)
here_up <- unsafeReadArray cost_row' i
let here_standard_only = EditCosts -> Char -> Char -> Int -> Int -> Int -> Int
standardCosts EditCosts
costs Char
row_char Char
col_char Int
left Int
left_up Int
here_up
here = if Char
prev_row_char Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
col_char Bool -> Bool -> Bool
&& Char
prev_col_char Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
row_char
then Int
here_standard_only Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` (Int
left_left_up_up Int -> Int -> Int
forall a. Num a => a -> a -> a
+ EditCosts -> Char -> Char -> Int
transpositionCost EditCosts
costs Char
col_char Char
row_char)
else Int
here_standard_only
unsafeWriteArray cost_row'' i here
{-# INLINE standardCosts #-}
standardCosts :: EditCosts -> Char -> Char -> Int -> Int -> Int -> Int
standardCosts :: EditCosts -> Char -> Char -> Int -> Int -> Int -> Int
standardCosts !EditCosts
costs !Char
row_char !Char
col_char !Int
cost_left !Int
cost_left_up !Int
cost_up = Int
deletion_cost Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
insertion_cost Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
subst_cost
where
deletion_cost :: Int
deletion_cost = Int
cost_left Int -> Int -> Int
forall a. Num a => a -> a -> a
+ EditCosts -> Char -> Int
deletionCost EditCosts
costs Char
col_char
insertion_cost :: Int
insertion_cost = Int
cost_up Int -> Int -> Int
forall a. Num a => a -> a -> a
+ EditCosts -> Char -> Int
insertionCost EditCosts
costs Char
row_char
subst_cost :: Int
subst_cost = Int
cost_left_up Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if Char
row_char Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
col_char then Int
0 else EditCosts -> Char -> Char -> Int
substitutionCost EditCosts
costs Char
col_char Char
row_char