-- For `ProbabilityScale` {-# LANGUAGE DataKinds #-} {-# LANGUAGE InstanceSigs #-} module Numeric.Probability ( Probability, ProbabilityScale, ProbabilityBounded (..), probability, safeProbability, runProbability, assertProbability, proba0, proba1, ) where import Control.Monad (Monad (..)) import Data.Bool (Bool) import Data.Eq (Eq (..)) import Data.Function (id, on, (.)) import Data.Maybe (Maybe (..), fromJust) import Data.Monoid (Monoid (..)) import Data.Ord (Ord (..), Ordering) import Data.Proxy (Proxy (..)) import Data.Semigroup (Semigroup (..)) import Data.Validity (Validity (..), declare) import Data.Word (Word64) import GHC.Generics (Generic) import GHC.Real (RealFrac (..)) import GHC.Stack (HasCallStack) import GHC.TypeNats (Natural, natVal) import Logic import Logic.Theory.Bool (type (&&)) import Logic.Theory.Ord (type (<=)) import Numeric.Decimal (Decimal (..), MonadThrow (..)) import Numeric.Decimal qualified as Decimal import System.Random (Random) import Text.Show (Show (show)) import Prelude (Bounded (..), Enum, Fractional (..), Integral, Num (..), Rational, Real (..), error, (^)) type Probability = Decimal Decimal.RoundHalfEven ProbabilityScale ProbabilityBounded instance Validity Probability where validate (Decimal wb) = validate wb probability :: MonadThrow m => Rational -> m Probability probability = Decimal.fromRationalDecimalBoundedWithRounding {-# INLINE probability #-} proba0 :: Probability proba1 :: Probability proba0 = fromJust (probability 0) proba1 = fromJust (probability 1) safeProbability :: r ::: Rational / r <= 0 && r <= 1 -> Probability safeProbability (Named r) = fromJust (probability r) runProbability :: Probability -> Rational runProbability = Decimal.toRationalDecimal {-# INLINE runProbability #-} assertProbability :: HasCallStack => Rational -> Probability assertProbability r = case probability r of Just p -> p Nothing -> error ("assertProbability: " <> show r) instance Num (Decimal.Arith Probability) where (+) = Decimal.bindM2 Decimal.plusDecimalBounded (-) = Decimal.bindM2 Decimal.minusDecimalBounded (*) = Decimal.bindM2 Decimal.timesDecimalBoundedWithRounding abs = id signum m = m >>= Decimal.signumDecimalBounded fromInteger = Decimal.fromIntegerDecimalBoundedIntegral instance Fractional (Decimal.Arith Probability) where (/) = Decimal.bindM2 Decimal.divideDecimalBoundedWithRounding fromRational = probability {- HasCallStack does not work well for those instance Eq (Decimal.Arith Probability) where (==) :: HasCallStack => Decimal.Arith Probability -> Decimal.Arith Probability -> Bool (==) = (==) `on` Decimal.arithError instance Ord (Decimal.Arith Probability) where compare :: HasCallStack => Decimal.Arith Probability -> Decimal.Arith Probability -> Ordering compare = compare `on` Decimal.arithError instance Real (Decimal.Arith Probability) where toRational = Decimal.toRationalDecimal . Decimal.arithError instance RealFrac (Decimal.Arith Probability) where properFraction p = (n, return (assertProbability f)) where (n,f) = properFraction (Decimal.toRationalDecimal (Decimal.arithError p)) -} -- >>> 10^19 <= (fromIntegral (maxBound :: Word64) :: Integer -- True -- >>> 10^20 <= (fromIntegral (maxBound :: Word64) :: Integer -- False type ProbabilityScale = 19 newtype ProbabilityBounded = ProbabilityBounded {unProbabilityBounded :: Word64} deriving (Show, Eq, Ord, Num, Real, Integral, Enum, Random, Generic) instance Bounded ProbabilityBounded where minBound = ProbabilityBounded 0 maxBound = ProbabilityBounded (10 ^ (natVal (Proxy @ProbabilityScale))) instance Validity ProbabilityBounded where validate (ProbabilityBounded w) = mconcat [ declare ("The contained word is smaller or equal to 10 ^ ProbabilityScale = " <> show (10 ^ n :: Natural)) (w <= 10 ^ n) ] where n :: Natural n = natVal (Proxy @ProbabilityScale)