{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE OverloadedStrings #-}
module Data.TreeSeq.Strict where

import Control.Applicative (Applicative(..))
import Control.Monad (Monad(..), ap)
import Data.Bool
import Data.Eq (Eq(..))
import Data.Foldable (Foldable(..))
import Data.Foldable (foldr)
import Data.Function (($), (.))
import Data.Functor (Functor, (<$>))
import Data.Maybe (Maybe(..))
import Data.Ord (Ord(..))
import Data.Semigroup (Semigroup(..))
import Data.Sequence (Seq, ViewL(..))
import Data.Text (Text)
import Data.Traversable (Traversable(..))
import Text.Show (Show(..))
import qualified Data.List as List
import qualified Data.Sequence as Seq
import qualified Data.Text as Text

-- * Type 'Tree'
data Tree k a
 =   TreeN !k !(Trees k a)
 |   Tree0 !a
 deriving (Eq, Ord, Show, Functor)

instance Traversable (Tree k) where
	traverse f (Tree0 a)    = Tree0 <$> f a
	traverse f (TreeN k ts) = TreeN k <$> traverse (traverse f) ts
	sequenceA (Tree0 a)     = Tree0 <$> a
	sequenceA (TreeN k ts)  = TreeN k <$> traverse sequenceA ts
instance Foldable (Tree k) where
	foldMap f (TreeN _k ts) = foldMap (foldMap f) ts
	foldMap f (Tree0 a)     = f a
instance Applicative (Tree k) where
	pure  = Tree0
	(<*>) = ap
instance Monad (Tree k) where
	return = Tree0
	Tree0 v >>= f = f v
	TreeN k ts >>= f =
		TreeN k $ (>>= f) <$> ts

isTree0 :: Tree k a -> Bool
isTree0 Tree0{} = True
isTree0 _       = False

isTreeN :: Tree k a -> Bool
isTreeN TreeN{} = True
isTreeN _       = False

unTree :: Tree a a -> a
unTree (TreeN k _) = k
unTree (Tree0 a)   = a

mapWithNode :: (Maybe k -> a -> b) -> Tree k a -> Tree k b
mapWithNode = go Nothing
	where
	go _k f (TreeN k ts) = TreeN k (go (Just k) f <$> ts)
	go k  f (Tree0 a)    = Tree0 (f k a)

mapAlsoNode :: (k -> l) -> (Maybe k -> a -> b) -> Tree k a -> Tree l b
mapAlsoNode fk fv = go Nothing
	where
	go _k (TreeN k ts) = TreeN (fk k) $ go (Just k) <$> ts
	go k  (Tree0 a)    = Tree0 (fv k a)

traverseWithNode :: Applicative f => (Maybe k -> a -> f b) -> Tree k a -> f (Tree k b)
traverseWithNode = go Nothing
	where
	go _p f (TreeN k ts) = TreeN k <$> traverse (go (Just k) f) ts
	go p  f (Tree0 a)    = Tree0 <$> f p a

foldlWithTree :: (b -> Tree k a -> b) -> b -> Tree k a -> b
foldlWithTree f b t =
	case t of
	 TreeN _k ts -> foldl' (foldlWithTree f) (f b t) ts
	 Tree0{} -> f b t

bindTree :: Tree k a -> (Tree k a -> Tree l b) -> Tree l b
bindTree t f =
	case t of
	 Tree0{} -> f t
	 TreeN _k ks ->
		case f t of
		 u@Tree0{}  -> u
		 TreeN l ls -> TreeN l $ ls <> ((`bindTree` f) <$> ks)

bindTrees :: Tree k a -> (Tree k a -> Trees l b) -> Trees l b
bindTrees t f =
	case t of
	 Tree0{} -> f t
	 TreeN _k ks ->
		f t >>= \fs ->
			case fs of
			 Tree0 b  -> Seq.singleton $ Tree0 b
			 TreeN l ls -> pure $ TreeN l $ ls <> (ks >>= (`bindTrees` f))

joinTrees :: Trees k (Trees k a) -> Trees k a
joinTrees ts =
	ts >>= \case
	 Tree0 s -> s
	 TreeN k ks -> Seq.singleton $ TreeN k $ joinTrees ks

-- * Type 'Trees'
type Trees k a = Seq (Tree k a)

-- * Type 'Pretty'
newtype Pretty k a = Pretty (Trees k a)
instance (Show k, Show a) => Show (Pretty k a) where
	show (Pretty t) = Text.unpack $ prettyTrees t

prettyTree :: (Show k, Show a) => Tree k a -> Text
prettyTree = Text.unlines . pretty

prettyTrees :: (Show k, Show a) => Trees k a -> Text
prettyTrees = foldr (\t acc -> prettyTree t <> "\n" <> acc) ""

pretty :: (Show k, Show a) => Tree k a -> [Text]
pretty (Tree0 a)     = [Text.pack (show a)]
pretty (TreeN k ts0) = Text.pack (show k) : prettySubTrees ts0
	where
	prettySubTrees s =
		case Seq.viewl s of
		 Seq.EmptyL -> []
		 t:<ts | Seq.null ts -> "|" : shift "`- " "   " (pretty t)
		       | otherwise   -> "|" : shift "+- " "|  " (pretty t) <> prettySubTrees ts
	shift first other = List.zipWith (<>) (first : List.repeat other)