1 -- For `ProbabilityScale`
2 {-# LANGUAGE DataKinds #-}
3 {-# LANGUAGE InstanceSigs #-}
5 module Numeric.Probability (
8 ProbabilityBounded (..),
18 import Control.Monad (Monad (..))
19 import Data.Bool (Bool)
20 import Data.Eq (Eq(..))
21 import Data.Function (id, on, (.))
22 import Data.Maybe (Maybe (..), fromJust)
23 import Data.Monoid (Monoid (..))
24 import Data.Ord (Ord (..), Ordering)
25 import Data.Proxy (Proxy (..))
26 import Data.Semigroup (Semigroup (..))
27 import Data.Validity (Validity (..), declare)
28 import Data.Word (Word64)
29 import GHC.Generics (Generic)
30 import GHC.Real (RealFrac(..))
31 import GHC.Stack (HasCallStack)
32 import GHC.TypeNats (Natural, natVal)
34 import Logic.Theory.Bool (type (&&))
35 import Logic.Theory.Ord (type (<=))
36 import Numeric.Decimal (Decimal (..), MonadThrow (..))
37 import Numeric.Decimal qualified as Decimal
38 import System.Random (Random)
39 import Text.Show (Show (show))
40 import Prelude (Bounded (..), Enum, Fractional (..), Integral, Num (..), Rational, Real(..), error, (^))
42 type Probability = Decimal Decimal.RoundHalfEven ProbabilityScale ProbabilityBounded
43 instance Validity Probability where
44 validate (Decimal wb) = validate wb
46 probability :: MonadThrow m => Rational -> m Probability
47 probability = Decimal.fromRationalDecimalBoundedWithRounding
48 {-# INLINE probability #-}
52 proba0 = fromJust (probability 0)
53 proba1 = fromJust (probability 1)
55 safeProbability :: r ::: Rational / r <= 0 && r <= 1 -> Probability
56 safeProbability (Named r) = fromJust (probability r)
58 runProbability :: Probability -> Rational
59 runProbability = Decimal.toRationalDecimal
60 {-# INLINE runProbability #-}
62 assertProbability :: HasCallStack => Rational -> Probability
63 assertProbability r = case probability r of
65 Nothing -> error ("assertProbability: " <> show r)
67 instance Num (Decimal.Arith Probability) where
68 (+) = Decimal.bindM2 Decimal.plusDecimalBounded
69 (-) = Decimal.bindM2 Decimal.minusDecimalBounded
70 (*) = Decimal.bindM2 Decimal.timesDecimalBoundedWithRounding
72 signum m = m >>= Decimal.signumDecimalBounded
73 fromInteger = Decimal.fromIntegerDecimalBoundedIntegral
75 instance Fractional (Decimal.Arith Probability) where
76 (/) = Decimal.bindM2 Decimal.divideDecimalBoundedWithRounding
77 fromRational = probability
79 instance Eq (Decimal.Arith Probability) where
80 --(==) :: HasCallStack => Decimal.Arith Probability -> Decimal.Arith Probability -> Bool
81 (==) = (==) `on` Decimal.arithError
83 instance Ord (Decimal.Arith Probability) where
84 --compare :: HasCallStack => Decimal.Arith Probability -> Decimal.Arith Probability -> Ordering
85 compare = compare `on` Decimal.arithError
87 instance Real (Decimal.Arith Probability) where
88 toRational = Decimal.toRationalDecimal . Decimal.arithError
90 instance RealFrac (Decimal.Arith Probability) where
91 properFraction p = (n, return (assertProbability f))
93 (n,f) = properFraction (Decimal.toRationalDecimal (Decimal.arithError p))
95 -- >>> 10^19 <= (fromIntegral (maxBound :: Word64) :: Integer
97 -- >>> 10^20 <= (fromIntegral (maxBound :: Word64) :: Integer
99 type ProbabilityScale = 19
101 newtype ProbabilityBounded = ProbabilityBounded {unProbabilityBounded :: Word64}
102 deriving (Show, Eq, Ord, Num, Real, Integral, Enum, Random, Generic)
103 instance Bounded ProbabilityBounded where
104 minBound = ProbabilityBounded 0
105 maxBound = ProbabilityBounded (10 ^ (natVal (Proxy @ProbabilityScale)))
106 instance Validity ProbabilityBounded where
107 validate (ProbabilityBounded w) =
109 [ declare ("The contained word is smaller or equal to 10 ^ ProbabilityScale = " <> show (10 ^ n :: Natural)) (w <= 10 ^ n)
113 n = natVal (Proxy @ProbabilityScale)