1 -- For `ProbabilityScale`
2 {-# LANGUAGE DataKinds #-}
4 module Worksheets.Utils.Probability (
7 ProbabilityBounded (..),
16 -- import Data.Bool (Bool)
17 -- import Data.Function (id, on, (.))
18 import Data.Maybe (fromJust)
20 -- import Data.Monoid (Monoid (..))
21 -- import Data.Validity (Validity (..), declare)
22 import Data.Word (Word64)
24 -- import GHC.Real (RealFrac (..))
25 import Numeric.Decimal (Decimal (..), MonadThrow (..))
26 import Numeric.Decimal qualified as Decimal
28 -- import System.Random (Random)
29 import Worksheets.Utils.Prelude
30 import Prelude (Bounded (..), Fractional (..), Integral, error, (^))
32 type Probability = Decimal Decimal.RoundHalfEven ProbabilityScale ProbabilityBounded
34 -- instance Validity Probability where
35 -- validate (Decimal wb) = validate wb
37 probability :: MonadThrow m => Rational -> m Probability
38 probability = Decimal.fromRationalDecimalBoundedWithRounding
39 {-# INLINE probability #-}
43 proba0 = fromJust (probability 0)
44 proba1 = fromJust (probability 1)
46 -- safeProbability :: r ::: Rational / r <= 0 && r <= 1 -> Probability
47 -- safeProbability (Named r) = fromJust (probability r)
49 runProbability :: Probability -> Rational
50 runProbability = Decimal.toRationalDecimal
51 {-# INLINE runProbability #-}
53 assertProbability :: HasCallStack => Rational -> Probability
54 assertProbability r = case probability r of
56 Nothing -> error ("assertProbability: " <> show r)
58 instance Num (Decimal.Arith Probability) where
59 (+) = Decimal.bindM2 Decimal.plusDecimalBounded
60 (-) = Decimal.bindM2 Decimal.minusDecimalBounded
61 (*) = Decimal.bindM2 Decimal.timesDecimalBoundedWithRounding
63 signum m = m >>= Decimal.signumDecimalBounded
64 fromInteger = Decimal.fromIntegerDecimalBoundedIntegral
66 instance Fractional (Decimal.Arith Probability) where
67 (/) = Decimal.bindM2 Decimal.divideDecimalBoundedWithRounding
68 fromRational = probability
70 {- HasCallStack does not work well for those
72 instance Eq (Decimal.Arith Probability) where
73 (==) :: HasCallStack => Decimal.Arith Probability -> Decimal.Arith Probability -> Bool
74 (==) = (==) `on` Decimal.arithError
76 instance Ord (Decimal.Arith Probability) where
77 compare :: HasCallStack => Decimal.Arith Probability -> Decimal.Arith Probability -> Ordering
78 compare = compare `on` Decimal.arithError
80 instance Real (Decimal.Arith Probability) where
81 toRational = Decimal.toRationalDecimal . Decimal.arithError
83 instance RealFrac (Decimal.Arith Probability) where
84 properFraction p = (n, return (assertProbability f))
86 (n,f) = properFraction (Decimal.toRationalDecimal (Decimal.arithError p))
89 -- >>> 10^19 <= (fromIntegral (maxBound :: Word64) :: Integer
91 -- >>> 10^20 <= (fromIntegral (maxBound :: Word64) :: Integer
93 type ProbabilityScale = 19
95 newtype ProbabilityBounded = ProbabilityBounded {unProbabilityBounded :: Word64}
96 deriving (Show, Eq, Ord, Num, Real, Integral, Enum, Generic)
97 instance Bounded ProbabilityBounded where
98 minBound = ProbabilityBounded 0
99 maxBound = ProbabilityBounded (10 ^ (natVal (Proxy @ProbabilityScale)))
101 -- instance Validity ProbabilityBounded where
102 -- validate (ProbabilityBounded w) =
104 -- [ declare ("The contained word is smaller or equal to 10 ^ ProbabilityScale = " <> show (10 ^ n :: Natural)) (w <= 10 ^ n)
108 -- n = natVal (Proxy @ProbabilityScale)