-- For `ProbabilityScale`
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE InstanceSigs #-}

module Numeric.Probability (
  Probability,
  ProbabilityScale,
  ProbabilityBounded (..),
  probability,
  safeProbability,
  runProbability,
  assertProbability,
  proba0,
  proba1,
)
where

import Control.Monad (Monad (..))
import Data.Bool (Bool)
import Data.Eq (Eq (..))
import Data.Function (id, on, (.))
import Data.Maybe (Maybe (..), fromJust)
import Data.Monoid (Monoid (..))
import Data.Ord (Ord (..), Ordering)
import Data.Proxy (Proxy (..))
import Data.Semigroup (Semigroup (..))
import Data.Validity (Validity (..), declare)
import Data.Word (Word64)
import GHC.Generics (Generic)
import GHC.Real (RealFrac (..))
import GHC.Stack (HasCallStack)
import GHC.TypeNats (Natural, natVal)
import Logic
import Logic.Theory.Bool (type (&&))
import Logic.Theory.Ord (type (<=))
import Numeric.Decimal (Decimal (..), MonadThrow (..))
import Numeric.Decimal qualified as Decimal
import System.Random (Random)
import Text.Show (Show (show))
import Prelude (Bounded (..), Enum, Fractional (..), Integral, Num (..), Rational, Real (..), error, (^))

type Probability = Decimal Decimal.RoundHalfEven ProbabilityScale ProbabilityBounded
instance Validity Probability where
  validate (Decimal wb) = validate wb

probability :: MonadThrow m => Rational -> m Probability
probability = Decimal.fromRationalDecimalBoundedWithRounding
{-# INLINE probability #-}

proba0 :: Probability
proba1 :: Probability
proba0 = fromJust (probability 0)
proba1 = fromJust (probability 1)

safeProbability :: r ::: Rational / r <= 0 && r <= 1 -> Probability
safeProbability (Named r) = fromJust (probability r)

runProbability :: Probability -> Rational
runProbability = Decimal.toRationalDecimal
{-# INLINE runProbability #-}

assertProbability :: HasCallStack => Rational -> Probability
assertProbability r = case probability r of
  Just p -> p
  Nothing -> error ("assertProbability: " <> show r)

instance Num (Decimal.Arith Probability) where
  (+) = Decimal.bindM2 Decimal.plusDecimalBounded
  (-) = Decimal.bindM2 Decimal.minusDecimalBounded
  (*) = Decimal.bindM2 Decimal.timesDecimalBoundedWithRounding
  abs = id
  signum m = m >>= Decimal.signumDecimalBounded
  fromInteger = Decimal.fromIntegerDecimalBoundedIntegral

instance Fractional (Decimal.Arith Probability) where
  (/) = Decimal.bindM2 Decimal.divideDecimalBoundedWithRounding
  fromRational = probability

{- HasCallStack does not work well for those

instance Eq (Decimal.Arith Probability) where
  (==) :: HasCallStack => Decimal.Arith Probability -> Decimal.Arith Probability -> Bool
  (==) = (==) `on` Decimal.arithError

instance Ord (Decimal.Arith Probability) where
  compare :: HasCallStack => Decimal.Arith Probability -> Decimal.Arith Probability -> Ordering
  compare = compare `on` Decimal.arithError

instance Real (Decimal.Arith Probability) where
  toRational = Decimal.toRationalDecimal . Decimal.arithError

instance RealFrac (Decimal.Arith Probability) where
  properFraction p = (n, return (assertProbability f))
    where
    (n,f) = properFraction (Decimal.toRationalDecimal (Decimal.arithError p))
-}

-- >>> 10^19 <= (fromIntegral (maxBound :: Word64) :: Integer
-- True
-- >>> 10^20 <= (fromIntegral (maxBound :: Word64) :: Integer
-- False
type ProbabilityScale = 19

newtype ProbabilityBounded = ProbabilityBounded {unProbabilityBounded :: Word64}
  deriving (Show, Eq, Ord, Num, Real, Integral, Enum, Random, Generic)
instance Bounded ProbabilityBounded where
  minBound = ProbabilityBounded 0
  maxBound = ProbabilityBounded (10 ^ (natVal (Proxy @ProbabilityScale)))
instance Validity ProbabilityBounded where
  validate (ProbabilityBounded w) =
    mconcat
      [ declare ("The contained word is smaller or equal to 10 ^ ProbabilityScale = " <> show (10 ^ n :: Natural)) (w <= 10 ^ n)
      ]
    where
      n :: Natural
      n = natVal (Proxy @ProbabilityScale)