From 0b1b44711d9cfa72e7ecddf48e5bfdc019382b30 Mon Sep 17 00:00:00 2001 From: Andrew Ponomarev Date: Sat, 31 Dec 2022 13:48:48 +0300 Subject: [PATCH] Public version 0.2. --- main.py | 2 + organ/models.py | 92 ++++++++++++++++++- organ/solver.py | 214 ++++++++++++++++++++++++++++++++----------- organ/utils.py | 1 + tests/test_models.py | 14 +-- tests/test_solver.py | 7 +- 6 files changed, 265 insertions(+), 65 deletions(-) diff --git a/main.py b/main.py index 8e30c85..292280e 100644 --- a/main.py +++ b/main.py @@ -147,6 +147,8 @@ if __name__ == '__main__': parser.add_argument('--use_tensorboard', type=str2bool, default=False) parser.add_argument('--augment', action='store_true', default=False, help='apply augmentations during training') + parser.add_argument('--no_pretrain', action='store_false', default=True, + dest='pretrain', help='disable pretraining') # Directories. # Директория с данными diff --git a/organ/models.py b/organ/models.py index 4291e5d..efc2240 100644 --- a/organ/models.py +++ b/organ/models.py @@ -6,11 +6,13 @@ This module defines the generator and discriminator networks import torch import torch.nn as nn +import organ.tingle + from organ.layers import GraphConvolution, GraphAggregation, \ EdgeConvolution, edge_aggregation -class Generator(nn.Module): +class SimpleGenerator(nn.Module): """Generator network for OrGAN. Generator is a non-linear neural transformation from an input @@ -49,7 +51,7 @@ class Generator(nn.Module): dropout : float Droupout [0; 1] (applied to each layer, including output). """ - super(Generator, self).__init__() + super(SimpleGenerator, self).__init__() assert vertexes == nodes @@ -103,7 +105,7 @@ class Generator(nn.Module): edges_logits = self.edges_layer(output) \ .view(-1, self.edges, self.vertexes, self.vertexes) # Получение симметричной (!) матрицы смежности - edges_logits = (edges_logits + edges_logits.permute(0, 1, 3, 2)) / 2 + # edges_logits = (edges_logits + edges_logits.permute(0, 1, 3, 2)) / 2 # TODO: (hatter) Мне странно применение дропаута к выходному слою edges_logits = self.dropoout(edges_logits.permute(0, 2, 3, 1)) @@ -116,6 +118,90 @@ class Generator(nn.Module): return edges_logits, nodes_logits +class EdgeAwareGenerator(nn.Module): + """Generator that creates edges based on types of nodes.""" + + def __init__(self, conv_dims, edge_conv_dims, z_dim, + vertexes, edges, nodes, dropout): + """Constructor. + + Parameters + ---------- + conv_dims : list + List, describing the FC layers in the beginning of the + generator. + edge_conv_dims : list + List, describint the edge layers. + z_dim : int + Input dimensions. + vertexes : int + Number of vertexes in the graph. Must be equal to `nodes`. + edges : int + Number of connections (edges). + nodes : int + Number of types of nodes. Must be equal to `vertexes`. + dropout : float + Droupout [0; 1] (applied to each layer, including output). + """ + super(EdgeAwareGenerator, self).__init__() + + assert vertexes == nodes + + self.vertexes = vertexes + self.edges = edges + self.nodes = nodes + + layers = [] + for c0, c1 in zip([z_dim] + conv_dims[:-1], conv_dims): + layers.append(nn.Linear(c0, c1)) + layers.append(nn.Tanh()) + layers.append(nn.Dropout(p=dropout, inplace=True)) + self.layers = nn.Sequential(*layers) + + self.edges_ctx_layer = nn.Linear(conv_dims[-1], + 32) + edge_layers = [] + for c0, c1 in zip([nodes + 32] + edge_conv_dims[:-1], edge_conv_dims): + edge_layers.append(nn.Linear(c0, c1)) + edge_layers.append(nn.Tanh()) + self.edge_layers = nn.Sequential(*edge_layers) + + self.edges_layer = nn.Linear(edge_conv_dims[-1], + edges) + self.nodes_layer = nn.Linear(conv_dims[-1], + vertexes) + self.dropoout = nn.Dropout(p=dropout) + + def forward(self, x): + + # Применение начальной группы полносвязных слоев + output = self.layers(x) + + # Получение спецификации вершин + nodes_logits = self.nodes_layer(output) + + # Описание вершин в развернутую форму + nodes_sigm = torch.sigmoid(nodes_logits) + nodes_hat = torch.diag_embed(nodes_sigm) + nodes_hat[:, :, 0] += (1 - nodes_sigm) + + # Получение спецификации связей графа + # Контекст генерации графа + ctx = self.edges_ctx_layer(output) + # Описания (типы) вершин, инцидентных + # ребру + cc = organ.tingle._cartesian(nodes_hat) + # Контекст + данные об инцидентных вершинах + edges_data = torch.cat([cc[0] - cc[1], + ctx.view(-1, 1, 1, 32). + expand(-1, self.nodes, self.nodes, 32)], + axis=-1) + edges = self.edge_layers(edges_data) + edges_logits = self.edges_layer(edges) + + return edges_logits, nodes_logits + + class Discriminator(nn.Module): """Discriminator for OrGAN. diff --git a/organ/solver.py b/organ/solver.py index 21051aa..af950e1 100644 --- a/organ/solver.py +++ b/organ/solver.py @@ -9,7 +9,7 @@ import datetime import torch import torch.nn.functional as F -from organ.models import Generator, Discriminator +from organ.models import EdgeAwareGenerator, Discriminator from organ.data.organization_structure_dataset \ import OrganizationStructureDataset from organ.utils import MetricsAggregator, all_scores @@ -127,6 +127,9 @@ class Solver(object): # необходимость ревизии констант. См. также `num_iters_decay`. self.lr_update_step = config.lr_update_step + # Should we pretrain? + self.pretrain = config.pretrain + # For the log to be informative, it should contain quality # characteristics of only generated structures assert self.log_step % self.n_critic == 0 @@ -144,11 +147,14 @@ class Solver(object): print('Node types:', self.data.node_num_types, self.m_dim) print('Edge types:', self.data.edge_num_types, self.b_dim) - self.G = Generator(self.g_conv_dim, self.z_dim, - self.data.vertexes, - self.data.edge_num_types, - self.data.node_num_types, - self.dropout) + self.G = EdgeAwareGenerator([128, 128], # self.g_conv_dim, + [128, 64, 32], + self.z_dim, + self.data.vertexes, + self.data.edge_num_types, + self.data.node_num_types, + self.dropout) + # NOTE: Архитектуры дискриминатора и аппроксиматора полностью # идентичны. self.D = Discriminator(self.d_conv_dim, @@ -180,6 +186,20 @@ class Solver(object): self.D.to(self.device) self.V.to(self.device) + def load_pretrained(self): + """Load pretrained models.""" + + for model_code, model in [('G', self.G), + ('D', self.D), + ('V', self.V)]: + # if there are pre-trained models and they are compatible + path = os.path.join(self.model_save_dir, f'pre-{model_code}.ckpt') + try: + model.load_state_dict(torch.load(path, map_location=lambda storage, loc: storage)) # noqa: E501 + print(f'Pretrained {model_code} has been loaded.') + except Exception: + print(f'Can"t load pre-trained {model_code} model, starting from scratch.') # noqa: E501 + def print_network(self, model, name): """Print model description. @@ -394,44 +414,6 @@ class Solver(object): def train(self): """Training cycle.""" - def next_batch(mode): - """Получение очередного батча, подготовка и загрузка на - устройство. - """ - if mode == 'train': - x, a = self.data.next_train_batch(self.batch_size) - elif mode == 'validation': - x, a = self.data.next_validation_batch() - else: - raise ValueError(f'Unknown mode: \'{mode}\'. ' - 'Only ''train'' and ''validation'' supported') - z = self.sample_z(x.shape[0]) # Батчи одинакового размера - orgs = list(zip(x, a)) - - # orgs is an (self.batch_size, ) numpy array - organization instances (used for checking) # noqa: E501 - # a is a (self.batch_size, 12, 12) numpy array - adjacency matrices (a_ij is the number of connections) # noqa: E501 - # x is a (self.batch_size, 12) numpy array - node type (categorical, 0 for no-node) # noqa: E501 - - # Загрузим данные на вычислительное устройство и приведем в вид, - # ожидаемый нейронными сетями - - a = torch.from_numpy(a).to(self.device).long() # Adjacency. - x = torch.from_numpy(x).to(self.device).long() # Nodes. - a_tensor = self.label2onehot(a, self.b_dim) - x_tensor = self.label2onehot(x, self.m_dim) - z = torch.from_numpy(z).to(self.device).float() - - return x_tensor, a_tensor, orgs, z - - def invoke_G(z): - """Генерация батча графов.""" - edges_logits, nodes_logits = self.G(z) - # Postprocess with Gumbel softmax - edges_hat = self.postprocess((edges_logits, ), - self.post_method)[0] - nodes_hat = self.postprocess_nodes(nodes_logits) - return edges_hat, nodes_hat - def compute_gp_loss(a_tensor, x_tensor, edges_hat, nodes_hat): eps = torch.rand(a_tensor.size(0), 1, 1, 1).to(self.device) x_int0 = (eps * a_tensor + (1. - eps) * edges_hat).requires_grad_(True) # noqa: E501 @@ -474,7 +456,7 @@ class Solver(object): d_loss_real = -torch.mean(torch.log(torch.sigmoid(logits_real))) # Compute loss with fake structures. - edges_hat, nodes_hat = invoke_G(z) + edges_hat, nodes_hat = self._invoke_G(z) logits_fake, features_fake = self.D(edges_hat, None, nodes_hat) d_loss_fake = -torch.mean(torch.log(1 - torch.sigmoid(logits_fake))) # noqa: E501 @@ -506,7 +488,7 @@ class Solver(object): # =========================================================== # # Получить батч из генератора - edges_hat, nodes_hat = invoke_G(z) + edges_hat, nodes_hat = self._invoke_G(z) # Получить оценку настоящих образцов с помощью "черного ящика" # Real Reward rewardR = torch.from_numpy(self.reward(orgs)).to(self.device) @@ -540,7 +522,7 @@ class Solver(object): # =========================================================== # # Получить батч из генератора - edges_hat, nodes_hat = invoke_G(z) + edges_hat, nodes_hat = self._invoke_G(z) # Оценить правдоподобие с точки зрения дискриминатора logits_fake, features_fake = self.D(edges_hat, None, nodes_hat) @@ -591,6 +573,10 @@ class Solver(object): if self.resume_iters: start_iters = self.resume_iters self.restore_model(self.resume_iters) + elif self.pretrain: + print('Start pre-training...') + self._pretrain() + # self.load_pretrained() # Start training. print('Start training...') @@ -599,7 +585,7 @@ class Solver(object): # Получение очередного батча, его подготовка и загрузка на # устройство - x_tensor, a_tensor, orgs, z = next_batch('train') + x_tensor, a_tensor, orgs, z = self._next_batch('train') # Обработка обучающего батча, пересчет весов orgs, loss = process_batch(a_tensor, x_tensor, @@ -611,7 +597,7 @@ class Solver(object): if (i+1) % self.log_step == 0: # Получение валидационного батча - x_tensor, a_tensor, orgs, z = next_batch('validation') + x_tensor, a_tensor, orgs, z = self._next_batch('validation') # Обработка обучающего батча, пересчет весов orgs, loss = process_batch(a_tensor, x_tensor, @@ -712,9 +698,10 @@ class Solver(object): for k, v in m0.items()} m0.update(m1) - log = '' - for tag, value in m0.items(): - log += ", {}: {:.4f}".format(tag, value) + log = 'Testing on {} structures: '.format(a.shape[0]) + if m0: + log += ', '.join(["{}: {:.4f}".format(tag, value) + for tag, value in m0.items()]) print(log) def generate(self, batch_size: int = 1, ctx=None): @@ -749,3 +736,126 @@ class Solver(object): return nodes_hard.cpu().numpy(), \ edges_hard.cpu().numpy(), None, None + + def _pretrain(self): + """Pretrain models.""" + + BATCH_SIZE = min(50, self.data.train_count) + + def pretrain_validator(model, + target_nodes, + target_edges, + max_iters=1000): + loss_fn = torch.nn.BCELoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + for i in range(max_iters): + # Батч "мусора" от генератора + optimizer.zero_grad() + input_z = self.sample_z(BATCH_SIZE) + z = torch.from_numpy(input_z).to(self.device).float() + edges_hat, nodes_hat = self._invoke_G(z) + value_bad, _ = model(edges_hat, + None, + nodes_hat, + torch.sigmoid) + loss_bad = loss_fn(value_bad, torch.zeros((BATCH_SIZE, 1), + dtype=torch.float32, + device=self.device)) + loss_bad.backward() + optimizer.step() + loss_bad = loss_bad.detach().cpu().item() + # Хорошие модели + optimizer.zero_grad() + value_good, _ = model(target_edges, + None, + target_nodes, + torch.sigmoid) + loss_good = loss_fn(value_good, torch.ones((BATCH_SIZE, 1), + dtype=torch.float32, + device=self.device)) + loss_good.backward() + optimizer.step() + + def pretrain_generator(z, target_nodes, target_edges, + max_iters=1000, loss_eps=0.01): + loss_fn = torch.nn.BCELoss() + optimizer = torch.optim.Adam(self.G.parameters(), lr=0.001) + for i in range(max_iters): + optimizer.zero_grad() + edges_hat, nodes_hat = self._invoke_G(z) + loss = loss_fn(nodes_hat, target_nodes) + \ + loss_fn(edges_hat, target_edges) + loss.backward() + optimizer.step() + loss_value = loss.detach().cpu().item() + if loss_value < loss_eps: + break + print(f'Generator loss@pretrain: {loss_value:.5f}') + + # Get a batch of real examples + x_tensor, a_tensor, _, z = self._next_batch('train', BATCH_SIZE) + + # Use these examples to give the validator and discriminator + # ideas of what is good and evil + pretrain_validator(self.V, x_tensor, a_tensor, max_iters=10) + pretrain_validator(self.D, x_tensor, a_tensor, max_iters=10) + + # Use the noise to pretrain the generator. + # It tries to map each point to the respective + # sample + pretrain_generator(z, x_tensor, a_tensor, + max_iters=1000, loss_eps=0.01) + + def _invoke_G(self, z): + """Generate a batch of graphs.""" + edges_logits, nodes_logits = self.G(z) + # Postprocess with Gumbel softmax + edges_hat = self.postprocess((edges_logits, ), + self.post_method)[0] + nodes_hat = self.postprocess_nodes(nodes_logits) + return edges_hat, nodes_hat + + def _next_batch(self, mode: str, batch_size=None): + """Retrieve next batch and load it to the device. + + Parameters + ---------- + mode: str + Specification of what set to use: 'train' or 'validation'. + + Returns + ------- + tuple + A tensor of nodes (batch, nodes, nodes), a tensor + of edges (batch, nodes, nodes, edges), a list + of structures corresponding to the batch, and z-noise + to use as an input for the generator. + """ + if batch_size is None: + batch_size = self.batch_size + + if mode == 'train': + x, a = self.data.next_train_batch(batch_size) + elif mode == 'validation': + x, a = self.data.next_validation_batch() + else: + raise ValueError(f'Unknown mode: \'{mode}\'. ' + 'Only ''train'' and ''validation'' supported') + + z = self.sample_z(x.shape[0]) # Батчи одинакового размера + orgs = list(zip(x, a)) + + # orgs is an (self.batch_size, ) numpy array - organization instances (used for checking) # noqa: E501 + # a is a (self.batch_size, 12, 12) numpy array - adjacency matrices (a_ij is the number of connections) # noqa: E501 + # x is a (self.batch_size, 12) numpy array - node type (categorical, 0 for no-node) # noqa: E501 + + # Загрузим данные на вычислительное устройство и приведем в вид, + # ожидаемый нейронными сетями + + a = torch.from_numpy(a).to(self.device).long() # Adjacency. + x = torch.from_numpy(x).to(self.device).long() # Nodes. + a_tensor = self.label2onehot(a, self.b_dim) + x_tensor = self.label2onehot(x, self.m_dim) + z = torch.from_numpy(z).to(self.device).float() + + return x_tensor, a_tensor, orgs, z diff --git a/organ/utils.py b/organ/utils.py index f875595..fbf3937 100644 --- a/organ/utils.py +++ b/organ/utils.py @@ -121,6 +121,7 @@ def all_scores(metrics_aggregator, m1 = {k: np.mean(v) for k, v in metrics_aggregator.get_scores(orgs).items() } + m1.update({'Accuracy': metrics_aggregator.valid_total_score(orgs)}) return m0, m1 diff --git a/tests/test_models.py b/tests/test_models.py index abd60c0..a234bb9 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -5,7 +5,7 @@ import torch sys.path.append('.') import organ.models # noqa: E402 -import tests.util # noqa: E402 +import tests.util # noqa: E402 def test_generator(): @@ -16,12 +16,12 @@ def test_generator(): # слот под вершину определенного типа и есть определенная позиция. n_node_types = n_nodes n_edge_types = 2 - g = organ.models.Generator([4, 5], - z_dim, # Размерность входного вектора - n_nodes, # Количество вершин в графе - n_edge_types, # Количество типов дуг - n_node_types, # Количество типов вершин - 0.0) + g = organ.models.SimpleGenerator([4, 5], + z_dim, # Размерность входного вектора # noqa: E501 + n_nodes, # Количество вершин в графе + n_edge_types, # Количество типов дуг + n_node_types, # Количество типов вершин + 0.0) edges, nodes = g(torch.randn(7, z_dim)) # Проверка размерности генерируемых массивов, описывающих граф assert edges.shape == (7, n_nodes, n_nodes, n_edge_types) diff --git a/tests/test_solver.py b/tests/test_solver.py index fe31bae..535912b 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -13,7 +13,7 @@ from organ.demo.structure import DemoOrganizationStructureModel # noqa: E402 import organ.structure.models # noqa: E402 -def make_config(): +def make_config(pretrain: bool = False): args = Namespace(rules=DemoOrganizationStructureModel(), z_dim=8, g_conv_dim=[128, 256, 512], @@ -40,7 +40,8 @@ def make_config(): model_save_dir='.tmp/organ/models', log_step=10, model_save_step=10, - lr_update_step=1000) + lr_update_step=1000, + pretrain=False) return args @@ -123,7 +124,7 @@ def test_fake(): @pytest.mark.integration def test_training_and_testing(): - config = make_config() + config = make_config(pretrain=True) # Необходимо, чтобы эти папки были созданы заранее if not os.path.exists(config.log_dir): -- GitLab