]> Git — Sourcephile - gargantext.git/blob - src/Gargantext/Core/Text/List/Learn.hs
[FEAT] Patch Scores key with max value if non positive case
[gargantext.git] / src / Gargantext / Core / Text / List / Learn.hs
1 {-|
2 Module : Gargantext.Core.Text.List.Learn
3 Description : Learn to make lists
4 Copyright : (c) CNRS, 2018-Present
5 License : AGPL + CECILL v3
6 Maintainer : team@gargantext.org
7 Stability : experimental
8 Portability : POSIX
9
10 CSV parser for Gargantext corpus files.
11
12 -}
13
14 {-# OPTIONS_GHC -fno-warn-orphans #-}
15
16
17 module Gargantext.Core.Text.List.Learn
18 where
19
20 import qualified Data.IntMap as IntMap
21 import qualified Data.List as List
22 import Data.Map (Map)
23 import qualified Data.Map as Map
24 import qualified Data.SVM as SVM
25 import qualified Data.Vector as Vec
26
27 import Gargantext.Core.Text.Metrics.Count (occurrencesWith)
28 import Gargantext.Core.Types.Main (ListType(..), listTypeId, fromListTypeId)
29 import Gargantext.Prelude
30 import Gargantext.Prelude.Utils
31
32 ------------------------------------------------------------------------
33 train :: Double -> Double -> SVM.Problem -> IO SVM.Model
34 train x y = (SVM.train (SVM.CSvc x) (SVM.RBF y))
35
36 predict :: SVM.Model -> [Vec.Vector Double] -> IO [Double]
37 predict m vs = mapM (predict' m) vs
38 where
39 predict' m' vs' = SVM.predict m' (IntMap.fromList $ (zip [1..]) $ Vec.toList vs')
40
41 ------------------------------------------------------------------------
42 trainList :: Double -> Double -> Map ListType [Vec.Vector Double] -> IO SVM.Model
43 trainList x y = (train x y) . trainList'
44 where
45 trainList' :: Map ListType [Vec.Vector Double] -> SVM.Problem
46 trainList' = mapVec2problem . (Map.mapKeys (fromIntegral . listTypeId))
47
48 mapVec2problem :: Map Double [Vec.Vector Double] -> SVM.Problem
49 mapVec2problem = List.concat . (map (\(a,as) -> zip (repeat a) as)) . Map.toList . (Map.map vecs2maps)
50
51 vecs2maps :: [Vec.Vector Double] -> [IntMap.IntMap Double]
52 vecs2maps = map (IntMap.fromList . (zip [1..]) . Vec.toList)
53
54
55 predictList :: Model -> [Vec.Vector Double] -> IO [Maybe ListType]
56 predictList (ModelSVM m _ _) vs = map (fromListTypeId . round) <$> predict m vs
57
58 ------------------------------------------------------------------------
59 data Model = ModelSVM { modelSVM :: SVM.Model
60 , param1 :: Maybe Double
61 , param2 :: Maybe Double
62 }
63 --{-
64 instance SaveFile Model
65 where
66 saveFile' fp (ModelSVM m _ _) = SVM.saveModel m fp
67
68 instance ReadFile Model
69 where
70 readFile' fp = do
71 m <- SVM.loadModel fp
72 pure $ ModelSVM m Nothing Nothing
73 --}
74 ------------------------------------------------------------------------
75 -- | TODO
76 -- shuffle list
77 -- split list : train / test
78 -- grid parameters on best result on test
79
80 type Train = Map ListType [Vec.Vector Double]
81 type Tests = Map ListType [Vec.Vector Double]
82 type Score = Double
83 type Param = Double
84
85 grid :: (MonadBase IO m)
86 => Param -> Param -> Train -> [Tests] -> m (Maybe Model)
87 grid _ _ _ [] = panic "Gargantext.Core.Text.List.Learn.grid : empty test data"
88 grid s e tr te = do
89 let
90 grid' :: (MonadBase IO m)
91 => Double -> Double
92 -> Train
93 -> [Tests]
94 -> m (Score, Model)
95 grid' x y tr' te' = do
96 model'' <- liftBase $ trainList x y tr'
97
98 let
99 model' = ModelSVM model'' (Just x) (Just y)
100
101 score' :: [(ListType, Maybe ListType)] -> Map (Maybe Bool) Int
102 score' = occurrencesWith (\(a,b) -> (==) <$> Just a <*> b)
103
104 score'' :: Map (Maybe Bool) Int -> Double
105 score'' m'' = maybe 0 (\t -> (fromIntegral t)/total) (Map.lookup (Just True) m'')
106 where
107 total = fromIntegral $ foldl (+) 0 $ Map.elems m''
108
109 getScore m t = do
110 let (res, toGuess) = List.unzip
111 $ List.concat
112 $ map (\(k,vs) -> zip (repeat k) vs)
113 $ Map.toList t
114
115 res' <- liftBase $ predictList m toGuess
116 pure $ score'' $ score' $ List.zip res res'
117
118 score <- mapM (getScore model') te'
119 pure (mean score, model')
120
121 r <- head . List.reverse
122 . (List.sortOn fst)
123 <$> mapM (\(x,y) -> grid' x y tr te)
124 [(x,y) | x <- [s..e], y <- [s..e]]
125
126 printDebug "GRID SEARCH" (map fst r)
127 --printDebug "file" fp
128 --fp <- saveFile (ModelSVM model')
129 --save best result
130 pure $ snd <$> r