Skip to content
This repository was archived by the owner on Dec 6, 2023. It is now read-only.

Commit cefdc44

Browse files
authored
fix compatibility with scikit-learn by dropping dependency from sklearn.testing (#158)
1 parent 34967a7 commit cefdc44

16 files changed

+202
-243
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
import numpy as np
22

3-
from sklearn.utils.testing import assert_equal
4-
53
from lightning.impl.datasets.samples_generator import make_nn_regression
64

75

86
def test_make_nn_regression():
97
X, y, w = make_nn_regression(n_samples=10, n_features=50, n_informative=5)
10-
assert_equal(X.shape[0], 10)
11-
assert_equal(X.shape[1], 50)
12-
assert_equal(y.shape[0], 10)
13-
assert_equal(w.shape[0], 50)
14-
assert_equal(np.sum(X.data != 0), 10 * 5)
8+
assert X.shape[0] == 10
9+
assert X.shape[1] == 50
10+
assert y.shape[0] == 10
11+
assert w.shape[0] == 50
12+
assert np.sum(X.data != 0) == 10 * 5
1513

1614
X, y, w = make_nn_regression(n_samples=10, n_features=50, n_informative=50)
17-
assert_equal(np.sum(X.data != 0), 10 * 50)
15+
assert np.sum(X.data != 0) == 10 * 50
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
11
import pickle
22
import numpy as np
3-
from numpy.testing import (assert_almost_equal, assert_array_equal,
4-
assert_equal)
3+
54
from lightning.impl.randomkit import RandomState
65
from six.moves import xrange
76

87

98
def test_randint():
109
rs = RandomState(seed=0)
1110
vals = [rs.randint(10) for t in xrange(10000)]
12-
assert_almost_equal(np.mean(vals), 5.018)
11+
np.testing.assert_almost_equal(np.mean(vals), 5.018)
1312

1413

1514
def test_shuffle():
1615
ind = np.arange(10)
1716
rs = RandomState(seed=0)
1817
rs.shuffle(ind)
19-
assert_array_equal(ind, [2, 8, 4, 9, 1, 6, 7, 3, 0, 5])
18+
np.testing.assert_array_equal(ind, [2, 8, 4, 9, 1, 6, 7, 3, 0, 5])
2019

2120

2221
def test_random_state_pickle():
@@ -25,4 +24,4 @@ def test_random_state_pickle():
2524
pickle_rs = pickle.dumps(rs)
2625
pickle_rs = pickle.loads(pickle_rs)
2726
pickle_random_integer = pickle_rs.randint(5)
28-
assert_equal(random_integer, pickle_random_integer)
27+
assert random_integer == pickle_random_integer

lightning/impl/tests/test_adagrad.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import numpy as np
22

33
from sklearn.datasets import load_iris
4-
from sklearn.utils.testing import assert_equal
5-
from sklearn.utils.testing import assert_almost_equal
64

75
from lightning.classification import AdaGradClassifier
86
from lightning.regression import AdaGradRegressor
@@ -20,44 +18,44 @@ def test_adagrad_elastic_hinge():
2018
clf = AdaGradClassifier(alpha=0.5, l1_ratio=0.85, n_iter=10, random_state=0)
2119
clf.fit(X_bin, y_bin)
2220
assert not hasattr(clf, "predict_proba")
23-
assert_equal(clf.score(X_bin, y_bin), 1.0)
21+
assert clf.score(X_bin, y_bin) == 1.0
2422

2523

2624
def test_adagrad_elastic_smooth_hinge():
2725
clf = AdaGradClassifier(alpha=0.5, l1_ratio=0.85, loss="smooth_hinge",
2826
n_iter=10, random_state=0)
2927
clf.fit(X_bin, y_bin)
3028
assert not hasattr(clf, "predict_proba")
31-
assert_equal(clf.score(X_bin, y_bin), 1.0)
29+
assert clf.score(X_bin, y_bin) == 1.0
3230

3331

3432
def test_adagrad_elastic_log():
3533
clf = AdaGradClassifier(alpha=0.1, l1_ratio=0.85, loss="log", n_iter=10,
3634
random_state=0)
3735
clf.fit(X_bin, y_bin)
38-
assert_equal(clf.score(X_bin, y_bin), 1.0)
36+
assert clf.score(X_bin, y_bin) == 1.0
3937
check_predict_proba(clf, X_bin)
4038

4139

4240
def test_adagrad_hinge_multiclass():
4341
clf = AdaGradClassifier(alpha=1e-2, n_iter=100, loss="hinge", random_state=0)
4442
clf.fit(X, y)
4543
assert not hasattr(clf, "predict_proba")
46-
assert_almost_equal(clf.score(X, y), 0.940, 3)
44+
np.testing.assert_almost_equal(clf.score(X, y), 0.940, 3)
4745

4846

4947
def test_adagrad_classes_binary():
5048
clf = AdaGradClassifier()
5149
assert not hasattr(clf, 'classes_')
5250
clf.fit(X_bin, y_bin)
53-
assert_equal(list(clf.classes_), [-1, 1])
51+
assert list(clf.classes_) == [-1, 1]
5452

5553

5654
def test_adagrad_classes_multiclass():
5755
clf = AdaGradClassifier()
5856
assert not hasattr(clf, 'classes_')
5957
clf.fit(X, y)
60-
assert_equal(list(clf.classes_), [0, 1, 2])
58+
assert list(clf.classes_) == [0, 1, 2]
6159

6260

6361
def test_adagrad_callback():
@@ -80,12 +78,12 @@ def __call__(self, clf, t):
8078
clf = AdaGradClassifier(alpha=0.5, l1_ratio=0.85, n_iter=10,
8179
callback=cb, random_state=0)
8280
clf.fit(X_bin, y_bin)
83-
assert_equal(cb.acc[-1], 1.0)
81+
assert cb.acc[-1] == 1.0
8482

8583

8684
def test_adagrad_regression():
8785
for loss in ("squared", "absolute"):
8886
reg = AdaGradRegressor(loss=loss)
8987
reg.fit(X_bin, y_bin)
9088
y_pred = np.sign(reg.predict(X_bin))
91-
assert_equal(np.mean(y_bin == y_pred), 1.0)
89+
assert np.mean(y_bin == y_pred) == 1.0

lightning/impl/tests/test_dataset.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
import numpy as np
33
import scipy.sparse as sp
44

5-
from sklearn.utils.testing import assert_array_equal
6-
from sklearn.utils.testing import assert_array_almost_equal
7-
from sklearn.utils.testing import assert_equal
85
from six.moves import xrange
96

107
from sklearn.datasets import make_classification
@@ -38,34 +35,34 @@ def test_contiguous_get_row():
3835
ind = np.arange(X.shape[1])
3936
for i in xrange(X.shape[0]):
4037
indices, data, n_nz = cds.get_row(i)
41-
assert_array_equal(indices, ind)
42-
assert_array_equal(data, X[i])
43-
assert_equal(n_nz, X.shape[1])
38+
np.testing.assert_array_equal(indices, ind)
39+
np.testing.assert_array_equal(data, X[i])
40+
assert n_nz == X.shape[1]
4441

4542

4643
def test_csr_get_row():
4744
for i in xrange(X.shape[0]):
4845
indices, data, n_nz = csr_ds.get_row(i)
4946
for jj in xrange(n_nz):
5047
j = indices[jj]
51-
assert_equal(X[i, j], data[jj])
48+
assert X[i, j] == data[jj]
5249

5350

5451
def test_fortran_get_column():
5552
ind = np.arange(X.shape[0])
5653
for j in xrange(X.shape[1]):
5754
indices, data, n_nz = fds.get_column(j)
58-
assert_array_equal(indices, ind)
59-
assert_array_equal(data, X[:, j])
60-
assert_equal(n_nz, X.shape[0])
55+
np.testing.assert_array_equal(indices, ind)
56+
np.testing.assert_array_equal(data, X[:, j])
57+
assert n_nz == X.shape[0]
6158

6259

6360
def test_csc_get_column():
6461
for j in xrange(X.shape[1]):
6562
indices, data, n_nz = csc_ds.get_column(j)
6663
for ii in xrange(n_nz):
6764
i = indices[ii]
68-
assert_equal(X[i, j], data[ii])
65+
assert X[i, j] == data[ii]
6966

7067

7168
def test_picklable_datasets():
@@ -74,5 +71,5 @@ def test_picklable_datasets():
7471
for dataset in [cds, csr_ds, fds, csc_ds]:
7572
pds = pickle.dumps(dataset)
7673
dataset = pickle.loads(pds)
77-
assert_equal(dataset.get_n_samples(), X.shape[0])
78-
assert_equal(dataset.get_n_features(), X.shape[1])
74+
assert dataset.get_n_samples() == X.shape[0]
75+
assert dataset.get_n_features() == X.shape[1]

lightning/impl/tests/test_dual_cd.py

+12-16
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,6 @@
55
from sklearn.datasets import make_regression
66
from six.moves import xrange
77

8-
from sklearn.utils.testing import assert_equal
9-
from sklearn.utils.testing import assert_greater
10-
from sklearn.utils.testing import assert_array_almost_equal
11-
128
from lightning.impl.datasets.samples_generator import make_classification
139
from lightning.impl.dual_cd import LinearSVC
1410
from lightning.impl.dual_cd import LinearSVR
@@ -40,16 +36,16 @@ def test_sparse_dot():
4036
K2[i, j] = sparse_dot(ds, i, j)
4137
K2[j, i] = K[i, j]
4238

43-
assert_array_almost_equal(K, K2)
39+
np.testing.assert_array_almost_equal(K, K2)
4440

4541

4642
def test_fit_linear_binary():
4743
for data in (bin_dense, bin_csr):
4844
for loss in ("l1", "l2"):
4945
clf = LinearSVC(loss=loss, random_state=0, max_iter=10)
5046
clf.fit(data, bin_target)
51-
assert_equal(list(clf.classes_), [0, 1])
52-
assert_equal(clf.score(data, bin_target), 1.0)
47+
assert list(clf.classes_) == [0, 1]
48+
assert clf.score(data, bin_target) == 1.0
5349
y_pred = clf.decision_function(data).ravel()
5450

5551

@@ -59,17 +55,17 @@ def test_fit_linear_binary_auc():
5955
clf = LinearSVC(loss=loss, criterion="auc", random_state=0,
6056
max_iter=25)
6157
clf.fit(data, bin_target)
62-
assert_equal(clf.score(data, bin_target), 1.0)
58+
assert clf.score(data, bin_target) == 1.0
6359

6460

6561
def test_fit_linear_multi():
6662
for data in (mult_dense, mult_sparse):
6763
clf = LinearSVC(random_state=0)
6864
clf.fit(data, mult_target)
69-
assert_equal(list(clf.classes_), [0, 1, 2])
65+
assert list(clf.classes_) == [0, 1, 2]
7066
y_pred = clf.predict(data)
7167
acc = np.mean(y_pred == mult_target)
72-
assert_greater(acc, 0.85)
68+
assert acc > 0.85
7369

7470

7571
def test_warm_start():
@@ -79,32 +75,32 @@ def test_warm_start():
7975

8076
clf.fit(bin_dense, bin_target)
8177
acc = clf.score(bin_dense, bin_target)
82-
assert_greater(acc, 0.99)
78+
assert acc > 0.99
8379

8480

8581
def test_linear_svr():
8682
reg = LinearSVR(random_state=0)
8783
reg.fit(reg_dense, reg_target)
88-
assert_greater(reg.score(reg_dense, reg_target), 0.99)
84+
assert reg.score(reg_dense, reg_target) > 0.99
8985

9086

9187
def test_linear_svr_fit_intercept():
9288
reg = LinearSVR(random_state=0, fit_intercept=True)
9389
reg.fit(reg_dense, reg_target)
94-
assert_greater(reg.score(reg_dense, reg_target), 0.99)
90+
assert reg.score(reg_dense, reg_target) > 0.99
9591

9692

9793
def test_linear_svr_l2():
9894
reg = LinearSVR(loss="l2", random_state=0)
9995
reg.fit(reg_dense, reg_target)
100-
assert_greater(reg.score(reg_dense, reg_target), 0.99)
96+
assert reg.score(reg_dense, reg_target) > 0.99
10197

10298

10399
def test_linear_svr_warm_start():
104100
reg = LinearSVR(C=1e-3, random_state=0, warm_start=True)
105101
reg.fit(reg_dense, reg_target)
106-
assert_greater(reg.score(reg_dense, reg_target), 0.96)
102+
assert reg.score(reg_dense, reg_target) > 0.96
107103

108104
reg.C = 1
109105
reg.fit(reg_dense, reg_target)
110-
assert_greater(reg.score(reg_dense, reg_target), 0.99)
106+
assert reg.score(reg_dense, reg_target) > 0.99

0 commit comments

Comments
 (0)