import os
import copy
import random
import pickle

import numpy as np

node_type_dict = { # status: 0 - optional, 1 - mandatory, 2 - replaceble
    0:  {'title': 'none', 'status': 0, 'weight': 0}, 
    1:  {'title': 'Управление запасами (складом)', 'status': 2, 'replacement': [2, 3]}, 
    2:  {'title': 'Управление материалами', 'status': 0},
    3:  {'title': 'Управление готовыми товарами', 'status': 0},
    4:  {'title': 'Планирование', 'status': 2, 'replacement': [5, 6, 7], 'children': [8]},
    5:  {'title': 'Планирование закупок', 'status': 0},
    6:  {'title': 'Планирование запасов', 'status': 0},
    7:  {'title': 'Планирование перевозок', 'status': 0},
    8:  {'title': 'Аналитика', 'status': 0},
    9:  {'title': 'Аудит', 'status': 0},
    10: {'title': 'Транспортировка', 'status': 1, 'children': [11]},
    11: {'title': 'Транспортное хозяйство', 'status': 0},
}

top_level_nodes = [1, 4, 9, 10]

relations_dict = [
    # 0   1   2   3   4   5   6   7   8   9   10  11
    [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0], # 0
    [ 0,  0,  1,  1,  2,  2,  2,  0,  2,  2,  2,  2], # 1
    [ 0,  0,  0,  0,  2,  2,  0,  2,  2,  2,  2,  2], # 2
    [ 0,  0,  0,  0,  2,  0,  2,  2,  2,  2,  2,  0], # 3
    [ 0,  2,  2,  2,  0,  1,  1,  1,  1,  0,  2,  2], # 4
    [ 0,  2,  2,  0,  0,  0,  0,  2,  2,  0,  0,  2], # 5
    [ 0,  2,  0,  2,  0,  0,  0,  2,  2,  0,  0,  0], # 6
    [ 0,  0,  2,  2,  0,  2,  2,  0,  2,  0,  2,  2], # 7
    [ 0,  2,  2,  2,  0,  2,  2,  2,  0,  0,  2,  2], # 8
    [ 0,  2,  2,  2,  0,  0,  0,  0,  0,  0,  2,  2], # 9
    [ 0,  2,  2,  2,  2,  0,  0,  2,  2,  2,  0,  1], # 10
    [ 0,  2,  2,  0,  2,  2,  0,  2,  2,  2,  0,  0], # 11
]

# Number of node types
NODE_N_TYPES = len(node_type_dict)
# Number of edge types
EDGE_N_TYPES = 3
# Max number of vertices per graph
MAX_NODES_PER_GRAPH = 12

# Parametrization constants
upper_limit = 25
min_person = 0.7
max_person = 1.4
min_nonzero_value = 0.1
min_orgunit = 2
req_orgunit = 5

# Recursive function to generate element tree

def generate_children(node_list, force_nodes = False):
    child_nodes = []
    for key in node_list:
        if node_type_dict[key]['status'] == 1 or force_nodes: # mandatory
            child_nodes.append(key)
            if 'replacement' in node_type_dict[key]:
                child_nodes.extend(generate_children(node_type_dict[key]['replacement']))
            if 'children' in node_type_dict[key]:
                child_nodes.extend(generate_children(node_type_dict[key]['children']))
        elif node_type_dict[key]['status'] == 0:
            if np.random.rand(1)[0] < 0.5:
                child_nodes.append(key) 
            if 'replacement' in node_type_dict[key]:
                child_nodes.extend(generate_children(node_type_dict[key]['replacement']))
            if 'children' in node_type_dict[key]:
                child_nodes.extend(generate_children(node_type_dict[key]['children']))
        elif node_type_dict[key]['status'] == 2:
            if np.random.rand(1)[0] < 0.5:
                child_nodes.append(key) 
                if 'replacement' in node_type_dict[key]:
                    child_nodes.extend(generate_children(node_type_dict[key]['replacement']))
                if 'children' in node_type_dict[key]:
                    child_nodes.extend(generate_children(node_type_dict[key]['children']))
            else:
                child_nodes.extend(generate_children(node_type_dict[key]['replacement'], True))
                if 'children' in node_type_dict[key]:
                    child_nodes.extend(generate_children(node_type_dict[key]['children']))

        #generate children
    return child_nodes

def check_org_unit_feasibility(nodes, load, unit_id, min_person, max_person, min_orgunit, req_orgunit, logging=False):
    if ((nodes[unit_id]==0 and load[unit_id]>max_person*req_orgunit) or 
        (nodes[unit_id]>0 and load[unit_id]<min_person*min_orgunit)):
        #error
        if logging:
            print(f"Node {unit_id} {node_type_dict[unit_id]['title']} doesn't meet the requirements: load = {load[unit_id]}, node = {nodes[unit_id]}.")
        return False
    
    return True

def generate_values(nodes, logging=False):
    v = np.zeros(12)

    if nodes[2] > 0:
        v[2] = np.random.uniform(min_person*min_orgunit, upper_limit) 
    else: 
        v[2] = np.random.uniform(min_nonzero_value, max_person*req_orgunit)

    if nodes[3] > 0:
        v[3] = np.random.uniform(min_person*min_orgunit, upper_limit) 
    else: 
        v[3] = np.random.uniform(min_nonzero_value, max_person*req_orgunit)

    if logging:
        print("input v=", v)

    v2_k = 0.2 if nodes[2] > 0 else 1
    v3_k = 0.2 if nodes[3] > 0 else 1

    v[1] = v2_k * v[2] + v3_k * v[3]
    if not check_org_unit_feasibility(nodes, v, 1, min_person, max_person, min_orgunit, req_orgunit, logging=logging):
        return np.empty(0)

    v[10] = v[2] * 1 + v[3] * 1

    v[11] = v[10] * 0.2
    if not check_org_unit_feasibility(nodes, v, 11, min_person, max_person, min_orgunit, req_orgunit, logging=logging):
        return np.empty(0)

    v12_k = v[2] * 0.4 if nodes[2] > 0 else v[1] * 0.2
    v11_k = v[11] * 0.2 if nodes[11] > 0 else 0
    v[5] = v12_k + v11_k
    if not check_org_unit_feasibility(nodes, v, 5, min_person, max_person, min_orgunit, req_orgunit, logging=logging):
        return np.empty(0)

    v[6] = v[3] * 0.4 if nodes[3] > 0 else v[1] * 0.2
    if not check_org_unit_feasibility(nodes, v, 6, min_person, max_person, min_orgunit, req_orgunit, logging=logging):
        return np.empty(0)

    v[7] = v[10] * 0.2
    if not check_org_unit_feasibility(nodes, v, 7, min_person, max_person, min_orgunit, req_orgunit, logging=logging):
        return np.empty(0)

    v[8] = v[5] * 0.2 + v[6] * 0.2 + v[7] * 0.2
    if not check_org_unit_feasibility(nodes, v, 8, min_person, max_person, min_orgunit, req_orgunit, logging=logging):
        return np.empty(0)

    v5_k = 0.2 if nodes[5] > 0 else 1
    v6_k = 0.2 if nodes[6] > 0 else 1
    v7_k = 0.2 if nodes[7] > 0 else 1
    v8_k = 0.2 if nodes[8] > 0 else 1
    v[4] = v5_k * v[5] + v6_k * v[6] + v7_k * v[7] + v8_k * v[8]
    if not check_org_unit_feasibility(nodes, v, 4, min_person, max_person, min_orgunit, req_orgunit, logging=logging):
        return np.empty(0)

    v[9] = v[1] * 0.05 + v[2] * 0.05 + v[3] * 0.05 + v[10] * 0.05 + v[11] * 0.05        
    if not check_org_unit_feasibility(nodes, v, 9, min_person, max_person, min_orgunit, req_orgunit, logging=logging):
        return np.empty(0)
    
    return v

def convert_values2persons(nodes, load, min_person, max_person):
    staff = np.zeros(12)
    for i in range(len(nodes)):
        #print(i)
        if nodes[i] > 0:
            #print(load[i], load[i]/max_person, load[i]/min_person)
            staff[i] = np.random.randint(np.ceil(load[i]/max_person), np.ceil(load[i]/min_person), 1)[0]
    return staff

# Shuffle nodes and relations

def shuffle_nodes(nodes, relations):
    keys = list(range(12))
    random.shuffle(keys)
    nodes = nodes[keys]
    relations = relations[keys][:, keys]
    #print(keys)
    #print(nodes)
    #print(relations)    
    return nodes, relations

def generate_model(shuffle=True):
    #generate nodes
    tmp_nodes = generate_children(top_level_nodes)
    
    #fill list with generated nodes and fill all relations
    nodes = list(node_type_dict.keys())
    relations = np.array(copy.deepcopy(relations_dict))

    for node_key in node_type_dict:
        if node_key not in tmp_nodes:
            nodes[node_key] = 0
            relations[node_key, :] = 0
            relations[:, node_key] = 0
    
    nodes = np.array(nodes)

    if shuffle:
        nodes, relations = shuffle_nodes(nodes, relations)
    return nodes, relations

def generate_parametrized_model(logging=False):
    while True: 
        nodes, relations = generate_model(False)
        if logging:
            print("\nnodes=", nodes)
        load = generate_values(nodes, logging)
        if len(load) > 0:
            if logging:
                print(load)
            staff = convert_values2persons(nodes, load, min_person, max_person)
            break
    return nodes, relations, staff

def unshuffle(nodes, relations):
    keys = [-1] * len(nodes)
    zeros = [i for i, node in enumerate(nodes) if node == 0]
    non_zeros = [(i, node) for i, node in enumerate(nodes) if node != 0]
    #print(zeros)
    #print(non_zeros)

    for i, node in non_zeros:
        keys[node] = i
    #print(keys)

    for i in range(len(keys)):
        if keys[i] == -1:
            keys[i] = zeros.pop(0)
    #print(keys)

    return nodes[keys], relations[keys][:, keys]

def check_children(top_level_nodes, node_list, force_nodes = False):
    #print('top_level_nodes=', top_level_nodes)
    #print('force_nodes', force_nodes)
    for key in top_level_nodes:
        #print('Key = ', key)
        if node_type_dict[key]['status'] == 1 or force_nodes: # mandatory
            #print('Mandatory')
            if node_list[key] == 0:
                return False, f'Mandatory element {key} is missing.'
            if 'replacement' in node_type_dict[key]:
                result, explanation = check_children(node_type_dict[key]['replacement'], node_list)
                if not result:
                    return False, explanation
            if 'children' in node_type_dict[key]:
                result, explanation = check_children(node_type_dict[key]['children'], node_list)
                if not result:
                    return False, explanation
        if node_type_dict[key]['status'] == 2 or force_nodes: # replacable
            #print('Replacable')
            if node_list[key] == 0:
                if 'replacement' in node_type_dict[key]:
                    result, explanation = check_children(node_type_dict[key]['replacement'], node_list, True)
                    if not result:
                        return False, f'Replacable element {key} is missing. {explanation}'
                else:
                    return False, f'Replacable element {key} is missing.'
            if 'children' in node_type_dict[key]:
                result, explanation = check_children(node_type_dict[key]['children'], node_list)
                if not result:
                    return False, explanation
        elif node_type_dict[key]['status'] == 0:
            #print('Optional')
            if 'replacement' in node_type_dict[key]:
                result, explanation = check_children(node_type_dict[key]['replacement'], node_list)
                if not result:
                    return False, explanation
            if 'children' in node_type_dict[key]:
                result, explanation = check_children(node_type_dict[key]['children'], node_list)
                if not result:
                    return False, explanation
    return True, ''

def check_relations(nodes, relations):
    """Checks relations validity.
    
    Parameters
    ----------
    nodes : List
        The list of node types.
    relations : numpy.ndarray (n, n)
        Relation type matrix.
    
    Returns
    -------
    result : Bool
        Returns `True` if all the set of edges is valid and consistent
        with the nodes.
    diff
        Boolean matrix of edge validness (`True` for valid edges).
    """
    target_relations = np.array(copy.deepcopy(relations_dict))
    for node_key in node_type_dict:
        if node_key not in nodes:
            target_relations[node_key, :] = 0
            target_relations[:, node_key] = 0

    relations_diff = np.array([relations == target_relations])
    result = relations_diff.all()
    return result, relations_diff

def check_nodes(nodes):
    """Checks node types validity.
    
    Parameters
    ----------
    nodes : List
        The list of node types.

    Returns
    -------
    Bool
        Returns `True` if the structure contains valid set of nodes.
    """
    return check_children(top_level_nodes, nodes)[0]

def overlap(first, last, another_first, another_last)->bool:
    #print(first, last, another_first, another_last)
    return min(last, another_last) - max(first, another_first) >= 0

def check_paramater_feasibility(nodes, staff, logging=False):
    for unit_id in node_type_dict:
        # Необязательный отдел имеет слишком мало сотрудников
        if (not node_type_dict[unit_id]['status'] == 1 and 
            nodes[unit_id] > 0 and staff[unit_id] < min_orgunit):
            #error
            if logging:
                print(f"Node {unit_id} {node_type_dict[unit_id]['title']} doesn't meet the requirements: staff = {staff[unit_id]}")
            return False
        # Пустной отдел имеют сотрудников
        if nodes[unit_id] == 0 and staff[unit_id] > 0:
            #error
            if logging:
                print(f"Non-existing Node {unit_id} {node_type_dict[unit_id]['title']} has staff: staff = {staff[unit_id]}")
            return False

    # Соотношения объемов отделов
    load_min = staff * min_person
    load_max = staff * max_person + min_person
    max_no_unit = req_orgunit * max_person
    for unit_id in node_type_dict:
        if (nodes[unit_id] == 0):
            load_max[unit_id] = max_no_unit
    #print(load_min)
    #print(load_max)

    v_min = np.zeros(12)
    v_max = np.zeros(12)

    unit_id = 1
    v_min[unit_id] = (0.2 if nodes[2] > 0 else 1) * load_min[2] + (0.2 if nodes[3] > 0 else 1) * load_min[3]
    v_max[unit_id] = (0.2 if nodes[2] > 0 else 1) * load_max[2] + (0.2 if nodes[3] > 0 else 1) * load_max[3]
    if nodes[unit_id] > 0 and not overlap(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]):
        #error
        if logging:
            print(f"Capacity of node {unit_id} {node_type_dict[unit_id]['title']} does not meet requirements.")
            print(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id])
        return False

    unit_id = 10        
    v_min[10] = load_min[2] * 1 + load_min[3] * 1
    v_max[10] = load_max[2] * 1 + load_max[3] * 1
    if nodes[unit_id] > 0 and not overlap(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]):
        #error
        if logging:
            print(f"Capacity of node {unit_id} {node_type_dict[unit_id]['title']} does not meet requirements.")
            print(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id])
        return False

    unit_id = 11
    v_min[11] = load_min[10] * 0.2
    v_max[11] = load_max[10] * 0.2
    if nodes[unit_id] > 0 and not overlap(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]):
        #error
        if logging:
            print(f"Capacity of node {unit_id} {node_type_dict[unit_id]['title']} does not meet requirements.")
            print(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id])
        return False

    unit_id = 5
    v_min[5] = (load_min[2] * 0.4 if nodes[2] > 0 else load_min[1] * 0.2) + (load_min[11] * 0.2 if nodes[11] > 0 else 0)
    v_max[5] = (load_max[2] * 0.4 if nodes[2] > 0 else load_max[1] * 0.2) + (load_max[11] * 0.2 if nodes[11] > 0 else 0)
    if nodes[unit_id] > 0 and not overlap(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]):
        #error
        if logging:
            print(f"Capacity of node {unit_id} {node_type_dict[unit_id]['title']} does not meet requirements.")
            print(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id])
        return False

    unit_id = 6
    v_min[6] = load_min[3] * 0.4 if nodes[3] > 0 else load_min[1] * 0.2
    v_max[6] = load_max[3] * 0.4 if nodes[3] > 0 else load_max[1] * 0.2
    if nodes[unit_id] > 0 and not overlap(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]):
        #error
        if logging:
            print(f"Capacity of node {unit_id} {node_type_dict[unit_id]['title']} does not meet requirements.")
            print(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id])
        return False

    unit_id = 7
    v_min[7] = load_min[10] * 0.2
    v_max[7] = load_max[10] * 0.2
    if nodes[unit_id] > 0 and not overlap(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]):
        #error
        if logging:
            print(f"Capacity of node {unit_id} {node_type_dict[unit_id]['title']} does not meet requirements.")
            print(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id])
        return False

    unit_id = 8
    v_min[8] = load_min[5] * 0.2 + load_min[6] * 0.2 + load_min[7] * 0.2
    v_max[8] = load_max[5] * 0.2 + load_max[6] * 0.2 + load_max[7] * 0.2
    if nodes[unit_id] > 0 and not overlap(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]):
        #error
        if logging:
            print(f"Capacity of node {unit_id} {node_type_dict[unit_id]['title']} does not meet requirements.")
            print(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id])
        return False

    unit_id = 4
    v_min[4] = (0.2 if nodes[5] > 0 else 1) * load_min[5] + \
               (0.2 if nodes[6] > 0 else 1) * load_min[6] + \
               (0.2 if nodes[7] > 0 else 1) * load_min[7] + \
               (0.2 if nodes[8] > 0 else 1) * load_min[8]
    v_max[4] = (0.2 if nodes[5] > 0 else 1) * load_max[5] + \
               (0.2 if nodes[6] > 0 else 1) * load_max[6] + \
               (0.2 if nodes[7] > 0 else 1) * load_max[7] + \
               (0.2 if nodes[8] > 0 else 1) * load_max[8]
    if nodes[unit_id] > 0 and not overlap(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]):
        #error
        if logging:
            print(f"Capacity of node {unit_id} {node_type_dict[unit_id]['title']} does not meet requirements.")
            print(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id])
        return False

    unit_id = 9
    v_min[9] = load_min[1] * 0.05 + load_min[2] * 0.05 + load_min[3] * 0.05 + load_min[10] * 0.05 + load_min[11] * 0.05        
    v_max[9] = load_max[1] * 0.05 + load_max[2] * 0.05 + load_max[3] * 0.05 + load_max[10] * 0.05 + load_max[11] * 0.05        
    if nodes[unit_id] > 0 and not overlap(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id]):
        #error
        if logging:
            print(f"Capacity of node {unit_id} {node_type_dict[unit_id]['title']} does not meet requirements.")
            print(load_min[unit_id], load_max[unit_id], v_min[unit_id], v_max[unit_id])
        return False
    
    return True

def generate_dataset(dataset_size):
    nodes_list = []
    edges_list = []
    staff_list = []
    for i in range(dataset_size):
        #nodes, relations = generate_model()
        nodes, relations, staff = generate_parametrized_model()
        nodes_list.append(nodes)
        edges_list.append(relations)
        staff_list.append(staff)
    return np.stack(nodes_list, axis=0), \
           np.stack(edges_list, axis=0), \
           np.stack(staff_list, axis=0)

def create_dataset(dataset_size, dataset_dir, validation=0.1, test=0.1):
    nodes, edges, staff = generate_dataset(dataset_size)
    
    validation = int(validation * dataset_size)
    test = int(test * dataset_size)
    train = dataset_size - validation - test

    all_idx = np.random.permutation(dataset_size)
    train_idx = all_idx[0:train]
    validation_idx = all_idx[train:train + validation]
    test_idx = all_idx[train + validation:]

    np.save(os.path.join(dataset_dir, 'data_nodes.npy'), nodes)
    np.save(os.path.join(dataset_dir, 'data_edges.npy'), edges)
    np.save(os.path.join(dataset_dir, 'data_staff.npy'), staff)

    with open(os.path.join(dataset_dir, 'data_meta.pkl'), 'wb') as f:
        pickle.dump({'train_idx': train_idx,
                     'train_count': train,
                     'train_counter': 0,
                     'validation_idx': validation_idx,
                     'validation_count': validation,
                     'validation_counter': 0,
                     'test_idx': test_idx,
                     'test_count': test,
                     'test_counter': 0,
                     
                     'node_num_types': NODE_N_TYPES,   
                     'edge_num_types': EDGE_N_TYPES,
                     'vertexes': MAX_NODES_PER_GRAPH,
                      }, f)

        
if __name__ == '__main__':
    
    # Make it reproducible
    random.seed(1)
    np.random.seed(1)
    
    create_dataset(10000, 'data', 0.1, 0.1)
