diff --git a/docs/intro.rst b/docs/intro.rst index 501ef6ba0178719328664cb335923cc15da400a5..8c3a2022fa6e15449d52e59b79a0e26c8d66017c 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -21,7 +21,7 @@ virtual environment and install dependencies: .. code-block:: shell - $ git clone http://cais.iias.spb.su/gitlab/Hatter/organ-private.git organ + $ git clone http://cais.iias.spb.su/gitlab/Hatter/organ.git organ $ cd organ $ python3 -m venv venv $ source venv/bin/activate @@ -52,7 +52,7 @@ the model one shoud augment the training dataset: .. code-block:: shell - $ python augment_dataset.py 1000 demo_data/logistics data demo_logistics + $ python augment_dataset.py demo_logistics 1000 demo_data/logistics data This script will create 1000 organization samples in the :file:`data` directory, the dataset format is discussed in `Data`_. @@ -188,7 +188,7 @@ connect only existing nodes: .. code-block:: python - def soft_constraints(self, nodes, edges, ignore, ignore): + def soft_constraints(self, nodes, edges, *ignored): return 0.1 * organ.structure.constraints.edge_consistent(nodes, edges) + \ 0.1 * organ.structure.constraints.edge_symmetric(edges) @@ -309,6 +309,7 @@ Dict with dataset description contains a number of keys: - `node_num_types` - number of node types (including 0-type), must be (`f` - 1), - `edge_num_types` - number of edge types (including 0-type), -- `vertexes` - must be equal to `node_num_types`. - - +- `vertexes` - must be equal to `node_num_types`, +- `features_per_node` - number of feartures per node, +- `condition_dim` - number of features representing the generation + context (goal organization parameters). diff --git a/docs/modules.rst b/docs/modules.rst index 2f6131883ce68d25c077974483bd6e63918c513f..2b69a7ff7ffe7e6926638b0cefd614a414ccbcc6 100644 --- a/docs/modules.rst +++ b/docs/modules.rst @@ -37,8 +37,8 @@ TiNGLe .. automodule:: organ.tingle :members: -demo.structure --------------- +demo +---- -.. automodule:: organ.demo.structure +.. automodule:: organ.demo :members: diff --git a/organ/demo.py b/organ/demo.py index f0fcc3fd97fb1c666a77917383f769984f2ded22..7b9b2a6f7e0efe49c89ad23586617cb4d7b78faa 100644 --- a/organ/demo.py +++ b/organ/demo.py @@ -2,6 +2,7 @@ import os import copy import random +import torch import numpy as np @@ -762,13 +763,13 @@ class ManagementStructureModel: # Number of node types NODE_N_TYPES = len(node_type_dict) # Number of edge types - EDGE_N_TYPES = 6 + EDGE_N_TYPES = 7 # Max number of vertices per graph MAX_NODES_PER_GRAPH = 9 # Parametrization constants - node_param_0_min_allowed = [0.0, 0.0, 1000.0, 100.0, 0.0, 0.0, 500.0, 0.0, 0.0] - node_param_0_min_required = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1001.0, 0.0, 0.0] + node_param_0_min_allowed = [0.0, 0.0, 1000.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + node_param_0_min_required = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1200.0, 0.0, 0.0] param_0_max = 2000.0 def __init__(self): @@ -971,29 +972,9 @@ class ManagementStructureModel: bool Returns `True` if the parameters are valid. """ - + log = "Valid configuration" - if round(staff[8, 0]) != round(ctx[0]): - log = f"Value 0 of node 8 {self.node_type_dict[8]['title']} does not meet the context. " - log += f"Actual value: {str(staff[8, 0])} " - log += f"Context value: {str(ctx[0])}" - return False, log - - tmp_param_min = 0 - tmp_param_max = self.param_0_max - for i, node in enumerate(nodes): - if node > 0: - tmp_param_min = max(self.node_param_0_min_allowed[i], tmp_param_min) - elif self.node_param_0_min_required[i] > 0: - tmp_param_max = min(self.node_param_0_min_required[i], tmp_param_max) - - if staff[8, 0] > tmp_param_max or staff[8, 0] < tmp_param_min: - log = f"Value 0 of node 8 {self.node_type_dict[8]['title']} does not meet the configuration. " - log += str(tmp_param_min) + ' <= ' - log += f" actual value {str(staff[8, 0])} <= " - log += str(tmp_param_max) - return False, log - + if ctx[1] != int(nodes[4] > 0): log = f"Structure does not meet the context. " log += f"{self.node_type_dict[4]['title']} = {nodes[4]}. " @@ -1222,3 +1203,62 @@ class ManagementModel(ManagementStructureModel): org.node_features, ctx=org.condition)[0] } + + def soft_constraints(self, nodes, edges, params, ctx): + """Soft constraints for this scenario. + + The function describes some relationships between + node parameters, context, and organization structure + to simplify the training of a generator. + + Parameters + ---------- + nodes : torch.tensor + Nodes description in an 'internal' format: + (batch, nodes, node_types). Value is the probability that + a node of the specific type is located in a certain position. + Zero-type corresponds to the absense of a node. Non-zero + values can be only on the matrix diagonal or zeroth column. + edges : torch.tensor + Edges representation in an 'internal' format: + (batch, nodes, nodes, edge_types). + params : torch.tensor + Node features: (batch, nodes, features_per_node). + ctx : torch.tensor + Generation context: (batch, context_features). + + Returns + ------- + torch.tensor + Value tensor (0-dimensional). Non-negative loss for violation + of the constraints. + """ + marketing_cnt = torch.mean(torch.abs(nodes[:, 4, 4] - ctx[:, 1])) + + node_8_cnt = torch.mean(torch.abs(params[:, 8, 0] - ctx[:, 0])) / self.param_0_max + + node_existence = torch.sum(nodes[:, :, 1:], axis=-1) + + pm_nodes = torch.stack([ + ctx[:, 0] * 0, + ctx[:, 0] / 200, + ctx[:, 0] / 200, + ctx[:, 0] / 200, + ctx[:, 0] / ctx[:, 0] * 5, + ctx[:, 0] / 100, + ctx[:, 0] / 500, + ctx[:, 0] * 0, + ctx[:, 0] / 100], dim=-1) + pm = torch.sum(node_existence * pm_nodes, axis=-1) + node_7_cnt = torch.mean(torch.abs(pm-params[:, 7, 1])) + + node_2_cnt = torch.mean(torch.nn.functional.relu((1000 - ctx[:, 0]) / 1000) * + nodes[:, 2, 2]) + node_6_cnt = torch.mean(torch.nn.functional.relu((ctx[:, 0] - 1200) / 1200) * + (1-nodes[:, 6, 6])) + + return torch.mean(torch.stack([marketing_cnt, + node_8_cnt, + node_7_cnt, + node_2_cnt, + node_6_cnt])) diff --git a/organ/models.py b/organ/models.py index 87ed79c91f3d894b5444ea4c7d5e555fb56368eb..cdc15fee6a9a8bd4d1576fcbe00e83cf558cdf62 100644 --- a/organ/models.py +++ b/organ/models.py @@ -490,9 +490,9 @@ class CPDiscriminator(nn.Module): cond = None # Collect a group from nodes, edges and parameters - comps = [h, # graph nodes - h1, # graph edges - (1 - nodes[:, :, 0]).view(-1, 12), # node presence + comps = [h, # graph nodes + h1, # graph edges + (1 - nodes[:, :, 0]).view(-1, self.n_nodes), # node presence ] # Condition, if present if cond is not None: diff --git a/organ/solver.py b/organ/solver.py index 35dc786f283f8ba40b99f30d46d14dfec0e3ea37..3540da1f4747e9bb08607c37873e9c0837cd92e2 100644 --- a/organ/solver.py +++ b/organ/solver.py @@ -20,20 +20,26 @@ from organ.utils import MetricsAggregator, all_scores class Normalizer: - def __init__(self, per_feature=False): + def __init__(self, device, per_feature=False): self.per_feature = per_feature + self.device = device def fit(self, x: np.ndarray): if self.per_feature: self.m = np.max(x, axis=0) else: self.m = np.max(x) + self.mt = torch.tensor(self.m).to(self.device) - def transform(self, x: np.ndarray) -> np.ndarray: - return x / self.m + def transform(self, x): + if isinstance(x, np.ndarray): + return x / self.m + return x / self.mt - def reverse_transform(self, x: np.ndarray) -> np.ndarray: - return x * self.m + def reverse_transform(self, x): + if isinstance(x, np.ndarray): + return x * self.m + return x * self.mt class Solver(object): @@ -181,11 +187,11 @@ class Solver(object): # Build normalizers for float features of the dataset if self.parametric: - self.node_features_normalizer = Normalizer() + self.node_features_normalizer = Normalizer(self.device) self.node_features_normalizer.fit(self.data.node_params) if self.conditional: - self.cond_normalizer = Normalizer(per_feature=True) + self.cond_normalizer = Normalizer(self.device, per_feature=True) self.cond_normalizer.fit(self.data.cond) def build_model(self): @@ -625,10 +631,18 @@ class Solver(object): # Тут также может быть расчет других, дифференцируемых, # характеристик сгенерированной структуры if hasattr(self.org_model, 'soft_constraints'): + # User-level function has to deal with non-normalized + # values + params_hat_ = self.node_features_normalizer.\ + reverse_transform(params_hat) \ + if params_hat is not None else None + cond_ = self.cond_normalizer.reverse_transform( + cond) if cond is not None else None g_loss_soft_constraints = self.org_model.soft_constraints( - nodes_hat, edges_hat) + nodes_hat, edges_hat, params_hat_, cond_) else: - g_loss_soft_constraints = 0.0 + g_loss_soft_constraints = torch.tensor(0.0, + device=self.device) # В итоге функция потерь для генератора складывается из # потерь неправдоподобности (g_loss_fake) потерь, связанных с @@ -643,6 +657,7 @@ class Solver(object): # Logging. loss['G/loss_fake'] = g_loss_fake.item() loss['G/loss_value'] = g_loss_value.item() + loss['G/loss_soft'] = g_loss_soft_constraints.item() return orgs, loss @@ -1052,8 +1067,8 @@ class Solver(object): '\nStaff:\n', org.node_features, '\nEdges:\n', org.edges, '\nCheck results:\n', self.org_model.check_paramater_feasibility(org.nodes, # noqa: E501 - org.node_features.ravel(), # noqa: E501 - logging=True, # noqa: E501 + org.node_features, # noqa: E501 + # logging=True, # noqa: E501 ctx=org.condition), # noqa: E501 file=f) print('=======', file=f) diff --git a/organ/structure/models.py b/organ/structure/models.py index 5441bb3ce8ba86a8fa95100d8ab4b2bb86a7ce83..754f9c86842864b016a96625664f7f4801d52612 100644 --- a/organ/structure/models.py +++ b/organ/structure/models.py @@ -35,7 +35,7 @@ class OrganizationModel(ABC): def metrics(self, org: Organization) -> dict: pass - def soft_constraints(self, nodes, edges): + def soft_constraints(self, nodes, edges, features, cond): return 0.0 @@ -105,6 +105,6 @@ class Generic(OrganizationModel): 'edge score': self.check_relations(org.nodes, org.edges)[0], } - def soft_constraints(self, nodes, edges): + def soft_constraints(self, nodes, edges, *ignored): """Differentiable constraints on the graph structure.""" return C.edge_consistent(nodes, edges) + C.edge_symmetric(edges) diff --git a/tests/test_org_models.py b/tests/test_org_models.py index c5b5a39fac96772890b86a790a06ef9057711491..c60d300030588820963f719516d25af692bda370 100644 --- a/tests/test_org_models.py +++ b/tests/test_org_models.py @@ -1,8 +1,10 @@ import torch +import numpy as np import pytest import organ.structure.models +import organ.demo def test_generic_conforming(): @@ -72,3 +74,223 @@ def test_generic_nonconforming(): assert 'node score' in metrics assert 'edge score' in metrics assert org_model.soft_constraints(nodes, edges).item() > 0.0 + + +def test_demo_management_soft_constraints(): + nodes = torch.tensor( + [[[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # node 0 + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # node 1 + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # node 2 + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # node 3 + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], # node 4 Marketing + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], # node 5 + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # node 6 + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], # node 7 + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]], # node 8 + [[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # node 0 + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # node 1 + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # node 2 + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # node 3 + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # node 4 Marketing + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], # node 5 + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # node 6 + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], # node 7 + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], # node 8 + ]], requires_grad=True) + + edges1 = torch.zeros(4, 4) + edges1[0, 1] = 1.0 + edges1 = edges1 + edges1.t() + edges = torch.stack([1 - edges1, edges1], dim=-1) + edges = torch.unsqueeze(edges, 0) + + params = torch.tensor([[[0.0, 0.0], # node 0 + [0.0, 0.0], # node 1 + [0.0, 0.0], # node 2 + [0.0, 0.0], # node 3 + [0.0, 0.0], # node 4 Marketing + [0.0, 0.0], # node 5 + [0.0, 0.0], # node 6 + [0.0, 30.0], # node 7 + [1000.0, 0.0]], # node 8 + [[0.0, 0.0], # node 0 + [0.0, 0.0], # node 1 + [0.0, 0.0], # node 2 + [0.0, 0.0], # node 3 + [0.0, 0.0], # node 4 Marketing + [0.0, 0.0], # node 5 + [0.0, 0.0], # node 6 + [0.0, 12.0], # node 7 + [1000.0, 0.0], # node 8 + ]], requires_grad=True) + + ctx = torch.tensor([[1000.0, 1.0], + [500.0, 1.0]], requires_grad=True) + + org = organ.demo.ManagementModel() + res = org.soft_constraints(nodes, edges, params, ctx) + assert res > 0 + assert nodes.grad is None + res.backward() + assert nodes.grad is not None + + +def test_demo_management_validness(): + org_model = organ.demo.ManagementModel() + # valid configuration + org = organ.structure.models.Organization( + nodes=np.array([0, 0, 2, 0, 0, 5, 0, 7, 8]), + edges=np.array([ # 0 1 2 3 4 5 6 7 8 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 0 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 1 + [0, 0, 0, 0, 0, 0, 0, 2, 0], # 2 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 3 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 4 + [0, 0, 0, 0, 0, 0, 0, 3, 0], # 5 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 6 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 7 + [0, 0, 0, 0, 0, 0, 0, 3, 0], # 8 + ]), + node_features=np.array([[0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 25], + [1000, 0]]), + condition=np.array([1000, 0])) + + assert org_model.validness(org) + + # invalid nodes + org = organ.structure.models.Organization( + nodes=np.array([0, 0, 0, 0, 0, 5, 0, 7, 8]), + edges=np.array([ # 0 1 2 3 4 5 6 7 8 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 0 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 1 + [0, 0, 0, 0, 0, 0, 0, 2, 0], # 2 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 3 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 4 + [0, 0, 0, 0, 0, 0, 0, 3, 0], # 5 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 6 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 7 + [0, 0, 0, 0, 0, 0, 0, 3, 0], # 8 + ]), + node_features=np.array([[0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 25], + [1000, 0]]), + condition=np.array([1000, 0])) + + assert not org_model.validness(org) + + # invalid relations + org = organ.structure.models.Organization( + nodes=np.array([0, 1, 2, 0, 0, 5, 0, 7, 8]), + edges=np.array([ # 0 1 2 3 4 5 6 7 8 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 0 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 1 + [0, 0, 0, 0, 0, 0, 0, 2, 0], # 2 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 3 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 4 + [0, 0, 0, 0, 0, 0, 0, 3, 0], # 5 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 6 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 7 + [0, 0, 0, 0, 0, 0, 0, 3, 0], # 8 + ]), + node_features=np.array([[0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 25], + [1000, 0]]), + condition=np.array([10000, 0])) + + assert not org_model.validness(org) + + # invalid parameters + org = organ.structure.models.Organization( + nodes=np.array([0, 0, 2, 0, 0, 5, 0, 7, 8]), + edges=np.array([ # 0 1 2 3 4 5 6 7 8 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 0 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 1 + [0, 0, 0, 0, 0, 0, 0, 2, 0], # 2 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 3 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 4 + [0, 0, 0, 0, 0, 0, 0, 3, 0], # 5 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 6 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 7 + [0, 0, 0, 0, 0, 0, 0, 3, 0], # 8 + ]), + node_features=np.array([[0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 20], + [1000, 0]]), + condition=np.array([1000, 0])) + + assert not org_model.validness(org) + + +def test_demo_management_generate_key_values(): + org_model = organ.demo.ManagementModel() + nodes = np.array([0, 0, 2, 0, 1, 0, 0, 0, 0]) + + v = org_model.generate_key_values(nodes) + assert (v[0] > 999.99 and + v[0] < 1201.01 and + v[1] == pytest.approx(1.0)) + + +def test_demo_management_generate_augmentation(): + org_model = organ.demo.ManagementModel() + # valid configuration + org = organ.structure.models.Organization( + nodes=np.array([0, 0, 2, 0, 0, 5, 0, 7, 8]), + edges=np.array([ # 0 1 2 3 4 5 6 7 8 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 0 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 1 + [0, 0, 0, 0, 0, 0, 0, 2, 0], # 2 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 3 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 4 + [0, 0, 0, 0, 0, 0, 0, 3, 0], # 5 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 6 + [0, 0, 0, 0, 0, 0, 0, 0, 0], # 7 + [0, 0, 0, 0, 0, 0, 0, 3, 0], # 8 + ]), + node_features=np.array([[0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 25], + [1000, 0]]), + condition=np.array([1000, 0])) + aug_configuration = org_model.generate_augmentation( + org.nodes, + org.edges, + org.node_features, + logging=False, + max_iterations=1000) + aug_org = organ.structure.models.Organization( + nodes=aug_configuration[0], + edges=aug_configuration[1], + node_features=aug_configuration[2], + condition=aug_configuration[3]) + assert org_model.validness(aug_org)