Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""A script to create an augmented dataset."""
import os
import sys
import random
import argparse
import numpy as np
import organ.data.util
import organ.demo.structure
from organ.data.organization_structure_dataset import OrganizationStructureDataset # noqa E501
if __name__ == '__main__':
# Make it reproducible
random.seed(1)
np.random.seed(1)
parser = argparse.ArgumentParser()
# Dataset size
parser.add_argument('n', type=int,
help='New instances to create.')
# Source dataset
parser.add_argument('source', type=str,
help='Source dataset directory.')
# Destination directory
parser.add_argument('destination', type=str, nargs='?',
default='data',
help='Destination directory.')
# Potentially overwrite a dataset
parser.add_argument('--force', action='store_true', default=False,
help='Force storing he augmented dataset in an '
'existing directory.')
# Test size
parser.add_argument('--test', type=int,
help='Test subset size.')
# Validation set size
parser.add_argument('--validation', type=int,
help='Validation subset size.')
config = parser.parse_args()
test_size = config.test if config.test is not None else 0.1
val_size = config.validation if config.validation is not None else 0.1
if os.path.isdir(config.destination):
if not config.force:
print('The destination directory exists and may contain a '
'dataset. Use --force flag to overwrite it.')
sys.exit(1)
else:
os.makedirs(config.destination)
# Load the dataset
dataset = OrganizationStructureDataset(load_cond=True, load_params=True)
dataset.load(config.source)
# Make the components of the augmented dataset
org_model = organ.demo.structure.DemoOrganizationStructureModel()
nodes, edges, features, cond = organ.data.util.augment_dataset(
dataset,
config.n,
org_model
)
# Store the augmented dataset
organ.data.util.save_dataset(
nodes,
edges,
features,
cond,
config.destination,
org_model,
val_size, test_size)