{-# LANGUAGE PatternSynonyms #-} -- For Instr
{-# LANGUAGE ViewPatterns #-} -- For unSomeInstr
-- | Initial encoding with bottom-up optimizations of 'Instr'uctions,
-- re-optimizing downward as needed after each optimization.
-- There is only one optimization (for 'pushValue') so far,
-- but the introspection enabled by the 'Instr' data-type
-- is also useful to optimize with more context in the 'Machine'.
module Symantic.Parser.Machine.Optimize where

import Data.Bifunctor (second)
import Data.Bool (Bool(..))
import Data.Either (Either)
import Data.Function ((.))
import Data.Kind (Constraint)
import Data.Maybe (Maybe(..))
import Data.Set (Set)
import Data.String (String)
import Type.Reflection (Typeable, typeRep, eqTypeRep, (:~~:)(..))
import qualified Data.Functor as Functor
import qualified Language.Haskell.TH as TH

import Symantic.Syntaxes.Derive
import Symantic.Parser.Grammar
import Symantic.Parser.Machine.Input
import Symantic.Parser.Machine.Instructions

-- * Data family 'Instr'
-- | 'Instr'uctions of the 'Machine'.
-- This is an extensible data-type.
data family Instr
  (instr :: ReprInstr -> Constraint)
  :: ReprInstr -> ReprInstr
type instance Derived (Instr instr repr inp vs) = repr inp vs

-- | Convenient utility to pattern-match a 'SomeInstr'.
pattern Instr :: Typeable instr =>
  Instr instr repr inp vs a ->
  SomeInstr repr inp vs a
pattern Instr x <- (unSomeInstr -> Just x)

-- ** Type 'SomeInstr'
-- | Some 'Instr'uction existentialized over the actual instruction symantic class.
-- Useful to handle a list of 'Instr'uctions
-- without requiring impredicative quantification.
-- Must be used by pattern-matching
-- on the 'SomeInstr' data-constructor,
-- to bring the constraints in scope.
--
-- As in 'SomeComb', a first pass of optimizations
-- is directly applied in it
-- to avoid introducing an extra newtype,
-- this also gives a more understandable code.
data SomeInstr repr inp vs a =
  forall instr.
  ( Derivable (Instr instr repr inp vs)
  , Typeable instr
  ) => SomeInstr (Instr instr repr inp vs a)

type instance Derived (SomeInstr repr inp vs) = repr inp vs
instance Derivable (SomeInstr repr inp vs) where
  derive (SomeInstr x) = derive x

-- | @(unSomeInstr i :: 'Maybe' ('Instr' instr repr inp vs a))@
-- extract the data-constructor from the given 'SomeInstr'
-- iif. it belongs to the @('Instr' instr repr a)@ data-instance.
unSomeInstr ::
  forall instr repr inp vs a.
  Typeable instr =>
  SomeInstr repr inp vs a ->
  Maybe (Instr instr repr inp vs a)
unSomeInstr (SomeInstr (i::Instr i repr inp vs a)) =
  case typeRep @instr `eqTypeRep` typeRep @i of
    Just HRefl -> Just i
    Nothing ->
      case typeRep @InstrComment `eqTypeRep` typeRep @i of
        Just HRefl | Comment _msg x <- i -> unSomeInstr x
        Nothing -> Nothing

-- InstrComment
data instance Instr InstrComment repr inp vs a where
  Comment ::
    String ->
    SomeInstr repr inp vs a ->
    Instr InstrComment repr inp vs a
instance InstrComment repr => Derivable (Instr InstrComment repr inp vs) where
  derive = \case
    Comment msg k -> comment msg (derive k)
instance InstrComment repr => InstrComment (SomeInstr repr) where
  comment msg = SomeInstr . Comment msg

-- InstrValuable
data instance Instr InstrValuable repr inp vs a where
  PushValue ::
    Splice v ->
    SomeInstr repr inp (v ': vs) a ->
    Instr InstrValuable repr inp vs a
  PopValue ::
    SomeInstr repr inp vs a ->
    Instr InstrValuable repr inp (v ': vs) a
  Lift2Value ::
    Splice (x -> y -> z) ->
    SomeInstr repr inp (z : vs) a ->
    Instr InstrValuable repr inp (y : x : vs) a
  SwapValue ::
    SomeInstr repr inp (x ': y ': vs) a ->
    Instr InstrValuable repr inp (y ': x ': vs) a
instance InstrValuable repr => Derivable (Instr InstrValuable repr inp vs) where
  derive = \case
    PushValue v k -> pushValue v (derive k)
    PopValue k -> popValue (derive k)
    Lift2Value v k -> lift2Value v (derive k)
    SwapValue k -> swapValue (derive k)
instance InstrValuable repr => InstrValuable (SomeInstr repr) where
  -- 'PopValue' after a 'PushValue' is a no-op.
  pushValue _v (Instr (PopValue i)) = i
  pushValue v i = SomeInstr (PushValue v i)
  popValue = SomeInstr . PopValue
  lift2Value f = SomeInstr . Lift2Value f
  swapValue = SomeInstr . SwapValue

-- InstrExceptionable
data instance Instr InstrExceptionable repr inp vs a where
  Raise ::
    ExceptionLabel ->
    Instr InstrExceptionable repr inp vs a
  Fail ::
    Set SomeFailure ->
    Instr InstrExceptionable repr inp vs a
  Commit ::
    Exception ->
    SomeInstr repr inp vs ret ->
    Instr InstrExceptionable repr inp vs ret
  Catch ::
    Exception ->
    SomeInstr repr inp vs ret ->
    SomeInstr repr inp (InputPosition inp ': vs) ret ->
    Instr InstrExceptionable repr inp vs ret
instance InstrExceptionable repr => Derivable (Instr InstrExceptionable repr inp vs) where
  derive = \case
    Raise exn -> raise exn
    Fail fs -> fail fs
    Commit exn k -> commit exn (derive k)
    Catch exn l r -> catch exn (derive l) (derive r)
instance InstrExceptionable repr => InstrExceptionable (SomeInstr repr) where
  raise = SomeInstr . Raise
  fail = SomeInstr . Fail
  commit exn = SomeInstr . Commit exn
  catch exn x = SomeInstr . Catch exn x

-- InstrBranchable
data instance Instr InstrBranchable repr inp vs a where
  CaseBranch ::
    SomeInstr repr inp (x ': vs) a ->
    SomeInstr repr inp (y ': vs) a ->
    Instr InstrBranchable repr inp (Either x y ': vs) a
  ChoicesBranch ::
    [(Splice (v -> Bool), SomeInstr repr inp vs a)] ->
    SomeInstr repr inp vs a ->
    Instr InstrBranchable repr inp (v ': vs) a
instance InstrBranchable repr => Derivable (Instr InstrBranchable repr inp vs) where
  derive = \case
    CaseBranch l r -> caseBranch (derive l) (derive r)
    ChoicesBranch bs d -> choicesBranch (second derive Functor.<$> bs) (derive d)
instance InstrBranchable repr => InstrBranchable (SomeInstr repr) where
  caseBranch l = SomeInstr . CaseBranch l
  choicesBranch bs = SomeInstr . ChoicesBranch bs

-- InstrCallable
data instance Instr InstrCallable repr inp vs a where
  DefLet ::
    LetBindings TH.Name (SomeInstr repr inp '[]) ->
    SomeInstr repr inp vs a ->
    Instr InstrCallable repr inp vs a
  Call ::
    Bool ->
    LetName v ->
    SomeInstr repr inp (v ': vs) a ->
    Instr InstrCallable repr inp vs a
  Ret ::
    Instr InstrCallable repr inp '[a] a
  Jump ::
    Bool ->
    LetName a ->
    Instr InstrCallable repr inp '[] a
instance InstrCallable repr => Derivable (Instr InstrCallable repr inp vs) where
  derive = \case
    DefLet subs k -> defLet ((\(SomeLet sub) -> SomeLet (derive sub)) Functor.<$> subs) (derive k)
    Jump isRec n -> jump isRec n
    Call isRec n k -> call isRec n (derive k)
    Ret -> ret
instance InstrCallable repr => InstrCallable (SomeInstr repr) where
  defLet subs = SomeInstr . DefLet subs
  jump isRec = SomeInstr . Jump isRec
  call isRec n = SomeInstr . Call isRec n
  ret = SomeInstr Ret

-- InstrJoinable
data instance Instr InstrJoinable repr inp vs a where
  DefJoin ::
    LetName v ->
    SomeInstr repr inp (v ': vs) a ->
    SomeInstr repr inp vs a ->
    Instr InstrJoinable repr inp vs a
  RefJoin ::
    LetName v ->
    Instr InstrJoinable repr inp (v ': vs) a
instance InstrJoinable repr => Derivable (Instr InstrJoinable repr inp vs) where
  derive = \case
    DefJoin n sub k -> defJoin n (derive sub) (derive k)
    RefJoin n -> refJoin n
instance InstrJoinable repr => InstrJoinable (SomeInstr repr) where
  defJoin n sub = SomeInstr . DefJoin n sub
  refJoin = SomeInstr . RefJoin

-- InstrInputable
data instance Instr InstrInputable repr inp vs a where
  PushInput ::
    SomeInstr repr inp (InputPosition inp ': vs) a ->
    Instr InstrInputable repr inp vs a
  LoadInput ::
    SomeInstr repr inp vs a ->
    Instr InstrInputable repr inp (InputPosition inp ': vs) a
instance InstrInputable repr => Derivable (Instr InstrInputable repr inp vs) where
  derive = \case
    PushInput k -> saveInput (derive k)
    LoadInput k -> loadInput (derive k)
instance InstrInputable repr => InstrInputable (SomeInstr repr) where
  saveInput = SomeInstr . PushInput
  loadInput = SomeInstr . LoadInput

-- InstrReadable
data instance Instr (InstrReadable tok) repr inp vs a where
  Read ::
    Set SomeFailure ->
    Splice (InputToken inp -> Bool) ->
    SomeInstr repr inp (InputToken inp ': vs) a ->
    Instr (InstrReadable tok) repr inp vs a
instance
  ( InstrReadable tok repr, tok ~ InputToken inp ) =>
  Derivable (Instr (InstrReadable tok) repr inp vs) where
  derive = \case
    Read fs p k -> read fs p (derive k)
instance
  ( InstrReadable tok repr, Typeable tok ) =>
  InstrReadable tok (SomeInstr repr) where
  read fs p = SomeInstr . Read fs p

-- InstrIterable
data instance Instr InstrIterable repr inp vs a where
  Iter ::
    LetName a ->
    SomeInstr repr inp '[] a ->
    SomeInstr repr inp (InputPosition inp ': vs) a ->
    Instr InstrIterable repr inp vs a
instance
  InstrIterable repr =>
  Derivable (Instr InstrIterable repr inp vs) where
  derive = \case
    Iter n op k -> iter n (derive op) (derive k)
instance
  InstrIterable repr =>
  InstrIterable (SomeInstr repr) where
  iter n op = SomeInstr . Iter n op

-- InstrRegisterable
data instance Instr InstrRegisterable repr inp vs a where
  NewRegister ::
    UnscopedRegister v ->
    SomeInstr repr inp vs a ->
    Instr InstrRegisterable repr inp (v : vs) a
  ReadRegister ::
    UnscopedRegister v ->
    SomeInstr repr inp (v : vs) a ->
    Instr InstrRegisterable repr inp vs a
  WriteRegister ::
    UnscopedRegister v ->
    SomeInstr repr inp vs a ->
    Instr InstrRegisterable repr inp (v : vs) a
instance
  InstrRegisterable repr =>
  Derivable (Instr InstrRegisterable repr inp vs) where
  derive = \case
    NewRegister r k -> newRegister r (derive k)
    ReadRegister r k -> readRegister r (derive k)
    WriteRegister r k -> writeRegister r (derive k)
instance
  InstrRegisterable repr =>
  InstrRegisterable (SomeInstr repr) where
  newRegister r = SomeInstr . NewRegister r
  readRegister r = SomeInstr . ReadRegister r
  writeRegister r = SomeInstr . WriteRegister r