]> Git — Sourcephile - literate-phylomemy.git/blob - src/Clustering/UnionFind/ST.hs
init
[literate-phylomemy.git] / src / Clustering / UnionFind / ST.hs
1 -- SPDX-License-Identifier: BSD-3-Clause
2 -- SPDX-FileCopyrightText: 2012 Thomas Schilling <nominolo@googlemail.com>
3 --
4 {-# OPTIONS_GHC -funbox-strict-fields #-}
5
6 -- | An implementation of Tarjan's UNION-FIND algorithm.
7 -- (Robert E Tarjan. \"Efficiency of a Good But Not Linear Set Union Algorithm\", JACM 22(2), 1975)
8 --
9 -- The algorithm implements three operations efficiently (all amortised @O(1)@):
10 --
11 -- 1. Check whether two elements are in the same equivalence class.
12 --
13 -- 2. Create a union of two equivalence classes.
14 --
15 -- 3. Look up the descriptor of the equivalence class.
16 --
17 -- The implementation is based on mutable references.
18 -- Each equivalence class has exactly one member that serves
19 -- as its representative element.
20 -- Every element either is the representative element of its equivalence class
21 -- or points to another element in the same equivalence class.
22 -- Equivalence testing thus consists of following the pointers
23 -- to the representative elements and then comparing these for identity.
24 --
25 -- The algorithm performs lazy path compression.
26 -- That is, whenever we walk along a path greater than length 1
27 -- we automatically update the pointers along the path to directly point
28 -- to the representative element.
29 -- Consequently future lookups will be have a path length of at most 1.
30 --
31 -- Adapted from Thomas Schilling's union-find package:
32 -- https://hackage.haskell.org/package/union-find
33 module Clustering.UnionFind.ST (
34 Point,
35 fresh,
36 repr,
37 union,
38 union',
39 equivalent,
40 redundant,
41 descriptor,
42 setDescriptor,
43 modifyDescriptor,
44 )
45 where
46
47 import Control.Applicative
48 import Control.Monad (Monad (..), when)
49 import Control.Monad.ST
50 import Data.Bool (Bool (..))
51 import Data.Eq (Eq (..))
52 import Data.Function (($))
53 import Data.Int (Int)
54 import Data.Ord (Ord (..))
55 import Data.STRef
56 import Prelude (error, (+))
57
58 -- | The abstract type of an element of the sets we work on. It is
59 -- parameterised over the type of the descriptor.
60 newtype Point s a = MkPoint (STRef s (Link s a))
61 deriving Eq -- Pointer equality on STRef
62
63 -- TODO: unpack Info
64 -- as in https://github.com/haskell/cabal/blob/8815e0a3e76e05cac91b8a88ce7d590afb07ef71/Cabal/src/Distribution/Utils/UnionFind.hs
65 data Link s a
66 = -- | This is the descriptive element of the equivalence class.
67 Info {-# UNPACK #-} !(STRef s (Info a))
68 | -- | Pointer to some other element of the equivalence class.
69 Link {-# UNPACK #-} !(Point s a)
70 deriving (Eq)
71
72 unInfo :: Link s a -> STRef s (Info a)
73 unInfo = \case
74 Info x -> x
75 _ -> error "unInfo"
76
77 data Info a = MkInfo
78 { weight :: {-# UNPACK #-} !Int
79 -- ^ The size of the equivalence class, used by 'union'.
80 , descr :: a
81 }
82 deriving (Eq)
83
84 -- | /O(1)/.
85 -- Create a fresh equivalence class and return it. A fresh point is in
86 -- the equivalence class that contains only itself.
87 fresh :: a -> ST s (Point s a)
88 fresh desc = do
89 info <- newSTRef (MkInfo{weight = 1, descr = desc})
90 l <- newSTRef (Info info)
91 return (MkPoint l)
92
93 -- | /O(1)/. @repr point@
94 -- returns the representative point of @point@'s equivalence class.
95 --
96 -- This method performs the path compresssion.
97 repr :: Point s a -> ST s (Point s a)
98 repr point@(MkPoint l) = do
99 link <- readSTRef l
100 case link of
101 Info _ -> return point
102 Link pt'@(MkPoint l') -> do
103 pt'' <- repr pt'
104 when (pt'' /= pt') $ do
105 -- At this point we know that @pt'@ is not the representative
106 -- element of @point@'s equivalent class. Therefore @pt'@'s
107 -- link must be of the form @Link r@. We write this same
108 -- value into @point@'s link reference and thereby perform
109 -- path compression.
110 link' <- readSTRef l'
111 writeSTRef l link'
112 return pt''
113
114 -- | Return the reference to the point's equivalence class's descriptor.
115 descrRef :: Point s a -> ST s (STRef s (Info a))
116 descrRef point@(MkPoint link_ref) = do
117 link <- readSTRef link_ref
118 case link of
119 Info info -> return info
120 Link (MkPoint link'_ref) -> do
121 -- Unrolling for the length == 1 case.
122 link' <- readSTRef link'_ref
123 case link' of
124 Info info -> return info
125 _ -> repr point >>= descrRef
126
127 -- | /O(1)/. Return the descriptor associated with argument point's
128 -- equivalence class.
129 descriptor :: Point s a -> ST s a
130 descriptor point = descr <$> (descrRef point >>= readSTRef)
131
132 -- | /O(1)/. Replace the descriptor of the point's equivalence class
133 -- with the second argument.
134 setDescriptor :: Point s a -> a -> ST s ()
135 setDescriptor point new_descr = do
136 r <- descrRef point
137 modifySTRef r $ \i -> i{descr = new_descr}
138
139 modifyDescriptor :: Point s a -> (a -> a) -> ST s ()
140 modifyDescriptor point f = do
141 r <- descrRef point
142 modifySTRef r $ \i -> i{descr = f (descr i)}
143
144 -- | /O(1)/. Join the equivalence classes of the points (which must be
145 -- distinct). The resulting equivalence class will get the descriptor
146 -- of the second argument.
147 union :: Point s a -> Point s a -> ST s ()
148 union p1 p2 = union' p1 p2 (\_ d2 -> return d2)
149
150 -- | Like 'union', but sets the descriptor returned from the callback.
151 --
152 -- The intention is to keep the descriptor of the second argument to
153 -- the callback, but the callback might adjust the information of the
154 -- descriptor or perform side effects.
155 union' :: Point s a -> Point s a -> (a -> a -> ST s a) -> ST s ()
156 union' p1 p2 update = do
157 point1@(MkPoint link_ref1) <- repr p1
158 point2@(MkPoint link_ref2) <- repr p2
159 -- The precondition ensures that we don't create cyclic structures.
160 when (point1 /= point2) $ do
161 info_ref1 <- unInfo <$> readSTRef link_ref1
162 info_ref2 <- unInfo <$> readSTRef link_ref2
163 MkInfo w1 d1 <- readSTRef info_ref1 -- d1 is discarded
164 MkInfo w2 d2 <- readSTRef info_ref2
165 d2' <- update d1 d2
166 -- Make the smaller tree a subtree of the bigger one.
167 -- The idea is this: We increase the path length of one set by one.
168 -- Assuming all elements are accessed equally often,
169 -- this means the penalty is smaller if we do it
170 -- for the smaller set of the two.
171 if w1 >= w2
172 then do
173 writeSTRef link_ref2 (Link point1)
174 writeSTRef info_ref1 (MkInfo (w1 + w2) d2')
175 else do
176 writeSTRef link_ref1 (Link point2)
177 writeSTRef info_ref2 (MkInfo (w1 + w2) d2')
178
179 -- | /O(1)/. Return @True@ if both points belong to the same
180 -- | equivalence class.
181 equivalent :: Point s a -> Point s a -> ST s Bool
182 equivalent p1 p2 = (==) <$> repr p1 <*> repr p2
183
184 -- | /O(1)/. Returns @True@ for all but one element of an equivalence class.
185 -- That is, if @ps = [p1, .., pn]@ are all in the same
186 -- equivalence class, then the following assertion holds.
187 --
188 -- > do rs <- mapM redundant ps
189 -- > assert (length (filter (==False) rs) == 1)
190 --
191 -- It is unspecified for which element function returns @False@,
192 -- so be really careful when using this.
193 redundant :: Point s a -> ST s Bool
194 redundant (MkPoint link_r) = do
195 link <- readSTRef link_r
196 case link of
197 Info _ -> return False
198 Link _ -> return True