Source code for deepirtools.utils

import random
import math
import torch
from torch import nn
from torch.utils.data import Dataset
from pyro.distributions import Categorical, Distribution
import numpy as np
from typing import List, Optional
from itertools import chain


[docs]def manual_seed(seed: int): """Set random seeds to ensure reproducible results.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed)
class ConvergenceChecker(): def __init__(self, n_intervals = 100, log_interval = 1): self.n_intervals = n_intervals self.log_interval = log_interval self.converged = False self.best_avg_loss = None self.loss_list = [] self.loss_improvement_counter = 0 def check_convergence(self, epoch, global_step, loss): self.loss_list.append(loss) if len(self.loss_list) > 100: self.loss_list.pop(0) if (global_step - 1) % 100 == 0 and global_step != 1: cur_mean_loss = np.mean(self.loss_list) if self.best_avg_loss is None: self.best_avg_loss = cur_mean_loss elif cur_mean_loss < self.best_avg_loss: self.best_avg_loss = cur_mean_loss if self.loss_improvement_counter >= 1: self.loss_improvement_counter = 0 elif cur_mean_loss >= self.best_avg_loss: self.loss_improvement_counter += 1 if self.loss_improvement_counter >= self.n_intervals: self.converged = True if (global_step - 1) % self.log_interval == 0: print("\rEpoch = {:8d}".format(epoch), "Iter. = {:8d}".format(global_step), "Cur. loss = {:7.2f}".format(loss), " Intervals no change = {:4d}".format(self.loss_improvement_counter), end = "") def get_thresholds(rng, n_cat): lower_prob = 1/(1+math.exp(-rng[0])) upper_prob = 1/(1+math.exp(-rng[1])) probs = torch.linspace(lower_prob, upper_prob, n_cat + 1)[1:-1] thresholds = -probs.pow(-1).add(-1).log() return thresholds
[docs]def invert_factors(mat: torch.Tensor): """ For each factor, flip sign if sum of loadings is negative. Parameters __________ mat : Tensor Loadings matrix. Returns ------- inverted_mat : Tensor Inverted loadings matrix. """ assert(len(mat.shape) == 2), "Loadings matrix must be 2D." mat = mat.clone() for col_idx in range(mat.shape[1]): if mat[:, col_idx].sum() < 0: mat[:, col_idx] = -mat[:, col_idx] return mat
[docs]def invert_cov(cov: torch.Tensor, mat: torch.Tensor, ): """ Flip factor covariances according to loadings signs. Parameters __________ cov : Tensor Factor covariance matrix. mat : Tensor Loadings matrix. Returns ------- inverted_cov : Tensor Inverted factor covariance matrix. """ assert(len(cov.shape) == 2), "Factor covariance matrix must be 2D." assert(len(mat.shape) == 2), "Loadings matrix must be 2D." cov = cov.clone() for col_idx in range(mat.shape[1]): if mat[:, col_idx].sum() < 0: inv_col_idxs = np.delete(np.arange(cov.shape[1]), col_idx, 0) cov[:, inv_col_idxs] = -cov[:, inv_col_idxs] cov[inv_col_idxs, :] = -cov[inv_col_idxs, :] return cov
[docs]def invert_mean(mean: torch.Tensor, mat: torch.Tensor, ): """ Flip factor means according to loadings signs. Parameters __________ mean : Tensor Factor mean vector. mat : Tensor Loadings matrix. Returns ------- inverted_mean : Tensor Inverted factor mean vector. """ assert(len(mean.shape) == 1), "Factor mean vector must be 1D." assert(len(mat.shape) == 2), "Loadings matrix must be 2D." mean = mean.clone() for col_idx in range(mat.shape[1]): if mat[:, col_idx].sum() < 0: mean[col_idx] = -mean[col_idx] return mean
[docs]def invert_latent_regression_weight(weight: torch.Tensor, mat: torch.Tensor, ): """ Flip latent regression weights according to loadings signs. Parameters __________ weight : Tensor Latent regression weight matrix. mat : Tensor Loadings matrix. Returns ------- inverted_weight : Tensor Inverted latent regression weight matrix. """ assert(len(weight.shape) == 2), "Latent regression weight matrix must be 2D." assert(len(mat.shape) == 2), "Loadings matrix must be 2D." weight = weight.clone() for col_idx in range(mat.shape[1]): if mat[:, col_idx].sum() < 0: weight[col_idx, :] = -weight[col_idx, :] return weight
[docs]def normalize_loadings(mat: torch.Tensor): """ Convert loadings to normal ogive metric (only for IRT models). Parameters __________ mat : Tensor Loadings matrix. Returns ------- normalized_mat : Tensor Normalized loadings matrix. """ assert(len(mat.shape) == 2), "Loadings matrix must be 2D." mat = mat.clone().div(1.702) scale_const = mat.pow(2).sum(dim = 1).add(1).sqrt() return (mat.T / scale_const).T
[docs]def normalize_ints(ints: torch.Tensor, mat: torch.Tensor, n_cats: List[int], ): """ Convert intercepts to normal ogive metric (only for IRT models). Parameters __________ ints : Tensor Intercepts vector. mat : Tensor Loadings matrix. n_cats : list of int Number of categories for each item. Returns ------- normalized_ints : Tensor Normalized intercepts vector. """ assert(len(ints.shape) == 1), "Intercepts vector must be 1D." assert(len(mat.shape) == 2), "Loadings matrix must be 2D." ints = ints.clone() n_cats = [1] + n_cats idxs = np.cumsum([n_cat - 1 for n_cat in n_cats]) sliced_ints = [ints[idxs[i]:idxs[i + 1]] for i in range(len(idxs) - 1)] mat = mat.clone().div(1.702) scale_const = mat.pow(2).sum(dim = 1).add(1).sqrt() return torch.cat([sliced_int / scale_const[i] for i, sliced_int in enumerate(sliced_ints)], dim = 0)
class MultiCategorical(Distribution): """ Multiple categorical distributions. https://github.com/pytorch/pytorch/issues/43250 """ def __init__(self, dists: List[Categorical]): super().__init__() self.dists = dists def log_prob(self, value): ans = [] for d, v in zip(self.dists, torch.split(value, 1, dim=-1)): ans.append(d.log_prob(v.squeeze(-1))) return torch.stack(ans, dim=-1).sum(dim=-1) def entropy(self): return torch.stack([d.entropy() for d in self.dists], dim=-1).sum(dim=-1) def sample(self, sample_shape=torch.Size()): return torch.stack([d.sample(sample_shape) for d in self.dists], dim=-1) def multi_categorical_maker(nvec): def get_multi_categorical(probs): start = 0 ans = [] for n in nvec: ans.append(Categorical(probs=probs[..., start: start + n])) start += n return MultiCategorical(ans) return get_multi_categorical class tensor_dataset(Dataset): def __init__(self, data, mask = None, covariates = None, ): self.data = data self.mask = mask self.covariates = covariates def __len__(self): return len(self.data) def __getitem__(self, idx): batch = {} batch["y"] = self.data[idx] if self.mask is not None: batch["mask"] = self.mask[idx] if self.covariates is not None: batch["covariates"] = self.covariates[idx] return batch