Prithvi 100M model
This notebook will demonstrate basic usage of the Prithvi ViT model.
wget https://raw.githubusercontent.com/selvaje/SE_data/master/exercise/foundation_model_IIASA2024.ipynb
wget https://raw.githubusercontent.com/selvaje/SE_data/master/exercise/foundation_model_IIASA2024.py
Useful links: - Hugginface page for this project - Github page
Getting started with Prithvi - Reconstruction
Get model files
To get started, clone the HuggingFace repository for Prithvi 100M, running the command below
# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install
git clone https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M
# rename to a valid python module name
mv Prithvi-100M prithvi
Alternatively, you can directly download the weights and model class and configuration file from the repository and place them inside a directory namedprithvi
.
A third alternative is to leverage the huggingface_hub
library to download these files directly through code. %pip install huggingface_hub
Treat it as a module
Next, lets add an __init__.py
file to the downloaded directory, so we can treat it as a module and import the MaskedAutoencoderViT
class from it. Simply create an empty file inside the prithvi
directory named __init__.py
by running the code below
[1]:
with open("prithvi/__init__.py", "w") as f:
f.write("")
Relevant imports
To run this notebook, besides following the installation steps in the README, make sure to install jupyter
[7]:
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
import rasterio
import yaml
from prithvi.Prithvi import MaskedAutoencoderViT
NO_DATA = -9999
NO_DATA_FLOAT = 0.0001
PERCENTILES = (0.1, 99.9)
/gpfs/gibbs/project/dijk/ahf38/conda_envs/geo_comp2/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
/gpfs/gibbs/project/dijk/ahf38/conda_envs/geo_comp2/lib/python3.8/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: /gpfs/gibbs/project/dijk/ahf38/conda_envs/geo_comp2/lib/python3.8/site-packages/torchvision/image.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
warn(f"Failed to load image Python extension: {e}")
Define some functions for visualization
[8]:
def load_raster(path, crop=None):
with rasterio.open(path) as src:
img = src.read()
print('img.shape: ',img.shape)
# load first 6 bands
img = img[:6]
# Handling No Data Values
img = np.where(img == NO_DATA, NO_DATA_FLOAT, img)
# Cropping
if crop:
img = img[:, -crop[0]:, -crop[1]:]
return img
def enhance_raster_for_visualization(raster, ref_img=None):
if ref_img is None:
ref_img = raster
channels = []
# Loop through each channel (band) in the raster
for channel in range(raster.shape[0]):
valid_mask = np.ones_like(ref_img[channel], dtype=bool)
valid_mask[ref_img[channel] == NO_DATA_FLOAT] = False
mins, maxs = np.percentile(ref_img[channel][valid_mask], PERCENTILES) # Calculate the minimum and maximum values at specified percentiles from the valid data
normalized_raster = (raster[channel] - mins) / (maxs - mins) # Normalize the raster channel to the range [0, 1] using the calculated mins and maxs
normalized_raster[~valid_mask] = 0 # Set the pixels that are not valid to 0 in the normalized raster
clipped = np.clip(normalized_raster, 0, 1) # Clip the values to ensure they are within the range [0, 1]
channels.append(clipped)
clipped = np.stack(channels)
channels_last = np.moveaxis(clipped, 0, -1)[..., :3]
rgb = channels_last[..., ::-1]
return rgb
[9]:
def plot_image_mask_reconstruction(normalized, mask_img, pred_img):
# Mix visible and predicted patches
rec_img = normalized.clone()
rec_img[mask_img == 1] = pred_img[mask_img == 1] # binary mask: 0 is keep, 1 is remove. Masked regions are replaced by 'pred' values
mask_img_np = mask_img.numpy().reshape(6, 224, 224).transpose((1, 2, 0))[..., :3]
rec_img_np = (rec_img.numpy().reshape(6, 224, 224) * stds) + means
fig, ax = plt.subplots(1, 3, figsize=(15, 6))
for subplot in ax:
subplot.axis('off')
ax[0].imshow(enhance_raster_for_visualization(input_data))
masked_img_np = enhance_raster_for_visualization(input_data).copy()
masked_img_np[mask_img_np[..., 0] == 1] = 0
ax[1].imshow(masked_img_np)
ax[2].imshow(enhance_raster_for_visualization(rec_img_np, ref_img=input_data))
Loading the model
Assuming you have the relevant files under this directory
[10]:
# load weights
weights_path = "./prithvi/Prithvi_100M.pt"
checkpoint = torch.load(weights_path, map_location="cpu")
# read model config
model_cfg_path = "./prithvi/Prithvi_100M_config.yaml"
with open(model_cfg_path) as f:
model_config = yaml.safe_load(f)
model_args, train_args = model_config["model_args"], model_config["train_params"]
# let us use only 1 frame for now (the model was trained on 3 frames)
model_args["num_frames"] = 1
# instantiate model
model = MaskedAutoencoderViT(**model_args)
model.eval()
# load weights into model
# strict=false since we are loading with only 1 frame, but the warning is expected
del checkpoint['pos_embed']
del checkpoint['decoder_pos_embed']
_ = model.load_state_dict(checkpoint, strict=False)
[11]:
# print(model)
Let’s try it out!
We can access the images directly from the HuggingFace space thanks to rasterio
[12]:
raster_path = "https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-demo/resolve/main/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"
input_data = load_raster(raster_path, crop=(224, 224))
print(f"Input data shape is {input_data.shape}")
raster_for_visualization = enhance_raster_for_visualization(input_data)
plt.imshow(raster_for_visualization)
img.shape: (6, 500, 540)
Input data shape is (6, 224, 224)
[12]:
<matplotlib.image.AxesImage at 0x146a9d76e3a0>

Lets call the model!
We pass: - The normalized input image, cropped to size (224, 224) - mask_ratio
: The proportion of pixels that will be masked
The model returns a tuple with: - loss - reconstructed image - mask used
[13]:
# statistics used to normalize images before passing to the model
means = np.array(train_args["data_mean"]).reshape(-1, 1, 1)
stds = np.array(train_args["data_std"]).reshape(-1, 1, 1)
def preprocess_image(image):
# normalize image
normalized = image.copy()
normalized = ((image - means) / stds)
normalized = torch.from_numpy(normalized.reshape(1, normalized.shape[0], 1, *normalized.shape[-2:])).to(torch.float32)
return normalized
[14]:
normalized = preprocess_image(input_data)
with torch.no_grad():
mask_ratio = 0.5
_, pred, mask = model(normalized, mask_ratio=mask_ratio)
# Let's take a look at the shape of the model's output
# This is the flat array of patches. Number of patches: image size=(224x224); patch size=(16x16); number of patches = (224/16)^2=196
print('pred.shape: ',pred.shape) # [batch, patches, dimmension]
print('mask.shape: ',mask.shape)
# Undo the patching, back to the original pixel space
mask_img = model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
pred_img = model.unpatchify(pred).detach().cpu()
pred.shape: torch.Size([1, 196, 1536])
mask.shape: torch.Size([1, 196])
Lets use these to build a nice output visualization
[10]:
plot_image_mask_reconstruction(normalized, mask_img, pred_img)

Inference with finetuned Prithvi
This time, lets use the huggingface hub library to directly download the files for the finetuned model.
[11]:
# %pip install huggingface_hub
[15]:
from mmcv import Config
from mmseg.models import build_segmentor
from mmseg.datasets.pipelines import Compose, LoadImageFromFile
from mmseg.apis import init_segmentor
from model_inference import inference_segmentor, process_test_pipeline
from huggingface_hub import hf_hub_download
import matplotlib
from torch import nn
[13]:
# Grab the config and model weights from huggingface
config_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-sen1floods11", filename="sen1floods11_Prithvi_100M.py")
ckpt=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-sen1floods11", filename='sen1floods11_Prithvi_100M.pth')
# finetuned_model = init_segmentor(Config.fromfile(config_path), ckpt, device="cpu")
finetuned_model = init_segmentor(Config.fromfile(config_path), ckpt, device="cuda")
/gpfs/gibbs/project/dijk/ahf38/conda_envs/geo_comp2/lib/python3.8/site-packages/mmseg/models/decode_heads/decode_head.py:104: UserWarning: For binary segmentation, we suggest using`out_channels = 1` to define the outputchannels of segmentor, and use `threshold`to convert seg_logist into a predictionapplying a threshold
warnings.warn('For binary segmentation, we suggest using'
/gpfs/gibbs/project/dijk/ahf38/conda_envs/geo_comp2/lib/python3.8/site-packages/mmseg/models/losses/cross_entropy_loss.py:235: UserWarning: Default ``avg_non_ignore`` is False, if you would like to ignore the certain label and average loss over non-ignore labels, which is the same with PyTorch official cross_entropy, set ``avg_non_ignore=True``.
warnings.warn(
load checkpoint from local path: /home/ahf38/.cache/huggingface/hub/models--ibm-nasa-geospatial--Prithvi-100M-sen1floods11/snapshots/220f62f00f6a31a70daac7babf139e4bf265f1c0/sen1floods11_Prithvi_100M.pth
Let’s grab an image to do inference on
!wget https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-sen1floods11-demo/resolve/main/Spain_7370579_S2Hand.tif
!wget https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-sen1floods11-demo/resolve/main/India_900498_S2Hand.tif.tif
!wget https://github.com/cloudtostreet/Sen1Floods11/blob/master/sample/S1/Spain_7370579_S1Hand.tif
[15]:
# We will load and print the imge we want to do inference with
input_data_inference = load_raster("Spain_7370579_S2Hand.tif")
print(f"Image input shape is {input_data_inference.shape}")
raster_for_visualization = enhance_raster_for_visualization(input_data_inference)
plt.axis('off')
plt.imshow(raster_for_visualization)
img.shape: (13, 512, 512)
Image input shape is (6, 512, 512)
[15]:
<matplotlib.image.AxesImage at 0x152dd56c4280>

[16]:
# Let's take a look at the definition of the model's pipeline
custom_test_pipeline = process_test_pipeline(finetuned_model.cfg.data.test.pipeline)
print('custom_test_pipeline: ',custom_test_pipeline)
result = inference_segmentor(finetuned_model, "Spain_7370579_S2Hand.tif", custom_test_pipeline=custom_test_pipeline)
custom_test_pipeline: [{'type': 'LoadGeospatialImageFromFile', 'to_float32': False, 'nodata': -9999, 'nodata_replace': 0}, {'type': 'BandsExtract', 'bands': [1, 2, 3, 8, 11, 12]}, {'type': 'ConstantMultiply', 'constant': 0.0001}, {'type': 'ToTensor', 'keys': ['img']}, {'type': 'TorchPermute', 'keys': ['img'], 'order': (2, 0, 1)}, {'type': 'TorchNormalize', 'means': [0.14245495, 0.13921481, 0.12434631, 0.31420089, 0.20743526, 0.12046503], 'stds': [0.04036231, 0.04186983, 0.05267646, 0.0822221, 0.06834774, 0.05294205]}, {'type': 'Reshape', 'keys': ['img'], 'new_shape': (6, 1, -1, -1), 'look_up': {'2': 1, '3': 2}}, {'type': 'CastTensor', 'keys': ['img'], 'new_type': 'torch.FloatTensor'}, {'type': 'CollectTestList', 'keys': ['img'], 'meta_keys': ['img_info', 'filename', 'ori_filename', 'img', 'img_shape', 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg']}]
[17]:
# The output of the model is a binary mask of same size of the input image
print('result[0].shape: ',result[0].shape)
result[0].shape: (512, 512)
[18]:
# Let's take a look at the model's prediction
fig, ax = plt.subplots(1, 3, figsize=(15, 10))
input_data_inference = load_raster("Spain_7370579_S2Hand.tif")
norm = matplotlib.colors.Normalize(vmin=0, vmax=1)
ax[0].imshow(enhance_raster_for_visualization(input_data_inference))
ax[1].imshow(result[0], norm=norm)
ax[2].imshow(enhance_raster_for_visualization(input_data_inference))
norm = matplotlib.colors.Normalize(vmin=0, vmax=2)
ax[2].imshow(result[0], cmap="jet", alpha=0.3, norm=norm)
for subplot in ax:
subplot.axis('off')
img.shape: (13, 512, 512)

[19]:
# Inference with a second image
filename = "USA_430764_S2Hand.tif"
input_data_inference = load_raster(filename)
print(f"Image input shape is {input_data_inference.shape}")
raster_for_visualization = enhance_raster_for_visualization(input_data_inference)
plt.axis('off')
plt.imshow(raster_for_visualization)
# adapt this pipeline for Tif files with > 3 images
custom_test_pipeline = process_test_pipeline(finetuned_model.cfg.data.test.pipeline)
print('custom_test_pipeline: ',custom_test_pipeline)
result = inference_segmentor(finetuned_model, filename, custom_test_pipeline=custom_test_pipeline)
fig, ax = plt.subplots(1, 3, figsize=(15, 10))
input_data_inference = load_raster(filename)
norm = matplotlib.colors.Normalize(vmin=0, vmax=1)
ax[0].imshow(enhance_raster_for_visualization(input_data_inference))
# ax[1].imshow(result[0], norm=norm, cmap="jet")
ax[1].imshow(result[0], norm=norm)
ax[2].imshow(enhance_raster_for_visualization(input_data_inference))
norm = matplotlib.colors.Normalize(vmin=0, vmax=2)
ax[2].imshow(result[0], cmap="jet", alpha=0.3, norm=norm)
for subplot in ax:
subplot.axis('off')
img.shape: (13, 512, 512)
Image input shape is (6, 512, 512)
custom_test_pipeline: [{'type': 'LoadGeospatialImageFromFile', 'to_float32': False, 'nodata': -9999, 'nodata_replace': 0}, {'type': 'BandsExtract', 'bands': [1, 2, 3, 8, 11, 12]}, {'type': 'ConstantMultiply', 'constant': 0.0001}, {'type': 'ToTensor', 'keys': ['img']}, {'type': 'TorchPermute', 'keys': ['img'], 'order': (2, 0, 1)}, {'type': 'TorchNormalize', 'means': [0.14245495, 0.13921481, 0.12434631, 0.31420089, 0.20743526, 0.12046503], 'stds': [0.04036231, 0.04186983, 0.05267646, 0.0822221, 0.06834774, 0.05294205]}, {'type': 'Reshape', 'keys': ['img'], 'new_shape': (6, 1, -1, -1), 'look_up': {'2': 1, '3': 2}}, {'type': 'CastTensor', 'keys': ['img'], 'new_type': 'torch.FloatTensor'}, {'type': 'CollectTestList', 'keys': ['img'], 'meta_keys': ['img_info', 'filename', 'ori_filename', 'img', 'img_shape', 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg']}]
img.shape: (13, 512, 512)


[20]:
# Inference with the a third image
filename = "India_900498_S2Hand.tif"
input_data_inference = load_raster(filename)
print(f"Image input shape is {input_data_inference.shape}")
raster_for_visualization = enhance_raster_for_visualization(input_data_inference)
plt.axis('off')
plt.imshow(raster_for_visualization)
# adapt this pipeline for Tif files with > 3 images
custom_test_pipeline = process_test_pipeline(finetuned_model.cfg.data.test.pipeline)
print('custom_test_pipeline: ',custom_test_pipeline)
result = inference_segmentor(finetuned_model, filename, custom_test_pipeline=custom_test_pipeline)
fig, ax = plt.subplots(1, 3, figsize=(15, 10))
input_data_inference = load_raster(filename)
norm = matplotlib.colors.Normalize(vmin=0, vmax=1)
ax[0].imshow(enhance_raster_for_visualization(input_data_inference))
# ax[1].imshow(result[0], norm=norm, cmap="jet")
ax[1].imshow(result[0], norm=norm)
ax[2].imshow(enhance_raster_for_visualization(input_data_inference))
norm = matplotlib.colors.Normalize(vmin=0, vmax=2)
ax[2].imshow(result[0], cmap="jet", alpha=0.3, norm=norm)
for subplot in ax:
subplot.axis('off')
img.shape: (13, 512, 512)
Image input shape is (6, 512, 512)
custom_test_pipeline: [{'type': 'LoadGeospatialImageFromFile', 'to_float32': False, 'nodata': -9999, 'nodata_replace': 0}, {'type': 'BandsExtract', 'bands': [1, 2, 3, 8, 11, 12]}, {'type': 'ConstantMultiply', 'constant': 0.0001}, {'type': 'ToTensor', 'keys': ['img']}, {'type': 'TorchPermute', 'keys': ['img'], 'order': (2, 0, 1)}, {'type': 'TorchNormalize', 'means': [0.14245495, 0.13921481, 0.12434631, 0.31420089, 0.20743526, 0.12046503], 'stds': [0.04036231, 0.04186983, 0.05267646, 0.0822221, 0.06834774, 0.05294205]}, {'type': 'Reshape', 'keys': ['img'], 'new_shape': (6, 1, -1, -1), 'look_up': {'2': 1, '3': 2}}, {'type': 'CastTensor', 'keys': ['img'], 'new_type': 'torch.FloatTensor'}, {'type': 'CollectTestList', 'keys': ['img'], 'meta_keys': ['img_info', 'filename', 'ori_filename', 'img', 'img_shape', 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg']}]
img.shape: (13, 512, 512)


Inference with finetuned Prithvi
Let’s explore a second finetuned model - Crop Classification
[34]:
config_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
filename="multi_temporal_crop_classification_Prithvi_100M.py")
ckpt=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
filename='multi_temporal_crop_classification_Prithvi_100M.pth')
finetuned_model = init_segmentor(Config.fromfile(config_path), ckpt, device="cuda")
load checkpoint from local path: /home/ahf38/.cache/huggingface/hub/models--ibm-nasa-geospatial--Prithvi-100M-multi-temporal-crop-classification/snapshots/3b8de5aa922c79b4cf69d497732fcf22f0edd8c6/multi_temporal_crop_classification_Prithvi_100M.pth
[40]:
# Load a sample image
import matplotlib.patches as mpatches
filename = "chip_102_345_merged.tif"
input_data_inference = load_raster(filename)
print(f"Image input shape is {input_data_inference.shape}")
# adapt this pipeline for Tif files with > 3 images
custom_test_pipeline = process_test_pipeline(finetuned_model.cfg.data.test.pipeline)
result = inference_segmentor(finetuned_model, filename, custom_test_pipeline=custom_test_pipeline)
print('result.shape: ',result[0].shape)
fig, ax = plt.subplots(1, 2, figsize=(15, 10))
input_data_inference = load_raster(filename)
norm = matplotlib.colors.Normalize(vmin=0, vmax=13)
ax[0].imshow(enhance_raster_for_visualization(input_data_inference))
ax[1].imshow(result[0], norm=norm, cmap="tab20")
# ax[2].imshow(enhance_raster_for_visualization(input_data_inference))
# ax[2].imshow(result[0], cmap="jet", alpha=0.3, norm=norm)
for subplot in ax:
subplot.axis('off')
# Turn off axis for all subplots
for subplot in ax:
subplot.axis('off')
# Define the legend handles
legend_labels = [
"Natural Vegetation",
"Forest",
"Corn",
"Soybeans",
"Wetlands",
"Developed/Barren",
"Open Water",
"Winter Wheat",
"Alfalfa",
"Fallow/Idle Cropland",
"Cotton",
"Sorghum",
"Other"]
colors = [plt.cm.tab20(norm(i)) for i in range(13)]
handles = [mpatches.Patch(color=colors[i], label=legend_labels[i]) for i in range(13)]
# Add the legend to ax[2]
ax[1].legend(handles=handles, loc='upper right',bbox_to_anchor = (1.5, 1))
img.shape: (18, 224, 224)
Image input shape is (6, 224, 224)
result.shape: (224, 224)
img.shape: (18, 224, 224)
[40]:
<matplotlib.legend.Legend at 0x146a6436be20>

Proposed exercises
Use some of the images listed below to run the flood prediction model and/or the crop classification models
Download a new image for your region and run these models. Does it match your expectation?
[19]:
# Images download from the demo (https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-demo)
fig, ax = plt.subplots(1, 3, figsize=(15, 10))
input_data_inference = load_raster("./temporal/HLS.L30.T18TVL.2018029T154533.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif")
input_data_inference2 = load_raster("./temporal/HLS.L30.T18TVL.2018141T154435.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif")
input_data_inference3 = load_raster("./temporal/HLS.L30.T18TVL.2018189T154446.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif")
norm = matplotlib.colors.Normalize(vmin=0, vmax=2)
ax[0].imshow(enhance_raster_for_visualization(input_data_inference))
ax[1].imshow(enhance_raster_for_visualization(input_data_inference2))
ax[2].imshow(enhance_raster_for_visualization(input_data_inference3))
[19]:
<matplotlib.image.AxesImage at 0x148d7d86a4f0>

[1]:
# ADD YOUR CODE TO RUN INFERENCE HERE
Finetuning for your use case
To finetune, you can now write a PyTorch loop as usual to train on your dataset. Simply extract the backbone from the model with some surgery and run only the model features forward, with no masking!
In general some reccomendations are: - At least in the beggining, experiment with freezing the backbone. This will give you much faster iteration through experiments. - Err on the side of a smaller learning rate - With an unfrozen encoder, regularization is your friend! (Weight decay, dropout, batchnorm…)
[17]:
# if going with plain pytorch:
# - remember to normalize images beforehand (find the normalization statistics in the config file)
# - turn off masking by passing mask_ratio = 0
normalized = preprocess_image(input_data)
features, _, _ = model.forward_encoder(normalized, mask_ratio=0)
print('features.shape: ',features.shape)
features.shape: torch.Size([1, 197, 768])
These are the standard output of a ViT. - Dim 1: Batch size - Dim 2: [cls_token
] + tokens representing flattened image - Dim 3: embedding dimension
First reshape features into “image-like” shape: - Drop cls_token - reshape into HxW shape
[18]:
print(f"Encoder features have shape {features.shape}")
# drop cls token
reshaped_features = features[:, 1:, :]
# reshape
feature_img_side_length = int(np.sqrt(reshaped_features.shape[1]))
reshaped_features = reshaped_features.view(-1, feature_img_side_length, feature_img_side_length, model_args["embed_dim"])
# channels first
reshaped_features = reshaped_features.permute(0, 3, 1, 2)
print(f"Encoder features have new shape {reshaped_features.shape}")
Encoder features have shape torch.Size([1, 197, 768])
Encoder features have new shape torch.Size([1, 768, 14, 14])
A simple segmentation head can consist of a few upscaling blocks + a final head for classification
[20]:
num_classes = 2
upscaling_block = lambda in_channels, out_channels: nn.Sequential(nn.Upsample(scale_factor=2), nn.Conv2d(kernel_size=3, in_channels=in_channels, out_channels=out_channels, padding=1), nn.ReLU())
embed_dims = [model_args["embed_dim"] // (2**i) for i in range(5)]
segmentation_head = nn.Sequential(
*[
upscaling_block(embed_dims[i], embed_dims[i+1]) for i in range(4)
],
nn.Conv2d(kernel_size=1, in_channels=embed_dims[-1], out_channels=num_classes))
Running features through the segmentation head
We now get an output of shape [batch_size, num_classes, height, width]
[28]:
print('len(embed_dims): ',len(embed_dims))
print('segmentation_head: ',segmentation_head)
print('segmentation_head(reshaped_features).shape: ', segmentation_head(reshaped_features).shape)
len(embed_dims): 5
segmentation_head: Sequential(
(0): Sequential(
(0): Upsample(scale_factor=2.0, mode=nearest)
(1): Conv2d(768, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): ReLU()
)
(1): Sequential(
(0): Upsample(scale_factor=2.0, mode=nearest)
(1): Conv2d(384, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): ReLU()
)
(2): Sequential(
(0): Upsample(scale_factor=2.0, mode=nearest)
(1): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): ReLU()
)
(3): Sequential(
(0): Upsample(scale_factor=2.0, mode=nearest)
(1): Conv2d(96, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): ReLU()
)
(4): Conv2d(48, 2, kernel_size=(1, 1), stride=(1, 1))
)
segmentation_head(reshaped_features).shape: torch.Size([1, 2, 224, 224])
[33]:
# new_model = nn.Sequential(model, segmentation_head)
# print(new_model)
Finetuning - MMSeg
Alternatively, finetune using the MMSegmentation extension we have opensourced. - No model surgery required - No need to write boilerplate training code - Integrations with Tensorboard, MLFlow, … - Segmentation evaluation metrics / losses built in
Build your config file. Look here for examples, the ReadME for some docs and MMSeg for more general tutorials.
Collect your dataset in the format determined by MMSeg
mim train mmsegmentation <path to my config>
This is what the model looks like in the MMSeg configuration code.
All this composition we did above is done for you!
model = dict(
type="TemporalEncoderDecoder",
frozen_backbone=False,
backbone=dict(
type="TemporalViTEncoder",
pretrained=pretrained_weights_path,
img_size=img_size,
patch_size=patch_size,
num_frames=num_frames,
tubelet_size=1,
in_chans=len(bands),
embed_dim=embed_dim,
depth=num_layers,
num_heads=num_heads,
mlp_ratio=4.0,
norm_pix_loss=False,
),
neck=dict(
type="ConvTransformerTokensToEmbeddingNeck",
embed_dim=num_frames*embed_dim,
output_embed_dim=embed_dim,
drop_cls_token=True,
Hp=img_size // patch_size,
Wp=img_size // patch_size,
),
decode_head=dict(
num_classes=num_classes,
in_channels=embed_dim,
type="FCNHead",
in_index=-1,
ignore_index=ignore_index,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type="CrossEntropyLoss",
use_sigmoid=False,
loss_weight=1,
class_weight=ce_weights,
avg_non_ignore=True
),
),
(...)