{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# OPTIONS_GHC -fno-warn-tabs #-}

module Data.TreeMap.Strict.Zipper where

import           Control.Monad (Monad(..), (>=>))
import           Data.Data (Data)
import           Data.Eq (Eq)
import           Data.Function (($), (.))
import qualified Data.List as List
import           Data.List.NonEmpty (NonEmpty(..))
import qualified Data.Map.Strict as Map
import           Data.Maybe (Maybe(..), maybe, maybeToList)
import           Data.Ord (Ord(..))
import           Data.Typeable (Typeable)
import           Text.Show (Show(..))

import           Data.TreeMap.Strict (TreeMap(..))
import qualified Data.TreeMap.Strict as TreeMap

-- * Type 'Zipper'

data Zipper k x
 =   Zipper
 {   zipper_path :: [Zipper_Step k x]
 ,   zipper_curr :: TreeMap k x
 } deriving (Data, Eq, Show, Typeable)

zipper :: TreeMap k x -> Zipper k x
zipper = Zipper []

zipper_root :: Ord k => Zipper k x -> TreeMap k x
zipper_root =
	zipper_curr . List.last .
	zipper_collect zipper_parent

-- * Type 'Zipper_Step'

data Zipper_Step k x
 =   Zipper_Step
 {   zipper_step_prec :: TreeMap k x
 ,   zipper_step_self :: (k, TreeMap.Node k x)
 ,   zipper_step_foll :: TreeMap k x
 } deriving (Data, Eq, Show, Typeable)

-- * Axis

-- | Collect all 'Zipper's along a given axis,
--   including the first 'Zipper'.
zipper_collect :: (z -> Maybe z) -> z -> [z]
zipper_collect f z = z : maybe [] (zipper_collect f) (f z)

-- | Collect all 'Zipper's along a given axis,
--   excluding the first 'Zipper'.
zipper_collect_without_self :: (z -> Maybe z) -> z -> [z]
zipper_collect_without_self f z = maybe [] (zipper_collect f) (f z)

-- ** Axis self

zipper_self :: Zipper k x -> Maybe (k, TreeMap.Node k x)
zipper_self z =
	case z of
	 Zipper{ zipper_path=
	         Zipper_Step{zipper_step_self}
	         : _ } -> Just zipper_step_self
	 _ -> Nothing

-- ** Axis child

zipper_child :: Ord k => Zipper k x -> [Zipper k x]
zipper_child z =
	maybeToList (zipper_child_first z)
	>>= zipper_collect zipper_foll

zipper_child_at :: Ord k => k -> Zipper k x -> Maybe (Zipper k x)
zipper_child_at k (Zipper path (TreeMap m)) =
	case Map.splitLookup k m of
	 (_, Nothing, _) -> Nothing
	 (ps, Just s, fs) ->
		Just Zipper
		 { zipper_path = Zipper_Step (TreeMap ps) (k, s) (TreeMap fs) : path
		 , zipper_curr = TreeMap.node_descendants s
		 }

zipper_child_first :: Zipper k x -> Maybe (Zipper k x)
zipper_child_first (Zipper path (TreeMap m)) =
	case Map.minViewWithKey m of
	 Nothing -> Nothing
	 Just ((k', s'), fs') ->
		Just Zipper
		 { zipper_path = Zipper_Step TreeMap.empty (k', s') (TreeMap fs') : path
		 , zipper_curr = TreeMap.node_descendants s'
		 }

zipper_child_last :: Zipper k x -> Maybe (Zipper k x)
zipper_child_last (Zipper path (TreeMap m)) =
	case Map.maxViewWithKey m of
	 Nothing -> Nothing
	 Just ((k', s'), ps') ->
		Just Zipper
		 { zipper_path = Zipper_Step (TreeMap ps') (k', s') TreeMap.empty : path
		 , zipper_curr = TreeMap.node_descendants s'
		 }

-- ** Axis ancestor

zipper_ancestor :: Ord k => Zipper k x -> [Zipper k x]
zipper_ancestor = zipper_collect_without_self zipper_parent

zipper_ancestor_or_self :: Ord k => Zipper k x -> [Zipper k x]
zipper_ancestor_or_self = zipper_collect zipper_parent

-- ** Axis descendant

zipper_descendant_or_self :: Ord k => Zipper k x -> [Zipper k x]
zipper_descendant_or_self =
	collect_child []
	where
		collect_child acc z =
			z : maybe acc
			 (collect_foll acc)
			 (zipper_child_first z)
		collect_foll  acc z =
			collect_child
			 (maybe acc
				 (collect_foll acc)
				 (zipper_foll z)
			 ) z

zipper_descendant_or_self_reverse :: Ord k => Zipper k x -> [Zipper k x]
zipper_descendant_or_self_reverse z =
	z : List.concatMap
	 zipper_descendant_or_self_reverse
	 (List.reverse $ zipper_child z)

zipper_descendant :: Ord k => Zipper k x -> [Zipper k x]
zipper_descendant = List.tail . zipper_descendant_or_self

zipper_descendant_at :: Ord k => TreeMap.Path k -> Zipper k x -> Maybe (Zipper k x)
zipper_descendant_at (k:|ks) =
	case ks of
	 []     -> zipper_child_at k
	 k':ks' -> zipper_child_at k >=> zipper_descendant_at (k':|ks')

-- ** Axis preceding

zipper_prec :: Ord k => Zipper k x -> Maybe (Zipper k x)
zipper_prec (Zipper path _curr) =
	case path of
	 [] -> Nothing
	 Zipper_Step (TreeMap ps) (k, s) (TreeMap fs):steps ->
		case Map.maxViewWithKey ps of
		 Nothing -> Nothing
		 Just ((k', s'), ps') ->
			Just Zipper
			 { zipper_path = Zipper_Step (TreeMap ps')
			                             (k', s')
			                             (TreeMap $ Map.insert k s fs)
			                 : steps
			 , zipper_curr = TreeMap.node_descendants s'
			 }

zipper_preceding :: Ord k => Zipper k x -> [Zipper k x]
zipper_preceding =
	zipper_ancestor_or_self >=>
	zipper_preceding_sibling >=>
	zipper_descendant_or_self_reverse

zipper_preceding_sibling :: Ord k => Zipper k x -> [Zipper k x]
zipper_preceding_sibling = zipper_collect_without_self zipper_prec

-- ** Axis following

zipper_foll :: Ord k => Zipper k x -> Maybe (Zipper k x)
zipper_foll (Zipper path _curr) =
	case path of
	 [] -> Nothing
	 Zipper_Step (TreeMap ps) (k, s) (TreeMap fs):steps ->
		case Map.minViewWithKey fs of
		 Nothing -> Nothing
		 Just ((k', s'), fs') ->
			Just Zipper
			 { zipper_path = Zipper_Step (TreeMap $ Map.insert k s ps)
			                             (k', s')
			                             (TreeMap fs')
			                 : steps
			 , zipper_curr = TreeMap.node_descendants s'
			 }

zipper_following :: Ord k => Zipper k x -> [Zipper k x]
zipper_following =
	zipper_ancestor_or_self >=>
	zipper_following_sibling >=>
	zipper_descendant_or_self

zipper_following_sibling :: Ord k => Zipper k x -> [Zipper k x]
zipper_following_sibling = zipper_collect_without_self zipper_foll

-- ** Axis parent

zipper_parent :: Ord k => Zipper k x -> Maybe (Zipper k x)
zipper_parent (Zipper path curr) =
	case path of
	 [] -> Nothing
	 Zipper_Step (TreeMap ps) (k, s) (TreeMap fs):steps ->
		let node = TreeMap.Node
			 { TreeMap.node_value       = TreeMap.node_value s
			 , TreeMap.node_size        = TreeMap.size curr
			 , TreeMap.node_descendants = curr
			 } in
		Just Zipper
		 { zipper_path = steps
		 , zipper_curr = TreeMap $ Map.union ps $ Map.insert k node fs
		 }