""" Utilities for training kmeans model. Author * Pooneh Mousavi 2023 """ import os from tqdm.contrib import tqdm from speechbrain.utils.logger import get_logger try: from sklearn.cluster import MiniBatchKMeans except ImportError: err_msg = "The optional dependency sklearn is needed to use this module\n" err_msg += "Cannot import sklearn.cluster.MiniBatchKMeans to use KMeans/\n" err_msg += "Please follow the instructions below\n" err_msg += "=============================\n" err_msg += "pip install -U scikit-learn\n" raise ImportError(err_msg) import joblib logger = get_logger(__name__) def accumulate_and_extract_features( batch, features_list, ssl_model, ssl_layer_num, device ): """Extract features (output of SSL model) and acculamte them on cpu to be used for clustering. Arguments --------- batch : tensor Single batch of data. features_list : list accumulate features list. ssl_model : torch.nn.Module SSL-model used to extract features used for clustering. ssl_layer_num : int specify output of which layer of the ssl_model should be used. device : str `cpu` or `cuda` device. """ batch = batch.to(device) wavs, wav_lens = batch.sig wavs, wav_lens = ( wavs.to(device), wav_lens.to(device), ) feats = ssl_model(wavs, wav_lens)[ssl_layer_num].flatten(end_dim=-2) features_list.extend(feats.to("cpu").detach().numpy()) def fetch_kmeans_model( n_clusters, init, max_iter, batch_size, tol, max_no_improvement, n_init, reassignment_ratio, random_state, checkpoint_path, ): """Return a k-means clustering model with specified parameters. Arguments --------- n_clusters : MiniBatchKMeans The number of clusters to form as well as the number of centroids to generate. init : int Method for initialization: {'k-means++'', ''random''} max_iter : int Maximum number of iterations over the complete dataset before stopping independently of any early stopping criterion heuristics. batch_size : int Size of the mini batches. tol : float Control early stopping based on the relative center changes as measured by a smoothed, variance-normalized of the mean center squared position changes. max_no_improvement :int Control early stopping based on the consecutive number of mini batches that does not yield an improvement on the smoothed inertia. n_init : int Number of random initializations that are tried reassignment_ratio : float Control the fraction of the maximum number of counts for a center to be reassigned. random_state :int Determines random number generation for centroid initialization and random reassignment. checkpoint_path : str Path to saved model. Returns ------- MiniBatchKMeans a k-means clustering model with specified parameters. """ if os.path.exists(checkpoint_path): logger.info(f"The checkpoint is loaded from {checkpoint_path}.") return joblib.load(checkpoint_path) logger.info( f"No checkpoint is found at {checkpoint_path}. New model is initialized for training." ) return MiniBatchKMeans( n_clusters=n_clusters, init=init, max_iter=max_iter, batch_size=batch_size, tol=tol, max_no_improvement=max_no_improvement, n_init=n_init, reassignment_ratio=reassignment_ratio, random_state=random_state, verbose=1, compute_labels=True, init_size=None, ) def process_chunks(data, chunk_size, model): """Process data in chunks of a specified size. Arguments --------- data : list The list of integers to be processed. chunk_size : int The size of each chunk. model : MiniBatchKMeans The initial kmeans model for training. """ for i in range(0, len(data), chunk_size): chunk = data[i : i + chunk_size] # Skip processing if the chunk size is smaller than chunk_size if len(chunk) < chunk_size: break model = model.partial_fit(chunk) def train( model, train_set, ssl_model, save_path, ssl_layer_num, kmeans_batch_size=1000, device="cpu", checkpoint_interval=10, ): """Train a Kmeans model . Arguments --------- model : MiniBatchKMeans The initial kmeans model for training. train_set : Dataloader Batches of tarining data. ssl_model : torch.nn.Module SSL-model used to extract features used for clustering. save_path: string Path to save intra-checkpoints and dataloader. ssl_layer_num : int Specify output of which layer of the ssl_model should be used. kmeans_batch_size : int Size of the mini batches. device : str `cpu` or `cuda` device. checkpoint_interval: int Determine at which iterations to save the checkpoints. """ logger.info("Start training kmeans model.") features_list = [] iteration = 0 with tqdm( train_set, dynamic_ncols=True, ) as t: for batch in t: # extract features from the SSL model accumulate_and_extract_features( batch, features_list, ssl_model, ssl_layer_num, device ) # train a kmeans model on a single batch if features_list reaches the kmeans_batch_size. if len(features_list) >= kmeans_batch_size: process_chunks(features_list, kmeans_batch_size, model) iteration += 1 features_list = [] if (iteration + 1) % checkpoint_interval == 0: logger.info( f"Saving intra-checkpoints for iteration {iteration}." ) train_set._speechbrain_save( os.path.join(save_path, "dataloader-TRAIN.ckpt") ) checkpoint_path = os.path.join( save_path, f"kmeans-cluster-{model.n_clusters}-layer-{ssl_layer_num}.pt", ) save_model(model, checkpoint_path) if len(features_list) >= kmeans_batch_size: process_chunks(features_list, kmeans_batch_size, model) def save_model(model, checkpoint_path): """Save a Kmeans model . Arguments --------- model : MiniBatchKMeans The kmeans model to be saved. checkpoint_path : str Path to save the model. """ joblib.dump(model, open(checkpoint_path, "wb"))
Memory