import math import os from keras import Sequential from keras.layers import Dense from keras.optimizers import SGD import numpy as np import pytest import tensorflow as tf from csrank.callbacks import EarlyStoppingWithWeights from csrank.callbacks import LRScheduler from csrank.tests.test_ranking import check_params_tunable callbacks_dict = { "EarlyStopping": (EarlyStoppingWithWeights, {"patience": 5, "min_delta": 5e-2}), "LRScheduler": (LRScheduler, {"epochs_drop": 5, "drop": 0.9}), } @pytest.fixture(scope="module") def trivial_classification_problem(): random_state = np.random.RandomState(123) x = random_state.randn(200, 2) w = random_state.rand(2) y = 1.0 / (1.0 + np.exp(-np.dot(x, w))) y_true = np.array(y > 0.5, dtype=np.int64) return x, y_true def create_model(): lr = 0.015 model = Sequential() model.add(Dense(10, activation="relu")) model.add(Dense(5, activation="relu")) model.add(Dense(1, activation="sigmoid")) model.compile(optimizer=SGD(lr=lr), loss="binary_crossentropy") return model, lr @pytest.mark.parametrize("name", list(callbacks_dict.keys())) def test_callbacks(trivial_classification_problem, name): tf.set_random_seed(0) os.environ["KERAS_BACKEND"] = "tensorflow" np.random.seed(123) x, y = trivial_classification_problem epochs = 15 model, init_lr = create_model() callback, params = callbacks_dict[name] callback = callback(**params) callbacks = [callback] model.fit(x, y, epochs=epochs, callbacks=callbacks, validation_split=0.1) rtol = 1e-2 atol = 5e-4 if name == "LRScheduler": epochs_drop, drop = params["epochs_drop"], params["drop"] step = math.floor(epochs / epochs_drop) actual_lr = init_lr * math.pow(drop, step) key = ( "learning_rate" if "learning_rate" in model.optimizer.get_config() else "lr" ) learning_rate = model.optimizer.get_config().get(key, 0.0) assert np.isclose( actual_lr, learning_rate, rtol=rtol, atol=atol, equal_nan=False ) elif name == "EarlyStopping": assert callback.stopped_epoch == 6 params = {"epochs_drop": 100, "drop": 0.5, "patience": 10, "min_delta": 1e-5} callback.set_tunable_parameters(**params) check_params_tunable(callback, params, rtol, atol)