ssl_framework.main.SelfTrainingClassifier.fit#
- SelfTrainingClassifier.fit(X_labeled, y_labeled, X_unlabeled, X_val=None, y_val=None)[source]#
Fit the self-training classifier using semi-supervised learning.
This method iteratively trains the base model by: 1. Training on current labeled data 2. Making predictions on unlabeled data 3. Selecting confident predictions using the selection strategy 4. Integrating new pseudo-labels using the integration strategy 5. Repeating until stopping criteria are met
- Parameters:
X_labeled (array-like of shape (n_labeled_samples, n_features)) – Initial labeled training data. Can be numpy array or pandas DataFrame.
y_labeled (array-like of shape (n_labeled_samples,)) – Target values for labeled data. Can be numpy array or pandas Series.
X_unlabeled (array-like of shape (n_unlabeled_samples, n_features)) – Unlabeled training data to iteratively pseudo-label. Can be numpy array or pandas DataFrame.
X_val (array-like of shape (n_val_samples, n_features), optional) – Validation data for early stopping. If provided with y_val, enables early stopping based on validation score plateau.
y_val (array-like of shape (n_val_samples,), optional) – Validation targets for early stopping.
- Returns:
self – Returns the fitted instance.
- Return type:
- classes_#
The classes seen during fit.
- Type:
ndarray of shape (n_classes,)
- ssl_framework.main.history_#
Training history containing metrics for each iteration: - iteration: iteration number - labeled_data_count: number of labeled samples before adding new ones - new_labels_count: number of new pseudo-labels added - average_confidence: mean confidence of newly added samples - validation_score: validation score (if validation data provided) - stopping_reason: reason for stopping (if applicable)
- ssl_framework.main.stopping_reason_#
Reason why training stopped (e.g., “Maximum iterations reached”, “Early stopping: no improvement”, “Labeling convergence”).
- Type:
- ssl_framework.main.feature_names_#
Feature names if input was DataFrame, None otherwise.
- Type:
list or None
Examples
>>> import numpy as np >>> from sklearn.linear_model import LogisticRegression >>> from ssl_framework.main import SelfTrainingClassifier >>> >>> # Create sample data >>> X_labeled = np.array([[0, 0], [1, 1], [10, 10], [11, 11]]) >>> y_labeled = np.array([0, 0, 1, 1]) >>> X_unlabeled = np.array([[0.5, 0.5], [10.5, 10.5], [5, 5]]) >>> >>> # Fit SSL classifier >>> ssl_clf = SelfTrainingClassifier(LogisticRegression()) >>> ssl_clf.fit(X_labeled, y_labeled, X_unlabeled) >>> >>> # Check training progress >>> print(f"Stopped due to: {ssl_clf.stopping_reason_}") >>> print(f"Training iterations: {len(ssl_clf.history_)}")