🎉 First commit
This commit is contained in:
commit
4f3f6de44a
22 changed files with 3123 additions and 0 deletions
40
.github/workflows/airflow-publish-container.yaml
vendored
Normal file
40
.github/workflows/airflow-publish-container.yaml
vendored
Normal file
|
@ -0,0 +1,40 @@
|
|||
# This workflow uses actions to build and publish airflow image on a
|
||||
# container registry.
|
||||
name: Airflow
|
||||
run-name: Publish Airflow Docker image
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
jobs:
|
||||
push_to_registry:
|
||||
name: Push Airflow Docker image to container registry
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
packages: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Check out the code
|
||||
uses: actions/checkout@v4
|
||||
- name: Log in to container registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: docker.registry
|
||||
username: dockeruser
|
||||
password: dockerpassword
|
||||
- name: Extract docker metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: dockeruser/dockerimage
|
||||
- name: Build and push Docker image
|
||||
id: push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/airflow/Dockerfile
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
170
.gitignore
vendored
Normal file
170
.gitignore
vendored
Normal file
|
@ -0,0 +1,170 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# Project extra folders
|
||||
db_clean_data
|
||||
db_raw_data
|
||||
db_ml_data
|
||||
logs
|
||||
minio_data
|
||||
plugins
|
13
LICENSE
Normal file
13
LICENSE
Normal file
|
@ -0,0 +1,13 @@
|
|||
Copyright (c) 2024 Nicolas Rojas
|
||||
|
||||
Permission to use, copy, modify, and/or distribute this software for any
|
||||
purpose with or without fee is hereby granted, provided that the above
|
||||
copyright notice and this permission notice appear in all copies.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
54
README.md
Normal file
54
README.md
Normal file
|
@ -0,0 +1,54 @@
|
|||
# ML-deploy
|
||||
|
||||
**By: Nicolas Rojas**
|
||||
|
||||
MLOps deployment pipeline.
|
||||
|
||||
This repository contains the source code to train and deploy a binary classification predictive model, using Airflow, Minio and Mlflow. By default, this repository solves the classification problem defined by the [loan aproval dataset](data/dataset.csv).
|
||||
|
||||
## Installation and execution
|
||||
|
||||
Run the following commands to install and run this project:
|
||||
|
||||
```shell
|
||||
git clone https://github.com/nirogu/ML-deploy --depth 1
|
||||
cd ML-deploy
|
||||
docker compose -f docker-compose.yaml --env-file config.env up airflow-init --build
|
||||
docker compose -f docker-compose.yaml --env-file config.env up --build -d
|
||||
```
|
||||
|
||||
To stop the service, run the folowing command:
|
||||
|
||||
```shell
|
||||
docker compose down --volumes --remove-orphans
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
- The Apache Airflow dashboard will be found on [localhost:8080](http://localhost:8080). From there, you can manage the DAGs as described in the [project documentation](https://airflow.apache.org/docs/apache-airflow/stable/index.html).
|
||||
- The MLflow dashboard will be found on [localhost:8083](http://localhost:8083). From there, you can manage the experiments and models as described in the [project documentation](https://mlflow.org/docs/latest/index.html).
|
||||
- The MinIO dashboard will be found on [localhost:8082](http://localhost:8082). From there, you can manage the storage buckets used by MLflow as described in the [project documentation](https://min.io/docs/minio/linux/index.html), although usually this is not needed.
|
||||
- A JupyterLab service can be accessed at [localhost:8085](http://localhost:8085). From there, you can run experiments on the mounted python environment, using Jupyter notebooks as you would in the local environment, as presented in the [example](notebooks/classification_experiments.ipynb).
|
||||
- The inference API can be accessed through [localhost:8086](http://localhost:8086), and the documentation can be found in the [docs endpoint](http://localhost:8086/docs).
|
||||
|
||||
## Rationale
|
||||
|
||||
It is assumed that the data is already collected and presented as a file. If that were not the case, it would be simple to edit the [data consuming DAG](dags/load_raw_data.py) to obtain the data as needed. Besides, it is assumed that data science related tasks have already been performed over the dataset, so that a machine learning model has already been defined, trained and tested in a local environment. Thus, the main challenge consists on getting the model to production, in a way that updates in the data, the model architecture, or in the training process can be automatically handled; and the model is permanently available for usage upon request. This is why a complete MLOps pipeline will be designed and explained.
|
||||
|
||||
This project's architecture is presented in the following diagram:
|
||||
|
||||

|
||||
|
||||
The architecture works as follows:
|
||||
- First, an [Airflow](https://airflow.apache.org/) service is mounted to coordinate and monitor three different workflows: data loading, data preprocessing, and model training. These workflows are represented as Airflow DAGs.
|
||||
- The data loading workflow can be found in [load_raw_data](dags/load_raw_data.py). Its function is retrieving the data from source and storing it (unprocessed) in a SQL database created for this purpose, called _raw-data_.
|
||||
- The data preprocessing workflow can be found in [preprocess_data](dags/preprocess_data.py). Its function is retrieving the raw data from the _raw-data_ data base, cleaning it, transforming it, splitting it in train/validation/test partitions, and storing the result in a SQL database created for this purpose, called _clean-data_.
|
||||
- The model training workflow can be found in [train_model](dags/train_model.py). Its function is retrieving the clean data from the _clean-data_ data base, defining a machine learning model (in this particular case, a [scikit-learn](https://scikit-learn.org/stable/index.html) model), training the model, testing it, and storing the artifacts and training metadata with a mlflow experiment.
|
||||
- A [MLflow](https://mlflow.org/) service is mounted to coordinate the experimentation process, controlling the model versioning and deploying the necessary pretrained models. This is done by storing the model artifacts in a [MinIO](https://min.io/) bucket and the experimentation metadata in a SQL database created for this purpose. MLflow will also serve a model tagged as production ready to be used by the application API.
|
||||
- A RestAPI is mounted with [FastAPI](https://fastapi.tiangolo.com), in the [backend script](src/back/main.py). This program loads the _production_ tagged model from MLflow and uses it to make predictions on the structured data received as a POST request. Besides returning the prediction, it also stores both the input and the prediction in a separate table in the _raw-data_ database. An example on how to consume the API is presented in the [frontend script](src/front/main.py).
|
||||
- Every service is defined by its respective [docker image](./docker/) and everything is mounted with [Docker compose](./docker-compose.yaml).
|
||||
|
||||
Thus, a typical workflow would consist in running the Airflow DAGs to obtain the data, training the model with its respective DAG, reviewing its performance and serving it with MLflow, and consuming it by making POST requests to the API.
|
||||
|
||||
- *Extra 1:* An additional (and optional) JupyterLab service is mounted to help the data scientists run experiments as they would in a local environment. An example is presented in the [included notebook](notebooks/classification_experiments.ipynb).
|
||||
- *Extra 2:* An alternative way of deploying the services is pushing each independent container to a docker container registry (e.g. AWS ECR, DockerHub, etc.) and pulling them when building the complete project. This can be done with [GitHub Actions](https://docs.github.com/actions) that build and push new versions of the docker images when there are changes in the code repository. An example of such a workflow is presented in [the GitHub Actions folder](.github/workflows/airflow-publish-container.yaml) (remember to store any sensitive information as [GitHub secrets](https://docs.github.com/actions/security-for-github-actions/security-guides/using-secrets-in-github-actions) instead of writing it in the code).
|
4
architecture.svg
Normal file
4
architecture.svg
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 101 KiB |
28
config.env
Normal file
28
config.env
Normal file
|
@ -0,0 +1,28 @@
|
|||
# MYSQL configuration
|
||||
MYSQL_DATABASE=mlflow_db
|
||||
MYSQL_RAW_DATABASE=raw_data
|
||||
MYSQL_CLEAN_DATABASE=clean_data
|
||||
|
||||
MYSQL_USER=sqluser
|
||||
MYSQL_PASSWORD=supersecretaccess2024
|
||||
MYSQL_ROOT_PASSWORD=supersecretaccess2024
|
||||
|
||||
MYSQL_PORT=3306
|
||||
MYSQL_RAW_PORT=8088
|
||||
MYSQL_CLEAN_PORT=8089
|
||||
|
||||
# MLflow configuration
|
||||
MLFLOW_PORT=8083
|
||||
MLFLOW_BUCKET_NAME=mlflow_bucket
|
||||
|
||||
# MinIO access keys - these are needed by MLflow
|
||||
MINIO_ACCESS_KEY=access2024minio
|
||||
MINIO_SECRET_ACCESS_KEY=supersecretaccess2024
|
||||
|
||||
# MinIO configuration
|
||||
MINIO_ROOT_USER=minio_user
|
||||
MINIO_ROOT_PASSWORD=minio_pwd
|
||||
MINIO_PORT=8081
|
||||
MINIO_CONSOLE_PORT=8082
|
||||
|
||||
AIRFLOW_UID=1000
|
93
dags/load_raw_data.py
Normal file
93
dags/load_raw_data.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
"""
|
||||
Airflow DAG to load raw data from speadsheet into database.
|
||||
|
||||
Author
|
||||
------
|
||||
Nicolas Rojas
|
||||
"""
|
||||
|
||||
# imports
|
||||
import os
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
from airflow import DAG
|
||||
from airflow.operators.python import PythonOperator
|
||||
from airflow.providers.mysql.hooks.mysql import MySqlHook
|
||||
|
||||
|
||||
def check_table_exists():
|
||||
"""Check whether raw_clients table exists in raw_data database. If not, create it."""
|
||||
# count number of rows in raw data table
|
||||
query = 'SELECT COUNT(*) FROM information_schema.tables WHERE table_name="raw_clients"'
|
||||
mysql_hook = MySqlHook(mysql_conn_id="raw_data", schema="raw_data")
|
||||
connection = mysql_hook.get_conn()
|
||||
cursor = connection.cursor()
|
||||
cursor.execute(query)
|
||||
results = cursor.fetchall()
|
||||
# check whether table exists
|
||||
if results[0][0] == 0:
|
||||
# create table
|
||||
print("----- table does not exists, creating it")
|
||||
create_sql = "CREATE TABLE `raw_clients`\
|
||||
(`id` BIGINT,\
|
||||
`age` SMALLINT,\
|
||||
`anual_income` BIGINT,\
|
||||
`credit_score` SMALLINT,\
|
||||
`loan_amount` BIGINT,\
|
||||
`loan_duration_years` TINYINT,\
|
||||
`number_of_open_accounts` SMALLINT,\
|
||||
`had_past_default` TINYINT,\
|
||||
`loan_approval` TINYINT\
|
||||
)"
|
||||
mysql_hook.run(create_sql)
|
||||
else:
|
||||
# no need to create table
|
||||
print("----- table already exists")
|
||||
|
||||
return "Table checked"
|
||||
|
||||
|
||||
def store_data():
|
||||
"""Store raw data in respective table and database."""
|
||||
# Path to the raw training data
|
||||
_data_root = "./data"
|
||||
_data_filename = "dataset.csv"
|
||||
_data_filepath = os.path.join(_data_root, _data_filename)
|
||||
|
||||
# read data and obtain variable names
|
||||
dataframe = pd.read_csv(_data_filepath)
|
||||
dataframe.rename(columns={"Unnamed: 0": "ID"}, inplace=True)
|
||||
sql_column_names = [name.lower() for name in dataframe.columns]
|
||||
|
||||
# insert every dataframe row into sql table
|
||||
mysql_hook = MySqlHook(mysql_conn_id="raw_data", schema="raw_data")
|
||||
conn = mysql_hook.get_conn()
|
||||
cur = conn.cursor()
|
||||
# VALUES in query are %s repeated as many columns are in dataframe
|
||||
sql_column_names = ", ".join(
|
||||
["`" + name + "`" for name in sql_column_names]
|
||||
)
|
||||
query = f"INSERT INTO `raw_clients` ({sql_column_names}) \
|
||||
VALUES ({', '.join(['%s' for _ in range(dataframe.shape[1])])})"
|
||||
dataframe = list(dataframe.itertuples(index=False, name=None))
|
||||
cur.executemany(query, dataframe)
|
||||
conn.commit()
|
||||
|
||||
return "Data stored"
|
||||
|
||||
|
||||
with DAG(
|
||||
"load_data",
|
||||
description="Read data from source and store it in raw_data database",
|
||||
start_date=datetime(2024, 9, 18, 0, 0),
|
||||
schedule_interval="@once",
|
||||
) as dag:
|
||||
|
||||
check_table_task = PythonOperator(
|
||||
task_id="check_table_exists", python_callable=check_table_exists
|
||||
)
|
||||
store_data_task = PythonOperator(
|
||||
task_id="store_data", python_callable=store_data
|
||||
)
|
||||
|
||||
check_table_task >> store_data_task
|
122
dags/preprocess_data.py
Normal file
122
dags/preprocess_data.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
"""
|
||||
Airflow DAG to load raw data, process it, split it, and store in database.
|
||||
|
||||
Author
|
||||
------
|
||||
Nicolas Rojas
|
||||
"""
|
||||
|
||||
# imports
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import train_test_split
|
||||
from airflow import DAG
|
||||
from airflow.operators.python import PythonOperator
|
||||
from airflow.providers.mysql.hooks.mysql import MySqlHook
|
||||
|
||||
|
||||
def check_table_exists(table_name: str):
|
||||
"""Check whether table exists in clean_data database. If not, create it.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
table_name : str
|
||||
Name of table to check.
|
||||
"""
|
||||
# count number of rows in data table
|
||||
query = f'SELECT COUNT(*) FROM information_schema.tables WHERE table_name="{table_name}"'
|
||||
mysql_hook = MySqlHook(mysql_conn_id="clean_data", schema="clean_data")
|
||||
connection = mysql_hook.get_conn()
|
||||
cursor = connection.cursor()
|
||||
cursor.execute(query)
|
||||
results = cursor.fetchall()
|
||||
# check whether table exists
|
||||
if results[0][0] == 0:
|
||||
# create table
|
||||
print("----- table does not exists, creating it")
|
||||
create_sql = f"CREATE TABLE `{table_name}`\
|
||||
`age` SMALLINT,\
|
||||
`anual_income` BIGINT,\
|
||||
`credit_score` SMALLINT,\
|
||||
`loan_amount` BIGINT,\
|
||||
`loan_duration_years` TINYINT,\
|
||||
`number_of_open_accounts` SMALLINT,\
|
||||
`had_past_default` TINYINT,\
|
||||
`loan_approval` TINYINT\
|
||||
)"
|
||||
mysql_hook.run(create_sql)
|
||||
else:
|
||||
# no need to create table
|
||||
print("----- table already exists")
|
||||
|
||||
return "Table checked"
|
||||
|
||||
|
||||
def store_data(dataframe: pd.DataFrame, table_name: str):
|
||||
"""Store dataframe data in given table, in clean data database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataframe : pd.DataFrame
|
||||
Dataframe to store in database.
|
||||
table_name : str
|
||||
Name of the table to store the data.
|
||||
"""
|
||||
check_table_exists(table_name)
|
||||
# insert every dataframe row into sql table
|
||||
mysql_hook = MySqlHook(mysql_conn_id="clean_data", schema="clean_data")
|
||||
sql_column_names = ", ".join(
|
||||
["`" + name + "`" for name in dataframe.columns]
|
||||
)
|
||||
conn = mysql_hook.get_conn()
|
||||
cur = conn.cursor()
|
||||
# VALUES in query are %s repeated as many columns are in dataframe
|
||||
query = f"INSERT INTO `{table_name}` ({sql_column_names}) \
|
||||
VALUES ({', '.join(['%s' for _ in range(dataframe.shape[1])])})"
|
||||
dataframe = list(dataframe.itertuples(index=False, name=None))
|
||||
cur.executemany(query, dataframe)
|
||||
conn.commit()
|
||||
|
||||
return "Data stored"
|
||||
|
||||
|
||||
def preprocess_data():
|
||||
"""Preprocess raw data and store it in clean_data database."""
|
||||
# retrieve raw data
|
||||
mysql_hook = MySqlHook(mysql_conn_id="raw_data", schema="raw_data")
|
||||
conn = mysql_hook.get_conn()
|
||||
query = "SELECT * FROM `raw_clients`"
|
||||
dataframe = pd.read_sql(query, con=conn)
|
||||
|
||||
# drop useless column
|
||||
dataframe.drop(columns=["id"], inplace=True)
|
||||
# fill empty fields
|
||||
dataframe.fillna(0, inplace=True)
|
||||
|
||||
# split data: 70% train, 10% val, 20% test
|
||||
df_train, df_test = train_test_split(
|
||||
dataframe, test_size=0.2, shuffle=True, random_state=1337
|
||||
)
|
||||
df_train, df_val = train_test_split(
|
||||
df_train, test_size=0.125, shuffle=True, random_state=1337
|
||||
)
|
||||
|
||||
# store data partitions in database
|
||||
store_data(df_train, "clean_clients_train")
|
||||
store_data(df_val, "clean_clients_val")
|
||||
store_data(df_test, "clean_clients_test")
|
||||
|
||||
return "Data preprocessed"
|
||||
|
||||
|
||||
with DAG(
|
||||
"preprocess_data",
|
||||
description="Fetch raw data, preprocess it and save it in mysql database",
|
||||
start_date=datetime(2024, 9, 18, 0, 2),
|
||||
schedule_interval="@once",
|
||||
) as dag:
|
||||
|
||||
preprocess_task = PythonOperator(
|
||||
task_id="preprocess_data", python_callable=preprocess_data
|
||||
)
|
||||
preprocess_task
|
117
dags/train_model.py
Normal file
117
dags/train_model.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
"""
|
||||
Airflow DAG to load clean data and train model with MLflow.
|
||||
|
||||
Author
|
||||
------
|
||||
Nicolas Rojas
|
||||
"""
|
||||
|
||||
# imports
|
||||
import os
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
from airflow import DAG
|
||||
from airflow.operators.python import PythonOperator
|
||||
from airflow.providers.mysql.hooks.mysql import MySqlHook
|
||||
from sklearn.ensemble import HistGradientBoostingClassifier
|
||||
from sklearn.compose import ColumnTransformer
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.metrics import f1_score
|
||||
import mlflow
|
||||
from mlflow.models import infer_signature
|
||||
|
||||
|
||||
def get_data(table_name: str, target_variable: str):
|
||||
"""Get data from clean_data database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
table_name : str
|
||||
Name of table to get data from.
|
||||
target_variable : str
|
||||
Name of the target variable in the classification problem.
|
||||
"""
|
||||
# connect to clean database
|
||||
mysql_hook = MySqlHook(mysql_conn_id="clean_data", schema="clean_data")
|
||||
sql_connection = mysql_hook.get_conn()
|
||||
# get all available data
|
||||
query = f"SELECT * FROM `{table_name}`"
|
||||
dataframe = pd.read_sql(query, con=sql_connection)
|
||||
# return input features and target variable
|
||||
return dataframe.drop(columns=[target_variable]), dataframe[target_variable]
|
||||
|
||||
|
||||
def train_model():
|
||||
"""Train model with clean data and save artifacts with MLflow."""
|
||||
# get data partitions
|
||||
target_variable = "loan_approval"
|
||||
X_train, y_train = get_data("clean_clients_train", target_variable)
|
||||
X_val, y_val = get_data("clean_clients_val", target_variable)
|
||||
X_test, y_test = get_data("clean_clients_test", target_variable)
|
||||
|
||||
# define preprocessor and classifier
|
||||
categorical_feature = "had_past_default"
|
||||
numerical_features = [
|
||||
category
|
||||
for category in X_train.columns
|
||||
if category != categorical_feature
|
||||
]
|
||||
preprocessor = ColumnTransformer(
|
||||
transformers=[
|
||||
("numerical", StandardScaler(), numerical_features),
|
||||
("categorical", "passthrough", [categorical_feature]),
|
||||
]
|
||||
)
|
||||
hyperparameters = {
|
||||
"classifier__n_estimators": 168,
|
||||
"classifier__max_depth": 6,
|
||||
"classifier__learning_rate": 0.001,
|
||||
}
|
||||
classifier = GradientBoostingClassifier(**hyperparameters)
|
||||
pipeline = Pipeline(
|
||||
steps=[("preprocessor", preprocessor), ("classifier", classifier)]
|
||||
)
|
||||
|
||||
# connect to mlflow
|
||||
os.environ["MLFLOW_S3_ENDPOINT_URL"] = "http://minio:8081"
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = "access2024minio"
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = "supersecretaccess2024"
|
||||
mlflow.set_tracking_uri("http://mlflow:8083")
|
||||
mlflow.set_experiment("mlflow_tracking_model")
|
||||
mlflow.sklearn.autolog(
|
||||
log_model_signatures=True,
|
||||
log_input_examples=True,
|
||||
registered_model_name="clients_model",
|
||||
)
|
||||
|
||||
# open mlflow run
|
||||
with mlflow.start_run(run_name="autolog_pipe_model") as run:
|
||||
# train model
|
||||
pipeline.fit(X_train, y_train)
|
||||
y_pred_val = pipeline.predict(X_val)
|
||||
y_pred_test = pipeline.predict(X_test)
|
||||
# log metrics
|
||||
mlflow.log_metric("f1_score_val", f1_score(y_val, y_pred_val))
|
||||
mlflow.log_metric("f1_score_test", f1_score(y_test, y_pred_test))
|
||||
# log model
|
||||
signature = infer_signature(X_test, y_pred_test)
|
||||
mlflow.sklearn.log_model(
|
||||
sk_model=pipeline,
|
||||
artifact_path="clients_model",
|
||||
signature=signature,
|
||||
registered_model_name="clients_model",
|
||||
)
|
||||
|
||||
|
||||
with DAG(
|
||||
"train_model",
|
||||
description="Fetch clean data from database, train model, save artifacts with MLflow and register model",
|
||||
start_date=datetime(2024, 9, 18, 0, 5),
|
||||
schedule_interval="@once",
|
||||
) as dag:
|
||||
|
||||
train_model_task = PythonOperator(
|
||||
task_id="train_model", python_callable=train_model
|
||||
)
|
||||
train_model_task
|
1001
data/dataset.csv
Normal file
1001
data/dataset.csv
Normal file
File diff suppressed because it is too large
Load diff
384
docker-compose.yaml
Normal file
384
docker-compose.yaml
Normal file
|
@ -0,0 +1,384 @@
|
|||
---
|
||||
x-airflow-common:
|
||||
&airflow-common
|
||||
# In order to add custom dependencies or upgrade provider packages you can use your extended image.
|
||||
# Comment the image line, place your Dockerfile in the directory where you placed the docker-compose.yaml
|
||||
# and uncomment the "build" line below, Then run `docker-compose build` to build the images.
|
||||
#image: ${AIRFLOW_IMAGE_NAME:-apache/airflow:2.6.0}
|
||||
build: ./docker/airflow
|
||||
# build: .
|
||||
environment:
|
||||
&airflow-common-env
|
||||
AIRFLOW__CORE__EXECUTOR: CeleryExecutor
|
||||
AIRFLOW__DATABASE__SQL_ALCHEMY_CONN: postgresql+psycopg2://airflow:airflow@postgres/airflow
|
||||
AIRFLOW__CELERY__RESULT_BACKEND: db+postgresql://airflow:airflow@postgres/airflow
|
||||
AIRFLOW__CELERY__BROKER_URL: redis://:@redis:6379/0
|
||||
AIRFLOW__CORE__FERNET_KEY: ''
|
||||
AIRFLOW__CORE__DAGS_ARE_PAUSED_AT_CREATION: 'true'
|
||||
AIRFLOW__CORE__LOAD_EXAMPLES: 'False'
|
||||
AIRFLOW__API__AUTH_BACKENDS: 'airflow.api.auth.backend.basic_auth,airflow.api.auth.backend.session'
|
||||
AIRFLOW_CONN_RAW_DATA: 'mysql://${MYSQL_USER}:${MYSQL_PASSWORD}@db_raw:${MYSQL_RAW_PORT}/${MYSQL_RAW_DATABASE}'
|
||||
AIRFLOW_CONN_CLEAN_DATA: 'mysql://${MYSQL_USER}:${MYSQL_PASSWORD}@db_clean:${MYSQL_CLEAN_PORT}/${MYSQL_CLEAN_DATABASE}'
|
||||
# yamllint disable rule:line-length
|
||||
# Use simple http server on scheduler for health checks
|
||||
# See https://airflow.apache.org/docs/apache-airflow/stable/administration-and-deployment/logging-monitoring/check-health.html#scheduler-health-check-server
|
||||
# yamllint enable rule:line-length
|
||||
AIRFLOW__SCHEDULER__ENABLE_HEALTH_CHECK: 'true'
|
||||
# WARNING: Use _PIP_ADDITIONAL_REQUIREMENTS option ONLY for a quick checks
|
||||
# for other purpose (development, test and especially production usage) build/extend Airflow image.
|
||||
_PIP_ADDITIONAL_REQUIREMENTS: ${_PIP_ADDITIONAL_REQUIREMENTS:-}
|
||||
volumes:
|
||||
- ${AIRFLOW_PROJ_DIR:-.}/dags:/opt/airflow/dags
|
||||
- ${AIRFLOW_PROJ_DIR:-.}/logs:/opt/airflow/logs
|
||||
- ${AIRFLOW_PROJ_DIR:-.}/plugins:/opt/airflow/plugins
|
||||
user: "${AIRFLOW_UID:-50000}:0"
|
||||
depends_on:
|
||||
&airflow-common-depends-on
|
||||
redis:
|
||||
condition: service_healthy
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
|
||||
services:
|
||||
db:
|
||||
restart: always
|
||||
image: mysql/mysql-server:5.7.28
|
||||
container_name: mlflow_db
|
||||
ports:
|
||||
- "${MYSQL_PORT}:3306"
|
||||
networks:
|
||||
- backend
|
||||
environment:
|
||||
- MYSQL_DATABASE=${MYSQL_DATABASE}
|
||||
- MYSQL_USER=${MYSQL_USER}
|
||||
- MYSQL_PASSWORD=${MYSQL_PASSWORD}
|
||||
- MYSQL_ROOT_PASSWORD=${MYSQL_ROOT_PASSWORD}
|
||||
volumes:
|
||||
- ./db_ml_data:/var/lib/mysql
|
||||
|
||||
db_raw:
|
||||
restart: always
|
||||
image: mysql/mysql-server:5.7.28
|
||||
container_name: raw_data_db
|
||||
ports:
|
||||
- "${MYSQL_RAW_PORT}:3306"
|
||||
networks:
|
||||
- backend
|
||||
environment:
|
||||
- MYSQL_DATABASE=${MYSQL_RAW_DATABASE}
|
||||
- MYSQL_USER=${MYSQL_USER}
|
||||
- MYSQL_PASSWORD=${MYSQL_PASSWORD}
|
||||
- MYSQL_ROOT_PASSWORD=${MYSQL_ROOT_PASSWORD}
|
||||
volumes:
|
||||
- ./db_raw_data:/var/lib/mysql
|
||||
|
||||
db_clean:
|
||||
restart: always
|
||||
image: mysql/mysql-server:5.7.28
|
||||
container_name: clean_data_db
|
||||
ports:
|
||||
- "${MYSQL_CLEAN_PORT}:3306"
|
||||
networks:
|
||||
- backend
|
||||
environment:
|
||||
- MYSQL_DATABASE=${MYSQL_CLEAN_DATABASE}
|
||||
- MYSQL_USER=${MYSQL_USER}
|
||||
- MYSQL_PASSWORD=${MYSQL_PASSWORD}
|
||||
- MYSQL_ROOT_PASSWORD=${MYSQL_ROOT_PASSWORD}
|
||||
volumes:
|
||||
- ./db_clean_data:/var/lib/mysql
|
||||
|
||||
fast_api:
|
||||
build: ./docker/fast_api
|
||||
ports:
|
||||
- 8086:8086
|
||||
container_name: fast_api
|
||||
depends_on:
|
||||
- db_raw
|
||||
- db_clean
|
||||
networks:
|
||||
- backend
|
||||
volumes:
|
||||
- ./src/back:/opt/code/
|
||||
command: "uvicorn opt.code.main:app --host 0.0.0.0 --port 8086 --reload"
|
||||
|
||||
minio:
|
||||
container_name: minio
|
||||
command: server /data --console-address ":${MINIO_CONSOLE_PORT}" --address ':${MINIO_PORT}'
|
||||
environment:
|
||||
- MINIO_ROOT_USER=${MINIO_ROOT_USER}
|
||||
- MINIO_ROOT_PASSWORD=${MINIO_ROOT_PASSWORD}
|
||||
image: quay.io/minio/minio:latest
|
||||
ports:
|
||||
- "${MINIO_PORT}:8081"
|
||||
- "${MINIO_CONSOLE_PORT}:8082"
|
||||
networks:
|
||||
- backend
|
||||
volumes:
|
||||
- ./minio_data:/data
|
||||
restart: unless-stopped
|
||||
|
||||
mlflow:
|
||||
restart: always
|
||||
build: ./docker/mlflow
|
||||
image: mlflow
|
||||
container_name: mlflow
|
||||
depends_on:
|
||||
- db
|
||||
- minio
|
||||
ports:
|
||||
- "${MLFLOW_PORT}:8083"
|
||||
networks:
|
||||
- backend
|
||||
environment:
|
||||
- AWS_ACCESS_KEY_ID=${MINIO_ACCESS_KEY}
|
||||
- AWS_SECRET_ACCESS_KEY=${MINIO_SECRET_ACCESS_KEY}
|
||||
- MLFLOW_S3_ENDPOINT_URL=http://minio:${MINIO_PORT}
|
||||
- MLFLOW_TRACKING_URI=http://mlflow:${MLFLOW_PORT}
|
||||
- BACKEND_STORE_URI=mysql+pymysql://${MYSQL_USER}:${MYSQL_PASSWORD}@db:${MYSQL_PORT}/${MYSQL_DATABASE}
|
||||
- DEFAULT_ARTIFACT_ROOT=s3://minio:${MINIO_PORT}
|
||||
- MLFLOW_S3_IGNORE_TLS=true
|
||||
command: >
|
||||
mlflow server
|
||||
--backend-store-uri mysql+pymysql://${MYSQL_USER}:${MYSQL_PASSWORD}@db:${MYSQL_PORT}/${MYSQL_DATABASE}
|
||||
--host 0.0.0.0
|
||||
--port ${MLFLOW_PORT}
|
||||
--serve-artifacts
|
||||
--artifacts-destination s3://${MLFLOW_BUCKET_NAME}/
|
||||
--default-artifact-root s3://${MLFLOW_BUCKET_NAME}/
|
||||
|
||||
work_jupyter:
|
||||
build: ./docker/jupyter
|
||||
volumes:
|
||||
- ./:/home/jovyan/work
|
||||
ports:
|
||||
- 8085:8888
|
||||
container_name: jupyter
|
||||
networks:
|
||||
- backend
|
||||
command: "jupyter lab --ip=0.0.0.0 --allow-root --NotebookApp.token='' --NotebookApp.password=''"
|
||||
|
||||
postgres:
|
||||
image: postgres:13
|
||||
environment:
|
||||
POSTGRES_USER: airflow
|
||||
POSTGRES_PASSWORD: airflow
|
||||
POSTGRES_DB: airflow
|
||||
volumes:
|
||||
- postgres-db-volume:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD", "pg_isready", "-U", "airflow"]
|
||||
interval: 10s
|
||||
retries: 5
|
||||
start_period: 5s
|
||||
restart: always
|
||||
networks:
|
||||
- backend
|
||||
|
||||
redis:
|
||||
image: redis:7.2-bookworm
|
||||
expose:
|
||||
- 6379
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 10s
|
||||
timeout: 30s
|
||||
retries: 50
|
||||
start_period: 30s
|
||||
restart: always
|
||||
networks:
|
||||
- backend
|
||||
|
||||
airflow-webserver:
|
||||
<<: *airflow-common
|
||||
command: webserver
|
||||
ports:
|
||||
- "8080:8080"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "--fail", "http://localhost:8080/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
start_period: 30s
|
||||
restart: always
|
||||
depends_on:
|
||||
<<: *airflow-common-depends-on
|
||||
airflow-init:
|
||||
condition: service_completed_successfully
|
||||
networks:
|
||||
- backend
|
||||
|
||||
airflow-scheduler:
|
||||
<<: *airflow-common
|
||||
command: scheduler
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "--fail", "http://localhost:8974/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
start_period: 30s
|
||||
restart: always
|
||||
depends_on:
|
||||
<<: *airflow-common-depends-on
|
||||
airflow-init:
|
||||
condition: service_completed_successfully
|
||||
networks:
|
||||
- backend
|
||||
|
||||
airflow-worker:
|
||||
<<: *airflow-common
|
||||
command: celery worker
|
||||
healthcheck:
|
||||
test:
|
||||
- "CMD-SHELL"
|
||||
- 'celery --app airflow.providers.celery.executors.celery_executor.app inspect ping -d "celery@$${HOSTNAME}" || celery --app airflow.executors.celery_executor.app inspect ping -d "celery@$${HOSTNAME}"'
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
start_period: 30s
|
||||
environment:
|
||||
<<: *airflow-common-env
|
||||
# Required to handle warm shutdown of the celery workers properly
|
||||
# See https://airflow.apache.org/docs/docker-stack/entrypoint.html#signal-propagation
|
||||
DUMB_INIT_SETSID: "0"
|
||||
restart: always
|
||||
depends_on:
|
||||
<<: *airflow-common-depends-on
|
||||
airflow-init:
|
||||
condition: service_completed_successfully
|
||||
networks:
|
||||
- backend
|
||||
|
||||
airflow-triggerer:
|
||||
<<: *airflow-common
|
||||
command: triggerer
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", 'airflow jobs check --job-type TriggererJob --hostname "$${HOSTNAME}"']
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
start_period: 30s
|
||||
restart: always
|
||||
depends_on:
|
||||
<<: *airflow-common-depends-on
|
||||
airflow-init:
|
||||
condition: service_completed_successfully
|
||||
networks:
|
||||
- backend
|
||||
|
||||
airflow-init:
|
||||
<<: *airflow-common
|
||||
entrypoint: /bin/bash
|
||||
# yamllint disable rule:line-length
|
||||
command:
|
||||
- -c
|
||||
- |
|
||||
if [[ -z "${AIRFLOW_UID}" ]]; then
|
||||
echo
|
||||
echo -e "\033[1;33mWARNING!!!: AIRFLOW_UID not set!\e[0m"
|
||||
echo "If you are on Linux, you SHOULD follow the instructions below to set "
|
||||
echo "AIRFLOW_UID environment variable, otherwise files will be owned by root."
|
||||
echo "For other operating systems you can get rid of the warning with manually created .env file:"
|
||||
echo " See: https://airflow.apache.org/docs/apache-airflow/stable/howto/docker-compose/index.html#setting-the-right-airflow-user"
|
||||
echo
|
||||
fi
|
||||
one_meg=1048576
|
||||
mem_available=$$(($$(getconf _PHYS_PAGES) * $$(getconf PAGE_SIZE) / one_meg))
|
||||
cpus_available=$$(grep -cE 'cpu[0-9]+' /proc/stat)
|
||||
disk_available=$$(df / | tail -1 | awk '{print $$4}')
|
||||
warning_resources="false"
|
||||
if (( mem_available < 4000 )) ; then
|
||||
echo
|
||||
echo -e "\033[1;33mWARNING!!!: Not enough memory available for Docker.\e[0m"
|
||||
echo "At least 4GB of memory required. You have $$(numfmt --to iec $$((mem_available * one_meg)))"
|
||||
echo
|
||||
warning_resources="true"
|
||||
fi
|
||||
if (( cpus_available < 2 )); then
|
||||
echo
|
||||
echo -e "\033[1;33mWARNING!!!: Not enough CPUS available for Docker.\e[0m"
|
||||
echo "At least 2 CPUs recommended. You have $${cpus_available}"
|
||||
echo
|
||||
warning_resources="true"
|
||||
fi
|
||||
if (( disk_available < one_meg * 10 )); then
|
||||
echo
|
||||
echo -e "\033[1;33mWARNING!!!: Not enough Disk space available for Docker.\e[0m"
|
||||
echo "At least 10 GBs recommended. You have $$(numfmt --to iec $$((disk_available * 1024 )))"
|
||||
echo
|
||||
warning_resources="true"
|
||||
fi
|
||||
if [[ $${warning_resources} == "true" ]]; then
|
||||
echo
|
||||
echo -e "\033[1;33mWARNING!!!: You have not enough resources to run Airflow (see above)!\e[0m"
|
||||
echo "Please follow the instructions to increase amount of resources available:"
|
||||
echo " https://airflow.apache.org/docs/apache-airflow/stable/howto/docker-compose/index.html#before-you-begin"
|
||||
echo
|
||||
fi
|
||||
mkdir -p /sources/logs /sources/dags /sources/plugins
|
||||
chown -R "${AIRFLOW_UID}:0" /sources/{logs,dags,plugins}
|
||||
exec /entrypoint airflow version
|
||||
# yamllint enable rule:line-length
|
||||
environment:
|
||||
<<: *airflow-common-env
|
||||
_AIRFLOW_DB_UPGRADE: 'true'
|
||||
_AIRFLOW_WWW_USER_CREATE: 'true'
|
||||
_AIRFLOW_WWW_USER_USERNAME: ${_AIRFLOW_WWW_USER_USERNAME:-airflow}
|
||||
_AIRFLOW_WWW_USER_PASSWORD: ${_AIRFLOW_WWW_USER_PASSWORD:-airflow}
|
||||
_PIP_ADDITIONAL_REQUIREMENTS: ''
|
||||
user: "0:0"
|
||||
volumes:
|
||||
- ${AIRFLOW_PROJ_DIR:-.}:/sources
|
||||
networks:
|
||||
- backend
|
||||
|
||||
airflow-cli:
|
||||
<<: *airflow-common
|
||||
profiles:
|
||||
- debug
|
||||
environment:
|
||||
<<: *airflow-common-env
|
||||
CONNECTION_CHECK_MAX_COUNT: "0"
|
||||
# Workaround for entrypoint issue. See: https://github.com/apache/airflow/issues/16252
|
||||
command:
|
||||
- bash
|
||||
- -c
|
||||
- airflow
|
||||
networks:
|
||||
- backend
|
||||
|
||||
# You can enable flower by adding "--profile flower" option e.g. docker-compose --profile flower up
|
||||
# or by explicitly targeted on the command line e.g. docker-compose up flower.
|
||||
# See: https://docs.docker.com/compose/profiles/
|
||||
flower:
|
||||
<<: *airflow-common
|
||||
command: celery flower
|
||||
profiles:
|
||||
- flower
|
||||
ports:
|
||||
- "5555:5555"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "--fail", "http://localhost:5555/"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
start_period: 30s
|
||||
restart: always
|
||||
depends_on:
|
||||
<<: *airflow-common-depends-on
|
||||
airflow-init:
|
||||
condition: service_completed_successfully
|
||||
networks:
|
||||
- backend
|
||||
|
||||
volumes:
|
||||
postgres-db-volume:
|
||||
db_ml_data:
|
||||
db_raw_data:
|
||||
db_clean_data:
|
||||
minio_data:
|
||||
|
||||
|
||||
networks:
|
||||
backend:
|
||||
driver: bridge
|
5
docker/airflow/Dockerfile
Normal file
5
docker/airflow/Dockerfile
Normal file
|
@ -0,0 +1,5 @@
|
|||
FROM apache/airflow:2.10.1-python3.9
|
||||
# Install python packages
|
||||
COPY ./requirements.txt /code/requirements.txt
|
||||
#RUN pip install --user --upgrade pip
|
||||
RUN pip install -r /code/requirements.txt
|
8
docker/airflow/requirements.txt
Normal file
8
docker/airflow/requirements.txt
Normal file
|
@ -0,0 +1,8 @@
|
|||
mlflow==2.3.0
|
||||
scikit-learn==1.2.2
|
||||
scipy==1.11.4
|
||||
pandas==1.5.3
|
||||
boto3==1.26.121
|
||||
requests==2.28.2
|
||||
joblib==1.3.0
|
||||
apache-airflow-providers-mysql==5.7.0
|
4
docker/fast_api/Dockerfile
Normal file
4
docker/fast_api/Dockerfile
Normal file
|
@ -0,0 +1,4 @@
|
|||
FROM python:3.9
|
||||
# Install python packages
|
||||
COPY ./requirements.txt /code/requirements.txt
|
||||
RUN pip install -r /code/requirements.txt
|
11
docker/fast_api/requirements.txt
Normal file
11
docker/fast_api/requirements.txt
Normal file
|
@ -0,0 +1,11 @@
|
|||
mlflow==2.3.0
|
||||
fastapi>=0.68.0,<0.69.0
|
||||
uvicorn>=0.15.0,<0.16.0
|
||||
scikit-learn==1.2.2
|
||||
numpy==1.24.3
|
||||
pandas==1.5.3
|
||||
scipy==1.11.4
|
||||
psutil==5.9.5
|
||||
typing_extensions==4.11.0
|
||||
boto3==1.26.121
|
||||
mysql-connector-python==8.4.0
|
4
docker/jupyter/Dockerfile
Normal file
4
docker/jupyter/Dockerfile
Normal file
|
@ -0,0 +1,4 @@
|
|||
FROM python:3.9
|
||||
# Install python packages
|
||||
COPY ./requirements.txt /code/requirements.txt
|
||||
RUN pip install -r /code/requirements.txt
|
9
docker/jupyter/requirements.txt
Normal file
9
docker/jupyter/requirements.txt
Normal file
|
@ -0,0 +1,9 @@
|
|||
mlflow==2.3.0
|
||||
jupyter==1.0.0
|
||||
jupyterlab==3.6.1
|
||||
pandas==1.5.3
|
||||
scikit-learn==1.4.2
|
||||
scipy==1.11.4
|
||||
boto3==1.26.121
|
||||
requests==2.28.2
|
||||
joblib==1.3.0
|
4
docker/mlflow/Dockerfile
Normal file
4
docker/mlflow/Dockerfile
Normal file
|
@ -0,0 +1,4 @@
|
|||
FROM python:3.10-slim-buster
|
||||
# Install python packages
|
||||
COPY ./requirements.txt /code/requirements.txt
|
||||
RUN pip install -r /code/requirements.txt
|
3
docker/mlflow/requirements.txt
Normal file
3
docker/mlflow/requirements.txt
Normal file
|
@ -0,0 +1,3 @@
|
|||
mlflow==2.3.0
|
||||
boto3==1.26.121
|
||||
pymysql==1.1.1
|
832
notebooks/classification_experiments.ipynb
Normal file
832
notebooks/classification_experiments.ipynb
Normal file
File diff suppressed because one or more lines are too long
180
src/back/main.py
Normal file
180
src/back/main.py
Normal file
|
@ -0,0 +1,180 @@
|
|||
"""
|
||||
Backend module for FastAPI application.
|
||||
|
||||
Author
|
||||
------
|
||||
Nicolas Rojas
|
||||
"""
|
||||
|
||||
# imports
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
import pandas as pd
|
||||
from fastapi import FastAPI, HTTPException
|
||||
import mlflow
|
||||
import mysql.connector
|
||||
|
||||
|
||||
def check_table_exists(table_name: str):
|
||||
"""Check whether table exists in raw_data database. If not, create it.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
table_name : str
|
||||
Name of table to check.
|
||||
"""
|
||||
# count number of rows in predictions data table
|
||||
query = f'SELECT COUNT(*) FROM information_schema.tables WHERE table_name="{table_name}"'
|
||||
connection = mysql.connector.connect(
|
||||
url="http://db_raw:8088",
|
||||
user="sqluser",
|
||||
password="supersecretaccess2024",
|
||||
database="raw_data",
|
||||
)
|
||||
cursor = connection.cursor()
|
||||
cursor.execute(query)
|
||||
results = cursor.fetchall()
|
||||
# check whether table exists
|
||||
if results[0][0] == 0:
|
||||
# create table
|
||||
print("----- table does not exists, creating it")
|
||||
create_sql = f"CREATE TABLE `{table_name}`\
|
||||
(`id` BIGINT,\
|
||||
`age` SMALLINT,\
|
||||
`anual_income` BIGINT,\
|
||||
`credit_score` SMALLINT,\
|
||||
`loan_amount` BIGINT,\
|
||||
`loan_duration_years` TINYINT,\
|
||||
`number_of_open_accounts` SMALLINT,\
|
||||
`had_past_default` TINYINT,\
|
||||
`loan_approval` TINYINT\
|
||||
)"
|
||||
cursor.execute(create_sql)
|
||||
else:
|
||||
# no need to create table
|
||||
print("----- table already exists")
|
||||
|
||||
cursor.close()
|
||||
connection.close()
|
||||
|
||||
|
||||
def store_data(dataframe: pd.DataFrame, table_name: str):
|
||||
"""Store dataframe data in given table, in raw data database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataframe : pd.DataFrame
|
||||
Dataframe to store in database.
|
||||
table_name : str
|
||||
Name of the table to store the data.
|
||||
"""
|
||||
check_table_exists(table_name)
|
||||
# insert every dataframe row into sql table
|
||||
connection = mysql.connector.connect(
|
||||
url="http://db_raw:8088",
|
||||
user="sqluser",
|
||||
password="supersecretaccess2024",
|
||||
database="raw_data",
|
||||
)
|
||||
sql_column_names = ", ".join(
|
||||
["`" + name + "`" for name in dataframe.columns]
|
||||
)
|
||||
cur = connection.cursor()
|
||||
# VALUES in query are %s repeated as many columns are in dataframe
|
||||
query = f"INSERT INTO `{table_name}` ({sql_column_names}) \
|
||||
VALUES ({', '.join(['%s' for _ in range(dataframe.shape[1])])})"
|
||||
dataframe = list(dataframe.itertuples(index=False, name=None))
|
||||
cur.executemany(query, dataframe)
|
||||
connection.commit()
|
||||
|
||||
cur.close()
|
||||
connection.close()
|
||||
|
||||
|
||||
# connect to mlflow
|
||||
os.environ["MLFLOW_S3_ENDPOINT_URL"] = "http://minio:8081"
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = "access2024minio"
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = "supersecretaccess2024"
|
||||
mlflow.set_tracking_uri("http://mlflow:8083")
|
||||
mlflow.set_experiment("mlflow_tracking_model")
|
||||
|
||||
# load model
|
||||
MODEL_NAME = "clients_model"
|
||||
MODEL_PRODUCTION_URI = f"models:/{MODEL_NAME}/production"
|
||||
loaded_model = mlflow.pyfunc.load_model(model_uri=MODEL_PRODUCTION_URI)
|
||||
|
||||
# create FastAPI app
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class ModelInput(BaseModel):
|
||||
"""Input model for FastAPI endpoint."""
|
||||
|
||||
id: int
|
||||
age: float
|
||||
annual_income: float
|
||||
credit_score: float
|
||||
loan_amount: float
|
||||
loan_duration_years: int
|
||||
number_of_open_accounts: float
|
||||
had_past_default: int
|
||||
|
||||
|
||||
@app.post("/predict/")
|
||||
def predict(item: ModelInput):
|
||||
"""Predict with loaded model over client data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
item : ModelInput
|
||||
Input data for model, received as JSON.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Dictionary with prediction.
|
||||
|
||||
Raises
|
||||
------
|
||||
HTTPException
|
||||
When receiving bad request.
|
||||
"""
|
||||
try:
|
||||
global loaded_model
|
||||
# get data from model_input
|
||||
received_data = item.model_dump()
|
||||
# preprocess data
|
||||
preprocessed_data = received_data.copy()
|
||||
preprocessed_data.pop("id")
|
||||
# transform data into DataFrame
|
||||
preprocessed_data = pd.DataFrame(
|
||||
{
|
||||
key: [
|
||||
value,
|
||||
]
|
||||
for key, value in preprocessed_data.items()
|
||||
}
|
||||
)
|
||||
# fill nan
|
||||
preprocessed_data.fillna(0, inplace=True)
|
||||
# predict with model
|
||||
prediction = loaded_model.predict(preprocessed_data)
|
||||
prediction = int(prediction[0])
|
||||
# store data in raw_data database
|
||||
received_data = pd.DataFrame(
|
||||
{
|
||||
key: [
|
||||
value,
|
||||
]
|
||||
for key, value in received_data.items()
|
||||
}
|
||||
)
|
||||
received_data["loan_approval"] = prediction
|
||||
store_data(received_data, "predictions_data")
|
||||
# return prediction as JSON
|
||||
return {"prediction": prediction}
|
||||
|
||||
except Exception as error:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Bad Request\n{error}"
|
||||
) from error
|
37
src/front/main.py
Normal file
37
src/front/main.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
"""
|
||||
Example file of interaction with the FastAPI server.
|
||||
|
||||
Author
|
||||
------
|
||||
Nicolas Rojas
|
||||
"""
|
||||
|
||||
# This script shows how to interact with the API serving the model
|
||||
# Given the context of this problem, a simple program sending a request makes
|
||||
# more sense than a graphical user interface, although building one with
|
||||
# libraries like gradio or streamlit would be trivial given the current script
|
||||
|
||||
import requests
|
||||
|
||||
ENDPOINT_URL = "http://localhost:8086/predict/"
|
||||
|
||||
try:
|
||||
data = {
|
||||
"id": 0,
|
||||
"age": 35.0,
|
||||
"annual_income": 107770.0,
|
||||
"credit_score": 331.0,
|
||||
"loan_amount": 31580.0,
|
||||
"loan_duration_years": 28,
|
||||
"number_of_open_accounts": 13.0,
|
||||
"had_past_default": 0,
|
||||
}
|
||||
|
||||
response = requests.post(ENDPOINT_URL, json=data, timeout=30)
|
||||
if response.status_code == 200:
|
||||
print(response.json())
|
||||
else:
|
||||
print(f"Failed with status code: {response.status_code}")
|
||||
|
||||
except requests.exceptions.RequestException as error:
|
||||
print(f"Failed to connect to FastAPI server:\n{error}")
|
Loading…
Add table
Add a link
Reference in a new issue