import os import argparse import importlib from torch.backends import cudnn from organ.solver import Solver import organ.demo.structure import organ.structure.models def str2bool(v): return v.lower() in ('true') def load_rules(python_class_str: str): """Creates an instance of a class by fully qualified name. Parameters ---------- python_class_str : str Fully qualified class name. MUST contain at least one dot symbol, separating module name and class name. Returns ------- An instance of the specified class. """ last_dot_pos = python_class_str.rfind('.') if last_dot_pos <= 0 or last_dot_pos >= len(python_class_str) - 1: raise ValueError('Must be a fully-qualified name') module_name = python_class_str[:last_dot_pos] class_name = python_class_str[last_dot_pos + 1:] rules_module = importlib.import_module(module_name) return getattr(rules_module, class_name)() def structure_rules(val: str): """Process 'rules' command-line argument.""" if val == 'demo': return organ.demo.structure.DemoOrganizationStructureModel() elif val == 'generic': return organ.structure.models.Generic() else: return load_rules(val) def main(config): # For fast training. cudnn.benchmark = True # Create directories if not exist. if not os.path.exists(config.log_dir): os.makedirs(config.log_dir) if not os.path.exists(config.model_save_dir): os.makedirs(config.model_save_dir) # Solver for training and testing OrGAN. solver = Solver(config) if config.mode == 'train': solver.train() elif config.mode == 'test': solver.test() if __name__ == '__main__': parser = argparse.ArgumentParser() # Organization structure rules configuration parser.add_argument('--rules', type=structure_rules, default='generic', help='organization structure rules description. ' 'Can be either "demo", "generic", or a ' 'fully-qualified class name') # Model configuration. parser.add_argument('--z_dim', type=int, default=8, help='input dimension of G') # Размерности группы полносвязных слоев в начале генератора parser.add_argument('--g_conv_dim', default=[128, 256, 512], help='neurons in the dense layers in the encoder of G') # Спецификация сложности преобразований, которые должны # быть реализованы дискриминатором (и аппроксиматором). # Состоит из трех компонент: # - список, описывающий параметры графовых сверток, в частности, # размерности представлений вершин, # - количество признаков в глобальном представлении графа, # - список, задающий количества нейронов в серии полносвязных слоев. parser.add_argument('--d_conv_dim', type=int, default=[[128, 64], 128, [128, 64]], help='specification of D') # Вес для штрафа на величину градиента в функции оптимизации parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty') # Метод постобработки сгенерированных графов parser.add_argument('--post_method', type=str, default='softmax', choices=['softmax', 'soft_gumbel', 'hard_gumbel']) # Training configuration. # Размер батча parser.add_argument('--batch_size', type=int, default=16, help='mini-batch size') # Количество итераций (батчей) в процессе обучения parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D') # Количество итераций (перед последней, `num_iters`) в течение # которых будет осуществляться снижение константы обучения parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr') # Константа обучения для генератора parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G') # Константа обучения для дискриминатора parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D') # Дропаут (одно и то же значение используется везде, между # каждой парой слоев) parser.add_argument('--dropout', type=float, default=0., help='dropout rate') # Периодичность тренировки генератора # (каждые `n_critic` батчей) parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update') # beta1 для Adam (при обучении всех моделей) parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer') # beta2 для Adam (при обучении всех моделей) parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer') # Итерация, с которой нужно продолжить процесс обучения. # Если значение не 0, то все модели будут загружены из # точек сохранения и процесс продолжен. parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step') # Test configuration. # Указание на то, какую именно модель следует тестировать # (модель, созданную после test_iters итераций обучения). parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step') # Miscellaneous. parser.add_argument('--num_workers', type=int, default=1) parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) parser.add_argument('--use_tensorboard', type=str2bool, default=False) # Directories. # Директория с данными parser.add_argument('--data_dir', type=str, default='data') # Директория записи журнала (используется только с # Tensorboard) parser.add_argument('--log_dir', type=str, default='output/logs') # Директория для сохранения моделей # (из этой же директории они будут подгружаться при необходимости # продолжить обучение) parser.add_argument('--model_save_dir', type=str, default='output/models') # Настройка периодичности вывода информации # # Периодичность записи данных в журнал (для Tensorboard) parser.add_argument('--log_step', type=int, default=10) # Периодичность сохранения моделей parser.add_argument('--model_save_step', type=int, default=10000) # Периодичность изменения констант обучения. # Этим параметром регулируется то, как часто будет оцениваться # необходимость ревизии констант. См. также `num_iters_decay`. parser.add_argument('--lr_update_step', type=int, default=1000) config = parser.parse_args() print(config) main(config)