fromtypingimportOptionalimportnumpyasnpfromnumpy.typingimportArrayLikefromscipy.sparseimportissparsefromsklearn.neighborsimportNearestNeighborsfromsklearn.utilsimport_approximate_mode,_safe_indexing,check_array,check_consistent_lengthdef_conditional_shuffle(nbrs:ArrayLike,replace:bool=False,seed=None)->ArrayLike:"""Compute a permutation of neighbors with restrictions. Parameters ---------- nbrs : ArrayLike of shape (n_samples, k) The k-nearest-neighbors for each sample index. Each row corresponds to the original sample. Each element corresponds to another sample index that is deemed as the k-nearest neighbors with respect to the original sample. replace : bool, optional Whether or not to allow replacement of samples, by default False. seed : int, optional Random seed, by default None. Returns ------- restricted_perm : ArrayLike of shape (n_samples) The final permutation order of the sample indices. There may be repeating samples. See Notes for details. Notes ----- Restricted permutation goes through random samples and looks at the k-nearest neighbors (columns of ``nbrs``) and shuffles the closest neighbor index only if it has not been used to permute another sample. If it has been, then the algorithm looks at the next nearest-neighbor and so on. If all k-nearest neighbors of a sample has been checked, then a random neighbor is chosen. In this manner, the algorithm tries to perform permutation without replacement, but if necessary, will choose a repeating neighbor sample. """n_samples,k_dims=nbrs.shaperng=np.random.default_rng(seed=seed)# initialize the final permutation orderrestricted_perm=np.zeros((n_samples,),dtype=np.intp)# generate a random order of samples to go throughrandom_order=rng.permutation(n_samples)# keep track of values we have already usedused=set()# go through the random orderforidxinrandom_order:ifreplace:possible_nbrs=nbrs[idx,:]restricted_perm[idx]=rng.choice(possible_nbrs,size=1).squeeze()else:m=0use_idx=nbrs[idx,m]# if the current nbr is already used, continue incrementing# until we have either found a new sample to use, or if# we have reach the maximum number of shuffles to considerwhile(use_idxinused)and(m<k_dims-1):m+=1use_idx=nbrs[idx,m]# check whether or not we have exhaustively checked all kNNifuse_idxinusedandm==k_dims:# XXX: Note this step is not in the original paper# choose a random neighbor to permuterestricted_perm[idx]=rng.choice(nbrs[idx,:],size=1)else:# permute with the existing neighborrestricted_perm[idx]=use_idxused.add(use_idx)returnrestricted_perm
[docs]defconditional_resample(conditional_array:ArrayLike,*arrays,nn_estimator=None,replace:bool=True,replace_nbrs:bool=True,n_samples:Optional[int]=None,random_state:Optional[int]=None,stratify:Optional[ArrayLike]=None,):"""Conditionally resample arrays or sparse matrices in a consistent way. The default strategy implements one step of the bootstrapping procedure. Conditional resampling is a modification of the bootstrap technique that preserves the conditional distribution of the data. This is done by fitting a nearest neighbors estimator on the conditional array and then resampling the nearest neighbors of each sample. Parameters ---------- conditional_array : array-like of shape (n_samples, n_features) The array, which we preserve the conditional distribution of. *arrays : sequence of array-like of shape (n_samples,) or \ (n_samples, n_outputs) Indexable data-structures can be arrays, lists, dataframes or scipy sparse matrices with consistent first dimension. nn_estimator : estimator object, default=None The nearest neighbors estimator to use. If None, then a :class:`sklearn.neighbors.NearestNeighbors` instance is used. replace : bool, default=True Implements resampling with replacement. If False, this will implement (sliced) random permutations. The replacement will take place at the level of the sample index. replace_nbrs : bool, default=True Implements resampling with replacement at the level of the nearest neighbors. n_samples : int, default=None Number of samples to generate. If left to None this is automatically set to the first dimension of the arrays. If replace is False it should not be larger than the length of arrays. random_state : int, RandomState instance or None, default=None Determines random number generation for shuffling the data. Pass an int for reproducible results across multiple function calls. See :term:`Glossary <random_state>`. stratify : array-like of shape (n_samples,) or (n_samples, n_outputs), \ default=None If not None, data is split in a stratified fashion, using this as the class labels. Returns ------- resampled_arrays : sequence of array-like of shape (n_samples,) or \ (n_samples, n_outputs) Sequence of resampled copies of the collections. The original arrays are not impacted. """max_n_samples=n_samplesrng=np.random.default_rng(random_state)iflen(arrays)==0:returnNonefirst=arrays[0]n_samples=first.shape[0]ifhasattr(first,"shape")elselen(first)ifmax_n_samplesisNone:max_n_samples=n_sampleselif(max_n_samples>n_samples)and(notreplace):raiseValueError(f"Cannot sample {max_n_samples} out of arrays with dim "f"{n_samples} when replace is False")check_consistent_length(conditional_array,*arrays)# fit nearest neighbors onto the conditional arrayifnn_estimatorisNone:nn_estimator=NearestNeighbors()nn_estimator.fit(conditional_array)ifstratifyisNone:ifreplace:indices=rng.integers(0,n_samples,size=(max_n_samples,))else:indices=np.arange(n_samples)rng.shuffle(indices)indices=indices[:max_n_samples]else:# Code adapted from StratifiedShuffleSplit()y=check_array(stratify,ensure_2d=False,dtype=None)ify.ndim==2:# for multi-label y, map each distinct row to a string repr# using join because str(row) uses an ellipsis if len(row) > 1000y=np.array([" ".join(row.astype("str"))forrowiny])classes,y_indices=np.unique(y,return_inverse=True)n_classes=classes.shape[0]class_counts=np.bincount(y_indices)# Find the sorted list of instances for each class:# (np.unique above performs a sort, so code is O(n logn) already)class_indices=np.split(np.argsort(y_indices,kind="mergesort"),np.cumsum(class_counts)[:-1])n_i=_approximate_mode(class_counts,max_n_samples,random_state)indices=[]foriinrange(n_classes):indices_i=rng.choice(class_indices[i],n_i[i],replace=replace)indices.extend(indices_i)indices=rng.permutation(indices)# now get the kNN indices for each sample (n_samples, n_neighbors)sample_nbrs=nn_estimator.kneighbors(X=conditional_array[indices,:],return_distance=False)# actually sample the indices using a conditional permutationindices=_conditional_shuffle(sample_nbrs,replace=replace_nbrs,seed=rng)# convert sparse matrices to CSR for row-based indexingarrays_=[a.tocsr()ifissparse(a)elseaforainarrays]resampled_arrays=[_safe_indexing(a,indices)forainarrays_]iflen(resampled_arrays)==1:# syntactic sugar for the unit argument casereturnresampled_arrays[0]else:returnresampled_arrays