2 Module : Gargantext.Text.List.Learn
3 Description : Learn to make lists
4 Copyright : (c) CNRS, 2018-Present
5 License : AGPL + CECILL v3
6 Maintainer : team@gargantext.org
7 Stability : experimental
10 CSV parser for Gargantext corpus files.
14 {-# OPTIONS_GHC -fno-warn-orphans #-}
16 {-# LANGUAGE NoImplicitPrelude #-}
17 {-# LANGUAGE OverloadedStrings #-}
19 module Gargantext.Text.List.Learn
22 import Control.Monad.Reader (MonadReader)
23 import Control.Monad.IO.Class (MonadIO, liftIO)
24 import Gargantext.API.Settings
26 import Data.Maybe (maybe)
27 import Gargantext.Core.Types.Main (ListType(..), listTypeId, fromListTypeId)
28 import Gargantext.Prelude
29 import Gargantext.Prelude.Utils
30 import Gargantext.Text.Metrics.Count (occurrencesWith)
31 import qualified Data.IntMap as IntMap
32 import qualified Data.List as List
33 import qualified Data.Map as Map
34 import qualified Data.SVM as SVM
35 import qualified Data.Vector as Vec
37 ------------------------------------------------------------------------
38 train :: Double -> Double -> SVM.Problem -> IO SVM.Model
39 train x y = (SVM.train (SVM.CSvc x) (SVM.RBF y))
41 predict :: SVM.Model -> [Vec.Vector Double] -> IO [Double]
42 predict m vs = mapM (predict' m) vs
44 predict' m' vs' = SVM.predict m' (IntMap.fromList $ (zip [1..]) $ Vec.toList vs')
46 ------------------------------------------------------------------------
47 trainList :: Double -> Double -> Map ListType [Vec.Vector Double] -> IO SVM.Model
48 trainList x y = (train x y) . trainList'
50 trainList' :: Map ListType [Vec.Vector Double] -> SVM.Problem
51 trainList' = mapVec2problem . (Map.mapKeys (fromIntegral . listTypeId))
53 mapVec2problem :: Map Double [Vec.Vector Double] -> SVM.Problem
54 mapVec2problem = List.concat . (map (\(a,as) -> zip (repeat a) as)) . Map.toList . (Map.map vecs2maps)
56 vecs2maps :: [Vec.Vector Double] -> [IntMap.IntMap Double]
57 vecs2maps = map (IntMap.fromList . (zip [1..]) . Vec.toList)
60 predictList :: Model -> [Vec.Vector Double] -> IO [Maybe ListType]
61 predictList (ModelSVM m _ _) vs = map (fromListTypeId . round) <$> predict m vs
63 ------------------------------------------------------------------------
64 data Model = ModelSVM { modelSVM :: SVM.Model
65 , param1 :: Maybe Double
66 , param2 :: Maybe Double
69 instance SaveFile Model
71 saveFile' fp (ModelSVM m _ _) = SVM.saveModel m fp
73 instance ReadFile Model
77 pure $ ModelSVM m Nothing Nothing
79 ------------------------------------------------------------------------
82 -- split list : train / test
83 -- grid parameters on best result on test
85 type Train = Map ListType [Vec.Vector Double]
86 type Tests = Map ListType [Vec.Vector Double]
90 grid :: (MonadReader env m, MonadIO m, HasSettings env)
91 => Param -> Param -> Train -> [Tests] -> m (Maybe Model)
92 grid _ _ _ [] = panic "Gargantext.Text.List.Learn.grid : empty test data"
95 grid' :: (MonadReader env m, MonadIO m, HasSettings env)
100 grid' x y tr' te' = do
101 model'' <- liftIO $ trainList x y tr'
104 model' = ModelSVM model'' (Just x) (Just y)
106 score' :: [(ListType, Maybe ListType)] -> Map (Maybe Bool) Int
107 score' = occurrencesWith (\(a,b) -> (==) <$> Just a <*> b)
109 score'' :: Map (Maybe Bool) Int -> Double
110 score'' m'' = maybe 0 (\t -> (fromIntegral t)/total) (Map.lookup (Just True) m'')
112 total = fromIntegral $ foldl (+) 0 $ Map.elems m''
115 let (res, toGuess) = List.unzip
117 $ map (\(k,vs) -> zip (repeat k) vs)
120 res' <- liftIO $ predictList m toGuess
121 pure $ score'' $ score' $ List.zip res res'
123 score <- mapM (getScore model') te'
124 pure (mean score, model')
126 r <- head . List.reverse
128 <$> mapM (\(x,y) -> grid' x y tr te)
129 [(x,y) | x <- [s..e], y <- [s..e]]
131 printDebug "GRID SEARCH" (map fst r)
132 --printDebug "file" fp
133 --fp <- saveFile (ModelSVM model')