]> Git — Sourcephile - gargantext.git/blob - src/Gargantext/Text/List/Learn.hs
[FIX] Metrics adding filtering.
[gargantext.git] / src / Gargantext / Text / List / Learn.hs
1 {-|
2 Module : Gargantext.Text.List.Learn
3 Description : Learn to make lists
4 Copyright : (c) CNRS, 2018-Present
5 License : AGPL + CECILL v3
6 Maintainer : team@gargantext.org
7 Stability : experimental
8 Portability : POSIX
9
10 CSV parser for Gargantext corpus files.
11
12 -}
13
14 {-# OPTIONS_GHC -fno-warn-orphans #-}
15
16 {-# LANGUAGE FlexibleContexts #-}
17 {-# LANGUAGE NoImplicitPrelude #-}
18 {-# LANGUAGE OverloadedStrings #-}
19
20 module Gargantext.Text.List.Learn
21 where
22
23 import Control.Monad.Reader (MonadReader)
24 -- TODO remvoe this deps
25 import Gargantext.API.Admin.Settings
26 import Data.Map (Map)
27 import Data.Maybe (maybe)
28 import Gargantext.Core.Types.Main (ListType(..), listTypeId, fromListTypeId)
29 import Gargantext.Prelude
30 import Gargantext.Prelude.Utils
31 import Gargantext.Text.Metrics.Count (occurrencesWith)
32 import qualified Data.IntMap as IntMap
33 import qualified Data.List as List
34 import qualified Data.Map as Map
35 import qualified Data.SVM as SVM
36 import qualified Data.Vector as Vec
37
38 ------------------------------------------------------------------------
39 train :: Double -> Double -> SVM.Problem -> IO SVM.Model
40 train x y = (SVM.train (SVM.CSvc x) (SVM.RBF y))
41
42 predict :: SVM.Model -> [Vec.Vector Double] -> IO [Double]
43 predict m vs = mapM (predict' m) vs
44 where
45 predict' m' vs' = SVM.predict m' (IntMap.fromList $ (zip [1..]) $ Vec.toList vs')
46
47 ------------------------------------------------------------------------
48 trainList :: Double -> Double -> Map ListType [Vec.Vector Double] -> IO SVM.Model
49 trainList x y = (train x y) . trainList'
50 where
51 trainList' :: Map ListType [Vec.Vector Double] -> SVM.Problem
52 trainList' = mapVec2problem . (Map.mapKeys (fromIntegral . listTypeId))
53
54 mapVec2problem :: Map Double [Vec.Vector Double] -> SVM.Problem
55 mapVec2problem = List.concat . (map (\(a,as) -> zip (repeat a) as)) . Map.toList . (Map.map vecs2maps)
56
57 vecs2maps :: [Vec.Vector Double] -> [IntMap.IntMap Double]
58 vecs2maps = map (IntMap.fromList . (zip [1..]) . Vec.toList)
59
60
61 predictList :: Model -> [Vec.Vector Double] -> IO [Maybe ListType]
62 predictList (ModelSVM m _ _) vs = map (fromListTypeId . round) <$> predict m vs
63
64 ------------------------------------------------------------------------
65 data Model = ModelSVM { modelSVM :: SVM.Model
66 , param1 :: Maybe Double
67 , param2 :: Maybe Double
68 }
69 --{-
70 instance SaveFile Model
71 where
72 saveFile' fp (ModelSVM m _ _) = SVM.saveModel m fp
73
74 instance ReadFile Model
75 where
76 readFile' fp = do
77 m <- SVM.loadModel fp
78 pure $ ModelSVM m Nothing Nothing
79 --}
80 ------------------------------------------------------------------------
81 -- | TODO
82 -- shuffle list
83 -- split list : train / test
84 -- grid parameters on best result on test
85
86 type Train = Map ListType [Vec.Vector Double]
87 type Tests = Map ListType [Vec.Vector Double]
88 type Score = Double
89 type Param = Double
90
91 grid :: (MonadReader env m, MonadBase IO m, HasSettings env)
92 => Param -> Param -> Train -> [Tests] -> m (Maybe Model)
93 grid _ _ _ [] = panic "Gargantext.Text.List.Learn.grid : empty test data"
94 grid s e tr te = do
95 let
96 grid' :: (MonadReader env m, MonadBase IO m, HasSettings env)
97 => Double -> Double
98 -> Train
99 -> [Tests]
100 -> m (Score, Model)
101 grid' x y tr' te' = do
102 model'' <- liftBase $ trainList x y tr'
103
104 let
105 model' = ModelSVM model'' (Just x) (Just y)
106
107 score' :: [(ListType, Maybe ListType)] -> Map (Maybe Bool) Int
108 score' = occurrencesWith (\(a,b) -> (==) <$> Just a <*> b)
109
110 score'' :: Map (Maybe Bool) Int -> Double
111 score'' m'' = maybe 0 (\t -> (fromIntegral t)/total) (Map.lookup (Just True) m'')
112 where
113 total = fromIntegral $ foldl (+) 0 $ Map.elems m''
114
115 getScore m t = do
116 let (res, toGuess) = List.unzip
117 $ List.concat
118 $ map (\(k,vs) -> zip (repeat k) vs)
119 $ Map.toList t
120
121 res' <- liftBase $ predictList m toGuess
122 pure $ score'' $ score' $ List.zip res res'
123
124 score <- mapM (getScore model') te'
125 pure (mean score, model')
126
127 r <- head . List.reverse
128 . (List.sortOn fst)
129 <$> mapM (\(x,y) -> grid' x y tr te)
130 [(x,y) | x <- [s..e], y <- [s..e]]
131
132 printDebug "GRID SEARCH" (map fst r)
133 --printDebug "file" fp
134 --fp <- saveFile (ModelSVM model')
135 --save best result
136 pure $ snd <$> r
137
138
139