module Voting.Protocol.Arithmetic where

import Control.Monad (bind)
import Data.Argonaut.Core as JSON
import Data.Argonaut.Decode (class DecodeJson, decodeJson)
import Data.Argonaut.Encode (class EncodeJson, encodeJson)
import Data.Argonaut.Parser as JSON
import Data.BigInt (BigInt)
import Data.BigInt as BigInt
import Data.Boolean (otherwise)
import Data.Bounded (class Bounded, top)
import Data.Either (Either(..))
import Data.Eq (class Eq, (==), (/=))
import Data.EuclideanRing (class EuclideanRing, (/), mod)
import Data.Foldable (all)
import Data.Function (($), identity, (<<<))
import Data.Functor ((<$>))
import Data.HeytingAlgebra ((&&))
import Data.List (List, (:))
import Data.List.Lazy as LL
import Data.Maybe (Maybe(..), maybe)
import Data.Monoid (class Monoid, mempty, (<>))
import Data.Newtype (class Newtype, wrap, unwrap)
import Data.Ord (class Ord, (<), (<=))
import Data.Reflection (class Reifies, reflect)
import Data.Ring (class Ring, (-), negate)
import Data.Semiring (class Semiring, zero, (+), one, (*))
import Data.Show (class Show, show)
import Data.String.CodeUnits as String
import Effect.Exception.Unsafe (unsafeThrow)
import Type.Proxy (Proxy(..))

-- * Type 'Natural'
newtype Natural = Natural BigInt
instance newtypeNatural :: Newtype Natural BigInt where
  wrap = Natural
  unwrap (Natural x) = x
derive newtype instance eqNatural            :: Eq Natural
derive newtype instance ordNatural           :: Ord Natural
derive newtype instance showNatural          :: Show Natural
derive newtype instance semiringNatural      :: Semiring Natural
derive newtype instance euclideanRingNatural :: EuclideanRing Natural

-- * Class 'FromNatural'
class FromNatural a where
  fromNatural :: Natural -> a

-- * Class 'ToNatural'
class ToNatural a where
  nat :: a -> Natural
instance toNaturalBigInt :: ToNatural Natural where
  nat = identity
instance toNaturalInt :: ToNatural Int where
  nat x | 0 <= x = wrap (BigInt.fromInt x)
        | otherwise = unsafeThrow "nat: given Int is negative"


-- * Class 'Additive'
-- | An additive semigroup.
class Additive a where
  gzero :: a
  gadd  :: a -> a -> a
instance additiveBigInt :: Additive BigInt where
  gzero = zero
  gadd  = (+)
instance additiveNatural :: Additive Natural where
  gzero = zero
  gadd  = (+)

-- | `('power' b e)` returns the modular exponentiation of base `b` by exponent `e`.
power :: forall crypto c a. Semiring a => a -> E crypto c -> a
power x = go <<< unwrap
  where
  two :: Natural
  two = one + one
  go :: Natural -> a
  go p
    | p == zero = one
    | p == one  = x
    | p `mod` two == zero = let x' = go (p / two) in x' * x'
    | otherwise           = let x' = go (p / two) in x' * x' * x
infixr 8 power as ^

-- * Class 'CryptoParams' where
class
 ( EuclideanRing (G crypto c)
 , FromNatural   (G crypto c)
 , ToNatural     (G crypto c)
 , Eq            (G crypto c)
 , Ord           (G crypto c)
 , Show          (G crypto c)
 , DecodeJson    (G crypto c)
 , EncodeJson    (G crypto c)
 , Reifies c crypto
 ) <= CryptoParams crypto c where
  -- | A generator of the subgroup.
  groupGen   :: G crypto c
  -- | The order of the subgroup.
  groupOrder :: Proxy crypto -> Proxy c -> Natural
  
-- | 'groupGenPowers' returns the infinite list
-- of powers of 'groupGen'.
groupGenPowers :: forall crypto c. CryptoParams crypto c => LL.List (G crypto c)
groupGenPowers = go one
  where go g = g LL.: go (g * groupGen)

-- | 'groupGenInverses' returns the infinite list
-- of 'inverse' powers of 'groupGen':
-- @['groupGen' '^' 'negate' i | i <- [0..]]@,
-- but by computing each value from the previous one.
--
-- Used by 'intervalDisjunctions'.
groupGenInverses :: forall crypto c. CryptoParams crypto c => LL.List (G crypto c)
groupGenInverses = go one
  where
  invGen = inverse groupGen
  go g = g LL.: go (g * invGen)

inverse :: forall a. EuclideanRing a => a -> a
inverse a = one / a

-- ** Class 'ReifyCrypto'
class ReifyCrypto crypto where
  -- | Like 'reify' but augmented with the 'CryptoParams' constraint.
  reifyCrypto :: forall r. crypto -> (forall c. Reifies c crypto => CryptoParams crypto c => Proxy c -> r) -> r

-- ** Type 'G'
-- | The type of the elements of a subgroup of a field.
newtype G crypto c = G Natural

-- ** Type 'E'
-- | An exponent of a (cyclic) subgroup of a field.
-- The value is always in @[0..'groupOrder'-1]@.
newtype E crypto c = E Natural
 -- deriving (Eq,Ord,Show)
 -- deriving newtype NFData
derive newtype instance eqE   :: Eq   (E crypto c)
derive newtype instance ordE  :: Ord  (E crypto c)
derive newtype instance showE :: Show (E crypto c)
instance newtypeE :: Newtype (E crypto c) Natural where
  wrap = E
  unwrap (E x) = x
instance additiveE :: CryptoParams crypto c => Additive (E crypto c) where
  gzero = zero
  gadd  = (+)
instance semiringE :: CryptoParams crypto c => Semiring (E crypto c) where
  zero = E zero
  add (E x) (E y) = E ((x + y) `mod` groupOrder (Proxy::Proxy crypto) (Proxy::Proxy c))
  one = E one
  mul (E x) (E y) = E ((x * y) `mod` groupOrder (Proxy::Proxy crypto) (Proxy::Proxy c))
instance ringE :: CryptoParams crypto c => Ring (E crypto c) where
  sub (E x) (E y) = E (x + wrap (unwrap (groupOrder (Proxy::Proxy crypto) (Proxy::Proxy c)) - unwrap y))
instance fromNaturalE :: CryptoParams crypto c => FromNatural (E crypto c) where
  fromNatural n = E (n `mod` groupOrder (Proxy::Proxy crypto) (Proxy::Proxy c))
instance toNaturalE :: ToNatural (E crypto c) where
  nat (E x) = x
instance boundedE :: CryptoParams crypto c => Bounded (E crypto c) where
  bottom = E zero
  top    = E $ wrap (unwrap (groupOrder (Proxy::Proxy crypto) (Proxy::Proxy c)) - one)
{-
instance enumE :: Reifies c crypto => Enum (E crypto c) where
  succ z = let z' = z + one in if z' > z then Just z' else Nothing
  pred z = let z' = z - one in if z' < z then Just z' else Nothing
instance boundedEnumE :: Reifies c crypto => BoundedEnum (E crypto c) where
  cardinality = Cardinality (toInt (undefined :: m) - 1)
  toEnum x = let z = mkE x in if runE z == x then Just z else Nothing
  fromEnum = runE
-}
instance encodeJsonE :: EncodeJson (E crypto c) where
  encodeJson (E n) = encodeJson (show n)
instance decodeJsonE :: CryptoParams crypto c => DecodeJson (E crypto c) where
  decodeJson = JSON.caseJsonString (Left "String") $ \s ->
    maybe (Left "Exponent") Right $ do
      {head:c0} <- String.uncons s
      if c0 /= '0' && all isDigit (String.toCharArray s)
      then do
        n <- Natural <$> BigInt.fromString s
        if n < groupOrder (Proxy::Proxy crypto) (Proxy::Proxy c)
        then Just (E n)
        else Nothing
      else Nothing

isDigit :: Char -> Boolean
isDigit c = case c of
 '0' -> true
 '1' -> true
 '2' -> true
 '3' -> true
 '4' -> true
 '5' -> true
 '6' -> true
 '7' -> true
 '8' -> true
 '9' -> true
 _ -> false