Source code for neuralee._aux.error

import torch
import math

[docs]@torch.no_grad() def error_ee(X, Wp, Wn, lam): """Elastic embedding loss function. It's quite straightforward, may unapplicable when size is large, and the alternative error_ee_cpu and error_ee_cuda are designed to release computation stress. :param X: sample-coordinates matrix. :type X: torch.FloatTensor :param Wp: attractive weights. :type Wp: torch.FloatTensor :param Wn: repulsive weights. :type Wn: torch.FloatTensor :param lam: trade-off factor of elastic embedding function. :returns: elastic embedding loss value and the kernel matrix. """ sqd = sqdist(X) ker = torch.exp(-sqd) error = Wp.view(-1).dot(sqd.view(-1)) + lam * Wn.view(-1).dot(ker.view(-1)) return error, ker
@torch.no_grad() def sqdist(X): """Euclidean distance of coordinates. :param X: sample-coordinates matrix. :type X: torch.FloatTensor :return: Euclidean distance. :rtype: torch.FloatTensor """ x = (X ** 2).sum(dim=1, keepdim=True) sqd = x - 2 * X @ X.t() + x.t() ind = torch.arange(X.shape[0]).tolist() sqd[ind, ind] = torch.zeros( X.shape[0], device=X.device, dtype=torch.float32) sqd = sqd.clamp_min(0) return sqd
[docs]@torch.no_grad() def error_ee_split(X, Wp, Wn, lam, memory=2, device=None): """Elastic embedding loss function deployed on GPU. It splits X, Wp, Wn into pieces and summarizes respective loss values to release computation stress. :param X: sample-coordinates matrix :type X: torch.FloatTensor :param Wp: attractive weights. :type Wp: torch.FloatTensor :param Wn: repulsive weights. :type Wn: torch.FloatTensor :param lam: trade-off factor of elastic embedding function. :param memory: memory(GB) allocated to computer error. :type device: torch.device :param device: device chosen to operate. If None, set as torch.device('cpu'). :type device: torch.device :returns: elastic embedding loss value. """ device = X.device if device is None else device X = N = X.shape[0] B = math.floor((memory * 1024 ** 3) / (2 * N * 8)) error = 0 i1 = 0 i2 = min(N, B) X2 = X ** 2 x2 = X2.sum(dim=1, keepdim=True) while i1 < N: sqd = X2[i1: i2, :].sum(dim=1, keepdim=True) - \ 2 * X[i1: i2, :] @ X.t() + x2.t() ker = (-sqd).exp() error += Wp[i1: i2, :].to(device).view(-1).dot(sqd.view(-1)) + \ lam * Wn[i1: i2, :].to(device).view(-1).dot(ker.view(-1)) i1 = i1 + B i2 = min(N, i1 + B) return error