Documentation
Modeling
Classification - Images
Model Factory

Module: model_factory.py

Functions

import tensorflow as tf
from tensorflow.keras.applications import ResNet50, VGG16
 
def get_image_model(model_type: str, input_shape):
    """
    Get a specific image classification model.
 
    Parameters:
        model_type (str): Type of model architecture to use (e.g., "SimpleCNN", "ResNet", "VGG").
        input_shape (tuple): Shape of the input images.
 
    Returns:
        tf.keras.Model: Compiled Keras model instance.
    """