-- For `ProbabilityScale` {-# LANGUAGE DataKinds #-} module Worksheets.Utils.Probability ( Probability, ProbabilityScale, ProbabilityBounded (..), probability, runProbability, assertProbability, proba0, proba1, ) where -- import Data.Bool (Bool) -- import Data.Function (id, on, (.)) import Data.Maybe (fromJust) -- import Data.Monoid (Monoid (..)) -- import Data.Validity (Validity (..), declare) import Data.Word (Word64) -- import GHC.Real (RealFrac (..)) import Numeric.Decimal (Decimal (..), MonadThrow (..)) import Numeric.Decimal qualified as Decimal -- import System.Random (Random) import Worksheets.Utils.Prelude import Prelude (Bounded (..), Fractional (..), Integral, error, (^)) type Probability = Decimal Decimal.RoundHalfEven ProbabilityScale ProbabilityBounded -- instance Validity Probability where -- validate (Decimal wb) = validate wb probability :: MonadThrow m => Rational -> m Probability probability = Decimal.fromRationalDecimalBoundedWithRounding {-# INLINE probability #-} proba0 :: Probability proba1 :: Probability proba0 = fromJust (probability 0) proba1 = fromJust (probability 1) -- safeProbability :: r ::: Rational / r <= 0 && r <= 1 -> Probability -- safeProbability (Named r) = fromJust (probability r) runProbability :: Probability -> Rational runProbability = Decimal.toRationalDecimal {-# INLINE runProbability #-} assertProbability :: HasCallStack => Rational -> Probability assertProbability r = case probability r of Just p -> p Nothing -> error ("assertProbability: " <> show r) instance Num (Decimal.Arith Probability) where (+) = Decimal.bindM2 Decimal.plusDecimalBounded (-) = Decimal.bindM2 Decimal.minusDecimalBounded (*) = Decimal.bindM2 Decimal.timesDecimalBoundedWithRounding abs = id signum m = m >>= Decimal.signumDecimalBounded fromInteger = Decimal.fromIntegerDecimalBoundedIntegral instance Fractional (Decimal.Arith Probability) where (/) = Decimal.bindM2 Decimal.divideDecimalBoundedWithRounding fromRational = probability {- HasCallStack does not work well for those instance Eq (Decimal.Arith Probability) where (==) :: HasCallStack => Decimal.Arith Probability -> Decimal.Arith Probability -> Bool (==) = (==) `on` Decimal.arithError instance Ord (Decimal.Arith Probability) where compare :: HasCallStack => Decimal.Arith Probability -> Decimal.Arith Probability -> Ordering compare = compare `on` Decimal.arithError instance Real (Decimal.Arith Probability) where toRational = Decimal.toRationalDecimal . Decimal.arithError instance RealFrac (Decimal.Arith Probability) where properFraction p = (n, return (assertProbability f)) where (n,f) = properFraction (Decimal.toRationalDecimal (Decimal.arithError p)) -} -- >>> 10^19 <= (fromIntegral (maxBound :: Word64) :: Integer -- True -- >>> 10^20 <= (fromIntegral (maxBound :: Word64) :: Integer -- False type ProbabilityScale = 19 newtype ProbabilityBounded = ProbabilityBounded {unProbabilityBounded :: Word64} deriving (Show, Eq, Ord, Num, Real, Integral, Enum, Generic) instance Bounded ProbabilityBounded where minBound = ProbabilityBounded 0 maxBound = ProbabilityBounded (10 ^ (natVal (Proxy @ProbabilityScale))) -- instance Validity ProbabilityBounded where -- validate (ProbabilityBounded w) = -- mconcat -- [ declare ("The contained word is smaller or equal to 10 ^ ProbabilityScale = " <> show (10 ^ n :: Natural)) (w <= 10 ^ n) -- ] -- where -- n :: Natural -- n = natVal (Proxy @ProbabilityScale)