1
+ import argparse
2
+ import bz2
3
+ import json
4
+ import os
5
+ import pickle
6
+ import random
7
+ import tempfile
8
+ import urllib .request
9
+
10
+ import xgboost
11
+ from tornasole import SaveConfig
12
+ from tornasole .xgboost import TornasoleHook
13
+
14
+
15
+ def parse_args ():
16
+
17
+ parser = argparse .ArgumentParser ()
18
+
19
+ parser .add_argument ("--max_depth" , type = int , default = 5 )
20
+ parser .add_argument ("--eta" , type = float , default = 0.05 ) # 0.2
21
+ parser .add_argument ("--gamma" , type = int , default = 4 )
22
+ parser .add_argument ("--min_child_weight" , type = int , default = 6 )
23
+ parser .add_argument ("--silent" , type = int , default = 0 )
24
+ parser .add_argument ("--objective" , type = str , default = "multi:softmax" )
25
+ parser .add_argument ("--num_class" , type = int , default = 10 )
26
+ parser .add_argument ("--num_round" , type = int , default = 10 )
27
+ parser .add_argument ("--tornasole_path" , type = str , default = None )
28
+ parser .add_argument ("--tornasole_frequency" , type = int , default = 1 )
29
+ parser .add_argument ("--output_uri" , type = str , default = "/opt/ml/output/tensors" ,
30
+ help = "S3 URI of the bucket where tensor data will be stored." )
31
+
32
+ parser .add_argument ('--train' , type = str , default = os .environ .get ('SM_CHANNEL_TRAIN' ))
33
+ parser .add_argument ('--validation' , type = str , default = os .environ .get ('SM_CHANNEL_VALIDATION' ))
34
+
35
+ args = parser .parse_args ()
36
+
37
+ return args
38
+
39
+
40
+ def load_mnist (train_split = 0.8 , seed = 42 ):
41
+
42
+ if not (0 < train_split <= 1 ):
43
+ raise ValueError ("'train_split' must be between 0 and 1." )
44
+
45
+ url = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.bz2"
46
+
47
+ with tempfile .NamedTemporaryFile (mode = "wb" , delete = False ) as mnist_bz2 :
48
+ urllib .request .urlretrieve (url , mnist_bz2 .name )
49
+
50
+ with bz2 .open (mnist_bz2 .name , "r" ) as fin :
51
+ content = fin .read ().decode ("utf-8" )
52
+ lines = content .strip ().split ('\n ' )
53
+ n = sum (1 for line in lines )
54
+ indices = list (range (n ))
55
+ random .seed (seed )
56
+ random .shuffle (indices )
57
+ train_indices = set (indices [:int (n * 0.8 )])
58
+
59
+ with tempfile .NamedTemporaryFile (mode = 'w' , delete = False ) as train_file :
60
+ with tempfile .NamedTemporaryFile (mode = 'w' , delete = False ) as valid_file :
61
+ for idx , line in enumerate (lines ):
62
+ if idx in train_indices :
63
+ train_file .write (line + '\n ' )
64
+ else :
65
+ valid_file .write (line + '\n ' )
66
+
67
+ return train_file .name , valid_file .name
68
+
69
+
70
+ def create_tornasole_hook (out_dir , train_data = None , validation_data = None , frequency = 1 ):
71
+
72
+ save_config = SaveConfig (save_interval = frequency )
73
+ hook = TornasoleHook (
74
+ out_dir = out_dir ,
75
+ save_config = save_config ,
76
+ train_data = train_data ,
77
+ validation_data = validation_data )
78
+
79
+ return hook
80
+
81
+
82
+ def main ():
83
+
84
+ args = parse_args ()
85
+
86
+ if args .train and args .validation :
87
+ train , validation = args .train , args .validation
88
+ else :
89
+ train , validation = load_mnist ()
90
+
91
+ dtrain = xgboost .DMatrix (train )
92
+ dval = xgboost .DMatrix (validation )
93
+
94
+ watchlist = [(dtrain , "train" ), (dval , "validation" )]
95
+
96
+ params = {
97
+ "max_depth" : args .max_depth ,
98
+ "eta" : args .eta ,
99
+ "gamma" : args .gamma ,
100
+ "min_child_weight" : args .min_child_weight ,
101
+ "silent" : args .silent ,
102
+ "objective" : args .objective ,
103
+ "num_class" : args .num_class }
104
+
105
+ # The output_uri is a the URI for the s3 bucket where the metrics will be
106
+ # saved.
107
+ output_uri = (
108
+ args .tornasole_path if args .tornasole_path is not None
109
+ else args .output_uri )
110
+
111
+ hook = create_tornasole_hook (
112
+ out_dir = output_uri ,
113
+ frequency = args .tornasole_frequency ,
114
+ train_data = dtrain ,
115
+ validation_data = dval )
116
+
117
+ bst = xgboost .train (
118
+ params = params ,
119
+ dtrain = dtrain ,
120
+ evals = watchlist ,
121
+ num_boost_round = args .num_round ,
122
+ callbacks = [hook ])
123
+
124
+
125
+ if __name__ == "__main__" :
126
+
127
+ main ()
0 commit comments