From 55bd81e81e998196f4243147241b3dac3bc43fc4 Mon Sep 17 00:00:00 2001 From: "Andrew (New server)" Date: Mon, 3 Oct 2022 13:42:27 +0000 Subject: [PATCH] Initial public commit. --- .gitignore | 1 + README.md | 25 +- data/organization_structure_dataset.py | 112 ++++ generate_dataset.py | 501 ++++++++++++++++++ layers.py | 57 +++ main.py | 81 +++ models.py | 90 ++++ notebooks/ModelGenerator.ipynb | 680 +++++++++++++++++++++++++ solver.py | 401 +++++++++++++++ utils.py | 66 +++ 10 files changed, 2013 insertions(+), 1 deletion(-) create mode 100644 .gitignore create mode 100644 data/organization_structure_dataset.py create mode 100644 generate_dataset.py create mode 100644 layers.py create mode 100644 main.py create mode 100644 models.py create mode 100644 notebooks/ModelGenerator.ipynb create mode 100644 solver.py create mode 100644 utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bee8a64 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__ diff --git a/README.md b/README.md index 8b6647b..1ecc9fd 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,26 @@ # OrGAN -Organization GAN (OrGAN) - using semantically enhanced GANs to generate organization structures. \ No newline at end of file +GAN-based generation of organization structures. + +It is based on the ideas of MolGAN: An implicit generative model for small molecular graphs (https://arxiv.org/abs/1805.11973) + +This implementation is based on +* [yongqyu/MolGAN-pytorch](https://github.com/yongqyu/MolGAN-pytorch) +* [nicola-decao/MolGAN](https://github.com/nicola-decao/MolGAN) + +Please note, that in the base repository (by yongqyu) there is an open issue that the results are not like in the MolGAN paper (official implementation is by nicola-decao. + + +## Dependencies + +* **python>=3.5** +* **pytorch>=0.4.1**: https://pytorch.org +* **numpy** + +## Structure +* `data`: should contain organization structure dataset. You can use `generate_dataset` script to generate one. + +## Usage +``` +python main.py +``` diff --git a/data/organization_structure_dataset.py b/data/organization_structure_dataset.py new file mode 100644 index 0000000..9cad10b --- /dev/null +++ b/data/organization_structure_dataset.py @@ -0,0 +1,112 @@ + +import pickle +import os +import numpy as np + +from datetime import datetime + +class OrganizationStructureDataset: + + def load(self, path, subset=1): + + # Train part + self.nodes = np.load(os.path.join(path, 'data_nodes.npy')) + self.edges = np.load(os.path.join(path, 'data_edges.npy')) + + with open(os.path.join(path, 'data_meta.pkl'), 'rb') as f: + self.__dict__.update(pickle.load(f)) + + self.train_idx = np.random.choice(self.train_idx, int(len(self.train_idx) * subset), replace=False) + self.validation_idx = np.random.choice(self.validation_idx, int(len(self.validation_idx) * subset), + replace=False) + self.test_idx = np.random.choice(self.test_idx, int(len(self.test_idx) * subset), replace=False) + + self.train_count = len(self.train_idx) + self.validation_count = len(self.validation_idx) + self.test_count = len(self.test_idx) + + self.__len = self.train_count + self.validation_count + self.test_count + + + def matrices2graph(self, node_labels, edge_labels, strict=False): + """ + Transforms matrix definition of a labeled graph into a graph instance. + + Currently, this function just glues inputs. In general, it can be used to transform it + to some optimized representation. + + Parameters + ---------- + node_labels : (nodes, ) + Numpy array with node types. + edge_labels : (nodes, nodes, ) + 2D numpy array with edge types. + + Returns + ------- + tuple ((nodes,), (nodes, nodes)) + tuple representing a graph. + + """ + + return node_labels, edge_labels + + + def remove_seq2mol(self, seq, strict=False): + mol = Chem.MolFromSmiles(''.join([self.smiles_decoder_m[e] for e in seq if e != 0])) + + if strict: + try: + Chem.SanitizeMol(mol) + except: + mol = None + + return mol + + def _next_batch(self, counter, count, idx, batch_size): + if batch_size is not None: + if counter + batch_size >= count: + counter = 0 + np.random.shuffle(idx) + + output = [obj[idx[counter:counter + batch_size]] + for obj in (self.nodes, self.edges)] + + counter += batch_size + else: + output = [obj[idx] for obj in (self.nodes, self.edges)] + + return [counter] + output + + def next_train_batch(self, batch_size=None): + out = self._next_batch(counter=self.train_counter, count=self.train_count, + idx=self.train_idx, batch_size=batch_size) + self.train_counter = out[0] + + return out[1:] + + def next_validation_batch(self, batch_size=None): + out = self._next_batch(counter=self.validation_counter, count=self.validation_count, + idx=self.validation_idx, batch_size=batch_size) + self.validation_counter = out[0] + + return out[1:] + + def next_test_batch(self, batch_size=None): + out = self._next_batch(counter=self.test_counter, count=self.test_count, + idx=self.test_idx, batch_size=batch_size) + self.test_counter = out[0] + + return out[1:] + + @staticmethod + def log(msg='', date=True): + print(str(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) + ' ' + str(msg) if date else str(msg)) + + def __len__(self): + return self.__len + + +if __name__ == '__main__': + + pass diff --git a/generate_dataset.py b/generate_dataset.py new file mode 100644 index 0000000..c6be886 --- /dev/null +++ b/generate_dataset.py @@ -0,0 +1,501 @@ +import os +import copy +import random +import pickle + +import numpy as np + +node_type_dict = { # status: 0 - optional, 1 - mandatory, 2 - replaceble + 0: {'title': 'none', 'status': 0, 'weight': 0}, + 1: {'title': 'Управление запасами (складом)', 'status': 2, 'replacement': [2, 3]}, + 2: {'title': 'Управление материалами', 'status': 0}, + 3: {'title': 'Управление готовыми товарами', 'status': 0}, + 4: {'title': 'Планирование', 'status': 2, 'replacement': [5, 6, 7], 'children': [8]}, + 5: {'title': 'Планирование закупок', 'status': 0}, + 6: {'title': 'Планирование запасов', 'status': 0}, + 7: {'title': 'Планирование перевозок', 'status': 0}, + 8: {'title': 'Аналитика', 'status': 0}, + 9: {'title': 'Аудит', 'status': 0}, + 10: {'title': 'Транспортировка', 'status': 1, 'children': [11]}, + 11: {'title': 'Транспортное хозяйство', 'status': 0}, +} + +top_level_nodes = [1, 4, 9, 10] + +relations_dict = [ + # 0 1 2 3 4 5 6 7 8 9 10 11 + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # 0 + [ 0, 0, 1, 1, 2, 2, 2, 0, 2, 2, 2, 2], # 1 + [ 0, 0, 0, 0, 2, 2, 0, 2, 2, 2, 2, 2], # 2 + [ 0, 0, 0, 0, 2, 0, 2, 2, 2, 2, 2, 0], # 3 + [ 0, 2, 2, 2, 0, 1, 1, 1, 1, 0, 2, 2], # 4 + [ 0, 2, 2, 0, 0, 0, 0, 2, 2, 0, 0, 2], # 5 + [ 0, 2, 0, 2, 0, 0, 0, 2, 2, 0, 0, 0], # 6 + [ 0, 0, 2, 2, 0, 2, 2, 0, 2, 0, 2, 2], # 7 + [ 0, 2, 2, 2, 0, 2, 2, 2, 0, 0, 2, 2], # 8 + [ 0, 2, 2, 2, 0, 0, 0, 0, 0, 0, 2, 2], # 9 + [ 0, 2, 2, 2, 2, 0, 0, 2, 2, 2, 0, 1], # 10 + [ 0, 2, 2, 0, 2, 2, 0, 2, 2, 2, 0, 0], # 11 +] + +# Number of node types +NODE_N_TYPES = len(node_type_dict) +# Number of edge types +EDGE_N_TYPES = 3 +# Max number of vertices per graph +MAX_NODES_PER_GRAPH = 12 + +# Parametrization constants +upper_limit = 25 +min_person = 0.7 +max_person = 1.4 +min_nonzero_value = 0.1 +min_orgunit = 2 +req_orgunit = 5 + +# Recursive function to generate element tree + +def generate_children(node_list, force_nodes = False): + child_nodes = [] + for key in node_list: + if node_type_dict[key]['status'] == 1 or force_nodes: # mandatory + child_nodes.append(key) + if 'replacement' in node_type_dict[key]: + child_nodes.extend(generate_children(node_type_dict[key]['replacement'])) + if 'children' in node_type_dict[key]: + child_nodes.extend(generate_children(node_type_dict[key]['children'])) + elif node_type_dict[key]['status'] == 0: + if np.random.rand(1)[0] < 0.5: + child_nodes.append(key) + if 'replacement' in node_type_dict[key]: + child_nodes.extend(generate_children(node_type_dict[key]['replacement'])) + if 'children' in node_type_dict[key]: + child_nodes.extend(generate_children(node_type_dict[key]['children'])) + elif node_type_dict[key]['status'] == 2: + if np.random.rand(1)[0] < 0.5: + child_nodes.append(key) + if 'replacement' in node_type_dict[key]: + child_nodes.extend(generate_children(node_type_dict[key]['replacement'])) + if 'children' in node_type_dict[key]: + child_nodes.extend(generate_children(node_type_dict[key]['children'])) + else: + child_nodes.extend(generate_children(node_type_dict[key]['replacement'], True)) + if 'children' in node_type_dict[key]: + child_nodes.extend(generate_children(node_type_dict[key]['children'])) + + #generate children + return child_nodes + +def check_org_unit_feasibility(nodes, load, unit_id, min_person, max_person, min_orgunit, req_orgunit, logging=False): + if ((nodes[unit_id]==0 and load[unit_id]>max_person*req_orgunit) or + (nodes[unit_id]>0 and load[unit_id] 0: + v[2] = np.random.uniform(min_person*min_orgunit, upper_limit) + else: + v[2] = np.random.uniform(min_nonzero_value, max_person*req_orgunit) + + if nodes[3] > 0: + v[3] = np.random.uniform(min_person*min_orgunit, upper_limit) + else: + v[3] = np.random.uniform(min_nonzero_value, max_person*req_orgunit) + + if logging: + print("input v=", v) + + v2_k = 0.2 if nodes[2] > 0 else 1 + v3_k = 0.2 if nodes[3] > 0 else 1 + + v[1] = v2_k * v[2] + v3_k * v[3] + if not check_org_unit_feasibility(nodes, v, 1, min_person, max_person, min_orgunit, req_orgunit, logging=logging): + return np.empty(0) + + v[10] = v[2] * 1 + v[3] * 1 + + v[11] = v[10] * 0.2 + if not check_org_unit_feasibility(nodes, v, 11, min_person, max_person, min_orgunit, req_orgunit, logging=logging): + return np.empty(0) + + v12_k = v[2] * 0.4 if nodes[2] > 0 else v[1] * 0.2 + v11_k = v[11] * 0.2 if nodes[11] > 0 else 0 + v[5] = v12_k + v11_k + if not check_org_unit_feasibility(nodes, v, 5, min_person, max_person, min_orgunit, req_orgunit, logging=logging): + return np.empty(0) + + v[6] = v[3] * 0.4 if nodes[3] > 0 else v[1] * 0.2 + if not check_org_unit_feasibility(nodes, v, 6, min_person, max_person, min_orgunit, req_orgunit, logging=logging): + return np.empty(0) + + v[7] = v[10] * 0.2 + if not check_org_unit_feasibility(nodes, v, 7, min_person, max_person, min_orgunit, req_orgunit, logging=logging): + return np.empty(0) + + v[8] = v[5] * 0.2 + v[6] * 0.2 + v[7] * 0.2 + if not check_org_unit_feasibility(nodes, v, 8, min_person, max_person, min_orgunit, req_orgunit, logging=logging): + return np.empty(0) + + v5_k = 0.2 if nodes[5] > 0 else 1 + v6_k = 0.2 if nodes[6] > 0 else 1 + v7_k = 0.2 if nodes[7] > 0 else 1 + v8_k = 0.2 if nodes[8] > 0 else 1 + v[4] = v5_k * v[5] + v6_k * v[6] + v7_k * v[7] + v8_k * v[8] + if not check_org_unit_feasibility(nodes, v, 4, min_person, max_person, min_orgunit, req_orgunit, logging=logging): + return np.empty(0) + + v[9] = v[1] * 0.05 + v[2] * 0.05 + v[3] * 0.05 + v[10] * 0.05 + v[11] * 0.05 + if not check_org_unit_feasibility(nodes, v, 9, min_person, max_person, min_orgunit, req_orgunit, logging=logging): + return np.empty(0) + + return v + +def convert_values2persons(nodes, load, min_person, max_person): + staff = np.zeros(12) + for i in range(len(nodes)): + #print(i) + if nodes[i] > 0: + #print(load[i], load[i]/max_person, load[i]/min_person) + staff[i] = np.random.randint(np.ceil(load[i]/max_person), np.ceil(load[i]/min_person), 1)[0] + return staff + +# Shuffle nodes and relations + +def shuffle_nodes(nodes, relations): + keys = list(range(12)) + random.shuffle(keys) + nodes = nodes[keys] + relations = relations[keys][:, keys] + #print(keys) + #print(nodes) + #print(relations) + return nodes, relations + +def generate_model(shuffle=True): + #generate nodes + tmp_nodes = generate_children(top_level_nodes) + + #fill list with generated nodes and fill all relations + nodes = list(node_type_dict.keys()) + relations = np.array(copy.deepcopy(relations_dict)) + + for node_key in node_type_dict: + if node_key not in tmp_nodes: + nodes[node_key] = 0 + relations[node_key, :] = 0 + relations[:, node_key] = 0 + + nodes = np.array(nodes) + + if shuffle: + nodes, relations = shuffle_nodes(nodes, relations) + return nodes, relations + +def generate_parametrized_model(logging=False): + while True: + nodes, relations = generate_model(False) + if logging: + print("\nnodes=", nodes) + load = generate_values(nodes, logging) + if len(load) > 0: + if logging: + print(load) + staff = convert_values2persons(nodes, load, min_person, max_person) + break + return nodes, relations, staff + +def unshuffle(nodes, relations): + keys = [-1] * len(nodes) + zeros = [i for i, node in enumerate(nodes) if node == 0] + non_zeros = [(i, node) for i, node in enumerate(nodes) if node != 0] + #print(zeros) + #print(non_zeros) + + for i, node in non_zeros: + keys[node] = i + #print(keys) + + for i in range(len(keys)): + if keys[i] == -1: + keys[i] = zeros.pop(0) + #print(keys) + + return nodes[keys], relations[keys][:, keys] + +def check_children(top_level_nodes, node_list, force_nodes = False): + #print('top_level_nodes=', top_level_nodes) + #print('force_nodes', force_nodes) + for key in top_level_nodes: + #print('Key = ', key) + if node_type_dict[key]['status'] == 1 or force_nodes: # mandatory + #print('Mandatory') + if node_list[key] == 0: + return False, f'Mandatory element {key} is missing.' + if 'replacement' in node_type_dict[key]: + result, explanation = check_children(node_type_dict[key]['replacement'], node_list) + if not result: + return False, explanation + if 'children' in node_type_dict[key]: + result, explanation = check_children(node_type_dict[key]['children'], node_list) + if not result: + return False, explanation + if node_type_dict[key]['status'] == 2 or force_nodes: # replacable + #print('Replacable') + if node_list[key] == 0: + if 'replacement' in node_type_dict[key]: + result, explanation = check_children(node_type_dict[key]['replacement'], node_list, True) + if not result: + return False, f'Replacable element {key} is missing. {explanation}' + else: + return False, f'Replacable element {key} is missing.' + if 'children' in node_type_dict[key]: + result, explanation = check_children(node_type_dict[key]['children'], node_list) + if not result: + return False, explanation + elif node_type_dict[key]['status'] == 0: + #print('Optional') + if 'replacement' in node_type_dict[key]: + result, explanation = check_children(node_type_dict[key]['replacement'], node_list) + if not result: + return False, explanation + if 'children' in node_type_dict[key]: + result, explanation = check_children(node_type_dict[key]['children'], node_list) + if not result: + return False, explanation + return True, '' + +def check_relations(nodes, relations): + """Checks relations validity. + + Parameters + ---------- + nodes : List + The list of node types. + relations : numpy.ndarray (n, n) + Relation type matrix. + + Returns + ------- + result : Bool + Returns `True` if all the set of edges is valid and consistent + with the nodes. + diff + Boolean matrix of edge validness (`True` for valid edges). + """ + target_relations = np.array(copy.deepcopy(relations_dict)) + for node_key in node_type_dict: + if node_key not in nodes: + target_relations[node_key, :] = 0 + target_relations[:, node_key] = 0 + + relations_diff = np.array([relations == target_relations]) + result = relations_diff.all() + return result, relations_diff + +def check_nodes(nodes): + """Checks node types validity. + + Parameters + ---------- + nodes : List + The list of node types. + + Returns + ------- + Bool + Returns `True` if the structure contains valid set of nodes. + """ + return check_children(top_level_nodes, nodes)[0] + +def overlap(first, last, another_first, another_last)->bool: + #print(first, last, another_first, another_last) + return min(last, another_last) - max(first, another_first) >= 0 + +def check_paramater_feasibility(nodes, staff, logging=False): + for unit_id in node_type_dict: + # Необязательный отдел имеет слишком мало сотрудников + if (not node_type_dict[unit_id]['status'] == 1 and + nodes[unit_id] > 0 and staff[unit_id] < min_orgunit): + #error + if logging: + print(f"Node {unit_id} {node_type_dict[unit_id]['title']} doesn't meet the requirements: staff = {staff[unit_id]}") + return False + # Пустной отдел имеют сотрудников + if nodes[unit_id] == 0 and staff[unit_id] > 0: + #error + if logging: + print(f"Non-existing Node {unit_id} {node_type_dict[unit_id]['title']} has staff: staff = {staff[unit_id]}") + return False + + # Соотношения объемов отделов + load_min = staff * min_person + load_max = staff * max_person + min_person + max_no_unit = req_orgunit * max_person + for unit_id in node_type_dict: + if (nodes[unit_id] == 0): + load_max[unit_id] = max_no_unit + #print(load_min) + #print(load_max) + + v_min = np.zeros(12) + v_max = np.zeros(12) + + unit_id = 1 + v_min[unit_id] = (0.2 if nodes[2] > 0 else 1) * load_min[2] + (0.2 if nodes[3] > 0 else 1) * load_min[3] + v_max[unit_id] = (0.2 if nodes[2] > 0 else 1) * load_max[2] + (0.2 if nodes[3] > 0 else 1) * load_max[3] + if nodes[unit_id] > 0 and not overlap(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]): + #error + if logging: + print(f"Capacity of node {unit_id} {node_type_dict[unit_id]['title']} does not meet requirements.") + print(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]) + return False + + unit_id = 10 + v_min[10] = load_min[2] * 1 + load_min[3] * 1 + v_max[10] = load_max[2] * 1 + load_max[3] * 1 + if nodes[unit_id] > 0 and not overlap(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]): + #error + if logging: + print(f"Capacity of node {unit_id} {node_type_dict[unit_id]['title']} does not meet requirements.") + print(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]) + return False + + unit_id = 11 + v_min[11] = load_min[10] * 0.2 + v_max[11] = load_max[10] * 0.2 + if nodes[unit_id] > 0 and not overlap(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]): + #error + if logging: + print(f"Capacity of node {unit_id} {node_type_dict[unit_id]['title']} does not meet requirements.") + print(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]) + return False + + unit_id = 5 + v_min[5] = (load_min[2] * 0.4 if nodes[2] > 0 else load_min[1] * 0.2) + (load_min[11] * 0.2 if nodes[11] > 0 else 0) + v_max[5] = (load_max[2] * 0.4 if nodes[2] > 0 else load_max[1] * 0.2) + (load_max[11] * 0.2 if nodes[11] > 0 else 0) + if nodes[unit_id] > 0 and not overlap(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]): + #error + if logging: + print(f"Capacity of node {unit_id} {node_type_dict[unit_id]['title']} does not meet requirements.") + print(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]) + return False + + unit_id = 6 + v_min[6] = load_min[3] * 0.4 if nodes[3] > 0 else load_min[1] * 0.2 + v_max[6] = load_max[3] * 0.4 if nodes[3] > 0 else load_max[1] * 0.2 + if nodes[unit_id] > 0 and not overlap(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]): + #error + if logging: + print(f"Capacity of node {unit_id} {node_type_dict[unit_id]['title']} does not meet requirements.") + print(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]) + return False + + unit_id = 7 + v_min[7] = load_min[10] * 0.2 + v_max[7] = load_max[10] * 0.2 + if nodes[unit_id] > 0 and not overlap(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]): + #error + if logging: + print(f"Capacity of node {unit_id} {node_type_dict[unit_id]['title']} does not meet requirements.") + print(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]) + return False + + unit_id = 8 + v_min[8] = load_min[5] * 0.2 + load_min[6] * 0.2 + load_min[7] * 0.2 + v_max[8] = load_max[5] * 0.2 + load_max[6] * 0.2 + load_max[7] * 0.2 + if nodes[unit_id] > 0 and not overlap(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]): + #error + if logging: + print(f"Capacity of node {unit_id} {node_type_dict[unit_id]['title']} does not meet requirements.") + print(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]) + return False + + unit_id = 4 + v_min[4] = (0.2 if nodes[5] > 0 else 1) * load_min[5] + \ + (0.2 if nodes[6] > 0 else 1) * load_min[6] + \ + (0.2 if nodes[7] > 0 else 1) * load_min[7] + \ + (0.2 if nodes[8] > 0 else 1) * load_min[8] + v_max[4] = (0.2 if nodes[5] > 0 else 1) * load_max[5] + \ + (0.2 if nodes[6] > 0 else 1) * load_max[6] + \ + (0.2 if nodes[7] > 0 else 1) * load_max[7] + \ + (0.2 if nodes[8] > 0 else 1) * load_max[8] + if nodes[unit_id] > 0 and not overlap(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]): + #error + if logging: + print(f"Capacity of node {unit_id} {node_type_dict[unit_id]['title']} does not meet requirements.") + print(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]) + return False + + unit_id = 9 + v_min[9] = load_min[1] * 0.05 + load_min[2] * 0.05 + load_min[3] * 0.05 + load_min[10] * 0.05 + load_min[11] * 0.05 + v_max[9] = load_max[1] * 0.05 + load_max[2] * 0.05 + load_max[3] * 0.05 + load_max[10] * 0.05 + load_max[11] * 0.05 + if nodes[unit_id] > 0 and not overlap(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]): + #error + if logging: + print(f"Capacity of node {unit_id} {node_type_dict[unit_id]['title']} does not meet requirements.") + print(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]) + return False + + return True + +def generate_dataset(dataset_size): + nodes_list = [] + edges_list = [] + staff_list = [] + for i in range(dataset_size): + #nodes, relations = generate_model() + nodes, relations, staff = generate_parametrized_model() + nodes_list.append(nodes) + edges_list.append(relations) + staff_list.append(staff) + return np.stack(nodes_list, axis=0), \ + np.stack(edges_list, axis=0), \ + np.stack(staff_list, axis=0) + +def create_dataset(dataset_size, dataset_dir, validation=0.1, test=0.1): + nodes, edges, staff = generate_dataset(dataset_size) + + validation = int(validation * dataset_size) + test = int(test * dataset_size) + train = dataset_size - validation - test + + all_idx = np.random.permutation(dataset_size) + train_idx = all_idx[0:train] + validation_idx = all_idx[train:train + validation] + test_idx = all_idx[train + validation:] + + np.save(os.path.join(dataset_dir, 'data_nodes.npy'), nodes) + np.save(os.path.join(dataset_dir, 'data_edges.npy'), edges) + np.save(os.path.join(dataset_dir, 'data_staff.npy'), staff) + + with open(os.path.join(dataset_dir, 'data_meta.pkl'), 'wb') as f: + pickle.dump({'train_idx': train_idx, + 'train_count': train, + 'train_counter': 0, + 'validation_idx': validation_idx, + 'validation_count': validation, + 'validation_counter': 0, + 'test_idx': test_idx, + 'test_count': test, + 'test_counter': 0, + + 'node_num_types': NODE_N_TYPES, + 'edge_num_types': EDGE_N_TYPES, + 'vertexes': MAX_NODES_PER_GRAPH, + }, f) + + +if __name__ == '__main__': + + # Make it reproducible + random.seed(1) + np.random.seed(1) + + create_dataset(10000, 'data', 0.1, 0.1) diff --git a/layers.py b/layers.py new file mode 100644 index 0000000..d8b1835 --- /dev/null +++ b/layers.py @@ -0,0 +1,57 @@ +import math +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter +from torch.nn.modules.module import Module + + +class GraphConvolution(Module): + + def __init__(self, in_features, out_feature_list, b_dim, dropout): + super(GraphConvolution, self).__init__() + self.in_features = in_features + self.out_feature_list = out_feature_list + + self.linear1 = nn.Linear(in_features, out_feature_list[0]) + self.linear2 = nn.Linear(out_feature_list[0], out_feature_list[1]) + + self.dropout = nn.Dropout(dropout) + + def forward(self, input, adj, activation=None): + # input : batch x n_nodes x n_node_types + # adj : batch x n_edge_types x n_nodes x n_nodes + + hidden = torch.stack([self.linear1(input) for _ in range(adj.size(1))], 1) + hidden = torch.einsum('bijk,bikl->bijl', (adj, hidden)) + hidden = torch.sum(hidden, 1) + self.linear1(input) + hidden = activation(hidden) if activation is not None else hidden + hidden = self.dropout(hidden) + + output = torch.stack([self.linear2(hidden) for _ in range(adj.size(1))], 1) + output = torch.einsum('bijk,bikl->bijl', (adj, output)) + output = torch.sum(output, 1) + self.linear2(hidden) + output = activation(output) if activation is not None else output + output = self.dropout(output) + + return output + + +class GraphAggregation(Module): + + def __init__(self, in_features, out_features, m_dim, dropout): + super(GraphAggregation, self).__init__() + self.sigmoid_linear = nn.Sequential(nn.Linear(in_features+m_dim, out_features), + nn.Sigmoid()) + self.tanh_linear = nn.Sequential(nn.Linear(in_features+m_dim, out_features), + nn.Tanh()) + self.dropout = nn.Dropout(dropout) + + def forward(self, input, activation): + i = self.sigmoid_linear(input) + j = self.tanh_linear(input) + output = torch.sum(torch.mul(i,j), 1) + output = activation(output) if activation is not None\ + else output + output = self.dropout(output) + + return output diff --git a/main.py b/main.py new file mode 100644 index 0000000..ade55f7 --- /dev/null +++ b/main.py @@ -0,0 +1,81 @@ +import os +import argparse +from solver import Solver +from torch.backends import cudnn + +def str2bool(v): + return v.lower() in ('true') + +def main(config): + # For fast training. + cudnn.benchmark = True + + # Create directories if not exist. + if not os.path.exists(config.log_dir): + os.makedirs(config.log_dir) + if not os.path.exists(config.model_save_dir): + os.makedirs(config.model_save_dir) + if not os.path.exists(config.sample_dir): + os.makedirs(config.sample_dir) + if not os.path.exists(config.result_dir): + os.makedirs(config.result_dir) + + # Solver for training and testing OrGAN. + solver = Solver(config) + + if config.mode == 'train': + solver.train() + elif config.mode == 'test': + solver.test() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + # Model configuration. + parser.add_argument('--z_dim', type=int, default=8, help='dimension of domain labels') + parser.add_argument('--g_conv_dim', default=[128,256,512], help='number of conv filters in the first layer of G') + parser.add_argument('--d_conv_dim', type=int, default=[[128, 64], 128, [128, 64]], help='number of conv filters in the first layer of D') + parser.add_argument('--g_repeat_num', type=int, default=6, help='number of residual blocks in G') + parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D') + parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss') + parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss') + parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty') + parser.add_argument('--post_method', type=str, default='softmax', choices=['softmax', 'soft_gumbel', 'hard_gumbel']) + + # Training configuration. + parser.add_argument('--batch_size', type=int, default=16, help='mini-batch size') + parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D') + parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr') + parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G') + parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D') + parser.add_argument('--dropout', type=float, default=0., help='dropout rate') + parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update') + parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer') + parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer') + parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step') + + # Test configuration. + parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step') + + # Miscellaneous. + parser.add_argument('--num_workers', type=int, default=1) + parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) + parser.add_argument('--use_tensorboard', type=str2bool, default=False) + + # Directories. + parser.add_argument('--data_dir', type=str, default='data') + parser.add_argument('--log_dir', type=str, default='organ/logs') + parser.add_argument('--model_save_dir', type=str, default='organ/models') + parser.add_argument('--sample_dir', type=str, default='organ/samples') + parser.add_argument('--result_dir', type=str, default='organ/results') + + # Step size. + parser.add_argument('--log_step', type=int, default=10) + parser.add_argument('--sample_step', type=int, default=1000) + parser.add_argument('--model_save_step', type=int, default=10000) + parser.add_argument('--lr_update_step', type=int, default=1000) + + config = parser.parse_args() + print(config) + main(config) diff --git a/models.py b/models.py new file mode 100644 index 0000000..498094d --- /dev/null +++ b/models.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from layers import GraphConvolution, GraphAggregation + + +class ResidualBlock(nn.Module): + """Residual Block with instance normalization.""" + def __init__(self, dim_in, dim_out): + super(ResidualBlock, self).__init__() + self.main = nn.Sequential( + nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False), + nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True), + nn.ReLU(inplace=True), + nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False), + nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True)) + + def forward(self, x): + return x + self.main(x) + + +class Generator(nn.Module): + """Generator network.""" + def __init__(self, conv_dims, z_dim, vertexes, edges, nodes, dropout): + super(Generator, self).__init__() + + self.vertexes = vertexes + self.edges = edges + self.nodes = nodes + + layers = [] + for c0, c1 in zip([z_dim]+conv_dims[:-1], conv_dims): + layers.append(nn.Linear(c0, c1)) + layers.append(nn.Tanh()) + layers.append(nn.Dropout(p=dropout, inplace=True)) + self.layers = nn.Sequential(*layers) + + self.edges_layer = nn.Linear(conv_dims[-1], edges * vertexes * vertexes) + self.nodes_layer = nn.Linear(conv_dims[-1], vertexes * nodes) + self.dropoout = nn.Dropout(p=dropout) + + def forward(self, x): + output = self.layers(x) + edges_logits = self.edges_layer(output)\ + .view(-1,self.edges,self.vertexes,self.vertexes) + edges_logits = (edges_logits + edges_logits.permute(0,1,3,2))/2 + edges_logits = self.dropoout(edges_logits.permute(0,2,3,1)) + + nodes_logits = self.nodes_layer(output) + nodes_logits = self.dropoout(nodes_logits.view(-1,self.vertexes,self.nodes)) + + return edges_logits, nodes_logits + + +class Discriminator(nn.Module): + """Discriminator network with PatchGAN.""" + def __init__(self, conv_dim, m_dim, b_dim, dropout): + super(Discriminator, self).__init__() + + graph_conv_dim, aux_dim, linear_dim = conv_dim + # discriminator + self.gcn_layer = GraphConvolution(m_dim, graph_conv_dim, b_dim, dropout) + self.agg_layer = GraphAggregation(graph_conv_dim[-1], aux_dim, m_dim, dropout) + + # multi dense layer + layers = [] + for c0, c1 in zip([aux_dim]+linear_dim[:-1], linear_dim): + layers.append(nn.Linear(c0,c1)) + layers.append(nn.Dropout(dropout)) + self.linear_layer = nn.Sequential(*layers) + + self.output_layer = nn.Linear(linear_dim[-1], 1) + + def forward(self, adj, hidden, node, activatation=None): + adj = adj[:,:,:,1:].permute(0,3,1,2) + annotations = torch.cat((hidden, node), -1) if hidden is not None else node + h = self.gcn_layer(annotations, adj) + annotations = torch.cat((h, hidden, node) if hidden is not None\ + else (h, node), -1) + h = self.agg_layer(annotations, torch.tanh) + h = self.linear_layer(h) + + # Need to implemente batch discriminator # + ########################################## + + output = self.output_layer(h) + output = activatation(output) if activatation is not None else output + + return output, h diff --git a/notebooks/ModelGenerator.ipynb b/notebooks/ModelGenerator.ipynb new file mode 100644 index 0000000..035f07a --- /dev/null +++ b/notebooks/ModelGenerator.ipynb @@ -0,0 +1,680 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Imports

" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import copy\n", + "import random" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Initial data

" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "node_type_dict = { # status: 0 - optional, 1 - mandatory, 2 - replaceble\n", + " 0: {'title': 'none', 'status': 0, 'weight': 0}, \n", + " 1: {'title': 'Управление запасами (складом)', 'status': 2, 'replacement': [2, 3]}, \n", + " 2: {'title': 'Управление материалами', 'status': 0},\n", + " 3: {'title': 'Управление готовыми товарами', 'status': 0},\n", + " 4: {'title': 'Планирование', 'status': 2, 'replacement': [5, 6, 7], 'children': [8]},\n", + " 5: {'title': 'Планирование закупок', 'status': 0},\n", + " 6: {'title': 'Планирование запасов', 'status': 0},\n", + " 7: {'title': 'Планирование перевозок', 'status': 0},\n", + " 8: {'title': 'Аналитика', 'status': 0},\n", + " 9: {'title': 'Аудит', 'status': 0},\n", + " 10: {'title': 'Транспортировка', 'status': 1},\n", + " 11: {'title': 'Транспортное хозяйство', 'status': 0},\n", + "}\n", + "\n", + "top_level_nodes = [1, 4, 9, 10]\n", + "\n", + "relations_dict = [\n", + " # 0 1 2 3 4 5 6 7 8 9 10 11\n", + " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # 0\n", + " [ 0, 0, 1, 1, 2, 2, 2, 0, 2, 2, 2, 2], # 1\n", + " [ 0, 0, 0, 0, 2, 2, 0, 2, 2, 2, 2, 2], # 2\n", + " [ 0, 0, 0, 0, 2, 0, 2, 2, 2, 2, 2, 0], # 3\n", + " [ 0, 2, 2, 2, 0, 1, 1, 1, 1, 0, 2, 2], # 4\n", + " [ 0, 2, 2, 0, 0, 0, 0, 2, 2, 0, 0, 2], # 5\n", + " [ 0, 2, 0, 2, 0, 0, 0, 2, 2, 0, 0, 0], # 6\n", + " [ 0, 0, 2, 2, 0, 2, 2, 0, 2, 0, 2, 2], # 7\n", + " [ 0, 2, 2, 2, 0, 2, 2, 2, 0, 0, 2, 2], # 8\n", + " [ 0, 2, 2, 2, 0, 0, 0, 0, 0, 0, 2, 2], # 9\n", + " [ 0, 2, 2, 2, 2, 0, 0, 2, 2, 2, 0, 1], # 10\n", + " [ 0, 2, 2, 0, 2, 2, 0, 2, 2, 2, 0, 0], # 11\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Auxiliary functions

" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Recursive function to generate element tree\n", + "\n", + "def generate_children(node_list, force_nodes = False):\n", + " child_nodes = []\n", + " for key in node_list:\n", + " if node_type_dict[key]['status'] == 1 or force_nodes: # mandatory\n", + " child_nodes.append(key)\n", + " if 'replacement' in node_type_dict[key]:\n", + " child_nodes.extend(generate_children(node_type_dict[key]['replacement']))\n", + " if 'children' in node_type_dict[key]:\n", + " child_nodes.extend(generate_children(node_type_dict[key]['children']))\n", + " elif node_type_dict[key]['status'] == 0:\n", + " if np.random.rand(1)[0] < 0.5:\n", + " child_nodes.append(key) \n", + " if 'replacement' in node_type_dict[key]:\n", + " child_nodes.extend(generate_children(node_type_dict[key]['replacement']))\n", + " if 'children' in node_type_dict[key]:\n", + " child_nodes.extend(generate_children(node_type_dict[key]['children']))\n", + " elif node_type_dict[key]['status'] == 2:\n", + " if np.random.rand(1)[0] < 0.5:\n", + " child_nodes.append(key) \n", + " if 'replacement' in node_type_dict[key]:\n", + " child_nodes.extend(generate_children(node_type_dict[key]['replacement']))\n", + " if 'children' in node_type_dict[key]:\n", + " child_nodes.extend(generate_children(node_type_dict[key]['children']))\n", + " else:\n", + " child_nodes.extend(generate_children(node_type_dict[key]['replacement'], True))\n", + " if 'children' in node_type_dict[key]:\n", + " child_nodes.extend(generate_children(node_type_dict[key]['children']))\n", + "\n", + " #generate children\n", + " return child_nodes\n", + "\n", + "# Shuffle nodes and relations\n", + "\n", + "def shuffle_nodes(nodes, relations):\n", + " keys = list(range(12))\n", + " random.shuffle(keys)\n", + " nodes = nodes[keys]\n", + " relations = relations[keys][:, keys]\n", + " #print(keys)\n", + " #print(nodes)\n", + " #print(relations) \n", + " return nodes, relations\n", + "\n", + "def generate_model(shuffle = True):\n", + " #generate nodes\n", + " tmp_nodes = generate_children(top_level_nodes)\n", + " \n", + " #fill list with generated nodes and fill all relations\n", + " nodes = list(node_type_dict.keys())\n", + " relations = np.array(copy.deepcopy(relations_dict))\n", + "\n", + " for node_key in node_type_dict:\n", + " if node_key not in tmp_nodes:\n", + " nodes[node_key] = 0\n", + " relations[node_key, :] = 0\n", + " relations[:, node_key] = 0\n", + " \n", + " nodes = np.array(nodes)\n", + " if shuffle:\n", + " nodes, relations = shuffle_nodes(nodes, relations)\n", + " return nodes, relations\n", + "\n", + "def unshuffle(nodes, relations):\n", + " keys = [-1] * len(nodes)\n", + " zeros = [i for i, node in enumerate(nodes) if node == 0]\n", + " non_zeros = [(i, node) for i, node in enumerate(nodes) if node != 0]\n", + " #print(zeros)\n", + " #print(non_zeros)\n", + "\n", + " for i, node in non_zeros:\n", + " keys[node] = i\n", + " #print(keys)\n", + "\n", + " for i in range(len(keys)):\n", + " if keys[i] == -1:\n", + " keys[i] = zeros.pop(0)\n", + " #print(keys)\n", + "\n", + " return nodes[keys], relations[keys][:, keys]\n", + "\n", + "def check_children(top_level_nodes, node_list, force_nodes = False):\n", + " #print('top_level_nodes=', top_level_nodes)\n", + " #print('force_nodes', force_nodes)\n", + " for key in top_level_nodes:\n", + " #print('Key = ', key)\n", + " if node_type_dict[key]['status'] == 1 or force_nodes: # mandatory\n", + " #print('Mandatory')\n", + " if node_list[key] == 0:\n", + " return False, f'Mandatory element {key} is missing.'\n", + " if 'replacement' in node_type_dict[key]:\n", + " result, explanation = check_children(node_type_dict[key]['replacement'], node_list)\n", + " if not result:\n", + " return False, explanation\n", + " if 'children' in node_type_dict[key]:\n", + " result, explanation = check_children(node_type_dict[key]['children'], node_list)\n", + " if not result:\n", + " return False, explanation\n", + " if node_type_dict[key]['status'] == 2 or force_nodes: # replacable\n", + " #print('Replacable')\n", + " if node_list[key] == 0:\n", + " if 'replacement' in node_type_dict[key]:\n", + " result, explanation = check_children(node_type_dict[key]['replacement'], node_list, True)\n", + " if not result:\n", + " return False, f'Replacable element {key} is missing. {explanation}'\n", + " else:\n", + " return False, f'Replacable element {key} is missing.'\n", + " if 'children' in node_type_dict[key]:\n", + " result, explanation = check_children(node_type_dict[key]['children'], node_list)\n", + " if not result:\n", + " return False, explanation\n", + " elif node_type_dict[key]['status'] == 0:\n", + " #print('Optional')\n", + " if 'replacement' in node_type_dict[key]:\n", + " result, explanation = check_children(node_type_dict[key]['replacement'], node_list)\n", + " if not result:\n", + " return False, explanation\n", + " if 'children' in node_type_dict[key]:\n", + " result, explanation = check_children(node_type_dict[key]['children'], node_list)\n", + " if not result:\n", + " return False, explanation\n", + " return True, ''\n", + "\n", + "def check_relations(nodes, relations):\n", + " target_relations = np.array(copy.deepcopy(relations_dict))\n", + " for node_key in node_type_dict:\n", + " if node_key not in nodes:\n", + " target_relations[node_key, :] = 0\n", + " target_relations[:, node_key] = 0\n", + "\n", + " relations_diff = np.array([relations == target_relations])\n", + " result = relations_diff.all()\n", + " return result, relations_diff" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Example

" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " Model 1\n", + "[ 0 0 10 0 0 6 2 3 5 0 7 0]\n", + "[[0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 2 2 0 0 2 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 2 0 0 2 0]\n", + " [0 0 2 0 0 0 0 0 2 0 2 0]\n", + " [0 0 2 0 0 2 0 0 0 0 2 0]\n", + " [0 0 0 0 0 0 2 0 0 0 2 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 2 0 0 2 2 2 2 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]]\n", + "Replacable element 1 is missing. Mandatory element 2 is missing.\n", + "[[[ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True False True False True True False\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]]]\n", + "\n", + " Model 2\n", + "[ 8 9 0 6 3 2 0 0 0 4 0 10]\n", + "[[0 0 0 2 2 2 0 0 0 0 0 2]\n", + " [0 0 0 0 2 2 0 0 0 0 0 2]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [2 0 0 0 2 0 0 0 0 0 0 0]\n", + " [2 2 0 2 0 0 0 0 0 2 0 2]\n", + " [2 2 0 0 0 0 0 0 0 2 0 2]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [1 0 0 1 2 2 0 0 0 0 0 2]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [2 2 0 0 2 2 0 0 0 2 0 0]]\n", + "Replacable element 1 is missing. Mandatory element 2 is missing.\n", + "[[[ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True False True True True False False False\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]]]\n", + "\n", + " Model 3\n", + "[ 0 0 3 10 5 0 8 7 0 2 0 6]\n", + "[[0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 2 0 0 2 2 0 0 0 2]\n", + " [0 0 2 0 0 0 2 2 0 2 0 0]\n", + " [0 0 0 0 0 0 2 2 0 2 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 2 2 2 0 0 2 0 2 0 2]\n", + " [0 0 2 2 2 0 2 0 0 2 0 2]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 2 2 0 2 2 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 2 0 0 0 2 2 0 0 0 0]]\n", + "Replacable element 1 is missing. Mandatory element 2 is missing.\n", + "[[[ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True False True False False True False\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]]]\n", + "\n", + " Model 4\n", + "[ 5 0 6 10 4 3 0 0 1 2 0 0]\n", + "[[0 0 0 0 0 0 0 0 2 2 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 2 0 0 2 0 0 0]\n", + " [0 0 0 0 2 2 0 0 2 2 0 0]\n", + " [1 0 1 2 0 2 0 0 2 2 0 0]\n", + " [0 0 2 2 2 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [2 0 2 2 2 1 0 0 0 1 0 0]\n", + " [2 0 0 2 2 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]]\n", + "[[[ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True False False True True True True False\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]]]\n", + "\n", + " Model 5\n", + "[ 2 10 0 3 0 0 0 4 5 0 0 0]\n", + "[[0 2 0 0 0 0 0 2 2 0 0 0]\n", + " [2 0 0 2 0 0 0 2 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 2 0 0 0 0 0 2 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [2 2 0 2 0 0 0 0 1 0 0 0]\n", + " [2 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]]\n", + "Replacable element 1 is missing. Mandatory element 2 is missing.\n", + "[[[ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True False False True True True True False\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]]]\n", + "\n", + " Model 6\n", + "[ 0 0 10 1 7 2 4 0 0 0 6 0]\n", + "[[0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 2 2 2 2 0 0 0 0 0]\n", + " [0 0 2 0 0 1 2 0 0 0 2 0]\n", + " [0 0 2 0 0 2 0 0 0 0 2 0]\n", + " [0 0 2 0 2 0 2 0 0 0 0 0]\n", + " [0 0 2 2 1 2 0 0 0 0 1 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 2 2 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]]\n", + "[[[ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True False True True False True True False\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]]]\n", + "\n", + " Model 7\n", + "[ 3 5 0 10 0 0 7 0 8 0 6 1]\n", + "[[0 0 0 2 0 0 2 0 2 0 2 0]\n", + " [0 0 0 0 0 0 2 0 2 0 0 2]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [2 0 0 0 0 0 2 0 2 0 0 2]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [2 2 0 2 0 0 0 0 2 0 2 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [2 2 0 2 0 0 2 0 0 0 2 2]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [2 0 0 0 0 0 2 0 2 0 0 2]\n", + " [1 2 0 2 0 0 0 0 2 0 2 0]]\n", + "\n", + " Model 8\n", + "[ 2 6 0 3 9 0 0 0 10 4 0 0]\n", + "[[0 0 0 0 2 0 0 0 2 2 0 0]\n", + " [0 0 0 2 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 2 0 0 2 0 0 0 2 2 0 0]\n", + " [2 0 0 2 0 0 0 0 2 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [2 0 0 2 2 0 0 0 0 2 0 0]\n", + " [2 1 0 2 0 0 0 0 2 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]]\n", + "Replacable element 1 is missing. Mandatory element 2 is missing.\n", + "[[[ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True False True True True True False False\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]]]\n", + "\n", + " Model 9\n", + "[ 2 0 0 0 4 5 9 7 0 0 10 3]\n", + "[[0 0 0 0 2 2 2 2 0 0 2 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [2 0 0 0 0 1 0 1 0 0 2 2]\n", + " [2 0 0 0 0 0 0 2 0 0 0 0]\n", + " [2 0 0 0 0 0 0 0 0 0 2 2]\n", + " [2 0 0 0 0 2 0 0 0 0 2 2]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [2 0 0 0 2 0 2 2 0 0 0 2]\n", + " [0 0 0 0 2 0 2 2 0 0 2 0]]\n", + "Replacable element 1 is missing. Mandatory element 2 is missing.\n", + "[[[ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True True True False False True False True False False\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True False True True True True True True True True\n", + " True]\n", + " [ True True True True True True True True True True True\n", + " True]]]\n", + "\n", + " Model 10\n", + "[ 1 5 0 0 10 0 0 7 9 0 6 0]\n", + "[[0 2 0 0 2 0 0 0 2 0 2 0]\n", + " [2 0 0 0 0 0 0 2 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [2 0 0 0 0 0 0 2 2 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [0 2 0 0 2 0 0 0 0 0 2 0]\n", + " [2 0 0 0 2 0 0 0 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]\n", + " [2 0 0 0 0 0 0 2 0 0 0 0]\n", + " [0 0 0 0 0 0 0 0 0 0 0 0]]\n" + ] + } + ], + "source": [ + "model_quantity = 10# Quantity of models to generate\n", + "for i in range(model_quantity):\n", + " \n", + " # Generate model\n", + " nodes, relations = generate_model() # optional parameter \"shuffle=True\"\n", + " print(\"\\n Model \", i+1)\n", + " print(nodes)\n", + " print(relations)\n", + " \n", + " # Check model\n", + " nodes, relations = unshuffle(nodes, relations)\n", + " nodes[2] = 0 # To demonstrate a error sometimes. Another possibility: nodes[10] = 0 - will always produce errors since 10 is a mandatory node\n", + " result, explanation = check_children(top_level_nodes, nodes)\n", + " if not result:\n", + " print(explanation)\n", + " result, explanation = check_relations(nodes, relations)\n", + " if not result:\n", + " print(explanation) " + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import pickle\n", + "\n", + "def generate_dataset(dataset_size):\n", + " nodes_list = []\n", + " edges_list = []\n", + " for i in range(dataset_size):\n", + " nodes, relations = generate_model() # optional parameter \"shuffle=True\"\n", + " nodes_list.append(nodes)\n", + " edges_list.append(relations)\n", + " return np.stack(nodes_list, axis=0), \\\n", + " np.stack(edges_list, axis=0)\n", + "\n", + "def create_dataset(dataset_size, validation=0.1, test=0.1):\n", + " nodes, edges = generate_dataset(dataset_size)\n", + " \n", + " validation = int(validation * dataset_size)\n", + " test = int(test * dataset_size)\n", + " train = dataset_size - validation - test\n", + "\n", + " all_idx = np.random.permutation(dataset_size)\n", + " train_idx = all_idx[0:train]\n", + " validation_idx = all_idx[train:train + validation]\n", + " test_idx = all_idx[train + validation:]\n", + "\n", + " np.save('data_nodes.npy', nodes)\n", + " np.save('data_edges.npy', edges)\n", + "\n", + " with open('data_meta.pkl', 'wb') as f:\n", + " pickle.dump({'train_idx': train_idx,\n", + " 'train_count': train,\n", + " 'train_counter': 0,\n", + " 'validation_idx': validation_idx,\n", + " 'validation_count': validation,\n", + " 'validation_counter': 0,\n", + " 'test_idx': test_idx,\n", + " 'test_count': test,\n", + " 'test_counter': 0\n", + " }, f)\n", + " \n", + " \n", + "#a, b = generate_dataset(10)\n", + "#np.save('data_nodes', a)\n", + "#np.save('data_edges', b)\n", + "\n", + "create_dataset(1000, 0.1, 0.1)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/solver.py b/solver.py new file mode 100644 index 0000000..d6149fa --- /dev/null +++ b/solver.py @@ -0,0 +1,401 @@ +import numpy as np +import os +import time +import datetime + +import torch +import torch.nn.functional as F +from torch.autograd import Variable + +from utils import * +from models import Generator, Discriminator +from data.organization_structure_dataset import OrganizationStructureDataset + + +class Solver(object): + """Solver for training and testing OrGAN.""" + + def __init__(self, config): + """Initialize configurations.""" + + # Dataset. + self.data = OrganizationStructureDataset() + self.data.load(config.data_dir) + + # Model configurations. + self.z_dim = config.z_dim + self.m_dim = self.data.node_num_types + self.b_dim = self.data.edge_num_types + self.g_conv_dim = config.g_conv_dim + self.d_conv_dim = config.d_conv_dim + self.g_repeat_num = config.g_repeat_num + self.d_repeat_num = config.d_repeat_num + self.lambda_cls = config.lambda_cls + self.lambda_rec = config.lambda_rec + self.lambda_gp = config.lambda_gp + self.post_method = config.post_method + + self.metric = 'all' + + # Training configurations. + self.batch_size = config.batch_size + self.num_iters = config.num_iters + self.num_iters_decay = config.num_iters_decay + self.g_lr = config.g_lr + self.d_lr = config.d_lr + self.dropout = config.dropout + self.n_critic = config.n_critic + self.beta1 = config.beta1 + self.beta2 = config.beta2 + self.resume_iters = config.resume_iters + + # Test configurations. + self.test_iters = config.test_iters + + # Miscellaneous. + self.use_tensorboard = config.use_tensorboard + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # Directories. + self.log_dir = config.log_dir + self.sample_dir = config.sample_dir + self.model_save_dir = config.model_save_dir + self.result_dir = config.result_dir + + # Step size. + self.log_step = config.log_step + self.sample_step = config.sample_step + self.model_save_step = config.model_save_step + self.lr_update_step = config.lr_update_step + + # For the log to be informative, it should contain quality characteristics + # of only generated structures + assert self.log_step % self.n_critic == 0 + + # Build the model and tensorboard. + self.build_model() + if self.use_tensorboard: + self.build_tensorboard() + + def build_model(self): + """Create a generator and a discriminator.""" + + print('Max nodes:', self.data.vertexes) + print('Node types:', self.data.node_num_types, self.m_dim) + print('Edge types:', self.data.edge_num_types, self.b_dim) + + self.G = Generator(self.g_conv_dim, self.z_dim, + self.data.vertexes, + self.data.edge_num_types, + self.data.node_num_types, + self.dropout) + self.D = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim, self.dropout) + self.V = Discriminator(self.d_conv_dim, self.m_dim, self.b_dim, self.dropout) + + self.g_optimizer = torch.optim.Adam(list(self.G.parameters())+list(self.V.parameters()), + self.g_lr, [self.beta1, self.beta2]) + self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) + self.print_network(self.G, 'G') + self.print_network(self.D, 'D') + + self.G.to(self.device) + self.D.to(self.device) + self.V.to(self.device) + + def print_network(self, model, name): + """Print out the network information.""" + num_params = 0 + for p in model.parameters(): + num_params += p.numel() + print(model) + print(name) + print("The number of parameters: {}".format(num_params)) + + def restore_model(self, resume_iters): + """Restore the trained generator and discriminator.""" + print('Loading the trained models from step {}...'.format(resume_iters)) + G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters)) + D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters)) + V_path = os.path.join(self.model_save_dir, '{}-V.ckpt'.format(resume_iters)) + self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage)) + self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage)) + self.V.load_state_dict(torch.load(V_path, map_location=lambda storage, loc: storage)) + + def build_tensorboard(self): + """Build a tensorboard logger.""" + from logger import Logger + self.logger = Logger(self.log_dir) + + def update_lr(self, g_lr, d_lr): + """Decay learning rates of the generator and discriminator.""" + for param_group in self.g_optimizer.param_groups: + param_group['lr'] = g_lr + for param_group in self.d_optimizer.param_groups: + param_group['lr'] = d_lr + + def reset_grad(self): + """Reset the gradient buffers.""" + self.g_optimizer.zero_grad() + self.d_optimizer.zero_grad() + + def denorm(self, x): + """Convert the range from [-1, 1] to [0, 1].""" + out = (x + 1) / 2 + return out.clamp_(0, 1) + + def gradient_penalty(self, y, x): + """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" + weight = torch.ones(y.size()).to(self.device) + dydx = torch.autograd.grad(outputs=y, + inputs=x, + grad_outputs=weight, + retain_graph=True, + create_graph=True, + only_inputs=True)[0] + + dydx = dydx.view(dydx.size(0), -1) + dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) + return torch.mean((dydx_l2norm-1)**2) + + def label2onehot(self, labels, dim): + """Convert label indices to one-hot vectors.""" + out = torch.zeros(list(labels.size())+[dim]).to(self.device) + out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.) + return out + + def sample_z(self, batch_size): + return np.random.normal(0, 1, size=(batch_size, self.z_dim)) + + def postprocess(self, inputs, method, temperature=1.): + + def listify(x): + return x if type(x) == list or type(x) == tuple else [x] + + def delistify(x): + return x if len(x) > 1 else x[0] + + if method == 'soft_gumbel': + softmax = [F.gumbel_softmax(e_logits.contiguous().view(-1,e_logits.size(-1)) + / temperature, hard=False).view(e_logits.size()) + for e_logits in listify(inputs)] + elif method == 'hard_gumbel': + softmax = [F.gumbel_softmax(e_logits.contiguous().view(-1,e_logits.size(-1)) + / temperature, hard=True).view(e_logits.size()) + for e_logits in listify(inputs)] + else: + softmax = [F.softmax(e_logits / temperature, -1) + for e_logits in listify(inputs)] + + return [delistify(e) for e in (softmax)] + + def reward(self, orgs): + """Structural reward. + + The method calculates a vector of structural reward values for the given + batch of organization descriptions. The definition of structural reward + can be project-specific (the list of metrics is defined in `self.metric`) + and relies on various metrics defined in `OrganizationStructureDataset`. + + Parameters + ---------- + orgs : list + A list of organization specifications. + + Returns + ------- + numpy.ndarray, shape (batch_size, 1) + Batch of reward values. + + """ + rr = 1. + for m in ('nodes,edges' if self.metric == 'all' else self.metric).split(','): + if m == 'nodes': + rr *= OrganizationMetrics.node_validness_scores(orgs) + elif m == 'edges': + rr *= OrganizationMetrics.edge_validness_scores(orgs) + else: + raise RuntimeError('{} is not defined as a metric'.format(m)) + return rr.reshape(-1, 1) + + def train(self): + + # Learning rate cache for decaying. + g_lr = self.g_lr + d_lr = self.d_lr + + # Start training from scratch or resume training. + start_iters = 0 + if self.resume_iters: + start_iters = self.resume_iters + self.restore_model(self.resume_iters) + + # Start training. + print('Start training...') + start_time = time.time() + for i in range(start_iters, self.num_iters): + if (i+1) % self.log_step == 0: + x, a = self.data.next_validation_batch() + z = self.sample_z(a.shape[0]) + print('[Valid]', '') + else: + x, a = self.data.next_train_batch(self.batch_size) + z = self.sample_z(self.batch_size) + orgs = list(zip(x, a)) + + # Hatter: + # orgs is an (self.batch_size, ) numpy array - organization instances (used for checking) + # a is a (self.batch_size, 9, 9) numpy array - adjacency matrices (a_ij is the number of connections) + # x is a (self.batch_size, 9) numpy array - node type (categorical, 0 for no-node) + + # =================================================================================== # + # 1. Preprocess input data # + # =================================================================================== # + + a = torch.from_numpy(a).to(self.device).long() # Adjacency. + x = torch.from_numpy(x).to(self.device).long() # Nodes. + a_tensor = self.label2onehot(a, self.b_dim) + x_tensor = self.label2onehot(x, self.m_dim) + z = torch.from_numpy(z).to(self.device).float() + + # =================================================================================== # + # 2. Train the discriminator # + # =================================================================================== # + + # Compute loss with real images. + logits_real, features_real = self.D(a_tensor, None, x_tensor) + d_loss_real = - torch.mean(logits_real) + + # Compute loss with fake images. + edges_logits, nodes_logits = self.G(z) + # Postprocess with Gumbel softmax + (edges_hat, nodes_hat) = self.postprocess((edges_logits, nodes_logits), self.post_method) + logits_fake, features_fake = self.D(edges_hat, None, nodes_hat) + d_loss_fake = torch.mean(logits_fake) + + # Compute loss for gradient penalty. + eps = torch.rand(logits_real.size(0),1,1,1).to(self.device) + x_int0 = (eps * a_tensor + (1. - eps) * edges_hat).requires_grad_(True) + x_int1 = (eps.squeeze(-1) * x_tensor + (1. - eps.squeeze(-1)) * nodes_hat).requires_grad_(True) + grad0, grad1 = self.D(x_int0, None, x_int1) + d_loss_gp = self.gradient_penalty(grad0, x_int0) + self.gradient_penalty(grad1, x_int1) + + + # Backward and optimize. + d_loss = d_loss_fake + d_loss_real + self.lambda_gp * d_loss_gp + self.reset_grad() + d_loss.backward() + self.d_optimizer.step() + + # Logging. + loss = {} + loss['D/loss_real'] = d_loss_real.item() + loss['D/loss_fake'] = d_loss_fake.item() + loss['D/loss_gp'] = d_loss_gp.item() + + # =================================================================================== # + # 3. Train the generator # + # =================================================================================== # + + if (i+1) % self.n_critic == 0: + # Z-to-target + edges_logits, nodes_logits = self.G(z) + # Postprocess with Gumbel softmax + (edges_hat, nodes_hat) = self.postprocess((edges_logits, nodes_logits), self.post_method) + logits_fake, features_fake = self.D(edges_hat, None, nodes_hat) + g_loss_fake = - torch.mean(logits_fake) + + # Real Reward + rewardR = torch.from_numpy(self.reward(orgs)).to(self.device) + # Fake Reward + (edges_hard, nodes_hard) = self.postprocess((edges_logits, nodes_logits), 'hard_gumbel') + edges_hard, nodes_hard = torch.max(edges_hard, -1)[1], torch.max(nodes_hard, -1)[1] + orgs = [self.data.matrices2graph(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True) + for e_, n_ in zip(edges_hard, nodes_hard)] + rewardF = torch.from_numpy(self.reward(orgs)).to(self.device) + + # Value loss + value_logit_real,_ = self.V(a_tensor, None, x_tensor, torch.sigmoid) + value_logit_fake,_ = self.V(edges_hat, None, nodes_hat, torch.sigmoid) + g_loss_value = torch.mean((value_logit_real - rewardR) ** 2 + ( + value_logit_fake - rewardF) ** 2) + #rl_loss= -value_logit_fake + #f_loss = (torch.mean(features_real, 0) - torch.mean(features_fake, 0)) ** 2 + + # Backward and optimize. + g_loss = g_loss_fake + g_loss_value + self.reset_grad() + g_loss.backward() + self.g_optimizer.step() + + # Logging. + loss['G/loss_fake'] = g_loss_fake.item() + loss['G/loss_value'] = g_loss_value.item() + + # =================================================================================== # + # 4. Miscellaneous # + # =================================================================================== # + + # Print out training information. + if (i+1) % self.log_step == 0: + et = time.time() - start_time + et = str(datetime.timedelta(seconds=et))[:-7] + log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters) + + # Log update + m0, m1 = all_scores(orgs, self.data, norm=True) # 'orgs' is output of Fake Reward + m0 = {k: np.array(v)[np.nonzero(v)].mean() for k, v in m0.items()} + m0.update(m1) + loss.update(m0) + for tag, value in loss.items(): + log += ", {}: {:.4f}".format(tag, value) + print(log) + + if self.use_tensorboard: + for tag, value in loss.items(): + self.logger.scalar_summary(tag, value, i+1) + + # Save model checkpoints. + if (i+1) % self.model_save_step == 0: + G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1)) + D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1)) + V_path = os.path.join(self.model_save_dir, '{}-V.ckpt'.format(i+1)) + torch.save(self.G.state_dict(), G_path) + torch.save(self.D.state_dict(), D_path) + torch.save(self.V.state_dict(), V_path) + print('Saved model checkpoints into {}...'.format(self.model_save_dir)) + + # Decay learning rates. + if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay): + g_lr -= (self.g_lr / float(self.num_iters_decay)) + d_lr -= (self.d_lr / float(self.num_iters_decay)) + self.update_lr(g_lr, d_lr) + print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) + + + def test(self): + # Load the trained generator. + self.restore_model(self.test_iters) + + with torch.no_grad(): + orgs, _, _, a, x, _, _, _, _ = self.data.next_test_batch() + z = self.sample_z(a.shape[0]) + + # Z-to-target + edges_logits, nodes_logits = self.G(z) + # Postprocess with Gumbel softmax + (edges_hat, nodes_hat) = self.postprocess((edges_logits, nodes_logits), self.post_method) + logits_fake, features_fake = self.D(edges_hat, None, nodes_hat) + g_loss_fake = - torch.mean(logits_fake) + + # Fake Reward + (edges_hard, nodes_hard) = self.postprocess((edges_logits, nodes_logits), 'hard_gumbel') + edges_hard, nodes_hard = torch.max(edges_hard, -1)[1], torch.max(nodes_hard, -1)[1] + orgs = [self.data.matrices2graph(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True) + for e_, n_ in zip(edges_hard, nodes_hard)] + + # Log update + m0, m1 = all_scores(orgs, self.data, norm=True) # 'orgs' is output of Fake Reward + m0 = {k: np.array(v)[np.nonzero(v)].mean() for k, v in m0.items()} + m0.update(m1) + for tag, value in m0.items(): + log += ", {}: {:.4f}".format(tag, value) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..4652224 --- /dev/null +++ b/utils.py @@ -0,0 +1,66 @@ +import numpy as np + +import generate_dataset + +class OrganizationMetrics(object): + """Utility functions to calculate organization quality and validity metrics. + + Organizations are defined by a tuple: + vector of node types + matrix of edge types + + """ + + @staticmethod + def valid_lambda(org): + """ + Checks validity of the organization structure. + + Currently, it is just a stub, returning true for every organization structure. + """ + return org is not None and generate_dataset.check_nodes(org[0]) and \ + generate_dataset.check_relations(org[0], org[1])[0] + + @staticmethod + def edge_validness_scores(orgs): + """Estimates egde validness for multiple organizations.""" + meth = lambda x : generate_dataset.check_relations(x[0], x[1])[0] + return np.array(list(map(meth, orgs)), dtype=np.float32) + + @staticmethod + def node_validness_scores(orgs): + """Estimates node validness for multiple organizations.""" + meth = lambda x : generate_dataset.check_nodes(x[0]) + return np.array(list(map(meth, orgs)), dtype=np.float32) + + @staticmethod + def valid_scores(orgs): + return np.array(list(map(OrganizationMetrics.valid_lambda, orgs)), dtype=np.float32) + + @staticmethod + def valid_filter(orgs): + return list(filter(OrganizationMetrics.valid_lambda, orgs)) + + @staticmethod + def valid_total_score(orgs): + return np.array(list(map(OrganizationMetrics.valid_lambda, orgs)), dtype=np.float32).mean() + + @staticmethod + def sample_organization_metric(orgs, norm=False): + scores = [OrganizationMetrics.valid_lambda(org) if org is not None else None for org in orgs] + scores = np.array(scores) + + return scores + +def all_scores(orgs, data, norm=False, reconstruction=False): + # These are one-value-for-structure scores + m0 = {k: list(filter(lambda e: e is not None, v)) for k, v in { + 'Sample score': OrganizationMetrics.sample_organization_metric(orgs, norm=norm)}.items() + } + # These are one-value-for-batch scores (used, e.g., in batch log reporting) + m1 = {'valid score': OrganizationMetrics.valid_total_score(orgs) * 100, + 'node score': np.mean(OrganizationMetrics.node_validness_scores(orgs)) * 100, + 'edge score': np.mean(OrganizationMetrics.edge_validness_scores(orgs)) * 100 + } + + return m0, m1 -- GitLab