Skip to content

Commit 9899c44

Browse files
author
Edward J Kim
authored
Add xgboost classification script (aws#225)
* Add xgboost classification script * Update xgboost regression script
1 parent ae6dd0b commit 9899c44

File tree

2 files changed

+132
-3
lines changed

2 files changed

+132
-3
lines changed

examples/xgboost/scripts/xgboost_abalone_basic_hook_demo.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,15 @@ def load_abalone(train_split=0.8, seed=42):
6262
return train_file.name, valid_file.name
6363

6464

65-
def create_tornasole_hook(out_dir, shap_data=None, frequency=1):
65+
def create_tornasole_hook(out_dir, train_data=None, validation_data=None, frequency=1):
6666

6767
save_config = SaveConfig(save_interval=frequency)
6868
hook = TornasoleHook(
6969
out_dir=out_dir,
7070
save_config=save_config,
71-
shap_data=shap_data)
71+
train_data=train_data,
72+
validation_data=validation_data
73+
)
7274

7375
return hook
7476

@@ -105,7 +107,7 @@ def main():
105107
hook = create_tornasole_hook(
106108
out_dir=output_uri,
107109
frequency=args.tornasole_frequency,
108-
shap_data=dtrain)
110+
train_data=dtrain)
109111

110112
bst = xgboost.train(
111113
params=params,
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import argparse
2+
import bz2
3+
import json
4+
import os
5+
import pickle
6+
import random
7+
import tempfile
8+
import urllib.request
9+
10+
import xgboost
11+
from tornasole import SaveConfig
12+
from tornasole.xgboost import TornasoleHook
13+
14+
15+
def parse_args():
16+
17+
parser = argparse.ArgumentParser()
18+
19+
parser.add_argument("--max_depth", type=int, default=5)
20+
parser.add_argument("--eta", type=float, default=0.05) # 0.2
21+
parser.add_argument("--gamma", type=int, default=4)
22+
parser.add_argument("--min_child_weight", type=int, default=6)
23+
parser.add_argument("--silent", type=int, default=0)
24+
parser.add_argument("--objective", type=str, default="multi:softmax")
25+
parser.add_argument("--num_class", type=int, default=10)
26+
parser.add_argument("--num_round", type=int, default=10)
27+
parser.add_argument("--tornasole_path", type=str, default=None)
28+
parser.add_argument("--tornasole_frequency", type=int, default=1)
29+
parser.add_argument("--output_uri", type=str, default="/opt/ml/output/tensors",
30+
help="S3 URI of the bucket where tensor data will be stored.")
31+
32+
parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAIN'))
33+
parser.add_argument('--validation', type=str, default=os.environ.get('SM_CHANNEL_VALIDATION'))
34+
35+
args = parser.parse_args()
36+
37+
return args
38+
39+
40+
def load_mnist(train_split=0.8, seed=42):
41+
42+
if not (0 < train_split <= 1):
43+
raise ValueError("'train_split' must be between 0 and 1.")
44+
45+
url = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.bz2"
46+
47+
with tempfile.NamedTemporaryFile(mode="wb", delete=False) as mnist_bz2:
48+
urllib.request.urlretrieve(url, mnist_bz2.name)
49+
50+
with bz2.open(mnist_bz2.name, "r") as fin:
51+
content = fin.read().decode("utf-8")
52+
lines = content.strip().split('\n')
53+
n = sum(1 for line in lines)
54+
indices = list(range(n))
55+
random.seed(seed)
56+
random.shuffle(indices)
57+
train_indices = set(indices[:int(n * 0.8)])
58+
59+
with tempfile.NamedTemporaryFile(mode='w', delete=False) as train_file:
60+
with tempfile.NamedTemporaryFile(mode='w', delete=False) as valid_file:
61+
for idx, line in enumerate(lines):
62+
if idx in train_indices:
63+
train_file.write(line + '\n')
64+
else:
65+
valid_file.write(line + '\n')
66+
67+
return train_file.name, valid_file.name
68+
69+
70+
def create_tornasole_hook(out_dir, train_data=None, validation_data=None, frequency=1):
71+
72+
save_config = SaveConfig(save_interval=frequency)
73+
hook = TornasoleHook(
74+
out_dir=out_dir,
75+
save_config=save_config,
76+
train_data=train_data,
77+
validation_data=validation_data)
78+
79+
return hook
80+
81+
82+
def main():
83+
84+
args = parse_args()
85+
86+
if args.train and args.validation:
87+
train, validation = args.train, args.validation
88+
else:
89+
train, validation = load_mnist()
90+
91+
dtrain = xgboost.DMatrix(train)
92+
dval = xgboost.DMatrix(validation)
93+
94+
watchlist = [(dtrain, "train"), (dval, "validation")]
95+
96+
params = {
97+
"max_depth": args.max_depth,
98+
"eta": args.eta,
99+
"gamma": args.gamma,
100+
"min_child_weight": args.min_child_weight,
101+
"silent": args.silent,
102+
"objective": args.objective,
103+
"num_class": args.num_class}
104+
105+
# The output_uri is a the URI for the s3 bucket where the metrics will be
106+
# saved.
107+
output_uri = (
108+
args.tornasole_path if args.tornasole_path is not None
109+
else args.output_uri)
110+
111+
hook = create_tornasole_hook(
112+
out_dir=output_uri,
113+
frequency=args.tornasole_frequency,
114+
train_data=dtrain,
115+
validation_data=dval)
116+
117+
bst = xgboost.train(
118+
params=params,
119+
dtrain=dtrain,
120+
evals=watchlist,
121+
num_boost_round=args.num_round,
122+
callbacks=[hook])
123+
124+
125+
if __name__ == "__main__":
126+
127+
main()

0 commit comments

Comments
 (0)