Library is build to work together with Keras and TensorFlow Keras frameworks
importsegmentation_modelsassm# Segmentation Models: using `keras` framework.
By default it tries to import keras, if it is not installed, it will try to start with tensorflow.keras framework.
There are several ways to choose framework:
Provide environment variable SM_FRAMEWORK=keras / SM_FRAMEWORK=tf.keras before import segmentation_models
You can also specify what kind of image_data_format to use, segmentation-models works with both: channels_last and channels_first.
This can be useful for further model conversion to Nvidia TensorRT format or optimizing model for cpu/gpu computations.
importkeras# or from tensorflow import keraskeras.backend.set_image_data_format('channels_last')
# or keras.backend.set_image_data_format('channels_first')
Created segmentation model is just an instance of Keras Model, which can be build as easy as:
model=sm.Unet()
Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:
Change number of output classes in the model (choose your case):
# binary segmentation (this parameters are default when you call Unet('resnet34')model=sm.Unet('resnet34', classes=1, activation='sigmoid')
# multiclass segmentation with non overlapping class masks (your classes + background)model=sm.Unet('resnet34', classes=3, activation='softmax')
# multiclass segmentation with independent overlapping/non-overlapping class masksmodel=sm.Unet('resnet34', classes=3, activation='sigmoid')
Change input shape of the model:
# if you set input channels not equal to 3, you have to set encoder_weights=None# how to handle such case with encoder_weights='imagenet' described in docsmodel=Unet('resnet34', input_shape=(None, None, 6), encoder_weights=None)
Simple training pipeline
importsegmentation_modelsassmBACKBONE='resnet34'preprocess_input=sm.get_preprocessing(BACKBONE)
# load your datax_train, y_train, x_val, y_val=load_data(...)
# preprocess inputx_train=preprocess_input(x_train)
x_val=preprocess_input(x_val)
# define modelmodel=sm.Unet(BACKBONE, encoder_weights='imagenet')
model.compile(
'Adam',
loss=sm.losses.bce_jaccard_loss,
metrics=[sm.metrics.iou_score],
)
# fit model# if you use data generator use model.fit_generator(...) instead of model.fit(...)# more about `fit_generator` here: https://keras.io/models/sequential/#fit_generatormodel.fit(
x=x_train,
y=y_train,
batch_size=16,
epochs=100,
validation_data=(x_val, y_val),
)
Same manipulations can be done with Linknet, PSPNet and FPN. For more detailed information about models API and use cases Read the Docs.
Examples
Models training examples:
[Jupyter Notebook] Binary segmentation (cars) on CamVid dataset here.
[Jupyter Notebook] Multi-class segmentation (cars, pedestrians) on CamVid dataset here.
qubvel/segmentation_models
The main features of this library are:
Important note
Table of Contents
Quick start
Library is build to work together with Keras and TensorFlow Keras frameworks
By default it tries to import
keras
, if it is not installed, it will try to start withtensorflow.keras
framework. There are several ways to choose framework:SM_FRAMEWORK=keras
/SM_FRAMEWORK=tf.keras
before importsegmentation_models
sm.set_framework('keras')
/sm.set_framework('tf.keras')
You can also specify what kind of
image_data_format
to use, segmentation-models works with both:channels_last
andchannels_first
. This can be useful for further model conversion to Nvidia TensorRT format or optimizing model for cpu/gpu computations.Created segmentation model is just an instance of Keras Model, which can be build as easy as:
Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:
Change number of output classes in the model (choose your case):
Change input shape of the model:
Simple training pipeline
Same manipulations can be done with
Linknet
,PSPNet
andFPN
. For more detailed information about models API and use cases Read the Docs.Examples
Models and Backbones
Models
Backbones
'vgg16' 'vgg19'
'resnet18' 'resnet34' 'resnet50' 'resnet101' 'resnet152'
'seresnet18' 'seresnet34' 'seresnet50' 'seresnet101' 'seresnet152'
'resnext50' 'resnext101'
'seresnext50' 'seresnext101'
'senet154'
'densenet121' 'densenet169' 'densenet201'
'inceptionv3' 'inceptionresnetv2'
'mobilenet' 'mobilenetv2'
'efficientnetb0' 'efficientnetb1' 'efficientnetb2' 'efficientnetb3' 'efficientnetb4' 'efficientnetb5' efficientnetb6' efficientnetb7'
Installation
Requirements
PyPI stable package
PyPI latest package
Source latest version
Documentation
Latest documentation is avaliable on Read the Docs
Change Log
To see important changes between versions look at CHANGELOG.md
Citing
License
Project is distributed under MIT Licence.