import numpy as np

import generate_dataset

class OrganizationMetrics(object):
    """Utility functions to calculate organization quality and validity metrics.
    
    Organizations are defined by a tuple:
      vector of node types
      matrix of edge types
    
    """
    
    @staticmethod
    def valid_lambda(org):
        """
        Checks validity of the organization structure.
        
        Currently, it is just a stub, returning true for every organization structure.
        """
        return org is not None and generate_dataset.check_nodes(org[0]) and \
                                 generate_dataset.check_relations(org[0], org[1])[0]

    @staticmethod
    def edge_validness_scores(orgs):
        """Estimates egde validness for multiple organizations."""
        meth = lambda x : generate_dataset.check_relations(x[0], x[1])[0]
        return np.array(list(map(meth, orgs)), dtype=np.float32)

    @staticmethod
    def node_validness_scores(orgs):
        """Estimates node validness for multiple organizations."""
        meth = lambda x : generate_dataset.check_nodes(x[0])
        return np.array(list(map(meth, orgs)), dtype=np.float32)
    
    @staticmethod
    def valid_scores(orgs):
        return np.array(list(map(OrganizationMetrics.valid_lambda, orgs)), dtype=np.float32)

    @staticmethod
    def valid_filter(orgs):
        return list(filter(OrganizationMetrics.valid_lambda, orgs))

    @staticmethod
    def valid_total_score(orgs):
        return np.array(list(map(OrganizationMetrics.valid_lambda, orgs)), dtype=np.float32).mean()
    
    @staticmethod
    def sample_organization_metric(orgs, norm=False):
        scores = [OrganizationMetrics.valid_lambda(org) if org is not None else None for org in orgs]
        scores = np.array(scores)

        return scores
    
def all_scores(orgs, data, norm=False, reconstruction=False):
    # These are one-value-for-structure scores
    m0 = {k: list(filter(lambda e: e is not None, v)) for k, v in {
           'Sample score': OrganizationMetrics.sample_organization_metric(orgs, norm=norm)}.items()
         }
    # These are one-value-for-batch scores (used, e.g., in batch log reporting)
    m1 = {'valid score': OrganizationMetrics.valid_total_score(orgs) * 100,
          'node score': np.mean(OrganizationMetrics.node_validness_scores(orgs)) * 100,
          'edge score': np.mean(OrganizationMetrics.edge_validness_scores(orgs)) * 100
         }

    return m0, m1
