{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TemplateHaskell #-}
-- | Haskell terms which are interesting
-- to pattern-match when optimizing.
module Symantic.Parser.Haskell where

import Data.Bool (Bool(..))
import Data.Either (Either(..))
import Data.Eq (Eq)
import Data.Maybe (Maybe(..))
import Data.Ord (Ord(..))
import Data.Kind (Type)
import Text.Show (Show(..), showParen, showString)
import qualified Data.Eq as Eq
import qualified Data.Function as Function
import qualified Language.Haskell.TH as TH
import qualified Language.Haskell.TH.Syntax as TH

import Symantic.Univariant.Trans

-- * Type 'ValueCode'
-- | Compile-time 'value' and corresponding 'code'
-- (that can produce that value at runtime).
data ValueCode a = ValueCode
  { value :: Value a
  , code :: TH.CodeQ a
  }
getValue :: ValueCode a -> a
getValue = unValue Function.. value
getCode :: ValueCode a -> TH.CodeQ a
getCode = code

-- ** Type 'Value'
newtype Value a = Value { unValue :: a }

-- * Class 'Haskellable'
-- | Final encoding of some Haskell functions
-- useful for some optimizations in 'optimizeComb'.
class Haskellable (repr :: Type -> Type) where
  (.) :: repr ((b->c) -> (a->b) -> a -> c)
  ($) :: repr ((a->b) -> a -> b)
  (.@) :: repr (a->b) -> repr a -> repr b
  bool :: Bool -> repr Bool
  char :: TH.Lift tok => tok -> repr tok
  cons :: repr (a -> [a] -> [a])
  const :: repr (a -> b -> a)
  eq :: Eq a => repr a -> repr (a -> Bool)
  flip :: repr ((a -> b -> c) -> b -> a -> c)
  id :: repr (a->a)
  nil :: repr [a]
  unit :: repr ()
  left :: repr (l -> Either l r)
  right :: repr (r -> Either l r)
  nothing :: repr (Maybe a)
  just :: repr (a -> Maybe a)

-- ** Type 'Haskellable'
-- | Initial encoding of 'Haskellable'.
data Haskell a where
  Haskell :: ValueCode a -> Haskell a
  (:.) :: Haskell ((b->c) -> (a->b) -> a -> c)
  (:$) :: Haskell ((a->b) -> a -> b)
  (:@) :: Haskell (a->b) -> Haskell a -> Haskell b
  Cons :: Haskell (a -> [a] -> [a])
  Const :: Haskell (a -> b -> a)
  Eq :: Eq a => Haskell a -> Haskell (a -> Bool)
  Flip :: Haskell ((a -> b -> c) -> b -> a -> c)
  Id :: Haskell (a->a)
  Unit :: Haskell ()
infixr 0 $, :$
infixr 9 ., :.
infixl 9 .@, :@

{-
pattern (:.@) ::
  -- Dummy constraint to get the following constraint
  -- in scope when pattern-matching.
  () =>
  ((x -> y -> z) ~ ((b -> c) -> (a -> b) -> a -> c)) =>
  Haskell x -> Haskell y -> Haskell z
pattern (:.@) f g = (:.) :@ f :@ g
pattern FlipApp ::
  () =>
  ((x -> y) ~ ((a -> b -> c) -> b -> a -> c)) =>
  Haskell x -> Haskell y
pattern FlipApp f = Flip :@ f
pattern FlipConst ::
  () =>
  (x ~ (a -> b -> b)) =>
  Haskell x
pattern FlipConst = FlipApp Const
-}

instance Show (Haskell a) where
  showsPrec p = \case
    Haskell{} -> showString "Haskell"
    (:$) -> showString "($)"
    (:.) :@ f :@ g ->
      showParen (p >= 9)
      Function.$ showsPrec 9 f
      Function.. showString " . "
      Function.. showsPrec 9 g
    (:.) -> showString "(.)"
    Cons :@ x :@ xs ->
      showParen (p >= 10)
      Function.$ showsPrec 10 x
      Function.. showString " : "
      Function.. showsPrec 10 xs
    Cons -> showString "cons"
    Const -> showString "const"
    Eq x ->
      showParen True
      Function.$ showString "== "
      Function.. showsPrec 0 x
    Flip -> showString "flip"
    Id -> showString "id"
    Unit -> showString "()"
    (:@) f x ->
      showParen (p >= 10)
      Function.$ showsPrec 10 f
      Function.. showString " "
      Function.. showsPrec 10 x
instance Trans Haskell Value where
  trans = value Function.. trans
instance Trans Haskell TH.CodeQ where
  trans = code Function.. trans
instance Trans Haskell ValueCode where
  trans = \case
    Haskell x -> x
    (:.) -> (.)
    (:$) -> ($)
    (:@) f x -> (.@) (trans f) (trans x)
    Cons -> cons
    Const -> const
    Eq x -> eq (trans x)
    Flip -> flip
    Id -> id
    Unit -> unit
instance Trans ValueCode Haskell where
  trans = Haskell
type instance Output Haskell = ValueCode

instance Haskellable Haskell where
  (.)     = (:.)
  ($)     = (:$)
  -- Small optimizations, mainly to reduce dump sizes.
  Id .@ x = x
  (Const :@ x) .@ _y = x
  ((Flip :@ Const) :@ _x) .@ y = y
  --
  f .@ x  = f :@ x
  cons    = Cons
  const   = Const
  eq      = Eq
  flip    = Flip
  id      = Id
  unit    = Unit
  bool b  = Haskell (bool b)
  char c  = Haskell (char c)
  nil     = Haskell nil
  left    = Haskell left
  right   = Haskell right
  nothing = Haskell nothing
  just    = Haskell just
instance Haskellable ValueCode where
  (.)      = ValueCode (.) (.)
  ($)      = ValueCode ($) ($)
  (.@) f x = ValueCode ((.@) (value f) (value x)) ((.@) (code f) (code x))
  bool b   = ValueCode (bool b) (bool b)
  char c   = ValueCode (char c) (char c)
  cons     = ValueCode cons cons
  const    = ValueCode const const
  eq x     = ValueCode (eq (value x)) (eq (code x))
  flip     = ValueCode flip flip
  id       = ValueCode id id
  nil      = ValueCode nil nil
  unit     = ValueCode unit unit
  left     = ValueCode left left
  right    = ValueCode right right
  nothing  = ValueCode nothing nothing
  just     = ValueCode just just
instance Haskellable Value where
  (.)      = Value (Function..)
  ($)      = Value (Function.$)
  (.@) f x = Value (unValue f (unValue x))
  bool     = Value
  char     = Value
  cons     = Value (:)
  const    = Value Function.const
  eq x     = Value (unValue x Eq.==)
  flip     = Value Function.flip
  id       = Value Function.id
  nil      = Value []
  unit     = Value ()
  left     = Value Left
  right    = Value Right
  nothing  = Value Nothing
  just     = Value Just
instance Haskellable TH.CodeQ where
  (.)      = [|| (Function..) ||]
  ($)      = [|| (Function.$) ||]
  (.@) f x = [|| $$f $$x ||]
  bool b   = [|| b ||]
  char c   = [|| c ||]
  cons     = [|| (:) ||]
  const    = [|| Function.const ||]
  eq x     = [|| ($$x Eq.==) ||]
  flip     = [|| \f x y -> f y x ||]
  id       = [|| \x -> x ||]
  nil      = [|| [] ||]
  unit     = [|| () ||]
  left     = [|| Left ||]
  right    = [|| Right ||]
  nothing  = [|| Nothing ||]
  just     = [|| Just ||]