1 -- For `ProbabilityScale`
2 {-# LANGUAGE DataKinds #-}
4 module Numeric.Probability (
7 ProbabilityBounded (..),
17 import Control.Monad (Monad (..))
19 import Data.Function (id)
20 import Data.Maybe (Maybe (..), fromJust)
21 import Data.Monoid (Monoid (..))
22 import Data.Ord (Ord (..))
23 import Data.Proxy (Proxy (..))
24 import Data.Semigroup (Semigroup (..))
25 import Data.Validity (Validity (..), declare)
26 import Data.Word (Word64)
27 import GHC.Generics (Generic)
28 import GHC.TypeNats (Natural, natVal)
30 import Logic.Theory.Bool (type (&&))
31 import Logic.Theory.Ord (type (<=))
32 import Numeric.Decimal (Decimal (..), MonadThrow (..))
33 import Numeric.Decimal qualified as Decimal
34 import System.Random (Random)
35 import Text.Show (Show (show))
36 import Prelude (Bounded (..), Enum, Fractional (..), Integral, Num (..), Rational, Real, error, (^))
38 type Probability = Decimal Decimal.RoundHalfEven ProbabilityScale ProbabilityBounded
39 instance Validity Probability where
40 validate (Decimal wb) = validate wb
42 probability :: MonadThrow m => Rational -> m Probability
43 probability = Decimal.fromRationalDecimalBoundedWithRounding
44 {-# INLINE probability #-}
48 proba0 = fromJust (probability 0)
49 proba1 = fromJust (probability 1)
51 safeProbability :: r ::: Rational / r <= 0 && r <= 1 -> Probability
52 safeProbability (Named r) = fromJust (probability r)
54 runProbability :: Probability -> Rational
55 runProbability = Decimal.toRationalDecimal
56 {-# INLINE runProbability #-}
58 assertProbability :: Rational -> Probability
59 assertProbability r = case probability r of
61 Nothing -> error ("assertProbability: " <> show r)
63 instance Num (Decimal.Arith Probability) where
64 (+) = Decimal.bindM2 Decimal.plusDecimalBounded
65 (-) = Decimal.bindM2 Decimal.minusDecimalBounded
66 (*) = Decimal.bindM2 Decimal.timesDecimalBoundedWithRounding
68 signum m = m >>= Decimal.signumDecimalBounded
69 fromInteger = Decimal.fromIntegerDecimalBoundedIntegral
71 instance Fractional (Decimal.Arith Probability) where
72 (/) = Decimal.bindM2 Decimal.divideDecimalBoundedWithRounding
73 fromRational = probability
75 -- >>> 10^19 <= (fromIntegral (maxBound :: Word64) :: Integer
77 -- >>> 10^20 <= (fromIntegral (maxBound :: Word64) :: Integer
79 type ProbabilityScale = 19
81 newtype ProbabilityBounded = ProbabilityBounded {unProbabilityBounded :: Word64}
82 deriving (Show, Eq, Ord, Num, Real, Integral, Enum, Random, Generic)
83 instance Bounded ProbabilityBounded where
84 minBound = ProbabilityBounded 0
85 maxBound = ProbabilityBounded (10 ^ (natVal (Proxy @ProbabilityScale)))
86 instance Validity ProbabilityBounded where
87 validate (ProbabilityBounded w) =
89 [ declare ("The contained word is smaller or equal to 10 ^ ProbabilityScale = " <> show (10 ^ n :: Natural)) (w <= 10 ^ n)
93 n = natVal (Proxy @ProbabilityScale)