]> Git — Sourcephile - gargantext.git/blob - src/Gargantext/Core/Text/List/Learn.hs
Merge remote-tracking branch 'origin/dev-hackathon-fixes' into dev
[gargantext.git] / src / Gargantext / Core / Text / List / Learn.hs
1 {-|
2 Module : Gargantext.Core.Text.List.Learn
3 Description : Learn to make lists
4 Copyright : (c) CNRS, 2018-Present
5 License : AGPL + CECILL v3
6 Maintainer : team@gargantext.org
7 Stability : experimental
8 Portability : POSIX
9
10 CSV parser for Gargantext corpus files.
11
12 -}
13
14 {-# OPTIONS_GHC -fno-warn-orphans #-}
15
16
17 module Gargantext.Core.Text.List.Learn
18 where
19
20 import qualified Data.IntMap as IntMap
21 import qualified Data.List as List
22 import Data.Map (Map)
23 import qualified Data.Map as Map
24 import qualified Data.SVM as SVM
25 import qualified Data.Vector as Vec
26
27 import Gargantext.Core
28 import Gargantext.Core.Text.Metrics.Count (occurrencesWith)
29 import Gargantext.Core.Types.Main (ListType(..))
30 import Gargantext.Prelude
31 import Gargantext.Database.GargDB
32
33 ------------------------------------------------------------------------
34 train :: Double -> Double -> SVM.Problem -> IO SVM.Model
35 train x y = (SVM.train (SVM.CSvc x) (SVM.RBF y))
36
37 predict :: SVM.Model -> [Vec.Vector Double] -> IO [Double]
38 predict m vs = mapM (predict' m) vs
39 where
40 predict' m' vs' = SVM.predict m' (IntMap.fromList $ (zip [1..]) $ Vec.toList vs')
41
42 ------------------------------------------------------------------------
43 trainList :: Double -> Double -> Map ListType [Vec.Vector Double] -> IO SVM.Model
44 trainList x y = (train x y) . trainList'
45 where
46 trainList' :: Map ListType [Vec.Vector Double] -> SVM.Problem
47 trainList' = mapVec2problem . (Map.mapKeys (fromIntegral . toDBid))
48
49 mapVec2problem :: Map Double [Vec.Vector Double] -> SVM.Problem
50 mapVec2problem = List.concat . (map (\(a,as) -> zip (repeat a) as)) . Map.toList . (Map.map vecs2maps)
51
52 vecs2maps :: [Vec.Vector Double] -> [IntMap.IntMap Double]
53 vecs2maps = map (IntMap.fromList . (zip [1..]) . Vec.toList)
54
55
56 predictList :: Model -> [Vec.Vector Double] -> IO [Maybe ListType]
57 predictList (ModelSVM m _ _) vs = map (Just . fromDBid . round) <$> predict m vs
58
59 ------------------------------------------------------------------------
60 data Model = ModelSVM { modelSVM :: SVM.Model
61 , param1 :: Maybe Double
62 , param2 :: Maybe Double
63 }
64 --{-
65 instance SaveFile Model
66 where
67 saveFile' fp (ModelSVM m _ _) = SVM.saveModel m fp
68
69 instance ReadFile Model
70 where
71 readFile' fp = do
72 m <- SVM.loadModel fp
73 pure $ ModelSVM m Nothing Nothing
74 --}
75 ------------------------------------------------------------------------
76 -- | TODO
77 -- shuffle list
78 -- split list : train / test
79 -- grid parameters on best result on test
80
81 type Train = Map ListType [Vec.Vector Double]
82 type Tests = Map ListType [Vec.Vector Double]
83 type Score = Double
84 type Param = Double
85
86 grid :: (MonadBase IO m)
87 => Param -> Param -> Train -> [Tests] -> m (Maybe Model)
88 grid _ _ _ [] = panic "Gargantext.Core.Text.List.Learn.grid : empty test data"
89 grid s e tr te = do
90 let
91 grid' :: (MonadBase IO m)
92 => Double -> Double
93 -> Train
94 -> [Tests]
95 -> m (Score, Model)
96 grid' x y tr' te' = do
97 model'' <- liftBase $ trainList x y tr'
98
99 let
100 model' = ModelSVM model'' (Just x) (Just y)
101
102 score' :: [(ListType, Maybe ListType)] -> Map (Maybe Bool) Int
103 score' = occurrencesWith (\(a,b) -> (==) <$> Just a <*> b)
104
105 score'' :: Map (Maybe Bool) Int -> Double
106 score'' m'' = maybe 0 (\t -> (fromIntegral t)/total) (Map.lookup (Just True) m'')
107 where
108 total = fromIntegral $ foldl (+) 0 $ Map.elems m''
109
110 getScore m t = do
111 let (res, toGuess) = List.unzip
112 $ List.concat
113 $ map (\(k,vs) -> zip (repeat k) vs)
114 $ Map.toList t
115
116 res' <- liftBase $ predictList m toGuess
117 pure $ score'' $ score' $ List.zip res res'
118
119 score <- mapM (getScore model') te'
120 pure (mean score, model')
121
122 r <- head . List.reverse
123 . (List.sortOn fst)
124 <$> mapM (\(x,y) -> grid' x y tr te)
125 [(x,y) | x <- [s..e], y <- [s..e]]
126
127 printDebug "GRID SEARCH" (map fst r)
128 --printDebug "file" fp
129 --fp <- saveFile (ModelSVM model')
130 --save best result
131 pure $ snd <$> r