Skip to content

Commit

Permalink
Update train_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
karimosman89 authored Sep 21, 2024
1 parent 2d8868c commit 4fca105
Showing 1 changed file with 30 additions and 20 deletions.
50 changes: 30 additions & 20 deletions src/train_model.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,42 @@
import mlflow
import mlflow.tensorflow
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.optimizers import Adam
import pandas as pd
from mlflow import log_metric, log_param, log_artifact
import mlflow

def create_model(input_shape=(32, 32, 3), num_classes=10):
def train_model(data_dir, model_save_path):
"""
Train a CNN on CIFAR-10 and log the results to MLflow.
"""
mlflow.start_run()

model = Sequential([
Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
MaxPooling2D((2, 2)),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D((2, 2)),
Flatten(),
Dense(128, activation='relu'),
Dense(num_classes, activation='softmax')
Dense(10, activation='softmax')
])
return model

if __name__ == "__main__":
mlflow.start_run() # Start MLflow run
mlflow.tensorflow.autolog() # Automatically log TensorFlow model

model = create_model()
model.compile(optimizer=Adam(), loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Load preprocessed data from S3
X_train = pd.read_csv('/tmp/X_train_preprocessed.csv').values
y_train = pd.read_csv('/tmp/y_train.csv').values

model.fit(X_train, y_train, epochs=10, batch_size=64)
model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])

log_param("optimizer", "Adam")
log_param("batch_size", 32)

# Load data
train_data = tf.keras.preprocessing.image_dataset_from_directory(data_dir, batch_size=32, image_size=(32, 32))

# Train
history = model.fit(train_data, epochs=10)

log_metric("accuracy", history.history['accuracy'][-1])
model.save(model_save_path)
mlflow.log_artifact(model_save_path)

# Save model to MLflow
mlflow.end_run()

if __name__ == '__main__':
train_model('data/processed/', 'src/model/model.h5')

0 comments on commit 4fca105

Please sign in to comment.