diff --git a/scoutbot/__init__.py b/scoutbot/__init__.py index b08779a..9ae169e 100644 --- a/scoutbot/__init__.py +++ b/scoutbot/__init__.py @@ -50,6 +50,8 @@ ) ''' import cv2 +from PIL import Image +import numpy as np from os.path import exists import pooch @@ -235,7 +237,8 @@ def pipeline_v3( ) det_result = tile_batched.get_sliced_prediction_batched( - cv2.imread(filepath), + #cv2.imread(filepath), + np.array(Image.open(filepath).convert("RGB")), batched_detection_model, slice_height=slice_height, slice_width=slice_width, diff --git a/scoutbot/tile_batched/__init__.py b/scoutbot/tile_batched/__init__.py index afd671c..008b58a 100644 --- a/scoutbot/tile_batched/__init__.py +++ b/scoutbot/tile_batched/__init__.py @@ -1 +1 @@ -from .main import get_sliced_prediction_batched, Yolov8DetectionModel # NOQA \ No newline at end of file +from .main import get_sliced_prediction_batched, Yolov8DetectionModel, HerdNetDetectionModel # NOQA \ No newline at end of file diff --git a/scoutbot/tile_batched/dla.py b/scoutbot/tile_batched/dla.py new file mode 100644 index 0000000..dd87ccc --- /dev/null +++ b/scoutbot/tile_batched/dla.py @@ -0,0 +1,590 @@ +__copyright__ = \ + """ + MIT License + + Copyright (c) 2019 Xingyi Zhou + All rights reserved. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + """ +__authors__ = "Xingyi Zhou, Dequan Wang, Philipp Krähenbühl" +__license__ = "MIT" + + +import math +from os.path import join +from posixpath import basename + +import torch +from torch import nn +import torch.utils.model_zoo as model_zoo + +import numpy as np + +BatchNorm = nn.BatchNorm2d + +def get_model_url(data='imagenet', name='dla34', hash='ba72cf86'): + return join('http://dl.yf.io/dla/models', data, '{}-{}.pth'.format(name, hash)) + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, inplanes, planes, stride=1, dilation=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, + stride=stride, padding=dilation, + bias=False, dilation=dilation) + self.bn1 = BatchNorm(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + stride=1, padding=dilation, + bias=False, dilation=dilation) + self.bn2 = BatchNorm(planes) + self.stride = stride + + def forward(self, x, residual=None): + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 2 + + def __init__(self, inplanes, planes, stride=1, dilation=1): + super(Bottleneck, self).__init__() + expansion = Bottleneck.expansion + bottle_planes = planes // expansion + self.conv1 = nn.Conv2d(inplanes, bottle_planes, + kernel_size=1, bias=False) + self.bn1 = BatchNorm(bottle_planes) + self.conv2 = nn.Conv2d(bottle_planes, bottle_planes, kernel_size=3, + stride=stride, padding=dilation, + bias=False, dilation=dilation) + self.bn2 = BatchNorm(bottle_planes) + self.conv3 = nn.Conv2d(bottle_planes, planes, + kernel_size=1, bias=False) + self.bn3 = BatchNorm(planes) + self.relu = nn.ReLU(inplace=True) + self.stride = stride + + def forward(self, x, residual=None): + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += residual + out = self.relu(out) + + return out + + +class BottleneckX(nn.Module): + expansion = 2 + cardinality = 32 + + def __init__(self, inplanes, planes, stride=1, dilation=1): + super(BottleneckX, self).__init__() + cardinality = BottleneckX.cardinality + bottle_planes = planes * cardinality // 32 + self.conv1 = nn.Conv2d(inplanes, bottle_planes, + kernel_size=1, bias=False) + self.bn1 = BatchNorm(bottle_planes) + self.conv2 = nn.Conv2d(bottle_planes, bottle_planes, kernel_size=3, + stride=stride, padding=dilation, bias=False, + dilation=dilation, groups=cardinality) + self.bn2 = BatchNorm(bottle_planes) + self.conv3 = nn.Conv2d(bottle_planes, planes, + kernel_size=1, bias=False) + self.bn3 = BatchNorm(planes) + self.relu = nn.ReLU(inplace=True) + self.stride = stride + + def forward(self, x, residual=None): + if residual is None: + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += residual + out = self.relu(out) + + return out + + +class Root(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, residual): + super(Root, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, 1, + stride=1, bias=False, padding=(kernel_size - 1) // 2) + self.bn = BatchNorm(out_channels) + self.relu = nn.ReLU(inplace=True) + self.residual = residual + + def forward(self, *x): + children = x + x = self.conv(torch.cat(x, 1)) + x = self.bn(x) + if self.residual: + x += children[0] + x = self.relu(x) + + return x + + +class Tree(nn.Module): + def __init__(self, levels, block, in_channels, out_channels, stride=1, + level_root=False, root_dim=0, root_kernel_size=1, + dilation=1, root_residual=False): + super(Tree, self).__init__() + if root_dim == 0: + root_dim = 2 * out_channels + if level_root: + root_dim += in_channels + if levels == 1: + self.tree1 = block(in_channels, out_channels, stride, + dilation=dilation) + self.tree2 = block(out_channels, out_channels, 1, + dilation=dilation) + else: + self.tree1 = Tree(levels - 1, block, in_channels, out_channels, + stride, root_dim=0, + root_kernel_size=root_kernel_size, + dilation=dilation, root_residual=root_residual) + self.tree2 = Tree(levels - 1, block, out_channels, out_channels, + root_dim=root_dim + out_channels, + root_kernel_size=root_kernel_size, + dilation=dilation, root_residual=root_residual) + if levels == 1: + self.root = Root(root_dim, out_channels, root_kernel_size, + root_residual) + self.level_root = level_root + self.root_dim = root_dim + self.downsample = None + self.project = None + self.levels = levels + if stride > 1: + self.downsample = nn.MaxPool2d(stride, stride=stride) + if in_channels != out_channels: + self.project = nn.Sequential( + nn.Conv2d(in_channels, out_channels, + kernel_size=1, stride=1, bias=False), + BatchNorm(out_channels) + ) + + def forward(self, x, residual=None, children=None): + children = [] if children is None else children + bottom = self.downsample(x) if self.downsample else x + residual = self.project(bottom) if self.project else bottom + if self.level_root: + children.append(bottom) + x1 = self.tree1(x, residual) + if self.levels == 1: + x2 = self.tree2(x1) + x = self.root(x2, x1, *children) + else: + children.append(x1) + x = self.tree2(x1, children=children) + return x + + +class DLA(nn.Module): + def __init__(self, levels, channels, num_classes=1000, + block=BasicBlock, residual_root=False, return_levels=False, + pool_size=7, linear_root=False): + super(DLA, self).__init__() + self.channels = channels + self.return_levels = return_levels + self.num_classes = num_classes + self.base_layer = nn.Sequential( + nn.Conv2d(3, channels[0], kernel_size=7, stride=1, + padding=3, bias=False), + BatchNorm(channels[0]), + nn.ReLU(inplace=True)) + self.level0 = self._make_conv_level( + channels[0], channels[0], levels[0]) + self.level1 = self._make_conv_level( + channels[0], channels[1], levels[1], stride=2) + self.level2 = Tree(levels[2], block, channels[1], channels[2], 2, + level_root=False, + root_residual=residual_root) + self.level3 = Tree(levels[3], block, channels[2], channels[3], 2, + level_root=True, root_residual=residual_root) + self.level4 = Tree(levels[4], block, channels[3], channels[4], 2, + level_root=True, root_residual=residual_root) + self.level5 = Tree(levels[5], block, channels[4], channels[5], 2, + level_root=True, root_residual=residual_root) + + self.avgpool = nn.AvgPool2d(pool_size) + self.fc = nn.Conv2d(channels[-1], num_classes, kernel_size=1, + stride=1, padding=0, bias=True) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, BatchNorm): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_level(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes: + downsample = nn.Sequential( + nn.MaxPool2d(stride, stride=stride), + nn.Conv2d(inplanes, planes, + kernel_size=1, stride=1, bias=False), + BatchNorm(planes), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample=downsample)) + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): + modules = [] + for i in range(convs): + modules.extend([ + nn.Conv2d(inplanes, planes, kernel_size=3, + stride=stride if i == 0 else 1, + padding=dilation, bias=False, dilation=dilation), + BatchNorm(planes), + nn.ReLU(inplace=True)]) + inplanes = planes + return nn.Sequential(*modules) + + def forward(self, x): + y = [] + x = self.base_layer(x) + for i in range(6): + x = getattr(self, 'level{}'.format(i))(x) + y.append(x) + if self.return_levels: + return y + else: + x = self.avgpool(x) + x = self.fc(x) + x = x.view(x.size(0), -1) + + return x + + def load_pretrained_model(self, data='imagenet', name='dla34', hash='ba72cf86'): + fc = self.fc + if name.endswith('.pth'): + model_weights = torch.load(data + name) + else: + model_url = get_model_url(data, name, hash) + model_weights = model_zoo.load_url(model_url) + num_classes = len(model_weights[list(model_weights.keys())[-1]]) + self.fc = nn.Conv2d( + self.channels[-1], num_classes, + kernel_size=1, stride=1, padding=0, bias=True) + self.load_state_dict(model_weights) + self.fc = fc + + +def dla34(pretrained, **kwargs): # DLA-34 + model = DLA([1, 1, 1, 2, 2, 1], + [16, 32, 64, 128, 256, 512], + block=BasicBlock, **kwargs) + if pretrained: + model.load_pretrained_model(data='imagenet', name='dla34', hash='ba72cf86') + return model + + +def dla46_c(pretrained=None, **kwargs): # DLA-46-C + Bottleneck.expansion = 2 + model = DLA([1, 1, 1, 2, 2, 1], + [16, 32, 64, 64, 128, 256], + block=Bottleneck, **kwargs) + if pretrained is not None: + model.load_pretrained_model(pretrained, 'dla46_c') + return model + + +def dla46x_c(pretrained=None, **kwargs): # DLA-X-46-C + BottleneckX.expansion = 2 + model = DLA([1, 1, 1, 2, 2, 1], + [16, 32, 64, 64, 128, 256], + block=BottleneckX, **kwargs) + if pretrained is not None: + model.load_pretrained_model(pretrained, 'dla46x_c') + return model + + +def dla60x_c(pretrained, **kwargs): # DLA-X-60-C + BottleneckX.expansion = 2 + model = DLA([1, 1, 1, 2, 3, 1], + [16, 32, 64, 64, 128, 256], + block=BottleneckX, **kwargs) + if pretrained: + model.load_pretrained_model(data='imagenet', name='dla60x_c', hash='b870c45c') + return model + + +def dla60(pretrained=None, **kwargs): # DLA-60 + Bottleneck.expansion = 2 + model = DLA([1, 1, 1, 2, 3, 1], + [16, 32, 128, 256, 512, 1024], + block=Bottleneck, **kwargs) + if pretrained is not None: + model.load_pretrained_model(pretrained, 'dla60') + return model + + +def dla60x(pretrained=None, **kwargs): # DLA-X-60 + BottleneckX.expansion = 2 + model = DLA([1, 1, 1, 2, 3, 1], + [16, 32, 128, 256, 512, 1024], + block=BottleneckX, **kwargs) + if pretrained is not None: + model.load_pretrained_model(pretrained, 'dla60x') + return model + + +def dla102(pretrained=None, **kwargs): # DLA-102 + Bottleneck.expansion = 2 + model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], + block=Bottleneck, residual_root=True, **kwargs) + if pretrained is not None: + model.load_pretrained_model(pretrained, 'dla102') + return model + + +def dla102x(pretrained=None, **kwargs): # DLA-X-102 + BottleneckX.expansion = 2 + model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], + block=BottleneckX, residual_root=True, **kwargs) + if pretrained is not None: + model.load_pretrained_model(pretrained, 'dla102x') + return model + + +def dla102x2(pretrained=None, **kwargs): # DLA-X-102 64 + BottleneckX.cardinality = 64 + model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], + block=BottleneckX, residual_root=True, **kwargs) + if pretrained is not None: + model.load_pretrained_model(pretrained, 'dla102x2') + return model + + +def dla169(pretrained=None, **kwargs): # DLA-169 + Bottleneck.expansion = 2 + model = DLA([1, 1, 2, 3, 5, 1], [16, 32, 128, 256, 512, 1024], + block=Bottleneck, residual_root=True, **kwargs) + if pretrained is not None: + model.load_pretrained_model(pretrained, 'dla169') + return model + + +class Identity(nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return x + + +def fill_up_weights(up): + w = up.weight.data + f = math.ceil(w.size(2) / 2) + c = (2 * f - 1 - f % 2) / (2. * f) + for i in range(w.size(2)): + for j in range(w.size(3)): + w[0, 0, i, j] = \ + (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) + for c in range(1, w.size(0)): + w[c, 0, :, :] = w[0, 0, :, :] + + +class IDAUp(nn.Module): + def __init__(self, node_kernel, out_dim, channels, up_factors): + super(IDAUp, self).__init__() + self.channels = channels + self.out_dim = out_dim + for i, c in enumerate(channels): + if c == out_dim: + proj = Identity() + else: + proj = nn.Sequential( + nn.Conv2d(c, out_dim, + kernel_size=1, stride=1, bias=False), + BatchNorm(out_dim), + nn.ReLU(inplace=True)) + f = int(up_factors[i]) + if f == 1: + up = Identity() + else: + up = nn.ConvTranspose2d( + out_dim, out_dim, f * 2, stride=f, padding=f // 2, + output_padding=0, groups=out_dim, bias=False) + fill_up_weights(up) + setattr(self, 'proj_' + str(i), proj) + setattr(self, 'up_' + str(i), up) + + for i in range(1, len(channels)): + node = nn.Sequential( + nn.Conv2d(out_dim * 2, out_dim, + kernel_size=node_kernel, stride=1, + padding=node_kernel // 2, bias=False), + BatchNorm(out_dim), + nn.ReLU(inplace=True)) + setattr(self, 'node_' + str(i), node) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, BatchNorm): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, layers): + assert len(self.channels) == len(layers), \ + '{} vs {} layers'.format(len(self.channels), len(layers)) + layers = list(layers) + for i, l in enumerate(layers): + upsample = getattr(self, 'up_' + str(i)) + project = getattr(self, 'proj_' + str(i)) + layers[i] = upsample(project(l)) + x = layers[0] + y = [] + for i in range(1, len(layers)): + node = getattr(self, 'node_' + str(i)) + x = node(torch.cat([x, layers[i]], 1)) + y.append(x) + return x, y + + +class DLAUp(nn.Module): + def __init__(self, channels, scales=(1, 2, 4, 8, 16), in_channels=None): + super(DLAUp, self).__init__() + if in_channels is None: + in_channels = channels + self.channels = channels + channels = list(channels) + scales = np.array(scales, dtype=int) + for i in range(len(channels) - 1): + j = -i - 2 + setattr(self, 'ida_{}'.format(i), + IDAUp(3, channels[j], in_channels[j:], + scales[j:] // scales[j])) + scales[j + 1:] = scales[j] + in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]] + + def forward(self, layers): + layers = list(layers) + assert len(layers) > 1 + for i in range(len(layers) - 1): + ida = getattr(self, 'ida_{}'.format(i)) + x, y = ida(layers[-i - 2:]) + layers[-i - 1:] = y + return x + +def fill_fc_weights(layers): + for m in layers.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, std=0.001) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + +class DLASeg(nn.Module): + def __init__(self, base_name, heads, + pretrained=True, down_ratio=4, head_conv=256): + super(DLASeg, self).__init__() + self.heads = heads + self.first_level = int(np.log2(down_ratio)) + self.base = globals()[base_name]( + pretrained=pretrained, return_levels=True) + channels = self.base.channels + scales = [2 ** i for i in range(len(channels[self.first_level:]))] + self.dla_up = DLAUp(channels[self.first_level:], scales=scales) + + for head in self.heads: + classes = self.heads[head] + if head_conv > 0: + fc = nn.Sequential( + nn.Conv2d(channels[self.first_level], head_conv, + kernel_size=3, padding=1, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(head_conv, classes, + kernel_size=1, stride=1, + padding=0, bias=True)) + if 'hm' in head: + fc[-1].bias.data.fill_(-2.19) + else: + fill_fc_weights(fc) + else: + fc = nn.Conv2d(channels[self.first_level], classes, + kernel_size=1, stride=1, + padding=0, bias=True) + if 'hm' in head: + fc.bias.data.fill_(-2.19) + else: + fill_fc_weights(fc) + self.__setattr__(head, fc) + + def forward(self, x): + x = self.base(x) + x = self.dla_up(x[self.first_level:]) + # x = self.fc(x) + # y = self.softmax(self.up(x)) + ret = {} + for head in self.heads: + ret[head] = self.__getattr__(head)(x) + return [ret] + +def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4): + model = DLASeg('dla{}'.format(num_layers), heads, + pretrained=True, + down_ratio=down_ratio, + head_conv=head_conv) + return model \ No newline at end of file diff --git a/scoutbot/tile_batched/herdnet_lmds.py b/scoutbot/tile_batched/herdnet_lmds.py new file mode 100644 index 0000000..18d2f3d --- /dev/null +++ b/scoutbot/tile_batched/herdnet_lmds.py @@ -0,0 +1,219 @@ +__copyright__ = \ + """ + Copyright (C) 2024 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life + All rights reserved. + + This source code is under the MIT License. + + Please contact the author Alexandre Delplanque (alexandre.delplanque@uliege.be) for any questions. + + Last modification: March 18, 2024 + """ +__author__ = "Alexandre Delplanque" +__license__ = "MIT License" +__version__ = "0.2.1" + +import torch +import numpy + +import torch.nn.functional as F + +from typing import Tuple, List + +__all__ = ['LMDS', 'HerdNetLMDS'] + + +class LMDS: + ''' Local Maxima Detection Strategy + + Adapted and enhanced from https://github.com/dk-liang/FIDTM (author: dklinag) + available under the MIT license ''' + + def __init__( + self, + kernel_size: tuple = (3,3), + adapt_ts: float = 100.0/255.0, + neg_ts: float = 0.1 + ) -> None: + ''' + Args: + kernel_size (tuple, optional): size of the kernel used to select local + maxima. Defaults to (3,3) (as in the paper). + adapt_ts (float, optional): adaptive threshold to select final points + from candidates. Defaults to 100.0/255.0 (as in the paper). + neg_ts (float, optional): negative sample threshold used to define if + an image is a negative sample or not. Defaults to 0.1 (as in the paper). + ''' + + assert kernel_size[0] == kernel_size[1], \ + f'The kernel shape must be a square, got {kernel_size[0]}x{kernel_size[1]}' + assert not kernel_size[0] % 2 == 0, \ + f'The kernel size must be odd, got {kernel_size[0]}' + + self.kernel_size = tuple(kernel_size) + self.adapt_ts = adapt_ts + self.neg_ts = neg_ts + + def __call__(self, est_map: torch.Tensor) -> Tuple[list,list,list,list]: + ''' + Args: + est_map (torch.Tensor): the estimated FIDT map + + Returns: + Tuple[list,list,list,list] + counts, labels, scores and locations per batch + ''' + batch_size, classes = est_map.shape[:2] + + b_counts, b_labels, b_scores, b_locs = [], [], [], [] + for b in range(batch_size): + counts, labels, scores, locs = [], [], [], [] + + for c in range(classes): + count, loc, score = self._lmds(est_map[b][c]) + counts.append(count) + labels = [*labels, *[c+1]*count] + scores = [*scores, *score] + locs = [*locs, *loc] + + b_counts.append(counts) + b_labels.append(labels) + b_scores.append(scores) + b_locs.append(locs) + + return b_counts, b_locs, b_labels, b_scores + + def _local_max(self, est_map: torch.Tensor) -> torch.Tensor: + ''' Shape: est_map = [B,C,H,W] ''' + + pad = int(self.kernel_size[0] / 2) + keep = torch.nn.functional.max_pool2d(est_map, kernel_size=self.kernel_size, stride=1, padding=pad) + keep = (keep == est_map).float() + est_map = keep * est_map + + return est_map + + def _get_locs_and_scores( + self, + locs_map: torch.Tensor, + scores_map: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + ''' Shapes: locs_map = [H,W] and scores_map = [H,W] ''' + + locs_map = locs_map.data.cpu().numpy() + scores_map = scores_map.data.cpu().numpy() + locs = [] + scores = [] + for i, j in numpy.argwhere(locs_map == 1): + locs.append((i,j)) + scores.append(scores_map[i][j]) + + return torch.Tensor(locs), torch.Tensor(scores) + + def _lmds(self, est_map: torch.Tensor) -> Tuple[int, list, list]: + ''' Shape: est_map = [H,W] ''' + + est_map_max = torch.max(est_map).item() + + # local maxima + est_map = self._local_max(est_map.unsqueeze(0).unsqueeze(0)) + + # adaptive threshold for counting + est_map[est_map < self.adapt_ts * est_map_max] = 0 + scores_map = torch.clone(est_map) + est_map[est_map > 0] = 1 + + # negative sample + if est_map_max < self.neg_ts: + est_map = est_map * 0 + + # count + count = int(torch.sum(est_map).item()) + + # locations and scores + locs, scores = self._get_locs_and_scores( + est_map.squeeze(0).squeeze(0), + scores_map.squeeze(0).squeeze(0) + ) + + return count, locs.tolist(), scores.tolist() + +class HerdNetLMDS(LMDS): + + def __init__( + self, + up: bool = True, + kernel_size: tuple = (3,3), + adapt_ts: float = 0.3, + neg_ts: float = 0.1 + ) -> None: + ''' + Args: + up (bool, optional): set to False to disable class maps upsampling. + Defaults to True. + kernel_size (tuple, optional): size of the kernel used to select local + maxima. Defaults to (3,3) (as in the paper). + adapt_ts (float, optional): adaptive threshold to select final points + from candidates. Defaults to 0.3. + neg_ts (float, optional): negative sample threshold used to define if + an image is a negative sample or not. Defaults to 0.1 (as in the paper). + ''' + + super().__init__(kernel_size=kernel_size, adapt_ts=adapt_ts, neg_ts=neg_ts) + + self.up = up + + def __call__(self, outputs: List[torch.Tensor]) -> Tuple[list, list, list, list, list]: + ''' + Args: + outmaps (torch.Tensor): outputs of HerdNet, i.e. 2 tensors: + - heatmap: [B,1,H,W], + - class map: [B,C,H/16,W/16], + + Returns: + Tuple[list,list,list,list,list] + counts, locations, labels, class scores and detection scores per batch + ''' + + heatmap, clsmap = outputs + + # upsample class map + if self.up: + scale_factor = 16 + clsmap = F.interpolate(clsmap, scale_factor=scale_factor, mode='nearest') + + # softmax + cls_scores = torch.softmax(clsmap, dim=1)[:,1:,:,:] + + # cat to heatmap + outmaps = torch.cat([heatmap, cls_scores], dim=1) + + # LMDS + batch_size, channels = outmaps.shape[:2] + + b_counts, b_labels, b_scores, b_locs, b_dscores = [], [], [], [], [] + for b in range(batch_size): + + _, locs, _ = self._lmds(heatmap[b][0]) + + cls_idx = torch.argmax(clsmap[b,1:,:,:], dim=0) + classes = torch.add(cls_idx, 1) + + h_idx = torch.Tensor([l[0] for l in locs]).long() + w_idx = torch.Tensor([l[1] for l in locs]).long() + labels = classes[h_idx, w_idx].long().tolist() + + chan_idx = cls_idx[h_idx, w_idx].long().tolist() + scores = cls_scores[b, chan_idx, h_idx, w_idx].float().tolist() + + dscores = heatmap[b, 0, h_idx, w_idx].float().tolist() + + counts = [labels.count(i) for i in range(1, channels)] + + b_labels.append(labels) + b_scores.append(scores) + b_locs.append(locs) + b_counts.append(counts) + b_dscores.append(dscores) + + return b_counts, b_locs, b_labels, b_scores, b_dscores \ No newline at end of file diff --git a/scoutbot/tile_batched/herdnet_model.py b/scoutbot/tile_batched/herdnet_model.py new file mode 100644 index 0000000..60672f0 --- /dev/null +++ b/scoutbot/tile_batched/herdnet_model.py @@ -0,0 +1,165 @@ +import torch +import torch.nn as nn +import numpy as np +import torchvision.transforms as T +from torch.utils.data import TensorDataset, DataLoader, SequentialSampler + +from typing import Optional + +from . import dla as dla_modules +from .herdnet_lmds import HerdNetLMDS +from typing import List, Optional, Union, Dict, Tuple + +class HerdNet(nn.Module): + ''' HerdNet architecture ''' + + def __init__( + self, + num_layers: int = 34, + num_classes: int = 2, + pretrained: bool = True, + down_ratio: Optional[int] = 2, + head_conv: int = 64 + ): + ''' + Args: + num_layers (int, optional): number of layers of DLA. Defaults to 34. + num_classes (int, optional): number of output classes, background included. + Defaults to 2. + pretrained (bool, optional): set False to disable pretrained DLA encoder parameters + from ImageNet. Defaults to True. + down_ratio (int, optional): downsample ratio. Possible values are 1, 2, 4, 8, or 16. + Set to 1 to get output of the same size as input (i.e. no downsample). + Defaults to 2. + head_conv (int, optional): number of supplementary convolutional layers at the end + of decoder. Defaults to 64. + ''' + + super(HerdNet, self).__init__() + + assert down_ratio in [1, 2, 4, 8, 16], \ + f'Downsample ratio possible values are 1, 2, 4, 8 or 16, got {down_ratio}' + + base_name = 'dla{}'.format(num_layers) + + self.down_ratio = down_ratio + self.num_classes = num_classes + self.head_conv = head_conv + + self.first_level = int(np.log2(down_ratio)) + + # backbone + base = dla_modules.__dict__[base_name](pretrained=pretrained, return_levels=True) + setattr(self, 'base_0', base) + setattr(self, 'channels_0', base.channels) + + channels = self.channels_0 + + scales = [2 ** i for i in range(len(channels[self.first_level:]))] + self.dla_up = dla_modules.DLAUp(channels[self.first_level:], scales=scales) + + # bottleneck conv + self.bottleneck_conv = nn.Conv2d( + channels[-1], channels[-1], + kernel_size=1, stride=1, + padding=0, bias=True + ) + + # localization head + self.loc_head = nn.Sequential( + nn.Conv2d(channels[self.first_level], head_conv, + kernel_size=3, padding=1, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d( + head_conv, 1, + kernel_size=1, stride=1, + padding=0, bias=True + ), + nn.Sigmoid() + ) + + self.loc_head[-2].bias.data.fill_(0.00) + + # classification head + self.cls_head = nn.Sequential( + nn.Conv2d(channels[-1], head_conv, + kernel_size=3, padding=1, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d( + head_conv, self.num_classes, + kernel_size=1, stride=1, + padding=0, bias=True + ) + ) + + self.cls_head[-1].bias.data.fill_(0.00) + + # Local Maxima Detection Strategy + lmds_kwargs: dict = {'kernel_size': (3, 3), 'adapt_ts': 0.2, 'neg_ts': 0.1} + self.lmds = HerdNetLMDS(up=False, **lmds_kwargs) + + def forward(self, input: torch.Tensor): + encode = self.base_0(input) + bottleneck = self.bottleneck_conv(encode[-1]) + encode[-1] = bottleneck + decode_hm = self.dla_up(encode[self.first_level:]) + heatmap = self.loc_head(decode_hm) + clsmap = self.cls_head(bottleneck) + + return heatmap, clsmap + + def freeze(self, layers: list) -> None: + ''' Freeze all layers mentioned in the input list ''' + for layer in layers: + self._freeze_layer(layer) + + def _freeze_layer(self, layer_name: str) -> None: + for param in getattr(self, layer_name).parameters(): + param.requires_grad = False + + def reshape_classes(self, num_classes: int) -> None: + ''' Reshape architecture according to a new number of classes. + + Arg: + num_classes (int): new number of classes + ''' + + self.cls_head[-1] = nn.Conv2d( + self.head_conv, num_classes, + kernel_size=1, stride=1, + padding=0, bias=True + ) + + self.cls_head[-1].bias.data.fill_(0.00) + + self.num_classes = num_classes + + @torch.no_grad() + def batch_image_detection(self, images: List[np.ndarray], transforms: T.Compose, batch_size: int = 1, device: str = 'cuda:0'): + self.eval() + self.device = device + self.to(self.device) + # convert images to a tensor of shape [len(images), C, H, W] + images = torch.stack([transforms(image) for image in images]) + dataset = TensorDataset(images) + dataloader = DataLoader(dataset, sampler=SequentialSampler(dataset), batch_size=batch_size) + counts, locs, labels, scores, dscores = [], [], [], [], [] + for patch in dataloader: + patch = patch[0].to(self.device) + outputs = self(patch) + heatmap = outputs[0] + clsmap = nn.functional.interpolate(outputs[1], scale_factor=16, mode='nearest') + outmaps = torch.cat([heatmap, clsmap], dim=1) + # (Upsample) + outmaps = nn.functional.interpolate(outmaps, scale_factor=2, mode='bilinear', align_corners=True) + heatmap, clsmap = outmaps[:,:1,:,:], outmaps[:,1:,:,:] + # Local Maxima Detection Strategy (LMDS) + counts_patch, locs_patch, labels_patch, scores_patch, dscores_patch = self.lmds((heatmap, clsmap)) + counts.append(counts_patch) + locs.append(locs_patch) + labels.append(labels_patch) + scores.append(scores_patch) + dscores.append(dscores_patch) + + return counts, locs, labels, scores, dscores + diff --git a/scoutbot/tile_batched/main.py b/scoutbot/tile_batched/main.py index e14e295..bc31641 100644 --- a/scoutbot/tile_batched/main.py +++ b/scoutbot/tile_batched/main.py @@ -14,6 +14,12 @@ PostprocessPredictions, ) from sahi.models.yolov8 import Yolov8DetectionModel as Yolov8DetectionModelBase +from sahi.models.base import DetectionModel +from sahi.utils.cv import read_image_as_pil +from .herdnet_model import HerdNet # NOQA +import torch +import torchvision.transforms as T +from tqdm import tqdm POSTPROCESS_NAME_TO_CLASS = { "GREEDYNMM": GreedyNMMPostprocess, @@ -71,6 +77,203 @@ def perform_inference(self, images: List[np.ndarray], batch_size: Optional[int] self._original_predictions = prediction_result +class HerdNetDetectionModel(DetectionModel): + + def __init__( + self, + model_path: Optional[str] = None, + device: Optional[str] = None, + confidence_threshold: float = 0.2, + category_mapping: Optional[Dict] = None, + category_remapping: Optional[Dict] = None, + load_at_init: bool = True, + image_size: int = None, + batch_size: Optional[int] = None, + dataset: str = 'general', + ): + self.dataset = dataset + self.batch_size = batch_size or DETECTOR_BATCH_SIZE + super().__init__( + model_path=model_path, + device=device, + confidence_threshold=confidence_threshold, + category_mapping=category_mapping, + category_remapping=category_remapping, + load_at_init=load_at_init, + image_size=image_size, + ) + + def load_model(self): + if self.model_path: + weights = os.path.join(torch.hub.get_dir(), "checkpoints", self.model_path) + checkpoint = torch.load(weights, map_location=torch.device(self.device)) + self.CLASS_NAMES = checkpoint["classes"] + self.num_classes = len(self.CLASS_NAMES) + 1 + self.img_mean = checkpoint['mean'] + self.img_std = checkpoint['std'] + self.transforms = T.Compose([ + T.ToTensor(), + T.Normalize(mean=self.img_mean, std=self.img_std) + ]) + self.model = HerdNet(num_classes=self.num_classes, pretrained=False) + state_dict = checkpoint['model_state_dict'] + new_state_dict = {k.replace('model.', ''): v for k, v in state_dict.items() if k.startswith('model.')} + self.model.load_state_dict(new_state_dict, strict=True) + + else: + self.model = HerdNet(pretrained=False) + self.transforms = T.Compose([ + T.ToTensor() + ]) + + def perform_inference(self, images: List[np.ndarray], batch_size: Optional[int] = None): + """ + Perform inference using the model and store predictions. + + Args: + images (List[np.ndarray]): List of images as numpy arrays for prediction. + batch_size (Optional[int]): Batch size for inference (kept for interface compatibility, + but HerdNet always uses batch_size=1). + + Note: + HerdNet always processes images with batch_size=1 regardless of the parameter value. + """ + # HerdNet always uses batch_size=1, ignoring the parameter for consistency with the model's requirements + all_preds = [] + counts, locs, labels, scores, dscores = self.model.batch_image_detection(images, + self.transforms, + batch_size=1, + device = self.device) + all_preds = [] + for i in range(len(counts)): + if sum(counts[i][0]) == 0: + all_preds.append([]) # add empty array to all_preds because there are no detections for this image + continue + preds_array_i = self.process_lmds_results(counts[i], locs[i], labels[i], scores[i], dscores[i], det_conf_thres=0.2, clf_conf_thres=0.2) + all_preds.append(preds_array_i) + self._original_predictions = all_preds + + def process_lmds_results(self, counts, locs, labels, scores, dscores, det_conf_thres=0.2, clf_conf_thres=0.2): + """ + Process the results from the Local Maxima Detection Strategy. + + Args: + counts (list): + Number of detections for each species. + locs (list): + Locations of the detections. + labels (list): + Labels of the detections. + scores (list): + Scores of the detections. + dscores (list): + Detection scores. + det_conf_thres (float, optional): + Confidence threshold for detections. Defaults to 0.2. + clf_conf_thres (float, optional): + Confidence threshold for classification. Defaults to 0.2. + + Returns: + numpy.ndarray: Processed detection results. + """ + # Flatten the lists since its a single image + counts = counts[0] + locs = locs[0] + labels = labels[0] + scores = scores[0] + dscores = dscores[0] + + total_detections = sum(counts) + preds_array = np.empty((total_detections, 6)) #xyxy, confidence, class_id + detection_idx = 0 + valid_detections_idx = 0 + # Loop through each species + for specie_idx in range(len(counts)): + count = counts[specie_idx] + if count == 0: + continue + + # Get the detections for this species + species_locs = np.array(locs[detection_idx : detection_idx + count]) + species_locs[:, [0, 1]] = species_locs[:, [1, 0]] # Swap x and y in species_locs (herdnet uses y, x format) + species_scores = np.array(scores[detection_idx : detection_idx + count]) + species_dscores = np.array(dscores[detection_idx : detection_idx + count]) + species_labels = np.array(labels[detection_idx : detection_idx + count]) + + # Apply the confidence threshold + valid_detections_by_clf_score = species_scores > clf_conf_thres + valid_detections_by_det_score = species_dscores > det_conf_thres + valid_detections = np.logical_and(valid_detections_by_clf_score, valid_detections_by_det_score) + valid_detections_count = np.sum(valid_detections) + valid_detections_idx += valid_detections_count + # Fill the preds_array with the valid detections + if valid_detections_count > 0: + preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, :2] = species_locs[valid_detections] - 2 + preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, 2:4] = species_locs[valid_detections] + 2 + preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, 4] = species_scores[valid_detections] + preds_array[valid_detections_idx - valid_detections_count : valid_detections_idx, 5] = species_labels[valid_detections] + + detection_idx += count # Move to the next species + + preds_array = preds_array[:valid_detections_idx] # Remove the empty rows + + return preds_array + + def _create_object_prediction_list_from_original_predictions( + self, + shift_amount_list: Optional[List[List[int]]] = [[0, 0]], + full_shape_list: Optional[List[List[int]]] = None, + ): + # Convert the predictions from HerdNet into a list of ObjectPrediction + original_predictions = self._original_predictions + object_prediction_list_per_image = [] + for image_ind, image_predictions in enumerate(original_predictions): + shift_amount = shift_amount_list[image_ind] + full_shape = None if full_shape_list is None else full_shape_list[image_ind] + object_prediction_list = [] + for prediction in image_predictions: + x1 = prediction[0] + y1 = prediction[1] + x2 = prediction[2] + y2 = prediction[3] + bbox = [x1, y1, x2, y2] + score = prediction[4] + category_id = int(prediction[5]) + #category_name = self.category_mapping[str(category_id)] + category_name = None # TODO: Get category name from category_mapping + + # fix negative box coords + bbox[0] = max(0, bbox[0]) + bbox[1] = max(0, bbox[1]) + bbox[2] = max(0, bbox[2]) + bbox[3] = max(0, bbox[3]) + + # fix out of image box coords + if full_shape is not None: + bbox[0] = min(full_shape[1], bbox[0]) + bbox[1] = min(full_shape[0], bbox[1]) + bbox[2] = min(full_shape[1], bbox[2]) + bbox[3] = min(full_shape[0], bbox[3]) + + if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]): + print(f"ignoring invalid prediction with bbox: {bbox}") + continue + + object_prediction = ObjectPrediction( + bbox=bbox, + category_id=category_id, + score=score, + segmentation=None, + category_name=category_name, + shift_amount=shift_amount, + full_shape=full_shape, + ) + object_prediction_list.append(object_prediction) + object_prediction_list_per_image.append(object_prediction_list) + + self._object_prediction_list_per_image = object_prediction_list_per_image + + class PredictionResult: def __init__( self, @@ -340,7 +543,7 @@ def slice_image( # create sliced image and append to sliced_image_result sliced_image = SlicedImage( - image=image_pil_slice, starting_pixel=[slice_bbox[0], slice_bbox[1]] + image=image_pil_slice, starting_pixel=[tlx, tly] ) sliced_image_result.add_sliced_image(sliced_image)