🎉 First commit
This commit is contained in:
commit
4f3f6de44a
22 changed files with 3123 additions and 0 deletions
93
dags/load_raw_data.py
Normal file
93
dags/load_raw_data.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
"""
|
||||
Airflow DAG to load raw data from speadsheet into database.
|
||||
|
||||
Author
|
||||
------
|
||||
Nicolas Rojas
|
||||
"""
|
||||
|
||||
# imports
|
||||
import os
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
from airflow import DAG
|
||||
from airflow.operators.python import PythonOperator
|
||||
from airflow.providers.mysql.hooks.mysql import MySqlHook
|
||||
|
||||
|
||||
def check_table_exists():
|
||||
"""Check whether raw_clients table exists in raw_data database. If not, create it."""
|
||||
# count number of rows in raw data table
|
||||
query = 'SELECT COUNT(*) FROM information_schema.tables WHERE table_name="raw_clients"'
|
||||
mysql_hook = MySqlHook(mysql_conn_id="raw_data", schema="raw_data")
|
||||
connection = mysql_hook.get_conn()
|
||||
cursor = connection.cursor()
|
||||
cursor.execute(query)
|
||||
results = cursor.fetchall()
|
||||
# check whether table exists
|
||||
if results[0][0] == 0:
|
||||
# create table
|
||||
print("----- table does not exists, creating it")
|
||||
create_sql = "CREATE TABLE `raw_clients`\
|
||||
(`id` BIGINT,\
|
||||
`age` SMALLINT,\
|
||||
`anual_income` BIGINT,\
|
||||
`credit_score` SMALLINT,\
|
||||
`loan_amount` BIGINT,\
|
||||
`loan_duration_years` TINYINT,\
|
||||
`number_of_open_accounts` SMALLINT,\
|
||||
`had_past_default` TINYINT,\
|
||||
`loan_approval` TINYINT\
|
||||
)"
|
||||
mysql_hook.run(create_sql)
|
||||
else:
|
||||
# no need to create table
|
||||
print("----- table already exists")
|
||||
|
||||
return "Table checked"
|
||||
|
||||
|
||||
def store_data():
|
||||
"""Store raw data in respective table and database."""
|
||||
# Path to the raw training data
|
||||
_data_root = "./data"
|
||||
_data_filename = "dataset.csv"
|
||||
_data_filepath = os.path.join(_data_root, _data_filename)
|
||||
|
||||
# read data and obtain variable names
|
||||
dataframe = pd.read_csv(_data_filepath)
|
||||
dataframe.rename(columns={"Unnamed: 0": "ID"}, inplace=True)
|
||||
sql_column_names = [name.lower() for name in dataframe.columns]
|
||||
|
||||
# insert every dataframe row into sql table
|
||||
mysql_hook = MySqlHook(mysql_conn_id="raw_data", schema="raw_data")
|
||||
conn = mysql_hook.get_conn()
|
||||
cur = conn.cursor()
|
||||
# VALUES in query are %s repeated as many columns are in dataframe
|
||||
sql_column_names = ", ".join(
|
||||
["`" + name + "`" for name in sql_column_names]
|
||||
)
|
||||
query = f"INSERT INTO `raw_clients` ({sql_column_names}) \
|
||||
VALUES ({', '.join(['%s' for _ in range(dataframe.shape[1])])})"
|
||||
dataframe = list(dataframe.itertuples(index=False, name=None))
|
||||
cur.executemany(query, dataframe)
|
||||
conn.commit()
|
||||
|
||||
return "Data stored"
|
||||
|
||||
|
||||
with DAG(
|
||||
"load_data",
|
||||
description="Read data from source and store it in raw_data database",
|
||||
start_date=datetime(2024, 9, 18, 0, 0),
|
||||
schedule_interval="@once",
|
||||
) as dag:
|
||||
|
||||
check_table_task = PythonOperator(
|
||||
task_id="check_table_exists", python_callable=check_table_exists
|
||||
)
|
||||
store_data_task = PythonOperator(
|
||||
task_id="store_data", python_callable=store_data
|
||||
)
|
||||
|
||||
check_table_task >> store_data_task
|
122
dags/preprocess_data.py
Normal file
122
dags/preprocess_data.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
"""
|
||||
Airflow DAG to load raw data, process it, split it, and store in database.
|
||||
|
||||
Author
|
||||
------
|
||||
Nicolas Rojas
|
||||
"""
|
||||
|
||||
# imports
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import train_test_split
|
||||
from airflow import DAG
|
||||
from airflow.operators.python import PythonOperator
|
||||
from airflow.providers.mysql.hooks.mysql import MySqlHook
|
||||
|
||||
|
||||
def check_table_exists(table_name: str):
|
||||
"""Check whether table exists in clean_data database. If not, create it.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
table_name : str
|
||||
Name of table to check.
|
||||
"""
|
||||
# count number of rows in data table
|
||||
query = f'SELECT COUNT(*) FROM information_schema.tables WHERE table_name="{table_name}"'
|
||||
mysql_hook = MySqlHook(mysql_conn_id="clean_data", schema="clean_data")
|
||||
connection = mysql_hook.get_conn()
|
||||
cursor = connection.cursor()
|
||||
cursor.execute(query)
|
||||
results = cursor.fetchall()
|
||||
# check whether table exists
|
||||
if results[0][0] == 0:
|
||||
# create table
|
||||
print("----- table does not exists, creating it")
|
||||
create_sql = f"CREATE TABLE `{table_name}`\
|
||||
`age` SMALLINT,\
|
||||
`anual_income` BIGINT,\
|
||||
`credit_score` SMALLINT,\
|
||||
`loan_amount` BIGINT,\
|
||||
`loan_duration_years` TINYINT,\
|
||||
`number_of_open_accounts` SMALLINT,\
|
||||
`had_past_default` TINYINT,\
|
||||
`loan_approval` TINYINT\
|
||||
)"
|
||||
mysql_hook.run(create_sql)
|
||||
else:
|
||||
# no need to create table
|
||||
print("----- table already exists")
|
||||
|
||||
return "Table checked"
|
||||
|
||||
|
||||
def store_data(dataframe: pd.DataFrame, table_name: str):
|
||||
"""Store dataframe data in given table, in clean data database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataframe : pd.DataFrame
|
||||
Dataframe to store in database.
|
||||
table_name : str
|
||||
Name of the table to store the data.
|
||||
"""
|
||||
check_table_exists(table_name)
|
||||
# insert every dataframe row into sql table
|
||||
mysql_hook = MySqlHook(mysql_conn_id="clean_data", schema="clean_data")
|
||||
sql_column_names = ", ".join(
|
||||
["`" + name + "`" for name in dataframe.columns]
|
||||
)
|
||||
conn = mysql_hook.get_conn()
|
||||
cur = conn.cursor()
|
||||
# VALUES in query are %s repeated as many columns are in dataframe
|
||||
query = f"INSERT INTO `{table_name}` ({sql_column_names}) \
|
||||
VALUES ({', '.join(['%s' for _ in range(dataframe.shape[1])])})"
|
||||
dataframe = list(dataframe.itertuples(index=False, name=None))
|
||||
cur.executemany(query, dataframe)
|
||||
conn.commit()
|
||||
|
||||
return "Data stored"
|
||||
|
||||
|
||||
def preprocess_data():
|
||||
"""Preprocess raw data and store it in clean_data database."""
|
||||
# retrieve raw data
|
||||
mysql_hook = MySqlHook(mysql_conn_id="raw_data", schema="raw_data")
|
||||
conn = mysql_hook.get_conn()
|
||||
query = "SELECT * FROM `raw_clients`"
|
||||
dataframe = pd.read_sql(query, con=conn)
|
||||
|
||||
# drop useless column
|
||||
dataframe.drop(columns=["id"], inplace=True)
|
||||
# fill empty fields
|
||||
dataframe.fillna(0, inplace=True)
|
||||
|
||||
# split data: 70% train, 10% val, 20% test
|
||||
df_train, df_test = train_test_split(
|
||||
dataframe, test_size=0.2, shuffle=True, random_state=1337
|
||||
)
|
||||
df_train, df_val = train_test_split(
|
||||
df_train, test_size=0.125, shuffle=True, random_state=1337
|
||||
)
|
||||
|
||||
# store data partitions in database
|
||||
store_data(df_train, "clean_clients_train")
|
||||
store_data(df_val, "clean_clients_val")
|
||||
store_data(df_test, "clean_clients_test")
|
||||
|
||||
return "Data preprocessed"
|
||||
|
||||
|
||||
with DAG(
|
||||
"preprocess_data",
|
||||
description="Fetch raw data, preprocess it and save it in mysql database",
|
||||
start_date=datetime(2024, 9, 18, 0, 2),
|
||||
schedule_interval="@once",
|
||||
) as dag:
|
||||
|
||||
preprocess_task = PythonOperator(
|
||||
task_id="preprocess_data", python_callable=preprocess_data
|
||||
)
|
||||
preprocess_task
|
117
dags/train_model.py
Normal file
117
dags/train_model.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
"""
|
||||
Airflow DAG to load clean data and train model with MLflow.
|
||||
|
||||
Author
|
||||
------
|
||||
Nicolas Rojas
|
||||
"""
|
||||
|
||||
# imports
|
||||
import os
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
from airflow import DAG
|
||||
from airflow.operators.python import PythonOperator
|
||||
from airflow.providers.mysql.hooks.mysql import MySqlHook
|
||||
from sklearn.ensemble import HistGradientBoostingClassifier
|
||||
from sklearn.compose import ColumnTransformer
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.metrics import f1_score
|
||||
import mlflow
|
||||
from mlflow.models import infer_signature
|
||||
|
||||
|
||||
def get_data(table_name: str, target_variable: str):
|
||||
"""Get data from clean_data database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
table_name : str
|
||||
Name of table to get data from.
|
||||
target_variable : str
|
||||
Name of the target variable in the classification problem.
|
||||
"""
|
||||
# connect to clean database
|
||||
mysql_hook = MySqlHook(mysql_conn_id="clean_data", schema="clean_data")
|
||||
sql_connection = mysql_hook.get_conn()
|
||||
# get all available data
|
||||
query = f"SELECT * FROM `{table_name}`"
|
||||
dataframe = pd.read_sql(query, con=sql_connection)
|
||||
# return input features and target variable
|
||||
return dataframe.drop(columns=[target_variable]), dataframe[target_variable]
|
||||
|
||||
|
||||
def train_model():
|
||||
"""Train model with clean data and save artifacts with MLflow."""
|
||||
# get data partitions
|
||||
target_variable = "loan_approval"
|
||||
X_train, y_train = get_data("clean_clients_train", target_variable)
|
||||
X_val, y_val = get_data("clean_clients_val", target_variable)
|
||||
X_test, y_test = get_data("clean_clients_test", target_variable)
|
||||
|
||||
# define preprocessor and classifier
|
||||
categorical_feature = "had_past_default"
|
||||
numerical_features = [
|
||||
category
|
||||
for category in X_train.columns
|
||||
if category != categorical_feature
|
||||
]
|
||||
preprocessor = ColumnTransformer(
|
||||
transformers=[
|
||||
("numerical", StandardScaler(), numerical_features),
|
||||
("categorical", "passthrough", [categorical_feature]),
|
||||
]
|
||||
)
|
||||
hyperparameters = {
|
||||
"classifier__n_estimators": 168,
|
||||
"classifier__max_depth": 6,
|
||||
"classifier__learning_rate": 0.001,
|
||||
}
|
||||
classifier = GradientBoostingClassifier(**hyperparameters)
|
||||
pipeline = Pipeline(
|
||||
steps=[("preprocessor", preprocessor), ("classifier", classifier)]
|
||||
)
|
||||
|
||||
# connect to mlflow
|
||||
os.environ["MLFLOW_S3_ENDPOINT_URL"] = "http://minio:8081"
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = "access2024minio"
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = "supersecretaccess2024"
|
||||
mlflow.set_tracking_uri("http://mlflow:8083")
|
||||
mlflow.set_experiment("mlflow_tracking_model")
|
||||
mlflow.sklearn.autolog(
|
||||
log_model_signatures=True,
|
||||
log_input_examples=True,
|
||||
registered_model_name="clients_model",
|
||||
)
|
||||
|
||||
# open mlflow run
|
||||
with mlflow.start_run(run_name="autolog_pipe_model") as run:
|
||||
# train model
|
||||
pipeline.fit(X_train, y_train)
|
||||
y_pred_val = pipeline.predict(X_val)
|
||||
y_pred_test = pipeline.predict(X_test)
|
||||
# log metrics
|
||||
mlflow.log_metric("f1_score_val", f1_score(y_val, y_pred_val))
|
||||
mlflow.log_metric("f1_score_test", f1_score(y_test, y_pred_test))
|
||||
# log model
|
||||
signature = infer_signature(X_test, y_pred_test)
|
||||
mlflow.sklearn.log_model(
|
||||
sk_model=pipeline,
|
||||
artifact_path="clients_model",
|
||||
signature=signature,
|
||||
registered_model_name="clients_model",
|
||||
)
|
||||
|
||||
|
||||
with DAG(
|
||||
"train_model",
|
||||
description="Fetch clean data from database, train model, save artifacts with MLflow and register model",
|
||||
start_date=datetime(2024, 9, 18, 0, 5),
|
||||
schedule_interval="@once",
|
||||
) as dag:
|
||||
|
||||
train_model_task = PythonOperator(
|
||||
task_id="train_model", python_callable=train_model
|
||||
)
|
||||
train_model_task
|
Loading…
Add table
Add a link
Reference in a new issue