{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyDataDeriving #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# OPTIONS_GHC -Wno-missing-signatures #-}
{-# OPTIONS_GHC -Wno-orphans #-}

-- {-# OPTIONS_GHC -Wno-monomorphism-restriction #-}

module Symantic.Syntaxes where

import Control.Arrow (left)
import Data.Either (Either (..))
import Data.Eq (Eq (..))
import Data.Function (fix, ($), (.))
import Data.Int (Int)
import Data.Kind (Type)
import Data.Maybe (Maybe (..))
import Data.Monoid (Monoid (..))
import Data.Void (Void)
import Symantic.Semantics.Forall (Forall (..), PerSyntax, PerSyntaxable (..))
import Symantic.Syntaxes.Classes (Instantiable (..), Syntaxes, Unabstractable)
import Symantic.Syntaxes.Classes qualified as Sym
import Text.Show (Show (..))
import Type.Reflection (eqTypeRep, typeRep, (:~~:) (..))
import Prelude (error)

import Symantic
import Symantic.Typer ()

parse ::
  forall syns prov finalSyns.
  finalSyns ~ (Instantiable ': Unabstractable ': syns) =>
  (forall k. ProvenanceKindOf (Ty @k) prov) =>
  (forall k. ProvenanceKindOf (Var @k) prov) =>
  Show prov =>
  Monoid prov =>
  Syntaxes finalSyns (Forall finalSyns) =>
  Parsers finalSyns finalSyns prov =>
  TermAST prov ->
  Either (PerSyntax finalSyns (ErrorParser prov)) (TermVT finalSyns prov '[])
parse ast = unParser (fix openParser ast) CtxTyZ
  where
    openParser final (BinTree0 e) = parsers @finalSyns @finalSyns @prov final e
    -- Instantiable and Unabstractable are always required.
    openParser final (BinTree2 fT xT) = Parser $ \ctx -> do
      TermVT fTy (f :: OpenTerm finalSyns fVS a2b) <- unParser (final fT) ctx
      TermVT xTy (a :: OpenTerm finalSyns aVS a) <- unParser (final xT) ctx
      case fTy of
        SkipTyProv
          TyApp
            { tyAppFun =
              SkipTyProv
                TyApp
                  { tyAppFun = TyConst{tyConst = eqTypeRep (typeRep @(->)) -> Just HRefl}
                  , tyAppArg = fxTy
                  }
            , tyAppArg = rTy
            } -> do
            let (fxTy', xTy') = appendVars fxTy xTy
            mgu <- (perSyntax . ErrorParserInstantiableUnify) `left` unifyType mempty fxTy' xTy'
            let fxTy'' = subst mgu fxTy'
            let xTy'' = subst mgu xTy'
            case eqTy fxTy'' xTy'' of
              Nothing -> Left $ perSyntax $ ErrorParserInstantiableArgumentMismatch (TyVT fxTy'') (TyVT xTy'')
              Just HRefl ->
                normalizeVarsTy rTy $ \case
                  -- TyApp (TyApp _c qr') tr' ->
                  rTy' ->
                    Right $ TermVT rTy' (f .@ a)
        _ -> Left $ perSyntax $ ErrorParserInstantiableNotAFunction (TyVT fTy)
data instance ErrorParser prov Instantiable
  = ErrorParserInstantiableArgumentMismatch (TyVT prov) (TyVT prov)
  | ErrorParserInstantiableNotAFunction (TyVT prov)
  | ErrorParserInstantiableUnify (ErrorUnify prov)
  deriving (Show)

-- * Class 'Parsers'
class Parsers syns finalSyns prov where
  parsers :: OpenParser finalSyns prov
instance PerSyntaxable finalSyns Unabstractable => Parsers '[] finalSyns prov where
  parsers _final tok = Parser $ \_ctx ->
    Left $ perSyntax $ ErrorParserUnabstractableInvalid tok
data instance ErrorParser prov Unabstractable
  = ErrorParserUnabstractableInvalid (TokenTerm prov)
  deriving (Show)
instance
  ( TokenParser syn finalSyns prov
  , Parsers syns finalSyns prov
  ) =>
  Parsers (syn ': syns) finalSyns prov
  where
  parsers final = tokenParser @syn @finalSyns (parsers @syns @finalSyns final) final

-- ** Type 'OpenParser'
type OpenParser syns prov =
  Show prov =>
  Monoid prov =>
  {-final-} (TermAST prov -> Parser syns prov) ->
  TokenTerm prov ->
  Parser syns prov

-- ** Class 'TokenParser'
class TokenParser syn finalSyns prov where
  tokenParser ::
    {-next-} (TokenTerm prov -> Parser finalSyns prov) ->
    OpenParser finalSyns prov
instance
  ( Syntaxes finalSyns (Forall finalSyns)
  ) =>
  TokenParser Instantiable finalSyns prov
  where
  tokenParser next _final = next
data instance ErrorParser prov (AbstractableTy (Ty prov '[]))
  = ErrorParserAbstractableUnknown Name
  | ErrorParserAbstractableNotAType (TyVT prov)
  deriving (Show)
instance
  ( Syntaxes finalSyns (Forall finalSyns)
  , forall sem. Syntaxes finalSyns sem => Unabstractable sem
  , forall sem. Syntaxes finalSyns sem => Instantiable sem
  , PerSyntaxable finalSyns (AbstractableTy (Ty prov '[]))
  , forall k. ProvenanceKindOf (Ty @k) prov
  ) =>
  TokenParser (AbstractableTy (Ty prov '[])) finalSyns prov
  where
  tokenParser _next _final (TokenTermVar varName) = Parser go
    where
      go :: forall vs. CtxTy prov vs Void -> Either (PerSyntax finalSyns (ErrorParser prov)) (TermVT finalSyns prov vs)
      go CtxTyZ = Left $ perSyntax $ ErrorParserAbstractableUnknown varName
      -- Introduce 'V'
      go (CtxTyS n ty _) | n == varName = Right $ TermVT ty V
      -- Introduce 'W'
      go (CtxTyS _n _typ tys) = do
        TermVT ty ot <- go tys
        Right $ TermVT ty (W ot)
  tokenParser _next final (TokenTermAbst argName (TyT (argTy :: Ty prov '[] a)) bT) =
    Parser $ \ctx ->
      case eqTy (monoTy @Type @prov) (kindOf argTy) of
        Nothing -> Left $ perSyntax $ ErrorParserAbstractableNotAType (TyVT argTy)
        Just HRefl -> do
          TermVT resTy (res :: OpenTerm syn resVS res) <-
            unParser (final bT) (CtxTyS argName argTy ctx)
          let argTy' = allocVarsR (lenVar resTy) argTy
          Right $
            TermVT (argTy' ~> resTy) $
              case res of
                E d -> E (Sym.const .@ d)
                V -> E Sym.id
                W e -> E Sym.const `appOpenTerm` e
                N e -> e
  tokenParser next _final e = next e
instance
  ( Syntaxes finalSyns (Forall finalSyns)
  , forall sem. Syntaxes finalSyns sem => Addable sem
  ) =>
  TokenParser Addable finalSyns prov
  where
  tokenParser next _final = \case
    TokenTermAtom s
      | Right (i :: Int) <- safeRead s ->
          Parser $ \ctx ->
            Right
              $ TermVT
                (tyOfTypeRep (lenVar ctx) (typeRep @Int))
              $ E
              $ Forall @_ @Int
              $ lit i
    TokenTermAtom "neg" -> Parser $ \_ctx ->
      Right $ TermVT (monoTy @Int ~> monoTy @Int) $ E neg
    TokenTermAtom "add" -> Parser $ \_ctx ->
      Right $ TermVT (monoTy @Int ~> monoTy @Int ~> monoTy @Int) $ E add
    ast -> next ast
instance
  ( Syntaxes syns (Forall syns)
  , forall sem. Syntaxes syns sem => Unabstractable sem
  ) =>
  TokenParser Unabstractable syns prov
  where
  tokenParser next _final = \case
    TokenTermAtom "id" -> Parser $ \_ctx ->
      Right $ TermVT (aTy @'[] ~> aTy) $ E Sym.id
    TokenTermAtom "(<*>)" -> Parser $ \_ctx ->
      Right $
        TermVT ((aTy ~> bTy ~> cTy @'[]) ~> (aTy ~> bTy) ~> aTy ~> cTy) $
          E Sym.ap
    TokenTermAtom "const" -> Parser $ \_ctx ->
      Right $ TermVT (aTy ~> bTy @'[] ~> aTy) $ E Sym.const
    TokenTermAtom "(.)" -> Parser $ \_ctx ->
      Right $ TermVT ((bTy ~> cTy @'[]) ~> (aTy ~> bTy) ~> aTy ~> cTy) $ E (Sym..)
    TokenTermAtom "flip" -> Parser $ \_ctx ->
      Right $ TermVT ((aTy ~> bTy ~> cTy @'[]) ~> bTy ~> aTy ~> cTy) $ E Sym.flip
    TokenTermAtom "($)" -> Parser $ \_ctx ->
      Right $ TermVT ((aTy ~> bTy @'[]) ~> aTy ~> bTy) $ E (Sym.$)
    tok -> next tok

-- * Class 'Addable'
class Addable sem where
  lit :: Int -> sem Int
  neg :: sem (Int -> Int)
  add :: sem (Int -> Int -> Int)
data instance ErrorParser prov Addable
  deriving (Show)

-- * Class 'Mulable'
class Mulable sem where
  mul :: sem Int -> sem Int -> sem Int
instance (forall sem. Syntaxes syn sem => Addable sem) => Addable (Forall syn) where
  lit n = Forall (lit n)
  neg = Forall neg
  add = Forall add
instance (forall sem. Syntaxes syn sem => Mulable sem) => Mulable (Forall syn) where
  mul (Forall a) (Forall b) = Forall (mul a b)

--
-- Parsing
--

-- tree0Parser :: (forall sem. syn sem => Addable sem, syn (Forall syn)) => Either ErrMsg (Forall syn Int)
-- tree0Parser = unParser $ parse tree0Print

-- parseMulable ::
--  forall syn prov.
--  ( forall sem. syn sem => Mulable sem
--  , forall sem. syn sem => Addable sem
--  , syn (Forall syn)
--  ) =>
--  Monoid prov =>
--  Show prov =>
--  (TermAST prov -> Parser syn prov) ->
--  TermAST prov -> Parser syn prov
-- parseMulable final (BinTree2 (BinTree2 (BinTree0 (TokenTermAtom "mul")) aT) bT) =
--  Parser $ \ctx -> do
--    case (final aT, final bT) of
--      (Parser aE, Parser bE) -> do
--        TermVT aTy (Forall a :: Forall syn a) <- aE ctx
--        case eqTy aTy TyConst{tyConst = Const{constType = typeRep @Int, constMeta = mempty::prov}, tyConstLen = lenVar aTy} of
--          Nothing -> Left "TypeError: Mulable 1"
--          Just HRefl -> do
--            TermVT bTy (Forall b :: Forall syn b) <- bE ctx
--            case eqTy bTy TyConst{tyConst = Const{constType = typeRep @Int, constMeta = mempty::prov}, tyConstLen = lenVar bTy} of
--              Nothing -> Left "TypeError: Mulable 2"
--              Just HRefl -> do
--                Right $ TermVT aTy (Forall @_ @a $ mul a b)
-- parseMulable final t = parseAddable final t

--
-- Printing
--

printConst n = Printer (BinTree0 (TokenTermAtom n))

instance Sym.Unabstractable (Printer prov) where
  ap = printConst "(<*>)"
  const = printConst "const"
  id = printConst "id"
  (.) = printConst "(.)"
  flip = printConst "flip"
  ($) = printConst "($)"
instance Addable (Printer prov) where
  lit n = Printer $ BinTree0 (TokenTermAtom (show n))
  neg = printConst "neg"
  add = printConst "add"
instance Mulable (Printer prov) where
  mul = print2 "mul"
instance Instantiable (Printer prov) where
  Printer f .@ Printer a = Printer $ BinTree2 f a
instance (Monoid prov) => AbstractableTy (Ty prov '[]) (Printer prov) where
  lamTy xTy (f :: Printer prov a -> Printer prov b) =
    Printer $
      BinTree0
        ( TokenTermAbst
            "x"
            (TyT xTy)
            (unPrinter (f (Printer (BinTree0 (TokenTermVar "x")))))
        )

-- * Type 'FinalSyntaxes'
type FinalSyntaxes prov = '[AbstractableTy (Ty prov '[]), Addable]

-- tree0 = lit 0
tree0 = add .@ lit 8 .@ (neg .@ (add .@ lit 1 .@ lit 2))

-- tree1 = mul (add (lit 8) (neg (add (lit 1) (lit 2)))) (lit 2)
tree2 = fun @() (\x -> add .@ x .@ lit 0)
tree3 = fun @() (\x -> add .@ x .@ x)
tree4 = Sym.ap .@ add .@ Sym.id
tree0Print = print tree0

-- tree1Print = print tree1
tree2Print = print tree2
tree3Print = print tree3
tree4Print = print tree4

tree0ParsePrint :: TermAST ()
tree0ParsePrint = case parse @(FinalSyntaxes ()) @() tree0Print of
  Left e -> error $ show e
  Right (TermVT _ty (unE -> Forall sem)) -> print sem

-- tree1ParsePrint :: TermAST ()
-- tree1ParsePrint = case parse @(FinalSyntaxes ()) @() tree1Print of
--  Left e -> error e
--  Right (TermVT _ty (unE -> Forall sem)) -> print sem

tree2ParsePrint :: TermAST ()
tree2ParsePrint = case parse @(FinalSyntaxes ()) @() tree2Print of
  Left e -> error $ show e
  Right (TermVT _ty (unE -> Forall sem)) -> print sem

tree3ParsePrint :: TermAST ()
tree3ParsePrint = case parse @(FinalSyntaxes ()) @() tree3Print of
  Left e -> error $ show e
  Right (TermVT _ty (unE -> Forall sem)) -> print sem

tree3ParsePrint' :: TermAST ()
tree3ParsePrint' = case parse @(FinalSyntaxes ()) @() tree3ParsePrint of
  Left e -> error $ show e
  Right (TermVT _ty (unE -> Forall sem)) -> print sem

tree4ParsePrint :: TermAST ()
tree4ParsePrint = case parse @(FinalSyntaxes ()) @() tree4Print of
  Left e -> error $ show e
  Right (TermVT _ty (unE -> Forall sem)) -> print sem