
import pickle
import os
import numpy as np

from datetime import datetime

class OrganizationStructureDataset:

    def load(self, path, subset=1):

        # Train part       
        self.nodes = np.load(os.path.join(path, 'data_nodes.npy'))
        self.edges = np.load(os.path.join(path, 'data_edges.npy'))
        
        with open(os.path.join(path, 'data_meta.pkl'), 'rb') as f:
            self.__dict__.update(pickle.load(f))

        self.train_idx = np.random.choice(self.train_idx, int(len(self.train_idx) * subset), replace=False)
        self.validation_idx = np.random.choice(self.validation_idx, int(len(self.validation_idx) * subset),
                                               replace=False)
        self.test_idx = np.random.choice(self.test_idx, int(len(self.test_idx) * subset), replace=False)

        self.train_count = len(self.train_idx)
        self.validation_count = len(self.validation_idx)
        self.test_count = len(self.test_idx)

        self.__len = self.train_count + self.validation_count + self.test_count


    def matrices2graph(self, node_labels, edge_labels, strict=False):
        """
        Transforms matrix definition of a labeled graph into a graph instance.
        
        Currently, this function just glues inputs. In general, it can be used to transform it
        to some optimized representation.
        
        Parameters
        ----------
        node_labels : (nodes, )
            Numpy array with node types.
        edge_labels : (nodes, nodes, )
            2D numpy array with edge types.

        Returns
        -------
        tuple ((nodes,), (nodes, nodes))
            tuple representing a graph. 
        
        """

        return node_labels, edge_labels
        

    def remove_seq2mol(self, seq, strict=False):
        mol = Chem.MolFromSmiles(''.join([self.smiles_decoder_m[e] for e in seq if e != 0]))

        if strict:
            try:
                Chem.SanitizeMol(mol)
            except:
                mol = None

        return mol

    def _next_batch(self, counter, count, idx, batch_size):
        if batch_size is not None:
            if counter + batch_size >= count:
                counter = 0
                np.random.shuffle(idx)

            output = [obj[idx[counter:counter + batch_size]]
                      for obj in (self.nodes, self.edges)]

            counter += batch_size
        else:
            output = [obj[idx] for obj in (self.nodes, self.edges)]

        return [counter] + output

    def next_train_batch(self, batch_size=None):
        out = self._next_batch(counter=self.train_counter, count=self.train_count,
                               idx=self.train_idx, batch_size=batch_size)
        self.train_counter = out[0]

        return out[1:]

    def next_validation_batch(self, batch_size=None):
        out = self._next_batch(counter=self.validation_counter, count=self.validation_count,
                               idx=self.validation_idx, batch_size=batch_size)
        self.validation_counter = out[0]

        return out[1:]

    def next_test_batch(self, batch_size=None):
        out = self._next_batch(counter=self.test_counter, count=self.test_count,
                               idx=self.test_idx, batch_size=batch_size)
        self.test_counter = out[0]

        return out[1:]

    @staticmethod
    def log(msg='', date=True):
        print(str(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) + ' ' + str(msg) if date else str(msg))

    def __len__(self):
        return self.__len


if __name__ == '__main__':

    pass
