Skip to content
main.py 8.46 KiB
Newer Older
import os
import argparse
Andrew Ponomarev's avatar
Andrew Ponomarev committed
import importlib

from torch.backends import cudnn

Andrew Ponomarev's avatar
Andrew Ponomarev committed
from organ.solver import Solver
import organ.demo.structure
import organ.structure.models


def str2bool(v):
    return v.lower() in ('true')

Andrew Ponomarev's avatar
Andrew Ponomarev committed

def load_rules(python_class_str: str):
    """Creates an instance of a class by fully qualified name.

    Parameters
    ----------
    python_class_str : str
        Fully qualified class name. MUST contain at least one
        dot symbol, separating module name and class name.
    Returns
    -------
    An instance of the specified class.
    """
    last_dot_pos = python_class_str.rfind('.')
    if last_dot_pos <= 0 or last_dot_pos >= len(python_class_str) - 1:
        raise ValueError('Must be a fully-qualified name')

    module_name = python_class_str[:last_dot_pos]
    class_name = python_class_str[last_dot_pos + 1:]

    rules_module = importlib.import_module(module_name)
    return getattr(rules_module, class_name)()


def structure_rules(val: str):
    """Process 'rules' command-line argument."""
    if val == 'demo':
        return organ.demo.structure.DemoOrganizationStructureModel()
    elif val == 'generic':
        return organ.structure.models.Generic()
    else:
        return load_rules(val)


def main(config):
    # For fast training.
    cudnn.benchmark = True

    # Create directories if not exist.
    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    if not os.path.exists(config.model_save_dir):
        os.makedirs(config.model_save_dir)

    # Solver for training and testing OrGAN.
    solver = Solver(config)

    if config.mode == 'train':
        solver.train()
    elif config.mode == 'test':
        solver.test()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

Andrew Ponomarev's avatar
Andrew Ponomarev committed
    # Organization structure rules configuration
    parser.add_argument('--rules', type=structure_rules, default='generic',
                        help='organization structure rules description. '
                        'Can be either "demo", "generic", or a '
                        'fully-qualified class name')

    # Model configuration.
Andrew Ponomarev's avatar
Andrew Ponomarev committed
    parser.add_argument('--z_dim', type=int, default=8,
                        help='input dimension of G')
    # Размерности группы полносвязных слоев в начале генератора
    parser.add_argument('--g_conv_dim', default=[128, 256, 512],
                        help='neurons in the dense layers in the encoder of G')
    # Спецификация сложности преобразований, которые должны
    # быть реализованы дискриминатором (и аппроксиматором).
    # Состоит из трех компонент:
    #   - список, описывающий параметры графовых сверток, в частности,
    #     размерности представлений вершин,
    #   - количество признаков в глобальном представлении графа,
    #   - список, задающий количества нейронов в серии полносвязных слоев.
    parser.add_argument('--d_conv_dim', type=int,
                        default=[[128, 64], 128, [128, 64]],
                        help='specification of D')
    # Вес для штрафа на величину градиента в функции оптимизации
    parser.add_argument('--lambda_gp', type=float, default=10,
                        help='weight for gradient penalty')
    # Метод постобработки сгенерированных графов
    parser.add_argument('--post_method', type=str, default='softmax',
                        choices=['softmax', 'soft_gumbel', 'hard_gumbel'])

    # Training configuration.
Andrew Ponomarev's avatar
Andrew Ponomarev committed
    # Размер батча
    parser.add_argument('--batch_size', type=int, default=16,
                        help='mini-batch size')
    # Количество итераций (батчей) в процессе обучения
    parser.add_argument('--num_iters', type=int, default=200000,
                        help='number of total iterations for training D')
    # Количество итераций (перед последней, `num_iters`) в течение
    # которых будет осуществляться снижение константы обучения
    parser.add_argument('--num_iters_decay', type=int, default=100000,
                        help='number of iterations for decaying lr')
    # Константа обучения для генератора
    parser.add_argument('--g_lr', type=float, default=0.0001,
                        help='learning rate for G')
    # Константа обучения для дискриминатора
    parser.add_argument('--d_lr', type=float, default=0.0001,
                        help='learning rate for D')
    # Дропаут (одно и то же значение используется везде, между
    # каждой парой слоев)
    parser.add_argument('--dropout', type=float, default=0.,
                        help='dropout rate')
    # Периодичность тренировки генератора
    # (каждые `n_critic` батчей)
    parser.add_argument('--n_critic', type=int, default=5,
                        help='number of D updates per each G update')
    # beta1 для Adam (при обучении всех моделей)
    parser.add_argument('--beta1', type=float, default=0.5,
                        help='beta1 for Adam optimizer')
    # beta2 для Adam (при обучении всех моделей)
    parser.add_argument('--beta2', type=float, default=0.999,
                        help='beta2 for Adam optimizer')
    # Итерация, с которой нужно продолжить процесс обучения.
    # Если значение не 0, то все модели будут загружены из
    # точек сохранения и процесс продолжен.
    parser.add_argument('--resume_iters', type=int, default=None,
                        help='resume training from this step')

    # Test configuration.
Andrew Ponomarev's avatar
Andrew Ponomarev committed
    # Указание на то, какую именно модель следует тестировать
    # (модель, созданную после test_iters итераций обучения).
    parser.add_argument('--test_iters', type=int, default=200000,
                        help='test model from this step')

    # Miscellaneous.
    parser.add_argument('--num_workers', type=int, default=1)
Andrew Ponomarev's avatar
Andrew Ponomarev committed
    parser.add_argument('--mode', type=str, default='train',
                        choices=['train', 'test'])
    parser.add_argument('--use_tensorboard', type=str2bool, default=False)
    parser.add_argument('--augment', action='store_true', default=False,
                        help='apply augmentations during training')
Andrew Ponomarev's avatar
Andrew Ponomarev committed
    parser.add_argument('--no_pretrain', action='store_false', default=True,
                        dest='pretrain', help='disable pretraining')

    # Directories.
Andrew Ponomarev's avatar
Andrew Ponomarev committed
    # Директория с данными
    parser.add_argument('--data_dir', type=str, default='data')
Andrew Ponomarev's avatar
Andrew Ponomarev committed
    # Директория записи журнала (используется только с
    # Tensorboard)
    parser.add_argument('--log_dir', type=str, default='output/logs')
    # Директория для сохранения моделей
    # (из этой же директории они будут подгружаться при необходимости
    # продолжить обучение)
    parser.add_argument('--model_save_dir', type=str, default='output/models')
Andrew Ponomarev's avatar
Andrew Ponomarev committed
    # Настройка периодичности вывода информации
    #
    # Периодичность записи данных в журнал (для Tensorboard)
    parser.add_argument('--log_step', type=int, default=10)
Andrew Ponomarev's avatar
Andrew Ponomarev committed
    # Периодичность сохранения моделей
    parser.add_argument('--model_save_step', type=int, default=10000)
Andrew Ponomarev's avatar
Andrew Ponomarev committed
    # Периодичность изменения констант обучения.
    # Этим параметром регулируется то, как часто будет оцениваться
    # необходимость ревизии констант. См. также `num_iters_decay`.
    parser.add_argument('--lr_update_step', type=int, default=1000)

    config = parser.parse_args()
    print(config)
    main(config)