module HUnit.Rank where

import Data.Bool
import Data.Eq (Eq(..))
import Data.Foldable (Foldable(..))
import Data.Function (($), (.))
import Data.Functor ((<$>))
import Data.List
import Data.Ord (Ord(..))
import Data.Ratio
import Data.Semigroup (Semigroup(..))
import GHC.Exts (IsList(..))
import Majority.Judgment
import Prelude (Integer, Num(..), fromIntegral)
import Test.Tasty
import Test.Tasty.HUnit
import Text.Show (Show(..))

import QuickCheck.Merit ()
import QuickCheck.Value ()

hunit :: TestTree
hunit = testGroup "Rank"
 [ testGroup "lexicographic"
	 [ testLexRank 1 1
	 , testLexRank 5 4
	 , testLexRank 5 5
	 , testLexRank 10 5
	 , testLexRank 15 5
	 ]
 , testGroup "majority"
	 [ testMajRank 1 1
	 , testMajRank 3 2
	 , testMajRank 5 4
	 , testMajRank 5 5
	 , testMajRank 9 5
	 , testMajRank 10 5
	 , testMajRank 11 5
	 , testMajRank 12 5
	 , testMajRank 13 5
	 , testMajRank 14 5
	 , testMajRank 15 5
	 {-
	 , testMajRank 25 4
	 , testMajRank 25 5
	 , testMajRank 20 6
	 , testMajRank 30 4
	 , testMajRank 30 5
	 , testMajRank 10 10
	 -}
	 ]
 ]

testLexRank :: JS -> GS -> TestTree
testLexRank js gs =
	testGroup ("js="<>show js<>" gs="<>show gs)
	 [ testCase "lexicographicRankOfMerit" $
		lexicographicRankOfMerit gs <$> merits js gs
		 @?= [0..lastRank js gs]
	 , testCase "lexRankOfMerit . meritOfLexRank == id" $
		let ranks = [0..lastRank js gs] in
		lexicographicRankOfMerit gs . meritOfLexicographicRank js gs
		 <$> ranks @?= ranks
	 , testCase "meritOfLexRank . lexRankOfMerit == id" $
		let dists = merits js gs in
		meritOfLexicographicRank js gs . lexicographicRankOfMerit gs
		 <$> dists @?= dists
	 ]

testMajRank :: JS -> GS -> TestTree
testMajRank js gs =
	let mvs = majorityValues js gs in
	testGroup ("js="<>show js<>" gs="<>show gs<>" ("<>show (countMerits js gs)<>" merits)")
	 [ testCase "listMediansBefore" $
		sum (countMedian js gs <$> listMediansBefore js gs 0 (Median (gs,gs)))
		 @?= countMerits js gs
	 , testCase "majorityValueOfRank" $
		majorityValueOfRank js gs <$> [0..lastRank js gs] @?= mvs
	 , testCase "rankOfMajorityValue" $
		rankOfMajorityValue gs <$> mvs @?= [0..lastRank js gs]
	 {- NOTE: already implied by the previous tests.
	 , testCase "rankOfMV . mvOfRank == id" $
		rankOfMajorityValue gs . majorityValueOfRank js gs
		 <$> [0..lastRank js gs] @?= [0..lastRank js gs]
	 , testCase "mvOfRank . rankOfMV == id" $
		majorityValueOfRank js gs . rankOfMajorityValue gs
		 <$> mvs @?= mvs
	 -}
	 ]

-- | Generate all distributions possible, in lexicographic order.
merits :: JS -> GS -> [[G]]
merits js0 gs = go 0 js0
	where
	go g js
	 | g == gs - 1 = [replicate (fromIntegral js) g]
	 | otherwise = concat
		 [ (replicate (fromIntegral r) g <>) <$> go (g+1) (js-r)
		 | r <- reverse [0..js]
		 ]

-- | Generate all distributions possible, in majority order.
majorityValues :: JS -> GS -> [MajorityValue (Ranked ())]
majorityValues js0 gs = sort $ majorityValue . fromList <$> go 0 js0
	where
	go g js
	 | g == gs - 1 = [[(Ranked (g, ()), js%1)]]
	 | otherwise = concat
		 [ ((Ranked (g, ()), r%1) :) <$> go (g+1) (js-r)
		 | r <- reverse [0..js]
		 ]

lexicographicRankOfMerit :: GS -> [Integer] -> Integer
lexicographicRankOfMerit gsI dist = go 0 ranks dist
	where
	js  = fromIntegral $ length dist
	gs  = fromIntegral gsI
	ranks = reverse $ reverse . take gs <$> take js pascalDiagonals
	go g0 (p:ps) (d:ds) =
		sum (take dI p) +
		go d (drop dI <$> ps) ds
		where dI = fromIntegral (d - g0)
	go _ _ _ = 0

meritOfLexicographicRank :: JS -> GS -> Integer -> [Integer]
meritOfLexicographicRank jsI gsI = go 0 ranks
	where
	js = fromIntegral jsI
	gs = fromIntegral gsI
	ranks = reverse $ reverse . take gs <$> take js pascalDiagonals
	go _g0 [] _r = []
	go g0 (p:ps) r = g : go g (drop s <$> ps) (r-dr)
		where
		skip = takeWhile (<= r) $ scanl1 (+) p
		s    = length skip
		g    = g0 + fromIntegral s
		dr   = if null skip then 0 else last skip

-- | Diagonals of Pascal's triangle.
pascalDiagonals :: [[Integer]]
pascalDiagonals = repeat 1 : (scanl1 (+) <$> pascalDiagonals)