{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TemplateHaskell #-}
-- | Haskell terms which are interesting
-- to pattern-match when optimizing.
module Symantic.Parser.Grammar.Pure where

import Data.Bool (Bool(..))
import Data.Either (Either(..))
import Data.Eq (Eq)
import Data.Maybe (Maybe(..))
import Data.Ord (Ord(..))
import Data.Kind (Type)
import Text.Show (Show(..), showParen, showString)
import qualified Data.Eq as Eq
import qualified Data.Function as Function
import qualified Language.Haskell.TH as TH
import qualified Language.Haskell.TH.Syntax as TH

import Symantic.Univariant.Trans

-- * Type 'ValueCode'
-- | Compile-time 'value' and corresponding 'code'
-- (that can produce that value at runtime).
data ValueCode a = ValueCode
  { value :: Value a
  , code :: TH.CodeQ a
  }
getValue :: ValueCode a -> a
getValue = unValue Function.. value
getCode :: ValueCode a -> TH.CodeQ a
getCode = code

-- ** Type 'Value'
newtype Value a = Value { unValue :: a }

-- * Class 'CombPurable'
-- | Final encoding of 'CombPure',
-- extended with useful terms.
class CombPurable (repr :: Type -> Type) where
  (.) :: repr ((b->c) -> (a->b) -> a -> c)
  ($) :: repr ((a->b) -> a -> b)
  (.@) :: repr (a->b) -> repr a -> repr b
  bool :: Bool -> repr Bool
  char :: TH.Lift tok => tok -> repr tok
  cons :: repr (a -> [a] -> [a])
  const :: repr (a -> b -> a)
  eq :: Eq a => repr a -> repr (a -> Bool)
  flip :: repr ((a -> b -> c) -> b -> a -> c)
  id :: repr (a->a)
  nil :: repr [a]
  unit :: repr ()
  left :: repr (l -> Either l r)
  right :: repr (r -> Either l r)
  nothing :: repr (Maybe a)
  just :: repr (a -> Maybe a)

-- ** Type 'CombPurable'
-- | Initial encoding of 'CombPurable',
-- useful for some optimizations in 'optimizeComb'.
data CombPure a where
  CombPure :: ValueCode a -> CombPure a
  (:.) :: CombPure ((b->c) -> (a->b) -> a -> c)
  (:$) :: CombPure ((a->b) -> a -> b)
  (:@) :: CombPure (a->b) -> CombPure a -> CombPure b
  Cons :: CombPure (a -> [a] -> [a])
  Const :: CombPure (a -> b -> a)
  Eq :: Eq a => CombPure a -> CombPure (a -> Bool)
  Flip :: CombPure ((a -> b -> c) -> b -> a -> c)
  Id :: CombPure (a->a)
  Unit :: CombPure ()
infixr 0 $, :$
infixr 9 ., :.
infixl 9 .@, :@

{-
pattern (:.@) ::
  -- Dummy constraint to get the following constraint
  -- in scope when pattern-matching.
  () =>
  ((x -> y -> z) ~ ((b -> c) -> (a -> b) -> a -> c)) =>
  CombPure x -> CombPure y -> CombPure z
pattern (:.@) f g = (:.) :@ f :@ g
pattern FlipApp ::
  () =>
  ((x -> y) ~ ((a -> b -> c) -> b -> a -> c)) =>
  CombPure x -> CombPure y
pattern FlipApp f = Flip :@ f
pattern FlipConst ::
  () =>
  (x ~ (a -> b -> b)) =>
  CombPure x
pattern FlipConst = FlipApp Const
-}

instance Show (CombPure a) where
  showsPrec p = \case
    CombPure{} -> showString "CombPure"
    (:$) -> showString "($)"
    (:.) :@ f :@ g ->
      showParen (p >= 9)
      Function.$ showsPrec 9 f
      Function.. showString " . "
      Function.. showsPrec 9 g
    (:.) -> showString "(.)"
    Cons :@ x :@ xs ->
      showParen (p >= 10)
      Function.$ showsPrec 10 x
      Function.. showString " : "
      Function.. showsPrec 10 xs
    Cons -> showString "cons"
    Const -> showString "const"
    Eq x ->
      showParen True
      Function.$ showString "== "
      Function.. showsPrec 0 x
    Flip -> showString "flip"
    Id -> showString "id"
    Unit -> showString "()"
    (:@) f x ->
      showParen (p >= 10)
      Function.$ showsPrec 10 f
      Function.. showString " "
      Function.. showsPrec 10 x

instance Trans CombPure TH.CodeQ where
  trans = code Function.. trans
instance Trans CombPure Value where
  trans = value Function.. trans
instance Trans CombPure ValueCode where
  trans = \case
    CombPure x -> x
    (:.) -> (.)
    (:$) -> ($)
    (:@) f x -> (.@) (trans f) (trans x)
    Cons -> cons
    Const -> const
    Eq x -> eq (trans x)
    Flip -> flip
    Id -> id
    Unit -> unit
instance Trans ValueCode CombPure where
  trans = CombPure
type instance Output CombPure = ValueCode

instance CombPurable CombPure where
  (.)     = (:.)
  ($)     = (:$)
  -- Small optimizations, mainly to reduce dump sizes.
  Id .@ x = x
  (Const :@ x) .@ _y = x
  ((Flip :@ Const) :@ _x) .@ y = y
  --
  f .@ x  = f :@ x
  cons    = Cons
  const   = Const
  eq      = Eq
  flip    = Flip
  id      = Id
  unit    = Unit
  bool b  = CombPure (bool b)
  char c  = CombPure (char c)
  nil     = CombPure nil
  left    = CombPure left
  right   = CombPure right
  nothing = CombPure nothing
  just    = CombPure just
instance CombPurable ValueCode where
  (.)      = ValueCode (.) (.)
  ($)      = ValueCode ($) ($)
  (.@) f x = ValueCode ((.@) (value f) (value x)) ((.@) (code f) (code x))
  bool b   = ValueCode (bool b) (bool b)
  char c   = ValueCode (char c) (char c)
  cons     = ValueCode cons cons
  const    = ValueCode const const
  eq x     = ValueCode (eq (value x)) (eq (code x))
  flip     = ValueCode flip flip
  id       = ValueCode id id
  nil      = ValueCode nil nil
  unit     = ValueCode unit unit
  left     = ValueCode left left
  right    = ValueCode right right
  nothing  = ValueCode nothing nothing
  just     = ValueCode just just
instance CombPurable Value where
  (.)      = Value (Function..)
  ($)      = Value (Function.$)
  (.@) f x = Value (unValue f (unValue x))
  bool     = Value
  char     = Value
  cons     = Value (:)
  const    = Value Function.const
  eq x     = Value (unValue x Eq.==)
  flip     = Value Function.flip
  id       = Value Function.id
  nil      = Value []
  unit     = Value ()
  left     = Value Left
  right    = Value Right
  nothing  = Value Nothing
  just     = Value Just
instance CombPurable TH.CodeQ where
  (.)      = [|| (Function..) ||]
  ($)      = [|| (Function.$) ||]
  (.@) f x = [|| $$f $$x ||]
  bool b   = [|| b ||]
  char c   = [|| c ||]
  cons     = [|| (:) ||]
  const    = [|| Function.const ||]
  eq x     = [|| ($$x Eq.==) ||]
  flip     = [|| \f x y -> f y x ||]
  id       = [|| \x -> x ||]
  nil      = [|| [] ||]
  unit     = [|| () ||]
  left     = [|| Left ||]
  right    = [|| Right ||]
  nothing  = [|| Nothing ||]
  just     = [|| Just ||]