Newer
Older
from organ.solver import Solver
import organ.demo.structure
import organ.structure.models
def str2bool(v):
return v.lower() in ('true')
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def load_rules(python_class_str: str):
"""Creates an instance of a class by fully qualified name.
Parameters
----------
python_class_str : str
Fully qualified class name. MUST contain at least one
dot symbol, separating module name and class name.
Returns
-------
An instance of the specified class.
"""
last_dot_pos = python_class_str.rfind('.')
if last_dot_pos <= 0 or last_dot_pos >= len(python_class_str) - 1:
raise ValueError('Must be a fully-qualified name')
module_name = python_class_str[:last_dot_pos]
class_name = python_class_str[last_dot_pos + 1:]
rules_module = importlib.import_module(module_name)
return getattr(rules_module, class_name)()
def structure_rules(val: str):
"""Process 'rules' command-line argument."""
if val == 'demo':
return organ.demo.structure.DemoOrganizationStructureModel()
elif val == 'generic':
return organ.structure.models.Generic()
else:
return load_rules(val)
def main(config):
# For fast training.
cudnn.benchmark = True
# Create directories if not exist.
if not os.path.exists(config.log_dir):
os.makedirs(config.log_dir)
if not os.path.exists(config.model_save_dir):
os.makedirs(config.model_save_dir)
# Solver for training and testing OrGAN.
solver = Solver(config)
if config.mode == 'train':
solver.train()
elif config.mode == 'test':
solver.test()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Organization structure rules configuration
parser.add_argument('--rules', type=structure_rules, default='generic',
help='organization structure rules description. '
'Can be either "demo", "generic", or a '
'fully-qualified class name')
parser.add_argument('--z_dim', type=int, default=8,
help='input dimension of G')
# Размерности группы полносвязных слоев в начале генератора
parser.add_argument('--g_conv_dim', default=[128, 256, 512],
help='neurons in the dense layers in the encoder of G')
# Спецификация сложности преобразований, которые должны
# быть реализованы дискриминатором (и аппроксиматором).
# Состоит из трех компонент:
# - список, описывающий параметры графовых сверток, в частности,
# размерности представлений вершин,
# - количество признаков в глобальном представлении графа,
# - список, задающий количества нейронов в серии полносвязных слоев.
parser.add_argument('--d_conv_dim', type=int,
default=[[128, 64], 128, [128, 64]],
help='specification of D')
# Вес для штрафа на величину градиента в функции оптимизации
parser.add_argument('--lambda_gp', type=float, default=10,
help='weight for gradient penalty')
# Метод постобработки сгенерированных графов
parser.add_argument('--post_method', type=str, default='softmax',
choices=['softmax', 'soft_gumbel', 'hard_gumbel'])
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Размер батча
parser.add_argument('--batch_size', type=int, default=16,
help='mini-batch size')
# Количество итераций (батчей) в процессе обучения
parser.add_argument('--num_iters', type=int, default=200000,
help='number of total iterations for training D')
# Количество итераций (перед последней, `num_iters`) в течение
# которых будет осуществляться снижение константы обучения
parser.add_argument('--num_iters_decay', type=int, default=100000,
help='number of iterations for decaying lr')
# Константа обучения для генератора
parser.add_argument('--g_lr', type=float, default=0.0001,
help='learning rate for G')
# Константа обучения для дискриминатора
parser.add_argument('--d_lr', type=float, default=0.0001,
help='learning rate for D')
# Дропаут (одно и то же значение используется везде, между
# каждой парой слоев)
parser.add_argument('--dropout', type=float, default=0.,
help='dropout rate')
# Периодичность тренировки генератора
# (каждые `n_critic` батчей)
parser.add_argument('--n_critic', type=int, default=5,
help='number of D updates per each G update')
# beta1 для Adam (при обучении всех моделей)
parser.add_argument('--beta1', type=float, default=0.5,
help='beta1 for Adam optimizer')
# beta2 для Adam (при обучении всех моделей)
parser.add_argument('--beta2', type=float, default=0.999,
help='beta2 for Adam optimizer')
# Итерация, с которой нужно продолжить процесс обучения.
# Если значение не 0, то все модели будут загружены из
# точек сохранения и процесс продолжен.
parser.add_argument('--resume_iters', type=int, default=None,
help='resume training from this step')
# Указание на то, какую именно модель следует тестировать
# (модель, созданную после test_iters итераций обучения).
parser.add_argument('--test_iters', type=int, default=200000,
help='test model from this step')
# Miscellaneous.
parser.add_argument('--num_workers', type=int, default=1)
parser.add_argument('--mode', type=str, default='train',
choices=['train', 'test'])
parser.add_argument('--use_tensorboard', type=str2bool, default=False)
# Directories.
parser.add_argument('--data_dir', type=str, default='data')
# Директория записи журнала (используется только с
# Tensorboard)
parser.add_argument('--log_dir', type=str, default='output/logs')
# Директория для сохранения моделей
# (из этой же директории они будут подгружаться при необходимости
# продолжить обучение)
parser.add_argument('--model_save_dir', type=str, default='output/models')
# Настройка периодичности вывода информации
#
# Периодичность записи данных в журнал (для Tensorboard)
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--model_save_step', type=int, default=10000)
# Периодичность изменения констант обучения.
# Этим параметром регулируется то, как часто будет оцениваться
# необходимость ревизии констант. См. также `num_iters_decay`.
parser.add_argument('--lr_update_step', type=int, default=1000)
config = parser.parse_args()
print(config)
main(config)