]> Git — Sourcephile - gargantext.git/blob - src/Gargantext/Core/Text/List/Learn.hs
[TextFlow] SpeGen scores (WIP)
[gargantext.git] / src / Gargantext / Core / Text / List / Learn.hs
1 {-|
2 Module : Gargantext.Core.Text.List.Learn
3 Description : Learn to make lists
4 Copyright : (c) CNRS, 2018-Present
5 License : AGPL + CECILL v3
6 Maintainer : team@gargantext.org
7 Stability : experimental
8 Portability : POSIX
9
10 CSV parser for Gargantext corpus files.
11
12 -}
13
14 {-# OPTIONS_GHC -fno-warn-orphans #-}
15
16
17 module Gargantext.Core.Text.List.Learn
18 where
19
20 import Control.Monad.Reader (MonadReader)
21 -- TODO remvoe this deps
22 import Gargantext.API.Admin.Settings
23 import Data.Map (Map)
24 import Gargantext.Core.Types.Main (ListType(..), listTypeId, fromListTypeId)
25 import Gargantext.Prelude
26 import Gargantext.Prelude.Utils
27 import Gargantext.Core.Text.Metrics.Count (occurrencesWith)
28 import qualified Data.IntMap as IntMap
29 import qualified Data.List as List
30 import qualified Data.Map as Map
31 import qualified Data.SVM as SVM
32 import qualified Data.Vector as Vec
33
34 ------------------------------------------------------------------------
35 train :: Double -> Double -> SVM.Problem -> IO SVM.Model
36 train x y = (SVM.train (SVM.CSvc x) (SVM.RBF y))
37
38 predict :: SVM.Model -> [Vec.Vector Double] -> IO [Double]
39 predict m vs = mapM (predict' m) vs
40 where
41 predict' m' vs' = SVM.predict m' (IntMap.fromList $ (zip [1..]) $ Vec.toList vs')
42
43 ------------------------------------------------------------------------
44 trainList :: Double -> Double -> Map ListType [Vec.Vector Double] -> IO SVM.Model
45 trainList x y = (train x y) . trainList'
46 where
47 trainList' :: Map ListType [Vec.Vector Double] -> SVM.Problem
48 trainList' = mapVec2problem . (Map.mapKeys (fromIntegral . listTypeId))
49
50 mapVec2problem :: Map Double [Vec.Vector Double] -> SVM.Problem
51 mapVec2problem = List.concat . (map (\(a,as) -> zip (repeat a) as)) . Map.toList . (Map.map vecs2maps)
52
53 vecs2maps :: [Vec.Vector Double] -> [IntMap.IntMap Double]
54 vecs2maps = map (IntMap.fromList . (zip [1..]) . Vec.toList)
55
56
57 predictList :: Model -> [Vec.Vector Double] -> IO [Maybe ListType]
58 predictList (ModelSVM m _ _) vs = map (fromListTypeId . round) <$> predict m vs
59
60 ------------------------------------------------------------------------
61 data Model = ModelSVM { modelSVM :: SVM.Model
62 , param1 :: Maybe Double
63 , param2 :: Maybe Double
64 }
65 --{-
66 instance SaveFile Model
67 where
68 saveFile' fp (ModelSVM m _ _) = SVM.saveModel m fp
69
70 instance ReadFile Model
71 where
72 readFile' fp = do
73 m <- SVM.loadModel fp
74 pure $ ModelSVM m Nothing Nothing
75 --}
76 ------------------------------------------------------------------------
77 -- | TODO
78 -- shuffle list
79 -- split list : train / test
80 -- grid parameters on best result on test
81
82 type Train = Map ListType [Vec.Vector Double]
83 type Tests = Map ListType [Vec.Vector Double]
84 type Score = Double
85 type Param = Double
86
87 grid :: (MonadReader env m, MonadBase IO m, HasSettings env)
88 => Param -> Param -> Train -> [Tests] -> m (Maybe Model)
89 grid _ _ _ [] = panic "Gargantext.Core.Text.List.Learn.grid : empty test data"
90 grid s e tr te = do
91 let
92 grid' :: (MonadReader env m, MonadBase IO m, HasSettings env)
93 => Double -> Double
94 -> Train
95 -> [Tests]
96 -> m (Score, Model)
97 grid' x y tr' te' = do
98 model'' <- liftBase $ trainList x y tr'
99
100 let
101 model' = ModelSVM model'' (Just x) (Just y)
102
103 score' :: [(ListType, Maybe ListType)] -> Map (Maybe Bool) Int
104 score' = occurrencesWith (\(a,b) -> (==) <$> Just a <*> b)
105
106 score'' :: Map (Maybe Bool) Int -> Double
107 score'' m'' = maybe 0 (\t -> (fromIntegral t)/total) (Map.lookup (Just True) m'')
108 where
109 total = fromIntegral $ foldl (+) 0 $ Map.elems m''
110
111 getScore m t = do
112 let (res, toGuess) = List.unzip
113 $ List.concat
114 $ map (\(k,vs) -> zip (repeat k) vs)
115 $ Map.toList t
116
117 res' <- liftBase $ predictList m toGuess
118 pure $ score'' $ score' $ List.zip res res'
119
120 score <- mapM (getScore model') te'
121 pure (mean score, model')
122
123 r <- head . List.reverse
124 . (List.sortOn fst)
125 <$> mapM (\(x,y) -> grid' x y tr te)
126 [(x,y) | x <- [s..e], y <- [s..e]]
127
128 printDebug "GRID SEARCH" (map fst r)
129 --printDebug "file" fp
130 --fp <- saveFile (ModelSVM model')
131 --save best result
132 pure $ snd <$> r