-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
89 lines (69 loc) · 2.56 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from dataclasses import dataclass, field
from datetime import datetime
import yaml
import re
@dataclass
class Dataset:
src: str
classes: list
class_mode: str = field(default="categorical")
batch: int = field(default=32)
train_options: dict = field(default_factory=dict)
valid_options: dict = field(default_factory=dict)
@dataclass
class Model:
cls: str = field(default="ResNet50")
module: str = field(default="tensorflow.keras.applications")
name: str = field(default=None)
include_top: bool = field(default=False)
weights: str = field(default="imagenet")
input_shape: str = field(default="(244, 224, 3)")
loss: str = field(default="binary_crossentropy")
metrics: str | list | dict = field(default="accuracy")
fc_layer_activation: str = field(default="sigmoid")
class_options: dict = field(default_factory=dict)
compile_options: dict = field(default_factory=dict)
@property
def input_tuple(self) -> tuple:
return tuple(map(int, re.findall(r"(\d+)", self.input_shape)))
@dataclass
class Optimizer:
cls: str
module: str = field(default="tensorflow.keras.optimizers")
options: dict = field(default_factory=dict)
@dataclass
class Schedule:
cls: str
module: str = field(default="tensorflow.keras.optimizers.schedules")
options: dict = field(default_factory=dict)
@dataclass
class Callback:
cls: str
module: str = field(default="tensorflow.keras.callbacks")
options: dict = field(default_factory=dict)
@dataclass
class Training:
epochs: int = field(default=10)
training_steps_per_epoch: int = field(default=None)
validation_steps_per_epoch: int = field(default=None)
options: dict = field(default_factory=dict)
class Config:
def __init__(self, dataset, model, optimizer, schedule=None, callbacks=None, training=None):
self.dataset = Dataset(**dataset)
self.model = Model(**model)
self.optimizer = Optimizer(**optimizer)
self.schedule = Schedule(**(schedule or dict()))
self.callbacks = [Callback(**c) for c in (callbacks or list())]
self.training = Training(**(training or dict()))
def parse_config(path: str) -> Config:
with open(path) as config_file:
raw = config_file.read()
unparsed = yaml.safe_load(raw)
regex = {
r'(\$now)': datetime.now().strftime("%Y%m%d-%H%M%S"),
r'(\$model_class)': unparsed['model']['cls'].lower()
}
parsed = raw
for pattern, repl in regex.items():
parsed = re.sub(pattern, repl, parsed)
return Config(**yaml.safe_load(parsed))