]> Git — Sourcephile - gargantext.git/blob - src/Gargantext/Core/Text/List/Learn.hs
[VERSION] +1 to 0.0.1.91.0
[gargantext.git] / src / Gargantext / Core / Text / List / Learn.hs
1 {-|
2 Module : Gargantext.Core.Text.List.Learn
3 Description : Learn to make lists
4 Copyright : (c) CNRS, 2018-Present
5 License : AGPL + CECILL v3
6 Maintainer : team@gargantext.org
7 Stability : experimental
8 Portability : POSIX
9
10 CSV parser for Gargantext corpus files.
11
12 -}
13
14 {-# OPTIONS_GHC -fno-warn-orphans #-}
15
16
17 module Gargantext.Core.Text.List.Learn
18 where
19
20 import Control.Monad.Reader (MonadReader)
21 -- TODO remvoe this deps
22 import qualified Data.IntMap as IntMap
23 import qualified Data.List as List
24 import Data.Map (Map)
25 import qualified Data.Map as Map
26 import qualified Data.SVM as SVM
27 import qualified Data.Vector as Vec
28
29 import Gargantext.API.Admin.Types
30 import Gargantext.Core.Text.Metrics.Count (occurrencesWith)
31 import Gargantext.Core.Types.Main (ListType(..), listTypeId, fromListTypeId)
32 import Gargantext.Prelude
33 import Gargantext.Prelude.Utils
34
35 ------------------------------------------------------------------------
36 train :: Double -> Double -> SVM.Problem -> IO SVM.Model
37 train x y = (SVM.train (SVM.CSvc x) (SVM.RBF y))
38
39 predict :: SVM.Model -> [Vec.Vector Double] -> IO [Double]
40 predict m vs = mapM (predict' m) vs
41 where
42 predict' m' vs' = SVM.predict m' (IntMap.fromList $ (zip [1..]) $ Vec.toList vs')
43
44 ------------------------------------------------------------------------
45 trainList :: Double -> Double -> Map ListType [Vec.Vector Double] -> IO SVM.Model
46 trainList x y = (train x y) . trainList'
47 where
48 trainList' :: Map ListType [Vec.Vector Double] -> SVM.Problem
49 trainList' = mapVec2problem . (Map.mapKeys (fromIntegral . listTypeId))
50
51 mapVec2problem :: Map Double [Vec.Vector Double] -> SVM.Problem
52 mapVec2problem = List.concat . (map (\(a,as) -> zip (repeat a) as)) . Map.toList . (Map.map vecs2maps)
53
54 vecs2maps :: [Vec.Vector Double] -> [IntMap.IntMap Double]
55 vecs2maps = map (IntMap.fromList . (zip [1..]) . Vec.toList)
56
57
58 predictList :: Model -> [Vec.Vector Double] -> IO [Maybe ListType]
59 predictList (ModelSVM m _ _) vs = map (fromListTypeId . round) <$> predict m vs
60
61 ------------------------------------------------------------------------
62 data Model = ModelSVM { modelSVM :: SVM.Model
63 , param1 :: Maybe Double
64 , param2 :: Maybe Double
65 }
66 --{-
67 instance SaveFile Model
68 where
69 saveFile' fp (ModelSVM m _ _) = SVM.saveModel m fp
70
71 instance ReadFile Model
72 where
73 readFile' fp = do
74 m <- SVM.loadModel fp
75 pure $ ModelSVM m Nothing Nothing
76 --}
77 ------------------------------------------------------------------------
78 -- | TODO
79 -- shuffle list
80 -- split list : train / test
81 -- grid parameters on best result on test
82
83 type Train = Map ListType [Vec.Vector Double]
84 type Tests = Map ListType [Vec.Vector Double]
85 type Score = Double
86 type Param = Double
87
88 grid :: (MonadReader env m, MonadBase IO m, HasSettings env)
89 => Param -> Param -> Train -> [Tests] -> m (Maybe Model)
90 grid _ _ _ [] = panic "Gargantext.Core.Text.List.Learn.grid : empty test data"
91 grid s e tr te = do
92 let
93 grid' :: (MonadReader env m, MonadBase IO m, HasSettings env)
94 => Double -> Double
95 -> Train
96 -> [Tests]
97 -> m (Score, Model)
98 grid' x y tr' te' = do
99 model'' <- liftBase $ trainList x y tr'
100
101 let
102 model' = ModelSVM model'' (Just x) (Just y)
103
104 score' :: [(ListType, Maybe ListType)] -> Map (Maybe Bool) Int
105 score' = occurrencesWith (\(a,b) -> (==) <$> Just a <*> b)
106
107 score'' :: Map (Maybe Bool) Int -> Double
108 score'' m'' = maybe 0 (\t -> (fromIntegral t)/total) (Map.lookup (Just True) m'')
109 where
110 total = fromIntegral $ foldl (+) 0 $ Map.elems m''
111
112 getScore m t = do
113 let (res, toGuess) = List.unzip
114 $ List.concat
115 $ map (\(k,vs) -> zip (repeat k) vs)
116 $ Map.toList t
117
118 res' <- liftBase $ predictList m toGuess
119 pure $ score'' $ score' $ List.zip res res'
120
121 score <- mapM (getScore model') te'
122 pure (mean score, model')
123
124 r <- head . List.reverse
125 . (List.sortOn fst)
126 <$> mapM (\(x,y) -> grid' x y tr te)
127 [(x,y) | x <- [s..e], y <- [s..e]]
128
129 printDebug "GRID SEARCH" (map fst r)
130 --printDebug "file" fp
131 --fp <- saveFile (ModelSVM model')
132 --save best result
133 pure $ snd <$> r