]> Git — Sourcephile - gargantext.git/blob - src/Gargantext/Text/List/Learn.hs
Merge branch 'dev' into dev-phylo
[gargantext.git] / src / Gargantext / Text / List / Learn.hs
1 {-|
2 Module : Gargantext.Text.List.Learn
3 Description : Learn to make lists
4 Copyright : (c) CNRS, 2018-Present
5 License : AGPL + CECILL v3
6 Maintainer : team@gargantext.org
7 Stability : experimental
8 Portability : POSIX
9
10 CSV parser for Gargantext corpus files.
11
12 -}
13
14 {-# OPTIONS_GHC -fno-warn-orphans #-}
15
16 {-# LANGUAGE NoImplicitPrelude #-}
17 {-# LANGUAGE OverloadedStrings #-}
18
19 module Gargantext.Text.List.Learn
20 where
21
22 import Control.Monad.Reader (MonadReader)
23 import Control.Monad.IO.Class (MonadIO, liftIO)
24 import Gargantext.API.Settings
25 import Data.Map (Map)
26 import Data.Maybe (maybe)
27 import Gargantext.Core.Types.Main (ListType(..), listTypeId, fromListTypeId)
28 import Gargantext.Prelude
29 import Gargantext.Prelude.Utils
30 import Gargantext.Text.Metrics.Count (occurrencesWith)
31 import qualified Data.IntMap as IntMap
32 import qualified Data.List as List
33 import qualified Data.Map as Map
34 import qualified Data.SVM as SVM
35 import qualified Data.Vector as Vec
36
37 ------------------------------------------------------------------------
38 train :: Double -> Double -> SVM.Problem -> IO SVM.Model
39 train x y = (SVM.train (SVM.CSvc x) (SVM.RBF y))
40
41 predict :: SVM.Model -> [Vec.Vector Double] -> IO [Double]
42 predict m vs = mapM (predict' m) vs
43 where
44 predict' m' vs' = SVM.predict m' (IntMap.fromList $ (zip [1..]) $ Vec.toList vs')
45
46 ------------------------------------------------------------------------
47 trainList :: Double -> Double -> Map ListType [Vec.Vector Double] -> IO SVM.Model
48 trainList x y = (train x y) . trainList'
49 where
50 trainList' :: Map ListType [Vec.Vector Double] -> SVM.Problem
51 trainList' = mapVec2problem . (Map.mapKeys (fromIntegral . listTypeId))
52
53 mapVec2problem :: Map Double [Vec.Vector Double] -> SVM.Problem
54 mapVec2problem = List.concat . (map (\(a,as) -> zip (repeat a) as)) . Map.toList . (Map.map vecs2maps)
55
56 vecs2maps :: [Vec.Vector Double] -> [IntMap.IntMap Double]
57 vecs2maps = map (IntMap.fromList . (zip [1..]) . Vec.toList)
58
59
60 predictList :: Model -> [Vec.Vector Double] -> IO [Maybe ListType]
61 predictList (ModelSVM m _ _) vs = map (fromListTypeId . round) <$> predict m vs
62
63 ------------------------------------------------------------------------
64 data Model = ModelSVM { modelSVM :: SVM.Model
65 , param1 :: Maybe Double
66 , param2 :: Maybe Double
67 }
68 --{-
69 instance SaveFile Model
70 where
71 saveFile' fp (ModelSVM m _ _) = SVM.saveModel m fp
72
73 instance ReadFile Model
74 where
75 readFile' fp = do
76 m <- SVM.loadModel fp
77 pure $ ModelSVM m Nothing Nothing
78 --}
79 ------------------------------------------------------------------------
80 -- | TODO
81 -- shuffle list
82 -- split list : train / test
83 -- grid parameters on best result on test
84
85 type Train = Map ListType [Vec.Vector Double]
86 type Tests = Map ListType [Vec.Vector Double]
87 type Score = Double
88 type Param = Double
89
90 grid :: (MonadReader env m, MonadIO m, HasSettings env)
91 => Param -> Param -> Train -> [Tests] -> m (Maybe Model)
92 grid _ _ _ [] = panic "Gargantext.Text.List.Learn.grid : empty test data"
93 grid s e tr te = do
94 let
95 grid' :: (MonadReader env m, MonadIO m, HasSettings env)
96 => Double -> Double
97 -> Train
98 -> [Tests]
99 -> m (Score, Model)
100 grid' x y tr' te' = do
101 model'' <- liftIO $ trainList x y tr'
102
103 let
104 model' = ModelSVM model'' (Just x) (Just y)
105
106 score' :: [(ListType, Maybe ListType)] -> Map (Maybe Bool) Int
107 score' = occurrencesWith (\(a,b) -> (==) <$> Just a <*> b)
108
109 score'' :: Map (Maybe Bool) Int -> Double
110 score'' m'' = maybe 0 (\t -> (fromIntegral t)/total) (Map.lookup (Just True) m'')
111 where
112 total = fromIntegral $ foldl (+) 0 $ Map.elems m''
113
114 getScore m t = do
115 let (res, toGuess) = List.unzip
116 $ List.concat
117 $ map (\(k,vs) -> zip (repeat k) vs)
118 $ Map.toList t
119
120 res' <- liftIO $ predictList m toGuess
121 pure $ score'' $ score' $ List.zip res res'
122
123 score <- mapM (getScore model') te'
124 pure (mean score, model')
125
126 r <- head . List.reverse
127 . (List.sortOn fst)
128 <$> mapM (\(x,y) -> grid' x y tr te)
129 [(x,y) | x <- [s..e], y <- [s..e]]
130
131 printDebug "GRID SEARCH" (map fst r)
132 --printDebug "file" fp
133 --fp <- saveFile (ModelSVM model')
134 --save best result
135 pure $ snd <$> r
136
137
138