ssl_framework.main.SelfTrainingClassifier.fit#

SelfTrainingClassifier.fit(X_labeled, y_labeled, X_unlabeled, X_val=None, y_val=None)[source]#

Fit the self-training classifier using semi-supervised learning.

This method iteratively trains the base model by: 1. Training on current labeled data 2. Making predictions on unlabeled data 3. Selecting confident predictions using the selection strategy 4. Integrating new pseudo-labels using the integration strategy 5. Repeating until stopping criteria are met

Parameters:
  • X_labeled (array-like of shape (n_labeled_samples, n_features)) – Initial labeled training data. Can be numpy array or pandas DataFrame.

  • y_labeled (array-like of shape (n_labeled_samples,)) – Target values for labeled data. Can be numpy array or pandas Series.

  • X_unlabeled (array-like of shape (n_unlabeled_samples, n_features)) – Unlabeled training data to iteratively pseudo-label. Can be numpy array or pandas DataFrame.

  • X_val (array-like of shape (n_val_samples, n_features), optional) – Validation data for early stopping. If provided with y_val, enables early stopping based on validation score plateau.

  • y_val (array-like of shape (n_val_samples,), optional) – Validation targets for early stopping.

Returns:

self – Returns the fitted instance.

Return type:

SelfTrainingClassifier

classes_#

The classes seen during fit.

Type:

ndarray of shape (n_classes,)

ssl_framework.main.history_#

Training history containing metrics for each iteration: - iteration: iteration number - labeled_data_count: number of labeled samples before adding new ones - new_labels_count: number of new pseudo-labels added - average_confidence: mean confidence of newly added samples - validation_score: validation score (if validation data provided) - stopping_reason: reason for stopping (if applicable)

Type:

list of dict

ssl_framework.main.stopping_reason_#

Reason why training stopped (e.g., “Maximum iterations reached”, “Early stopping: no improvement”, “Labeling convergence”).

Type:

str

ssl_framework.main.feature_names_#

Feature names if input was DataFrame, None otherwise.

Type:

list or None

Examples

>>> import numpy as np
>>> from sklearn.linear_model import LogisticRegression
>>> from ssl_framework.main import SelfTrainingClassifier
>>>
>>> # Create sample data
>>> X_labeled = np.array([[0, 0], [1, 1], [10, 10], [11, 11]])
>>> y_labeled = np.array([0, 0, 1, 1])
>>> X_unlabeled = np.array([[0.5, 0.5], [10.5, 10.5], [5, 5]])
>>>
>>> # Fit SSL classifier
>>> ssl_clf = SelfTrainingClassifier(LogisticRegression())
>>> ssl_clf.fit(X_labeled, y_labeled, X_unlabeled)
>>>
>>> # Check training progress
>>> print(f"Stopped due to: {ssl_clf.stopping_reason_}")
>>> print(f"Training iterations: {len(ssl_clf.history_)}")