932 lines
32 KiB
Python
932 lines
32 KiB
Python
"""
|
|
Data Driven Initialization for Neural Network Models.
|
|
|
|
Author
|
|
------
|
|
Nicolas Rojas
|
|
"""
|
|
|
|
import torch
|
|
|
|
|
|
def linear_classification_weights(X: torch.Tensor, method: str) -> torch.Tensor:
|
|
"""Initialize the weights of a linear layer, using the input data.
|
|
|
|
Parameters
|
|
----------
|
|
X : torch.Tensor
|
|
Input data to the linear classification.
|
|
method : str
|
|
Method to be used for initialization. Can be either "mean" or "median".
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
Weights of the linear layer.
|
|
"""
|
|
# sometimes, there are no elements for the given class. Return 0 if that is the case
|
|
if len(X) == 0:
|
|
return torch.zeros(1, X.shape[1])
|
|
# weights can be mean or median of input data
|
|
if method == "mean":
|
|
weights = torch.mean(X, dim=0, keepdim=True)
|
|
elif method == "median":
|
|
weights = torch.median(X, dim=0, keepdim=True)[0]
|
|
# normalize weights
|
|
weights /= torch.norm(weights)
|
|
|
|
return weights
|
|
|
|
|
|
def separe_class_projections(
|
|
p0: torch.Tensor, p1: torch.Tensor, method: str
|
|
) -> torch.Tensor:
|
|
"""Find the value that best separates the projections of two different classes.
|
|
|
|
Parameters
|
|
----------
|
|
p0 : torch.Tensor
|
|
Projections of the first class.
|
|
p1 : torch.Tensor
|
|
Projections of the second class.
|
|
method : str
|
|
Method to be used for separation. Can be either "mean" or "quadratic".
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
Value that best separates the projections of the two classes.
|
|
"""
|
|
# if classes dont overlap, return middle point of dividing space
|
|
if torch.all(p1 >= p0.max()):
|
|
separation = (p1.min() + p0.max()) / 2
|
|
else:
|
|
# case where classes overlap
|
|
p0_overlap = p0 > p1.min()
|
|
p1_overlap = p1 < p0.max()
|
|
|
|
if method == "mean":
|
|
# separation value is the weighted mean of the overlapping points
|
|
n0 = p0_overlap.sum()
|
|
n1 = p1_overlap.sum()
|
|
sum0 = p0[p0_overlap].sum()
|
|
sum1 = p1[p1_overlap].sum()
|
|
separation = (n0 / n1 * sum1 + n1 / n0 * sum0) / (n1 + n0)
|
|
|
|
elif method == "quadratic":
|
|
# separation value is mid point of quadratic function
|
|
# (points histogram represented as a parabola)
|
|
q = torch.tensor([0, 0.05, 0.50, 0.95, 1])
|
|
lq = len(q)
|
|
pp0_overlap = p0[p0_overlap]
|
|
pp1_overlap = p1[p1_overlap]
|
|
pp1_vals = torch.quantile(pp1_overlap, q, dim=0, keepdim=True).reshape(
|
|
-1, 1
|
|
)
|
|
pp0_vals = torch.quantile(pp0_overlap, 1 - q, dim=0, keepdim=True).reshape(
|
|
-1, 1
|
|
)
|
|
A1, A0 = torch.ones(lq, 3), torch.ones(lq, 3)
|
|
A1[0:, 0] = ((q / 100)) ** 2
|
|
A1[0:, 1] = q / 100
|
|
A0[0:, 0] = (1 - (q / 100)) ** 2
|
|
A0[0:, 1] = 1 - (q / 100)
|
|
coeff0 = torch.linalg.pinv(A0.T @ A0) @ (A0.T @ pp0_vals).squeeze()
|
|
coeff1 = torch.linalg.pinv(A1.T @ A1) @ (A1.T @ pp1_vals).squeeze()
|
|
a0, b0, c0 = coeff0[0], coeff0[1], coeff0[2]
|
|
a1, b1, c1 = coeff1[0], coeff1[1], coeff1[2]
|
|
a = a1 - a0
|
|
b = b1 + 2 * a0 + b0
|
|
c = c1 - a0 - c0
|
|
i1 = (-b + (b**2 - 4 * a * c) ** 0.5) / (2 * a)
|
|
i2 = (-b - (b**2 - 4 * a * c) ** 0.5) / (2 * a)
|
|
separation = max(i1, i2)
|
|
|
|
return separation
|
|
|
|
|
|
def linear_classification_bias(
|
|
X0: torch.Tensor, X1: torch.Tensor, weights: torch.Tensor, method: str
|
|
) -> torch.Tensor:
|
|
"""Find the bias of a feed-forward classification layer, given its weights.
|
|
|
|
Parameters
|
|
----------
|
|
X0 : torch.Tensor
|
|
Input data of the first class.
|
|
X1 : torch.Tensor
|
|
Input data of the second class.
|
|
weights : torch.Tensor
|
|
Weights of the linear layer.
|
|
method : str
|
|
Method to be used for bias calculation. Can be either "mean" or "quadratic".
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
Bias of the linear layer.
|
|
"""
|
|
# sometimes, there are no elements for the given class. Return 0 if that is the case
|
|
if len(X0) == 0 or len(X1) == 0:
|
|
return torch.tensor(0)
|
|
|
|
# project observations over weights to get 1D vectors
|
|
p0 = (X0 @ weights.T).squeeze()
|
|
p1 = (X1 @ weights.T).squeeze()
|
|
|
|
# find bias according to class projections
|
|
bias = separe_class_projections(p0, p1, method)
|
|
return bias
|
|
|
|
|
|
def linear_classification_output_layer(
|
|
X: torch.Tensor, y: torch.Tensor, weights_method: str, bias_method: str
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Initialize the output (linear) layer of a classification neural network.
|
|
|
|
Parameters
|
|
----------
|
|
X : torch.Tensor
|
|
Input data to the classification.
|
|
y : torch.Tensor
|
|
Labels of the input data.
|
|
weights_method : str
|
|
Method to be used for weights initialization. Can be either "mean" or "median".
|
|
bias_method : str
|
|
Method to be used for bias initialization. Can be either "mean" or "quadratic".
|
|
|
|
Returns
|
|
-------
|
|
tuple[torch.Tensor, torch.Tensor]
|
|
Weights and bias of the output layer.
|
|
"""
|
|
all_classes = torch.unique(y)
|
|
all_weights = []
|
|
all_biases = []
|
|
binary = len(all_classes) == 2
|
|
|
|
# when there are only 2 classes, it is only necessary to consider class 1
|
|
if binary:
|
|
all_classes = all_classes[1:]
|
|
|
|
for yi in all_classes:
|
|
# for each class, initialize weights and biases
|
|
X0 = X[y != yi]
|
|
X1 = X[y == yi]
|
|
|
|
weights = linear_classification_weights(X1, weights_method)
|
|
bias = linear_classification_bias(X0, X1, weights, bias_method)
|
|
|
|
all_weights.append(weights)
|
|
all_biases.append(bias)
|
|
|
|
# transform lists of tensors to tensors
|
|
if binary:
|
|
all_weights = all_weights[0]
|
|
else:
|
|
all_weights = torch.cat(all_weights)
|
|
all_biases = torch.tensor(all_biases)
|
|
|
|
return all_weights, all_biases
|
|
|
|
|
|
def select_support_vectors(
|
|
X: torch.Tensor, y: torch.Tensor, distances, num_neurons: int, num_classes: int
|
|
) -> tuple[torch.Tensor, list[int]]:
|
|
"""
|
|
Find the support vectors for a set of vectors.
|
|
|
|
Support vectors are defined as the elements of a certain class
|
|
that are closer to elements from a different class.
|
|
|
|
Parameters
|
|
----------
|
|
X : torch.Tensor
|
|
Input data to the classification.
|
|
y : torch.Tensor
|
|
Labels of the input data.
|
|
distances : torch.Tensor
|
|
Pairwise distances between input data.
|
|
num_neurons : int
|
|
Number of neurons to be used in the layer.
|
|
num_classes : int
|
|
Number of classes in the input data.
|
|
|
|
Returns
|
|
-------
|
|
tuple[torch.Tensor, list[int]]
|
|
Support vectors and their corresponding classes.
|
|
"""
|
|
# get how many neurons should belong to each class
|
|
quotient, remainder = divmod(num_neurons, num_classes)
|
|
neurons_classes = [quotient] * num_classes
|
|
for idx in range(remainder):
|
|
neurons_classes[idx] += 1
|
|
|
|
vectors = []
|
|
classes = []
|
|
# iterate over each class with its number of corresponding neurons
|
|
for label, num_vectors in enumerate(neurons_classes):
|
|
# get elements belonging to desired class
|
|
label_indices = y == label
|
|
X1 = X[label_indices]
|
|
# obtain distances between elements belonging to class and elements
|
|
label_outside_distances = distances[label_indices, :][:, ~label_indices]
|
|
# get mean distance for each element in class and order from lesser to greater
|
|
label_outside_distances = label_outside_distances.mean(dim=1)
|
|
min_elements = label_outside_distances.argsort()[:num_vectors]
|
|
# vectors closer to elements from other classes are support vectors
|
|
vectors.append(X1[min_elements])
|
|
classes.extend([label] * num_vectors)
|
|
vectors = torch.cat(vectors)
|
|
# return vectors with their corresponding classes
|
|
return vectors, classes
|
|
|
|
|
|
def linear_classification_hidden_layer(
|
|
X: torch.Tensor,
|
|
y: torch.Tensor,
|
|
num_neurons: int,
|
|
num_classes: int,
|
|
weights_method: str,
|
|
bias_method: str,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Initialize a hidden linear layer of a classification neural network.
|
|
|
|
Parameters
|
|
----------
|
|
X : torch.Tensor
|
|
Input data to the classification.
|
|
y : torch.Tensor
|
|
Labels of the input data.
|
|
num_neurons : int
|
|
Number of neurons to be used in the layer.
|
|
num_classes : int
|
|
Number of classes in the input data.
|
|
weights_method : str
|
|
Method to be used for weights initialization. Can be either "mean" or "median".
|
|
bias_method : str
|
|
Method to be used for bias initialization. Can be either "mean" or "quadratic".
|
|
|
|
Returns
|
|
-------
|
|
tuple[torch.Tensor, torch.Tensor]
|
|
Weights and bias of the hidden layer.
|
|
"""
|
|
# get pairwise distances of input observations
|
|
distances = torch.cdist(X, X)
|
|
# get support vectors for the layer
|
|
characteristic_vectors, cv_classes = select_support_vectors(
|
|
X, y, distances, num_neurons, num_classes
|
|
)
|
|
|
|
# get distances between support vectors and every element
|
|
distances = torch.cdist(characteristic_vectors, X)
|
|
# neighborhoods are Voronoi regions of support vectors
|
|
neighborhoods = distances.argmin(dim=0)
|
|
|
|
layer_weights = []
|
|
layer_bias = []
|
|
for neighborhood in range(num_neurons):
|
|
# get points belonging to neighborhood
|
|
close_points = neighborhoods == neighborhood
|
|
close_X = X[close_points]
|
|
close_y = y[close_points]
|
|
|
|
# get elements belonging to same class as support vector
|
|
k = cv_classes[neighborhood]
|
|
X0 = close_X[close_y != k]
|
|
X1 = close_X[close_y == k]
|
|
|
|
# get weights and biases of layer using elements in same class as support vector
|
|
weights = linear_classification_weights(
|
|
X1 - X1.mean(dim=1, keepdim=True), weights_method
|
|
)
|
|
bias = linear_classification_bias(X0, X1, weights, bias_method)
|
|
|
|
layer_weights.append(weights)
|
|
layer_bias.append(bias)
|
|
|
|
layer_weights = torch.cat(layer_weights)
|
|
layer_bias = torch.tensor(layer_bias)
|
|
# return weights and biases of each layer as single tensor
|
|
return layer_weights, layer_bias
|
|
|
|
|
|
def rnn_bias(
|
|
X0: torch.Tensor,
|
|
X1: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
h0: torch.Tensor,
|
|
h1: torch.Tensor,
|
|
h_weights: torch.Tensor,
|
|
method: str,
|
|
) -> torch.Tensor:
|
|
"""Find the bias of a recurrent classification layer, given all of its weights.
|
|
|
|
Parameters
|
|
----------
|
|
X0 : torch.Tensor
|
|
Input data of the first class.
|
|
X1 : torch.Tensor
|
|
Input data of the second class.
|
|
weights : torch.Tensor
|
|
Weights of the linear layer.
|
|
h0 : torch.Tensor
|
|
Hidden state of the first class.
|
|
h1 : torch.Tensor
|
|
Hidden state of the second class.
|
|
h_weights : torch.Tensor
|
|
Weights of the hidden state.
|
|
method : str
|
|
Method to be used for bias calculation. Can be either "mean" or "quadratic".
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
Bias of the recurrent layer.
|
|
"""
|
|
# sometimes, there are no elements for the given class. Return 0 if that is the case
|
|
if len(X0) == 0 or len(X1) == 0:
|
|
return torch.tensor(0)
|
|
|
|
# project observations over weights to get 1D vectors
|
|
p0 = (X0 @ weights.T + h0 @ h_weights.T).squeeze()
|
|
p1 = (X1 @ weights.T + h1 @ h_weights.T).squeeze()
|
|
|
|
# find bias according to class projections
|
|
bias = separe_class_projections(p0, p1, method)
|
|
return bias
|
|
|
|
|
|
def rnn_hidden_layer(
|
|
X: torch.Tensor,
|
|
h: torch.Tensor,
|
|
y: torch.Tensor,
|
|
num_neurons: int,
|
|
num_classes: int,
|
|
weights_method: str,
|
|
bias_method: str,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Initialize a recurrent layer of a classification neural network.
|
|
|
|
Parameters
|
|
----------
|
|
X : torch.Tensor
|
|
Input data to the classification.
|
|
h : torch.Tensor
|
|
Hidden state of the input data.
|
|
y : torch.Tensor
|
|
Labels of the input data.
|
|
num_neurons : int
|
|
Number of neurons to be used in the layer.
|
|
num_classes : int
|
|
Number of classes in the input data.
|
|
weights_method : str
|
|
Method to be used for weights initialization. Can be either "mean" or "median".
|
|
bias_method : str
|
|
Method to be used for bias initialization. Can be either "mean" or "quadratic".
|
|
|
|
Returns
|
|
-------
|
|
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
|
Weights, hidden state weights, and bias of the recurrent layer.
|
|
"""
|
|
# get pairwise distances of input observations
|
|
distances = torch.cdist(X, X)
|
|
# get support vectors for the layer
|
|
characteristic_vectors, cv_classes = select_support_vectors(
|
|
X, y, distances, num_neurons, num_classes
|
|
)
|
|
|
|
# get distances between support vectors and every element
|
|
distances = torch.cdist(characteristic_vectors, X)
|
|
# neighborhoods are Voronoi regions of support vectors
|
|
neighborhoods = distances.argmin(dim=0)
|
|
|
|
layer_weights = []
|
|
layer_h_weights = []
|
|
layer_bias = []
|
|
for neighborhood in range(num_neurons):
|
|
# get points from X, y, and h belonging to neighborhood
|
|
close_points = neighborhoods == neighborhood
|
|
close_X = X[close_points]
|
|
close_h = h[close_points]
|
|
close_y = y[close_points]
|
|
|
|
# get elements belonging to same class as support vector
|
|
k = cv_classes[neighborhood]
|
|
X0 = close_X[close_y != k]
|
|
X1 = close_X[close_y == k]
|
|
h0 = close_h[close_y != k]
|
|
h1 = close_h[close_y == k]
|
|
|
|
# get weights and biases of layer using elements in same class as support vector
|
|
weights = linear_classification_weights(
|
|
X1 - X1.mean(dim=1, keepdim=True), weights_method
|
|
)
|
|
h_weights = linear_classification_weights(
|
|
h1 - h1.mean(dim=1, keepdim=True), weights_method
|
|
)
|
|
bias = rnn_bias(X0, X1, weights, h0, h1, h_weights, bias_method)
|
|
|
|
layer_weights.append(weights)
|
|
layer_h_weights.append(h_weights)
|
|
layer_bias.append(bias)
|
|
|
|
layer_weights = torch.cat(layer_weights)
|
|
layer_h_weights = torch.cat(layer_h_weights)
|
|
layer_bias = torch.tensor(layer_bias)
|
|
# return weights and biases of each layer as single tensor
|
|
return layer_weights, layer_h_weights, layer_bias
|
|
|
|
|
|
def conv_bias(
|
|
X: torch.Tensor,
|
|
y: torch.Tensor,
|
|
kernel: torch.Tensor,
|
|
kernel_classes: list[int],
|
|
method: str,
|
|
) -> torch.Tensor:
|
|
"""Find the bias of a convolutional classification layer, given its kernel.
|
|
|
|
Parameters
|
|
----------
|
|
X : torch.Tensor
|
|
Input data of the classification.
|
|
y : torch.Tensor
|
|
Labels of the input data.
|
|
kernel : torch.Tensor
|
|
Kernel of the convolutional layer.
|
|
kernel_classes : list[int]
|
|
Classes of the kernel.
|
|
method : str
|
|
Method to be used for bias calculation. Can be either "mean" or "quadratic".
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
Bias of the convolutional layer.
|
|
"""
|
|
# relevant values are kernel dimensions
|
|
out_dims, in_dims, rows, columns = kernel.shape
|
|
layer_bias = []
|
|
# find the bias of each individual kernel layer
|
|
for iteration in range(out_dims):
|
|
# find the points corresponding to the same class as the kernel layer
|
|
iteration_class = kernel_classes[iteration]
|
|
X0 = X[y == iteration_class]
|
|
X1 = X[y != iteration_class]
|
|
iteration_kernel = kernel[iteration : iteration + 1]
|
|
|
|
# get convolution of observations and kernel layer to get 1D vectors
|
|
p0 = torch.nn.functional.conv2d(X0, iteration_kernel).flatten()
|
|
p1 = torch.nn.functional.conv2d(X1, iteration_kernel).flatten()
|
|
|
|
# find bias according to class projections
|
|
bias = separe_class_projections(p0, p1, method)
|
|
layer_bias.append(bias)
|
|
|
|
# return inverse of kernel to center class division
|
|
layer_bias = -torch.tensor(layer_bias)
|
|
return layer_bias
|
|
|
|
|
|
def init_weights_conv(
|
|
X: torch.Tensor,
|
|
y: torch.Tensor,
|
|
out_channels: int,
|
|
kernel_row: int,
|
|
kernel_col: int,
|
|
num_classes: int,
|
|
bias_method: str,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Initialize a convolutional layer of a classification neural network.
|
|
|
|
Parameters
|
|
----------
|
|
X : torch.Tensor
|
|
Input data to the classification.
|
|
y : torch.Tensor
|
|
Labels of the input data.
|
|
out_channels : int
|
|
Number of output channels of the convolutional layer.
|
|
kernel_row : int
|
|
Number of rows of the kernel.
|
|
kernel_col : int
|
|
Number of columns of the kernel.
|
|
num_classes : int
|
|
Number of classes in the input data.
|
|
bias_method : str
|
|
Method to be used for bias initialization. Can be either "mean" or "quadratic".
|
|
|
|
Returns
|
|
-------
|
|
tuple[torch.Tensor, torch.Tensor]
|
|
Weights and bias of the convolutional layer.
|
|
"""
|
|
# get dimensions of input data
|
|
num_data, num_channels, rows, columns = X.shape
|
|
fragments = []
|
|
fragments_classes = []
|
|
# get number of row and column partitions over data
|
|
row_iters = rows // kernel_row
|
|
col_iters = columns // kernel_col
|
|
|
|
for unique_class in range(num_classes):
|
|
# get median element of each class to represent it
|
|
X_class = X[y == unique_class]
|
|
X_class = X.median(dim=0).values
|
|
# iterate over groups of rows and columns in median object
|
|
for row_idx in range(row_iters):
|
|
for col_idx in range(col_iters):
|
|
# flatten fragment to use as kernel candidate
|
|
fragment = X_class[
|
|
:,
|
|
kernel_row * row_idx : kernel_row * (row_idx + 1),
|
|
kernel_col * col_idx : kernel_col * (col_idx + 1),
|
|
]
|
|
fragment = fragment.flatten()
|
|
if fragment.norm() <= 0.01:
|
|
continue
|
|
fragments.append(fragment)
|
|
fragments_classes.append(unique_class)
|
|
|
|
fragments = torch.stack(fragments)
|
|
# normalize fragments
|
|
fragments -= fragments.mean(1, keepdim=True)
|
|
# if there are not enough fragments, use existing fragments plus a
|
|
# random noise as extra fragments until fragment number is enough
|
|
# (at least equal to the number of output channels)
|
|
while fragments.shape[0] < out_channels:
|
|
difference = out_channels - fragments.shape[0]
|
|
difference = fragments[: min(difference, fragments.shape[0])]
|
|
difference += torch.normal(0, 0.1, size=difference.shape)
|
|
fragments = torch.cat((fragments, difference), 0)
|
|
|
|
# get pairwise correlations of kernel candidates
|
|
correlations = torch.zeros(len(fragments), len(fragments))
|
|
for idx1 in range(len(fragments)):
|
|
for idx2 in range(idx1, len(fragments)):
|
|
correlations[idx1, idx2] = abs(
|
|
torch.nn.functional.cosine_similarity(
|
|
fragments[idx1], fragments[idx2], dim=0
|
|
)
|
|
)
|
|
correlations[idx2, idx1] = correlations[idx1, idx2]
|
|
|
|
fragments_classes = torch.tensor(fragments_classes)
|
|
# find optimal kernels from kernel candidates, using support vectors method
|
|
characteristic_vectors, kernel_classes = select_support_vectors(
|
|
fragments, fragments_classes, correlations, out_channels, num_classes
|
|
)
|
|
current_num_weights = characteristic_vectors.shape[0]
|
|
# un-flatten selected kernels
|
|
characteristic_vectors = characteristic_vectors.reshape(
|
|
(current_num_weights, num_channels, kernel_row, kernel_col)
|
|
)
|
|
# normalize selected kernels
|
|
for weight in range(current_num_weights):
|
|
for channel in range(num_channels):
|
|
characteristic_vectors[weight, channel, :, :] /= torch.linalg.matrix_norm(
|
|
characteristic_vectors[weight, channel, :, :]
|
|
)
|
|
|
|
# if there are not enough kernels, use existing kernels plus a
|
|
# random noise as extra kernels until kernel number is enough
|
|
# (at least equal to the number of output channels)
|
|
while current_num_weights < out_channels:
|
|
difference = out_channels - current_num_weights
|
|
difference = characteristic_vectors[: min(difference, current_num_weights)]
|
|
difference += torch.normal(0, 0.1, size=difference.shape)
|
|
characteristic_vectors = torch.cat((characteristic_vectors, difference), 0)
|
|
current_num_weights = characteristic_vectors.shape[0]
|
|
|
|
# find layer biases using selected kernels
|
|
layer_bias = conv_bias(X, y, characteristic_vectors, kernel_classes, bias_method)
|
|
return characteristic_vectors, layer_bias
|
|
|
|
|
|
def init_weights_classification(
|
|
model: torch.nn.Module,
|
|
X: torch.Tensor,
|
|
y: torch.Tensor,
|
|
weights_method: str = "mean",
|
|
bias_method: str = "mean",
|
|
) -> torch.nn.Module:
|
|
"""Initialize a classification neural network.
|
|
|
|
Parameters
|
|
----------
|
|
model : torch.nn.Module
|
|
Neural network model to be initialized.
|
|
X : torch.Tensor
|
|
Input data to the classification.
|
|
y : torch.Tensor
|
|
Labels of the input data.
|
|
weights_method : str, optional
|
|
Method to be used for weights initialization. Can be either "mean" or "median".
|
|
Default is "mean".
|
|
bias_method : str, optional
|
|
Method to be used for bias initialization. Can be either "mean" or "quadratic".
|
|
Default is "mean".
|
|
|
|
Returns
|
|
-------
|
|
torch.nn.Module
|
|
Initialized neural network model.
|
|
|
|
Raises
|
|
------
|
|
ValueError
|
|
If X and y have different number of samples.
|
|
If input data has less than 2 classes.
|
|
|
|
Example
|
|
-------
|
|
>>> import torch
|
|
>>> from torch import nn
|
|
>>> num_classes = 3
|
|
>>> num_data = 20
|
|
>>> num_features = 10
|
|
>>> hidden_size = 5
|
|
>>> X = torch.rand(num_data, num_features)
|
|
>>> y = torch.randint(0, num_classes, (num_data,))
|
|
>>> model = nn.Sequential(
|
|
>>> nn.Linear(num_features, hidden_size),
|
|
>>> nn.ReLU(),
|
|
>>> nn.Linear(hidden_size, num_classes)
|
|
>>> )
|
|
>>> model = init_weights_classification(model, X, y)
|
|
"""
|
|
# Throw error when X and y have different number of samples
|
|
if X.shape[0] != y.shape[0]:
|
|
raise ValueError("X and y must have the same number of samples")
|
|
|
|
num_classes = len(torch.unique(y))
|
|
# There must be at least two classes. Else, throw error
|
|
if num_classes < 2:
|
|
raise ValueError("Input data must present at least 2 classes")
|
|
|
|
# Model is a single linear layer
|
|
if isinstance(model, torch.nn.Linear):
|
|
# find weights and biases of single layer
|
|
weights, bias = linear_classification_output_layer(
|
|
X, y, weights_method, bias_method
|
|
)
|
|
model.weight = torch.nn.Parameter(weights)
|
|
model.bias = torch.nn.Parameter(bias)
|
|
return model
|
|
|
|
# Model is a sequential model with at least one linear layer (output)
|
|
num_layers = len(model) - 1
|
|
|
|
for layer in range(num_layers):
|
|
if isinstance(model[layer], torch.nn.Linear):
|
|
# layer is linear layer
|
|
num_neurons = model[layer].out_features
|
|
layer_weights, layer_bias = linear_classification_hidden_layer(
|
|
X, y, num_neurons, num_classes, weights_method, bias_method
|
|
)
|
|
model[layer].weight = torch.nn.Parameter(layer_weights)
|
|
model[layer].bias = torch.nn.Parameter(layer_bias)
|
|
elif isinstance(model[layer], torch.nn.Conv2d):
|
|
# layer is 2D convolutional layer
|
|
kernel_row, kernel_col = model[layer].kernel_size
|
|
out_channels = model[layer].out_channels
|
|
layer_weights, layer_bias = init_weights_conv(
|
|
X, y, out_channels, kernel_row, kernel_col, num_classes, bias_method
|
|
)
|
|
model[layer].weight = torch.nn.Parameter(layer_weights)
|
|
model[layer].bias = torch.nn.Parameter(layer_bias)
|
|
elif isinstance(model[layer], torch.nn.RNN):
|
|
# layer is (stack of) recurrent layer. This kind of
|
|
# layer requires an special treatment due to the way
|
|
# it is represented in pytorch: not as a group of
|
|
# layers but as a single object with multiple weights
|
|
# and biases
|
|
num_rnn_layers = model[layer].num_layers
|
|
num_neurons = model[layer].hidden_size
|
|
activation = (
|
|
torch.nn.functional.tanh
|
|
if model[layer].nonlinearity == "tanh"
|
|
else torch.nn.functional.relu
|
|
)
|
|
# get last element in sequence
|
|
layer_X = X[:, -1, :].detach().clone()
|
|
for layer_idx in range(num_rnn_layers):
|
|
# initialize the x-weights of each recurrent layer in stack
|
|
layer_weights, layer_bias = linear_classification_hidden_layer(
|
|
layer_X, y, num_neurons, num_classes, weights_method, bias_method
|
|
)
|
|
setattr(
|
|
model[layer],
|
|
f"weight_ih_l{layer_idx}",
|
|
torch.nn.Parameter(layer_weights),
|
|
)
|
|
setattr(
|
|
model[layer],
|
|
f"bias_ih_l{layer_idx}",
|
|
torch.nn.Parameter(layer_bias),
|
|
)
|
|
# propagate layer_x through recurrent stack
|
|
layer_X = activation(layer_X @ layer_weights.T + layer_bias)
|
|
# obtain final state of h for each recurrent layer in stack to use as h0
|
|
_, h0 = model[layer](X)
|
|
# get last element in sequence
|
|
layer_X = X[:, -1, :].detach().clone()
|
|
for layer_idx in range(num_rnn_layers):
|
|
# initialize the x-weights and h-weights of each recurrent layer in stack
|
|
layer_weights, layer_h_weights, layer_bias = rnn_hidden_layer(
|
|
layer_X,
|
|
h0[layer_idx],
|
|
y,
|
|
num_neurons,
|
|
num_classes,
|
|
weights_method,
|
|
bias_method,
|
|
)
|
|
setattr(
|
|
model[layer],
|
|
f"weight_ih_l{layer_idx}",
|
|
torch.nn.Parameter(layer_weights),
|
|
)
|
|
setattr(
|
|
model[layer],
|
|
f"bias_ih_l{layer_idx}",
|
|
torch.nn.Parameter(layer_bias),
|
|
)
|
|
setattr(
|
|
model[layer],
|
|
f"weight_hh_l{layer_idx}",
|
|
torch.nn.Parameter(layer_h_weights),
|
|
)
|
|
setattr(
|
|
model[layer],
|
|
f"bias_hh_l{layer_idx}",
|
|
torch.nn.Parameter(torch.zeros_like(layer_bias)),
|
|
)
|
|
# propagate layer_x through recurrent stack
|
|
layer_X = activation(
|
|
layer_X @ layer_weights.T
|
|
+ h0[layer_idx] @ layer_h_weights.T
|
|
+ layer_bias
|
|
)
|
|
# propagate X no matter the layers type
|
|
X = model[layer](X)
|
|
|
|
# Last layer (linear output layer)
|
|
layer_weights, layer_bias = linear_classification_output_layer(
|
|
X, y, weights_method, bias_method
|
|
)
|
|
model[num_layers].weight = torch.nn.Parameter(layer_weights)
|
|
model[num_layers].bias = torch.nn.Parameter(layer_bias)
|
|
|
|
return model
|
|
|
|
|
|
def linear_regression(
|
|
X: torch.Tensor, y: torch.Tensor
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Fit a single linear regression over input data.
|
|
|
|
Parameters
|
|
----------
|
|
X : torch.Tensor
|
|
Input data to the regression.
|
|
y : torch.Tensor
|
|
Target values of the input data.
|
|
|
|
Returns
|
|
-------
|
|
tuple[torch.Tensor, torch.Tensor]
|
|
Weights and bias of the linear regression.
|
|
"""
|
|
# expand X with a column of ones to find bias along with weights
|
|
ones = torch.ones(X.shape[0], 1)
|
|
X = torch.cat((ones, X), dim=1)
|
|
# find parameters by multiplying pseudoinverse of X with y
|
|
weights = torch.linalg.pinv(X) @ y
|
|
# bias is first column of parameters and weights are the remaining columns
|
|
bias = weights[0]
|
|
weights = torch.unsqueeze(weights[1:], dim=0)
|
|
# return weights and biases
|
|
return weights, bias
|
|
|
|
|
|
def piecewise_linear_regression(
|
|
X: torch.Tensor, y: torch.Tensor, num_pieces: int
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Fit multiple linear regressions over different sections of input data.
|
|
|
|
Parameters
|
|
----------
|
|
X : torch.Tensor
|
|
Input data to the regression.
|
|
y : torch.Tensor
|
|
Target values of the input data.
|
|
num_pieces : int
|
|
Number of segments to divide the input data.
|
|
|
|
Returns
|
|
-------
|
|
tuple[torch.Tensor, torch.Tensor]
|
|
Weights and bias of the piecewise linear regression.
|
|
"""
|
|
# order data according to X dimensions values
|
|
ordered_idx = torch.argsort(X, dim=0)[:, 0]
|
|
X = X[ordered_idx]
|
|
y = y[ordered_idx]
|
|
# find the size of a segment
|
|
piece_length = len(y) // num_pieces
|
|
all_weights = []
|
|
all_biases = []
|
|
|
|
# iterate over every segment
|
|
for piece in range(num_pieces):
|
|
# get data belonging to segment
|
|
piece_idx = range(piece_length * piece, piece_length * (piece + 1))
|
|
partial_X = X[piece_idx]
|
|
partial_y = y[piece_idx]
|
|
# fit linear regression over segment to obtain partial weights and biases
|
|
weights, bias = linear_regression(partial_X, partial_y)
|
|
all_weights.append(weights)
|
|
all_biases.append(bias)
|
|
# merge all weights and biases into individual tensors
|
|
all_weights = torch.cat(all_weights, dim=0)
|
|
all_biases = torch.tensor(all_biases)
|
|
# return results
|
|
return all_weights, all_biases
|
|
|
|
|
|
def init_weights_regression(
|
|
model: torch.nn.Module, X: torch.Tensor, y: torch.Tensor
|
|
) -> torch.nn.Module:
|
|
"""Initialize a regression neural network.
|
|
|
|
Parameters
|
|
----------
|
|
model : torch.nn.Module
|
|
Neural network model to be initialized.
|
|
X : torch.Tensor
|
|
Input data to the regression.
|
|
y : torch.Tensor
|
|
Target values of the input data.
|
|
|
|
Returns
|
|
-------
|
|
torch.nn.Module
|
|
Initialized neural network model.
|
|
|
|
Raises
|
|
------
|
|
ValueError
|
|
If X and y have different number of samples.
|
|
|
|
Example
|
|
-------
|
|
>>> import torch
|
|
>>> from torch import nn
|
|
>>> num_data = 20
|
|
>>> num_features = 10
|
|
>>> hidden_size = 5
|
|
>>> X = torch.rand(num_data, num_features)
|
|
>>> y = torch.rand(num_data)
|
|
>>> model = nn.Sequential(
|
|
>>> nn.Linear(num_features, hidden_size),
|
|
>>> nn.ReLU(),
|
|
>>> nn.Linear(hidden_size, 1)
|
|
>>> )
|
|
>>> model = init_weights_regression(model, X, y)
|
|
"""
|
|
# Throw error when X and y have different number of samples
|
|
if X.shape[0] != y.shape[0]:
|
|
raise ValueError("X and y must have the same number of samples")
|
|
|
|
# Model is a single linear layer
|
|
if isinstance(model, torch.nn.Linear):
|
|
# fit singular linear regression and set models parameters
|
|
weights, bias = linear_regression(X, y)
|
|
model.weight = torch.nn.Parameter(weights)
|
|
model.bias = torch.nn.Parameter(bias)
|
|
return model
|
|
|
|
# Model is a sequential model with at least one linear layer (output)
|
|
num_layers = len(model) - 1
|
|
for layer in range(num_layers):
|
|
if isinstance(model[layer], torch.nn.Linear):
|
|
# layer is linear layer
|
|
layer_weights, layer_bias = piecewise_linear_regression(
|
|
X, y, num_pieces=model[layer].out_features
|
|
)
|
|
model[layer].weight = torch.nn.Parameter(layer_weights)
|
|
model[layer].bias = torch.nn.Parameter(layer_bias)
|
|
# propagate X no matter the layers type
|
|
X = model[layer](X)
|
|
|
|
# Last layer (linear output layer)
|
|
layer_weights, layer_bias = linear_regression(X, y)
|
|
model[num_layers].weight = torch.nn.Parameter(layer_weights)
|
|
model[num_layers].bias = torch.nn.Parameter(layer_bias)
|
|
|
|
return model
|