mirror of
https://github.com/RootKit-Org/AI-Aimbot.git
synced 2025-06-21 02:41:01 +08:00
Updated yolov5 dependency
This commit is contained in:
parent
6dca4d84aa
commit
c9b239078f
223
models/common.py
223
models/common.py
@ -17,15 +17,15 @@ import pandas as pd
|
|||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import yaml
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.cuda import amp
|
from torch.cuda import amp
|
||||||
|
|
||||||
from utils.dataloaders import exif_transpose, letterbox
|
from utils.dataloaders import exif_transpose, letterbox
|
||||||
from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
|
from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
|
||||||
make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
|
increment_path, make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh,
|
||||||
|
yaml_load)
|
||||||
from utils.plots import Annotator, colors, save_one_box
|
from utils.plots import Annotator, colors, save_one_box
|
||||||
from utils.torch_utils import copy_attr, time_sync
|
from utils.torch_utils import copy_attr, smart_inference_mode
|
||||||
|
|
||||||
|
|
||||||
def autopad(k, p=None): # kernel, padding
|
def autopad(k, p=None): # kernel, padding
|
||||||
@ -322,13 +322,10 @@ class DetectMultiBackend(nn.Module):
|
|||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
w = str(weights[0] if isinstance(weights, list) else weights)
|
w = str(weights[0] if isinstance(weights, list) else weights)
|
||||||
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self.model_type(w) # get backend
|
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self._model_type(w) # get backend
|
||||||
w = attempt_download(w) # download if not local
|
w = attempt_download(w) # download if not local
|
||||||
fp16 &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16
|
fp16 &= pt or jit or onnx or engine # FP16
|
||||||
stride, names = 32, [f'class{i}' for i in range(1000)] # assign defaults
|
stride = 32 # default stride
|
||||||
if data: # assign class names (optional)
|
|
||||||
with open(data, errors='ignore') as f:
|
|
||||||
names = yaml.safe_load(f)['names']
|
|
||||||
|
|
||||||
if pt: # PyTorch
|
if pt: # PyTorch
|
||||||
model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
|
model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
|
||||||
@ -341,8 +338,10 @@ class DetectMultiBackend(nn.Module):
|
|||||||
extra_files = {'config.txt': ''} # model metadata
|
extra_files = {'config.txt': ''} # model metadata
|
||||||
model = torch.jit.load(w, _extra_files=extra_files)
|
model = torch.jit.load(w, _extra_files=extra_files)
|
||||||
model.half() if fp16 else model.float()
|
model.half() if fp16 else model.float()
|
||||||
if extra_files['config.txt']:
|
if extra_files['config.txt']: # load metadata dict
|
||||||
d = json.loads(extra_files['config.txt']) # extra_files dict
|
d = json.loads(extra_files['config.txt'],
|
||||||
|
object_hook=lambda d: {int(k) if k.isdigit() else k: v
|
||||||
|
for k, v in d.items()})
|
||||||
stride, names = int(d['stride']), d['names']
|
stride, names = int(d['stride']), d['names']
|
||||||
elif dnn: # ONNX OpenCV DNN
|
elif dnn: # ONNX OpenCV DNN
|
||||||
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
|
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
|
||||||
@ -350,7 +349,7 @@ class DetectMultiBackend(nn.Module):
|
|||||||
net = cv2.dnn.readNetFromONNX(w)
|
net = cv2.dnn.readNetFromONNX(w)
|
||||||
elif onnx: # ONNX Runtime
|
elif onnx: # ONNX Runtime
|
||||||
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
|
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
|
||||||
cuda = torch.cuda.is_available()
|
cuda = torch.cuda.is_available() and device.type != 'cpu'
|
||||||
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
|
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
|
||||||
@ -380,23 +379,30 @@ class DetectMultiBackend(nn.Module):
|
|||||||
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
||||||
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
||||||
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
|
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
|
||||||
|
if device.type == 'cpu':
|
||||||
|
device = torch.device('cuda:0')
|
||||||
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
||||||
logger = trt.Logger(trt.Logger.INFO)
|
logger = trt.Logger(trt.Logger.INFO)
|
||||||
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
||||||
model = runtime.deserialize_cuda_engine(f.read())
|
model = runtime.deserialize_cuda_engine(f.read())
|
||||||
|
context = model.create_execution_context()
|
||||||
bindings = OrderedDict()
|
bindings = OrderedDict()
|
||||||
fp16 = False # default updated below
|
fp16 = False # default updated below
|
||||||
|
dynamic = False
|
||||||
for index in range(model.num_bindings):
|
for index in range(model.num_bindings):
|
||||||
name = model.get_binding_name(index)
|
name = model.get_binding_name(index)
|
||||||
dtype = trt.nptype(model.get_binding_dtype(index))
|
dtype = trt.nptype(model.get_binding_dtype(index))
|
||||||
shape = tuple(model.get_binding_shape(index))
|
if model.binding_is_input(index):
|
||||||
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
|
if -1 in tuple(model.get_binding_shape(index)): # dynamic
|
||||||
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
|
dynamic = True
|
||||||
if model.binding_is_input(index) and dtype == np.float16:
|
context.set_binding_shape(index, tuple(model.get_profile_shape(0, index)[2]))
|
||||||
fp16 = True
|
if dtype == np.float16:
|
||||||
|
fp16 = True
|
||||||
|
shape = tuple(context.get_binding_shape(index))
|
||||||
|
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
|
||||||
|
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
||||||
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
||||||
context = model.create_execution_context()
|
batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
|
||||||
batch_size = bindings['images'].shape[0]
|
|
||||||
elif coreml: # CoreML
|
elif coreml: # CoreML
|
||||||
LOGGER.info(f'Loading {w} for CoreML inference...')
|
LOGGER.info(f'Loading {w} for CoreML inference...')
|
||||||
import coremltools as ct
|
import coremltools as ct
|
||||||
@ -440,9 +446,16 @@ class DetectMultiBackend(nn.Module):
|
|||||||
input_details = interpreter.get_input_details() # inputs
|
input_details = interpreter.get_input_details() # inputs
|
||||||
output_details = interpreter.get_output_details() # outputs
|
output_details = interpreter.get_output_details() # outputs
|
||||||
elif tfjs:
|
elif tfjs:
|
||||||
raise Exception('ERROR: YOLOv5 TF.js inference is not supported')
|
raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported')
|
||||||
else:
|
else:
|
||||||
raise Exception(f'ERROR: {w} is not a supported format')
|
raise NotImplementedError(f'ERROR: {w} is not a supported format')
|
||||||
|
|
||||||
|
# class names
|
||||||
|
if 'names' not in locals():
|
||||||
|
names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)}
|
||||||
|
if names[0] == 'n01440764' and len(names) == 1000: # ImageNet
|
||||||
|
names = yaml_load(ROOT / 'data/ImageNet.yaml')['names'] # human-readable names
|
||||||
|
|
||||||
self.__dict__.update(locals()) # assign all variables to self
|
self.__dict__.update(locals()) # assign all variables to self
|
||||||
|
|
||||||
def forward(self, im, augment=False, visualize=False, val=False):
|
def forward(self, im, augment=False, visualize=False, val=False):
|
||||||
@ -452,7 +465,9 @@ class DetectMultiBackend(nn.Module):
|
|||||||
im = im.half() # to FP16
|
im = im.half() # to FP16
|
||||||
|
|
||||||
if self.pt: # PyTorch
|
if self.pt: # PyTorch
|
||||||
y = self.model(im, augment=augment, visualize=visualize)[0]
|
y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
|
||||||
|
if isinstance(y, tuple):
|
||||||
|
y = y[0]
|
||||||
elif self.jit: # TorchScript
|
elif self.jit: # TorchScript
|
||||||
y = self.model(im)[0]
|
y = self.model(im)[0]
|
||||||
elif self.dnn: # ONNX OpenCV DNN
|
elif self.dnn: # ONNX OpenCV DNN
|
||||||
@ -466,7 +481,13 @@ class DetectMultiBackend(nn.Module):
|
|||||||
im = im.cpu().numpy() # FP32
|
im = im.cpu().numpy() # FP32
|
||||||
y = self.executable_network([im])[self.output_layer]
|
y = self.executable_network([im])[self.output_layer]
|
||||||
elif self.engine: # TensorRT
|
elif self.engine: # TensorRT
|
||||||
assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape)
|
if self.dynamic and im.shape != self.bindings['images'].shape:
|
||||||
|
i_in, i_out = (self.model.get_binding_index(x) for x in ('images', 'output'))
|
||||||
|
self.context.set_binding_shape(i_in, im.shape) # reshape if dynamic
|
||||||
|
self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
|
||||||
|
self.bindings['output'].data.resize_(tuple(self.context.get_binding_shape(i_out)))
|
||||||
|
s = self.bindings['images'].shape
|
||||||
|
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
|
||||||
self.binding_addrs['images'] = int(im.data_ptr())
|
self.binding_addrs['images'] = int(im.data_ptr())
|
||||||
self.context.execute_v2(list(self.binding_addrs.values()))
|
self.context.execute_v2(list(self.binding_addrs.values()))
|
||||||
y = self.bindings['output'].data
|
y = self.bindings['output'].data
|
||||||
@ -510,14 +531,14 @@ class DetectMultiBackend(nn.Module):
|
|||||||
# Warmup model by running inference once
|
# Warmup model by running inference once
|
||||||
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb
|
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb
|
||||||
if any(warmup_types) and self.device.type != 'cpu':
|
if any(warmup_types) and self.device.type != 'cpu':
|
||||||
im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
||||||
for _ in range(2 if self.jit else 1): #
|
for _ in range(2 if self.jit else 1): #
|
||||||
self.forward(im) # warmup
|
self.forward(im) # warmup
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def model_type(p='path/to/model.pt'):
|
def _model_type(p='path/to/model.pt'):
|
||||||
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
||||||
# from export import export_formats
|
from export import export_formats
|
||||||
suffixes = list(export_formats().Suffix) + ['.xml'] # export suffixes
|
suffixes = list(export_formats().Suffix) + ['.xml'] # export suffixes
|
||||||
check_suffix(p, suffixes) # checks
|
check_suffix(p, suffixes) # checks
|
||||||
p = Path(p).name # eliminate trailing separators
|
p = Path(p).name # eliminate trailing separators
|
||||||
@ -529,25 +550,9 @@ class DetectMultiBackend(nn.Module):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_metadata(f='path/to/meta.yaml'):
|
def _load_metadata(f='path/to/meta.yaml'):
|
||||||
# Load metadata from meta.yaml if it exists
|
# Load metadata from meta.yaml if it exists
|
||||||
with open(f, errors='ignore') as f:
|
d = yaml_load(f)
|
||||||
d = yaml.safe_load(f)
|
|
||||||
return d['stride'], d['names'] # assign stride, names
|
return d['stride'], d['names'] # assign stride, names
|
||||||
|
|
||||||
def export_formats():
|
|
||||||
# YOLOv5 export formats
|
|
||||||
x = [
|
|
||||||
['PyTorch', '-', '.pt', True, True],
|
|
||||||
['TorchScript', 'torchscript', '.torchscript', True, True],
|
|
||||||
['ONNX', 'onnx', '.onnx', True, True],
|
|
||||||
['OpenVINO', 'openvino', '_openvino_model', True, False],
|
|
||||||
['TensorRT', 'engine', '.engine', False, True],
|
|
||||||
['CoreML', 'coreml', '.mlmodel', True, False],
|
|
||||||
['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
|
|
||||||
['TensorFlow GraphDef', 'pb', '.pb', True, True],
|
|
||||||
['TensorFlow Lite', 'tflite', '.tflite', True, False],
|
|
||||||
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],
|
|
||||||
['TensorFlow.js', 'tfjs', '_web_model', False, False],]
|
|
||||||
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
|
||||||
|
|
||||||
class AutoShape(nn.Module):
|
class AutoShape(nn.Module):
|
||||||
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
|
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
|
||||||
@ -567,6 +572,9 @@ class AutoShape(nn.Module):
|
|||||||
self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
|
self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
|
||||||
self.pt = not self.dmb or model.pt # PyTorch model
|
self.pt = not self.dmb or model.pt # PyTorch model
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
|
if self.pt:
|
||||||
|
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
|
||||||
|
m.inplace = False # Detect.inplace=False for safe multithread inference
|
||||||
|
|
||||||
def _apply(self, fn):
|
def _apply(self, fn):
|
||||||
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
||||||
@ -579,10 +587,10 @@ class AutoShape(nn.Module):
|
|||||||
m.anchor_grid = list(map(fn, m.anchor_grid))
|
m.anchor_grid = list(map(fn, m.anchor_grid))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@torch.no_grad()
|
@smart_inference_mode()
|
||||||
def forward(self, imgs, size=640, augment=False, profile=False):
|
def forward(self, ims, size=640, augment=False, profile=False):
|
||||||
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
|
# Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:
|
||||||
# file: imgs = 'data/images/zidane.jpg' # str or PosixPath
|
# file: ims = 'data/images/zidane.jpg' # str or PosixPath
|
||||||
# URI: = 'https://ultralytics.com/images/zidane.jpg'
|
# URI: = 'https://ultralytics.com/images/zidane.jpg'
|
||||||
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
|
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
|
||||||
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
|
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
|
||||||
@ -590,65 +598,67 @@ class AutoShape(nn.Module):
|
|||||||
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
|
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
|
||||||
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
||||||
|
|
||||||
t = [time_sync()]
|
dt = (Profile(), Profile(), Profile())
|
||||||
p = next(self.model.parameters()) if self.pt else torch.zeros(1, device=self.model.device) # for device, type
|
with dt[0]:
|
||||||
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
|
if isinstance(size, int): # expand
|
||||||
if isinstance(imgs, torch.Tensor): # torch
|
size = (size, size)
|
||||||
with amp.autocast(autocast):
|
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
|
||||||
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
|
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
|
||||||
|
if isinstance(ims, torch.Tensor): # torch
|
||||||
|
with amp.autocast(autocast):
|
||||||
|
return self.model(ims.to(p.device).type_as(p), augment, profile) # inference
|
||||||
|
|
||||||
# Pre-process
|
# Pre-process
|
||||||
n, imgs = (len(imgs), list(imgs)) if isinstance(imgs, (list, tuple)) else (1, [imgs]) # number, list of images
|
n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
|
||||||
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
|
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
|
||||||
for i, im in enumerate(imgs):
|
for i, im in enumerate(ims):
|
||||||
f = f'image{i}' # filename
|
f = f'image{i}' # filename
|
||||||
if isinstance(im, (str, Path)): # filename or uri
|
if isinstance(im, (str, Path)): # filename or uri
|
||||||
im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
|
im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
|
||||||
im = np.asarray(exif_transpose(im))
|
im = np.asarray(exif_transpose(im))
|
||||||
elif isinstance(im, Image.Image): # PIL Image
|
elif isinstance(im, Image.Image): # PIL Image
|
||||||
im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
|
im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
|
||||||
files.append(Path(f).with_suffix('.jpg').name)
|
files.append(Path(f).with_suffix('.jpg').name)
|
||||||
if im.shape[0] < 5: # image in CHW
|
if im.shape[0] < 5: # image in CHW
|
||||||
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
|
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
|
||||||
im = im[..., :3] if im.ndim == 3 else np.tile(im[..., None], 3) # enforce 3ch input
|
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
|
||||||
s = im.shape[:2] # HWC
|
s = im.shape[:2] # HWC
|
||||||
shape0.append(s) # image shape
|
shape0.append(s) # image shape
|
||||||
g = (size / max(s)) # gain
|
g = max(size) / max(s) # gain
|
||||||
shape1.append([y * g for y in s])
|
shape1.append([y * g for y in s])
|
||||||
imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
|
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
|
||||||
shape1 = [make_divisible(x, self.stride) if self.pt else size for x in np.array(shape1).max(0)] # inf shape
|
shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape
|
||||||
x = [letterbox(im, shape1, auto=False)[0] for im in imgs] # pad
|
x = [letterbox(im, shape1, auto=False)[0] for im in ims] # pad
|
||||||
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
|
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
|
||||||
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
||||||
t.append(time_sync())
|
|
||||||
|
|
||||||
with amp.autocast(autocast):
|
with amp.autocast(autocast):
|
||||||
# Inference
|
# Inference
|
||||||
y = self.model(x, augment, profile) # forward
|
with dt[1]:
|
||||||
t.append(time_sync())
|
y = self.model(x, augment, profile) # forward
|
||||||
|
|
||||||
# Post-process
|
# Post-process
|
||||||
y = non_max_suppression(y if self.dmb else y[0],
|
with dt[2]:
|
||||||
self.conf,
|
y = non_max_suppression(y if self.dmb else y[0],
|
||||||
self.iou,
|
self.conf,
|
||||||
self.classes,
|
self.iou,
|
||||||
self.agnostic,
|
self.classes,
|
||||||
self.multi_label,
|
self.agnostic,
|
||||||
max_det=self.max_det) # NMS
|
self.multi_label,
|
||||||
for i in range(n):
|
max_det=self.max_det) # NMS
|
||||||
scale_coords(shape1, y[i][:, :4], shape0[i])
|
for i in range(n):
|
||||||
|
scale_coords(shape1, y[i][:, :4], shape0[i])
|
||||||
|
|
||||||
t.append(time_sync())
|
return Detections(ims, y, files, dt, self.names, x.shape)
|
||||||
return Detections(imgs, y, files, t, self.names, x.shape)
|
|
||||||
|
|
||||||
|
|
||||||
class Detections:
|
class Detections:
|
||||||
# YOLOv5 detections class for inference results
|
# YOLOv5 detections class for inference results
|
||||||
def __init__(self, imgs, pred, files, times=(0, 0, 0, 0), names=None, shape=None):
|
def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
d = pred[0].device # device
|
d = pred[0].device # device
|
||||||
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs] # normalizations
|
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
|
||||||
self.imgs = imgs # list of images as numpy arrays
|
self.ims = ims # list of images as numpy arrays
|
||||||
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
|
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
|
||||||
self.names = names # class names
|
self.names = names # class names
|
||||||
self.files = files # image filenames
|
self.files = files # image filenames
|
||||||
@ -658,12 +668,12 @@ class Detections:
|
|||||||
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
|
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
|
||||||
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
||||||
self.n = len(self.pred) # number of images (batch size)
|
self.n = len(self.pred) # number of images (batch size)
|
||||||
self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
|
self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
|
||||||
self.s = shape # inference BCHW shape
|
self.s = shape # inference BCHW shape
|
||||||
|
|
||||||
def display(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
|
def display(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
|
||||||
crops = []
|
crops = []
|
||||||
for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
|
for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
|
||||||
s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
|
s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
|
||||||
if pred.shape[0]:
|
if pred.shape[0]:
|
||||||
for c in pred[:, -1].unique():
|
for c in pred[:, -1].unique():
|
||||||
@ -698,7 +708,7 @@ class Detections:
|
|||||||
if i == self.n - 1:
|
if i == self.n - 1:
|
||||||
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
|
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
|
||||||
if render:
|
if render:
|
||||||
self.imgs[i] = np.asarray(im)
|
self.ims[i] = np.asarray(im)
|
||||||
if crop:
|
if crop:
|
||||||
if save:
|
if save:
|
||||||
LOGGER.info(f'Saved results to {save_dir}\n')
|
LOGGER.info(f'Saved results to {save_dir}\n')
|
||||||
@ -721,7 +731,7 @@ class Detections:
|
|||||||
|
|
||||||
def render(self, labels=True):
|
def render(self, labels=True):
|
||||||
self.display(render=True, labels=labels) # render results
|
self.display(render=True, labels=labels) # render results
|
||||||
return self.imgs
|
return self.ims
|
||||||
|
|
||||||
def pandas(self):
|
def pandas(self):
|
||||||
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
|
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
|
||||||
@ -736,9 +746,9 @@ class Detections:
|
|||||||
def tolist(self):
|
def tolist(self):
|
||||||
# return a list of Detections objects, i.e. 'for result in results.tolist():'
|
# return a list of Detections objects, i.e. 'for result in results.tolist():'
|
||||||
r = range(self.n) # iterable
|
r = range(self.n) # iterable
|
||||||
x = [Detections([self.imgs[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
|
x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
|
||||||
# for d in x:
|
# for d in x:
|
||||||
# for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
|
# for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
|
||||||
# setattr(d, k, getattr(d, k)[0]) # pop out of list
|
# setattr(d, k, getattr(d, k)[0]) # pop out of list
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -754,10 +764,13 @@ class Classify(nn.Module):
|
|||||||
# Classification head, i.e. x(b,c1,20,20) to x(b,c2)
|
# Classification head, i.e. x(b,c1,20,20) to x(b,c2)
|
||||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
|
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
|
c_ = 1280 # efficientnet_b0 size
|
||||||
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
|
self.conv = Conv(c1, c_, k, s, autopad(k, p), g)
|
||||||
self.flat = nn.Flatten()
|
self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
|
||||||
|
self.drop = nn.Dropout(p=0.0, inplace=True)
|
||||||
|
self.linear = nn.Linear(c_, c2) # to x(b,c2)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
|
if isinstance(x, list):
|
||||||
return self.flat(self.conv(z)) # flatten to x(b,c2)
|
x = torch.cat(x, 1)
|
||||||
|
return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
|
||||||
|
@ -8,7 +8,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from models.common import Conv
|
|
||||||
from utils.downloads import attempt_download
|
from utils.downloads import attempt_download
|
||||||
|
|
||||||
|
|
||||||
@ -79,9 +78,16 @@ def attempt_load(weights, device=None, inplace=True, fuse=True):
|
|||||||
for w in weights if isinstance(weights, list) else [weights]:
|
for w in weights if isinstance(weights, list) else [weights]:
|
||||||
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
||||||
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||||
model.append(ckpt.fuse().eval() if fuse else ckpt.eval()) # fused or un-fused model in eval mode
|
|
||||||
|
|
||||||
# Compatibility updates
|
# Model compatibility updates
|
||||||
|
if not hasattr(ckpt, 'stride'):
|
||||||
|
ckpt.stride = torch.tensor([32.])
|
||||||
|
if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
|
||||||
|
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
|
||||||
|
|
||||||
|
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
||||||
|
|
||||||
|
# Module compatibility updates
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
t = type(m)
|
t = type(m)
|
||||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
|
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
|
||||||
@ -89,16 +95,17 @@ def attempt_load(weights, device=None, inplace=True, fuse=True):
|
|||||||
if t is Detect and not isinstance(m.anchor_grid, list):
|
if t is Detect and not isinstance(m.anchor_grid, list):
|
||||||
delattr(m, 'anchor_grid')
|
delattr(m, 'anchor_grid')
|
||||||
setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
|
setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
|
||||||
elif t is Conv:
|
|
||||||
m._non_persistent_buffers_set = set() # torch 1.6.0 compatibility
|
|
||||||
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
||||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||||
|
|
||||||
|
# Return model
|
||||||
if len(model) == 1:
|
if len(model) == 1:
|
||||||
return model[-1] # return model
|
return model[-1]
|
||||||
|
|
||||||
|
# Return detection ensemble
|
||||||
print(f'Ensemble created with {weights}\n')
|
print(f'Ensemble created with {weights}\n')
|
||||||
for k in 'names', 'nc', 'yaml':
|
for k in 'names', 'nc', 'yaml':
|
||||||
setattr(model, k, getattr(model[0], k))
|
setattr(model, k, getattr(model[0], k))
|
||||||
model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
|
model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
|
||||||
assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
|
assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
|
||||||
return model # return ensemble
|
return model
|
||||||
|
@ -7,7 +7,7 @@ Usage:
|
|||||||
$ python models/tf.py --weights yolov5s.pt
|
$ python models/tf.py --weights yolov5s.pt
|
||||||
|
|
||||||
Export:
|
Export:
|
||||||
$ python path/to/export.py --weights yolov5s.pt --include saved_model pb tflite tfjs
|
$ python export.py --weights yolov5s.pt --include saved_model pb tflite tfjs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
169
models/yolo.py
169
models/yolo.py
@ -3,10 +3,11 @@
|
|||||||
YOLO-specific modules
|
YOLO-specific modules
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
$ python path/to/models/yolo.py --cfg yolov5s.yaml
|
$ python models/yolo.py --cfg yolov5s.yaml
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import sys
|
import sys
|
||||||
@ -36,7 +37,7 @@ except ImportError:
|
|||||||
|
|
||||||
class Detect(nn.Module):
|
class Detect(nn.Module):
|
||||||
stride = None # strides computed during build
|
stride = None # strides computed during build
|
||||||
onnx_dynamic = False # ONNX export parameter
|
dynamic = False # force grid reconstruction
|
||||||
export = False # export mode
|
export = False # export mode
|
||||||
|
|
||||||
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
|
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
|
||||||
@ -45,11 +46,11 @@ class Detect(nn.Module):
|
|||||||
self.no = nc + 5 # number of outputs per anchor
|
self.no = nc + 5 # number of outputs per anchor
|
||||||
self.nl = len(anchors) # number of detection layers
|
self.nl = len(anchors) # number of detection layers
|
||||||
self.na = len(anchors[0]) // 2 # number of anchors
|
self.na = len(anchors[0]) // 2 # number of anchors
|
||||||
self.grid = [torch.zeros(1)] * self.nl # init grid
|
self.grid = [torch.empty(1)] * self.nl # init grid
|
||||||
self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
|
self.anchor_grid = [torch.empty(1)] * self.nl # init anchor grid
|
||||||
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
|
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
|
||||||
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
||||||
self.inplace = inplace # use in-place ops (e.g. slice assignment)
|
self.inplace = inplace # use inplace ops (e.g. slice assignment)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
z = [] # inference output
|
z = [] # inference output
|
||||||
@ -59,7 +60,7 @@ class Detect(nn.Module):
|
|||||||
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
||||||
|
|
||||||
if not self.training: # inference
|
if not self.training: # inference
|
||||||
if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
|
if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
|
||||||
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
|
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
|
||||||
|
|
||||||
y = x[i].sigmoid()
|
y = x[i].sigmoid()
|
||||||
@ -75,22 +76,75 @@ class Detect(nn.Module):
|
|||||||
|
|
||||||
return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
|
return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
|
||||||
|
|
||||||
def _make_grid(self, nx=20, ny=20, i=0):
|
def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, '1.10.0')):
|
||||||
d = self.anchors[i].device
|
d = self.anchors[i].device
|
||||||
t = self.anchors[i].dtype
|
t = self.anchors[i].dtype
|
||||||
shape = 1, self.na, ny, nx, 2 # grid shape
|
shape = 1, self.na, ny, nx, 2 # grid shape
|
||||||
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
|
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
|
||||||
if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
|
yv, xv = torch.meshgrid(y, x, indexing='ij') if torch_1_10 else torch.meshgrid(y, x) # torch>=0.7 compatibility
|
||||||
yv, xv = torch.meshgrid(y, x, indexing='ij')
|
|
||||||
else:
|
|
||||||
yv, xv = torch.meshgrid(y, x)
|
|
||||||
grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5
|
grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5
|
||||||
anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
|
anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
|
||||||
return grid, anchor_grid
|
return grid, anchor_grid
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class BaseModel(nn.Module):
|
||||||
# YOLOv5 model
|
# YOLOv5 base model
|
||||||
|
def forward(self, x, profile=False, visualize=False):
|
||||||
|
return self._forward_once(x, profile, visualize) # single-scale inference, train
|
||||||
|
|
||||||
|
def _forward_once(self, x, profile=False, visualize=False):
|
||||||
|
y, dt = [], [] # outputs
|
||||||
|
for m in self.model:
|
||||||
|
if m.f != -1: # if not from previous layer
|
||||||
|
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
||||||
|
if profile:
|
||||||
|
self._profile_one_layer(m, x, dt)
|
||||||
|
x = m(x) # run
|
||||||
|
y.append(x if m.i in self.save else None) # save output
|
||||||
|
if visualize:
|
||||||
|
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _profile_one_layer(self, m, x, dt):
|
||||||
|
c = m == self.model[-1] # is final layer, copy input as inplace fix
|
||||||
|
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
||||||
|
t = time_sync()
|
||||||
|
for _ in range(10):
|
||||||
|
m(x.copy() if c else x)
|
||||||
|
dt.append((time_sync() - t) * 100)
|
||||||
|
if m == self.model[0]:
|
||||||
|
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
||||||
|
LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
|
||||||
|
if c:
|
||||||
|
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
||||||
|
|
||||||
|
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
|
||||||
|
LOGGER.info('Fusing layers... ')
|
||||||
|
for m in self.model.modules():
|
||||||
|
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
|
||||||
|
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
||||||
|
delattr(m, 'bn') # remove batchnorm
|
||||||
|
m.forward = m.forward_fuse # update forward
|
||||||
|
self.info()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def info(self, verbose=False, img_size=640): # print model information
|
||||||
|
model_info(self, verbose, img_size)
|
||||||
|
|
||||||
|
def _apply(self, fn):
|
||||||
|
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
||||||
|
self = super()._apply(fn)
|
||||||
|
m = self.model[-1] # Detect()
|
||||||
|
if isinstance(m, Detect):
|
||||||
|
m.stride = fn(m.stride)
|
||||||
|
m.grid = list(map(fn, m.grid))
|
||||||
|
if isinstance(m.anchor_grid, list):
|
||||||
|
m.anchor_grid = list(map(fn, m.anchor_grid))
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionModel(BaseModel):
|
||||||
|
# YOLOv5 detection model
|
||||||
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
|
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if isinstance(cfg, dict):
|
if isinstance(cfg, dict):
|
||||||
@ -118,7 +172,7 @@ class Model(nn.Module):
|
|||||||
if isinstance(m, Detect):
|
if isinstance(m, Detect):
|
||||||
s = 256 # 2x min stride
|
s = 256 # 2x min stride
|
||||||
m.inplace = self.inplace
|
m.inplace = self.inplace
|
||||||
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
|
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.empty(1, ch, s, s))]) # forward
|
||||||
check_anchor_order(m) # must be in pixel-space (not grid-space)
|
check_anchor_order(m) # must be in pixel-space (not grid-space)
|
||||||
m.anchors /= m.stride.view(-1, 1, 1)
|
m.anchors /= m.stride.view(-1, 1, 1)
|
||||||
self.stride = m.stride
|
self.stride = m.stride
|
||||||
@ -148,19 +202,6 @@ class Model(nn.Module):
|
|||||||
y = self._clip_augmented(y) # clip augmented tails
|
y = self._clip_augmented(y) # clip augmented tails
|
||||||
return torch.cat(y, 1), None # augmented inference, train
|
return torch.cat(y, 1), None # augmented inference, train
|
||||||
|
|
||||||
def _forward_once(self, x, profile=False, visualize=False):
|
|
||||||
y, dt = [], [] # outputs
|
|
||||||
for m in self.model:
|
|
||||||
if m.f != -1: # if not from previous layer
|
|
||||||
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
|
||||||
if profile:
|
|
||||||
self._profile_one_layer(m, x, dt)
|
|
||||||
x = m(x) # run
|
|
||||||
y.append(x if m.i in self.save else None) # save output
|
|
||||||
if visualize:
|
|
||||||
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def _descale_pred(self, p, flips, scale, img_size):
|
def _descale_pred(self, p, flips, scale, img_size):
|
||||||
# de-scale predictions following augmented inference (inverse operation)
|
# de-scale predictions following augmented inference (inverse operation)
|
||||||
if self.inplace:
|
if self.inplace:
|
||||||
@ -189,19 +230,6 @@ class Model(nn.Module):
|
|||||||
y[-1] = y[-1][:, i:] # small
|
y[-1] = y[-1][:, i:] # small
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def _profile_one_layer(self, m, x, dt):
|
|
||||||
c = isinstance(m, Detect) # is final layer, copy input as inplace fix
|
|
||||||
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
|
||||||
t = time_sync()
|
|
||||||
for _ in range(10):
|
|
||||||
m(x.copy() if c else x)
|
|
||||||
dt.append((time_sync() - t) * 100)
|
|
||||||
if m == self.model[0]:
|
|
||||||
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
|
||||||
LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
|
|
||||||
if c:
|
|
||||||
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
|
||||||
|
|
||||||
def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
|
def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
|
||||||
# https://arxiv.org/abs/1708.02002 section 3.3
|
# https://arxiv.org/abs/1708.02002 section 3.3
|
||||||
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
|
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
|
||||||
@ -212,41 +240,34 @@ class Model(nn.Module):
|
|||||||
b[:, 5:] += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # cls
|
b[:, 5:] += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # cls
|
||||||
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
||||||
|
|
||||||
def _print_biases(self):
|
|
||||||
m = self.model[-1] # Detect() module
|
|
||||||
for mi in m.m: # from
|
|
||||||
b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
|
|
||||||
LOGGER.info(
|
|
||||||
('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
|
|
||||||
|
|
||||||
# def _print_weights(self):
|
Model = DetectionModel # retain YOLOv5 'Model' class for backwards compatibility
|
||||||
# for m in self.model.modules():
|
|
||||||
# if type(m) is Bottleneck:
|
|
||||||
# LOGGER.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
|
|
||||||
|
|
||||||
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
|
|
||||||
LOGGER.info('Fusing layers... ')
|
|
||||||
for m in self.model.modules():
|
|
||||||
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
|
|
||||||
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
|
||||||
delattr(m, 'bn') # remove batchnorm
|
|
||||||
m.forward = m.forward_fuse # update forward
|
|
||||||
self.info()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def info(self, verbose=False, img_size=640): # print model information
|
class ClassificationModel(BaseModel):
|
||||||
model_info(self, verbose, img_size)
|
# YOLOv5 classification model
|
||||||
|
def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, number of classes, cutoff index
|
||||||
|
super().__init__()
|
||||||
|
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)
|
||||||
|
|
||||||
def _apply(self, fn):
|
def _from_detection_model(self, model, nc=1000, cutoff=10):
|
||||||
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
# Create a YOLOv5 classification model from a YOLOv5 detection model
|
||||||
self = super()._apply(fn)
|
if isinstance(model, DetectMultiBackend):
|
||||||
m = self.model[-1] # Detect()
|
model = model.model # unwrap DetectMultiBackend
|
||||||
if isinstance(m, Detect):
|
model.model = model.model[:cutoff] # backbone
|
||||||
m.stride = fn(m.stride)
|
m = model.model[-1] # last layer
|
||||||
m.grid = list(map(fn, m.grid))
|
ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
|
||||||
if isinstance(m.anchor_grid, list):
|
c = Classify(ch, nc) # Classify()
|
||||||
m.anchor_grid = list(map(fn, m.anchor_grid))
|
c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
|
||||||
return self
|
model.model[-1] = c # replace
|
||||||
|
self.model = model.model
|
||||||
|
self.stride = model.stride
|
||||||
|
self.save = []
|
||||||
|
self.nc = nc
|
||||||
|
|
||||||
|
def _from_yaml(self, cfg):
|
||||||
|
# Create a YOLOv5 classification model from a *.yaml file
|
||||||
|
self.model = None
|
||||||
|
|
||||||
|
|
||||||
def parse_model(d, ch): # model_dict, input_channels(3)
|
def parse_model(d, ch): # model_dict, input_channels(3)
|
||||||
@ -259,10 +280,8 @@ def parse_model(d, ch): # model_dict, input_channels(3)
|
|||||||
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
||||||
m = eval(m) if isinstance(m, str) else m # eval strings
|
m = eval(m) if isinstance(m, str) else m # eval strings
|
||||||
for j, a in enumerate(args):
|
for j, a in enumerate(args):
|
||||||
try:
|
with contextlib.suppress(NameError):
|
||||||
args[j] = eval(a) if isinstance(a, str) else a # eval strings
|
args[j] = eval(a) if isinstance(a, str) else a # eval strings
|
||||||
except NameError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
|
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
|
||||||
if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
|
if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
|
||||||
@ -322,7 +341,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# Options
|
# Options
|
||||||
if opt.line_profile: # profile layer by layer
|
if opt.line_profile: # profile layer by layer
|
||||||
_ = model(im, profile=True)
|
model(im, profile=True)
|
||||||
|
|
||||||
elif opt.profile: # profile forward-backward
|
elif opt.profile: # profile forward-backward
|
||||||
results = profile(input=im, ops=[model], n=3)
|
results = profile(input=im, ops=[model], n=3)
|
||||||
|
@ -3,6 +3,33 @@
|
|||||||
utils/initialization
|
utils/initialization
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
class TryExcept(contextlib.ContextDecorator):
|
||||||
|
# YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
|
||||||
|
def __init__(self, msg=''):
|
||||||
|
self.msg = msg
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, value, traceback):
|
||||||
|
if value:
|
||||||
|
print(f'{self.msg}{value}')
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def threaded(func):
|
||||||
|
# Multi-threads a target function and returns thread. Usage: @threaded decorator
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
return thread
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def notebook_init(verbose=True):
|
def notebook_init(verbose=True):
|
||||||
# Check system software and hardware
|
# Check system software and hardware
|
||||||
@ -11,10 +38,12 @@ def notebook_init(verbose=True):
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from utils.general import check_requirements, emojis, is_colab
|
from utils.general import check_font, check_requirements, emojis, is_colab
|
||||||
from utils.torch_utils import select_device # imports
|
from utils.torch_utils import select_device # imports
|
||||||
|
|
||||||
check_requirements(('psutil', 'IPython'))
|
check_requirements(('psutil', 'IPython'))
|
||||||
|
check_font()
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
from IPython import display # to display images and clear console output
|
from IPython import display # to display images and clear console output
|
||||||
|
|
||||||
|
@ -8,15 +8,22 @@ import random
|
|||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
import torchvision.transforms.functional as TF
|
||||||
|
|
||||||
from utils.general import LOGGER, check_version, colorstr, resample_segments, segment2box
|
from utils.general import LOGGER, check_version, colorstr, resample_segments, segment2box
|
||||||
from utils.metrics import bbox_ioa
|
from utils.metrics import bbox_ioa
|
||||||
|
|
||||||
|
IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
|
||||||
|
IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
|
||||||
|
|
||||||
|
|
||||||
class Albumentations:
|
class Albumentations:
|
||||||
# YOLOv5 Albumentations class (optional, only used if package is installed)
|
# YOLOv5 Albumentations class (optional, only used if package is installed)
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.transform = None
|
self.transform = None
|
||||||
|
prefix = colorstr('albumentations: ')
|
||||||
try:
|
try:
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
||||||
@ -31,11 +38,11 @@ class Albumentations:
|
|||||||
A.ImageCompression(quality_lower=75, p=0.0)] # transforms
|
A.ImageCompression(quality_lower=75, p=0.0)] # transforms
|
||||||
self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
|
self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
|
||||||
|
|
||||||
LOGGER.info(colorstr('albumentations: ') + ', '.join(f'{x}' for x in self.transform.transforms if x.p))
|
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
|
||||||
except ImportError: # package not installed, skip
|
except ImportError: # package not installed, skip
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.info(colorstr('albumentations: ') + f'{e}')
|
LOGGER.info(f'{prefix}{e}')
|
||||||
|
|
||||||
def __call__(self, im, labels, p=1.0):
|
def __call__(self, im, labels, p=1.0):
|
||||||
if self.transform and random.random() < p:
|
if self.transform and random.random() < p:
|
||||||
@ -44,6 +51,18 @@ class Albumentations:
|
|||||||
return im, labels
|
return im, labels
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD, inplace=False):
|
||||||
|
# Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = (x - mean) / std
|
||||||
|
return TF.normalize(x, mean, std, inplace=inplace)
|
||||||
|
|
||||||
|
|
||||||
|
def denormalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD):
|
||||||
|
# Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = x * std + mean
|
||||||
|
for i in range(3):
|
||||||
|
x[:, i] = x[:, i] * std[i] + mean[i]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
|
def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
|
||||||
# HSV color-space augmentation
|
# HSV color-space augmentation
|
||||||
if hgain or sgain or vgain:
|
if hgain or sgain or vgain:
|
||||||
@ -282,3 +301,96 @@ def box_candidates(box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16):
|
|||||||
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
|
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
|
||||||
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
|
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
|
||||||
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
|
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
|
||||||
|
|
||||||
|
|
||||||
|
def classify_albumentations(augment=True,
|
||||||
|
size=224,
|
||||||
|
scale=(0.08, 1.0),
|
||||||
|
hflip=0.5,
|
||||||
|
vflip=0.0,
|
||||||
|
jitter=0.4,
|
||||||
|
mean=IMAGENET_MEAN,
|
||||||
|
std=IMAGENET_STD,
|
||||||
|
auto_aug=False):
|
||||||
|
# YOLOv5 classification Albumentations (optional, only used if package is installed)
|
||||||
|
prefix = colorstr('albumentations: ')
|
||||||
|
try:
|
||||||
|
import albumentations as A
|
||||||
|
from albumentations.pytorch import ToTensorV2
|
||||||
|
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
||||||
|
if augment: # Resize and crop
|
||||||
|
T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
|
||||||
|
if auto_aug:
|
||||||
|
# TODO: implement AugMix, AutoAug & RandAug in albumentation
|
||||||
|
LOGGER.info(f'{prefix}auto augmentations are currently not supported')
|
||||||
|
else:
|
||||||
|
if hflip > 0:
|
||||||
|
T += [A.HorizontalFlip(p=hflip)]
|
||||||
|
if vflip > 0:
|
||||||
|
T += [A.VerticalFlip(p=vflip)]
|
||||||
|
if jitter > 0:
|
||||||
|
color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, satuaration, 0 hue
|
||||||
|
T += [A.ColorJitter(*color_jitter, 0)]
|
||||||
|
else: # Use fixed crop for eval set (reproducibility)
|
||||||
|
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
|
||||||
|
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
|
||||||
|
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
|
||||||
|
return A.Compose(T)
|
||||||
|
|
||||||
|
except ImportError: # package not installed, skip
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.info(f'{prefix}{e}')
|
||||||
|
|
||||||
|
|
||||||
|
def classify_transforms(size=224):
|
||||||
|
# Transforms to apply if albumentations not installed
|
||||||
|
assert isinstance(size, int), f'ERROR: classify_transforms size {size} must be integer, not (list, tuple)'
|
||||||
|
# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
||||||
|
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
||||||
|
|
||||||
|
|
||||||
|
class LetterBox:
|
||||||
|
# YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
|
||||||
|
def __init__(self, size=(640, 640), auto=False, stride=32):
|
||||||
|
super().__init__()
|
||||||
|
self.h, self.w = (size, size) if isinstance(size, int) else size
|
||||||
|
self.auto = auto # pass max size integer, automatically solve for short side using stride
|
||||||
|
self.stride = stride # used with auto
|
||||||
|
|
||||||
|
def __call__(self, im): # im = np.array HWC
|
||||||
|
imh, imw = im.shape[:2]
|
||||||
|
r = min(self.h / imh, self.w / imw) # ratio of new/old
|
||||||
|
h, w = round(imh * r), round(imw * r) # resized image
|
||||||
|
hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
|
||||||
|
top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
|
||||||
|
im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
|
||||||
|
im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||||
|
return im_out
|
||||||
|
|
||||||
|
|
||||||
|
class CenterCrop:
|
||||||
|
# YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
|
||||||
|
def __init__(self, size=640):
|
||||||
|
super().__init__()
|
||||||
|
self.h, self.w = (size, size) if isinstance(size, int) else size
|
||||||
|
|
||||||
|
def __call__(self, im): # im = np.array HWC
|
||||||
|
imh, imw = im.shape[:2]
|
||||||
|
m = min(imh, imw) # min dimension
|
||||||
|
top, left = (imh - m) // 2, (imw - m) // 2
|
||||||
|
return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
|
||||||
|
class ToTensor:
|
||||||
|
# YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
|
||||||
|
def __init__(self, half=False):
|
||||||
|
super().__init__()
|
||||||
|
self.half = half
|
||||||
|
|
||||||
|
def __call__(self, im): # im = np.array HWC in BGR order
|
||||||
|
im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
|
||||||
|
im = torch.from_numpy(im) # to torch
|
||||||
|
im = im.half() if self.half else im.float() # uint8 to fp16/32
|
||||||
|
im /= 255.0 # 0-255 to 0.0-1.0
|
||||||
|
return im
|
||||||
|
@ -10,7 +10,8 @@ import torch
|
|||||||
import yaml
|
import yaml
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from utils.general import LOGGER, colorstr, emojis
|
from utils import TryExcept
|
||||||
|
from utils.general import LOGGER, colorstr
|
||||||
|
|
||||||
PREFIX = colorstr('AutoAnchor: ')
|
PREFIX = colorstr('AutoAnchor: ')
|
||||||
|
|
||||||
@ -25,6 +26,7 @@ def check_anchor_order(m):
|
|||||||
m.anchors[:] = m.anchors.flip(0)
|
m.anchors[:] = m.anchors.flip(0)
|
||||||
|
|
||||||
|
|
||||||
|
@TryExcept(f'{PREFIX}ERROR: ')
|
||||||
def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
||||||
# Check anchor fit to data, recompute if necessary
|
# Check anchor fit to data, recompute if necessary
|
||||||
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
|
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
|
||||||
@ -45,14 +47,11 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
|||||||
bpr, aat = metric(anchors.cpu().view(-1, 2))
|
bpr, aat = metric(anchors.cpu().view(-1, 2))
|
||||||
s = f'\n{PREFIX}{aat:.2f} anchors/target, {bpr:.3f} Best Possible Recall (BPR). '
|
s = f'\n{PREFIX}{aat:.2f} anchors/target, {bpr:.3f} Best Possible Recall (BPR). '
|
||||||
if bpr > 0.98: # threshold to recompute
|
if bpr > 0.98: # threshold to recompute
|
||||||
LOGGER.info(emojis(f'{s}Current anchors are a good fit to dataset ✅'))
|
LOGGER.info(f'{s}Current anchors are a good fit to dataset ✅')
|
||||||
else:
|
else:
|
||||||
LOGGER.info(emojis(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...'))
|
LOGGER.info(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...')
|
||||||
na = m.anchors.numel() // 2 # number of anchors
|
na = m.anchors.numel() // 2 # number of anchors
|
||||||
try:
|
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
|
||||||
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
|
|
||||||
except Exception as e:
|
|
||||||
LOGGER.info(f'{PREFIX}ERROR: {e}')
|
|
||||||
new_bpr = metric(anchors)[0]
|
new_bpr = metric(anchors)[0]
|
||||||
if new_bpr > bpr: # replace anchors
|
if new_bpr > bpr: # replace anchors
|
||||||
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
|
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
|
||||||
@ -62,7 +61,7 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
|||||||
s = f'{PREFIX}Done ✅ (optional: update model *.yaml to use these anchors in the future)'
|
s = f'{PREFIX}Done ✅ (optional: update model *.yaml to use these anchors in the future)'
|
||||||
else:
|
else:
|
||||||
s = f'{PREFIX}Done ⚠️ (original anchors better than new anchors, proceeding with original anchors)'
|
s = f'{PREFIX}Done ⚠️ (original anchors better than new anchors, proceeding with original anchors)'
|
||||||
LOGGER.info(emojis(s))
|
LOGGER.info(s)
|
||||||
|
|
||||||
|
|
||||||
def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
|
def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
|
||||||
@ -124,7 +123,7 @@ def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen
|
|||||||
i = (wh0 < 3.0).any(1).sum()
|
i = (wh0 < 3.0).any(1).sum()
|
||||||
if i:
|
if i:
|
||||||
LOGGER.info(f'{PREFIX}WARNING: Extremely small objects found: {i} of {len(wh0)} labels are < 3 pixels in size')
|
LOGGER.info(f'{PREFIX}WARNING: Extremely small objects found: {i} of {len(wh0)} labels are < 3 pixels in size')
|
||||||
wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
|
wh = wh0[(wh0 >= 2.0).any(1)].astype(np.float32) # filter > 2 pixels
|
||||||
# wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
|
# wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
|
||||||
|
|
||||||
# Kmeans init
|
# Kmeans init
|
||||||
@ -167,4 +166,4 @@ def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen
|
|||||||
if verbose:
|
if verbose:
|
||||||
print_results(k, verbose)
|
print_results(k, verbose)
|
||||||
|
|
||||||
return print_results(k)
|
return print_results(k).astype(np.float32)
|
||||||
|
@ -8,7 +8,7 @@ from copy import deepcopy
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from utils.general import LOGGER, colorstr, emojis
|
from utils.general import LOGGER, colorstr
|
||||||
from utils.torch_utils import profile
|
from utils.torch_utils import profile
|
||||||
|
|
||||||
|
|
||||||
@ -18,7 +18,7 @@ def check_train_batch_size(model, imgsz=640, amp=True):
|
|||||||
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
|
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
|
||||||
|
|
||||||
|
|
||||||
def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
|
def autobatch(model, imgsz=640, fraction=0.8, batch_size=16):
|
||||||
# Automatically estimate best batch size to use `fraction` of available CUDA memory
|
# Automatically estimate best batch size to use `fraction` of available CUDA memory
|
||||||
# Usage:
|
# Usage:
|
||||||
# import torch
|
# import torch
|
||||||
@ -47,7 +47,7 @@ def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
|
|||||||
# Profile batch sizes
|
# Profile batch sizes
|
||||||
batch_sizes = [1, 2, 4, 8, 16]
|
batch_sizes = [1, 2, 4, 8, 16]
|
||||||
try:
|
try:
|
||||||
img = [torch.zeros(b, 3, imgsz, imgsz) for b in batch_sizes]
|
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
|
||||||
results = profile(img, model, n=3, device=device)
|
results = profile(img, model, n=3, device=device)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.warning(f'{prefix}{e}')
|
LOGGER.warning(f'{prefix}{e}')
|
||||||
@ -60,7 +60,10 @@ def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
|
|||||||
i = results.index(None) # first fail index
|
i = results.index(None) # first fail index
|
||||||
if b >= batch_sizes[i]: # y intercept above failure point
|
if b >= batch_sizes[i]: # y intercept above failure point
|
||||||
b = batch_sizes[max(i - 1, 0)] # select prior safe point
|
b = batch_sizes[max(i - 1, 0)] # select prior safe point
|
||||||
|
if b < 1 or b > 1024: # b outside of safe range
|
||||||
|
b = batch_size
|
||||||
|
LOGGER.warning(f'{prefix}WARNING: ⚠️ CUDA anomaly detected, recommend restart environment and retry command.')
|
||||||
|
|
||||||
fraction = np.polyval(p, b) / t # actual fraction predicted
|
fraction = np.polyval(p, b) / t # actual fraction predicted
|
||||||
LOGGER.info(emojis(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅'))
|
LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
|
||||||
return b
|
return b
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
Callback utils
|
Callback utils
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
class Callbacks:
|
class Callbacks:
|
||||||
""""
|
""""
|
||||||
@ -55,17 +57,20 @@ class Callbacks:
|
|||||||
"""
|
"""
|
||||||
return self._callbacks[hook] if hook else self._callbacks
|
return self._callbacks[hook] if hook else self._callbacks
|
||||||
|
|
||||||
def run(self, hook, *args, **kwargs):
|
def run(self, hook, *args, thread=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Loop through the registered actions and fire all callbacks
|
Loop through the registered actions and fire all callbacks on main thread
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hook: The name of the hook to check, defaults to all
|
hook: The name of the hook to check, defaults to all
|
||||||
args: Arguments to receive from YOLOv5
|
args: Arguments to receive from YOLOv5
|
||||||
|
thread: (boolean) Run callbacks in daemon thread
|
||||||
kwargs: Keyword Arguments to receive from YOLOv5
|
kwargs: Keyword Arguments to receive from YOLOv5
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
|
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
|
||||||
|
|
||||||
for logger in self._callbacks[hook]:
|
for logger in self._callbacks[hook]:
|
||||||
logger['callback'](*args, **kwargs)
|
if thread:
|
||||||
|
threading.Thread(target=logger['callback'], args=args, kwargs=kwargs, daemon=True).start()
|
||||||
|
else:
|
||||||
|
logger['callback'](*args, **kwargs)
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
Dataloaders and dataset utils
|
Dataloaders and dataset utils
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import glob
|
import glob
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
@ -21,22 +22,25 @@ from zipfile import ZipFile
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torchvision
|
||||||
import yaml
|
import yaml
|
||||||
from PIL import ExifTags, Image, ImageOps
|
from PIL import ExifTags, Image, ImageOps
|
||||||
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
|
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
|
from utils.augmentations import (Albumentations, augment_hsv, classify_albumentations, classify_transforms, copy_paste,
|
||||||
|
letterbox, mixup, random_perspective)
|
||||||
from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str,
|
from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str,
|
||||||
cv2, is_colab, is_kaggle, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
|
cv2, is_colab, is_kaggle, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
|
||||||
from utils.torch_utils import torch_distributed_zero_first
|
from utils.torch_utils import torch_distributed_zero_first
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
HELP_URL = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
|
HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
|
||||||
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp' # include image suffixes
|
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # include image suffixes
|
||||||
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
|
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
|
||||||
BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
|
BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
|
||||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||||
|
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
|
||||||
|
|
||||||
# Get orientation exif tag
|
# Get orientation exif tag
|
||||||
for orientation in ExifTags.TAGS.keys():
|
for orientation in ExifTags.TAGS.keys():
|
||||||
@ -55,13 +59,10 @@ def get_hash(paths):
|
|||||||
def exif_size(img):
|
def exif_size(img):
|
||||||
# Returns exif-corrected PIL size
|
# Returns exif-corrected PIL size
|
||||||
s = img.size # (width, height)
|
s = img.size # (width, height)
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
rotation = dict(img._getexif().items())[orientation]
|
rotation = dict(img._getexif().items())[orientation]
|
||||||
if rotation in [6, 8]: # rotation 270 or 90
|
if rotation in [6, 8]: # rotation 270 or 90
|
||||||
s = (s[1], s[0])
|
s = (s[1], s[0])
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
@ -83,7 +84,7 @@ def exif_transpose(image):
|
|||||||
5: Image.TRANSPOSE,
|
5: Image.TRANSPOSE,
|
||||||
6: Image.ROTATE_270,
|
6: Image.ROTATE_270,
|
||||||
7: Image.TRANSVERSE,
|
7: Image.TRANSVERSE,
|
||||||
8: Image.ROTATE_90,}.get(orientation)
|
8: Image.ROTATE_90}.get(orientation)
|
||||||
if method is not None:
|
if method is not None:
|
||||||
image = image.transpose(method)
|
image = image.transpose(method)
|
||||||
del exif[0x0112]
|
del exif[0x0112]
|
||||||
@ -91,6 +92,13 @@ def exif_transpose(image):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def seed_worker(worker_id):
|
||||||
|
# Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
|
||||||
|
worker_seed = torch.initial_seed() % 2 ** 32
|
||||||
|
np.random.seed(worker_seed)
|
||||||
|
random.seed(worker_seed)
|
||||||
|
|
||||||
|
|
||||||
def create_dataloader(path,
|
def create_dataloader(path,
|
||||||
imgsz,
|
imgsz,
|
||||||
batch_size,
|
batch_size,
|
||||||
@ -130,13 +138,17 @@ def create_dataloader(path,
|
|||||||
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
|
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
|
||||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||||
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
||||||
|
generator = torch.Generator()
|
||||||
|
generator.manual_seed(0)
|
||||||
return loader(dataset,
|
return loader(dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=shuffle and sampler is None,
|
shuffle=shuffle and sampler is None,
|
||||||
num_workers=nw,
|
num_workers=nw,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=True,
|
pin_memory=PIN_MEMORY,
|
||||||
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
|
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
|
||||||
|
worker_init_fn=seed_worker,
|
||||||
|
generator=generator), dataset
|
||||||
|
|
||||||
|
|
||||||
class InfiniteDataLoader(dataloader.DataLoader):
|
class InfiniteDataLoader(dataloader.DataLoader):
|
||||||
@ -175,7 +187,7 @@ class _RepeatSampler:
|
|||||||
|
|
||||||
class LoadImages:
|
class LoadImages:
|
||||||
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
|
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
|
||||||
def __init__(self, path, img_size=640, stride=32, auto=True):
|
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None):
|
||||||
files = []
|
files = []
|
||||||
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
||||||
p = str(Path(p).resolve())
|
p = str(Path(p).resolve())
|
||||||
@ -199,8 +211,9 @@ class LoadImages:
|
|||||||
self.video_flag = [False] * ni + [True] * nv
|
self.video_flag = [False] * ni + [True] * nv
|
||||||
self.mode = 'image'
|
self.mode = 'image'
|
||||||
self.auto = auto
|
self.auto = auto
|
||||||
|
self.transforms = transforms # optional
|
||||||
if any(videos):
|
if any(videos):
|
||||||
self.new_video(videos[0]) # new video
|
self._new_video(videos[0]) # new video
|
||||||
else:
|
else:
|
||||||
self.cap = None
|
self.cap = None
|
||||||
assert self.nf > 0, f'No images or videos found in {p}. ' \
|
assert self.nf > 0, f'No images or videos found in {p}. ' \
|
||||||
@ -218,103 +231,69 @@ class LoadImages:
|
|||||||
if self.video_flag[self.count]:
|
if self.video_flag[self.count]:
|
||||||
# Read video
|
# Read video
|
||||||
self.mode = 'video'
|
self.mode = 'video'
|
||||||
ret_val, img0 = self.cap.read()
|
ret_val, im0 = self.cap.read()
|
||||||
while not ret_val:
|
while not ret_val:
|
||||||
self.count += 1
|
self.count += 1
|
||||||
self.cap.release()
|
self.cap.release()
|
||||||
if self.count == self.nf: # last video
|
if self.count == self.nf: # last video
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
path = self.files[self.count]
|
path = self.files[self.count]
|
||||||
self.new_video(path)
|
self._new_video(path)
|
||||||
ret_val, img0 = self.cap.read()
|
ret_val, im0 = self.cap.read()
|
||||||
|
|
||||||
self.frame += 1
|
self.frame += 1
|
||||||
|
# im0 = self._cv2_rotate(im0) # for use if cv2 auto rotation is False
|
||||||
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
|
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Read image
|
# Read image
|
||||||
self.count += 1
|
self.count += 1
|
||||||
img0 = cv2.imread(path) # BGR
|
im0 = cv2.imread(path) # BGR
|
||||||
assert img0 is not None, f'Image Not Found {path}'
|
assert im0 is not None, f'Image Not Found {path}'
|
||||||
s = f'image {self.count}/{self.nf} {path}: '
|
s = f'image {self.count}/{self.nf} {path}: '
|
||||||
|
|
||||||
# Padded resize
|
if self.transforms:
|
||||||
img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]
|
im = self.transforms(im0) # transforms
|
||||||
|
else:
|
||||||
|
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
|
||||||
|
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
||||||
|
im = np.ascontiguousarray(im) # contiguous
|
||||||
|
|
||||||
# Convert
|
return path, im, im0, self.cap, s
|
||||||
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
|
||||||
img = np.ascontiguousarray(img)
|
|
||||||
|
|
||||||
return path, img, img0, self.cap, s
|
def _new_video(self, path):
|
||||||
|
# Create a new video capture object
|
||||||
def new_video(self, path):
|
|
||||||
self.frame = 0
|
self.frame = 0
|
||||||
self.cap = cv2.VideoCapture(path)
|
self.cap = cv2.VideoCapture(path)
|
||||||
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
|
||||||
|
# self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493
|
||||||
|
|
||||||
|
def _cv2_rotate(self, im):
|
||||||
|
# Rotate a cv2 video manually
|
||||||
|
if self.orientation == 0:
|
||||||
|
return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
|
||||||
|
elif self.orientation == 180:
|
||||||
|
return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||||
|
elif self.orientation == 90:
|
||||||
|
return cv2.rotate(im, cv2.ROTATE_180)
|
||||||
|
return im
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.nf # number of files
|
return self.nf # number of files
|
||||||
|
|
||||||
|
|
||||||
class LoadWebcam: # for inference
|
|
||||||
# YOLOv5 local webcam dataloader, i.e. `python detect.py --source 0`
|
|
||||||
def __init__(self, pipe='0', img_size=640, stride=32):
|
|
||||||
self.img_size = img_size
|
|
||||||
self.stride = stride
|
|
||||||
self.pipe = eval(pipe) if pipe.isnumeric() else pipe
|
|
||||||
self.cap = cv2.VideoCapture(self.pipe) # video capture object
|
|
||||||
self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
self.count = -1
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
self.count += 1
|
|
||||||
if cv2.waitKey(1) == ord('q'): # q to quit
|
|
||||||
self.cap.release()
|
|
||||||
cv2.destroyAllWindows()
|
|
||||||
raise StopIteration
|
|
||||||
|
|
||||||
# Read frame
|
|
||||||
ret_val, img0 = self.cap.read()
|
|
||||||
img0 = cv2.flip(img0, 1) # flip left-right
|
|
||||||
|
|
||||||
# Print
|
|
||||||
assert ret_val, f'Camera Error {self.pipe}'
|
|
||||||
img_path = 'webcam.jpg'
|
|
||||||
s = f'webcam {self.count}: '
|
|
||||||
|
|
||||||
# Padded resize
|
|
||||||
img = letterbox(img0, self.img_size, stride=self.stride)[0]
|
|
||||||
|
|
||||||
# Convert
|
|
||||||
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
|
||||||
img = np.ascontiguousarray(img)
|
|
||||||
|
|
||||||
return img_path, img, img0, None, s
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
class LoadStreams:
|
class LoadStreams:
|
||||||
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
|
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
|
||||||
def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
|
def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True, transforms=None):
|
||||||
|
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
||||||
self.mode = 'stream'
|
self.mode = 'stream'
|
||||||
self.img_size = img_size
|
self.img_size = img_size
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
sources = Path(sources).read_text().rsplit() if Path(sources).is_file() else [sources]
|
||||||
if os.path.isfile(sources):
|
|
||||||
with open(sources) as f:
|
|
||||||
sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
|
|
||||||
else:
|
|
||||||
sources = [sources]
|
|
||||||
|
|
||||||
n = len(sources)
|
n = len(sources)
|
||||||
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
|
|
||||||
self.sources = [clean_str(x) for x in sources] # clean source names for later
|
self.sources = [clean_str(x) for x in sources] # clean source names for later
|
||||||
self.auto = auto
|
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
|
||||||
for i, s in enumerate(sources): # index, source
|
for i, s in enumerate(sources): # index, source
|
||||||
# Start thread to read frames from video stream
|
# Start thread to read frames from video stream
|
||||||
st = f'{i + 1}/{n}: {s}... '
|
st = f'{i + 1}/{n}: {s}... '
|
||||||
@ -341,8 +320,10 @@ class LoadStreams:
|
|||||||
LOGGER.info('') # newline
|
LOGGER.info('') # newline
|
||||||
|
|
||||||
# check for common shapes
|
# check for common shapes
|
||||||
s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs])
|
s = np.stack([letterbox(x, img_size, stride=stride, auto=auto)[0].shape for x in self.imgs])
|
||||||
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
|
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
|
||||||
|
self.auto = auto and self.rect
|
||||||
|
self.transforms = transforms # optional
|
||||||
if not self.rect:
|
if not self.rect:
|
||||||
LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
|
LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
|
||||||
|
|
||||||
@ -351,8 +332,7 @@ class LoadStreams:
|
|||||||
n, f, read = 0, self.frames[i], 1 # frame number, frame array, inference every 'read' frame
|
n, f, read = 0, self.frames[i], 1 # frame number, frame array, inference every 'read' frame
|
||||||
while cap.isOpened() and n < f:
|
while cap.isOpened() and n < f:
|
||||||
n += 1
|
n += 1
|
||||||
# _, self.imgs[index] = cap.read()
|
cap.grab() # .read() = .grab() followed by .retrieve()
|
||||||
cap.grab()
|
|
||||||
if n % read == 0:
|
if n % read == 0:
|
||||||
success, im = cap.retrieve()
|
success, im = cap.retrieve()
|
||||||
if success:
|
if success:
|
||||||
@ -373,18 +353,15 @@ class LoadStreams:
|
|||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
|
|
||||||
# Letterbox
|
im0 = self.imgs.copy()
|
||||||
img0 = self.imgs.copy()
|
if self.transforms:
|
||||||
img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0]
|
im = np.stack([self.transforms(x) for x in im0]) # transforms
|
||||||
|
else:
|
||||||
|
im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize
|
||||||
|
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
|
||||||
|
im = np.ascontiguousarray(im) # contiguous
|
||||||
|
|
||||||
# Stack
|
return self.sources, im, im0, None, ''
|
||||||
img = np.stack(img, 0)
|
|
||||||
|
|
||||||
# Convert
|
|
||||||
img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
|
|
||||||
img = np.ascontiguousarray(img)
|
|
||||||
|
|
||||||
return self.sources, img, img0, None, ''
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
|
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
|
||||||
@ -444,7 +421,7 @@ class LoadImagesAndLabels(Dataset):
|
|||||||
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
||||||
assert self.im_files, f'{prefix}No images found'
|
assert self.im_files, f'{prefix}No images found'
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')
|
raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}')
|
||||||
|
|
||||||
# Check cache
|
# Check cache
|
||||||
self.label_files = img2label_paths(self.im_files) # labels
|
self.label_files = img2label_paths(self.im_files) # labels
|
||||||
@ -463,13 +440,15 @@ class LoadImagesAndLabels(Dataset):
|
|||||||
tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
|
tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
|
||||||
if cache['msgs']:
|
if cache['msgs']:
|
||||||
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
||||||
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {HELP_URL}'
|
assert nf > 0 or not augment, f'{prefix}No labels found in {cache_path}, can not start training. {HELP_URL}'
|
||||||
|
|
||||||
# Read cache
|
# Read cache
|
||||||
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
|
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
|
||||||
labels, shapes, self.segments = zip(*cache.values())
|
labels, shapes, self.segments = zip(*cache.values())
|
||||||
|
nl = len(np.concatenate(labels, 0)) # number of labels
|
||||||
|
assert nl > 0 or not augment, f'{prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}'
|
||||||
self.labels = list(labels)
|
self.labels = list(labels)
|
||||||
self.shapes = np.array(shapes, dtype=np.float64)
|
self.shapes = np.array(shapes)
|
||||||
self.im_files = list(cache.keys()) # update
|
self.im_files = list(cache.keys()) # update
|
||||||
self.label_files = img2label_paths(cache.keys()) # update
|
self.label_files = img2label_paths(cache.keys()) # update
|
||||||
n = len(shapes) # number of images
|
n = len(shapes) # number of images
|
||||||
@ -560,7 +539,7 @@ class LoadImagesAndLabels(Dataset):
|
|||||||
if msgs:
|
if msgs:
|
||||||
LOGGER.info('\n'.join(msgs))
|
LOGGER.info('\n'.join(msgs))
|
||||||
if nf == 0:
|
if nf == 0:
|
||||||
LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}')
|
LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. {HELP_URL}')
|
||||||
x['hash'] = get_hash(self.label_files + self.im_files)
|
x['hash'] = get_hash(self.label_files + self.im_files)
|
||||||
x['results'] = nf, nm, ne, nc, len(self.im_files)
|
x['results'] = nf, nm, ne, nc, len(self.im_files)
|
||||||
x['msgs'] = msgs # warnings
|
x['msgs'] = msgs # warnings
|
||||||
@ -671,8 +650,7 @@ class LoadImagesAndLabels(Dataset):
|
|||||||
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
|
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
|
||||||
im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp)
|
im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp)
|
||||||
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
|
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
|
||||||
else:
|
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
|
||||||
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
|
|
||||||
|
|
||||||
def cache_images_to_disk(self, i):
|
def cache_images_to_disk(self, i):
|
||||||
# Saves an image as an *.npy file for faster loading
|
# Saves an image as an *.npy file for faster loading
|
||||||
@ -823,7 +801,7 @@ class LoadImagesAndLabels(Dataset):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def collate_fn4(batch):
|
def collate_fn4(batch):
|
||||||
img, label, path, shapes = zip(*batch) # transposed
|
im, label, path, shapes = zip(*batch) # transposed
|
||||||
n = len(shapes) // 4
|
n = len(shapes) // 4
|
||||||
im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
|
im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
|
||||||
|
|
||||||
@ -833,13 +811,13 @@ class LoadImagesAndLabels(Dataset):
|
|||||||
for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
|
for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
|
||||||
i *= 4
|
i *= 4
|
||||||
if random.random() < 0.5:
|
if random.random() < 0.5:
|
||||||
im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
|
im1 = F.interpolate(im[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
|
||||||
align_corners=False)[0].type(img[i].type())
|
align_corners=False)[0].type(im[i].type())
|
||||||
lb = label[i]
|
lb = label[i]
|
||||||
else:
|
else:
|
||||||
im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
|
im1 = torch.cat((torch.cat((im[i], im[i + 1]), 1), torch.cat((im[i + 2], im[i + 3]), 1)), 2)
|
||||||
lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
|
lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
|
||||||
im4.append(im)
|
im4.append(im1)
|
||||||
label4.append(lb)
|
label4.append(lb)
|
||||||
|
|
||||||
for i, lb in enumerate(label4):
|
for i, lb in enumerate(label4):
|
||||||
@ -849,25 +827,20 @@ class LoadImagesAndLabels(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
# Ancillary functions --------------------------------------------------------------------------------------------------
|
# Ancillary functions --------------------------------------------------------------------------------------------------
|
||||||
def create_folder(path='./new'):
|
|
||||||
# Create folder
|
|
||||||
if os.path.exists(path):
|
|
||||||
shutil.rmtree(path) # delete output folder
|
|
||||||
os.makedirs(path) # make new output folder
|
|
||||||
|
|
||||||
|
|
||||||
def flatten_recursive(path=DATASETS_DIR / 'coco128'):
|
def flatten_recursive(path=DATASETS_DIR / 'coco128'):
|
||||||
# Flatten a recursive directory by bringing all files to top level
|
# Flatten a recursive directory by bringing all files to top level
|
||||||
new_path = Path(str(path) + '_flat')
|
new_path = Path(f'{str(path)}_flat')
|
||||||
create_folder(new_path)
|
if os.path.exists(new_path):
|
||||||
for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
|
shutil.rmtree(new_path) # delete output folder
|
||||||
|
os.makedirs(new_path) # make new output folder
|
||||||
|
for file in tqdm(glob.glob(f'{str(Path(path))}/**/*.*', recursive=True)):
|
||||||
shutil.copyfile(file, new_path / Path(file).name)
|
shutil.copyfile(file, new_path / Path(file).name)
|
||||||
|
|
||||||
|
|
||||||
def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.dataloaders import *; extract_boxes()
|
def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.dataloaders import *; extract_boxes()
|
||||||
# Convert detection dataset into classification dataset, with one directory per class
|
# Convert detection dataset into classification dataset, with one directory per class
|
||||||
path = Path(path) # images dir
|
path = Path(path) # images dir
|
||||||
shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
|
shutil.rmtree(path / 'classification') if (path / 'classification').is_dir() else None # remove existing
|
||||||
files = list(path.rglob('*.*'))
|
files = list(path.rglob('*.*'))
|
||||||
n = len(files) # number of files
|
n = len(files) # number of files
|
||||||
for im_file in tqdm(files, total=n):
|
for im_file in tqdm(files, total=n):
|
||||||
@ -913,13 +886,15 @@ def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), ann
|
|||||||
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
|
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
|
||||||
|
|
||||||
txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
|
txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
|
||||||
[(path.parent / x).unlink(missing_ok=True) for x in txt] # remove existing
|
for x in txt:
|
||||||
|
if (path.parent / x).exists():
|
||||||
|
(path.parent / x).unlink() # remove existing
|
||||||
|
|
||||||
print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
|
print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
|
||||||
for i, img in tqdm(zip(indices, files), total=n):
|
for i, img in tqdm(zip(indices, files), total=n):
|
||||||
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
|
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
|
||||||
with open(path.parent / txt[i], 'a') as f:
|
with open(path.parent / txt[i], 'a') as f:
|
||||||
f.write('./' + img.relative_to(path.parent).as_posix() + '\n') # add image to txt file
|
f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file
|
||||||
|
|
||||||
|
|
||||||
def verify_image_label(args):
|
def verify_image_label(args):
|
||||||
@ -959,7 +934,7 @@ def verify_image_label(args):
|
|||||||
if len(i) < nl: # duplicate row check
|
if len(i) < nl: # duplicate row check
|
||||||
lb = lb[i] # remove duplicates
|
lb = lb[i] # remove duplicates
|
||||||
if segments:
|
if segments:
|
||||||
segments = segments[i]
|
segments = [segments[x] for x in i]
|
||||||
msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed'
|
msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed'
|
||||||
else:
|
else:
|
||||||
ne = 1 # label empty
|
ne = 1 # label empty
|
||||||
@ -974,21 +949,35 @@ def verify_image_label(args):
|
|||||||
return [None, None, None, None, nm, nf, ne, nc, msg]
|
return [None, None, None, None, nm, nf, ne, nc, msg]
|
||||||
|
|
||||||
|
|
||||||
def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False, profile=False, hub=False):
|
class HUBDatasetStats():
|
||||||
""" Return dataset statistics dictionary with images and instances counts per split per class
|
""" Return dataset statistics dictionary with images and instances counts per split per class
|
||||||
To run in parent directory: export PYTHONPATH="$PWD/yolov5"
|
To run in parent directory: export PYTHONPATH="$PWD/yolov5"
|
||||||
Usage1: from utils.dataloaders import *; dataset_stats('coco128.yaml', autodownload=True)
|
Usage1: from utils.dataloaders import *; HUBDatasetStats('coco128.yaml', autodownload=True)
|
||||||
Usage2: from utils.dataloaders import *; dataset_stats('path/to/coco128_with_yaml.zip')
|
Usage2: from utils.dataloaders import *; HUBDatasetStats('path/to/coco128_with_yaml.zip')
|
||||||
Arguments
|
Arguments
|
||||||
path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
|
path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
|
||||||
autodownload: Attempt to download dataset if not found locally
|
autodownload: Attempt to download dataset if not found locally
|
||||||
verbose: Print stats dictionary
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _round_labels(labels):
|
def __init__(self, path='coco128.yaml', autodownload=False):
|
||||||
# Update labels to integer class and 6 decimal place floats
|
# Initialize class
|
||||||
return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
|
zipped, data_dir, yaml_path = self._unzip(Path(path))
|
||||||
|
try:
|
||||||
|
with open(check_yaml(yaml_path), errors='ignore') as f:
|
||||||
|
data = yaml.safe_load(f) # data dict
|
||||||
|
if zipped:
|
||||||
|
data['path'] = data_dir
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("error/HUB/dataset_stats/yaml_load") from e
|
||||||
|
|
||||||
|
check_dataset(data, autodownload) # download dataset if missing
|
||||||
|
self.hub_dir = Path(data['path'] + '-hub')
|
||||||
|
self.im_dir = self.hub_dir / 'images'
|
||||||
|
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
|
||||||
|
self.stats = {'nc': data['nc'], 'names': list(data['names'].values())} # statistics dictionary
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _find_yaml(dir):
|
def _find_yaml(dir):
|
||||||
# Return data.yaml file
|
# Return data.yaml file
|
||||||
files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
|
files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
|
||||||
@ -999,26 +988,25 @@ def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False, profil
|
|||||||
assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
|
assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
|
||||||
return files[0]
|
return files[0]
|
||||||
|
|
||||||
def _unzip(path):
|
def _unzip(self, path):
|
||||||
# Unzip data.zip
|
# Unzip data.zip
|
||||||
if str(path).endswith('.zip'): # path is data.zip
|
if not str(path).endswith('.zip'): # path is data.yaml
|
||||||
assert Path(path).is_file(), f'Error unzipping {path}, file not found'
|
|
||||||
ZipFile(path).extractall(path=path.parent) # unzip
|
|
||||||
dir = path.with_suffix('') # dataset directory == zip name
|
|
||||||
assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/'
|
|
||||||
return True, str(dir), _find_yaml(dir) # zipped, data_dir, yaml_path
|
|
||||||
else: # path is data.yaml
|
|
||||||
return False, None, path
|
return False, None, path
|
||||||
|
assert Path(path).is_file(), f'Error unzipping {path}, file not found'
|
||||||
|
ZipFile(path).extractall(path=path.parent) # unzip
|
||||||
|
dir = path.with_suffix('') # dataset directory == zip name
|
||||||
|
assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/'
|
||||||
|
return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path
|
||||||
|
|
||||||
def _hub_ops(f, max_dim=1920):
|
def _hub_ops(self, f, max_dim=1920):
|
||||||
# HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
|
# HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
|
||||||
f_new = im_dir / Path(f).name # dataset-hub image filename
|
f_new = self.im_dir / Path(f).name # dataset-hub image filename
|
||||||
try: # use PIL
|
try: # use PIL
|
||||||
im = Image.open(f)
|
im = Image.open(f)
|
||||||
r = max_dim / max(im.height, im.width) # ratio
|
r = max_dim / max(im.height, im.width) # ratio
|
||||||
if r < 1.0: # image too large
|
if r < 1.0: # image too large
|
||||||
im = im.resize((int(im.width * r), int(im.height * r)))
|
im = im.resize((int(im.width * r), int(im.height * r)))
|
||||||
im.save(f_new, 'JPEG', quality=75, optimize=True) # save
|
im.save(f_new, 'JPEG', quality=50, optimize=True) # save
|
||||||
except Exception as e: # use OpenCV
|
except Exception as e: # use OpenCV
|
||||||
print(f'WARNING: HUB ops PIL failure {f}: {e}')
|
print(f'WARNING: HUB ops PIL failure {f}: {e}')
|
||||||
im = cv2.imread(f)
|
im = cv2.imread(f)
|
||||||
@ -1028,69 +1016,111 @@ def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False, profil
|
|||||||
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
|
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
|
||||||
cv2.imwrite(str(f_new), im)
|
cv2.imwrite(str(f_new), im)
|
||||||
|
|
||||||
zipped, data_dir, yaml_path = _unzip(Path(path))
|
def get_json(self, save=False, verbose=False):
|
||||||
try:
|
# Return dataset JSON for Ultralytics HUB
|
||||||
with open(check_yaml(yaml_path), errors='ignore') as f:
|
def _round(labels):
|
||||||
data = yaml.safe_load(f) # data dict
|
# Update labels to integer class and 6 decimal place floats
|
||||||
if zipped:
|
return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
|
||||||
data['path'] = data_dir # TODO: should this be dir.resolve()?`
|
|
||||||
except Exception:
|
|
||||||
raise Exception("error/HUB/dataset_stats/yaml_load")
|
|
||||||
|
|
||||||
check_dataset(data, autodownload) # download dataset if missing
|
for split in 'train', 'val', 'test':
|
||||||
hub_dir = Path(data['path'] + ('-hub' if hub else ''))
|
if self.data.get(split) is None:
|
||||||
stats = {'nc': data['nc'], 'names': data['names']} # statistics dictionary
|
self.stats[split] = None # i.e. no test set
|
||||||
for split in 'train', 'val', 'test':
|
continue
|
||||||
if data.get(split) is None:
|
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
|
||||||
stats[split] = None # i.e. no test set
|
x = np.array([
|
||||||
continue
|
np.bincount(label[:, 0].astype(int), minlength=self.data['nc'])
|
||||||
x = []
|
for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics')]) # shape(128x80)
|
||||||
dataset = LoadImagesAndLabels(data[split]) # load dataset
|
self.stats[split] = {
|
||||||
for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
|
'instance_stats': {
|
||||||
x.append(np.bincount(label[:, 0].astype(int), minlength=data['nc']))
|
'total': int(x.sum()),
|
||||||
x = np.array(x) # shape(128x80)
|
'per_class': x.sum(0).tolist()},
|
||||||
stats[split] = {
|
'image_stats': {
|
||||||
'instance_stats': {
|
'total': dataset.n,
|
||||||
'total': int(x.sum()),
|
'unlabelled': int(np.all(x == 0, 1).sum()),
|
||||||
'per_class': x.sum(0).tolist()},
|
'per_class': (x > 0).sum(0).tolist()},
|
||||||
'image_stats': {
|
'labels': [{
|
||||||
'total': dataset.n,
|
str(Path(k).name): _round(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
|
||||||
'unlabelled': int(np.all(x == 0, 1).sum()),
|
|
||||||
'per_class': (x > 0).sum(0).tolist()},
|
|
||||||
'labels': [{
|
|
||||||
str(Path(k).name): _round_labels(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
|
|
||||||
|
|
||||||
if hub:
|
# Save, print and return
|
||||||
im_dir = hub_dir / 'images'
|
if save:
|
||||||
im_dir.mkdir(parents=True, exist_ok=True)
|
stats_path = self.hub_dir / 'stats.json'
|
||||||
for _ in tqdm(ThreadPool(NUM_THREADS).imap(_hub_ops, dataset.im_files), total=dataset.n, desc='HUB Ops'):
|
print(f'Saving {stats_path.resolve()}...')
|
||||||
|
with open(stats_path, 'w') as f:
|
||||||
|
json.dump(self.stats, f) # save stats.json
|
||||||
|
if verbose:
|
||||||
|
print(json.dumps(self.stats, indent=2, sort_keys=False))
|
||||||
|
return self.stats
|
||||||
|
|
||||||
|
def process_images(self):
|
||||||
|
# Compress images for Ultralytics HUB
|
||||||
|
for split in 'train', 'val', 'test':
|
||||||
|
if self.data.get(split) is None:
|
||||||
|
continue
|
||||||
|
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
|
||||||
|
desc = f'{split} images'
|
||||||
|
for _ in tqdm(ThreadPool(NUM_THREADS).imap(self._hub_ops, dataset.im_files), total=dataset.n, desc=desc):
|
||||||
pass
|
pass
|
||||||
|
print(f'Done. All images saved to {self.im_dir}')
|
||||||
|
return self.im_dir
|
||||||
|
|
||||||
# Profile
|
|
||||||
stats_path = hub_dir / 'stats.json'
|
|
||||||
if profile:
|
|
||||||
for _ in range(1):
|
|
||||||
file = stats_path.with_suffix('.npy')
|
|
||||||
t1 = time.time()
|
|
||||||
np.save(file, stats)
|
|
||||||
t2 = time.time()
|
|
||||||
x = np.load(file, allow_pickle=True)
|
|
||||||
print(f'stats.npy times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
|
|
||||||
|
|
||||||
file = stats_path.with_suffix('.json')
|
# Classification dataloaders -------------------------------------------------------------------------------------------
|
||||||
t1 = time.time()
|
class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||||
with open(file, 'w') as f:
|
"""
|
||||||
json.dump(stats, f) # save stats *.json
|
YOLOv5 Classification Dataset.
|
||||||
t2 = time.time()
|
Arguments
|
||||||
with open(file) as f:
|
root: Dataset path
|
||||||
x = json.load(f) # load hyps dict
|
transform: torchvision transforms, used by default
|
||||||
print(f'stats.json times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
|
album_transform: Albumentations transforms, used if installed
|
||||||
|
"""
|
||||||
|
|
||||||
# Save, print and return
|
def __init__(self, root, augment, imgsz, cache=False):
|
||||||
if hub:
|
super().__init__(root=root)
|
||||||
print(f'Saving {stats_path.resolve()}...')
|
self.torch_transforms = classify_transforms(imgsz)
|
||||||
with open(stats_path, 'w') as f:
|
self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
|
||||||
json.dump(stats, f) # save stats.json
|
self.cache_ram = cache is True or cache == 'ram'
|
||||||
if verbose:
|
self.cache_disk = cache == 'disk'
|
||||||
print(json.dumps(stats, indent=2, sort_keys=False))
|
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
|
||||||
return stats
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
||||||
|
if self.cache_ram and im is None:
|
||||||
|
im = self.samples[i][3] = cv2.imread(f)
|
||||||
|
elif self.cache_disk:
|
||||||
|
if not fn.exists(): # load npy
|
||||||
|
np.save(fn.as_posix(), cv2.imread(f))
|
||||||
|
im = np.load(fn)
|
||||||
|
else: # read image
|
||||||
|
im = cv2.imread(f) # BGR
|
||||||
|
if self.album_transforms:
|
||||||
|
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
|
||||||
|
else:
|
||||||
|
sample = self.torch_transforms(im)
|
||||||
|
return sample, j
|
||||||
|
|
||||||
|
|
||||||
|
def create_classification_dataloader(path,
|
||||||
|
imgsz=224,
|
||||||
|
batch_size=16,
|
||||||
|
augment=True,
|
||||||
|
cache=False,
|
||||||
|
rank=-1,
|
||||||
|
workers=8,
|
||||||
|
shuffle=True):
|
||||||
|
# Returns Dataloader object to be used with YOLOv5 Classifier
|
||||||
|
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||||
|
dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache)
|
||||||
|
batch_size = min(batch_size, len(dataset))
|
||||||
|
nd = torch.cuda.device_count()
|
||||||
|
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
|
||||||
|
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||||
|
generator = torch.Generator()
|
||||||
|
generator.manual_seed(0)
|
||||||
|
return InfiniteDataLoader(dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=shuffle and sampler is None,
|
||||||
|
num_workers=nw,
|
||||||
|
sampler=sampler,
|
||||||
|
pin_memory=PIN_MEMORY,
|
||||||
|
worker_init_fn=seed_worker,
|
||||||
|
generator=generator) # or DataLoader(persistent_workers=True)
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
# Image is CUDA-optimized for YOLOv5 single/multi-GPU training and inference
|
# Image is CUDA-optimized for YOLOv5 single/multi-GPU training and inference
|
||||||
|
|
||||||
# Start FROM NVIDIA PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch
|
# Start FROM NVIDIA PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch
|
||||||
FROM nvcr.io/nvidia/pytorch:22.06-py3
|
FROM nvcr.io/nvidia/pytorch:22.07-py3
|
||||||
RUN rm -rf /opt/pytorch # remove 1.2GB dir
|
RUN rm -rf /opt/pytorch # remove 1.2GB dir
|
||||||
|
|
||||||
# Downloads to user config dir
|
# Downloads to user config dir
|
||||||
@ -15,7 +15,7 @@ RUN apt update && apt install --no-install-recommends -y zip htop screen libgl1-
|
|||||||
# Install pip packages
|
# Install pip packages
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
RUN python -m pip install --upgrade pip wheel
|
RUN python -m pip install --upgrade pip wheel
|
||||||
RUN pip uninstall -y Pillow torchtext # torch torchvision
|
RUN pip uninstall -y Pillow torchtext torch torchvision
|
||||||
RUN pip install --no-cache -r requirements.txt albumentations wandb gsutil notebook Pillow>=9.1.0 \
|
RUN pip install --no-cache -r requirements.txt albumentations wandb gsutil notebook Pillow>=9.1.0 \
|
||||||
'opencv-python<4.6.0.66' \
|
'opencv-python<4.6.0.66' \
|
||||||
--extra-index-url https://download.pytorch.org/whl/cu113
|
--extra-index-url https://download.pytorch.org/whl/cu113
|
||||||
@ -25,8 +25,8 @@ RUN mkdir -p /usr/src/app
|
|||||||
WORKDIR /usr/src/app
|
WORKDIR /usr/src/app
|
||||||
|
|
||||||
# Copy contents
|
# Copy contents
|
||||||
COPY . /usr/src/app
|
# COPY . /usr/src/app (issues as not a .git directory)
|
||||||
RUN git clone https://github.com/ultralytics/yolov5 /usr/src/yolov5
|
RUN git clone https://github.com/ultralytics/yolov5 /usr/src/app
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
ENV OMP_NUM_THREADS=8
|
ENV OMP_NUM_THREADS=8
|
||||||
@ -49,11 +49,8 @@ ENV OMP_NUM_THREADS=8
|
|||||||
# Kill all image-based
|
# Kill all image-based
|
||||||
# sudo docker kill $(sudo docker ps -qa --filter ancestor=ultralytics/yolov5:latest)
|
# sudo docker kill $(sudo docker ps -qa --filter ancestor=ultralytics/yolov5:latest)
|
||||||
|
|
||||||
# Bash into running container
|
# DockerHub tag update
|
||||||
# sudo docker exec -it 5a9b5863d93d bash
|
# t=ultralytics/yolov5:latest tnew=ultralytics/yolov5:v6.2 && sudo docker pull $t && sudo docker tag $t $tnew && sudo docker push $tnew
|
||||||
|
|
||||||
# Bash into stopped container
|
|
||||||
# id=$(sudo docker ps -qa) && sudo docker start $id && sudo docker exec -it $id bash
|
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
# docker system prune -a --volumes
|
# docker system prune -a --volumes
|
||||||
|
@ -11,8 +11,7 @@ ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Aria
|
|||||||
# Install linux packages
|
# Install linux packages
|
||||||
RUN apt update
|
RUN apt update
|
||||||
RUN DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt install -y tzdata
|
RUN DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt install -y tzdata
|
||||||
RUN apt install --no-install-recommends -y python3-pip git zip curl htop gcc \
|
RUN apt install --no-install-recommends -y python3-pip git zip curl htop gcc libgl1-mesa-glx libglib2.0-0 libpython3-dev
|
||||||
libgl1-mesa-glx libglib2.0-0 libpython3.8-dev
|
|
||||||
# RUN alias python=python3
|
# RUN alias python=python3
|
||||||
|
|
||||||
# Install pip packages
|
# Install pip packages
|
||||||
@ -29,8 +28,8 @@ RUN mkdir -p /usr/src/app
|
|||||||
WORKDIR /usr/src/app
|
WORKDIR /usr/src/app
|
||||||
|
|
||||||
# Copy contents
|
# Copy contents
|
||||||
COPY . /usr/src/app
|
# COPY . /usr/src/app (issues as not a .git directory)
|
||||||
RUN git clone https://github.com/ultralytics/yolov5 /usr/src/yolov5
|
RUN git clone https://github.com/ultralytics/yolov5 /usr/src/app
|
||||||
|
|
||||||
|
|
||||||
# Usage Examples -------------------------------------------------------------------------------------------------------
|
# Usage Examples -------------------------------------------------------------------------------------------------------
|
||||||
|
@ -11,14 +11,15 @@ ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Aria
|
|||||||
# Install linux packages
|
# Install linux packages
|
||||||
RUN apt update
|
RUN apt update
|
||||||
RUN DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt install -y tzdata
|
RUN DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt install -y tzdata
|
||||||
RUN apt install --no-install-recommends -y python3-pip git zip curl htop libgl1-mesa-glx libglib2.0-0 libpython3.8-dev
|
RUN apt install --no-install-recommends -y python3-pip git zip curl htop libgl1-mesa-glx libglib2.0-0 libpython3-dev
|
||||||
# RUN alias python=python3
|
# RUN alias python=python3
|
||||||
|
|
||||||
# Install pip packages
|
# Install pip packages
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
RUN python3 -m pip install --upgrade pip wheel
|
RUN python3 -m pip install --upgrade pip wheel
|
||||||
RUN pip install --no-cache -r requirements.txt albumentations gsutil notebook \
|
RUN pip install --no-cache -r requirements.txt albumentations gsutil notebook \
|
||||||
coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu tensorflowjs \
|
coremltools onnx onnx-simplifier onnxruntime tensorflow-cpu tensorflowjs \
|
||||||
|
# openvino-dev \
|
||||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
|
|
||||||
# Create working directory
|
# Create working directory
|
||||||
@ -26,8 +27,8 @@ RUN mkdir -p /usr/src/app
|
|||||||
WORKDIR /usr/src/app
|
WORKDIR /usr/src/app
|
||||||
|
|
||||||
# Copy contents
|
# Copy contents
|
||||||
COPY . /usr/src/app
|
# COPY . /usr/src/app (issues as not a .git directory)
|
||||||
RUN git clone https://github.com/ultralytics/yolov5 /usr/src/yolov5
|
RUN git clone https://github.com/ultralytics/yolov5 /usr/src/app
|
||||||
|
|
||||||
|
|
||||||
# Usage Examples -------------------------------------------------------------------------------------------------------
|
# Usage Examples -------------------------------------------------------------------------------------------------------
|
||||||
|
@ -16,12 +16,14 @@ import requests
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def is_url(url):
|
def is_url(url, check_online=True):
|
||||||
# Check if online file exists
|
# Check if online file exists
|
||||||
try:
|
try:
|
||||||
r = urllib.request.urlopen(url) # response
|
url = str(url)
|
||||||
return r.getcode() == 200
|
result = urllib.parse.urlparse(url)
|
||||||
except urllib.request.HTTPError:
|
assert all([result.scheme, result.netloc, result.path]) # check if is url
|
||||||
|
return (urllib.request.urlopen(url).getcode() == 200) if check_online else True # check if exists online
|
||||||
|
except (AssertionError, urllib.request.HTTPError):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@ -31,6 +33,12 @@ def gsutil_getsize(url=''):
|
|||||||
return eval(s.split(' ')[0]) if len(s) else 0 # bytes
|
return eval(s.split(' ')[0]) if len(s) else 0 # bytes
|
||||||
|
|
||||||
|
|
||||||
|
def url_getsize(url='https://ultralytics.com/images/bus.jpg'):
|
||||||
|
# Return downloadable file size in bytes
|
||||||
|
response = requests.head(url, allow_redirects=True)
|
||||||
|
return int(response.headers.get('content-length', -1))
|
||||||
|
|
||||||
|
|
||||||
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
||||||
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
|
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
|
||||||
from utils.general import LOGGER
|
from utils.general import LOGGER
|
||||||
@ -42,24 +50,26 @@ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
|||||||
torch.hub.download_url_to_file(url, str(file), progress=LOGGER.level <= logging.INFO)
|
torch.hub.download_url_to_file(url, str(file), progress=LOGGER.level <= logging.INFO)
|
||||||
assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
|
assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
|
||||||
except Exception as e: # url2
|
except Exception as e: # url2
|
||||||
file.unlink(missing_ok=True) # remove partial downloads
|
if file.exists():
|
||||||
|
file.unlink() # remove partial downloads
|
||||||
LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
|
LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
|
||||||
os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
|
os.system(f"curl -# -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
|
||||||
finally:
|
finally:
|
||||||
if not file.exists() or file.stat().st_size < min_bytes: # check
|
if not file.exists() or file.stat().st_size < min_bytes: # check
|
||||||
file.unlink(missing_ok=True) # remove partial downloads
|
if file.exists():
|
||||||
|
file.unlink() # remove partial downloads
|
||||||
LOGGER.info(f"ERROR: {assert_msg}\n{error_msg}")
|
LOGGER.info(f"ERROR: {assert_msg}\n{error_msg}")
|
||||||
LOGGER.info('')
|
LOGGER.info('')
|
||||||
|
|
||||||
|
|
||||||
def attempt_download(file, repo='ultralytics/yolov5', release='v6.1'):
|
def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
|
||||||
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.1', etc.
|
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
|
||||||
from utils.general import LOGGER
|
from utils.general import LOGGER
|
||||||
|
|
||||||
def github_assets(repository, version='latest'):
|
def github_assets(repository, version='latest'):
|
||||||
# Return GitHub repo tag (i.e. 'v6.1') and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...])
|
# Return GitHub repo tag (i.e. 'v6.2') and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...])
|
||||||
if version != 'latest':
|
if version != 'latest':
|
||||||
version = f'tags/{version}' # i.e. tags/v6.1
|
version = f'tags/{version}' # i.e. tags/v6.2
|
||||||
response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
|
response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
|
||||||
return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
|
return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
|
||||||
|
|
||||||
@ -110,8 +120,10 @@ def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):
|
|||||||
file = Path(file)
|
file = Path(file)
|
||||||
cookie = Path('cookie') # gdrive cookie
|
cookie = Path('cookie') # gdrive cookie
|
||||||
print(f'Downloading https://drive.google.com/uc?export=download&id={id} as {file}... ', end='')
|
print(f'Downloading https://drive.google.com/uc?export=download&id={id} as {file}... ', end='')
|
||||||
file.unlink(missing_ok=True) # remove existing file
|
if file.exists():
|
||||||
cookie.unlink(missing_ok=True) # remove existing cookie
|
file.unlink() # remove existing file
|
||||||
|
if cookie.exists():
|
||||||
|
cookie.unlink() # remove existing cookie
|
||||||
|
|
||||||
# Attempt file download
|
# Attempt file download
|
||||||
out = "NUL" if platform.system() == "Windows" else "/dev/null"
|
out = "NUL" if platform.system() == "Windows" else "/dev/null"
|
||||||
@ -121,11 +133,13 @@ def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):
|
|||||||
else: # small file
|
else: # small file
|
||||||
s = f'curl -s -L -o {file} "drive.google.com/uc?export=download&id={id}"'
|
s = f'curl -s -L -o {file} "drive.google.com/uc?export=download&id={id}"'
|
||||||
r = os.system(s) # execute, capture return
|
r = os.system(s) # execute, capture return
|
||||||
cookie.unlink(missing_ok=True) # remove existing cookie
|
if cookie.exists():
|
||||||
|
cookie.unlink() # remove existing cookie
|
||||||
|
|
||||||
# Error check
|
# Error check
|
||||||
if r != 0:
|
if r != 0:
|
||||||
file.unlink(missing_ok=True) # remove partial
|
if file.exists():
|
||||||
|
file.unlink() # remove partial
|
||||||
print('Download error ') # raise Exception('Download error')
|
print('Download error ') # raise Exception('Download error')
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
||||||
"""
|
"""
|
||||||
Run a Flask REST API exposing a YOLOv5s model
|
Run a Flask REST API exposing one or more YOLOv5s models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -11,12 +11,13 @@ from flask import Flask, request
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
models = {}
|
||||||
|
|
||||||
DETECTION_URL = "/v1/object-detection/yolov5s"
|
DETECTION_URL = "/v1/object-detection/<model>"
|
||||||
|
|
||||||
|
|
||||||
@app.route(DETECTION_URL, methods=["POST"])
|
@app.route(DETECTION_URL, methods=["POST"])
|
||||||
def predict():
|
def predict(model):
|
||||||
if request.method != "POST":
|
if request.method != "POST":
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -30,17 +31,18 @@ def predict():
|
|||||||
im_bytes = im_file.read()
|
im_bytes = im_file.read()
|
||||||
im = Image.open(io.BytesIO(im_bytes))
|
im = Image.open(io.BytesIO(im_bytes))
|
||||||
|
|
||||||
results = model(im, size=640) # reduce size=320 for faster inference
|
if model in models:
|
||||||
return results.pandas().xyxy[0].to_json(orient="records")
|
results = models[model](im, size=640) # reduce size=320 for faster inference
|
||||||
|
return results.pandas().xyxy[0].to_json(orient="records")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Flask API exposing YOLOv5 model")
|
parser = argparse.ArgumentParser(description="Flask API exposing YOLOv5 model")
|
||||||
parser.add_argument("--port", default=5000, type=int, help="port number")
|
parser.add_argument("--port", default=5000, type=int, help="port number")
|
||||||
|
parser.add_argument('--model', nargs='+', default=['yolov5s'], help='model(s) to run, i.e. --model yolov5n yolov5s')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
|
||||||
# Fix known issue urllib.error.HTTPError 403: rate limit exceeded https://github.com/ultralytics/yolov5/pull/7210
|
for m in opt.model:
|
||||||
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
|
models[m] = torch.hub.load("ultralytics/yolov5", m, force_reload=True, skip_validation=True)
|
||||||
|
|
||||||
model = torch.hub.load("ultralytics/yolov5", "yolov5s", force_reload=True) # force_reload to recache
|
|
||||||
app.run(host="0.0.0.0", port=opt.port) # debug=True causes Restarting with stat
|
app.run(host="0.0.0.0", port=opt.port) # debug=True causes Restarting with stat
|
||||||
|
240
utils/general.py
240
utils/general.py
@ -14,7 +14,7 @@ import random
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import signal
|
import signal
|
||||||
import threading
|
import sys
|
||||||
import time
|
import time
|
||||||
import urllib
|
import urllib
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -33,6 +33,7 @@ import torch
|
|||||||
import torchvision
|
import torchvision
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from utils import TryExcept
|
||||||
from utils.downloads import gsutil_getsize
|
from utils.downloads import gsutil_getsize
|
||||||
from utils.metrics import box_iou, fitness
|
from utils.metrics import box_iou, fitness
|
||||||
|
|
||||||
@ -55,20 +56,42 @@ os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
|||||||
os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS) # OpenMP (PyTorch and SciPy)
|
os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS) # OpenMP (PyTorch and SciPy)
|
||||||
|
|
||||||
|
|
||||||
|
def is_ascii(s=''):
|
||||||
|
# Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
|
||||||
|
s = str(s) # convert list, tuple, None, etc. to str
|
||||||
|
return len(s.encode().decode('ascii', 'ignore')) == len(s)
|
||||||
|
|
||||||
|
|
||||||
|
def is_chinese(s='人工智能'):
|
||||||
|
# Is string composed of any Chinese characters?
|
||||||
|
return bool(re.search('[\u4e00-\u9fff]', str(s)))
|
||||||
|
|
||||||
|
|
||||||
|
def is_colab():
|
||||||
|
# Is environment a Google Colab instance?
|
||||||
|
return 'COLAB_GPU' in os.environ
|
||||||
|
|
||||||
|
|
||||||
def is_kaggle():
|
def is_kaggle():
|
||||||
# Is environment a Kaggle Notebook?
|
# Is environment a Kaggle Notebook?
|
||||||
try:
|
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
|
||||||
assert os.environ.get('PWD') == '/kaggle/working'
|
|
||||||
assert os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
|
|
||||||
|
def is_docker() -> bool:
|
||||||
|
"""Check if the process runs inside a docker container."""
|
||||||
|
if Path("/.dockerenv").exists():
|
||||||
return True
|
return True
|
||||||
except AssertionError:
|
try: # check if docker is in control groups
|
||||||
|
with open("/proc/self/cgroup") as file:
|
||||||
|
return any("docker" in line for line in file)
|
||||||
|
except OSError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_writeable(dir, test=False):
|
def is_writeable(dir, test=False):
|
||||||
# Return True if directory has write permissions, test opening a file with write permissions if test=True
|
# Return True if directory has write permissions, test opening a file with write permissions if test=True
|
||||||
if not test:
|
if not test:
|
||||||
return os.access(dir, os.R_OK) # possible issues on Windows
|
return os.access(dir, os.W_OK) # possible issues on Windows
|
||||||
file = Path(dir) / 'tmp.txt'
|
file = Path(dir) / 'tmp.txt'
|
||||||
try:
|
try:
|
||||||
with open(file, 'w'): # open file with write permissions
|
with open(file, 'w'): # open file with write permissions
|
||||||
@ -81,7 +104,7 @@ def is_writeable(dir, test=False):
|
|||||||
|
|
||||||
def set_logging(name=None, verbose=VERBOSE):
|
def set_logging(name=None, verbose=VERBOSE):
|
||||||
# Sets level and returns logger
|
# Sets level and returns logger
|
||||||
if is_kaggle():
|
if is_kaggle() or is_colab():
|
||||||
for h in logging.root.handlers:
|
for h in logging.root.handlers:
|
||||||
logging.root.removeHandler(h) # remove all handlers associated with the root logger object
|
logging.root.removeHandler(h) # remove all handlers associated with the root logger object
|
||||||
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
|
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
|
||||||
@ -96,6 +119,9 @@ def set_logging(name=None, verbose=VERBOSE):
|
|||||||
|
|
||||||
set_logging() # run before defining LOGGER
|
set_logging() # run before defining LOGGER
|
||||||
LOGGER = logging.getLogger("yolov5") # define globally (used in train.py, val.py, detect.py, etc.)
|
LOGGER = logging.getLogger("yolov5") # define globally (used in train.py, val.py, detect.py, etc.)
|
||||||
|
if platform.system() == 'Windows':
|
||||||
|
for fn in LOGGER.info, LOGGER.warning:
|
||||||
|
setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
|
||||||
|
|
||||||
|
|
||||||
def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
|
def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
|
||||||
@ -115,16 +141,27 @@ CONFIG_DIR = user_config_dir() # Ultralytics settings dir
|
|||||||
|
|
||||||
|
|
||||||
class Profile(contextlib.ContextDecorator):
|
class Profile(contextlib.ContextDecorator):
|
||||||
# Usage: @Profile() decorator or 'with Profile():' context manager
|
# YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
|
||||||
|
def __init__(self, t=0.0):
|
||||||
|
self.t = t
|
||||||
|
self.cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.start = time.time()
|
self.start = self.time()
|
||||||
|
return self
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
def __exit__(self, type, value, traceback):
|
||||||
print(f'Profile results: {time.time() - self.start:.5f}s')
|
self.dt = self.time() - self.start # delta-time
|
||||||
|
self.t += self.dt # accumulate dt
|
||||||
|
|
||||||
|
def time(self):
|
||||||
|
if self.cuda:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
return time.time()
|
||||||
|
|
||||||
|
|
||||||
class Timeout(contextlib.ContextDecorator):
|
class Timeout(contextlib.ContextDecorator):
|
||||||
# Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
|
# YOLOv5 Timeout class. Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
|
||||||
def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
|
def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
|
||||||
self.seconds = int(seconds)
|
self.seconds = int(seconds)
|
||||||
self.timeout_message = timeout_msg
|
self.timeout_message = timeout_msg
|
||||||
@ -158,64 +195,50 @@ class WorkingDirectory(contextlib.ContextDecorator):
|
|||||||
os.chdir(self.cwd)
|
os.chdir(self.cwd)
|
||||||
|
|
||||||
|
|
||||||
def try_except(func):
|
|
||||||
# try-except function. Usage: @try_except decorator
|
|
||||||
def handler(*args, **kwargs):
|
|
||||||
try:
|
|
||||||
func(*args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
return handler
|
|
||||||
|
|
||||||
|
|
||||||
def threaded(func):
|
|
||||||
# Multi-threads a target function and returns thread. Usage: @threaded decorator
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
|
||||||
thread.start()
|
|
||||||
return thread
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def methods(instance):
|
def methods(instance):
|
||||||
# Get class/instance methods
|
# Get class/instance methods
|
||||||
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
|
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
|
||||||
|
|
||||||
|
|
||||||
def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False):
|
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
||||||
# Print function arguments (optional args dict)
|
# Print function arguments (optional args dict)
|
||||||
x = inspect.currentframe().f_back # previous frame
|
x = inspect.currentframe().f_back # previous frame
|
||||||
file, _, fcn, _, _ = inspect.getframeinfo(x)
|
file, _, func, _, _ = inspect.getframeinfo(x)
|
||||||
if args is None: # get args automatically
|
if args is None: # get args automatically
|
||||||
args, _, _, frm = inspect.getargvalues(x)
|
args, _, _, frm = inspect.getargvalues(x)
|
||||||
args = {k: v for k, v in frm.items() if k in args}
|
args = {k: v for k, v in frm.items() if k in args}
|
||||||
s = (f'{Path(file).stem}: ' if show_file else '') + (f'{fcn}: ' if show_fcn else '')
|
try:
|
||||||
|
file = Path(file).resolve().relative_to(ROOT).with_suffix('')
|
||||||
|
except ValueError:
|
||||||
|
file = Path(file).stem
|
||||||
|
s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
|
||||||
LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
|
LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
|
||||||
|
|
||||||
|
|
||||||
def init_seeds(seed=0, deterministic=False):
|
def init_seeds(seed=0, deterministic=False):
|
||||||
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
|
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
|
||||||
# cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
|
|
||||||
import torch.backends.cudnn as cudnn
|
|
||||||
|
|
||||||
if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
|
|
||||||
torch.use_deterministic_algorithms(True)
|
|
||||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
|
||||||
# os.environ['PYTHONHASHSEED'] = str(seed)
|
|
||||||
|
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
|
torch.cuda.manual_seed(seed)
|
||||||
# torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
|
||||||
# torch.cuda.manual_seed_all(seed) # for multi GPU, exception safe
|
torch.backends.cudnn.benchmark = True # for faster training
|
||||||
|
if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
|
||||||
|
torch.use_deterministic_algorithms(True)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
||||||
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||||
|
|
||||||
|
|
||||||
def intersect_dicts(da, db, exclude=()):
|
def intersect_dicts(da, db, exclude=()):
|
||||||
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
||||||
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
|
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_args(func):
|
||||||
|
# Get func() default arguments
|
||||||
|
signature = inspect.signature(func)
|
||||||
|
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
|
||||||
|
|
||||||
|
|
||||||
def get_latest_run(search_dir='.'):
|
def get_latest_run(search_dir='.'):
|
||||||
@ -224,36 +247,6 @@ def get_latest_run(search_dir='.'):
|
|||||||
return max(last_list, key=os.path.getctime) if last_list else ''
|
return max(last_list, key=os.path.getctime) if last_list else ''
|
||||||
|
|
||||||
|
|
||||||
def is_docker():
|
|
||||||
# Is environment a Docker container?
|
|
||||||
return Path('/workspace').exists() # or Path('/.dockerenv').exists()
|
|
||||||
|
|
||||||
|
|
||||||
def is_colab():
|
|
||||||
# Is environment a Google Colab instance?
|
|
||||||
try:
|
|
||||||
import google.colab
|
|
||||||
return True
|
|
||||||
except ImportError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_pip():
|
|
||||||
# Is file in a pip package?
|
|
||||||
return 'site-packages' in Path(__file__).resolve().parts
|
|
||||||
|
|
||||||
|
|
||||||
def is_ascii(s=''):
|
|
||||||
# Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
|
|
||||||
s = str(s) # convert list, tuple, None, etc. to str
|
|
||||||
return len(s.encode().decode('ascii', 'ignore')) == len(s)
|
|
||||||
|
|
||||||
|
|
||||||
def is_chinese(s='人工智能'):
|
|
||||||
# Is string composed of any Chinese characters?
|
|
||||||
return bool(re.search('[\u4e00-\u9fff]', str(s)))
|
|
||||||
|
|
||||||
|
|
||||||
def emojis(str=''):
|
def emojis(str=''):
|
||||||
# Return platform-dependent emoji-safe version of string
|
# Return platform-dependent emoji-safe version of string
|
||||||
return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
|
return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
|
||||||
@ -302,25 +295,32 @@ def git_describe(path=ROOT): # path must be a directory
|
|||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
@try_except
|
@TryExcept()
|
||||||
@WorkingDirectory(ROOT)
|
@WorkingDirectory(ROOT)
|
||||||
def check_git_status():
|
def check_git_status(repo='ultralytics/yolov5', branch='master'):
|
||||||
# Recommend 'git pull' if code is out of date
|
# YOLOv5 status check, recommend 'git pull' if code is out of date
|
||||||
msg = ', for updates see https://github.com/ultralytics/yolov5'
|
url = f'https://github.com/{repo}'
|
||||||
|
msg = f', for updates see {url}'
|
||||||
s = colorstr('github: ') # string
|
s = colorstr('github: ') # string
|
||||||
assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
|
assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
|
||||||
assert not is_docker(), s + 'skipping check (Docker image)' + msg
|
|
||||||
assert check_online(), s + 'skipping check (offline)' + msg
|
assert check_online(), s + 'skipping check (offline)' + msg
|
||||||
|
|
||||||
cmd = 'git fetch && git config --get remote.origin.url'
|
splits = re.split(pattern=r'\s', string=check_output('git remote -v', shell=True).decode())
|
||||||
url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
|
matches = [repo in s for s in splits]
|
||||||
branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
|
if any(matches):
|
||||||
n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
|
remote = splits[matches.index(True) - 1]
|
||||||
|
else:
|
||||||
|
remote = 'ultralytics'
|
||||||
|
check_output(f'git remote add {remote} {url}', shell=True)
|
||||||
|
check_output(f'git fetch {remote}', shell=True, timeout=5) # git fetch
|
||||||
|
local_branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
|
||||||
|
n = int(check_output(f'git rev-list {local_branch}..{remote}/{branch} --count', shell=True)) # commits behind
|
||||||
if n > 0:
|
if n > 0:
|
||||||
s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
|
pull = 'git pull' if remote == 'origin' else f'git pull {remote} {branch}'
|
||||||
|
s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `{pull}` or `git clone {url}` to update."
|
||||||
else:
|
else:
|
||||||
s += f'up to date with {url} ✅'
|
s += f'up to date with {url} ✅'
|
||||||
LOGGER.info(emojis(s)) # emoji-safe
|
LOGGER.info(s)
|
||||||
|
|
||||||
|
|
||||||
def check_python(minimum='3.7.0'):
|
def check_python(minimum='3.7.0'):
|
||||||
@ -332,17 +332,17 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals
|
|||||||
# Check version vs. required version
|
# Check version vs. required version
|
||||||
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
|
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
|
||||||
result = (current == minimum) if pinned else (current >= minimum) # bool
|
result = (current == minimum) if pinned else (current >= minimum) # bool
|
||||||
s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
|
s = f'WARNING: ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed' # string
|
||||||
if hard:
|
if hard:
|
||||||
assert result, s # assert min requirements met
|
assert result, emojis(s) # assert min requirements met
|
||||||
if verbose and not result:
|
if verbose and not result:
|
||||||
LOGGER.warning(s)
|
LOGGER.warning(s)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@try_except
|
@TryExcept()
|
||||||
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
|
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
|
||||||
# Check installed dependencies meet requirements (pass *.txt file or list of packages)
|
# Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages)
|
||||||
prefix = colorstr('red', 'bold', 'requirements:')
|
prefix = colorstr('red', 'bold', 'requirements:')
|
||||||
check_python() # check python version
|
check_python() # check python version
|
||||||
if isinstance(requirements, (str, Path)): # requirements.txt file
|
if isinstance(requirements, (str, Path)): # requirements.txt file
|
||||||
@ -374,7 +374,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta
|
|||||||
source = file.resolve() if 'file' in locals() else requirements
|
source = file.resolve() if 'file' in locals() else requirements
|
||||||
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
|
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
|
||||||
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
|
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
|
||||||
LOGGER.info(emojis(s))
|
LOGGER.info(s)
|
||||||
|
|
||||||
|
|
||||||
def check_img_size(imgsz, s=32, floor=0):
|
def check_img_size(imgsz, s=32, floor=0):
|
||||||
@ -436,6 +436,9 @@ def check_file(file, suffix=''):
|
|||||||
torch.hub.download_url_to_file(url, file)
|
torch.hub.download_url_to_file(url, file)
|
||||||
assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
|
assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
|
||||||
return file
|
return file
|
||||||
|
elif file.startswith('clearml://'): # ClearML Dataset ID
|
||||||
|
assert 'clearml' in sys.modules, "ClearML is not installed, so cannot use ClearML dataset. Try running 'pip install clearml'."
|
||||||
|
return file
|
||||||
else: # search
|
else: # search
|
||||||
files = []
|
files = []
|
||||||
for d in 'data', 'models', 'utils': # search directories
|
for d in 'data', 'models', 'utils': # search directories
|
||||||
@ -450,7 +453,7 @@ def check_font(font=FONT, progress=False):
|
|||||||
font = Path(font)
|
font = Path(font)
|
||||||
file = CONFIG_DIR / font.name
|
file = CONFIG_DIR / font.name
|
||||||
if not font.exists() and not file.exists():
|
if not font.exists() and not file.exists():
|
||||||
url = "https://ultralytics.com/assets/" + font.name
|
url = f'https://ultralytics.com/assets/{font.name}'
|
||||||
LOGGER.info(f'Downloading {url} to {file}...')
|
LOGGER.info(f'Downloading {url} to {file}...')
|
||||||
torch.hub.download_url_to_file(url, str(file), progress=progress)
|
torch.hub.download_url_to_file(url, str(file), progress=progress)
|
||||||
|
|
||||||
@ -461,7 +464,7 @@ def check_dataset(data, autodownload=True):
|
|||||||
# Download (optional)
|
# Download (optional)
|
||||||
extract_dir = ''
|
extract_dir = ''
|
||||||
if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
|
if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
|
||||||
download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False, threads=1)
|
download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1)
|
||||||
data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
|
data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
|
||||||
extract_dir, autodownload = data.parent, False
|
extract_dir, autodownload = data.parent, False
|
||||||
|
|
||||||
@ -471,11 +474,11 @@ def check_dataset(data, autodownload=True):
|
|||||||
data = yaml.safe_load(f) # dictionary
|
data = yaml.safe_load(f) # dictionary
|
||||||
|
|
||||||
# Checks
|
# Checks
|
||||||
for k in 'train', 'val', 'nc':
|
for k in 'train', 'val', 'names':
|
||||||
assert k in data, emojis(f"data.yaml '{k}:' field missing ❌")
|
assert k in data, f"data.yaml '{k}:' field missing ❌"
|
||||||
if 'names' not in data:
|
if isinstance(data['names'], (list, tuple)): # old array format
|
||||||
LOGGER.warning(emojis("data.yaml 'names:' field missing ⚠, assigning default names 'class0', 'class1', etc."))
|
data['names'] = dict(enumerate(data['names'])) # convert to dict
|
||||||
data['names'] = [f'class{i}' for i in range(data['nc'])] # default names
|
data['nc'] = len(data['names'])
|
||||||
|
|
||||||
# Resolve paths
|
# Resolve paths
|
||||||
path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
|
path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
|
||||||
@ -490,9 +493,9 @@ def check_dataset(data, autodownload=True):
|
|||||||
if val:
|
if val:
|
||||||
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
||||||
if not all(x.exists() for x in val):
|
if not all(x.exists() for x in val):
|
||||||
LOGGER.info(emojis('\nDataset not found ⚠, missing paths %s' % [str(x) for x in val if not x.exists()]))
|
LOGGER.info('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()])
|
||||||
if not s or not autodownload:
|
if not s or not autodownload:
|
||||||
raise Exception(emojis('Dataset not found ❌'))
|
raise Exception('Dataset not found ❌')
|
||||||
t = time.time()
|
t = time.time()
|
||||||
root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
|
root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
|
||||||
if s.startswith('http') and s.endswith('.zip'): # URL
|
if s.startswith('http') and s.endswith('.zip'): # URL
|
||||||
@ -510,7 +513,7 @@ def check_dataset(data, autodownload=True):
|
|||||||
r = exec(s, {'yaml': data}) # return None
|
r = exec(s, {'yaml': data}) # return None
|
||||||
dt = f'({round(time.time() - t, 1)}s)'
|
dt = f'({round(time.time() - t, 1)}s)'
|
||||||
s = f"success ✅ {dt}, saved to {colorstr('bold', root)}" if r in (0, None) else f"failure {dt} ❌"
|
s = f"success ✅ {dt}, saved to {colorstr('bold', root)}" if r in (0, None) else f"failure {dt} ❌"
|
||||||
LOGGER.info(emojis(f"Dataset download {s}"))
|
LOGGER.info(f"Dataset download {s}")
|
||||||
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
|
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
|
||||||
return data # dictionary
|
return data # dictionary
|
||||||
|
|
||||||
@ -529,20 +532,32 @@ def check_amp(model):
|
|||||||
|
|
||||||
prefix = colorstr('AMP: ')
|
prefix = colorstr('AMP: ')
|
||||||
device = next(model.parameters()).device # get model device
|
device = next(model.parameters()).device # get model device
|
||||||
if device.type == 'cpu':
|
if device.type in ('cpu', 'mps'):
|
||||||
return False # AMP disabled on CPU
|
return False # AMP only used on CUDA devices
|
||||||
f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
|
f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
|
||||||
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
|
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
|
||||||
try:
|
try:
|
||||||
assert amp_allclose(model, im) or amp_allclose(DetectMultiBackend('yolov5n.pt', device), im)
|
assert amp_allclose(model, im) or amp_allclose(DetectMultiBackend('yolov5n.pt', device), im)
|
||||||
LOGGER.info(emojis(f'{prefix}checks passed ✅'))
|
LOGGER.info(f'{prefix}checks passed ✅')
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
|
help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
|
||||||
LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}'))
|
LOGGER.warning(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def yaml_load(file='data.yaml'):
|
||||||
|
# Single-line safe yaml loading
|
||||||
|
with open(file, errors='ignore') as f:
|
||||||
|
return yaml.safe_load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def yaml_save(file='data.yaml', data={}):
|
||||||
|
# Single-line safe yaml saving
|
||||||
|
with open(file, 'w') as f:
|
||||||
|
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
|
||||||
|
|
||||||
|
|
||||||
def url2file(url):
|
def url2file(url):
|
||||||
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
|
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
|
||||||
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
|
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
|
||||||
@ -550,7 +565,7 @@ def url2file(url):
|
|||||||
|
|
||||||
|
|
||||||
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
|
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
|
||||||
# Multi-threaded file download and unzip function, used in data.yaml for autodownload
|
# Multithreaded file download and unzip function, used in data.yaml for autodownload
|
||||||
def download_one(url, dir):
|
def download_one(url, dir):
|
||||||
# Download 1 file
|
# Download 1 file
|
||||||
success = True
|
success = True
|
||||||
@ -562,7 +577,8 @@ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry
|
|||||||
for i in range(retry + 1):
|
for i in range(retry + 1):
|
||||||
if curl:
|
if curl:
|
||||||
s = 'sS' if threads > 1 else '' # silent
|
s = 'sS' if threads > 1 else '' # silent
|
||||||
r = os.system(f'curl -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
|
r = os.system(
|
||||||
|
f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
|
||||||
success = r == 0
|
success = r == 0
|
||||||
else:
|
else:
|
||||||
torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
|
torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
|
||||||
@ -574,10 +590,12 @@ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry
|
|||||||
else:
|
else:
|
||||||
LOGGER.warning(f'Failed to download {url}...')
|
LOGGER.warning(f'Failed to download {url}...')
|
||||||
|
|
||||||
if unzip and success and f.suffix in ('.zip', '.gz'):
|
if unzip and success and f.suffix in ('.zip', '.tar', '.gz'):
|
||||||
LOGGER.info(f'Unzipping {f}...')
|
LOGGER.info(f'Unzipping {f}...')
|
||||||
if f.suffix == '.zip':
|
if f.suffix == '.zip':
|
||||||
ZipFile(f).extractall(path=dir) # unzip
|
ZipFile(f).extractall(path=dir) # unzip
|
||||||
|
elif f.suffix == '.tar':
|
||||||
|
os.system(f'tar xf {f} --directory {f.parent}') # unzip
|
||||||
elif f.suffix == '.gz':
|
elif f.suffix == '.gz':
|
||||||
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
|
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
|
||||||
if delete:
|
if delete:
|
||||||
@ -587,7 +605,7 @@ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry
|
|||||||
dir.mkdir(parents=True, exist_ok=True) # make directory
|
dir.mkdir(parents=True, exist_ok=True) # make directory
|
||||||
if threads > 1:
|
if threads > 1:
|
||||||
pool = ThreadPool(threads)
|
pool = ThreadPool(threads)
|
||||||
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
|
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
|
||||||
pool.close()
|
pool.close()
|
||||||
pool.join()
|
pool.join()
|
||||||
else:
|
else:
|
||||||
|
@ -5,17 +5,19 @@ Logging utils
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pkg_resources as pkg
|
import pkg_resources as pkg
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from utils.general import colorstr, cv2, emojis
|
from utils.general import colorstr, cv2
|
||||||
|
from utils.loggers.clearml.clearml_utils import ClearmlLogger
|
||||||
from utils.loggers.wandb.wandb_utils import WandbLogger
|
from utils.loggers.wandb.wandb_utils import WandbLogger
|
||||||
from utils.plots import plot_images, plot_results
|
from utils.plots import plot_images, plot_labels, plot_results
|
||||||
from utils.torch_utils import de_parallel
|
from utils.torch_utils import de_parallel
|
||||||
|
|
||||||
LOGGERS = ('csv', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases
|
LOGGERS = ('csv', 'tb', 'wandb', 'clearml') # *.csv, TensorBoard, Weights & Biases, ClearML
|
||||||
RANK = int(os.getenv('RANK', -1))
|
RANK = int(os.getenv('RANK', -1))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -32,6 +34,13 @@ try:
|
|||||||
except (ImportError, AssertionError):
|
except (ImportError, AssertionError):
|
||||||
wandb = None
|
wandb = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import clearml
|
||||||
|
|
||||||
|
assert hasattr(clearml, '__version__') # verify package import not local dir
|
||||||
|
except (ImportError, AssertionError):
|
||||||
|
clearml = None
|
||||||
|
|
||||||
|
|
||||||
class Loggers():
|
class Loggers():
|
||||||
# YOLOv5 Loggers class
|
# YOLOv5 Loggers class
|
||||||
@ -40,6 +49,7 @@ class Loggers():
|
|||||||
self.weights = weights
|
self.weights = weights
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.hyp = hyp
|
self.hyp = hyp
|
||||||
|
self.plots = not opt.noplots # plot results
|
||||||
self.logger = logger # for printing results to console
|
self.logger = logger # for printing results to console
|
||||||
self.include = include
|
self.include = include
|
||||||
self.keys = [
|
self.keys = [
|
||||||
@ -61,11 +71,15 @@ class Loggers():
|
|||||||
setattr(self, k, None) # init empty logger dictionary
|
setattr(self, k, None) # init empty logger dictionary
|
||||||
self.csv = True # always log to csv
|
self.csv = True # always log to csv
|
||||||
|
|
||||||
# Message
|
# Messages
|
||||||
if not wandb:
|
if not wandb:
|
||||||
prefix = colorstr('Weights & Biases: ')
|
prefix = colorstr('Weights & Biases: ')
|
||||||
s = f"{prefix}run 'pip install wandb' to automatically track and visualize YOLOv5 🚀 runs (RECOMMENDED)"
|
s = f"{prefix}run 'pip install wandb' to automatically track and visualize YOLOv5 🚀 runs in Weights & Biases"
|
||||||
self.logger.info(emojis(s))
|
self.logger.info(s)
|
||||||
|
if not clearml:
|
||||||
|
prefix = colorstr('ClearML: ')
|
||||||
|
s = f"{prefix}run 'pip install clearml' to automatically track, visualize and remotely train YOLOv5 🚀 in ClearML"
|
||||||
|
self.logger.info(s)
|
||||||
|
|
||||||
# TensorBoard
|
# TensorBoard
|
||||||
s = self.save_dir
|
s = self.save_dir
|
||||||
@ -82,36 +96,57 @@ class Loggers():
|
|||||||
self.wandb = WandbLogger(self.opt, run_id)
|
self.wandb = WandbLogger(self.opt, run_id)
|
||||||
# temp warn. because nested artifacts not supported after 0.12.10
|
# temp warn. because nested artifacts not supported after 0.12.10
|
||||||
if pkg.parse_version(wandb.__version__) >= pkg.parse_version('0.12.11'):
|
if pkg.parse_version(wandb.__version__) >= pkg.parse_version('0.12.11'):
|
||||||
self.logger.warning(
|
s = "YOLOv5 temporarily requires wandb version 0.12.10 or below. Some features may not work as expected."
|
||||||
"YOLOv5 temporarily requires wandb version 0.12.10 or below. Some features may not work as expected."
|
self.logger.warning(s)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.wandb = None
|
self.wandb = None
|
||||||
|
|
||||||
|
# ClearML
|
||||||
|
if clearml and 'clearml' in self.include:
|
||||||
|
self.clearml = ClearmlLogger(self.opt, self.hyp)
|
||||||
|
else:
|
||||||
|
self.clearml = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def remote_dataset(self):
|
||||||
|
# Get data_dict if custom dataset artifact link is provided
|
||||||
|
data_dict = None
|
||||||
|
if self.clearml:
|
||||||
|
data_dict = self.clearml.data_dict
|
||||||
|
if self.wandb:
|
||||||
|
data_dict = self.wandb.data_dict
|
||||||
|
|
||||||
|
return data_dict
|
||||||
|
|
||||||
def on_train_start(self):
|
def on_train_start(self):
|
||||||
# Callback runs on train start
|
# Callback runs on train start
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_pretrain_routine_end(self):
|
def on_pretrain_routine_end(self, labels, names):
|
||||||
# Callback runs on pre-train routine end
|
# Callback runs on pre-train routine end
|
||||||
paths = self.save_dir.glob('*labels*.jpg') # training labels
|
if self.plots:
|
||||||
if self.wandb:
|
plot_labels(labels, names, self.save_dir)
|
||||||
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
|
paths = self.save_dir.glob('*labels*.jpg') # training labels
|
||||||
|
if self.wandb:
|
||||||
|
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
|
||||||
|
# if self.clearml:
|
||||||
|
# pass # ClearML saves these images automatically using hooks
|
||||||
|
|
||||||
def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
|
def on_train_batch_end(self, model, ni, imgs, targets, paths):
|
||||||
# Callback runs on train batch end
|
# Callback runs on train batch end
|
||||||
if plots:
|
# ni: number integrated batches (since train start)
|
||||||
if ni == 0:
|
if self.plots:
|
||||||
if not self.opt.sync_bn: # --sync known issue https://github.com/ultralytics/yolov5/issues/3754
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter('ignore') # suppress jit trace warning
|
|
||||||
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
|
|
||||||
if ni < 3:
|
if ni < 3:
|
||||||
f = self.save_dir / f'train_batch{ni}.jpg' # filename
|
f = self.save_dir / f'train_batch{ni}.jpg' # filename
|
||||||
plot_images(imgs, targets, paths, f)
|
plot_images(imgs, targets, paths, f)
|
||||||
if self.wandb and ni == 10:
|
if ni == 0 and self.tb and not self.opt.sync_bn:
|
||||||
|
log_tensorboard_graph(self.tb, model, imgsz=(self.opt.imgsz, self.opt.imgsz))
|
||||||
|
if ni == 10 and (self.wandb or self.clearml):
|
||||||
files = sorted(self.save_dir.glob('train*.jpg'))
|
files = sorted(self.save_dir.glob('train*.jpg'))
|
||||||
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
|
if self.wandb:
|
||||||
|
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
|
||||||
|
if self.clearml:
|
||||||
|
self.clearml.log_debug_samples(files, title='Mosaics')
|
||||||
|
|
||||||
def on_train_epoch_end(self, epoch):
|
def on_train_epoch_end(self, epoch):
|
||||||
# Callback runs on train epoch end
|
# Callback runs on train epoch end
|
||||||
@ -122,12 +157,17 @@ class Loggers():
|
|||||||
# Callback runs on val image end
|
# Callback runs on val image end
|
||||||
if self.wandb:
|
if self.wandb:
|
||||||
self.wandb.val_one_image(pred, predn, path, names, im)
|
self.wandb.val_one_image(pred, predn, path, names, im)
|
||||||
|
if self.clearml:
|
||||||
|
self.clearml.log_image_with_boxes(path, pred, names, im)
|
||||||
|
|
||||||
def on_val_end(self):
|
def on_val_end(self):
|
||||||
# Callback runs on val end
|
# Callback runs on val end
|
||||||
if self.wandb:
|
if self.wandb or self.clearml:
|
||||||
files = sorted(self.save_dir.glob('val*.jpg'))
|
files = sorted(self.save_dir.glob('val*.jpg'))
|
||||||
self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
|
if self.wandb:
|
||||||
|
self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
|
||||||
|
if self.clearml:
|
||||||
|
self.clearml.log_debug_samples(files, title='Validation')
|
||||||
|
|
||||||
def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
|
def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
|
||||||
# Callback runs at the end of each fit (train+val) epoch
|
# Callback runs at the end of each fit (train+val) epoch
|
||||||
@ -142,6 +182,10 @@ class Loggers():
|
|||||||
if self.tb:
|
if self.tb:
|
||||||
for k, v in x.items():
|
for k, v in x.items():
|
||||||
self.tb.add_scalar(k, v, epoch)
|
self.tb.add_scalar(k, v, epoch)
|
||||||
|
elif self.clearml: # log to ClearML if TensorBoard not used
|
||||||
|
for k, v in x.items():
|
||||||
|
title, series = k.split('/')
|
||||||
|
self.clearml.task.get_logger().report_scalar(title, series, v, epoch)
|
||||||
|
|
||||||
if self.wandb:
|
if self.wandb:
|
||||||
if best_fitness == fi:
|
if best_fitness == fi:
|
||||||
@ -151,21 +195,29 @@ class Loggers():
|
|||||||
self.wandb.log(x)
|
self.wandb.log(x)
|
||||||
self.wandb.end_epoch(best_result=best_fitness == fi)
|
self.wandb.end_epoch(best_result=best_fitness == fi)
|
||||||
|
|
||||||
|
if self.clearml:
|
||||||
|
self.clearml.current_epoch_logged_images = set() # reset epoch image limit
|
||||||
|
self.clearml.current_epoch += 1
|
||||||
|
|
||||||
def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
|
def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
|
||||||
# Callback runs on model save event
|
# Callback runs on model save event
|
||||||
if self.wandb:
|
if (epoch + 1) % self.opt.save_period == 0 and not final_epoch and self.opt.save_period != -1:
|
||||||
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
|
if self.wandb:
|
||||||
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
|
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
|
||||||
|
if self.clearml:
|
||||||
|
self.clearml.task.update_output_model(model_path=str(last),
|
||||||
|
model_name='Latest Model',
|
||||||
|
auto_delete_file=False)
|
||||||
|
|
||||||
def on_train_end(self, last, best, plots, epoch, results):
|
def on_train_end(self, last, best, epoch, results):
|
||||||
# Callback runs on training end
|
# Callback runs on training end, i.e. saving best model
|
||||||
if plots:
|
if self.plots:
|
||||||
plot_results(file=self.save_dir / 'results.csv') # save results.png
|
plot_results(file=self.save_dir / 'results.csv') # save results.png
|
||||||
files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
|
files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
|
||||||
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
|
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
|
||||||
self.logger.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
self.logger.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||||
|
|
||||||
if self.tb:
|
if self.tb and not self.clearml: # These images are already captured by ClearML by now, we don't want doubles
|
||||||
for f in files:
|
for f in files:
|
||||||
self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')
|
self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')
|
||||||
|
|
||||||
@ -180,8 +232,106 @@ class Loggers():
|
|||||||
aliases=['latest', 'best', 'stripped'])
|
aliases=['latest', 'best', 'stripped'])
|
||||||
self.wandb.finish_run()
|
self.wandb.finish_run()
|
||||||
|
|
||||||
def on_params_update(self, params):
|
if self.clearml and not self.opt.evolve:
|
||||||
|
self.clearml.task.update_output_model(model_path=str(best if best.exists() else last), name='Best Model')
|
||||||
|
|
||||||
|
def on_params_update(self, params: dict):
|
||||||
# Update hyperparams or configs of the experiment
|
# Update hyperparams or configs of the experiment
|
||||||
# params: A dict containing {param: value} pairs
|
|
||||||
if self.wandb:
|
if self.wandb:
|
||||||
self.wandb.wandb_run.config.update(params, allow_val_change=True)
|
self.wandb.wandb_run.config.update(params, allow_val_change=True)
|
||||||
|
|
||||||
|
|
||||||
|
class GenericLogger:
|
||||||
|
"""
|
||||||
|
YOLOv5 General purpose logger for non-task specific logging
|
||||||
|
Usage: from utils.loggers import GenericLogger; logger = GenericLogger(...)
|
||||||
|
Arguments
|
||||||
|
opt: Run arguments
|
||||||
|
console_logger: Console logger
|
||||||
|
include: loggers to include
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, opt, console_logger, include=('tb', 'wandb')):
|
||||||
|
# init default loggers
|
||||||
|
self.save_dir = Path(opt.save_dir)
|
||||||
|
self.include = include
|
||||||
|
self.console_logger = console_logger
|
||||||
|
self.csv = self.save_dir / 'results.csv' # CSV logger
|
||||||
|
if 'tb' in self.include:
|
||||||
|
prefix = colorstr('TensorBoard: ')
|
||||||
|
self.console_logger.info(
|
||||||
|
f"{prefix}Start with 'tensorboard --logdir {self.save_dir.parent}', view at http://localhost:6006/")
|
||||||
|
self.tb = SummaryWriter(str(self.save_dir))
|
||||||
|
|
||||||
|
if wandb and 'wandb' in self.include:
|
||||||
|
self.wandb = wandb.init(project=web_project_name(str(opt.project)),
|
||||||
|
name=None if opt.name == "exp" else opt.name,
|
||||||
|
config=opt)
|
||||||
|
else:
|
||||||
|
self.wandb = None
|
||||||
|
|
||||||
|
def log_metrics(self, metrics, epoch):
|
||||||
|
# Log metrics dictionary to all loggers
|
||||||
|
if self.csv:
|
||||||
|
keys, vals = list(metrics.keys()), list(metrics.values())
|
||||||
|
n = len(metrics) + 1 # number of cols
|
||||||
|
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
|
||||||
|
with open(self.csv, 'a') as f:
|
||||||
|
f.write(s + ('%23.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')
|
||||||
|
|
||||||
|
if self.tb:
|
||||||
|
for k, v in metrics.items():
|
||||||
|
self.tb.add_scalar(k, v, epoch)
|
||||||
|
|
||||||
|
if self.wandb:
|
||||||
|
self.wandb.log(metrics, step=epoch)
|
||||||
|
|
||||||
|
def log_images(self, files, name='Images', epoch=0):
|
||||||
|
# Log images to all loggers
|
||||||
|
files = [Path(f) for f in (files if isinstance(files, (tuple, list)) else [files])] # to Path
|
||||||
|
files = [f for f in files if f.exists()] # filter by exists
|
||||||
|
|
||||||
|
if self.tb:
|
||||||
|
for f in files:
|
||||||
|
self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')
|
||||||
|
|
||||||
|
if self.wandb:
|
||||||
|
self.wandb.log({name: [wandb.Image(str(f), caption=f.name) for f in files]}, step=epoch)
|
||||||
|
|
||||||
|
def log_graph(self, model, imgsz=(640, 640)):
|
||||||
|
# Log model graph to all loggers
|
||||||
|
if self.tb:
|
||||||
|
log_tensorboard_graph(self.tb, model, imgsz)
|
||||||
|
|
||||||
|
def log_model(self, model_path, epoch=0, metadata={}):
|
||||||
|
# Log model to all loggers
|
||||||
|
if self.wandb:
|
||||||
|
art = wandb.Artifact(name=f"run_{wandb.run.id}_model", type="model", metadata=metadata)
|
||||||
|
art.add_file(str(model_path))
|
||||||
|
wandb.log_artifact(art)
|
||||||
|
|
||||||
|
def update_params(self, params):
|
||||||
|
# Update the paramters logged
|
||||||
|
if self.wandb:
|
||||||
|
wandb.run.config.update(params, allow_val_change=True)
|
||||||
|
|
||||||
|
|
||||||
|
def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
|
||||||
|
# Log model graph to TensorBoard
|
||||||
|
try:
|
||||||
|
p = next(model.parameters()) # for device, type
|
||||||
|
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz # expand
|
||||||
|
im = torch.zeros((1, 3, *imgsz)).to(p.device).type_as(p) # input image (WARNING: must be zeros, not empty)
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter('ignore') # suppress jit trace warning
|
||||||
|
tb.add_graph(torch.jit.trace(de_parallel(model), im, strict=False), [])
|
||||||
|
except Exception as e:
|
||||||
|
print(f'WARNING: TensorBoard graph visualization failure {e}')
|
||||||
|
|
||||||
|
|
||||||
|
def web_project_name(project):
|
||||||
|
# Convert local project name to web project name
|
||||||
|
if not project.startswith('runs/train'):
|
||||||
|
return project
|
||||||
|
suffix = '-Classify' if project.endswith('-cls') else '-Segment' if project.endswith('-seg') else ''
|
||||||
|
return f'YOLOv5{suffix}'
|
||||||
|
222
utils/loggers/clearml/README.md
Normal file
222
utils/loggers/clearml/README.md
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
# ClearML Integration
|
||||||
|
|
||||||
|
<img align="center" src="https://github.com/thepycoder/clearml_screenshots/raw/main/logos_dark.png#gh-light-mode-only" alt="Clear|ML"><img align="center" src="https://github.com/thepycoder/clearml_screenshots/raw/main/logos_light.png#gh-dark-mode-only" alt="Clear|ML">
|
||||||
|
|
||||||
|
## About ClearML
|
||||||
|
|
||||||
|
[ClearML](https://cutt.ly/yolov5-tutorial-clearml) is an [open-source](https://github.com/allegroai/clearml) toolbox designed to save you time ⏱️.
|
||||||
|
|
||||||
|
🔨 Track every YOLOv5 training run in the <b>experiment manager</b>
|
||||||
|
|
||||||
|
🔧 Version and easily access your custom training data with the integrated ClearML <b>Data Versioning Tool</b>
|
||||||
|
|
||||||
|
🔦 <b>Remotely train and monitor</b> your YOLOv5 training runs using ClearML Agent
|
||||||
|
|
||||||
|
🔬 Get the very best mAP using ClearML <b>Hyperparameter Optimization</b>
|
||||||
|
|
||||||
|
🔭 Turn your newly trained <b>YOLOv5 model into an API</b> with just a few commands using ClearML Serving
|
||||||
|
|
||||||
|
<br />
|
||||||
|
And so much more. It's up to you how many of these tools you want to use, you can stick to the experiment manager, or chain them all together into an impressive pipeline!
|
||||||
|
<br />
|
||||||
|
<br />
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
<br />
|
||||||
|
<br />
|
||||||
|
|
||||||
|
## 🦾 Setting Things Up
|
||||||
|
|
||||||
|
To keep track of your experiments and/or data, ClearML needs to communicate to a server. You have 2 options to get one:
|
||||||
|
|
||||||
|
Either sign up for free to the [ClearML Hosted Service](https://cutt.ly/yolov5-tutorial-clearml) or you can set up your own server, see [here](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server). Even the server is open-source, so even if you're dealing with sensitive data, you should be good to go!
|
||||||
|
|
||||||
|
1. Install the `clearml` python package:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install clearml
|
||||||
|
```
|
||||||
|
|
||||||
|
1. Connect the ClearML SDK to the server by [creating credentials](https://app.clear.ml/settings/workspace-configuration) (go right top to Settings -> Workspace -> Create new credentials), then execute the command below and follow the instructions:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
clearml-init
|
||||||
|
```
|
||||||
|
|
||||||
|
That's it! You're done 😎
|
||||||
|
|
||||||
|
<br />
|
||||||
|
|
||||||
|
## 🚀 Training YOLOv5 With ClearML
|
||||||
|
|
||||||
|
To enable ClearML experiment tracking, simply install the ClearML pip package.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install clearml
|
||||||
|
```
|
||||||
|
|
||||||
|
This will enable integration with the YOLOv5 training script. Every training run from now on, will be captured and stored by the ClearML experiment manager. If you want to change the `project_name` or `task_name`, head over to our custom logger, where you can change it: `utils/loggers/clearml/clearml_utils.py`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train.py --img 640 --batch 16 --epochs 3 --data coco128.yaml --weights yolov5s.pt --cache
|
||||||
|
```
|
||||||
|
|
||||||
|
This will capture:
|
||||||
|
- Source code + uncommitted changes
|
||||||
|
- Installed packages
|
||||||
|
- (Hyper)parameters
|
||||||
|
- Model files (use `--save-period n` to save a checkpoint every n epochs)
|
||||||
|
- Console output
|
||||||
|
- Scalars (mAP_0.5, mAP_0.5:0.95, precision, recall, losses, learning rates, ...)
|
||||||
|
- General info such as machine details, runtime, creation date etc.
|
||||||
|
- All produced plots such as label correlogram and confusion matrix
|
||||||
|
- Images with bounding boxes per epoch
|
||||||
|
- Mosaic per epoch
|
||||||
|
- Validation images per epoch
|
||||||
|
- ...
|
||||||
|
|
||||||
|
That's a lot right? 🤯
|
||||||
|
Now, we can visualize all of this information in the ClearML UI to get an overview of our training progress. Add custom columns to the table view (such as e.g. mAP_0.5) so you can easily sort on the best performing model. Or select multiple experiments and directly compare them!
|
||||||
|
|
||||||
|
There even more we can do with all of this information, like hyperparameter optimization and remote execution, so keep reading if you want to see how that works!
|
||||||
|
|
||||||
|
<br />
|
||||||
|
|
||||||
|
## 🔗 Dataset Version Management
|
||||||
|
|
||||||
|
Versioning your data separately from your code is generally a good idea and makes it easy to aqcuire the latest version too. This repository supports supplying a dataset version ID and it will make sure to get the data if it's not there yet. Next to that, this workflow also saves the used dataset ID as part of the task parameters, so you will always know for sure which data was used in which experiment!
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### Prepare Your Dataset
|
||||||
|
|
||||||
|
The YOLOv5 repository supports a number of different datasets by using yaml files containing their information. By default datasets are downloaded to the `../datasets` folder in relation to the repository root folder. So if you downloaded the `coco128` dataset using the link in the yaml or with the scripts provided by yolov5, you get this folder structure:
|
||||||
|
|
||||||
|
```
|
||||||
|
..
|
||||||
|
|_ yolov5
|
||||||
|
|_ datasets
|
||||||
|
|_ coco128
|
||||||
|
|_ images
|
||||||
|
|_ labels
|
||||||
|
|_ LICENSE
|
||||||
|
|_ README.txt
|
||||||
|
```
|
||||||
|
But this can be any dataset you wish. Feel free to use your own, as long as you keep to this folder structure.
|
||||||
|
|
||||||
|
Next, ⚠️**copy the corresponding yaml file to the root of the dataset folder**⚠️. This yaml files contains the information ClearML will need to properly use the dataset. You can make this yourself too, of course, just follow the structure of the example yamls.
|
||||||
|
|
||||||
|
Basically we need the following keys: `path`, `train`, `test`, `val`, `nc`, `names`.
|
||||||
|
|
||||||
|
```
|
||||||
|
..
|
||||||
|
|_ yolov5
|
||||||
|
|_ datasets
|
||||||
|
|_ coco128
|
||||||
|
|_ images
|
||||||
|
|_ labels
|
||||||
|
|_ coco128.yaml # <---- HERE!
|
||||||
|
|_ LICENSE
|
||||||
|
|_ README.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### Upload Your Dataset
|
||||||
|
|
||||||
|
To get this dataset into ClearML as a versionned dataset, go to the dataset root folder and run the following command:
|
||||||
|
```bash
|
||||||
|
cd coco128
|
||||||
|
clearml-data sync --project YOLOv5 --name coco128 --folder .
|
||||||
|
```
|
||||||
|
|
||||||
|
The command `clearml-data sync` is actually a shorthand command. You could also run these commands one after the other:
|
||||||
|
```bash
|
||||||
|
# Optionally add --parent <parent_dataset_id> if you want to base
|
||||||
|
# this version on another dataset version, so no duplicate files are uploaded!
|
||||||
|
clearml-data create --name coco128 --project YOLOv5
|
||||||
|
clearml-data add --files .
|
||||||
|
clearml-data close
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run Training Using A ClearML Dataset
|
||||||
|
|
||||||
|
Now that you have a ClearML dataset, you can very simply use it to train custom YOLOv5 🚀 models!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train.py --img 640 --batch 16 --epochs 3 --data clearml://<your_dataset_id> --weights yolov5s.pt --cache
|
||||||
|
```
|
||||||
|
|
||||||
|
<br />
|
||||||
|
|
||||||
|
## 👀 Hyperparameter Optimization
|
||||||
|
|
||||||
|
Now that we have our experiments and data versioned, it's time to take a look at what we can build on top!
|
||||||
|
|
||||||
|
Using the code information, installed packages and environment details, the experiment itself is now **completely reproducible**. In fact, ClearML allows you to clone an experiment and even change its parameters. We can then just rerun it with these new parameters automatically, this is basically what HPO does!
|
||||||
|
|
||||||
|
To **run hyperparameter optimization locally**, we've included a pre-made script for you. Just make sure a training task has been run at least once, so it is in the ClearML experiment manager, we will essentially clone it and change its hyperparameters.
|
||||||
|
|
||||||
|
You'll need to fill in the ID of this `template task` in the script found at `utils/loggers/clearml/hpo.py` and then just run it :) You can change `task.execute_locally()` to `task.execute()` to put it in a ClearML queue and have a remote agent work on it instead.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# To use optuna, install it first, otherwise you can change the optimizer to just be RandomSearch
|
||||||
|
pip install optuna
|
||||||
|
python utils/loggers/clearml/hpo.py
|
||||||
|
```
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## 🤯 Remote Execution (advanced)
|
||||||
|
|
||||||
|
Running HPO locally is really handy, but what if we want to run our experiments on a remote machine instead? Maybe you have access to a very powerful GPU machine on-site or you have some budget to use cloud GPUs.
|
||||||
|
This is where the ClearML Agent comes into play. Check out what the agent can do here:
|
||||||
|
|
||||||
|
- [YouTube video](https://youtu.be/MX3BrXnaULs)
|
||||||
|
- [Documentation](https://clear.ml/docs/latest/docs/clearml_agent)
|
||||||
|
|
||||||
|
In short: every experiment tracked by the experiment manager contains enough information to reproduce it on a different machine (installed packages, uncommitted changes etc.). So a ClearML agent does just that: it listens to a queue for incoming tasks and when it finds one, it recreates the environment and runs it while still reporting scalars, plots etc. to the experiment manager.
|
||||||
|
|
||||||
|
You can turn any machine (a cloud VM, a local GPU machine, your own laptop ... ) into a ClearML agent by simply running:
|
||||||
|
```bash
|
||||||
|
clearml-agent daemon --queue <queues_to_listen_to> [--docker]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Cloning, Editing And Enqueuing
|
||||||
|
|
||||||
|
With our agent running, we can give it some work. Remember from the HPO section that we can clone a task and edit the hyperparameters? We can do that from the interface too!
|
||||||
|
|
||||||
|
🪄 Clone the experiment by right clicking it
|
||||||
|
|
||||||
|
🎯 Edit the hyperparameters to what you wish them to be
|
||||||
|
|
||||||
|
⏳ Enqueue the task to any of the queues by right clicking it
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### Executing A Task Remotely
|
||||||
|
|
||||||
|
Now you can clone a task like we explained above, or simply mark your current script by adding `task.execute_remotely()` and on execution it will be put into a queue, for the agent to start working on!
|
||||||
|
|
||||||
|
To run the YOLOv5 training script remotely, all you have to do is add this line to the training.py script after the clearml logger has been instatiated:
|
||||||
|
```python
|
||||||
|
# ...
|
||||||
|
# Loggers
|
||||||
|
data_dict = None
|
||||||
|
if RANK in {-1, 0}:
|
||||||
|
loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
|
||||||
|
if loggers.clearml:
|
||||||
|
loggers.clearml.task.execute_remotely(queue='my_queue') # <------ ADD THIS LINE
|
||||||
|
# Data_dict is either None is user did not choose for ClearML dataset or is filled in by ClearML
|
||||||
|
data_dict = loggers.clearml.data_dict
|
||||||
|
# ...
|
||||||
|
```
|
||||||
|
When running the training script after this change, python will run the script up until that line, after which it will package the code and send it to the queue instead!
|
||||||
|
|
||||||
|
### Autoscaling workers
|
||||||
|
|
||||||
|
ClearML comes with autoscalers too! This tool will automatically spin up new remote machines in the cloud of your choice (AWS, GCP, Azure) and turn them into ClearML agents for you whenever there are experiments detected in the queue. Once the tasks are processed, the autoscaler will automatically shut down the remote machines and you stop paying!
|
||||||
|
|
||||||
|
Check out the autoscalers getting started video below.
|
||||||
|
|
||||||
|
[](https://youtu.be/j4XVMAaUt3E)
|
0
utils/loggers/clearml/__init__.py
Normal file
0
utils/loggers/clearml/__init__.py
Normal file
156
utils/loggers/clearml/clearml_utils.py
Normal file
156
utils/loggers/clearml/clearml_utils.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
"""Main Logger class for ClearML experiment tracking."""
|
||||||
|
import glob
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from utils.plots import Annotator, colors
|
||||||
|
|
||||||
|
try:
|
||||||
|
import clearml
|
||||||
|
from clearml import Dataset, Task
|
||||||
|
assert hasattr(clearml, '__version__') # verify package import not local dir
|
||||||
|
except (ImportError, AssertionError):
|
||||||
|
clearml = None
|
||||||
|
|
||||||
|
|
||||||
|
def construct_dataset(clearml_info_string):
|
||||||
|
"""Load in a clearml dataset and fill the internal data_dict with its contents.
|
||||||
|
"""
|
||||||
|
dataset_id = clearml_info_string.replace('clearml://', '')
|
||||||
|
dataset = Dataset.get(dataset_id=dataset_id)
|
||||||
|
dataset_root_path = Path(dataset.get_local_copy())
|
||||||
|
|
||||||
|
# We'll search for the yaml file definition in the dataset
|
||||||
|
yaml_filenames = list(glob.glob(str(dataset_root_path / "*.yaml")) + glob.glob(str(dataset_root_path / "*.yml")))
|
||||||
|
if len(yaml_filenames) > 1:
|
||||||
|
raise ValueError('More than one yaml file was found in the dataset root, cannot determine which one contains '
|
||||||
|
'the dataset definition this way.')
|
||||||
|
elif len(yaml_filenames) == 0:
|
||||||
|
raise ValueError('No yaml definition found in dataset root path, check that there is a correct yaml file '
|
||||||
|
'inside the dataset root path.')
|
||||||
|
with open(yaml_filenames[0]) as f:
|
||||||
|
dataset_definition = yaml.safe_load(f)
|
||||||
|
|
||||||
|
assert set(dataset_definition.keys()).issuperset(
|
||||||
|
{'train', 'test', 'val', 'nc', 'names'}
|
||||||
|
), "The right keys were not found in the yaml file, make sure it at least has the following keys: ('train', 'test', 'val', 'nc', 'names')"
|
||||||
|
|
||||||
|
data_dict = dict()
|
||||||
|
data_dict['train'] = str(
|
||||||
|
(dataset_root_path / dataset_definition['train']).resolve()) if dataset_definition['train'] else None
|
||||||
|
data_dict['test'] = str(
|
||||||
|
(dataset_root_path / dataset_definition['test']).resolve()) if dataset_definition['test'] else None
|
||||||
|
data_dict['val'] = str(
|
||||||
|
(dataset_root_path / dataset_definition['val']).resolve()) if dataset_definition['val'] else None
|
||||||
|
data_dict['nc'] = dataset_definition['nc']
|
||||||
|
data_dict['names'] = dataset_definition['names']
|
||||||
|
|
||||||
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
|
class ClearmlLogger:
|
||||||
|
"""Log training runs, datasets, models, and predictions to ClearML.
|
||||||
|
|
||||||
|
This logger sends information to ClearML at app.clear.ml or to your own hosted server. By default,
|
||||||
|
this information includes hyperparameters, system configuration and metrics, model metrics, code information and
|
||||||
|
basic data metrics and analyses.
|
||||||
|
|
||||||
|
By providing additional command line arguments to train.py, datasets,
|
||||||
|
models and predictions can also be logged.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, opt, hyp):
|
||||||
|
"""
|
||||||
|
- Initialize ClearML Task, this object will capture the experiment
|
||||||
|
- Upload dataset version to ClearML Data if opt.upload_dataset is True
|
||||||
|
|
||||||
|
arguments:
|
||||||
|
opt (namespace) -- Commandline arguments for this run
|
||||||
|
hyp (dict) -- Hyperparameters for this run
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.current_epoch = 0
|
||||||
|
# Keep tracked of amount of logged images to enforce a limit
|
||||||
|
self.current_epoch_logged_images = set()
|
||||||
|
# Maximum number of images to log to clearML per epoch
|
||||||
|
self.max_imgs_to_log_per_epoch = 16
|
||||||
|
# Get the interval of epochs when bounding box images should be logged
|
||||||
|
self.bbox_interval = opt.bbox_interval
|
||||||
|
self.clearml = clearml
|
||||||
|
self.task = None
|
||||||
|
self.data_dict = None
|
||||||
|
if self.clearml:
|
||||||
|
self.task = Task.init(
|
||||||
|
project_name='YOLOv5',
|
||||||
|
task_name='training',
|
||||||
|
tags=['YOLOv5'],
|
||||||
|
output_uri=True,
|
||||||
|
auto_connect_frameworks={'pytorch': False}
|
||||||
|
# We disconnect pytorch auto-detection, because we added manual model save points in the code
|
||||||
|
)
|
||||||
|
# ClearML's hooks will already grab all general parameters
|
||||||
|
# Only the hyperparameters coming from the yaml config file
|
||||||
|
# will have to be added manually!
|
||||||
|
self.task.connect(hyp, name='Hyperparameters')
|
||||||
|
|
||||||
|
# Get ClearML Dataset Version if requested
|
||||||
|
if opt.data.startswith('clearml://'):
|
||||||
|
# data_dict should have the following keys:
|
||||||
|
# names, nc (number of classes), test, train, val (all three relative paths to ../datasets)
|
||||||
|
self.data_dict = construct_dataset(opt.data)
|
||||||
|
# Set data to data_dict because wandb will crash without this information and opt is the best way
|
||||||
|
# to give it to them
|
||||||
|
opt.data = self.data_dict
|
||||||
|
|
||||||
|
def log_debug_samples(self, files, title='Debug Samples'):
|
||||||
|
"""
|
||||||
|
Log files (images) as debug samples in the ClearML task.
|
||||||
|
|
||||||
|
arguments:
|
||||||
|
files (List(PosixPath)) a list of file paths in PosixPath format
|
||||||
|
title (str) A title that groups together images with the same values
|
||||||
|
"""
|
||||||
|
for f in files:
|
||||||
|
if f.exists():
|
||||||
|
it = re.search(r'_batch(\d+)', f.name)
|
||||||
|
iteration = int(it.groups()[0]) if it else 0
|
||||||
|
self.task.get_logger().report_image(title=title,
|
||||||
|
series=f.name.replace(it.group(), ''),
|
||||||
|
local_path=str(f),
|
||||||
|
iteration=iteration)
|
||||||
|
|
||||||
|
def log_image_with_boxes(self, image_path, boxes, class_names, image, conf_threshold=0.25):
|
||||||
|
"""
|
||||||
|
Draw the bounding boxes on a single image and report the result as a ClearML debug sample.
|
||||||
|
|
||||||
|
arguments:
|
||||||
|
image_path (PosixPath) the path the original image file
|
||||||
|
boxes (list): list of scaled predictions in the format - [xmin, ymin, xmax, ymax, confidence, class]
|
||||||
|
class_names (dict): dict containing mapping of class int to class name
|
||||||
|
image (Tensor): A torch tensor containing the actual image data
|
||||||
|
"""
|
||||||
|
if len(self.current_epoch_logged_images) < self.max_imgs_to_log_per_epoch and self.current_epoch >= 0:
|
||||||
|
# Log every bbox_interval times and deduplicate for any intermittend extra eval runs
|
||||||
|
if self.current_epoch % self.bbox_interval == 0 and image_path not in self.current_epoch_logged_images:
|
||||||
|
im = np.ascontiguousarray(np.moveaxis(image.mul(255).clamp(0, 255).byte().cpu().numpy(), 0, 2))
|
||||||
|
annotator = Annotator(im=im, pil=True)
|
||||||
|
for i, (conf, class_nr, box) in enumerate(zip(boxes[:, 4], boxes[:, 5], boxes[:, :4])):
|
||||||
|
color = colors(i)
|
||||||
|
|
||||||
|
class_name = class_names[int(class_nr)]
|
||||||
|
confidence_percentage = round(float(conf) * 100, 2)
|
||||||
|
label = f"{class_name}: {confidence_percentage}%"
|
||||||
|
|
||||||
|
if conf > conf_threshold:
|
||||||
|
annotator.rectangle(box.cpu().numpy(), outline=color)
|
||||||
|
annotator.box_label(box.cpu().numpy(), label=label, color=color)
|
||||||
|
|
||||||
|
annotated_image = annotator.result()
|
||||||
|
self.task.get_logger().report_image(title='Bounding Boxes',
|
||||||
|
series=image_path.name,
|
||||||
|
iteration=self.current_epoch,
|
||||||
|
image=annotated_image)
|
||||||
|
self.current_epoch_logged_images.add(image_path)
|
84
utils/loggers/clearml/hpo.py
Normal file
84
utils/loggers/clearml/hpo.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
from clearml import Task
|
||||||
|
# Connecting ClearML with the current process,
|
||||||
|
# from here on everything is logged automatically
|
||||||
|
from clearml.automation import HyperParameterOptimizer, UniformParameterRange
|
||||||
|
from clearml.automation.optuna import OptimizerOptuna
|
||||||
|
|
||||||
|
task = Task.init(project_name='Hyper-Parameter Optimization',
|
||||||
|
task_name='YOLOv5',
|
||||||
|
task_type=Task.TaskTypes.optimizer,
|
||||||
|
reuse_last_task_id=False)
|
||||||
|
|
||||||
|
# Example use case:
|
||||||
|
optimizer = HyperParameterOptimizer(
|
||||||
|
# This is the experiment we want to optimize
|
||||||
|
base_task_id='<your_template_task_id>',
|
||||||
|
# here we define the hyper-parameters to optimize
|
||||||
|
# Notice: The parameter name should exactly match what you see in the UI: <section_name>/<parameter>
|
||||||
|
# For Example, here we see in the base experiment a section Named: "General"
|
||||||
|
# under it a parameter named "batch_size", this becomes "General/batch_size"
|
||||||
|
# If you have `argparse` for example, then arguments will appear under the "Args" section,
|
||||||
|
# and you should instead pass "Args/batch_size"
|
||||||
|
hyper_parameters=[
|
||||||
|
UniformParameterRange('Hyperparameters/lr0', min_value=1e-5, max_value=1e-1),
|
||||||
|
UniformParameterRange('Hyperparameters/lrf', min_value=0.01, max_value=1.0),
|
||||||
|
UniformParameterRange('Hyperparameters/momentum', min_value=0.6, max_value=0.98),
|
||||||
|
UniformParameterRange('Hyperparameters/weight_decay', min_value=0.0, max_value=0.001),
|
||||||
|
UniformParameterRange('Hyperparameters/warmup_epochs', min_value=0.0, max_value=5.0),
|
||||||
|
UniformParameterRange('Hyperparameters/warmup_momentum', min_value=0.0, max_value=0.95),
|
||||||
|
UniformParameterRange('Hyperparameters/warmup_bias_lr', min_value=0.0, max_value=0.2),
|
||||||
|
UniformParameterRange('Hyperparameters/box', min_value=0.02, max_value=0.2),
|
||||||
|
UniformParameterRange('Hyperparameters/cls', min_value=0.2, max_value=4.0),
|
||||||
|
UniformParameterRange('Hyperparameters/cls_pw', min_value=0.5, max_value=2.0),
|
||||||
|
UniformParameterRange('Hyperparameters/obj', min_value=0.2, max_value=4.0),
|
||||||
|
UniformParameterRange('Hyperparameters/obj_pw', min_value=0.5, max_value=2.0),
|
||||||
|
UniformParameterRange('Hyperparameters/iou_t', min_value=0.1, max_value=0.7),
|
||||||
|
UniformParameterRange('Hyperparameters/anchor_t', min_value=2.0, max_value=8.0),
|
||||||
|
UniformParameterRange('Hyperparameters/fl_gamma', min_value=0.0, max_value=4.0),
|
||||||
|
UniformParameterRange('Hyperparameters/hsv_h', min_value=0.0, max_value=0.1),
|
||||||
|
UniformParameterRange('Hyperparameters/hsv_s', min_value=0.0, max_value=0.9),
|
||||||
|
UniformParameterRange('Hyperparameters/hsv_v', min_value=0.0, max_value=0.9),
|
||||||
|
UniformParameterRange('Hyperparameters/degrees', min_value=0.0, max_value=45.0),
|
||||||
|
UniformParameterRange('Hyperparameters/translate', min_value=0.0, max_value=0.9),
|
||||||
|
UniformParameterRange('Hyperparameters/scale', min_value=0.0, max_value=0.9),
|
||||||
|
UniformParameterRange('Hyperparameters/shear', min_value=0.0, max_value=10.0),
|
||||||
|
UniformParameterRange('Hyperparameters/perspective', min_value=0.0, max_value=0.001),
|
||||||
|
UniformParameterRange('Hyperparameters/flipud', min_value=0.0, max_value=1.0),
|
||||||
|
UniformParameterRange('Hyperparameters/fliplr', min_value=0.0, max_value=1.0),
|
||||||
|
UniformParameterRange('Hyperparameters/mosaic', min_value=0.0, max_value=1.0),
|
||||||
|
UniformParameterRange('Hyperparameters/mixup', min_value=0.0, max_value=1.0),
|
||||||
|
UniformParameterRange('Hyperparameters/copy_paste', min_value=0.0, max_value=1.0)],
|
||||||
|
# this is the objective metric we want to maximize/minimize
|
||||||
|
objective_metric_title='metrics',
|
||||||
|
objective_metric_series='mAP_0.5',
|
||||||
|
# now we decide if we want to maximize it or minimize it (accuracy we maximize)
|
||||||
|
objective_metric_sign='max',
|
||||||
|
# let us limit the number of concurrent experiments,
|
||||||
|
# this in turn will make sure we do dont bombard the scheduler with experiments.
|
||||||
|
# if we have an auto-scaler connected, this, by proxy, will limit the number of machine
|
||||||
|
max_number_of_concurrent_tasks=1,
|
||||||
|
# this is the optimizer class (actually doing the optimization)
|
||||||
|
# Currently, we can choose from GridSearch, RandomSearch or OptimizerBOHB (Bayesian optimization Hyper-Band)
|
||||||
|
optimizer_class=OptimizerOptuna,
|
||||||
|
# If specified only the top K performing Tasks will be kept, the others will be automatically archived
|
||||||
|
save_top_k_tasks_only=5, # 5,
|
||||||
|
compute_time_limit=None,
|
||||||
|
total_max_jobs=20,
|
||||||
|
min_iteration_per_job=None,
|
||||||
|
max_iteration_per_job=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# report every 10 seconds, this is way too often, but we are testing here
|
||||||
|
optimizer.set_report_period(10 / 60)
|
||||||
|
# You can also use the line below instead to run all the optimizer tasks locally, without using queues or agent
|
||||||
|
# an_optimizer.start_locally(job_complete_callback=job_complete_callback)
|
||||||
|
# set the time limit for the optimization process (2 hours)
|
||||||
|
optimizer.set_time_limit(in_minutes=120.0)
|
||||||
|
# Start the optimization process in the local environment
|
||||||
|
optimizer.start_locally()
|
||||||
|
# wait until process is done (notice we are controlling the optimization process in the background)
|
||||||
|
optimizer.wait()
|
||||||
|
# make sure background optimization stopped
|
||||||
|
optimizer.stop()
|
||||||
|
|
||||||
|
print('We are done, good bye')
|
@ -43,6 +43,9 @@ def check_wandb_config_file(data_config_file):
|
|||||||
def check_wandb_dataset(data_file):
|
def check_wandb_dataset(data_file):
|
||||||
is_trainset_wandb_artifact = False
|
is_trainset_wandb_artifact = False
|
||||||
is_valset_wandb_artifact = False
|
is_valset_wandb_artifact = False
|
||||||
|
if isinstance(data_file, dict):
|
||||||
|
# In that case another dataset manager has already processed it and we don't have to
|
||||||
|
return data_file
|
||||||
if check_file(data_file) and data_file.endswith('.yaml'):
|
if check_file(data_file) and data_file.endswith('.yaml'):
|
||||||
with open(data_file, errors='ignore') as f:
|
with open(data_file, errors='ignore') as f:
|
||||||
data_dict = yaml.safe_load(f)
|
data_dict = yaml.safe_load(f)
|
||||||
@ -121,7 +124,7 @@ class WandbLogger():
|
|||||||
"""
|
"""
|
||||||
- Initialize WandbLogger instance
|
- Initialize WandbLogger instance
|
||||||
- Upload dataset if opt.upload_dataset is True
|
- Upload dataset if opt.upload_dataset is True
|
||||||
- Setup trainig processes if job_type is 'Training'
|
- Setup training processes if job_type is 'Training'
|
||||||
|
|
||||||
arguments:
|
arguments:
|
||||||
opt (namespace) -- Commandline arguments for this run
|
opt (namespace) -- Commandline arguments for this run
|
||||||
@ -170,7 +173,11 @@ class WandbLogger():
|
|||||||
if not opt.resume:
|
if not opt.resume:
|
||||||
self.wandb_artifact_data_dict = self.check_and_upload_dataset(opt)
|
self.wandb_artifact_data_dict = self.check_and_upload_dataset(opt)
|
||||||
|
|
||||||
if opt.resume:
|
if isinstance(opt.data, dict):
|
||||||
|
# This means another dataset manager has already processed the dataset info (e.g. ClearML)
|
||||||
|
# and they will have stored the already processed dict in opt.data
|
||||||
|
self.data_dict = opt.data
|
||||||
|
elif opt.resume:
|
||||||
# resume from artifact
|
# resume from artifact
|
||||||
if isinstance(opt.resume, str) and opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
|
if isinstance(opt.resume, str) and opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
|
||||||
self.data_dict = dict(self.wandb_run.config.data_dict)
|
self.data_dict = dict(self.wandb_run.config.data_dict)
|
||||||
|
@ -11,6 +11,8 @@ import matplotlib.pyplot as plt
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from utils import TryExcept, threaded
|
||||||
|
|
||||||
|
|
||||||
def fitness(x):
|
def fitness(x):
|
||||||
# Model fitness as a weighted combination of metrics
|
# Model fitness as a weighted combination of metrics
|
||||||
@ -139,6 +141,12 @@ class ConfusionMatrix:
|
|||||||
Returns:
|
Returns:
|
||||||
None, updates confusion matrix accordingly
|
None, updates confusion matrix accordingly
|
||||||
"""
|
"""
|
||||||
|
if detections is None:
|
||||||
|
gt_classes = labels.int()
|
||||||
|
for gc in gt_classes:
|
||||||
|
self.matrix[self.nc, gc] += 1 # background FN
|
||||||
|
return
|
||||||
|
|
||||||
detections = detections[detections[:, 4] > self.conf]
|
detections = detections[detections[:, 4] > self.conf]
|
||||||
gt_classes = labels[:, 0].int()
|
gt_classes = labels[:, 0].int()
|
||||||
detection_classes = detections[:, 5].int()
|
detection_classes = detections[:, 5].int()
|
||||||
@ -178,35 +186,35 @@ class ConfusionMatrix:
|
|||||||
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
|
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
|
||||||
return tp[:-1], fp[:-1] # remove background class
|
return tp[:-1], fp[:-1] # remove background class
|
||||||
|
|
||||||
|
@TryExcept('WARNING: ConfusionMatrix plot failure: ')
|
||||||
def plot(self, normalize=True, save_dir='', names=()):
|
def plot(self, normalize=True, save_dir='', names=()):
|
||||||
try:
|
import seaborn as sn
|
||||||
import seaborn as sn
|
|
||||||
|
|
||||||
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
|
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
|
||||||
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
||||||
|
|
||||||
fig = plt.figure(figsize=(12, 9), tight_layout=True)
|
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
|
||||||
nc, nn = self.nc, len(names) # number of classes, names
|
nc, nn = self.nc, len(names) # number of classes, names
|
||||||
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
||||||
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
||||||
sn.heatmap(array,
|
sn.heatmap(array,
|
||||||
annot=nc < 30,
|
ax=ax,
|
||||||
annot_kws={
|
annot=nc < 30,
|
||||||
"size": 8},
|
annot_kws={
|
||||||
cmap='Blues',
|
"size": 8},
|
||||||
fmt='.2f',
|
cmap='Blues',
|
||||||
square=True,
|
fmt='.2f',
|
||||||
vmin=0.0,
|
square=True,
|
||||||
xticklabels=names + ['background FP'] if labels else "auto",
|
vmin=0.0,
|
||||||
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
|
xticklabels=names + ['background FP'] if labels else "auto",
|
||||||
fig.axes[0].set_xlabel('True')
|
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
|
||||||
fig.axes[0].set_ylabel('Predicted')
|
ax.set_ylabel('True')
|
||||||
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
|
ax.set_ylabel('Predicted')
|
||||||
plt.close()
|
ax.set_title('Confusion Matrix')
|
||||||
except Exception as e:
|
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
|
||||||
print(f'WARNING: ConfusionMatrix plot failure: {e}')
|
plt.close(fig)
|
||||||
|
|
||||||
def print(self):
|
def print(self):
|
||||||
for i in range(self.nc + 1):
|
for i in range(self.nc + 1):
|
||||||
@ -313,6 +321,7 @@ def wh_iou(wh1, wh2, eps=1e-7):
|
|||||||
# Plots ----------------------------------------------------------------------------------------------------------------
|
# Plots ----------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@threaded
|
||||||
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
|
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
|
||||||
# Precision-recall curve
|
# Precision-recall curve
|
||||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||||
@ -329,11 +338,13 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
|
|||||||
ax.set_ylabel('Precision')
|
ax.set_ylabel('Precision')
|
||||||
ax.set_xlim(0, 1)
|
ax.set_xlim(0, 1)
|
||||||
ax.set_ylim(0, 1)
|
ax.set_ylim(0, 1)
|
||||||
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
||||||
|
ax.set_title('Precision-Recall Curve')
|
||||||
fig.savefig(save_dir, dpi=250)
|
fig.savefig(save_dir, dpi=250)
|
||||||
plt.close()
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
@threaded
|
||||||
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
|
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
|
||||||
# Metric-confidence curve
|
# Metric-confidence curve
|
||||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||||
@ -350,6 +361,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
|
|||||||
ax.set_ylabel(ylabel)
|
ax.set_ylabel(ylabel)
|
||||||
ax.set_xlim(0, 1)
|
ax.set_xlim(0, 1)
|
||||||
ax.set_ylim(0, 1)
|
ax.set_ylim(0, 1)
|
||||||
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
||||||
|
ax.set_title(f'{ylabel}-Confidence Curve')
|
||||||
fig.savefig(save_dir, dpi=250)
|
fig.savefig(save_dir, dpi=250)
|
||||||
plt.close()
|
plt.close(fig)
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
Plotting utils
|
Plotting utils
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from copy import copy
|
from copy import copy
|
||||||
@ -18,8 +19,9 @@ import seaborn as sn
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords,
|
from utils import TryExcept, threaded
|
||||||
increment_path, is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh)
|
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path,
|
||||||
|
is_ascii, xywh2xyxy, xyxy2xywh)
|
||||||
from utils.metrics import fitness
|
from utils.metrics import fitness
|
||||||
|
|
||||||
# Settings
|
# Settings
|
||||||
@ -115,10 +117,12 @@ class Annotator:
|
|||||||
# Add rectangle to image (PIL-only)
|
# Add rectangle to image (PIL-only)
|
||||||
self.draw.rectangle(xy, fill, outline, width)
|
self.draw.rectangle(xy, fill, outline, width)
|
||||||
|
|
||||||
def text(self, xy, text, txt_color=(255, 255, 255)):
|
def text(self, xy, text, txt_color=(255, 255, 255), anchor='top'):
|
||||||
# Add text to image (PIL-only)
|
# Add text to image (PIL-only)
|
||||||
w, h = self.font.getsize(text) # text width, height
|
if anchor == 'bottom': # start y from font bottom
|
||||||
self.draw.text((xy[0], xy[1] - h + 1), text, fill=txt_color, font=self.font)
|
w, h = self.font.getsize(text) # text width, height
|
||||||
|
xy[1] += 1 - h
|
||||||
|
self.draw.text(xy, text, fill=txt_color, font=self.font)
|
||||||
|
|
||||||
def result(self):
|
def result(self):
|
||||||
# Return annotated image as array
|
# Return annotated image as array
|
||||||
@ -148,6 +152,7 @@ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detec
|
|||||||
ax[i].axis('off')
|
ax[i].axis('off')
|
||||||
|
|
||||||
LOGGER.info(f'Saving {f}... ({n}/{channels})')
|
LOGGER.info(f'Saving {f}... ({n}/{channels})')
|
||||||
|
plt.title('Features')
|
||||||
plt.savefig(f, dpi=300, bbox_inches='tight')
|
plt.savefig(f, dpi=300, bbox_inches='tight')
|
||||||
plt.close()
|
plt.close()
|
||||||
np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
|
np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
|
||||||
@ -179,8 +184,7 @@ def output_to_target(output):
|
|||||||
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
|
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
|
||||||
targets = []
|
targets = []
|
||||||
for i, o in enumerate(output):
|
for i, o in enumerate(output):
|
||||||
for *box, conf, cls in o.cpu().numpy():
|
targets.extend([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf] for *box, conf, cls in o.cpu().numpy())
|
||||||
targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
|
|
||||||
return np.array(targets)
|
return np.array(targets)
|
||||||
|
|
||||||
|
|
||||||
@ -220,7 +224,7 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
|
|||||||
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
||||||
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
|
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
|
||||||
if paths:
|
if paths:
|
||||||
annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
|
annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
|
||||||
if len(targets) > 0:
|
if len(targets) > 0:
|
||||||
ti = targets[targets[:, 0] == i] # image targets
|
ti = targets[targets[:, 0] == i] # image targets
|
||||||
boxes = xywh2xyxy(ti[:, 2:6]).T
|
boxes = xywh2xyxy(ti[:, 2:6]).T
|
||||||
@ -338,8 +342,7 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_
|
|||||||
plt.savefig(f, dpi=300)
|
plt.savefig(f, dpi=300)
|
||||||
|
|
||||||
|
|
||||||
@try_except # known issue https://github.com/ultralytics/yolov5/issues/5395
|
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
|
||||||
@Timeout(30) # known issue https://github.com/ultralytics/yolov5/issues/5611
|
|
||||||
def plot_labels(labels, names=(), save_dir=Path('')):
|
def plot_labels(labels, names=(), save_dir=Path('')):
|
||||||
# plot dataset labels
|
# plot dataset labels
|
||||||
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
||||||
@ -356,10 +359,8 @@ def plot_labels(labels, names=(), save_dir=Path('')):
|
|||||||
matplotlib.use('svg') # faster
|
matplotlib.use('svg') # faster
|
||||||
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
|
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
|
||||||
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
||||||
try: # color histogram bars by class
|
with contextlib.suppress(Exception): # color histogram bars by class
|
||||||
[y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
|
[y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
ax[0].set_ylabel('instances')
|
ax[0].set_ylabel('instances')
|
||||||
if 0 < len(names) < 30:
|
if 0 < len(names) < 30:
|
||||||
ax[0].set_xticks(range(len(names)))
|
ax[0].set_xticks(range(len(names)))
|
||||||
@ -387,6 +388,35 @@ def plot_labels(labels, names=(), save_dir=Path('')):
|
|||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def imshow_cls(im, labels=None, pred=None, names=None, nmax=25, verbose=False, f=Path('images.jpg')):
|
||||||
|
# Show classification image grid with labels (optional) and predictions (optional)
|
||||||
|
from utils.augmentations import denormalize
|
||||||
|
|
||||||
|
names = names or [f'class{i}' for i in range(1000)]
|
||||||
|
blocks = torch.chunk(denormalize(im.clone()).cpu().float(), len(im),
|
||||||
|
dim=0) # select batch index 0, block by channels
|
||||||
|
n = min(len(blocks), nmax) # number of plots
|
||||||
|
m = min(8, round(n ** 0.5)) # 8 x 8 default
|
||||||
|
fig, ax = plt.subplots(math.ceil(n / m), m) # 8 rows x n/8 cols
|
||||||
|
ax = ax.ravel() if m > 1 else [ax]
|
||||||
|
# plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
||||||
|
for i in range(n):
|
||||||
|
ax[i].imshow(blocks[i].squeeze().permute((1, 2, 0)).numpy().clip(0.0, 1.0))
|
||||||
|
ax[i].axis('off')
|
||||||
|
if labels is not None:
|
||||||
|
s = names[labels[i]] + (f'—{names[pred[i]]}' if pred is not None else '')
|
||||||
|
ax[i].set_title(s, fontsize=8, verticalalignment='top')
|
||||||
|
plt.savefig(f, dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
if verbose:
|
||||||
|
LOGGER.info(f"Saving {f}")
|
||||||
|
if labels is not None:
|
||||||
|
LOGGER.info('True: ' + ' '.join(f'{names[i]:3s}' for i in labels[:nmax]))
|
||||||
|
if pred is not None:
|
||||||
|
LOGGER.info('Predicted:' + ' '.join(f'{names[i]:3s}' for i in pred[:nmax]))
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
def plot_evolve(evolve_csv='path/to/evolve.csv'): # from utils.plots import *; plot_evolve()
|
def plot_evolve(evolve_csv='path/to/evolve.csv'): # from utils.plots import *; plot_evolve()
|
||||||
# Plot evolve.csv hyp evolution results
|
# Plot evolve.csv hyp evolution results
|
||||||
evolve_csv = Path(evolve_csv)
|
evolve_csv = Path(evolve_csv)
|
||||||
@ -484,6 +514,6 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False,
|
|||||||
if save:
|
if save:
|
||||||
file.parent.mkdir(parents=True, exist_ok=True) # make directory
|
file.parent.mkdir(parents=True, exist_ok=True) # make directory
|
||||||
f = str(increment_path(file).with_suffix('.jpg'))
|
f = str(increment_path(file).with_suffix('.jpg'))
|
||||||
# cv2.imwrite(f, crop) # https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
|
# cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
|
||||||
Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)).save(f, quality=95, subsampling=0)
|
Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
|
||||||
return crop
|
return crop
|
||||||
|
@ -34,6 +34,23 @@ except ImportError:
|
|||||||
warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
|
warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
|
||||||
|
|
||||||
|
|
||||||
|
def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
|
||||||
|
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
|
||||||
|
def decorate(fn):
|
||||||
|
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
|
||||||
|
|
||||||
|
return decorate
|
||||||
|
|
||||||
|
|
||||||
|
def smartCrossEntropyLoss(label_smoothing=0.0):
|
||||||
|
# Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0
|
||||||
|
if check_version(torch.__version__, '1.10.0'):
|
||||||
|
return nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
||||||
|
if label_smoothing > 0:
|
||||||
|
LOGGER.warning(f'WARNING: label smoothing {label_smoothing} requires torch>=1.10.0')
|
||||||
|
return nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
|
||||||
def smart_DDP(model):
|
def smart_DDP(model):
|
||||||
# Model DDP creation with checks
|
# Model DDP creation with checks
|
||||||
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
|
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
|
||||||
@ -45,6 +62,28 @@ def smart_DDP(model):
|
|||||||
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_classifier_output(model, n=1000):
|
||||||
|
# Update a TorchVision classification model to class count 'n' if required
|
||||||
|
from models.common import Classify
|
||||||
|
name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
|
||||||
|
if isinstance(m, Classify): # YOLOv5 Classify() head
|
||||||
|
if m.linear.out_features != n:
|
||||||
|
m.linear = nn.Linear(m.linear.in_features, n)
|
||||||
|
elif isinstance(m, nn.Linear): # ResNet, EfficientNet
|
||||||
|
if m.out_features != n:
|
||||||
|
setattr(model, name, nn.Linear(m.in_features, n))
|
||||||
|
elif isinstance(m, nn.Sequential):
|
||||||
|
types = [type(x) for x in m]
|
||||||
|
if nn.Linear in types:
|
||||||
|
i = types.index(nn.Linear) # nn.Linear index
|
||||||
|
if m[i].out_features != n:
|
||||||
|
m[i] = nn.Linear(m[i].in_features, n)
|
||||||
|
elif nn.Conv2d in types:
|
||||||
|
i = types.index(nn.Conv2d) # nn.Conv2d index
|
||||||
|
if m[i].out_channels != n:
|
||||||
|
m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def torch_distributed_zero_first(local_rank: int):
|
def torch_distributed_zero_first(local_rank: int):
|
||||||
# Decorator to make all processes in distributed training wait for each local_master to do something
|
# Decorator to make all processes in distributed training wait for each local_master to do something
|
||||||
@ -78,7 +117,7 @@ def select_device(device='', batch_size=0, newline=True):
|
|||||||
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
|
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
|
||||||
f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
|
f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
|
||||||
|
|
||||||
if not (cpu or mps) and torch.cuda.is_available(): # prefer GPU if available
|
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
||||||
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
||||||
n = len(devices) # device count
|
n = len(devices) # device count
|
||||||
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
|
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
|
||||||
@ -97,7 +136,7 @@ def select_device(device='', batch_size=0, newline=True):
|
|||||||
|
|
||||||
if not newline:
|
if not newline:
|
||||||
s = s.rstrip()
|
s = s.rstrip()
|
||||||
LOGGER.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe
|
LOGGER.info(s)
|
||||||
return torch.device(arg)
|
return torch.device(arg)
|
||||||
|
|
||||||
|
|
||||||
@ -109,14 +148,13 @@ def time_sync():
|
|||||||
|
|
||||||
|
|
||||||
def profile(input, ops, n=10, device=None):
|
def profile(input, ops, n=10, device=None):
|
||||||
# YOLOv5 speed/memory/FLOPs profiler
|
""" YOLOv5 speed/memory/FLOPs profiler
|
||||||
#
|
Usage:
|
||||||
# Usage:
|
input = torch.randn(16, 3, 640, 640)
|
||||||
# input = torch.randn(16, 3, 640, 640)
|
m1 = lambda x: x * torch.sigmoid(x)
|
||||||
# m1 = lambda x: x * torch.sigmoid(x)
|
m2 = nn.SiLU()
|
||||||
# m2 = nn.SiLU()
|
profile(input, [m1, m2], n=100) # profile over 100 iterations
|
||||||
# profile(input, [m1, m2], n=100) # profile over 100 iterations
|
"""
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
if not isinstance(device, torch.device):
|
if not isinstance(device, torch.device):
|
||||||
device = select_device(device)
|
device = select_device(device)
|
||||||
@ -199,12 +237,11 @@ def sparsity(model):
|
|||||||
def prune(model, amount=0.3):
|
def prune(model, amount=0.3):
|
||||||
# Prune model to requested global sparsity
|
# Prune model to requested global sparsity
|
||||||
import torch.nn.utils.prune as prune
|
import torch.nn.utils.prune as prune
|
||||||
print('Pruning model... ', end='')
|
|
||||||
for name, m in model.named_modules():
|
for name, m in model.named_modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
prune.l1_unstructured(m, name='weight', amount=amount) # prune
|
prune.l1_unstructured(m, name='weight', amount=amount) # prune
|
||||||
prune.remove(m, 'weight') # make permanent
|
prune.remove(m, 'weight') # make permanent
|
||||||
print(' %.3g global sparsity' % sparsity(model))
|
LOGGER.info(f'Model pruned to {sparsity(model):.3g} global sparsity')
|
||||||
|
|
||||||
|
|
||||||
def fuse_conv_and_bn(conv, bn):
|
def fuse_conv_and_bn(conv, bn):
|
||||||
@ -230,7 +267,7 @@ def fuse_conv_and_bn(conv, bn):
|
|||||||
return fusedconv
|
return fusedconv
|
||||||
|
|
||||||
|
|
||||||
def model_info(model, verbose=False, img_size=640):
|
def model_info(model, verbose=False, imgsz=640):
|
||||||
# Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
|
# Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
|
||||||
n_p = sum(x.numel() for x in model.parameters()) # number parameters
|
n_p = sum(x.numel() for x in model.parameters()) # number parameters
|
||||||
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
|
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
|
||||||
@ -242,12 +279,12 @@ def model_info(model, verbose=False, img_size=640):
|
|||||||
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
|
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
|
||||||
|
|
||||||
try: # FLOPs
|
try: # FLOPs
|
||||||
from thop import profile
|
p = next(model.parameters())
|
||||||
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
|
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
|
||||||
img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
|
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||||||
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
|
flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
|
||||||
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
|
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
||||||
fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 GFLOPs
|
fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs' # 640x640 GFLOPs
|
||||||
except Exception:
|
except Exception:
|
||||||
fs = ''
|
fs = ''
|
||||||
|
|
||||||
@ -276,7 +313,7 @@ def copy_attr(a, b, include=(), exclude=()):
|
|||||||
setattr(a, k, v)
|
setattr(a, k, v)
|
||||||
|
|
||||||
|
|
||||||
def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, weight_decay=1e-5):
|
def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
||||||
# YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decay
|
# YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decay
|
||||||
g = [], [], [] # optimizer parameter groups
|
g = [], [], [] # optimizer parameter groups
|
||||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
||||||
@ -299,13 +336,45 @@ def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, weight_decay=1e-
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'Optimizer {name} not implemented.')
|
raise NotImplementedError(f'Optimizer {name} not implemented.')
|
||||||
|
|
||||||
optimizer.add_param_group({'params': g[0], 'weight_decay': weight_decay}) # add g0 with weight_decay
|
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
|
||||||
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
||||||
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
|
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
|
||||||
f"{len(g[1])} weight (no decay), {len(g[0])} weight, {len(g[2])} bias")
|
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def smart_hub_load(repo='ultralytics/yolov5', model='yolov5s', **kwargs):
|
||||||
|
# YOLOv5 torch.hub.load() wrapper with smart error/issue handling
|
||||||
|
if check_version(torch.__version__, '1.9.1'):
|
||||||
|
kwargs['skip_validation'] = True # validation causes GitHub API rate limit errors
|
||||||
|
if check_version(torch.__version__, '1.12.0'):
|
||||||
|
kwargs['trust_repo'] = True # argument required starting in torch 0.12
|
||||||
|
try:
|
||||||
|
return torch.hub.load(repo, model, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
return torch.hub.load(repo, model, force_reload=True, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, resume=True):
|
||||||
|
# Resume training from a partially trained checkpoint
|
||||||
|
best_fitness = 0.0
|
||||||
|
start_epoch = ckpt['epoch'] + 1
|
||||||
|
if ckpt['optimizer'] is not None:
|
||||||
|
optimizer.load_state_dict(ckpt['optimizer']) # optimizer
|
||||||
|
best_fitness = ckpt['best_fitness']
|
||||||
|
if ema and ckpt.get('ema'):
|
||||||
|
ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
|
||||||
|
ema.updates = ckpt['updates']
|
||||||
|
if resume:
|
||||||
|
assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.\n' \
|
||||||
|
f"Start a new training without --resume, i.e. 'python train.py --weights {weights}'"
|
||||||
|
LOGGER.info(f'Resuming training from {weights} from epoch {start_epoch} to {epochs} total epochs')
|
||||||
|
if epochs < start_epoch:
|
||||||
|
LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
|
||||||
|
epochs += ckpt['epoch'] # finetune additional epochs
|
||||||
|
return best_fitness, start_epoch, epochs
|
||||||
|
|
||||||
|
|
||||||
class EarlyStopping:
|
class EarlyStopping:
|
||||||
# YOLOv5 simple early stopper
|
# YOLOv5 simple early stopper
|
||||||
def __init__(self, patience=30):
|
def __init__(self, patience=30):
|
||||||
@ -338,8 +407,6 @@ class ModelEMA:
|
|||||||
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||||||
# Create EMA
|
# Create EMA
|
||||||
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
||||||
# if next(model.parameters()).device.type != 'cpu':
|
|
||||||
# self.ema.half() # FP16 EMA
|
|
||||||
self.updates = updates # number of EMA updates
|
self.updates = updates # number of EMA updates
|
||||||
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
||||||
for p in self.ema.parameters():
|
for p in self.ema.parameters():
|
||||||
@ -347,15 +414,15 @@ class ModelEMA:
|
|||||||
|
|
||||||
def update(self, model):
|
def update(self, model):
|
||||||
# Update EMA parameters
|
# Update EMA parameters
|
||||||
with torch.no_grad():
|
self.updates += 1
|
||||||
self.updates += 1
|
d = self.decay(self.updates)
|
||||||
d = self.decay(self.updates)
|
|
||||||
|
|
||||||
msd = de_parallel(model).state_dict() # model state_dict
|
msd = de_parallel(model).state_dict() # model state_dict
|
||||||
for k, v in self.ema.state_dict().items():
|
for k, v in self.ema.state_dict().items():
|
||||||
if v.dtype.is_floating_point:
|
if v.dtype.is_floating_point: # true for FP16 and FP32
|
||||||
v *= d
|
v *= d
|
||||||
v += (1 - d) * msd[k].detach()
|
v += (1 - d) * msd[k].detach()
|
||||||
|
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
|
||||||
|
|
||||||
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
||||||
# Update EMA attributes
|
# Update EMA attributes
|
||||||
|
Loading…
x
Reference in New Issue
Block a user