Skip to content
Snippets Groups Projects
Commit da09599b authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Forgot to add the files

parent 7f42a6d1
Branches
No related tags found
No related merge requests found
import configparser
import builtins
from distutils.util import strtobool as tobool
def get_the_args(path_to_config_file="../config_files/config.ini"):
"""This is the main function for extracting the args for a '.ini' file"""
config_parser = configparser.ConfigParser(comment_prefixes=('#'))
config_parser.read(path_to_config_file)
config_dict = {}
for section in config_parser:
config_dict[section] = {}
for key in config_parser[section]:
value = format_raw_arg(config_parser[section][key])
config_dict[section][key] = value
return config_dict
def format_raw_arg(raw_arg):
"""This function is used to convert the raw arg in a types value.
For example, 'list_int ; 10 20' will be formatted in [10,20]"""
function_name, raw_value = raw_arg.split(" ; ")
if function_name.startswith("list"):
function_name = function_name.split("_")[1]
raw_values = raw_value.split(" ")
value = [getattr(builtins, function_name)(raw_value)
if function_name != "bool" else bool(tobool(raw_value))
for raw_value in raw_values]
else:
if raw_value == "None":
value = None
else:
if function_name=="bool":
value = bool(tobool(raw_value))
else:
value = getattr(builtins, function_name)(raw_value)
return value
import os
import unittest
import numpy as np
from ...MonoMultiViewClassifiers.utils import configuration
class Test_get_the_args(unittest.TestCase):
def setUp(self):
self.path_to_config_file = "multiview_platform/Tests/tmp_tests/config_temp.ini"
os.mkdir("multiview_platform/Tests/tmp_tests")
config_file = open(self.path_to_config_file, "w")
config_file.write("[Base]\nfirst_arg = int ; 10\nsecond_arg = list_float ; 12.5 1e-06\n[Classification]\nthird_arg = bool ; yes")
config_file.close()
def tearDown(self):
os.remove("multiview_platform/Tests/tmp_tests/config_temp.ini")
os.rmdir("multiview_platform/Tests/tmp_tests")
def test_file_loading(self):
config_dict = configuration.get_the_args(self.path_to_config_file)
self.assertEqual(type(config_dict), dict)
def test_dict_format(self):
config_dict = configuration.get_the_args(self.path_to_config_file)
self.assertIn("Base", config_dict)
self.assertIn("Classification", config_dict)
self.assertIn("first_arg", config_dict["Base"])
self.assertIn("third_arg", config_dict["Classification"])
def test_arguments(self):
config_dict = configuration.get_the_args(self.path_to_config_file)
self.assertEqual(config_dict["Base"]["first_arg"], 10)
self.assertEqual(config_dict["Base"]["second_arg"], [12.5, 1e-06])
self.assertEqual(config_dict["Classification"]["third_arg"], True)
class Test_format_the_args(unittest.TestCase):
def test_bool(self):
value = configuration.format_raw_arg("bool ; yes")
self.assertEqual(value, True)
def test_int(self):
value = configuration.format_raw_arg("int ; 1")
self.assertEqual(value, 1)
def test_float(self):
value = configuration.format_raw_arg("float ; 1.5")
self.assertEqual(value, 1.5)
def test_string(self):
value = configuration.format_raw_arg("str ; chicken_is_heaven")
self.assertEqual(value, "chicken_is_heaven")
def test_list_bool(self):
value = configuration.format_raw_arg("list_bool ; yes no yes yes")
self.assertEqual(value, [True, False, True, True])
def test_list_int(self):
value = configuration.format_raw_arg("list_int ; 1 2 3 4")
self.assertEqual(value, [1,2,3,4])
def test_list_float(self):
value = configuration.format_raw_arg("list_float ; 1.5 1.6 1.7")
self.assertEqual(value, [1.5, 1.6, 1.7])
def test_list_string(self):
value = configuration.format_raw_arg("list_str ; list string")
self.assertEqual(value, ["list", "string"])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment