🎉 First commit
This commit is contained in:
commit
aee862f75d
20 changed files with 10389 additions and 0 deletions
160
.gitignore
vendored
Normal file
160
.gitignore
vendored
Normal file
|
@ -0,0 +1,160 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
9
CITATION.cff
Normal file
9
CITATION.cff
Normal file
|
@ -0,0 +1,9 @@
|
|||
cff-version: 1.2.0
|
||||
message: "If you use this software, please cite it as below."
|
||||
authors:
|
||||
- family-names: "Rojas"
|
||||
given-names: "Nicolas"
|
||||
title: "Data driven initialization for neural network models"
|
||||
date-released: 2024-05-15
|
||||
license: ISC
|
||||
url: "https://github.com/nirogu/data-driven-init"
|
5
LICENSE
Normal file
5
LICENSE
Normal file
|
@ -0,0 +1,5 @@
|
|||
Copyright (c) 2024 Nicolas Rojas
|
||||
|
||||
Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
13
README.md
Normal file
13
README.md
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Data driven initialization for neural network models
|
||||
|
||||
This repository contains the reference implementation of the IDEAL method to initialize the parameters of a neural network, using the training data to find adequate initial values for the model's weights and biases.
|
||||
|
||||
## Repository structure
|
||||
|
||||
The initialization methods are implemented in `ideal_init/initialization.py` and can be used with PyTorch models. An example of how to use them with PyTorch Lightning modules can be found in `ideal_init/lightning_model.py`. These modules are used in all the examples in the remaining folders: `tabular_classification`, `tabular_regression`, `image_classification` and `sequence_classification`. These examples are Jupyter notebooks that can be directly run if PyTorch, scikit-learn and PyTorch Lightning are already installed.
|
||||
|
||||
## Method performance
|
||||
|
||||
Although each notebook in the examples can be run independently from the others, we chose to visualize all the results in a single figure, using the minimalistic code in `plot_training.ipynb`. The results are presented in the following image, where IDEAL is shown in blue, the Kaiming He method in orange, the solid lines represent the mean of 10 similar experiments, and the shadows around the lines represent a 95% confidence interval.
|
||||
|
||||

|
932
ideal_init/initialization.py
Normal file
932
ideal_init/initialization.py
Normal file
|
@ -0,0 +1,932 @@
|
|||
"""
|
||||
Data Driven Initialization for Neural Network Models.
|
||||
|
||||
Author
|
||||
------
|
||||
Nicolas Rojas
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def linear_classification_weights(X: torch.Tensor, method: str) -> torch.Tensor:
|
||||
"""Initialize the weights of a linear layer, using the input data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : torch.Tensor
|
||||
Input data to the linear classification.
|
||||
method : str
|
||||
Method to be used for initialization. Can be either "mean" or "median".
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Weights of the linear layer.
|
||||
"""
|
||||
# sometimes, there are no elements for the given class. Return 0 if that is the case
|
||||
if len(X) == 0:
|
||||
return torch.zeros(1, X.shape[1])
|
||||
# weights can be mean or median of input data
|
||||
if method == "mean":
|
||||
weights = torch.mean(X, dim=0, keepdim=True)
|
||||
elif method == "median":
|
||||
weights = torch.median(X, dim=0, keepdim=True)[0]
|
||||
# normalize weights
|
||||
weights /= torch.norm(weights)
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def separe_class_projections(
|
||||
p0: torch.Tensor, p1: torch.Tensor, method: str
|
||||
) -> torch.Tensor:
|
||||
"""Find the value that best separates the projections of two different classes.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
p0 : torch.Tensor
|
||||
Projections of the first class.
|
||||
p1 : torch.Tensor
|
||||
Projections of the second class.
|
||||
method : str
|
||||
Method to be used for separation. Can be either "mean" or "quadratic".
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Value that best separates the projections of the two classes.
|
||||
"""
|
||||
# if classes dont overlap, return middle point of dividing space
|
||||
if torch.all(p1 >= p0.max()):
|
||||
separation = (p1.min() + p0.max()) / 2
|
||||
else:
|
||||
# case where classes overlap
|
||||
p0_overlap = p0 > p1.min()
|
||||
p1_overlap = p1 < p0.max()
|
||||
|
||||
if method == "mean":
|
||||
# separation value is the weighted mean of the overlapping points
|
||||
n0 = p0_overlap.sum()
|
||||
n1 = p1_overlap.sum()
|
||||
sum0 = p0[p0_overlap].sum()
|
||||
sum1 = p1[p1_overlap].sum()
|
||||
separation = (n0 / n1 * sum1 + n1 / n0 * sum0) / (n1 + n0)
|
||||
|
||||
elif method == "quadratic":
|
||||
# separation value is mid point of quadratic function
|
||||
# (points histogram represented as a parabola)
|
||||
q = torch.tensor([0, 0.05, 0.50, 0.95, 1])
|
||||
lq = len(q)
|
||||
pp0_overlap = p0[p0_overlap]
|
||||
pp1_overlap = p1[p1_overlap]
|
||||
pp1_vals = torch.quantile(pp1_overlap, q, dim=0, keepdim=True).reshape(
|
||||
-1, 1
|
||||
)
|
||||
pp0_vals = torch.quantile(pp0_overlap, 1 - q, dim=0, keepdim=True).reshape(
|
||||
-1, 1
|
||||
)
|
||||
A1, A0 = torch.ones(lq, 3), torch.ones(lq, 3)
|
||||
A1[0:, 0] = ((q / 100)) ** 2
|
||||
A1[0:, 1] = q / 100
|
||||
A0[0:, 0] = (1 - (q / 100)) ** 2
|
||||
A0[0:, 1] = 1 - (q / 100)
|
||||
coeff0 = torch.linalg.pinv(A0.T @ A0) @ (A0.T @ pp0_vals).squeeze()
|
||||
coeff1 = torch.linalg.pinv(A1.T @ A1) @ (A1.T @ pp1_vals).squeeze()
|
||||
a0, b0, c0 = coeff0[0], coeff0[1], coeff0[2]
|
||||
a1, b1, c1 = coeff1[0], coeff1[1], coeff1[2]
|
||||
a = a1 - a0
|
||||
b = b1 + 2 * a0 + b0
|
||||
c = c1 - a0 - c0
|
||||
i1 = (-b + (b**2 - 4 * a * c) ** 0.5) / (2 * a)
|
||||
i2 = (-b - (b**2 - 4 * a * c) ** 0.5) / (2 * a)
|
||||
separation = max(i1, i2)
|
||||
|
||||
return separation
|
||||
|
||||
|
||||
def linear_classification_bias(
|
||||
X0: torch.Tensor, X1: torch.Tensor, weights: torch.Tensor, method: str
|
||||
) -> torch.Tensor:
|
||||
"""Find the bias of a feed-forward classification layer, given its weights.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X0 : torch.Tensor
|
||||
Input data of the first class.
|
||||
X1 : torch.Tensor
|
||||
Input data of the second class.
|
||||
weights : torch.Tensor
|
||||
Weights of the linear layer.
|
||||
method : str
|
||||
Method to be used for bias calculation. Can be either "mean" or "quadratic".
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Bias of the linear layer.
|
||||
"""
|
||||
# sometimes, there are no elements for the given class. Return 0 if that is the case
|
||||
if len(X0) == 0 or len(X1) == 0:
|
||||
return torch.tensor(0)
|
||||
|
||||
# project observations over weights to get 1D vectors
|
||||
p0 = (X0 @ weights.T).squeeze()
|
||||
p1 = (X1 @ weights.T).squeeze()
|
||||
|
||||
# find bias according to class projections
|
||||
bias = separe_class_projections(p0, p1, method)
|
||||
return bias
|
||||
|
||||
|
||||
def linear_classification_output_layer(
|
||||
X: torch.Tensor, y: torch.Tensor, weights_method: str, bias_method: str
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Initialize the output (linear) layer of a classification neural network.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : torch.Tensor
|
||||
Input data to the classification.
|
||||
y : torch.Tensor
|
||||
Labels of the input data.
|
||||
weights_method : str
|
||||
Method to be used for weights initialization. Can be either "mean" or "median".
|
||||
bias_method : str
|
||||
Method to be used for bias initialization. Can be either "mean" or "quadratic".
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
Weights and bias of the output layer.
|
||||
"""
|
||||
all_classes = torch.unique(y)
|
||||
all_weights = []
|
||||
all_biases = []
|
||||
binary = len(all_classes) == 2
|
||||
|
||||
# when there are only 2 classes, it is only necessary to consider class 1
|
||||
if binary:
|
||||
all_classes = all_classes[1:]
|
||||
|
||||
for yi in all_classes:
|
||||
# for each class, initialize weights and biases
|
||||
X0 = X[y != yi]
|
||||
X1 = X[y == yi]
|
||||
|
||||
weights = linear_classification_weights(X1, weights_method)
|
||||
bias = linear_classification_bias(X0, X1, weights, bias_method)
|
||||
|
||||
all_weights.append(weights)
|
||||
all_biases.append(bias)
|
||||
|
||||
# transform lists of tensors to tensors
|
||||
if binary:
|
||||
all_weights = all_weights[0]
|
||||
else:
|
||||
all_weights = torch.cat(all_weights)
|
||||
all_biases = torch.tensor(all_biases)
|
||||
|
||||
return all_weights, all_biases
|
||||
|
||||
|
||||
def select_support_vectors(
|
||||
X: torch.Tensor, y: torch.Tensor, distances, num_neurons: int, num_classes: int
|
||||
) -> tuple[torch.Tensor, list[int]]:
|
||||
"""
|
||||
Find the support vectors for a set of vectors.
|
||||
|
||||
Support vectors are defined as the elements of a certain class
|
||||
that are closer to elements from a different class.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : torch.Tensor
|
||||
Input data to the classification.
|
||||
y : torch.Tensor
|
||||
Labels of the input data.
|
||||
distances : torch.Tensor
|
||||
Pairwise distances between input data.
|
||||
num_neurons : int
|
||||
Number of neurons to be used in the layer.
|
||||
num_classes : int
|
||||
Number of classes in the input data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[torch.Tensor, list[int]]
|
||||
Support vectors and their corresponding classes.
|
||||
"""
|
||||
# get how many neurons should belong to each class
|
||||
quotient, remainder = divmod(num_neurons, num_classes)
|
||||
neurons_classes = [quotient] * num_classes
|
||||
for idx in range(remainder):
|
||||
neurons_classes[idx] += 1
|
||||
|
||||
vectors = []
|
||||
classes = []
|
||||
# iterate over each class with its number of corresponding neurons
|
||||
for label, num_vectors in enumerate(neurons_classes):
|
||||
# get elements belonging to desired class
|
||||
label_indices = y == label
|
||||
X1 = X[label_indices]
|
||||
# obtain distances between elements belonging to class and elements
|
||||
label_outside_distances = distances[label_indices, :][:, ~label_indices]
|
||||
# get mean distance for each element in class and order from lesser to greater
|
||||
label_outside_distances = label_outside_distances.mean(dim=1)
|
||||
min_elements = label_outside_distances.argsort()[:num_vectors]
|
||||
# vectors closer to elements from other classes are support vectors
|
||||
vectors.append(X1[min_elements])
|
||||
classes.extend([label] * num_vectors)
|
||||
vectors = torch.cat(vectors)
|
||||
# return vectors with their corresponding classes
|
||||
return vectors, classes
|
||||
|
||||
|
||||
def linear_classification_hidden_layer(
|
||||
X: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
num_neurons: int,
|
||||
num_classes: int,
|
||||
weights_method: str,
|
||||
bias_method: str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Initialize a hidden linear layer of a classification neural network.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : torch.Tensor
|
||||
Input data to the classification.
|
||||
y : torch.Tensor
|
||||
Labels of the input data.
|
||||
num_neurons : int
|
||||
Number of neurons to be used in the layer.
|
||||
num_classes : int
|
||||
Number of classes in the input data.
|
||||
weights_method : str
|
||||
Method to be used for weights initialization. Can be either "mean" or "median".
|
||||
bias_method : str
|
||||
Method to be used for bias initialization. Can be either "mean" or "quadratic".
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
Weights and bias of the hidden layer.
|
||||
"""
|
||||
# get pairwise distances of input observations
|
||||
distances = torch.cdist(X, X)
|
||||
# get support vectors for the layer
|
||||
characteristic_vectors, cv_classes = select_support_vectors(
|
||||
X, y, distances, num_neurons, num_classes
|
||||
)
|
||||
|
||||
# get distances between support vectors and every element
|
||||
distances = torch.cdist(characteristic_vectors, X)
|
||||
# neighborhoods are Voronoi regions of support vectors
|
||||
neighborhoods = distances.argmin(dim=0)
|
||||
|
||||
layer_weights = []
|
||||
layer_bias = []
|
||||
for neighborhood in range(num_neurons):
|
||||
# get points belonging to neighborhood
|
||||
close_points = neighborhoods == neighborhood
|
||||
close_X = X[close_points]
|
||||
close_y = y[close_points]
|
||||
|
||||
# get elements belonging to same class as support vector
|
||||
k = cv_classes[neighborhood]
|
||||
X0 = close_X[close_y != k]
|
||||
X1 = close_X[close_y == k]
|
||||
|
||||
# get weights and biases of layer using elements in same class as support vector
|
||||
weights = linear_classification_weights(
|
||||
X1 - X1.mean(dim=1, keepdim=True), weights_method
|
||||
)
|
||||
bias = linear_classification_bias(X0, X1, weights, bias_method)
|
||||
|
||||
layer_weights.append(weights)
|
||||
layer_bias.append(bias)
|
||||
|
||||
layer_weights = torch.cat(layer_weights)
|
||||
layer_bias = torch.tensor(layer_bias)
|
||||
# return weights and biases of each layer as single tensor
|
||||
return layer_weights, layer_bias
|
||||
|
||||
|
||||
def rnn_bias(
|
||||
X0: torch.Tensor,
|
||||
X1: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
h0: torch.Tensor,
|
||||
h1: torch.Tensor,
|
||||
h_weights: torch.Tensor,
|
||||
method: str,
|
||||
) -> torch.Tensor:
|
||||
"""Find the bias of a recurrent classification layer, given all of its weights.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X0 : torch.Tensor
|
||||
Input data of the first class.
|
||||
X1 : torch.Tensor
|
||||
Input data of the second class.
|
||||
weights : torch.Tensor
|
||||
Weights of the linear layer.
|
||||
h0 : torch.Tensor
|
||||
Hidden state of the first class.
|
||||
h1 : torch.Tensor
|
||||
Hidden state of the second class.
|
||||
h_weights : torch.Tensor
|
||||
Weights of the hidden state.
|
||||
method : str
|
||||
Method to be used for bias calculation. Can be either "mean" or "quadratic".
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Bias of the recurrent layer.
|
||||
"""
|
||||
# sometimes, there are no elements for the given class. Return 0 if that is the case
|
||||
if len(X0) == 0 or len(X1) == 0:
|
||||
return torch.tensor(0)
|
||||
|
||||
# project observations over weights to get 1D vectors
|
||||
p0 = (X0 @ weights.T + h0 @ h_weights.T).squeeze()
|
||||
p1 = (X1 @ weights.T + h1 @ h_weights.T).squeeze()
|
||||
|
||||
# find bias according to class projections
|
||||
bias = separe_class_projections(p0, p1, method)
|
||||
return bias
|
||||
|
||||
|
||||
def rnn_hidden_layer(
|
||||
X: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
num_neurons: int,
|
||||
num_classes: int,
|
||||
weights_method: str,
|
||||
bias_method: str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Initialize a recurrent layer of a classification neural network.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : torch.Tensor
|
||||
Input data to the classification.
|
||||
h : torch.Tensor
|
||||
Hidden state of the input data.
|
||||
y : torch.Tensor
|
||||
Labels of the input data.
|
||||
num_neurons : int
|
||||
Number of neurons to be used in the layer.
|
||||
num_classes : int
|
||||
Number of classes in the input data.
|
||||
weights_method : str
|
||||
Method to be used for weights initialization. Can be either "mean" or "median".
|
||||
bias_method : str
|
||||
Method to be used for bias initialization. Can be either "mean" or "quadratic".
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
Weights, hidden state weights, and bias of the recurrent layer.
|
||||
"""
|
||||
# get pairwise distances of input observations
|
||||
distances = torch.cdist(X, X)
|
||||
# get support vectors for the layer
|
||||
characteristic_vectors, cv_classes = select_support_vectors(
|
||||
X, y, distances, num_neurons, num_classes
|
||||
)
|
||||
|
||||
# get distances between support vectors and every element
|
||||
distances = torch.cdist(characteristic_vectors, X)
|
||||
# neighborhoods are Voronoi regions of support vectors
|
||||
neighborhoods = distances.argmin(dim=0)
|
||||
|
||||
layer_weights = []
|
||||
layer_h_weights = []
|
||||
layer_bias = []
|
||||
for neighborhood in range(num_neurons):
|
||||
# get points from X, y, and h belonging to neighborhood
|
||||
close_points = neighborhoods == neighborhood
|
||||
close_X = X[close_points]
|
||||
close_h = h[close_points]
|
||||
close_y = y[close_points]
|
||||
|
||||
# get elements belonging to same class as support vector
|
||||
k = cv_classes[neighborhood]
|
||||
X0 = close_X[close_y != k]
|
||||
X1 = close_X[close_y == k]
|
||||
h0 = close_h[close_y != k]
|
||||
h1 = close_h[close_y == k]
|
||||
|
||||
# get weights and biases of layer using elements in same class as support vector
|
||||
weights = linear_classification_weights(
|
||||
X1 - X1.mean(dim=1, keepdim=True), weights_method
|
||||
)
|
||||
h_weights = linear_classification_weights(
|
||||
h1 - h1.mean(dim=1, keepdim=True), weights_method
|
||||
)
|
||||
bias = rnn_bias(X0, X1, weights, h0, h1, h_weights, bias_method)
|
||||
|
||||
layer_weights.append(weights)
|
||||
layer_h_weights.append(h_weights)
|
||||
layer_bias.append(bias)
|
||||
|
||||
layer_weights = torch.cat(layer_weights)
|
||||
layer_h_weights = torch.cat(layer_h_weights)
|
||||
layer_bias = torch.tensor(layer_bias)
|
||||
# return weights and biases of each layer as single tensor
|
||||
return layer_weights, layer_h_weights, layer_bias
|
||||
|
||||
|
||||
def conv_bias(
|
||||
X: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
kernel_classes: list[int],
|
||||
method: str,
|
||||
) -> torch.Tensor:
|
||||
"""Find the bias of a convolutional classification layer, given its kernel.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : torch.Tensor
|
||||
Input data of the classification.
|
||||
y : torch.Tensor
|
||||
Labels of the input data.
|
||||
kernel : torch.Tensor
|
||||
Kernel of the convolutional layer.
|
||||
kernel_classes : list[int]
|
||||
Classes of the kernel.
|
||||
method : str
|
||||
Method to be used for bias calculation. Can be either "mean" or "quadratic".
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Bias of the convolutional layer.
|
||||
"""
|
||||
# relevant values are kernel dimensions
|
||||
out_dims, in_dims, rows, columns = kernel.shape
|
||||
layer_bias = []
|
||||
# find the bias of each individual kernel layer
|
||||
for iteration in range(out_dims):
|
||||
# find the points corresponding to the same class as the kernel layer
|
||||
iteration_class = kernel_classes[iteration]
|
||||
X0 = X[y == iteration_class]
|
||||
X1 = X[y != iteration_class]
|
||||
iteration_kernel = kernel[iteration : iteration + 1]
|
||||
|
||||
# get convolution of observations and kernel layer to get 1D vectors
|
||||
p0 = torch.nn.functional.conv2d(X0, iteration_kernel).flatten()
|
||||
p1 = torch.nn.functional.conv2d(X1, iteration_kernel).flatten()
|
||||
|
||||
# find bias according to class projections
|
||||
bias = separe_class_projections(p0, p1, method)
|
||||
layer_bias.append(bias)
|
||||
|
||||
# return inverse of kernel to center class division
|
||||
layer_bias = -torch.tensor(layer_bias)
|
||||
return layer_bias
|
||||
|
||||
|
||||
def init_weights_conv(
|
||||
X: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
out_channels: int,
|
||||
kernel_row: int,
|
||||
kernel_col: int,
|
||||
num_classes: int,
|
||||
bias_method: str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Initialize a convolutional layer of a classification neural network.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : torch.Tensor
|
||||
Input data to the classification.
|
||||
y : torch.Tensor
|
||||
Labels of the input data.
|
||||
out_channels : int
|
||||
Number of output channels of the convolutional layer.
|
||||
kernel_row : int
|
||||
Number of rows of the kernel.
|
||||
kernel_col : int
|
||||
Number of columns of the kernel.
|
||||
num_classes : int
|
||||
Number of classes in the input data.
|
||||
bias_method : str
|
||||
Method to be used for bias initialization. Can be either "mean" or "quadratic".
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
Weights and bias of the convolutional layer.
|
||||
"""
|
||||
# get dimensions of input data
|
||||
num_data, num_channels, rows, columns = X.shape
|
||||
fragments = []
|
||||
fragments_classes = []
|
||||
# get number of row and column partitions over data
|
||||
row_iters = rows // kernel_row
|
||||
col_iters = columns // kernel_col
|
||||
|
||||
for unique_class in range(num_classes):
|
||||
# get median element of each class to represent it
|
||||
X_class = X[y == unique_class]
|
||||
X_class = X.median(dim=0).values
|
||||
# iterate over groups of rows and columns in median object
|
||||
for row_idx in range(row_iters):
|
||||
for col_idx in range(col_iters):
|
||||
# flatten fragment to use as kernel candidate
|
||||
fragment = X_class[
|
||||
:,
|
||||
kernel_row * row_idx : kernel_row * (row_idx + 1),
|
||||
kernel_col * col_idx : kernel_col * (col_idx + 1),
|
||||
]
|
||||
fragment = fragment.flatten()
|
||||
if fragment.norm() <= 0.01:
|
||||
continue
|
||||
fragments.append(fragment)
|
||||
fragments_classes.append(unique_class)
|
||||
|
||||
fragments = torch.stack(fragments)
|
||||
# normalize fragments
|
||||
fragments -= fragments.mean(1, keepdim=True)
|
||||
# if there are not enough fragments, use existing fragments plus a
|
||||
# random noise as extra fragments until fragment number is enough
|
||||
# (at least equal to the number of output channels)
|
||||
while fragments.shape[0] < out_channels:
|
||||
difference = out_channels - fragments.shape[0]
|
||||
difference = fragments[: min(difference, fragments.shape[0])]
|
||||
difference += torch.normal(0, 0.1, size=difference.shape)
|
||||
fragments = torch.cat((fragments, difference), 0)
|
||||
|
||||
# get pairwise correlations of kernel candidates
|
||||
correlations = torch.zeros(len(fragments), len(fragments))
|
||||
for idx1 in range(len(fragments)):
|
||||
for idx2 in range(idx1, len(fragments)):
|
||||
correlations[idx1, idx2] = abs(
|
||||
torch.nn.functional.cosine_similarity(
|
||||
fragments[idx1], fragments[idx2], dim=0
|
||||
)
|
||||
)
|
||||
correlations[idx2, idx1] = correlations[idx1, idx2]
|
||||
|
||||
fragments_classes = torch.tensor(fragments_classes)
|
||||
# find optimal kernels from kernel candidates, using support vectors method
|
||||
characteristic_vectors, kernel_classes = select_support_vectors(
|
||||
fragments, fragments_classes, correlations, out_channels, num_classes
|
||||
)
|
||||
current_num_weights = characteristic_vectors.shape[0]
|
||||
# un-flatten selected kernels
|
||||
characteristic_vectors = characteristic_vectors.reshape(
|
||||
(current_num_weights, num_channels, kernel_row, kernel_col)
|
||||
)
|
||||
# normalize selected kernels
|
||||
for weight in range(current_num_weights):
|
||||
for channel in range(num_channels):
|
||||
characteristic_vectors[weight, channel, :, :] /= torch.linalg.matrix_norm(
|
||||
characteristic_vectors[weight, channel, :, :]
|
||||
)
|
||||
|
||||
# if there are not enough kernels, use existing kernels plus a
|
||||
# random noise as extra kernels until kernel number is enough
|
||||
# (at least equal to the number of output channels)
|
||||
while current_num_weights < out_channels:
|
||||
difference = out_channels - current_num_weights
|
||||
difference = characteristic_vectors[: min(difference, current_num_weights)]
|
||||
difference += torch.normal(0, 0.1, size=difference.shape)
|
||||
characteristic_vectors = torch.cat((characteristic_vectors, difference), 0)
|
||||
current_num_weights = characteristic_vectors.shape[0]
|
||||
|
||||
# find layer biases using selected kernels
|
||||
layer_bias = conv_bias(X, y, characteristic_vectors, kernel_classes, bias_method)
|
||||
return characteristic_vectors, layer_bias
|
||||
|
||||
|
||||
def init_weights_classification(
|
||||
model: torch.nn.Module,
|
||||
X: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
weights_method: str = "mean",
|
||||
bias_method: str = "mean",
|
||||
) -> torch.nn.Module:
|
||||
"""Initialize a classification neural network.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : torch.nn.Module
|
||||
Neural network model to be initialized.
|
||||
X : torch.Tensor
|
||||
Input data to the classification.
|
||||
y : torch.Tensor
|
||||
Labels of the input data.
|
||||
weights_method : str, optional
|
||||
Method to be used for weights initialization. Can be either "mean" or "median".
|
||||
Default is "mean".
|
||||
bias_method : str, optional
|
||||
Method to be used for bias initialization. Can be either "mean" or "quadratic".
|
||||
Default is "mean".
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.nn.Module
|
||||
Initialized neural network model.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If X and y have different number of samples.
|
||||
If input data has less than 2 classes.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import torch
|
||||
>>> from torch import nn
|
||||
>>> num_classes = 3
|
||||
>>> num_data = 20
|
||||
>>> num_features = 10
|
||||
>>> hidden_size = 5
|
||||
>>> X = torch.rand(num_data, num_features)
|
||||
>>> y = torch.randint(0, num_classes, (num_data,))
|
||||
>>> model = nn.Sequential(
|
||||
>>> nn.Linear(num_features, hidden_size),
|
||||
>>> nn.ReLU(),
|
||||
>>> nn.Linear(hidden_size, num_classes)
|
||||
>>> )
|
||||
>>> model = init_weights_classification(model, X, y)
|
||||
"""
|
||||
# Throw error when X and y have different number of samples
|
||||
if X.shape[0] != y.shape[0]:
|
||||
raise ValueError("X and y must have the same number of samples")
|
||||
|
||||
num_classes = len(torch.unique(y))
|
||||
# There must be at least two classes. Else, throw error
|
||||
if num_classes < 2:
|
||||
raise ValueError("Input data must present at least 2 classes")
|
||||
|
||||
# Model is a single linear layer
|
||||
if isinstance(model, torch.nn.Linear):
|
||||
# find weights and biases of single layer
|
||||
weights, bias = linear_classification_output_layer(
|
||||
X, y, weights_method, bias_method
|
||||
)
|
||||
model.weight = torch.nn.Parameter(weights)
|
||||
model.bias = torch.nn.Parameter(bias)
|
||||
return model
|
||||
|
||||
# Model is a sequential model with at least one linear layer (output)
|
||||
num_layers = len(model) - 1
|
||||
|
||||
for layer in range(num_layers):
|
||||
if isinstance(model[layer], torch.nn.Linear):
|
||||
# layer is linear layer
|
||||
num_neurons = model[layer].out_features
|
||||
layer_weights, layer_bias = linear_classification_hidden_layer(
|
||||
X, y, num_neurons, num_classes, weights_method, bias_method
|
||||
)
|
||||
model[layer].weight = torch.nn.Parameter(layer_weights)
|
||||
model[layer].bias = torch.nn.Parameter(layer_bias)
|
||||
elif isinstance(model[layer], torch.nn.Conv2d):
|
||||
# layer is 2D convolutional layer
|
||||
kernel_row, kernel_col = model[layer].kernel_size
|
||||
out_channels = model[layer].out_channels
|
||||
layer_weights, layer_bias = init_weights_conv(
|
||||
X, y, out_channels, kernel_row, kernel_col, num_classes, bias_method
|
||||
)
|
||||
model[layer].weight = torch.nn.Parameter(layer_weights)
|
||||
model[layer].bias = torch.nn.Parameter(layer_bias)
|
||||
elif isinstance(model[layer], torch.nn.RNN):
|
||||
# layer is (stack of) recurrent layer. This kind of
|
||||
# layer requires an special treatment due to the way
|
||||
# it is represented in pytorch: not as a group of
|
||||
# layers but as a single object with multiple weights
|
||||
# and biases
|
||||
num_rnn_layers = model[layer].num_layers
|
||||
num_neurons = model[layer].hidden_size
|
||||
activation = (
|
||||
torch.nn.functional.tanh
|
||||
if model[layer].nonlinearity == "tanh"
|
||||
else torch.nn.functional.relu
|
||||
)
|
||||
# get last element in sequence
|
||||
layer_X = X[:, -1, :].detach().clone()
|
||||
for layer_idx in range(num_rnn_layers):
|
||||
# initialize the x-weights of each recurrent layer in stack
|
||||
layer_weights, layer_bias = linear_classification_hidden_layer(
|
||||
layer_X, y, num_neurons, num_classes, weights_method, bias_method
|
||||
)
|
||||
setattr(
|
||||
model[layer],
|
||||
f"weight_ih_l{layer_idx}",
|
||||
torch.nn.Parameter(layer_weights),
|
||||
)
|
||||
setattr(
|
||||
model[layer],
|
||||
f"bias_ih_l{layer_idx}",
|
||||
torch.nn.Parameter(layer_bias),
|
||||
)
|
||||
# propagate layer_x through recurrent stack
|
||||
layer_X = activation(layer_X @ layer_weights.T + layer_bias)
|
||||
# obtain final state of h for each recurrent layer in stack to use as h0
|
||||
_, h0 = model[layer](X)
|
||||
# get last element in sequence
|
||||
layer_X = X[:, -1, :].detach().clone()
|
||||
for layer_idx in range(num_rnn_layers):
|
||||
# initialize the x-weights and h-weights of each recurrent layer in stack
|
||||
layer_weights, layer_h_weights, layer_bias = rnn_hidden_layer(
|
||||
layer_X,
|
||||
h0[layer_idx],
|
||||
y,
|
||||
num_neurons,
|
||||
num_classes,
|
||||
weights_method,
|
||||
bias_method,
|
||||
)
|
||||
setattr(
|
||||
model[layer],
|
||||
f"weight_ih_l{layer_idx}",
|
||||
torch.nn.Parameter(layer_weights),
|
||||
)
|
||||
setattr(
|
||||
model[layer],
|
||||
f"bias_ih_l{layer_idx}",
|
||||
torch.nn.Parameter(layer_bias),
|
||||
)
|
||||
setattr(
|
||||
model[layer],
|
||||
f"weight_hh_l{layer_idx}",
|
||||
torch.nn.Parameter(layer_h_weights),
|
||||
)
|
||||
setattr(
|
||||
model[layer],
|
||||
f"bias_hh_l{layer_idx}",
|
||||
torch.nn.Parameter(torch.zeros_like(layer_bias)),
|
||||
)
|
||||
# propagate layer_x through recurrent stack
|
||||
layer_X = activation(
|
||||
layer_X @ layer_weights.T
|
||||
+ h0[layer_idx] @ layer_h_weights.T
|
||||
+ layer_bias
|
||||
)
|
||||
# propagate X no matter the layers type
|
||||
X = model[layer](X)
|
||||
|
||||
# Last layer (linear output layer)
|
||||
layer_weights, layer_bias = linear_classification_output_layer(
|
||||
X, y, weights_method, bias_method
|
||||
)
|
||||
model[num_layers].weight = torch.nn.Parameter(layer_weights)
|
||||
model[num_layers].bias = torch.nn.Parameter(layer_bias)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def linear_regression(
|
||||
X: torch.Tensor, y: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Fit a single linear regression over input data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : torch.Tensor
|
||||
Input data to the regression.
|
||||
y : torch.Tensor
|
||||
Target values of the input data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
Weights and bias of the linear regression.
|
||||
"""
|
||||
# expand X with a column of ones to find bias along with weights
|
||||
ones = torch.ones(X.shape[0], 1)
|
||||
X = torch.cat((ones, X), dim=1)
|
||||
# find parameters by multiplying pseudoinverse of X with y
|
||||
weights = torch.linalg.pinv(X) @ y
|
||||
# bias is first column of parameters and weights are the remaining columns
|
||||
bias = weights[0]
|
||||
weights = torch.unsqueeze(weights[1:], dim=0)
|
||||
# return weights and biases
|
||||
return weights, bias
|
||||
|
||||
|
||||
def piecewise_linear_regression(
|
||||
X: torch.Tensor, y: torch.Tensor, num_pieces: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Fit multiple linear regressions over different sections of input data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : torch.Tensor
|
||||
Input data to the regression.
|
||||
y : torch.Tensor
|
||||
Target values of the input data.
|
||||
num_pieces : int
|
||||
Number of segments to divide the input data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
Weights and bias of the piecewise linear regression.
|
||||
"""
|
||||
# order data according to X dimensions values
|
||||
ordered_idx = torch.argsort(X, dim=0)[:, 0]
|
||||
X = X[ordered_idx]
|
||||
y = y[ordered_idx]
|
||||
# find the size of a segment
|
||||
piece_length = len(y) // num_pieces
|
||||
all_weights = []
|
||||
all_biases = []
|
||||
|
||||
# iterate over every segment
|
||||
for piece in range(num_pieces):
|
||||
# get data belonging to segment
|
||||
piece_idx = range(piece_length * piece, piece_length * (piece + 1))
|
||||
partial_X = X[piece_idx]
|
||||
partial_y = y[piece_idx]
|
||||
# fit linear regression over segment to obtain partial weights and biases
|
||||
weights, bias = linear_regression(partial_X, partial_y)
|
||||
all_weights.append(weights)
|
||||
all_biases.append(bias)
|
||||
# merge all weights and biases into individual tensors
|
||||
all_weights = torch.cat(all_weights, dim=0)
|
||||
all_biases = torch.tensor(all_biases)
|
||||
# return results
|
||||
return all_weights, all_biases
|
||||
|
||||
|
||||
def init_weights_regression(
|
||||
model: torch.nn.Module, X: torch.Tensor, y: torch.Tensor
|
||||
) -> torch.nn.Module:
|
||||
"""Initialize a regression neural network.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : torch.nn.Module
|
||||
Neural network model to be initialized.
|
||||
X : torch.Tensor
|
||||
Input data to the regression.
|
||||
y : torch.Tensor
|
||||
Target values of the input data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.nn.Module
|
||||
Initialized neural network model.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If X and y have different number of samples.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import torch
|
||||
>>> from torch import nn
|
||||
>>> num_data = 20
|
||||
>>> num_features = 10
|
||||
>>> hidden_size = 5
|
||||
>>> X = torch.rand(num_data, num_features)
|
||||
>>> y = torch.rand(num_data)
|
||||
>>> model = nn.Sequential(
|
||||
>>> nn.Linear(num_features, hidden_size),
|
||||
>>> nn.ReLU(),
|
||||
>>> nn.Linear(hidden_size, 1)
|
||||
>>> )
|
||||
>>> model = init_weights_regression(model, X, y)
|
||||
"""
|
||||
# Throw error when X and y have different number of samples
|
||||
if X.shape[0] != y.shape[0]:
|
||||
raise ValueError("X and y must have the same number of samples")
|
||||
|
||||
# Model is a single linear layer
|
||||
if isinstance(model, torch.nn.Linear):
|
||||
# fit singular linear regression and set models parameters
|
||||
weights, bias = linear_regression(X, y)
|
||||
model.weight = torch.nn.Parameter(weights)
|
||||
model.bias = torch.nn.Parameter(bias)
|
||||
return model
|
||||
|
||||
# Model is a sequential model with at least one linear layer (output)
|
||||
num_layers = len(model) - 1
|
||||
for layer in range(num_layers):
|
||||
if isinstance(model[layer], torch.nn.Linear):
|
||||
# layer is linear layer
|
||||
layer_weights, layer_bias = piecewise_linear_regression(
|
||||
X, y, num_pieces=model[layer].out_features
|
||||
)
|
||||
model[layer].weight = torch.nn.Parameter(layer_weights)
|
||||
model[layer].bias = torch.nn.Parameter(layer_bias)
|
||||
# propagate X no matter the layers type
|
||||
X = model[layer](X)
|
||||
|
||||
# Last layer (linear output layer)
|
||||
layer_weights, layer_bias = linear_regression(X, y)
|
||||
model[num_layers].weight = torch.nn.Parameter(layer_weights)
|
||||
model[num_layers].bias = torch.nn.Parameter(layer_bias)
|
||||
|
||||
return model
|
459
ideal_init/lightning_model.py
Normal file
459
ideal_init/lightning_model.py
Normal file
|
@ -0,0 +1,459 @@
|
|||
"""
|
||||
Pytorch lightning models using IDEAL initialization.
|
||||
|
||||
Author
|
||||
------
|
||||
Nicolas Rojas
|
||||
"""
|
||||
from time import time
|
||||
from numpy import ndarray
|
||||
from pandas import read_csv, concat
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import linalg
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from torchmetrics import Accuracy, R2Score
|
||||
from lightning import LightningModule
|
||||
from .initialization import init_weights_regression, init_weights_classification
|
||||
|
||||
|
||||
class NNClassifier(LightningModule):
|
||||
def __init__(self, X_train: ndarray, y_train: ndarray, X_val: ndarray, y_val: ndarray, X_test: ndarray, y_test: ndarray, initialize: bool = False, hidden_sizes: tuple[int] = None, learning_rate: float = 1e-3, batch_size: int = 64, num_workers: int = 4):
|
||||
super().__init__()
|
||||
|
||||
# Set our init args as class attributes
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
# Normalize data
|
||||
scaler = StandardScaler()
|
||||
X_train = scaler.fit_transform(X_train)
|
||||
|
||||
# Transform numpy matrices to torch tensors
|
||||
X_train = torch.from_numpy(X_train).float()
|
||||
y_train = torch.from_numpy(y_train)
|
||||
|
||||
self.in_dims = X_train.shape[1]
|
||||
self.n_classes = len(torch.unique(y_train))
|
||||
binary = self.n_classes == 2
|
||||
# Define PyTorch model
|
||||
if binary:
|
||||
self.metric = Accuracy(task="binary")
|
||||
self.loss_fn = F.binary_cross_entropy_with_logits
|
||||
self.out_activation = nn.Sigmoid()
|
||||
out_shape = 1
|
||||
else:
|
||||
self.metric = Accuracy(task="multiclass", num_classes=self.n_classes)
|
||||
self.loss_fn = F.cross_entropy
|
||||
self.out_activation = nn.Softmax(dim=1)
|
||||
out_shape = self.n_classes
|
||||
|
||||
self.init_time = time()
|
||||
# Check if model is or is not multilayer
|
||||
if hidden_sizes is None:
|
||||
self.model = nn.Linear(self.in_dims, out_shape)
|
||||
else:
|
||||
last_shape = self.in_dims
|
||||
self.model = nn.Sequential()
|
||||
for hidden_size in hidden_sizes:
|
||||
self.model.append(nn.Linear(last_shape, hidden_size))
|
||||
self.model.append(nn.ReLU())
|
||||
last_shape = hidden_size
|
||||
self.model.append(nn.Linear(last_shape, out_shape))
|
||||
|
||||
# Initialize model weights if needed
|
||||
if initialize:
|
||||
self.model = init_weights_classification(self.model, X_train, y_train, weights_method="mean", bias_method="mean")
|
||||
self.init_time = time() - self.init_time
|
||||
|
||||
# Create datasets
|
||||
if binary:
|
||||
y_train = y_train.float().unsqueeze(dim=1)
|
||||
y_val = torch.from_numpy(y_val).float().unsqueeze(dim=1)
|
||||
y_test = torch.from_numpy(y_test).float().unsqueeze(dim=1)
|
||||
else:
|
||||
y_train = y_train.long()
|
||||
y_val = torch.from_numpy(y_val).long()
|
||||
y_test = torch.from_numpy(y_test).long()
|
||||
X_val = torch.from_numpy(scaler.transform(X_val)).float()
|
||||
X_test = torch.from_numpy(scaler.transform(X_test)).float()
|
||||
|
||||
self.train_data = TensorDataset(X_train, y_train)
|
||||
self.val_data = TensorDataset(X_val, y_val)
|
||||
self.test_data = TensorDataset(X_test, y_test)
|
||||
|
||||
def forward(self, x):
|
||||
logits = self.model(x)
|
||||
probas = self.out_activation(logits)
|
||||
return probas
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self.model(x)
|
||||
loss = self.loss_fn(logits, y)
|
||||
metric = self.metric(logits, y)
|
||||
|
||||
self.log("train_loss", loss, prog_bar=False, on_step=True, on_epoch=False)
|
||||
self.log("train_metric", metric, prog_bar=False, on_step=True, on_epoch=False)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self.model(x)
|
||||
loss = self.loss_fn(logits, y)
|
||||
metric = self.metric(logits, y)
|
||||
|
||||
self.log("val_loss", loss, prog_bar=True, on_step=True, on_epoch=False)
|
||||
self.log("val_metric", metric, prog_bar=True, on_step=True, on_epoch=False)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self.model(x)
|
||||
loss = self.loss_fn(logits, y)
|
||||
metric = self.metric(logits, y)
|
||||
|
||||
self.log("test_loss", loss, prog_bar=True)
|
||||
self.log("test_metric", metric, prog_bar=True)
|
||||
|
||||
def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
|
||||
x, y = batch
|
||||
probas = self(x)
|
||||
return probas
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=self.num_workers)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=self.num_workers)
|
||||
|
||||
def predict_dataloader(self):
|
||||
return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=self.num_workers)
|
||||
|
||||
|
||||
class NNRegressor(LightningModule):
|
||||
def __init__(self, X_train: ndarray, y_train: ndarray, X_val: ndarray, y_val: ndarray, X_test: ndarray, y_test: ndarray, initialize: bool = False, hidden_sizes: tuple[int] = None, learning_rate: float = 1e-3, batch_size: int = 64, num_workers: int = 4):
|
||||
super().__init__()
|
||||
|
||||
# Set our init args as class attributes
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
# Normalize data
|
||||
scaler = StandardScaler()
|
||||
X_train = scaler.fit_transform(X_train)
|
||||
|
||||
# Transform numpy matrices to torch tensors
|
||||
X_train = torch.from_numpy(X_train).float()
|
||||
y_train = torch.from_numpy(y_train).float()
|
||||
|
||||
self.in_dims = X_train.shape[1]
|
||||
|
||||
# Define PyTorch model
|
||||
self.init_time = time()
|
||||
if hidden_sizes is None:
|
||||
self.model = nn.Linear(self.in_dims, 1)
|
||||
else:
|
||||
last_shape = self.in_dims
|
||||
self.model = nn.Sequential()
|
||||
for hidden_size in hidden_sizes:
|
||||
self.model.append(nn.Linear(last_shape, hidden_size))
|
||||
self.model.append(nn.ReLU())
|
||||
last_shape = hidden_size
|
||||
self.model.append(nn.Linear(last_shape, 1))
|
||||
|
||||
self.metric = R2Score()
|
||||
self.loss_fn = F.mse_loss
|
||||
|
||||
# Initialize model weights if needed
|
||||
if initialize:
|
||||
self.model = init_weights_regression(self.model, X_train, y_train)
|
||||
self.init_time = time() - self.init_time
|
||||
|
||||
# Create datasets
|
||||
y_train = y_train.unsqueeze(dim=1)
|
||||
y_val = torch.from_numpy(y_val).float().unsqueeze(dim=1)
|
||||
y_test = torch.from_numpy(y_test).float().unsqueeze(dim=1)
|
||||
X_val = torch.from_numpy(scaler.transform(X_val)).float()
|
||||
X_test = torch.from_numpy(scaler.transform(X_test)).float()
|
||||
self.train_data = TensorDataset(X_train, y_train)
|
||||
self.val_data = TensorDataset(X_val, y_val)
|
||||
self.test_data = TensorDataset(X_test, y_test)
|
||||
|
||||
def forward(self, x):
|
||||
logits = self.model(x)
|
||||
return logits
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self.model(x)
|
||||
loss = self.loss_fn(logits, y)
|
||||
metric = self.metric(logits, y)
|
||||
|
||||
self.log("train_loss", loss, prog_bar=False, on_step=True, on_epoch=False)
|
||||
self.log("train_metric", metric, prog_bar=False, on_step=True, on_epoch=False)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self.model(x)
|
||||
loss = self.loss_fn(logits, y)
|
||||
metric = self.metric(logits, y)
|
||||
|
||||
self.log("val_loss", loss, prog_bar=True, on_step=True, on_epoch=False)
|
||||
self.log("val_metric", metric, prog_bar=True, on_step=True, on_epoch=False)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self.model(x)
|
||||
loss = self.loss_fn(logits, y)
|
||||
metric = self.metric(logits, y)
|
||||
|
||||
self.log("test_loss", loss, prog_bar=True)
|
||||
self.log("test_metric", metric, prog_bar=True)
|
||||
|
||||
def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
|
||||
x, y = batch
|
||||
probas = self(x)
|
||||
return probas
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=self.num_workers)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=self.num_workers)
|
||||
|
||||
def predict_dataloader(self):
|
||||
return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=self.num_workers)
|
||||
|
||||
|
||||
class CNNClassifier(LightningModule):
|
||||
def __init__(self, X_train: ndarray, y_train: ndarray, X_val: ndarray, y_val: ndarray, X_test: ndarray, y_test: ndarray, initialize: bool = False, learning_rate: float = 1e-3, batch_size: int = 64, num_workers: int = 4):
|
||||
super().__init__()
|
||||
|
||||
# Set our init args as class attributes
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
self.in_dims = X_train.shape[1]
|
||||
self.n_classes = len(torch.unique(y_train))
|
||||
binary = self.n_classes == 2
|
||||
# Define PyTorch model
|
||||
if binary:
|
||||
self.metric = Accuracy(task="binary")
|
||||
self.loss_fn = F.binary_cross_entropy_with_logits
|
||||
self.out_activation = nn.Sigmoid()
|
||||
out_shape = 1
|
||||
else:
|
||||
self.metric = Accuracy(task="multiclass", num_classes=self.n_classes)
|
||||
self.loss_fn = F.cross_entropy
|
||||
self.out_activation = nn.Softmax(dim=1)
|
||||
out_shape = self.n_classes
|
||||
|
||||
self.init_time = time()
|
||||
self.model = nn.Sequential(
|
||||
nn.Conv2d(1, 5, kernel_size=5),
|
||||
nn.ReLU(),
|
||||
nn.Flatten(),
|
||||
nn.Linear(2880, out_shape),
|
||||
)
|
||||
|
||||
X_train /= 255.0
|
||||
X_val /= 255.0
|
||||
X_test /= 255.0
|
||||
# Initialize model weights if needed
|
||||
if initialize:
|
||||
self.model = init_weights_classification(self.model, X_train, y_train, weights_method="mean", bias_method="mean")
|
||||
self.init_time = time() - self.init_time
|
||||
|
||||
# Create datasets
|
||||
self.train_data = TensorDataset(X_train, y_train)
|
||||
self.val_data = TensorDataset(X_val, y_val)
|
||||
self.test_data = TensorDataset(X_test, y_test)
|
||||
|
||||
def forward(self, x):
|
||||
logits = self.model(x)
|
||||
probas = self.out_activation(logits)
|
||||
return probas
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self.model(x)
|
||||
loss = self.loss_fn(logits, y)
|
||||
metric = self.metric(logits, y)
|
||||
|
||||
self.log("train_loss", loss, prog_bar=False, on_step=True, on_epoch=False)
|
||||
self.log("train_metric", metric, prog_bar=False, on_step=True, on_epoch=False)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self.model(x)
|
||||
loss = self.loss_fn(logits, y)
|
||||
metric = self.metric(logits, y)
|
||||
|
||||
self.log("val_loss", loss, prog_bar=True, on_step=True, on_epoch=False)
|
||||
self.log("val_metric", metric, prog_bar=True, on_step=True, on_epoch=False)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self.model(x)
|
||||
loss = self.loss_fn(logits, y)
|
||||
metric = self.metric(logits, y)
|
||||
|
||||
self.log("test_loss", loss, prog_bar=True)
|
||||
self.log("test_metric", metric, prog_bar=True)
|
||||
|
||||
def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
|
||||
x, y = batch
|
||||
probas = self(x)
|
||||
return probas
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=self.num_workers)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=self.num_workers)
|
||||
|
||||
def predict_dataloader(self):
|
||||
return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=self.num_workers)
|
||||
|
||||
|
||||
class extract_tensor(nn.Module):
|
||||
def forward(self,x):
|
||||
tensor, _ = x
|
||||
return tensor[:, -1, :]
|
||||
|
||||
|
||||
class RNNClassifier(LightningModule):
|
||||
def __init__(self, X_train: ndarray, y_train: ndarray, X_val: ndarray, y_val: ndarray, X_test: ndarray, y_test: ndarray, initialize: bool = False, learning_rate: float = 1e-3, batch_size: int = 64, num_workers: int = 4):
|
||||
super().__init__()
|
||||
|
||||
# Set our init args as class attributes
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
self.in_dims = X_train.shape[2]
|
||||
self.n_classes = len(torch.unique(y_train))
|
||||
binary = self.n_classes == 2
|
||||
# Define PyTorch model
|
||||
if binary:
|
||||
self.metric = Accuracy(task="binary")
|
||||
self.loss_fn = F.binary_cross_entropy_with_logits
|
||||
self.out_activation = nn.Sigmoid()
|
||||
out_shape = 1
|
||||
else:
|
||||
self.metric = Accuracy(task="multiclass", num_classes=self.n_classes)
|
||||
self.loss_fn = F.cross_entropy
|
||||
self.out_activation = nn.Softmax(dim=1)
|
||||
out_shape = self.n_classes
|
||||
|
||||
self.init_time = time()
|
||||
self.model = nn.Sequential(
|
||||
nn.RNN(self.in_dims, 256, num_layers=2, batch_first=True),
|
||||
extract_tensor(),
|
||||
nn.Linear(256, out_shape)
|
||||
)
|
||||
|
||||
# Initialize model weights if needed
|
||||
if initialize:
|
||||
self.model = init_weights_classification(self.model, X_train, y_train, weights_method="mean", bias_method="mean")
|
||||
self.init_time = time() - self.init_time
|
||||
|
||||
self.train_data = TensorDataset(X_train, y_train)
|
||||
self.val_data = TensorDataset(X_val, y_val)
|
||||
self.test_data = TensorDataset(X_test, y_test)
|
||||
|
||||
def forward(self, x):
|
||||
logits = self.model(x)
|
||||
probas = self.out_activation(logits)
|
||||
return probas
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self.model(x)
|
||||
loss = self.loss_fn(logits, y)
|
||||
metric = self.metric(logits, y)
|
||||
|
||||
self.log("train_loss", loss, prog_bar=False, on_step=True, on_epoch=False)
|
||||
self.log("train_metric", metric, prog_bar=False, on_step=True, on_epoch=False)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self.model(x)
|
||||
loss = self.loss_fn(logits, y)
|
||||
metric = self.metric(logits, y)
|
||||
|
||||
self.log("val_loss", loss, prog_bar=True, on_step=True, on_epoch=False)
|
||||
self.log("val_metric", metric, prog_bar=True, on_step=True, on_epoch=False)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self.model(x)
|
||||
loss = self.loss_fn(logits, y)
|
||||
metric = self.metric(logits, y)
|
||||
|
||||
self.log("test_loss", loss, prog_bar=True)
|
||||
self.log("test_metric", metric, prog_bar=True)
|
||||
|
||||
def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
|
||||
x, y = batch
|
||||
probas = self(x)
|
||||
return probas
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=self.num_workers)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=self.num_workers)
|
||||
|
||||
def predict_dataloader(self):
|
||||
return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=self.num_workers)
|
||||
|
||||
|
||||
# Get logs from both models
|
||||
def merge_logs(init_model_logs: str, no_init_model_logs: str):
|
||||
init_logs = read_csv(init_model_logs, usecols=["step", "val_metric"]).dropna(axis=0)
|
||||
init_logs["method"] = "IDEAL"
|
||||
no_init_logs = read_csv(no_init_model_logs, usecols=["step", "val_metric"]).dropna(axis=0)
|
||||
no_init_logs["method"] = "He"
|
||||
full_logs = concat([init_logs, no_init_logs], ignore_index=True)
|
||||
return full_logs
|
731
image_classification/emnist.ipynb
Normal file
731
image_classification/emnist.ipynb
Normal file
|
@ -0,0 +1,731 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Import libraries and read data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"logging.getLogger(\"lightning.pytorch.utilities.rank_zero\").setLevel(logging.WARNING)\n",
|
||||
"logging.getLogger(\"lightning.pytorch.accelerators.cuda\").setLevel(logging.WARNING)\n",
|
||||
"from time import time\n",
|
||||
"from pandas import DataFrame\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from sklearn.utils import resample\n",
|
||||
"from torch import cuda\n",
|
||||
"from lightning import Trainer\n",
|
||||
"from lightning.pytorch.loggers import CSVLogger\n",
|
||||
"from torchvision import datasets\n",
|
||||
"from ideal_init.lightning_model import CNNClassifier, merge_logs\n",
|
||||
"\n",
|
||||
"# Define training parameters\n",
|
||||
"BATCH_SIZE = 2**13\n",
|
||||
"LEARNING_RATE = 1e-1\n",
|
||||
"DIRECTORY = \"./\"\n",
|
||||
"EPOCHS = 10\n",
|
||||
"DATASET_NAME = \"EMNIST\"\n",
|
||||
"\n",
|
||||
"def experiment_data():\n",
|
||||
" raw_data_train = datasets.EMNIST(\"data\", split=\"balanced\", train=True, download=True)\n",
|
||||
" X = raw_data_train.data.float().unsqueeze(1)\n",
|
||||
" y = raw_data_train.targets\n",
|
||||
" X, y = resample(X, y, replace=False, n_samples=100000, stratify=y)\n",
|
||||
" X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)\n",
|
||||
"\n",
|
||||
" raw_data_test = datasets.EMNIST(\"data\", split=\"balanced\", train=False, download=True)\n",
|
||||
" X_test = raw_data_test.data.float().unsqueeze(1)\n",
|
||||
" y_test = raw_data_test.targets\n",
|
||||
" X_test, y_test = resample(X, y, replace=False, n_samples=10000, stratify=y_test)\n",
|
||||
" return X_train, y_train, X_val, y_val, X_test, y_test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Run experiments and measure performance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# keep record of the results of all the iterations\n",
|
||||
"values=[\"IDEAL_acc_before_train\", \"IDEAL_acc_after_train\", \"IDEAL_init_time\", \"IDEAL_train_time\", \"He_acc_before_train\", \"He_acc_after_train\", \"He_init_time\", \"He_train_time\"]\n",
|
||||
"results = {value: [] for value in values}\n",
|
||||
"\n",
|
||||
"# repeat experiment multiple times\n",
|
||||
"NUM_EXPERIMENTS = 10\n",
|
||||
"for experiment in range(NUM_EXPERIMENTS):\n",
|
||||
"\n",
|
||||
" print(f\"Running experiment {experiment+1}\")\n",
|
||||
" X_train, y_train, X_val, y_val, X_test, y_test = experiment_data()\n",
|
||||
"\n",
|
||||
" # Create models\n",
|
||||
" init_model = CNNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=True, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
" no_init_model = CNNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=False, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
"\n",
|
||||
" # Create trainers\n",
|
||||
" init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
" no_init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
"\n",
|
||||
" # Test models before training\n",
|
||||
" results[\"IDEAL_acc_before_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_before_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Train models and plot training comparison\n",
|
||||
" init_time = time()\n",
|
||||
" init_trainer.validate(init_model)\n",
|
||||
" init_trainer.fit(init_model)\n",
|
||||
" init_time = time() - init_time\n",
|
||||
" init_model_logs = f\"{init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" no_init_time = time()\n",
|
||||
" no_init_trainer.validate(no_init_model)\n",
|
||||
" no_init_trainer.fit(no_init_model)\n",
|
||||
" no_init_time = time() - no_init_time\n",
|
||||
" no_init_model_logs = f\"{no_init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" logs = merge_logs(init_model_logs, no_init_model_logs)\n",
|
||||
" logs[\"dataset\"] = DATASET_NAME\n",
|
||||
"\n",
|
||||
" # Test models after training\n",
|
||||
" results[\"IDEAL_acc_after_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_after_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Init and train times\n",
|
||||
" results[\"IDEAL_init_time\"].append(init_model.init_time)\n",
|
||||
" results[\"IDEAL_train_time\"].append(init_time)\n",
|
||||
" results[\"He_init_time\"].append(no_init_model.init_time)\n",
|
||||
" results[\"He_train_time\"].append(no_init_time)\n",
|
||||
"\n",
|
||||
" # store logs\n",
|
||||
" with open(\"results.csv\", \"a\", encoding=\"utf8\") as results_file:\n",
|
||||
" logs.to_csv(results_file, header=False, index=False, quoting=1)\n",
|
||||
"\n",
|
||||
" #clear cache\n",
|
||||
" del init_model, no_init_model\n",
|
||||
" del init_trainer, no_init_trainer\n",
|
||||
" del X_train, y_train, X_val, y_val, X_test, y_test\n",
|
||||
" del logs\n",
|
||||
" cuda.empty_cache()\n",
|
||||
"\n",
|
||||
"results = DataFrame(results)\n",
|
||||
"mean = results.mean()\n",
|
||||
"error = results.sem()*1.96\n",
|
||||
"print(\"Mean values:\")\n",
|
||||
"print(mean.astype(str).apply(lambda x: x[:6]))\n",
|
||||
"print(\"Confidence error:\")\n",
|
||||
"print(error.astype(str).apply(lambda x: x[:6]))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"availableInstances": [
|
||||
{
|
||||
"_defaultOrder": 0,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.t3.medium",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 1,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.t3.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 2,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.t3.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 3,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.t3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 4,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 5,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 6,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 7,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 8,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 9,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 10,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 11,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 12,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5d.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 13,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5d.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 14,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5d.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 15,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5d.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 16,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5d.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 17,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5d.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 18,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5d.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 19,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 20,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": true,
|
||||
"memoryGiB": 0,
|
||||
"name": "ml.geospatial.interactive",
|
||||
"supportedImageNames": [
|
||||
"sagemaker-geospatial-v1-0"
|
||||
],
|
||||
"vcpuNum": 0
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 21,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.c5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 22,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.c5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 23,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.c5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 24,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.c5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 25,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 72,
|
||||
"name": "ml.c5.9xlarge",
|
||||
"vcpuNum": 36
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 26,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 96,
|
||||
"name": "ml.c5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 27,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 144,
|
||||
"name": "ml.c5.18xlarge",
|
||||
"vcpuNum": 72
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 28,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.c5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 29,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g4dn.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 30,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g4dn.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 31,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g4dn.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 32,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g4dn.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 33,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g4dn.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 34,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g4dn.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 35,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 61,
|
||||
"name": "ml.p3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 36,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 244,
|
||||
"name": "ml.p3.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 37,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 488,
|
||||
"name": "ml.p3.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 38,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.p3dn.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 39,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.r5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 40,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.r5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 41,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.r5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 42,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.r5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 43,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.r5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 44,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.r5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 45,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 512,
|
||||
"name": "ml.r5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 46,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.r5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 47,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 48,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 49,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 50,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 51,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 52,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 53,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.g5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 54,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.g5.48xlarge",
|
||||
"vcpuNum": 192
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 55,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 56,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4de.24xlarge",
|
||||
"vcpuNum": 96
|
||||
}
|
||||
],
|
||||
"instance_type": "ml.g4dn.xlarge",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
728
image_classification/fashion_mnist.ipynb
Normal file
728
image_classification/fashion_mnist.ipynb
Normal file
|
@ -0,0 +1,728 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Import libraries and read data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"logging.getLogger(\"lightning.pytorch.utilities.rank_zero\").setLevel(logging.WARNING)\n",
|
||||
"logging.getLogger(\"lightning.pytorch.accelerators.cuda\").setLevel(logging.WARNING)\n",
|
||||
"from time import time\n",
|
||||
"from pandas import DataFrame\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from torch import cuda\n",
|
||||
"from lightning import Trainer\n",
|
||||
"from lightning.pytorch.loggers import CSVLogger\n",
|
||||
"from torchvision import datasets\n",
|
||||
"from ideal_init.lightning_model import CNNClassifier, merge_logs\n",
|
||||
"\n",
|
||||
"# Define training parameters\n",
|
||||
"BATCH_SIZE = 2**11\n",
|
||||
"LEARNING_RATE = 1e-2\n",
|
||||
"DIRECTORY = \"./\"\n",
|
||||
"EPOCHS = 20\n",
|
||||
"DATASET_NAME = \"Fashion MNIST\"\n",
|
||||
"\n",
|
||||
"def experiment_data():\n",
|
||||
" raw_data_train = datasets.FashionMNIST(\"data\", train=True, download=True)\n",
|
||||
" X = raw_data_train.data.float().unsqueeze(1)\n",
|
||||
" y = raw_data_train.targets\n",
|
||||
" X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)\n",
|
||||
"\n",
|
||||
" raw_data_test = datasets.FashionMNIST(\"data\", train=False, download=True)\n",
|
||||
" X_test = raw_data_test.data.float().unsqueeze(1)\n",
|
||||
" y_test = raw_data_test.targets\n",
|
||||
" return X_train, y_train, X_val, y_val, X_test, y_test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Run experiments and measure performance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# keep record of the results of all the iterations\n",
|
||||
"values=[\"IDEAL_acc_before_train\", \"IDEAL_acc_after_train\", \"IDEAL_init_time\", \"IDEAL_train_time\", \"He_acc_before_train\", \"He_acc_after_train\", \"He_init_time\", \"He_train_time\"]\n",
|
||||
"results = {value: [] for value in values}\n",
|
||||
"\n",
|
||||
"# repeat experiment multiple times\n",
|
||||
"NUM_EXPERIMENTS = 10\n",
|
||||
"for experiment in range(NUM_EXPERIMENTS):\n",
|
||||
"\n",
|
||||
" print(f\"Running experiment {experiment+1}\")\n",
|
||||
" X_train, y_train, X_val, y_val, X_test, y_test = experiment_data()\n",
|
||||
"\n",
|
||||
" # Create models\n",
|
||||
" init_model = CNNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=True, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
" no_init_model = CNNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=False, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
"\n",
|
||||
" # Create trainers\n",
|
||||
" init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
" no_init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
"\n",
|
||||
" # Test models before training\n",
|
||||
" results[\"IDEAL_acc_before_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_before_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Train models and plot training comparison\n",
|
||||
" init_time = time()\n",
|
||||
" init_trainer.validate(init_model)\n",
|
||||
" init_trainer.fit(init_model)\n",
|
||||
" init_time = time() - init_time\n",
|
||||
" init_model_logs = f\"{init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" no_init_time = time()\n",
|
||||
" no_init_trainer.validate(no_init_model)\n",
|
||||
" no_init_trainer.fit(no_init_model)\n",
|
||||
" no_init_time = time() - no_init_time\n",
|
||||
" no_init_model_logs = f\"{no_init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" logs = merge_logs(init_model_logs, no_init_model_logs)\n",
|
||||
" logs[\"dataset\"] = DATASET_NAME\n",
|
||||
"\n",
|
||||
" # Test models after training\n",
|
||||
" results[\"IDEAL_acc_after_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_after_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Init and train times\n",
|
||||
" results[\"IDEAL_init_time\"].append(init_model.init_time)\n",
|
||||
" results[\"IDEAL_train_time\"].append(init_time)\n",
|
||||
" results[\"He_init_time\"].append(no_init_model.init_time)\n",
|
||||
" results[\"He_train_time\"].append(no_init_time)\n",
|
||||
"\n",
|
||||
" # store logs\n",
|
||||
" with open(\"results.csv\", \"a\", encoding=\"utf8\") as results_file:\n",
|
||||
" logs.to_csv(results_file, header=False, index=False, quoting=1)\n",
|
||||
"\n",
|
||||
" #clear cache\n",
|
||||
" del init_model, no_init_model\n",
|
||||
" del init_trainer, no_init_trainer\n",
|
||||
" del X_train, y_train, X_val, y_val, X_test, y_test\n",
|
||||
" del logs\n",
|
||||
" cuda.empty_cache()\n",
|
||||
"\n",
|
||||
"results = DataFrame(results)\n",
|
||||
"mean = results.mean()\n",
|
||||
"error = results.sem()*1.96\n",
|
||||
"print(\"Mean values:\")\n",
|
||||
"print(mean.astype(str).apply(lambda x: x[:6]))\n",
|
||||
"print(\"Confidence error:\")\n",
|
||||
"print(error.astype(str).apply(lambda x: x[:6]))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"availableInstances": [
|
||||
{
|
||||
"_defaultOrder": 0,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.t3.medium",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 1,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.t3.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 2,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.t3.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 3,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.t3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 4,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 5,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 6,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 7,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 8,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 9,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 10,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 11,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 12,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5d.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 13,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5d.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 14,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5d.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 15,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5d.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 16,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5d.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 17,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5d.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 18,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5d.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 19,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 20,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": true,
|
||||
"memoryGiB": 0,
|
||||
"name": "ml.geospatial.interactive",
|
||||
"supportedImageNames": [
|
||||
"sagemaker-geospatial-v1-0"
|
||||
],
|
||||
"vcpuNum": 0
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 21,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.c5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 22,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.c5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 23,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.c5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 24,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.c5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 25,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 72,
|
||||
"name": "ml.c5.9xlarge",
|
||||
"vcpuNum": 36
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 26,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 96,
|
||||
"name": "ml.c5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 27,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 144,
|
||||
"name": "ml.c5.18xlarge",
|
||||
"vcpuNum": 72
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 28,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.c5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 29,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g4dn.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 30,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g4dn.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 31,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g4dn.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 32,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g4dn.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 33,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g4dn.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 34,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g4dn.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 35,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 61,
|
||||
"name": "ml.p3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 36,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 244,
|
||||
"name": "ml.p3.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 37,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 488,
|
||||
"name": "ml.p3.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 38,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.p3dn.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 39,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.r5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 40,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.r5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 41,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.r5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 42,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.r5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 43,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.r5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 44,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.r5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 45,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 512,
|
||||
"name": "ml.r5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 46,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.r5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 47,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 48,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 49,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 50,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 51,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 52,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 53,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.g5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 54,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.g5.48xlarge",
|
||||
"vcpuNum": 192
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 55,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 56,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4de.24xlarge",
|
||||
"vcpuNum": 96
|
||||
}
|
||||
],
|
||||
"instance_type": "ml.g4dn.xlarge",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
728
image_classification/kmnist.ipynb
Normal file
728
image_classification/kmnist.ipynb
Normal file
|
@ -0,0 +1,728 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Import libraries and read data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"logging.getLogger(\"lightning.pytorch.utilities.rank_zero\").setLevel(logging.WARNING)\n",
|
||||
"logging.getLogger(\"lightning.pytorch.accelerators.cuda\").setLevel(logging.WARNING)\n",
|
||||
"from time import time\n",
|
||||
"from pandas import DataFrame\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from torch import cuda\n",
|
||||
"from lightning import Trainer\n",
|
||||
"from lightning.pytorch.loggers import CSVLogger\n",
|
||||
"from torchvision import datasets\n",
|
||||
"from ideal_init.lightning_model import CNNClassifier, merge_logs\n",
|
||||
"\n",
|
||||
"# Define training parameters\n",
|
||||
"BATCH_SIZE = 2**11\n",
|
||||
"LEARNING_RATE = 1e-2\n",
|
||||
"DIRECTORY = \"./\"\n",
|
||||
"EPOCHS = 20\n",
|
||||
"DATASET_NAME = \"KMNIST\"\n",
|
||||
"\n",
|
||||
"def experiment_data():\n",
|
||||
" raw_data_train = datasets.KMNIST(\"data\", train=True, download=True)\n",
|
||||
" X = raw_data_train.data.float().unsqueeze(1)\n",
|
||||
" y = raw_data_train.targets\n",
|
||||
" X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)\n",
|
||||
"\n",
|
||||
" raw_data_test = datasets.KMNIST(\"data\", train=False, download=True)\n",
|
||||
" X_test = raw_data_test.data.float().unsqueeze(1)\n",
|
||||
" y_test = raw_data_test.targets\n",
|
||||
" return X_train, y_train, X_val, y_val, X_test, y_test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Run experiments and measure performance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# keep record of the results of all the iterations\n",
|
||||
"values=[\"IDEAL_acc_before_train\", \"IDEAL_acc_after_train\", \"IDEAL_init_time\", \"IDEAL_train_time\", \"He_acc_before_train\", \"He_acc_after_train\", \"He_init_time\", \"He_train_time\"]\n",
|
||||
"results = {value: [] for value in values}\n",
|
||||
"\n",
|
||||
"# repeat experiment multiple times\n",
|
||||
"NUM_EXPERIMENTS = 10\n",
|
||||
"for experiment in range(NUM_EXPERIMENTS):\n",
|
||||
"\n",
|
||||
" print(f\"Running experiment {experiment+1}\")\n",
|
||||
" X_train, y_train, X_val, y_val, X_test, y_test = experiment_data()\n",
|
||||
"\n",
|
||||
" # Create models\n",
|
||||
" init_model = CNNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=True, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
" no_init_model = CNNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=False, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
"\n",
|
||||
" # Create trainers\n",
|
||||
" init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
" no_init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
"\n",
|
||||
" # Test models before training\n",
|
||||
" results[\"IDEAL_acc_before_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_before_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Train models and plot training comparison\n",
|
||||
" init_time = time()\n",
|
||||
" init_trainer.validate(init_model)\n",
|
||||
" init_trainer.fit(init_model)\n",
|
||||
" init_time = time() - init_time\n",
|
||||
" init_model_logs = f\"{init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" no_init_time = time()\n",
|
||||
" no_init_trainer.validate(no_init_model)\n",
|
||||
" no_init_trainer.fit(no_init_model)\n",
|
||||
" no_init_time = time() - no_init_time\n",
|
||||
" no_init_model_logs = f\"{no_init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" logs = merge_logs(init_model_logs, no_init_model_logs)\n",
|
||||
" logs[\"dataset\"] = DATASET_NAME\n",
|
||||
"\n",
|
||||
" # Test models after training\n",
|
||||
" results[\"IDEAL_acc_after_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_after_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Init and train times\n",
|
||||
" results[\"IDEAL_init_time\"].append(init_model.init_time)\n",
|
||||
" results[\"IDEAL_train_time\"].append(init_time)\n",
|
||||
" results[\"He_init_time\"].append(no_init_model.init_time)\n",
|
||||
" results[\"He_train_time\"].append(no_init_time)\n",
|
||||
"\n",
|
||||
" # store logs\n",
|
||||
" with open(\"results.csv\", \"a\", encoding=\"utf8\") as results_file:\n",
|
||||
" logs.to_csv(results_file, header=False, index=False, quoting=1)\n",
|
||||
"\n",
|
||||
" #clear cache\n",
|
||||
" del init_model, no_init_model\n",
|
||||
" del init_trainer, no_init_trainer\n",
|
||||
" del X_train, y_train, X_val, y_val, X_test, y_test\n",
|
||||
" del logs\n",
|
||||
" cuda.empty_cache()\n",
|
||||
"\n",
|
||||
"results = DataFrame(results)\n",
|
||||
"mean = results.mean()\n",
|
||||
"error = results.sem()*1.96\n",
|
||||
"print(\"Mean values:\")\n",
|
||||
"print(mean.astype(str).apply(lambda x: x[:6]))\n",
|
||||
"print(\"Confidence error:\")\n",
|
||||
"print(error.astype(str).apply(lambda x: x[:6]))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"availableInstances": [
|
||||
{
|
||||
"_defaultOrder": 0,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.t3.medium",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 1,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.t3.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 2,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.t3.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 3,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.t3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 4,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 5,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 6,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 7,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 8,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 9,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 10,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 11,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 12,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5d.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 13,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5d.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 14,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5d.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 15,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5d.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 16,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5d.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 17,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5d.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 18,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5d.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 19,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 20,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": true,
|
||||
"memoryGiB": 0,
|
||||
"name": "ml.geospatial.interactive",
|
||||
"supportedImageNames": [
|
||||
"sagemaker-geospatial-v1-0"
|
||||
],
|
||||
"vcpuNum": 0
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 21,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.c5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 22,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.c5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 23,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.c5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 24,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.c5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 25,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 72,
|
||||
"name": "ml.c5.9xlarge",
|
||||
"vcpuNum": 36
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 26,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 96,
|
||||
"name": "ml.c5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 27,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 144,
|
||||
"name": "ml.c5.18xlarge",
|
||||
"vcpuNum": 72
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 28,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.c5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 29,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g4dn.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 30,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g4dn.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 31,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g4dn.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 32,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g4dn.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 33,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g4dn.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 34,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g4dn.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 35,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 61,
|
||||
"name": "ml.p3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 36,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 244,
|
||||
"name": "ml.p3.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 37,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 488,
|
||||
"name": "ml.p3.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 38,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.p3dn.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 39,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.r5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 40,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.r5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 41,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.r5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 42,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.r5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 43,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.r5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 44,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.r5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 45,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 512,
|
||||
"name": "ml.r5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 46,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.r5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 47,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 48,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 49,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 50,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 51,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 52,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 53,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.g5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 54,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.g5.48xlarge",
|
||||
"vcpuNum": 192
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 55,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 56,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4de.24xlarge",
|
||||
"vcpuNum": 96
|
||||
}
|
||||
],
|
||||
"instance_type": "ml.g4dn.xlarge",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
728
image_classification/mnist.ipynb
Normal file
728
image_classification/mnist.ipynb
Normal file
|
@ -0,0 +1,728 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Import libraries and read data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"logging.getLogger(\"lightning.pytorch.utilities.rank_zero\").setLevel(logging.WARNING)\n",
|
||||
"logging.getLogger(\"lightning.pytorch.accelerators.cuda\").setLevel(logging.WARNING)\n",
|
||||
"from time import time\n",
|
||||
"from pandas import DataFrame\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from torch import cuda\n",
|
||||
"from lightning import Trainer\n",
|
||||
"from lightning.pytorch.loggers import CSVLogger\n",
|
||||
"from torchvision import datasets\n",
|
||||
"from ideal_init.lightning_model import CNNClassifier, merge_logs\n",
|
||||
"\n",
|
||||
"# Define training parameters\n",
|
||||
"BATCH_SIZE = 2**11\n",
|
||||
"LEARNING_RATE = 1e-2\n",
|
||||
"DIRECTORY = \"./\"\n",
|
||||
"EPOCHS = 13\n",
|
||||
"DATASET_NAME = \"MNIST\"\n",
|
||||
"\n",
|
||||
"def experiment_data():\n",
|
||||
" raw_data_train = datasets.MNIST(\"data\", train=True, download=True)\n",
|
||||
" X = raw_data_train.data.float().unsqueeze(1)\n",
|
||||
" y = raw_data_train.targets\n",
|
||||
" X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)\n",
|
||||
"\n",
|
||||
" raw_data_test = datasets.MNIST(\"data\", train=False, download=True)\n",
|
||||
" X_test = raw_data_test.data.float().unsqueeze(1)\n",
|
||||
" y_test = raw_data_test.targets\n",
|
||||
" return X_train, y_train, X_val, y_val, X_test, y_test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Run experiments and measure performance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# keep record of the results of all the iterations\n",
|
||||
"values=[\"IDEAL_acc_before_train\", \"IDEAL_acc_after_train\", \"IDEAL_init_time\", \"IDEAL_train_time\", \"He_acc_before_train\", \"He_acc_after_train\", \"He_init_time\", \"He_train_time\"]\n",
|
||||
"results = {value: [] for value in values}\n",
|
||||
"\n",
|
||||
"# repeat experiment multiple times\n",
|
||||
"NUM_EXPERIMENTS = 10\n",
|
||||
"for experiment in range(NUM_EXPERIMENTS):\n",
|
||||
"\n",
|
||||
" print(f\"Running experiment {experiment+1}\")\n",
|
||||
" X_train, y_train, X_val, y_val, X_test, y_test = experiment_data()\n",
|
||||
"\n",
|
||||
" # Create models\n",
|
||||
" init_model = CNNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=True, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
" no_init_model = CNNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=False, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
"\n",
|
||||
" # Create trainers\n",
|
||||
" init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
" no_init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
"\n",
|
||||
" # Test models before training\n",
|
||||
" results[\"IDEAL_acc_before_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_before_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Train models and plot training comparison\n",
|
||||
" init_time = time()\n",
|
||||
" init_trainer.validate(init_model)\n",
|
||||
" init_trainer.fit(init_model)\n",
|
||||
" init_time = time() - init_time\n",
|
||||
" init_model_logs = f\"{init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" no_init_time = time()\n",
|
||||
" no_init_trainer.validate(no_init_model)\n",
|
||||
" no_init_trainer.fit(no_init_model)\n",
|
||||
" no_init_time = time() - no_init_time\n",
|
||||
" no_init_model_logs = f\"{no_init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" logs = merge_logs(init_model_logs, no_init_model_logs)\n",
|
||||
" logs[\"dataset\"] = DATASET_NAME\n",
|
||||
"\n",
|
||||
" # Test models after training\n",
|
||||
" results[\"IDEAL_acc_after_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_after_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Init and train times\n",
|
||||
" results[\"IDEAL_init_time\"].append(init_model.init_time)\n",
|
||||
" results[\"IDEAL_train_time\"].append(init_time)\n",
|
||||
" results[\"He_init_time\"].append(no_init_model.init_time)\n",
|
||||
" results[\"He_train_time\"].append(no_init_time)\n",
|
||||
"\n",
|
||||
" # store logs\n",
|
||||
" with open(\"results.csv\", \"a\", encoding=\"utf8\") as results_file:\n",
|
||||
" logs.to_csv(results_file, header=False, index=False, quoting=1)\n",
|
||||
"\n",
|
||||
" #clear cache\n",
|
||||
" del init_model, no_init_model\n",
|
||||
" del init_trainer, no_init_trainer\n",
|
||||
" del X_train, y_train, X_val, y_val, X_test, y_test\n",
|
||||
" del logs\n",
|
||||
" cuda.empty_cache()\n",
|
||||
"\n",
|
||||
"results = DataFrame(results)\n",
|
||||
"mean = results.mean()\n",
|
||||
"error = results.sem()*1.96\n",
|
||||
"print(\"Mean values:\")\n",
|
||||
"print(mean.astype(str).apply(lambda x: x[:6]))\n",
|
||||
"print(\"Confidence error:\")\n",
|
||||
"print(error.astype(str).apply(lambda x: x[:6]))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"availableInstances": [
|
||||
{
|
||||
"_defaultOrder": 0,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.t3.medium",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 1,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.t3.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 2,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.t3.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 3,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.t3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 4,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 5,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 6,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 7,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 8,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 9,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 10,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 11,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 12,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5d.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 13,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5d.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 14,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5d.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 15,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5d.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 16,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5d.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 17,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5d.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 18,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5d.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 19,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 20,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": true,
|
||||
"memoryGiB": 0,
|
||||
"name": "ml.geospatial.interactive",
|
||||
"supportedImageNames": [
|
||||
"sagemaker-geospatial-v1-0"
|
||||
],
|
||||
"vcpuNum": 0
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 21,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.c5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 22,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.c5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 23,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.c5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 24,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.c5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 25,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 72,
|
||||
"name": "ml.c5.9xlarge",
|
||||
"vcpuNum": 36
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 26,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 96,
|
||||
"name": "ml.c5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 27,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 144,
|
||||
"name": "ml.c5.18xlarge",
|
||||
"vcpuNum": 72
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 28,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.c5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 29,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g4dn.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 30,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g4dn.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 31,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g4dn.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 32,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g4dn.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 33,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g4dn.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 34,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g4dn.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 35,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 61,
|
||||
"name": "ml.p3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 36,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 244,
|
||||
"name": "ml.p3.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 37,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 488,
|
||||
"name": "ml.p3.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 38,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.p3dn.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 39,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.r5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 40,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.r5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 41,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.r5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 42,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.r5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 43,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.r5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 44,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.r5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 45,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 512,
|
||||
"name": "ml.r5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 46,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.r5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 47,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 48,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 49,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 50,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 51,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 52,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 53,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.g5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 54,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.g5.48xlarge",
|
||||
"vcpuNum": 192
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 55,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 56,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4de.24xlarge",
|
||||
"vcpuNum": 96
|
||||
}
|
||||
],
|
||||
"instance_type": "ml.g4dn.xlarge",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
50
plot_training.ipynb
Normal file
50
plot_training.ipynb
Normal file
File diff suppressed because one or more lines are too long
BIN
results.png
Normal file
BIN
results.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 597 KiB |
752
sequence_classification/20newsgroups.ipynb
Normal file
752
sequence_classification/20newsgroups.ipynb
Normal file
|
@ -0,0 +1,752 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Import libraries and read data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"logging.getLogger(\"lightning.pytorch.utilities.rank_zero\").setLevel(logging.WARNING)\n",
|
||||
"logging.getLogger(\"lightning.pytorch.accelerators.cuda\").setLevel(logging.WARNING)\n",
|
||||
"from time import time\n",
|
||||
"from pandas import DataFrame\n",
|
||||
"from sklearn.datasets import fetch_20newsgroups\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"import torch\n",
|
||||
"from torchtext.vocab import GloVe\n",
|
||||
"from torch.nn.utils.rnn import pad_sequence\n",
|
||||
"from torchtext.data.utils import get_tokenizer\n",
|
||||
"from lightning import Trainer\n",
|
||||
"from lightning.pytorch.loggers import CSVLogger\n",
|
||||
"from ideal_init.lightning_model import RNNClassifier, merge_logs\n",
|
||||
"\n",
|
||||
"# Define training parameters\n",
|
||||
"BATCH_SIZE = 2**11\n",
|
||||
"LEARNING_RATE = 1e-2\n",
|
||||
"DIRECTORY = \"./\"\n",
|
||||
"EPOCHS = 10\n",
|
||||
"DATASET_NAME = \"20 newsgroups\"\n",
|
||||
"\n",
|
||||
"categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space']\n",
|
||||
"tokenizer = get_tokenizer(\"spacy\")\n",
|
||||
"embedding_function = GloVe(name=\"42B\", dim=300)\n",
|
||||
"def text_embedding(words):\n",
|
||||
" embeddings = [embedding_function[word] for word in words]\n",
|
||||
" embeddings = torch.stack(embeddings)\n",
|
||||
" return embeddings\n",
|
||||
"\n",
|
||||
"def text_data(subset):\n",
|
||||
" texts, labels = fetch_20newsgroups(return_X_y=True, subset=subset, remove=(\"headers\", \"footers\", \"quotes\"), categories=categories)\n",
|
||||
" X = []\n",
|
||||
" y = []\n",
|
||||
" max_len = 400\n",
|
||||
" for label, text in zip(labels, texts):\n",
|
||||
" try:\n",
|
||||
" tokens = tokenizer(text.lower())\n",
|
||||
" if len(tokens) > max_len:\n",
|
||||
" tokens = tokens[:max_len]\n",
|
||||
" X.append(text_embedding(tokens))\n",
|
||||
" y.append(label)\n",
|
||||
" except:\n",
|
||||
" continue\n",
|
||||
" X = pad_sequence(X, batch_first=True)\n",
|
||||
" y = torch.tensor(y)\n",
|
||||
" return X, y\n",
|
||||
"\n",
|
||||
"def experiment_data():\n",
|
||||
" X_train, y_train = text_data(\"train\")\n",
|
||||
" X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, stratify=y_train)\n",
|
||||
" X_test, y_test = text_data(\"test\")\n",
|
||||
" return X_train, y_train, X_val, y_val, X_test, y_test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Run experiments and measure performance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# keep record of the results of all the iterations\n",
|
||||
"values=[\"IDEAL_acc_before_train\", \"IDEAL_acc_after_train\", \"IDEAL_init_time\", \"IDEAL_train_time\", \"He_acc_before_train\", \"He_acc_after_train\", \"He_init_time\", \"He_train_time\"]\n",
|
||||
"results = {value: [] for value in values}\n",
|
||||
"\n",
|
||||
"# repeat experiment multiple times\n",
|
||||
"NUM_EXPERIMENTS = 10\n",
|
||||
"for experiment in range(NUM_EXPERIMENTS):\n",
|
||||
"\n",
|
||||
" print(f\"Running experiment {experiment+1}\")\n",
|
||||
" X_train, y_train, X_val, y_val, X_test, y_test = experiment_data()\n",
|
||||
"\n",
|
||||
" # Create models\n",
|
||||
" init_model = RNNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=True, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
" no_init_model = RNNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=False, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
"\n",
|
||||
" # Create trainers\n",
|
||||
" init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
" no_init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
"\n",
|
||||
" # Test models before training\n",
|
||||
" results[\"IDEAL_acc_before_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_before_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Train models and plot training comparison\n",
|
||||
" init_time = time()\n",
|
||||
" init_trainer.validate(init_model)\n",
|
||||
" init_trainer.fit(init_model)\n",
|
||||
" init_time = time() - init_time\n",
|
||||
" init_model_logs = f\"{init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" no_init_time = time()\n",
|
||||
" no_init_trainer.validate(no_init_model)\n",
|
||||
" no_init_trainer.fit(no_init_model)\n",
|
||||
" no_init_time = time() - no_init_time\n",
|
||||
" no_init_model_logs = f\"{no_init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" logs = merge_logs(init_model_logs, no_init_model_logs)\n",
|
||||
" logs[\"dataset\"] = DATASET_NAME\n",
|
||||
"\n",
|
||||
" # Test models after training\n",
|
||||
" results[\"IDEAL_acc_after_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_after_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Init and train times\n",
|
||||
" results[\"IDEAL_init_time\"].append(init_model.init_time)\n",
|
||||
" results[\"IDEAL_train_time\"].append(init_time)\n",
|
||||
" results[\"He_init_time\"].append(no_init_model.init_time)\n",
|
||||
" results[\"He_train_time\"].append(no_init_time)\n",
|
||||
"\n",
|
||||
" # store logs\n",
|
||||
" with open(\"results.csv\", \"a\", encoding=\"utf8\") as results_file:\n",
|
||||
" logs.to_csv(results_file, header=False, index=False, quoting=1)\n",
|
||||
"\n",
|
||||
" #clear cache\n",
|
||||
" del init_model, no_init_model\n",
|
||||
" del init_trainer, no_init_trainer\n",
|
||||
" del X_train, y_train, X_val, y_val, X_test, y_test\n",
|
||||
" del logs\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
"\n",
|
||||
"results = DataFrame(results)\n",
|
||||
"mean = results.mean()\n",
|
||||
"error = results.sem()*1.96\n",
|
||||
"print(\"Mean values:\")\n",
|
||||
"print(mean.astype(str).apply(lambda x: x[:6]))\n",
|
||||
"print(\"Confidence error:\")\n",
|
||||
"print(error.astype(str).apply(lambda x: x[:6]))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"availableInstances": [
|
||||
{
|
||||
"_defaultOrder": 0,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.t3.medium",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 1,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.t3.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 2,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.t3.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 3,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.t3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 4,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 5,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 6,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 7,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 8,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 9,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 10,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 11,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 12,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5d.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 13,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5d.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 14,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5d.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 15,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5d.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 16,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5d.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 17,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5d.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 18,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5d.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 19,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 20,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": true,
|
||||
"memoryGiB": 0,
|
||||
"name": "ml.geospatial.interactive",
|
||||
"supportedImageNames": [
|
||||
"sagemaker-geospatial-v1-0"
|
||||
],
|
||||
"vcpuNum": 0
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 21,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.c5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 22,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.c5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 23,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.c5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 24,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.c5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 25,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 72,
|
||||
"name": "ml.c5.9xlarge",
|
||||
"vcpuNum": 36
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 26,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 96,
|
||||
"name": "ml.c5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 27,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 144,
|
||||
"name": "ml.c5.18xlarge",
|
||||
"vcpuNum": 72
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 28,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.c5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 29,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g4dn.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 30,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g4dn.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 31,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g4dn.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 32,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g4dn.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 33,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g4dn.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 34,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g4dn.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 35,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 61,
|
||||
"name": "ml.p3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 36,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 244,
|
||||
"name": "ml.p3.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 37,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 488,
|
||||
"name": "ml.p3.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 38,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.p3dn.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 39,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.r5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 40,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.r5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 41,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.r5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 42,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.r5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 43,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.r5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 44,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.r5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 45,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 512,
|
||||
"name": "ml.r5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 46,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.r5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 47,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 48,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 49,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 50,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 51,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 52,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 53,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.g5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 54,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.g5.48xlarge",
|
||||
"vcpuNum": 192
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 55,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 56,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4de.24xlarge",
|
||||
"vcpuNum": 96
|
||||
}
|
||||
],
|
||||
"instance_type": "ml.g4dn.xlarge",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
744
sequence_classification/sequences.ipynb
Normal file
744
sequence_classification/sequences.ipynb
Normal file
|
@ -0,0 +1,744 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Import libraries and read data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"logging.getLogger(\"lightning.pytorch.utilities.rank_zero\").setLevel(logging.WARNING)\n",
|
||||
"logging.getLogger(\"lightning.pytorch.accelerators.cuda\").setLevel(logging.WARNING)\n",
|
||||
"from time import time\n",
|
||||
"import numpy as np\n",
|
||||
"from pandas import DataFrame\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"import torch\n",
|
||||
"from lightning import Trainer\n",
|
||||
"from lightning.pytorch.loggers import CSVLogger\n",
|
||||
"from ideal_init.lightning_model import RNNClassifier, merge_logs\n",
|
||||
"\n",
|
||||
"# Define training parameters\n",
|
||||
"BATCH_SIZE = 2**10\n",
|
||||
"LEARNING_RATE = 1e-2\n",
|
||||
"DIRECTORY = \"./\"\n",
|
||||
"EPOCHS = 20\n",
|
||||
"DATASET_NAME = \"Sequences\"\n",
|
||||
"\n",
|
||||
"def experiment_data():\n",
|
||||
" # Define the sequence length and number of features\n",
|
||||
" num_sequences = 1000\n",
|
||||
" seq_length = 20\n",
|
||||
" num_features = 10\n",
|
||||
"\n",
|
||||
" # Generate the dataset\n",
|
||||
" X = np.concatenate([\n",
|
||||
" np.random.uniform(low=-1, high=1, size=(num_sequences, seq_length, num_features)),\n",
|
||||
" np.random.normal(loc=0, scale=2, size=(num_sequences, seq_length, num_features)),\n",
|
||||
" np.random.exponential(scale=0.5, size=(num_sequences, seq_length, num_features))\n",
|
||||
" ], axis=0)\n",
|
||||
" y = np.concatenate([\n",
|
||||
" np.repeat(0, num_sequences),\n",
|
||||
" np.repeat(1, num_sequences),\n",
|
||||
" np.repeat(2, num_sequences)\n",
|
||||
" ])\n",
|
||||
" X, X_test, y, y_test = train_test_split(X, y, test_size=0.2, stratify=y)\n",
|
||||
" X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.125, stratify=y)\n",
|
||||
" X_train = torch.from_numpy(X_train).float()\n",
|
||||
" X_val = torch.from_numpy(X_val).float()\n",
|
||||
" X_test = torch.from_numpy(X_test).float()\n",
|
||||
" y_train = torch.from_numpy(y_train)\n",
|
||||
" y_val = torch.from_numpy(y_val)\n",
|
||||
" y_test = torch.from_numpy(y_test)\n",
|
||||
" return X_train, y_train, X_val, y_val, X_test, y_test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Run experiments and measure performance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# keep record of the results of all the iterations\n",
|
||||
"values=[\"IDEAL_acc_before_train\", \"IDEAL_acc_after_train\", \"IDEAL_init_time\", \"IDEAL_train_time\", \"He_acc_before_train\", \"He_acc_after_train\", \"He_init_time\", \"He_train_time\"]\n",
|
||||
"results = {value: [] for value in values}\n",
|
||||
"\n",
|
||||
"# repeat experiment multiple times\n",
|
||||
"NUM_EXPERIMENTS = 10\n",
|
||||
"for experiment in range(NUM_EXPERIMENTS):\n",
|
||||
"\n",
|
||||
" print(f\"Running experiment {experiment+1}\")\n",
|
||||
" X_train, y_train, X_val, y_val, X_test, y_test = experiment_data()\n",
|
||||
"\n",
|
||||
" # Create models\n",
|
||||
" init_model = RNNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=True, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
" no_init_model = RNNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=False, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
"\n",
|
||||
" # Create trainers\n",
|
||||
" init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
" no_init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
"\n",
|
||||
" # Test models before training\n",
|
||||
" results[\"IDEAL_acc_before_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_before_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Train models and plot training comparison\n",
|
||||
" init_time = time()\n",
|
||||
" init_trainer.validate(init_model)\n",
|
||||
" init_trainer.fit(init_model)\n",
|
||||
" init_time = time() - init_time\n",
|
||||
" init_model_logs = f\"{init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" no_init_time = time()\n",
|
||||
" no_init_trainer.validate(no_init_model)\n",
|
||||
" no_init_trainer.fit(no_init_model)\n",
|
||||
" no_init_time = time() - no_init_time\n",
|
||||
" no_init_model_logs = f\"{no_init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" logs = merge_logs(init_model_logs, no_init_model_logs)\n",
|
||||
" logs[\"dataset\"] = DATASET_NAME\n",
|
||||
"\n",
|
||||
" # Test models after training\n",
|
||||
" results[\"IDEAL_acc_after_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_after_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Init and train times\n",
|
||||
" results[\"IDEAL_init_time\"].append(init_model.init_time)\n",
|
||||
" results[\"IDEAL_train_time\"].append(init_time)\n",
|
||||
" results[\"He_init_time\"].append(no_init_model.init_time)\n",
|
||||
" results[\"He_train_time\"].append(no_init_time)\n",
|
||||
"\n",
|
||||
" # store logs\n",
|
||||
" with open(\"results.csv\", \"a\", encoding=\"utf8\") as results_file:\n",
|
||||
" logs.to_csv(results_file, header=False, index=False, quoting=1)\n",
|
||||
"\n",
|
||||
" #clear cache\n",
|
||||
" del init_model, no_init_model\n",
|
||||
" del init_trainer, no_init_trainer\n",
|
||||
" del X_train, y_train, X_val, y_val, X_test, y_test\n",
|
||||
" del logs\n",
|
||||
" torch.cuda.empty_cache()\n",
|
||||
"\n",
|
||||
"results = DataFrame(results)\n",
|
||||
"mean = results.mean()\n",
|
||||
"error = results.sem()*1.96\n",
|
||||
"print(\"Mean values:\")\n",
|
||||
"print(mean.astype(str).apply(lambda x: x[:6]))\n",
|
||||
"print(\"Confidence error:\")\n",
|
||||
"print(error.astype(str).apply(lambda x: x[:6]))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"availableInstances": [
|
||||
{
|
||||
"_defaultOrder": 0,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.t3.medium",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 1,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.t3.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 2,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.t3.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 3,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.t3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 4,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 5,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 6,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 7,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 8,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 9,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 10,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 11,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 12,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5d.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 13,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5d.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 14,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5d.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 15,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5d.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 16,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5d.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 17,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5d.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 18,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5d.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 19,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 20,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": true,
|
||||
"memoryGiB": 0,
|
||||
"name": "ml.geospatial.interactive",
|
||||
"supportedImageNames": [
|
||||
"sagemaker-geospatial-v1-0"
|
||||
],
|
||||
"vcpuNum": 0
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 21,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.c5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 22,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.c5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 23,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.c5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 24,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.c5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 25,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 72,
|
||||
"name": "ml.c5.9xlarge",
|
||||
"vcpuNum": 36
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 26,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 96,
|
||||
"name": "ml.c5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 27,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 144,
|
||||
"name": "ml.c5.18xlarge",
|
||||
"vcpuNum": 72
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 28,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.c5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 29,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g4dn.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 30,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g4dn.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 31,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g4dn.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 32,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g4dn.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 33,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g4dn.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 34,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g4dn.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 35,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 61,
|
||||
"name": "ml.p3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 36,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 244,
|
||||
"name": "ml.p3.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 37,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 488,
|
||||
"name": "ml.p3.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 38,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.p3dn.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 39,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.r5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 40,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.r5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 41,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.r5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 42,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.r5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 43,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.r5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 44,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.r5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 45,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 512,
|
||||
"name": "ml.r5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 46,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.r5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 47,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 48,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 49,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 50,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 51,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 52,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 53,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.g5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 54,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.g5.48xlarge",
|
||||
"vcpuNum": 192
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 55,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 56,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4de.24xlarge",
|
||||
"vcpuNum": 96
|
||||
}
|
||||
],
|
||||
"instance_type": "ml.g4dn.xlarge",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
725
tabular_classification/breast_cancer.ipynb
Normal file
725
tabular_classification/breast_cancer.ipynb
Normal file
|
@ -0,0 +1,725 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Import libraries and read data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"logging.getLogger(\"lightning.pytorch.utilities.rank_zero\").setLevel(logging.WARNING)\n",
|
||||
"logging.getLogger(\"lightning.pytorch.accelerators.cuda\").setLevel(logging.WARNING)\n",
|
||||
"from time import time\n",
|
||||
"from pandas import DataFrame\n",
|
||||
"from sklearn.datasets import load_breast_cancer\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from torch import cuda\n",
|
||||
"from lightning import Trainer\n",
|
||||
"from lightning.pytorch.loggers import CSVLogger\n",
|
||||
"from ideal_init.lightning_model import NNClassifier, merge_logs\n",
|
||||
"\n",
|
||||
"# Define training parameters\n",
|
||||
"BATCH_SIZE = 256\n",
|
||||
"LEARNING_RATE = 1e-1\n",
|
||||
"HIDDEN_SIZES = (10,)\n",
|
||||
"DIRECTORY = \"./\"\n",
|
||||
"EPOCHS = 10\n",
|
||||
"DATASET_NAME = \"Breast cancer\"\n",
|
||||
"\n",
|
||||
"def experiment_data():\n",
|
||||
" X, y = load_breast_cancer(return_X_y=True)\n",
|
||||
" # Split data: 70% train, 10% validation, 20% test\n",
|
||||
" X, X_test, y, y_test = train_test_split(X, y, test_size=0.2)\n",
|
||||
" X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.125)\n",
|
||||
" return X_train, y_train, X_val, y_val, X_test, y_test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Run experiments and measure performance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# keep record of the results of all the iterations\n",
|
||||
"values=[\"IDEAL_acc_before_train\", \"IDEAL_acc_after_train\", \"IDEAL_init_time\", \"IDEAL_train_time\", \"He_acc_before_train\", \"He_acc_after_train\", \"He_init_time\", \"He_train_time\"]\n",
|
||||
"results = {value: [] for value in values}\n",
|
||||
"\n",
|
||||
"# repeat experiment multiple times\n",
|
||||
"NUM_EXPERIMENTS = 10\n",
|
||||
"for experiment in range(NUM_EXPERIMENTS):\n",
|
||||
"\n",
|
||||
" print(f\"Running experiment {experiment+1}\")\n",
|
||||
" X_train, y_train, X_val, y_val, X_test, y_test = experiment_data()\n",
|
||||
"\n",
|
||||
" # Create models\n",
|
||||
" init_model = NNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=True, hidden_sizes=HIDDEN_SIZES, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
" no_init_model = NNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=False, hidden_sizes=HIDDEN_SIZES, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
"\n",
|
||||
" # Create trainers\n",
|
||||
" init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
" no_init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
"\n",
|
||||
" # Test models before training\n",
|
||||
" results[\"IDEAL_acc_before_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_before_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Train models and plot training comparison\n",
|
||||
" init_time = time()\n",
|
||||
" init_trainer.validate(init_model)\n",
|
||||
" init_trainer.fit(init_model)\n",
|
||||
" init_time = time() - init_time\n",
|
||||
" init_model_logs = f\"{init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" no_init_time = time()\n",
|
||||
" no_init_trainer.validate(no_init_model)\n",
|
||||
" no_init_trainer.fit(no_init_model)\n",
|
||||
" no_init_time = time() - no_init_time\n",
|
||||
" no_init_model_logs = f\"{no_init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" logs = merge_logs(init_model_logs, no_init_model_logs)\n",
|
||||
" logs[\"dataset\"] = DATASET_NAME\n",
|
||||
"\n",
|
||||
" # Test models after training\n",
|
||||
" results[\"IDEAL_acc_after_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_after_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Init and train times\n",
|
||||
" results[\"IDEAL_init_time\"].append(init_model.init_time)\n",
|
||||
" results[\"IDEAL_train_time\"].append(init_time)\n",
|
||||
" results[\"He_init_time\"].append(no_init_model.init_time)\n",
|
||||
" results[\"He_train_time\"].append(no_init_time)\n",
|
||||
"\n",
|
||||
" # store logs\n",
|
||||
" with open(\"results.csv\", \"a\", encoding=\"utf8\") as results_file:\n",
|
||||
" logs.to_csv(results_file, header=False, index=False, quoting=1)\n",
|
||||
"\n",
|
||||
" #clear cache\n",
|
||||
" del init_model, no_init_model\n",
|
||||
" del init_trainer, no_init_trainer\n",
|
||||
" del X_train, y_train, X_val, y_val, X_test, y_test\n",
|
||||
" del logs\n",
|
||||
" cuda.empty_cache()\n",
|
||||
"\n",
|
||||
"results = DataFrame(results)\n",
|
||||
"mean = results.mean()\n",
|
||||
"error = results.sem()*1.96\n",
|
||||
"print(\"Mean values:\")\n",
|
||||
"print(mean.astype(str).apply(lambda x: x[:6]))\n",
|
||||
"print(\"Confidence error:\")\n",
|
||||
"print(error.astype(str).apply(lambda x: x[:6]))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"availableInstances": [
|
||||
{
|
||||
"_defaultOrder": 0,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.t3.medium",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 1,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.t3.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 2,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.t3.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 3,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.t3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 4,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 5,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 6,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 7,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 8,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 9,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 10,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 11,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 12,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5d.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 13,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5d.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 14,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5d.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 15,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5d.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 16,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5d.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 17,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5d.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 18,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5d.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 19,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 20,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": true,
|
||||
"memoryGiB": 0,
|
||||
"name": "ml.geospatial.interactive",
|
||||
"supportedImageNames": [
|
||||
"sagemaker-geospatial-v1-0"
|
||||
],
|
||||
"vcpuNum": 0
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 21,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.c5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 22,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.c5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 23,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.c5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 24,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.c5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 25,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 72,
|
||||
"name": "ml.c5.9xlarge",
|
||||
"vcpuNum": 36
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 26,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 96,
|
||||
"name": "ml.c5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 27,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 144,
|
||||
"name": "ml.c5.18xlarge",
|
||||
"vcpuNum": 72
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 28,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.c5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 29,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g4dn.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 30,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g4dn.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 31,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g4dn.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 32,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g4dn.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 33,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g4dn.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 34,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g4dn.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 35,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 61,
|
||||
"name": "ml.p3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 36,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 244,
|
||||
"name": "ml.p3.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 37,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 488,
|
||||
"name": "ml.p3.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 38,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.p3dn.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 39,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.r5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 40,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.r5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 41,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.r5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 42,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.r5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 43,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.r5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 44,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.r5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 45,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 512,
|
||||
"name": "ml.r5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 46,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.r5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 47,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 48,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 49,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 50,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 51,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 52,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 53,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.g5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 54,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.g5.48xlarge",
|
||||
"vcpuNum": 192
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 55,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 56,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4de.24xlarge",
|
||||
"vcpuNum": 96
|
||||
}
|
||||
],
|
||||
"instance_type": "ml.g4dn.xlarge",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
725
tabular_classification/digits.ipynb
Normal file
725
tabular_classification/digits.ipynb
Normal file
|
@ -0,0 +1,725 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Import libraries and read data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"logging.getLogger(\"lightning.pytorch.utilities.rank_zero\").setLevel(logging.WARNING)\n",
|
||||
"logging.getLogger(\"lightning.pytorch.accelerators.cuda\").setLevel(logging.WARNING)\n",
|
||||
"from time import time\n",
|
||||
"from pandas import DataFrame\n",
|
||||
"from sklearn.datasets import load_digits\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from torch import cuda\n",
|
||||
"from lightning import Trainer\n",
|
||||
"from lightning.pytorch.loggers import CSVLogger\n",
|
||||
"from ideal_init.lightning_model import NNClassifier, merge_logs\n",
|
||||
"\n",
|
||||
"# Define training parameters\n",
|
||||
"BATCH_SIZE = 256\n",
|
||||
"LEARNING_RATE = 1e-1\n",
|
||||
"HIDDEN_SIZES = (10,)\n",
|
||||
"DIRECTORY = \"./\"\n",
|
||||
"EPOCHS = 10\n",
|
||||
"DATASET_NAME = \"Handwritten digits\"\n",
|
||||
"\n",
|
||||
"def experiment_data():\n",
|
||||
" X, y = load_digits(return_X_y=True)\n",
|
||||
" # Split data: 70% train, 10% validation, 20% test\n",
|
||||
" X, X_test, y, y_test = train_test_split(X, y, test_size=0.2)\n",
|
||||
" X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.125)\n",
|
||||
" return X_train, y_train, X_val, y_val, X_test, y_test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Run experiments and measure performance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# keep record of the results of all the iterations\n",
|
||||
"values=[\"IDEAL_acc_before_train\", \"IDEAL_acc_after_train\", \"IDEAL_init_time\", \"IDEAL_train_time\", \"He_acc_before_train\", \"He_acc_after_train\", \"He_init_time\", \"He_train_time\"]\n",
|
||||
"results = {value: [] for value in values}\n",
|
||||
"\n",
|
||||
"# repeat experiment multiple times\n",
|
||||
"NUM_EXPERIMENTS = 10\n",
|
||||
"for experiment in range(NUM_EXPERIMENTS):\n",
|
||||
"\n",
|
||||
" print(f\"Running experiment {experiment+1}\")\n",
|
||||
" X_train, y_train, X_val, y_val, X_test, y_test = experiment_data()\n",
|
||||
"\n",
|
||||
" # Create models\n",
|
||||
" init_model = NNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=True, hidden_sizes=HIDDEN_SIZES, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
" no_init_model = NNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=False, hidden_sizes=HIDDEN_SIZES, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
"\n",
|
||||
" # Create trainers\n",
|
||||
" init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
" no_init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
"\n",
|
||||
" # Test models before training\n",
|
||||
" results[\"IDEAL_acc_before_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_before_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Train models and plot training comparison\n",
|
||||
" init_time = time()\n",
|
||||
" init_trainer.validate(init_model)\n",
|
||||
" init_trainer.fit(init_model)\n",
|
||||
" init_time = time() - init_time\n",
|
||||
" init_model_logs = f\"{init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" no_init_time = time()\n",
|
||||
" no_init_trainer.validate(no_init_model)\n",
|
||||
" no_init_trainer.fit(no_init_model)\n",
|
||||
" no_init_time = time() - no_init_time\n",
|
||||
" no_init_model_logs = f\"{no_init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" logs = merge_logs(init_model_logs, no_init_model_logs)\n",
|
||||
" logs[\"dataset\"] = DATASET_NAME\n",
|
||||
"\n",
|
||||
" # Test models after training\n",
|
||||
" results[\"IDEAL_acc_after_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_after_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Init and train times\n",
|
||||
" results[\"IDEAL_init_time\"].append(init_model.init_time)\n",
|
||||
" results[\"IDEAL_train_time\"].append(init_time)\n",
|
||||
" results[\"He_init_time\"].append(no_init_model.init_time)\n",
|
||||
" results[\"He_train_time\"].append(no_init_time)\n",
|
||||
"\n",
|
||||
" # store logs\n",
|
||||
" with open(\"results.csv\", \"a\", encoding=\"utf8\") as results_file:\n",
|
||||
" logs.to_csv(results_file, header=False, index=False, quoting=1)\n",
|
||||
"\n",
|
||||
" #clear cache\n",
|
||||
" del init_model, no_init_model\n",
|
||||
" del init_trainer, no_init_trainer\n",
|
||||
" del X_train, y_train, X_val, y_val, X_test, y_test\n",
|
||||
" del logs\n",
|
||||
" cuda.empty_cache()\n",
|
||||
"\n",
|
||||
"results = DataFrame(results)\n",
|
||||
"mean = results.mean()\n",
|
||||
"error = results.sem()*1.96\n",
|
||||
"print(\"Mean values:\")\n",
|
||||
"print(mean.astype(str).apply(lambda x: x[:6]))\n",
|
||||
"print(\"Confidence error:\")\n",
|
||||
"print(error.astype(str).apply(lambda x: x[:6]))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"availableInstances": [
|
||||
{
|
||||
"_defaultOrder": 0,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.t3.medium",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 1,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.t3.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 2,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.t3.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 3,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.t3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 4,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 5,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 6,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 7,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 8,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 9,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 10,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 11,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 12,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5d.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 13,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5d.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 14,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5d.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 15,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5d.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 16,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5d.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 17,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5d.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 18,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5d.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 19,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 20,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": true,
|
||||
"memoryGiB": 0,
|
||||
"name": "ml.geospatial.interactive",
|
||||
"supportedImageNames": [
|
||||
"sagemaker-geospatial-v1-0"
|
||||
],
|
||||
"vcpuNum": 0
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 21,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.c5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 22,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.c5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 23,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.c5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 24,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.c5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 25,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 72,
|
||||
"name": "ml.c5.9xlarge",
|
||||
"vcpuNum": 36
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 26,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 96,
|
||||
"name": "ml.c5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 27,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 144,
|
||||
"name": "ml.c5.18xlarge",
|
||||
"vcpuNum": 72
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 28,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.c5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 29,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g4dn.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 30,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g4dn.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 31,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g4dn.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 32,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g4dn.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 33,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g4dn.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 34,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g4dn.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 35,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 61,
|
||||
"name": "ml.p3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 36,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 244,
|
||||
"name": "ml.p3.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 37,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 488,
|
||||
"name": "ml.p3.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 38,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.p3dn.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 39,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.r5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 40,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.r5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 41,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.r5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 42,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.r5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 43,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.r5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 44,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.r5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 45,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 512,
|
||||
"name": "ml.r5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 46,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.r5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 47,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 48,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 49,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 50,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 51,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 52,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 53,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.g5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 54,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.g5.48xlarge",
|
||||
"vcpuNum": 192
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 55,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 56,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4de.24xlarge",
|
||||
"vcpuNum": 96
|
||||
}
|
||||
],
|
||||
"instance_type": "ml.g4dn.xlarge",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
725
tabular_classification/iris.ipynb
Normal file
725
tabular_classification/iris.ipynb
Normal file
|
@ -0,0 +1,725 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Import libraries and read data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"logging.getLogger(\"lightning.pytorch.utilities.rank_zero\").setLevel(logging.WARNING)\n",
|
||||
"logging.getLogger(\"lightning.pytorch.accelerators.cuda\").setLevel(logging.WARNING)\n",
|
||||
"from time import time\n",
|
||||
"from pandas import DataFrame\n",
|
||||
"from sklearn.datasets import load_iris\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from torch import cuda\n",
|
||||
"from lightning import Trainer\n",
|
||||
"from lightning.pytorch.loggers import CSVLogger\n",
|
||||
"from ideal_init.lightning_model import NNClassifier, merge_logs\n",
|
||||
"\n",
|
||||
"# Define training parameters\n",
|
||||
"BATCH_SIZE = 128\n",
|
||||
"LEARNING_RATE = 1e-1\n",
|
||||
"HIDDEN_SIZES = (10,)\n",
|
||||
"DIRECTORY = \"./\"\n",
|
||||
"EPOCHS = 10\n",
|
||||
"DATASET_NAME = \"Iris plants\"\n",
|
||||
"\n",
|
||||
"def experiment_data():\n",
|
||||
" X, y = load_iris(return_X_y=True)\n",
|
||||
" # Split data: 70% train, 10% validation, 20% test\n",
|
||||
" X, X_test, y, y_test = train_test_split(X, y, test_size=0.2)\n",
|
||||
" X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.125)\n",
|
||||
" return X_train, y_train, X_val, y_val, X_test, y_test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Run experiments and measure performance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# keep record of the results of all the iterations\n",
|
||||
"values=[\"IDEAL_acc_before_train\", \"IDEAL_acc_after_train\", \"IDEAL_init_time\", \"IDEAL_train_time\", \"He_acc_before_train\", \"He_acc_after_train\", \"He_init_time\", \"He_train_time\"]\n",
|
||||
"results = {value: [] for value in values}\n",
|
||||
"\n",
|
||||
"# repeat experiment multiple times\n",
|
||||
"NUM_EXPERIMENTS = 10\n",
|
||||
"for experiment in range(NUM_EXPERIMENTS):\n",
|
||||
"\n",
|
||||
" print(f\"Running experiment {experiment+1}\")\n",
|
||||
" X_train, y_train, X_val, y_val, X_test, y_test = experiment_data()\n",
|
||||
"\n",
|
||||
" # Create models\n",
|
||||
" init_model = NNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=True, hidden_sizes=HIDDEN_SIZES, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
" no_init_model = NNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=False, hidden_sizes=HIDDEN_SIZES, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
"\n",
|
||||
" # Create trainers\n",
|
||||
" init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
" no_init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
"\n",
|
||||
" # Test models before training\n",
|
||||
" results[\"IDEAL_acc_before_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_before_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Train models and plot training comparison\n",
|
||||
" init_time = time()\n",
|
||||
" init_trainer.validate(init_model)\n",
|
||||
" init_trainer.fit(init_model)\n",
|
||||
" init_time = time() - init_time\n",
|
||||
" init_model_logs = f\"{init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" no_init_time = time()\n",
|
||||
" no_init_trainer.validate(no_init_model)\n",
|
||||
" no_init_trainer.fit(no_init_model)\n",
|
||||
" no_init_time = time() - no_init_time\n",
|
||||
" no_init_model_logs = f\"{no_init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" logs = merge_logs(init_model_logs, no_init_model_logs)\n",
|
||||
" logs[\"dataset\"] = DATASET_NAME\n",
|
||||
"\n",
|
||||
" # Test models after training\n",
|
||||
" results[\"IDEAL_acc_after_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_after_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Init and train times\n",
|
||||
" results[\"IDEAL_init_time\"].append(init_model.init_time)\n",
|
||||
" results[\"IDEAL_train_time\"].append(init_time)\n",
|
||||
" results[\"He_init_time\"].append(no_init_model.init_time)\n",
|
||||
" results[\"He_train_time\"].append(no_init_time)\n",
|
||||
"\n",
|
||||
" # store logs\n",
|
||||
" with open(\"results.csv\", \"a\", encoding=\"utf8\") as results_file:\n",
|
||||
" logs.to_csv(results_file, header=False, index=False, quoting=1)\n",
|
||||
"\n",
|
||||
" #clear cache\n",
|
||||
" del init_model, no_init_model\n",
|
||||
" del init_trainer, no_init_trainer\n",
|
||||
" del X_train, y_train, X_val, y_val, X_test, y_test\n",
|
||||
" del logs\n",
|
||||
" cuda.empty_cache()\n",
|
||||
"\n",
|
||||
"results = DataFrame(results)\n",
|
||||
"mean = results.mean()\n",
|
||||
"error = results.sem()*1.96\n",
|
||||
"print(\"Mean values:\")\n",
|
||||
"print(mean.astype(str).apply(lambda x: x[:6]))\n",
|
||||
"print(\"Confidence error:\")\n",
|
||||
"print(error.astype(str).apply(lambda x: x[:6]))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"availableInstances": [
|
||||
{
|
||||
"_defaultOrder": 0,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.t3.medium",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 1,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.t3.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 2,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.t3.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 3,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.t3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 4,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 5,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 6,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 7,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 8,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 9,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 10,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 11,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 12,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5d.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 13,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5d.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 14,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5d.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 15,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5d.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 16,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5d.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 17,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5d.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 18,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5d.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 19,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 20,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": true,
|
||||
"memoryGiB": 0,
|
||||
"name": "ml.geospatial.interactive",
|
||||
"supportedImageNames": [
|
||||
"sagemaker-geospatial-v1-0"
|
||||
],
|
||||
"vcpuNum": 0
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 21,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.c5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 22,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.c5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 23,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.c5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 24,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.c5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 25,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 72,
|
||||
"name": "ml.c5.9xlarge",
|
||||
"vcpuNum": 36
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 26,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 96,
|
||||
"name": "ml.c5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 27,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 144,
|
||||
"name": "ml.c5.18xlarge",
|
||||
"vcpuNum": 72
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 28,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.c5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 29,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g4dn.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 30,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g4dn.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 31,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g4dn.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 32,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g4dn.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 33,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g4dn.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 34,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g4dn.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 35,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 61,
|
||||
"name": "ml.p3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 36,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 244,
|
||||
"name": "ml.p3.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 37,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 488,
|
||||
"name": "ml.p3.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 38,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.p3dn.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 39,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.r5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 40,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.r5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 41,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.r5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 42,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.r5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 43,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.r5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 44,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.r5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 45,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 512,
|
||||
"name": "ml.r5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 46,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.r5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 47,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 48,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 49,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 50,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 51,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 52,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 53,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.g5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 54,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.g5.48xlarge",
|
||||
"vcpuNum": 192
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 55,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 56,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4de.24xlarge",
|
||||
"vcpuNum": 96
|
||||
}
|
||||
],
|
||||
"instance_type": "ml.g4dn.xlarge",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
725
tabular_classification/wine.ipynb
Normal file
725
tabular_classification/wine.ipynb
Normal file
|
@ -0,0 +1,725 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Import libraries and read data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"logging.getLogger(\"lightning.pytorch.utilities.rank_zero\").setLevel(logging.WARNING)\n",
|
||||
"logging.getLogger(\"lightning.pytorch.accelerators.cuda\").setLevel(logging.WARNING)\n",
|
||||
"from time import time\n",
|
||||
"from pandas import DataFrame\n",
|
||||
"from sklearn.datasets import load_wine\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from torch import cuda\n",
|
||||
"from lightning import Trainer\n",
|
||||
"from lightning.pytorch.loggers import CSVLogger\n",
|
||||
"from ideal_init.lightning_model import NNClassifier, merge_logs\n",
|
||||
"\n",
|
||||
"# Define training parameters\n",
|
||||
"BATCH_SIZE = 256\n",
|
||||
"LEARNING_RATE = 1e-1\n",
|
||||
"HIDDEN_SIZES = (10,)\n",
|
||||
"DIRECTORY = \"./\"\n",
|
||||
"EPOCHS = 10\n",
|
||||
"DATASET_NAME = \"Wine\"\n",
|
||||
"\n",
|
||||
"def experiment_data():\n",
|
||||
" X, y = load_wine(return_X_y=True)\n",
|
||||
" # Split data: 70% train, 10% validation, 20% test\n",
|
||||
" X, X_test, y, y_test = train_test_split(X, y, test_size=0.2)\n",
|
||||
" X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.125)\n",
|
||||
" return X_train, y_train, X_val, y_val, X_test, y_test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Run experiments and measure performance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# keep record of the results of all the iterations\n",
|
||||
"values=[\"IDEAL_acc_before_train\", \"IDEAL_acc_after_train\", \"IDEAL_init_time\", \"IDEAL_train_time\", \"He_acc_before_train\", \"He_acc_after_train\", \"He_init_time\", \"He_train_time\"]\n",
|
||||
"results = {value: [] for value in values}\n",
|
||||
"\n",
|
||||
"# repeat experiment multiple times\n",
|
||||
"NUM_EXPERIMENTS = 10\n",
|
||||
"for experiment in range(NUM_EXPERIMENTS):\n",
|
||||
"\n",
|
||||
" print(f\"Running experiment {experiment+1}\")\n",
|
||||
" X_train, y_train, X_val, y_val, X_test, y_test = experiment_data()\n",
|
||||
"\n",
|
||||
" # Create models\n",
|
||||
" init_model = NNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=True, hidden_sizes=HIDDEN_SIZES, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
" no_init_model = NNClassifier(X_train, y_train, X_val, y_val, X_test, y_test, initialize=False, hidden_sizes=HIDDEN_SIZES, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
"\n",
|
||||
" # Create trainers\n",
|
||||
" init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
" no_init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
"\n",
|
||||
" # Test models before training\n",
|
||||
" results[\"IDEAL_acc_before_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_before_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Train models and plot training comparison\n",
|
||||
" init_time = time()\n",
|
||||
" init_trainer.validate(init_model)\n",
|
||||
" init_trainer.fit(init_model)\n",
|
||||
" init_time = time() - init_time\n",
|
||||
" init_model_logs = f\"{init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" no_init_time = time()\n",
|
||||
" no_init_trainer.validate(no_init_model)\n",
|
||||
" no_init_trainer.fit(no_init_model)\n",
|
||||
" no_init_time = time() - no_init_time\n",
|
||||
" no_init_model_logs = f\"{no_init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" logs = merge_logs(init_model_logs, no_init_model_logs)\n",
|
||||
" logs[\"dataset\"] = DATASET_NAME\n",
|
||||
"\n",
|
||||
" # Test models after training\n",
|
||||
" results[\"IDEAL_acc_after_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_acc_after_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Init and train times\n",
|
||||
" results[\"IDEAL_init_time\"].append(init_model.init_time)\n",
|
||||
" results[\"IDEAL_train_time\"].append(init_time)\n",
|
||||
" results[\"He_init_time\"].append(no_init_model.init_time)\n",
|
||||
" results[\"He_train_time\"].append(no_init_time)\n",
|
||||
"\n",
|
||||
" # store logs\n",
|
||||
" with open(\"results.csv\", \"a\", encoding=\"utf8\") as results_file:\n",
|
||||
" logs.to_csv(results_file, header=False, index=False, quoting=1)\n",
|
||||
"\n",
|
||||
" #clear cache\n",
|
||||
" del init_model, no_init_model\n",
|
||||
" del init_trainer, no_init_trainer\n",
|
||||
" del X_train, y_train, X_val, y_val, X_test, y_test\n",
|
||||
" del logs\n",
|
||||
" cuda.empty_cache()\n",
|
||||
"\n",
|
||||
"results = DataFrame(results)\n",
|
||||
"mean = results.mean()\n",
|
||||
"error = results.sem()*1.96\n",
|
||||
"print(\"Mean values:\")\n",
|
||||
"print(mean.astype(str).apply(lambda x: x[:6]))\n",
|
||||
"print(\"Confidence error:\")\n",
|
||||
"print(error.astype(str).apply(lambda x: x[:6]))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"availableInstances": [
|
||||
{
|
||||
"_defaultOrder": 0,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.t3.medium",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 1,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.t3.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 2,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.t3.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 3,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.t3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 4,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 5,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 6,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 7,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 8,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 9,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 10,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 11,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 12,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5d.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 13,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5d.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 14,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5d.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 15,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5d.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 16,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5d.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 17,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5d.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 18,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5d.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 19,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 20,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": true,
|
||||
"memoryGiB": 0,
|
||||
"name": "ml.geospatial.interactive",
|
||||
"supportedImageNames": [
|
||||
"sagemaker-geospatial-v1-0"
|
||||
],
|
||||
"vcpuNum": 0
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 21,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.c5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 22,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.c5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 23,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.c5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 24,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.c5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 25,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 72,
|
||||
"name": "ml.c5.9xlarge",
|
||||
"vcpuNum": 36
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 26,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 96,
|
||||
"name": "ml.c5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 27,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 144,
|
||||
"name": "ml.c5.18xlarge",
|
||||
"vcpuNum": 72
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 28,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.c5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 29,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g4dn.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 30,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g4dn.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 31,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g4dn.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 32,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g4dn.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 33,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g4dn.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 34,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g4dn.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 35,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 61,
|
||||
"name": "ml.p3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 36,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 244,
|
||||
"name": "ml.p3.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 37,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 488,
|
||||
"name": "ml.p3.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 38,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.p3dn.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 39,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.r5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 40,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.r5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 41,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.r5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 42,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.r5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 43,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.r5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 44,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.r5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 45,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 512,
|
||||
"name": "ml.r5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 46,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.r5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 47,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 48,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 49,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 50,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 51,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 52,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 53,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.g5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 54,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.g5.48xlarge",
|
||||
"vcpuNum": 192
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 55,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 56,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4de.24xlarge",
|
||||
"vcpuNum": 96
|
||||
}
|
||||
],
|
||||
"instance_type": "ml.g4dn.xlarge",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
725
tabular_regression/california_housing.ipynb
Normal file
725
tabular_regression/california_housing.ipynb
Normal file
|
@ -0,0 +1,725 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Import libraries and read data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"logging.getLogger(\"lightning.pytorch.utilities.rank_zero\").setLevel(logging.WARNING)\n",
|
||||
"logging.getLogger(\"lightning.pytorch.accelerators.cuda\").setLevel(logging.WARNING)\n",
|
||||
"from time import time\n",
|
||||
"from pandas import DataFrame\n",
|
||||
"from sklearn.datasets import fetch_california_housing\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from torch import cuda\n",
|
||||
"from lightning import Trainer\n",
|
||||
"from lightning.pytorch.loggers import CSVLogger\n",
|
||||
"from ideal_init.lightning_model import NNRegressor, merge_logs\n",
|
||||
"\n",
|
||||
"# Define training parameters\n",
|
||||
"BATCH_SIZE = 2048\n",
|
||||
"LEARNING_RATE = 1e-5\n",
|
||||
"HIDDEN_SIZES = (10, 10, 10)\n",
|
||||
"DIRECTORY = \"./\"\n",
|
||||
"EPOCHS = 10\n",
|
||||
"DATASET_NAME = \"California housing\"\n",
|
||||
"\n",
|
||||
"def experiment_data():\n",
|
||||
" X, y = fetch_california_housing(return_X_y=True)\n",
|
||||
" # Split data: 70% train, 10% validation, 20% test\n",
|
||||
" X, X_test, y, y_test = train_test_split(X, y, test_size=0.2)\n",
|
||||
" X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.125)\n",
|
||||
" return X_train, y_train, X_val, y_val, X_test, y_test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Run experiments and measure performance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# keep record of the results of all the iterations\n",
|
||||
"values=[\"IDEAL_r2_before_train\", \"IDEAL_r2_after_train\", \"IDEAL_init_time\", \"IDEAL_train_time\", \"He_r2_before_train\", \"He_r2_after_train\", \"He_init_time\", \"He_train_time\"]\n",
|
||||
"results = {value: [] for value in values}\n",
|
||||
"\n",
|
||||
"# repeat experiment multiple times\n",
|
||||
"NUM_EXPERIMENTS = 10\n",
|
||||
"for experiment in range(NUM_EXPERIMENTS):\n",
|
||||
"\n",
|
||||
" print(f\"Running experiment {experiment+1}\")\n",
|
||||
" X_train, y_train, X_val, y_val, X_test, y_test = experiment_data()\n",
|
||||
"\n",
|
||||
" # Create models\n",
|
||||
" init_model = NNRegressor(X_train, y_train, X_val, y_val, X_test, y_test, initialize=True, hidden_sizes=HIDDEN_SIZES, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
" no_init_model = NNRegressor(X_train, y_train, X_val, y_val, X_test, y_test, initialize=False, hidden_sizes=HIDDEN_SIZES, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
"\n",
|
||||
" # Create trainers\n",
|
||||
" init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
" no_init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
"\n",
|
||||
" # Test models before training\n",
|
||||
" results[\"IDEAL_r2_before_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_r2_before_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Train models and plot training comparison\n",
|
||||
" init_time = time()\n",
|
||||
" init_trainer.validate(init_model)\n",
|
||||
" init_trainer.fit(init_model)\n",
|
||||
" init_time = time() - init_time\n",
|
||||
" init_model_logs = f\"{init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" no_init_time = time()\n",
|
||||
" no_init_trainer.validate(no_init_model)\n",
|
||||
" no_init_trainer.fit(no_init_model)\n",
|
||||
" no_init_time = time() - no_init_time\n",
|
||||
" no_init_model_logs = f\"{no_init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" logs = merge_logs(init_model_logs, no_init_model_logs)\n",
|
||||
" logs[\"dataset\"] = DATASET_NAME\n",
|
||||
"\n",
|
||||
" # Test models after training\n",
|
||||
" results[\"IDEAL_r2_after_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_r2_after_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Init and train times\n",
|
||||
" results[\"IDEAL_init_time\"].append(init_model.init_time)\n",
|
||||
" results[\"IDEAL_train_time\"].append(init_time)\n",
|
||||
" results[\"He_init_time\"].append(no_init_model.init_time)\n",
|
||||
" results[\"He_train_time\"].append(no_init_time)\n",
|
||||
"\n",
|
||||
" # store logs\n",
|
||||
" with open(\"results.csv\", \"a\", encoding=\"utf8\") as results_file:\n",
|
||||
" logs.to_csv(results_file, header=False, index=False, quoting=1)\n",
|
||||
"\n",
|
||||
" #clear cache\n",
|
||||
" del init_model, no_init_model\n",
|
||||
" del init_trainer, no_init_trainer\n",
|
||||
" del X_train, y_train, X_val, y_val, X_test, y_test\n",
|
||||
" del logs\n",
|
||||
" cuda.empty_cache()\n",
|
||||
"\n",
|
||||
"results = DataFrame(results)\n",
|
||||
"mean = results.mean()\n",
|
||||
"error = results.sem()*1.96\n",
|
||||
"print(\"Mean values:\")\n",
|
||||
"print(mean.astype(str).apply(lambda x: x[:6]))\n",
|
||||
"print(\"Confidence error:\")\n",
|
||||
"print(error.astype(str).apply(lambda x: x[:6]))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"availableInstances": [
|
||||
{
|
||||
"_defaultOrder": 0,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.t3.medium",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 1,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.t3.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 2,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.t3.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 3,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.t3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 4,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 5,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 6,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 7,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 8,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 9,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 10,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 11,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 12,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5d.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 13,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5d.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 14,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5d.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 15,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5d.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 16,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5d.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 17,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5d.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 18,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5d.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 19,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 20,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": true,
|
||||
"memoryGiB": 0,
|
||||
"name": "ml.geospatial.interactive",
|
||||
"supportedImageNames": [
|
||||
"sagemaker-geospatial-v1-0"
|
||||
],
|
||||
"vcpuNum": 0
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 21,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.c5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 22,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.c5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 23,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.c5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 24,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.c5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 25,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 72,
|
||||
"name": "ml.c5.9xlarge",
|
||||
"vcpuNum": 36
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 26,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 96,
|
||||
"name": "ml.c5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 27,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 144,
|
||||
"name": "ml.c5.18xlarge",
|
||||
"vcpuNum": 72
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 28,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.c5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 29,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g4dn.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 30,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g4dn.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 31,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g4dn.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 32,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g4dn.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 33,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g4dn.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 34,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g4dn.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 35,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 61,
|
||||
"name": "ml.p3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 36,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 244,
|
||||
"name": "ml.p3.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 37,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 488,
|
||||
"name": "ml.p3.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 38,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.p3dn.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 39,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.r5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 40,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.r5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 41,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.r5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 42,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.r5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 43,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.r5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 44,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.r5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 45,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 512,
|
||||
"name": "ml.r5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 46,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.r5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 47,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 48,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 49,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 50,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 51,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 52,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 53,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.g5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 54,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.g5.48xlarge",
|
||||
"vcpuNum": 192
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 55,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 56,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4de.24xlarge",
|
||||
"vcpuNum": 96
|
||||
}
|
||||
],
|
||||
"instance_type": "ml.g4dn.xlarge",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
725
tabular_regression/diabetes.ipynb
Normal file
725
tabular_regression/diabetes.ipynb
Normal file
|
@ -0,0 +1,725 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Import libraries and read data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import logging\n",
|
||||
"logging.getLogger(\"lightning.pytorch.utilities.rank_zero\").setLevel(logging.WARNING)\n",
|
||||
"logging.getLogger(\"lightning.pytorch.accelerators.cuda\").setLevel(logging.WARNING)\n",
|
||||
"from time import time\n",
|
||||
"from pandas import DataFrame\n",
|
||||
"from sklearn.datasets import load_diabetes\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from torch import cuda\n",
|
||||
"from lightning import Trainer\n",
|
||||
"from lightning.pytorch.loggers import CSVLogger\n",
|
||||
"from ideal_init.lightning_model import NNRegressor, merge_logs\n",
|
||||
"\n",
|
||||
"# Define training parameters\n",
|
||||
"BATCH_SIZE = 256\n",
|
||||
"LEARNING_RATE = 8e-7\n",
|
||||
"HIDDEN_SIZES = (10, 10, 10)\n",
|
||||
"DIRECTORY = \"./\"\n",
|
||||
"EPOCHS = 30\n",
|
||||
"DATASET_NAME = \"Diabetes\"\n",
|
||||
"\n",
|
||||
"def experiment_data():\n",
|
||||
" X, y = load_diabetes(return_X_y=True)\n",
|
||||
" # Split data: 70% train, 10% validation, 20% test\n",
|
||||
" X, X_test, y, y_test = train_test_split(X, y, test_size=0.2)\n",
|
||||
" X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.125)\n",
|
||||
" return X_train, y_train, X_val, y_val, X_test, y_test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Run experiments and measure performance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# keep record of the results of all the iterations\n",
|
||||
"values=[\"IDEAL_r2_before_train\", \"IDEAL_r2_after_train\", \"IDEAL_init_time\", \"IDEAL_train_time\", \"He_r2_before_train\", \"He_r2_after_train\", \"He_init_time\", \"He_train_time\"]\n",
|
||||
"results = {value: [] for value in values}\n",
|
||||
"\n",
|
||||
"# repeat experiment multiple times\n",
|
||||
"NUM_EXPERIMENTS = 10\n",
|
||||
"for experiment in range(NUM_EXPERIMENTS):\n",
|
||||
"\n",
|
||||
" print(f\"Running experiment {experiment+1}\")\n",
|
||||
" X_train, y_train, X_val, y_val, X_test, y_test = experiment_data()\n",
|
||||
"\n",
|
||||
" # Create models\n",
|
||||
" init_model = NNRegressor(X_train, y_train, X_val, y_val, X_test, y_test, initialize=True, hidden_sizes=HIDDEN_SIZES, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
" no_init_model = NNRegressor(X_train, y_train, X_val, y_val, X_test, y_test, initialize=False, hidden_sizes=HIDDEN_SIZES, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)\n",
|
||||
"\n",
|
||||
" # Create trainers\n",
|
||||
" init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
" no_init_trainer = Trainer(default_root_dir=DIRECTORY, accelerator=\"auto\", devices=\"auto\", max_epochs=EPOCHS, logger=CSVLogger(save_dir=DIRECTORY), enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, val_check_interval=1, log_every_n_steps=1, limit_val_batches=1, precision=64)\n",
|
||||
"\n",
|
||||
" # Test models before training\n",
|
||||
" results[\"IDEAL_r2_before_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_r2_before_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Train models and plot training comparison\n",
|
||||
" init_time = time()\n",
|
||||
" init_trainer.validate(init_model)\n",
|
||||
" init_trainer.fit(init_model)\n",
|
||||
" init_time = time() - init_time\n",
|
||||
" init_model_logs = f\"{init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" no_init_time = time()\n",
|
||||
" no_init_trainer.validate(no_init_model)\n",
|
||||
" no_init_trainer.fit(no_init_model)\n",
|
||||
" no_init_time = time() - no_init_time\n",
|
||||
" no_init_model_logs = f\"{no_init_trainer.logger.log_dir}/metrics.csv\"\n",
|
||||
"\n",
|
||||
" logs = merge_logs(init_model_logs, no_init_model_logs)\n",
|
||||
" logs[\"dataset\"] = DATASET_NAME\n",
|
||||
"\n",
|
||||
" # Test models after training\n",
|
||||
" results[\"IDEAL_r2_after_train\"].append(init_trainer.test(init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
" results[\"He_r2_after_train\"].append(no_init_trainer.test(no_init_model, verbose=False)[0][\"test_metric\"])\n",
|
||||
"\n",
|
||||
" # Init and train times\n",
|
||||
" results[\"IDEAL_init_time\"].append(init_model.init_time)\n",
|
||||
" results[\"IDEAL_train_time\"].append(init_time)\n",
|
||||
" results[\"He_init_time\"].append(no_init_model.init_time)\n",
|
||||
" results[\"He_train_time\"].append(no_init_time)\n",
|
||||
"\n",
|
||||
" # store logs\n",
|
||||
" with open(\"results.csv\", \"a\", encoding=\"utf8\") as results_file:\n",
|
||||
" logs.to_csv(results_file, header=False, index=False, quoting=1)\n",
|
||||
"\n",
|
||||
" #clear cache\n",
|
||||
" del init_model, no_init_model\n",
|
||||
" del init_trainer, no_init_trainer\n",
|
||||
" del X_train, y_train, X_val, y_val, X_test, y_test\n",
|
||||
" del logs\n",
|
||||
" cuda.empty_cache()\n",
|
||||
"\n",
|
||||
"results = DataFrame(results)\n",
|
||||
"mean = results.mean()\n",
|
||||
"error = results.sem()*1.96\n",
|
||||
"print(\"Mean values:\")\n",
|
||||
"print(mean.astype(str).apply(lambda x: x[:6]))\n",
|
||||
"print(\"Confidence error:\")\n",
|
||||
"print(error.astype(str).apply(lambda x: x[:6]))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"availableInstances": [
|
||||
{
|
||||
"_defaultOrder": 0,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.t3.medium",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 1,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.t3.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 2,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.t3.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 3,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.t3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 4,
|
||||
"_isFastLaunch": true,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 5,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 6,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 7,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 8,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 9,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 10,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 11,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 12,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.m5d.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 13,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.m5d.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 14,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.m5d.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 15,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.m5d.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 16,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.m5d.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 17,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.m5d.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 18,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.m5d.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 19,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.m5d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 20,
|
||||
"_isFastLaunch": false,
|
||||
"category": "General purpose",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": true,
|
||||
"memoryGiB": 0,
|
||||
"name": "ml.geospatial.interactive",
|
||||
"supportedImageNames": [
|
||||
"sagemaker-geospatial-v1-0"
|
||||
],
|
||||
"vcpuNum": 0
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 21,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 4,
|
||||
"name": "ml.c5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 22,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 8,
|
||||
"name": "ml.c5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 23,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.c5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 24,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.c5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 25,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 72,
|
||||
"name": "ml.c5.9xlarge",
|
||||
"vcpuNum": 36
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 26,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 96,
|
||||
"name": "ml.c5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 27,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 144,
|
||||
"name": "ml.c5.18xlarge",
|
||||
"vcpuNum": 72
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 28,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Compute optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.c5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 29,
|
||||
"_isFastLaunch": true,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g4dn.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 30,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g4dn.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 31,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g4dn.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 32,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g4dn.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 33,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g4dn.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 34,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g4dn.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 35,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 61,
|
||||
"name": "ml.p3.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 36,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 244,
|
||||
"name": "ml.p3.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 37,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 488,
|
||||
"name": "ml.p3.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 38,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.p3dn.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 39,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.r5.large",
|
||||
"vcpuNum": 2
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 40,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.r5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 41,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.r5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 42,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.r5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 43,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.r5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 44,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.r5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 45,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 512,
|
||||
"name": "ml.r5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 46,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Memory Optimized",
|
||||
"gpuNum": 0,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.r5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 47,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 16,
|
||||
"name": "ml.g5.xlarge",
|
||||
"vcpuNum": 4
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 48,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 32,
|
||||
"name": "ml.g5.2xlarge",
|
||||
"vcpuNum": 8
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 49,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 64,
|
||||
"name": "ml.g5.4xlarge",
|
||||
"vcpuNum": 16
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 50,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 128,
|
||||
"name": "ml.g5.8xlarge",
|
||||
"vcpuNum": 32
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 51,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 1,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 256,
|
||||
"name": "ml.g5.16xlarge",
|
||||
"vcpuNum": 64
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 52,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 192,
|
||||
"name": "ml.g5.12xlarge",
|
||||
"vcpuNum": 48
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 53,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 4,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 384,
|
||||
"name": "ml.g5.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 54,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 768,
|
||||
"name": "ml.g5.48xlarge",
|
||||
"vcpuNum": 192
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 55,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4d.24xlarge",
|
||||
"vcpuNum": 96
|
||||
},
|
||||
{
|
||||
"_defaultOrder": 56,
|
||||
"_isFastLaunch": false,
|
||||
"category": "Accelerated computing",
|
||||
"gpuNum": 8,
|
||||
"hideHardwareSpecs": false,
|
||||
"memoryGiB": 1152,
|
||||
"name": "ml.p4de.24xlarge",
|
||||
"vcpuNum": 96
|
||||
}
|
||||
],
|
||||
"instance_type": "ml.g4dn.xlarge",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue