-- For ShowLetName
{-# LANGUAGE AllowAmbiguousTypes #-}
-- For makeSharingName
{-# LANGUAGE BangPatterns #-}
-- For ShowLetName
{-# LANGUAGE DataKinds #-}
-- For SharingName
{-# LANGUAGE ExistentialQuantification #-}

-- {-# LANGUAGE MagicHash #-} -- For unsafeCoerce#

-- | This module provides the 'SharingObserver' semantic
-- which interprets combinators to observe @let@ definitions
-- at the host language level (Haskell),
-- effectively turning infinite values into finite ones,
-- which is useful for example to inspect
-- and optimize recursive grammars.
--
-- Inspired by Andy Gill's [Type-safe observable sharing in Haskell](https://doi.org/10.1145/1596638.1596653).
-- For an example, see [symantic-parser](https://hackage.haskell.org/package/symantic-parser).
module Symantic.Semantics.SharingObserver where

import Control.Applicative (Applicative (..))
import Control.Monad (Monad (..))
import Control.Monad.Trans.Class qualified as MT
import Control.Monad.Trans.Reader qualified as MT
import Control.Monad.Trans.State qualified as MT
import Control.Monad.Trans.Writer qualified as MT
import Data.Bool
import Data.Eq (Eq (..))
import Data.Function (($), (.))
import Data.Functor (Functor, (<$>))
import Data.Functor.Compose (Compose (..))
import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as HM
import Data.HashSet (HashSet)
import Data.HashSet qualified as HS
import Data.Hashable (Hashable, hash, hashWithSalt)
import Data.Int (Int)
import Data.Maybe (Maybe (..), isNothing)
import Data.Monoid (Monoid (..))
import Data.Ord (Ord (..))
import GHC.StableName (StableName (..), eqStableName, hashStableName, makeStableName)
import Prelude (error, (+))
import System.IO (IO)
import System.IO.Unsafe (unsafePerformIO)
import Text.Show (Show (..))

import Symantic.Syntaxes.Derive

-- * Class 'Referenceable'

-- | This class is not for end-users like usual symantic operators,
-- though it will have to be defined on end-users' interpreters.
class Referenceable letName sem where
  -- | @('ref' isRec letName)@ is a reference to @(letName)@.
  -- It is introduced by 'sharingObserver'.
  -- @(isRec)@ is 'True' iif. this 'ref'erence is recursive,
  -- ie. appears within its 'define'.
  --
  -- TODO: index 'letName' with 'a' to enable dependent-map
  ref :: Bool -> letName -> sem a
  ref isRec name = liftDerived (ref isRec name)
  default ref ::
    FromDerived (Referenceable letName) sem =>
    Bool ->
    letName ->
    sem a

-- * Class 'Definable'

-- | This class is not for end-users like usual symantic operators.
-- There should be not need to use it outside this module,
-- because used 'define's are gathered in 'Letsable'.
class Definable letName sem where
  -- | @('define' letName sub)@ let-binds @(letName)@ to be equal to @(sub)@.
  -- This is a temporary node either replaced
  -- by 'ref' and an entry in 'lets''s 'LetBindings',
  -- or removed when no 'ref'erence is made to it.
  define :: letName -> sem a -> sem a
  define name = liftDerived1 (define name)
  default define ::
    FromDerived1 (Definable letName) sem =>
    letName ->
    sem a ->
    sem a

-- * Class 'MakeLetName'
class MakeLetName letName where
  makeLetName :: SharingName -> IO letName

-- * Type 'SharingName'

-- | Note that the observable sharing enabled by 'StableName'
-- is not perfect as it will not observe all the sharing explicitely done.
--
-- Note also that the observed sharing could be different between ghc and ghci.
data SharingName = forall a. SharingName (StableName a)

-- | @('makeSharingName' x)@ is like @('makeStableName' x)@ but it also forces
-- evaluation of @(x)@ to ensure that the 'StableName' is correct first time,
-- which avoids to produce a tree bigger than needed.
--
-- Note that this function uses 'unsafePerformIO' instead of returning in 'IO',
-- this is apparently required to avoid infinite loops due to unstable 'StableName'
-- in compiled code, and sometimes also in ghci.
--
-- Note that maybe [pseq should be used here](https://gitlab.haskell.org/ghc/ghc/-/issues/2916).
makeSharingName :: a -> SharingName
makeSharingName !x = SharingName $ unsafePerformIO $ makeStableName x

instance Eq SharingName where
  SharingName x == SharingName y = eqStableName x y
instance Hashable SharingName where
  hash (SharingName n) = hashStableName n
  hashWithSalt salt (SharingName n) = hashWithSalt salt n

{-
instance Show SharingName where
  showsPrec _ (SharingName n) = showHex (I# (unsafeCoerce# n))
-}

-- * Type 'SharingObserver'
newtype SharingObserver letName sem a = SharingObserver
  { unSharingObserver ::
      MT.ReaderT
        (HashSet SharingName)
        (MT.State (SharingObserverState letName))
        (SharingFinalizer letName sem a)
  }

-- | Interpreter detecting some (Haskell embedded) @let@-like definitions
-- used at least once and/or recursively, in order to replace them
-- with the 'lets' and 'ref' combinators.
-- See [Type-safe observable sharing in Haskell](https://doi.org/10.1145/1596638.1596653)
--
-- Beware not to apply 'sharingObserver' more than once on the same term
-- otherwise some 'define' introduced by the first call
-- would be removed by the second call.
sharingObserver ::
  Eq letName =>
  Hashable letName =>
  Show letName =>
  SharingObserver letName sem a ->
  WithSharing letName sem a
sharingObserver (SharingObserver m) =
  let (fs, st) =
        MT.runReaderT m mempty
          `MT.runState` SharingObserverState
            { sharingObserverStateRefs = HM.empty
            , sharingObserverStateRecs = HS.empty
            }
   in let refs =
            HS.fromList
              [ letName
              | (letName, refCount) <- HM.elems (sharingObserverStateRefs st)
              , refCount > 0
              ]
       in -- trace (show refs) $
          MT.runWriter $
            (`MT.runReaderT` refs) $
              unSharingFinalizer fs

-- ** Type 'WithSharing'
type WithSharing letName sem a =
  (sem a, HM.HashMap letName (SomeLet sem))

{-
-- * Type 'WithSharing'
data WithSharing letName sem a = WithSharing
  { lets :: HM.HashMap letName (SomeLet sem)
  , body :: sem a
  }
mapWithSharing ::
  (forall v. sem v -> sem v) ->
  WithSharing letName sem a ->
  WithSharing letName sem a
mapWithSharing f ws = WithSharing
  { lets = (\(SomeLet sem) -> SomeLet (f sem)) <$> lets ws
  , body = f (body ws)
  }
-}

-- ** Type 'SharingObserverState'
data SharingObserverState letName = SharingObserverState
  { sharingObserverStateRefs :: HashMap SharingName (letName, Int)
  , sharingObserverStateRecs :: HashSet SharingName
  }

sharingObserverNode ::
  Eq letName =>
  Hashable letName =>
  Show letName =>
  Referenceable letName sem =>
  MakeLetName letName =>
  SharingObserver letName sem a ->
  SharingObserver letName sem a
sharingObserverNode (SharingObserver m) = SharingObserver $ do
  let nodeName = makeSharingName m
  st <- MT.lift MT.get
  ((letName, seenBefore), seen) <-
    getCompose $
      HM.alterF
        ( \seenBefore ->
            -- Compose is used to return (letName, seenBefore) along seen
            -- in the same HashMap lookup.
            Compose $
              return $ case seenBefore of
                Nothing ->
                  ((letName, seenBefore), Just (letName, 0))
                  where
                    letName = unsafePerformIO $ makeLetName nodeName
                Just (letName, refCount) ->
                  ((letName, seenBefore), Just (letName, refCount + 1))
        )
        nodeName
        (sharingObserverStateRefs st)
  parentNames <- MT.ask
  if nodeName `HS.member` parentNames
    then do
      -- recursive reference to nodeName:
      -- update seen references
      -- and mark nodeName as recursive
      MT.lift $
        MT.put
          st
            { sharingObserverStateRefs = seen
            , sharingObserverStateRecs = HS.insert nodeName (sharingObserverStateRecs st)
            }
      return $ ref True letName
    else do
      -- non-recursive reference to nodeName:
      -- update seen references
      -- and recurse if the nodeName hasn't been seen before
      -- (would be in a preceding sibling branch, not in parentNames).
      MT.lift $ MT.put st{sharingObserverStateRefs = seen}
      if isNothing seenBefore
        then MT.local (HS.insert nodeName) (define letName <$> m)
        else return $ ref False letName

type instance Derived (SharingObserver letName sem) = SharingFinalizer letName sem
instance
  ( Referenceable letName sem
  , MakeLetName letName
  , Eq letName
  , Hashable letName
  , Show letName
  ) =>
  LiftDerived (SharingObserver letName sem)
  where
  liftDerived = sharingObserverNode . SharingObserver . return
instance
  ( Referenceable letName sem
  , MakeLetName letName
  , Eq letName
  , Hashable letName
  , Show letName
  ) =>
  LiftDerived1 (SharingObserver letName sem)
  where
  liftDerived1 f a =
    sharingObserverNode $
      SharingObserver $
        f <$> unSharingObserver a
instance
  ( Referenceable letName sem
  , MakeLetName letName
  , Eq letName
  , Hashable letName
  , Show letName
  ) =>
  LiftDerived2 (SharingObserver letName sem)
  where
  liftDerived2 f a b =
    sharingObserverNode $
      SharingObserver $
        f
          <$> unSharingObserver a
          <*> unSharingObserver b
instance
  ( Referenceable letName sem
  , MakeLetName letName
  , Eq letName
  , Hashable letName
  , Show letName
  ) =>
  LiftDerived3 (SharingObserver letName sem)
  where
  liftDerived3 f a b c =
    sharingObserverNode $
      SharingObserver $
        f
          <$> unSharingObserver a
          <*> unSharingObserver b
          <*> unSharingObserver c
instance
  ( Referenceable letName sem
  , MakeLetName letName
  , Eq letName
  , Hashable letName
  , Show letName
  ) =>
  LiftDerived4 (SharingObserver letName sem)
  where
  liftDerived4 f a b c d =
    sharingObserverNode $
      SharingObserver $
        f
          <$> unSharingObserver a
          <*> unSharingObserver b
          <*> unSharingObserver c
          <*> unSharingObserver d
instance Referenceable letName (SharingObserver letName sem) where
  ref = error "[BUG]: sharingObserver MUST NOT be applied twice"
instance Definable letName (SharingObserver letName sem) where
  define = error "[BUG]: sharingObserver MUST NOT be applied twice"
instance Letsable letName (SharingObserver letName sem) where
  lets = error "[BUG]: sharingObserver MUST NOT be applied twice"

-- * Type 'SharingFinalizer'

-- | Remove 'define' when non-recursive or unused
-- or replace it by 'ref', moving 'define's to the top.
newtype SharingFinalizer letName sem a = SharingFinalizer
  { unSharingFinalizer ::
      MT.ReaderT
        (HS.HashSet letName)
        (MT.Writer (LetBindings letName sem))
        (sem a)
  }

type instance Derived (SharingFinalizer _letName sem) = sem
instance
  (Eq letName, Hashable letName) =>
  LiftDerived (SharingFinalizer letName sem)
  where
  liftDerived = SharingFinalizer . pure
instance
  (Eq letName, Hashable letName) =>
  LiftDerived1 (SharingFinalizer letName sem)
  where
  liftDerived1 f a = SharingFinalizer $ f <$> unSharingFinalizer a
instance
  (Eq letName, Hashable letName) =>
  LiftDerived2 (SharingFinalizer letName sem)
  where
  liftDerived2 f a b =
    SharingFinalizer $
      f
        <$> unSharingFinalizer a
        <*> unSharingFinalizer b
instance
  (Eq letName, Hashable letName) =>
  LiftDerived3 (SharingFinalizer letName sem)
  where
  liftDerived3 f a b c =
    SharingFinalizer $
      f
        <$> unSharingFinalizer a
        <*> unSharingFinalizer b
        <*> unSharingFinalizer c
instance
  (Eq letName, Hashable letName) =>
  LiftDerived4 (SharingFinalizer letName sem)
  where
  liftDerived4 f a b c d =
    SharingFinalizer $
      f
        <$> unSharingFinalizer a
        <*> unSharingFinalizer b
        <*> unSharingFinalizer c
        <*> unSharingFinalizer d
instance
  ( Referenceable letName sem
  , Eq letName
  , Hashable letName
  , Show letName
  ) =>
  Referenceable letName (SharingFinalizer letName sem)
  where
  ref isRec = liftDerived . ref isRec
instance
  ( Referenceable letName sem
  , Eq letName
  , Hashable letName
  , Show letName
  ) =>
  Definable letName (SharingFinalizer letName sem)
  where
  define name body = SharingFinalizer $ do
    refs <- MT.ask
    let (sem, defs) =
          MT.runWriter $ MT.runReaderT (unSharingFinalizer body) refs
    if name `HS.member` refs
      then do
        -- This 'define' is 'ref'erenced: move it into the result,
        -- to put it in scope even when some 'ref' to it exists outside of 'body'
        -- (which can happen when a body-expression is shared),
        -- and replace it by a 'ref'.
        MT.lift $ MT.tell $ HM.insert name (SomeLet sem) defs
        return $ ref False name
      else -- Remove this unreferenced 'define' node.
        unSharingFinalizer body

-- * Class 'Letsable'
class Letsable letName sem where
  -- | @('lets' defs x)@ let-binds @(defs)@ in @(x)@.
  lets :: LetBindings letName sem -> sem a -> sem a
  lets defs = liftDerived1 (lets ((\(SomeLet val) -> SomeLet (derive val)) <$> defs))
  default lets ::
    Derivable sem =>
    FromDerived1 (Letsable letName) sem =>
    LetBindings letName sem ->
    sem a ->
    sem a

-- ** Type 'SomeLet'
data SomeLet sem = forall a. SomeLet (sem a)

-- ** Type 'LetBindings'
type LetBindings letName sem = HM.HashMap letName (SomeLet sem)

{-
-- | Not used but can be written nonetheless.
instance
  ( Letsable letName sem
  , Eq letName
  , Hashable letName
  , Show letName
  ) => Letsable letName (SharingFinalizer letName sem) where
  lets defs x = SharingFinalizer $ do
    ds <- traverse (\(SomeLet v) -> do
      r <- unSharingFinalizer v
      return (SomeLet r)
      ) defs
    MT.lift $ MT.tell ds
    unSharingFinalizer x
-}

-- ** Type 'OpenLetRecs'

-- | Mutually recursive terms, in open recursion style.
type OpenLetRecs letName a = LetRecs letName (OpenLetRec letName a)

-- | Mutually recursive term, in open recursion style.
-- The term is given a @final@ (aka. @self@) map
-- of other terms it can refer to (including itself).
type OpenLetRec letName a = LetRecs letName a -> a

-- | Recursive let bindings.
type LetRecs letName = HM.HashMap letName

-- | Least fixpoint combinator.
fix :: (a -> a) -> a
fix f = final where final = f final

-- | Least fixpoint combinator of mutually recursive terms.
-- @('mutualFix' opens)@ takes a container of terms
-- in the open recursion style @(opens)@,
-- and return that container of terms with their knots tied-up.
--
-- Used to express mutual recursion and to transparently introduce memoization.
--
-- Here all mutually dependent functions are restricted to the same polymorphic type @(a)@.
-- See http://okmij.org/ftp/Computation/fixed-point-combinators.html#Poly-variadic
mutualFix :: forall recs a. Functor recs => recs ({-finals-} recs a -> a) -> recs a
mutualFix opens = fix f
  where
    f :: recs a -> recs a
    f recs = ($ recs) <$> opens