blob: 5acc896f60f1f31ed32371edc10dae9d5259483b (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
|
{-# LANGUAGE BangPatterns #-}
module Data.ByteString.Search.DFA (strictSearcher) where
import qualified Data.ByteString as S
import Data.ByteString.Unsafe (unsafeIndex)
import Control.Monad (when)
import Data.Array.Base (unsafeRead, unsafeWrite, unsafeAt)
import Data.Array.ST (newArray, newArray_, runSTUArray)
import Data.Array.Unboxed (UArray)
import Data.Bits (Bits(..))
import Data.Word (Word8)
------------------------------------------------------------------------------
-- Searching Function --
------------------------------------------------------------------------------
strictSearcher :: Bool -> S.ByteString -> S.ByteString -> [Int]
strictSearcher _ !pat
| S.null pat = enumFromTo 0 . S.length
| S.length pat == 1 = let !w = S.head pat in S.elemIndices w
strictSearcher !overlap pat = search
where
!patLen = S.length pat
!auto = automaton pat
!p0 = unsafeIndex pat 0
!ams = if overlap then patLen else 0
search str = match 0 0
where
!strLen = S.length str
{-# INLINE strAt #-}
strAt :: Int -> Int
strAt !i = fromIntegral (unsafeIndex str i)
match 0 idx
| idx == strLen = []
| unsafeIndex str idx == p0 = match 1 (idx + 1)
| otherwise = match 0 (idx + 1)
match state idx
| idx == strLen = []
| otherwise =
let !nstate = unsafeAt auto ((state `shiftL` 8) + strAt idx)
!nxtIdx = idx + 1
in if nstate == patLen
then (nxtIdx - patLen) : match ams nxtIdx
else match nstate nxtIdx
------------------------------------------------------------------------------
-- Preprocessing --
------------------------------------------------------------------------------
{-# INLINE automaton #-}
automaton :: S.ByteString -> UArray Int Int
automaton !pat = runSTUArray (do
let !patLen = S.length pat
{-# INLINE patAt #-}
patAt !i = fromIntegral (unsafeIndex pat i)
!bord = kmpBorders pat
aut <- newArray (0, (patLen + 1)*256 - 1) 0
unsafeWrite aut (patAt 0) 1
let loop !state = do
let !base = state `shiftL` 8
inner j
| j < 0 = if state == patLen
then return aut
else loop (state+1)
| otherwise = do
let !i = base + patAt j
s <- unsafeRead aut i
when (s == 0) (unsafeWrite aut i (j+1))
inner (unsafeAt bord j)
if state == patLen
then inner (unsafeAt bord state)
else inner state
loop 1)
-- kmpBorders calculates the width of the widest borders of the prefixes
-- of the pattern which are not extensible to borders of the next
-- longer prefix. Most entries will be 0.
{-# INLINE kmpBorders #-}
kmpBorders :: S.ByteString -> UArray Int Int
kmpBorders pat = runSTUArray (do
let !patLen = S.length pat
{-# INLINE patAt #-}
patAt :: Int -> Word8
patAt i = unsafeIndex pat i
ar <- newArray_ (0, patLen)
unsafeWrite ar 0 (-1)
let dec w j
| j < 0 || w == patAt j = return $! j+1
| otherwise = unsafeRead ar j >>= dec w
bordLoop !i !j
| patLen < i = return ar
| otherwise = do
let !w = patAt (i-1)
j' <- dec w j
if i < patLen && patAt j' == patAt i
then unsafeRead ar j' >>= unsafeWrite ar i
else unsafeWrite ar i j'
bordLoop (i+1) j'
bordLoop 1 (-1))
|