这篇文章主要介绍了python+opencv实现目标跟踪过程,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
python-opencv3.0新增了一些比较有用的追踪器算法
程序只能运行在安装有opencv3.0以上版本和对应的contrib模块的python解释器
- #encoding=utf-8
-
- import cv2
- from items import MessageItem
- import time
- import numpy as np
- '''
- 监视者模块,负责入侵检测,目标跟踪
- '''
- class WatchDog(object):
- #入侵检测者模块,用于入侵检测
- def __init__(self,frame=None):
- #运动检测器构造函数
- self._background = None
- if frame is not None:
- self._background = cv2.GaussianBlur(cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY),(21,21),0)
- self.es = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
- def isWorking(self):
- #运动检测器是否工作
- return self._background is not None
- def startWorking(self,frame):
- #运动检测器开始工作
- if frame is not None:
- self._background = cv2.GaussianBlur(cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY), (21, 21), 0)
- def stopWorking(self):
- #运动检测器结束工作
- self._background = None
- def analyze(self,frame):
- #运动检测
- if frame is None or self._background is None:
- return
- sample_frame = cv2.GaussianBlur(cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY),(21,21),0)
- diff = cv2.absdiff(self._background,sample_frame)
- diff = cv2.threshold(diff, 25, 255, cv2.THRESH_BINARY)[1]
- diff = cv2.dilate(diff, self.es, iterations=2)
- image, cnts, hierarchy = cv2.findContours(diff.copy(),cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
- coordinate = []
- bigC = None
- bigMulti = 0
- for c in cnts:
- if cv2.contourArea(c) < 1500:
- continue
- (x,y,w,h) = cv2.boundingRect(c)
- if w * h > bigMulti:
- bigMulti = w * h
- bigC = ((x,y),(x+w,y+h))
- if bigC:
- cv2.rectangle(frame, bigC[0],bigC[1], (255,0,0), 2, 1)
- coordinate.append(bigC)
- message = {"coord":coordinate}
- message['msg'] = None
- return MessageItem(frame,message)
-
- class Tracker(object):
- '''
- 追踪者模块,用于追踪指定目标
- '''
- def __init__(self,tracker_type = "BOOSTING",draw_coord = True):
- '''
- 初始化追踪器种类
- '''
- #获得opencv版本
- (major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.')
- self.tracker_types = ['BOOSTING', 'MIL','KCF', 'TLD', 'MEDIANFLOW', 'GOTURN']
- self.tracker_type = tracker_type
- self.isWorking = False
- self.draw_coord = draw_coord
- #构造追踪器
- if int(minor_ver) < 3:
- self.tracker = cv2.Tracker_create(tracker_type)
- else:
- if tracker_type == 'BOOSTING':
- self.tracker = cv2.TrackerBoosting_create()
- if tracker_type == 'MIL':
- self.tracker = cv2.TrackerMIL_create()
- if tracker_type == 'KCF':
- self.tracker = cv2.TrackerKCF_create()
- if tracker_type == 'TLD':
- self.tracker = cv2.TrackerTLD_create()
- if tracker_type == 'MEDIANFLOW':
- self.tracker = cv2.TrackerMedianFlow_create()
- if tracker_type == 'GOTURN':
- self.tracker = cv2.TrackerGOTURN_create()
- def initWorking(self,frame,box):
- '''
- 追踪器工作初始化
- frame:初始化追踪画面
- box:追踪的区域
- '''
- if not self.tracker:
- raise Exception("追踪器未初始化")
- status = self.tracker.init(frame,box)
- if not status:
- raise Exception("追踪器工作初始化失败")
- self.coord = box
- self.isWorking = True
-
- def track(self,frame):
- '''
- 开启追踪
- '''
- message = None
- if self.isWorking:
- status,self.coord = self.tracker.update(frame)
- if status:
- message = {"coord":[((int(self.coord[0]), int(self.coord[1])),(int(self.coord[0] + self.coord[2]), int(self.coord[1] + self.coord[3])))]}
- if self.draw_coord:
- p1 = (int(self.coord[0]), int(self.coord[1]))
- p2 = (int(self.coord[0] + self.coord[2]), int(self.coord[1] + self.coord[3]))
- cv2.rectangle(frame, p1, p2, (255,0,0), 2, 1)
- message['msg'] = "is tracking"
- return MessageItem(frame,message)
-
- class ObjectTracker(object):
- def __init__(self,dataSet):
- self.cascade = cv2.CascadeClassifier(dataSet)
- def track(self,frame):
- gray = cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY)
- faces = self.cascade.detectMultiScale(gray,1.03,5)
- for (x,y,w,h) in faces:
- cv2.rectangle(frame,(x,y),(x+w,y+h),(255,0,0),2)
- return frame
-
- if __name__ == '__main__' :
- a = ['BOOSTING', 'MIL','KCF', 'TLD', 'MEDIANFLOW', 'GOTURN']
- tracker = Tracker(tracker_type="KCF")
- video = cv2.VideoCapture(0)
- ok, frame = video.read()
- bbox = cv2.selectROI(frame, False)
- tracker.initWorking(frame,bbox)
- while True:
- _,frame = video.read();
- if(_):
- item = tracker.track(frame);
- cv2.imshow("track",item.getFrame())
- k = cv2.waitKey(1) & 0xff
- if k == 27:
- break
- #encoding=utf-8
- import json
- from utils import IOUtil
- '''
- 信息封装类
- '''
- class MessageItem(object):
- #用于封装信息的类,包含图片和其他信息
- def __init__(self,frame,message):
- self._frame = frame
- self._message = message
- def getFrame(self):
- #图片信息
- return self._frame
- def getMessage(self):
- #文字信息,json格式
- return self._message
- def getBase64Frame(self):
- #返回base64格式的图片,将BGR图像转化为RGB图像
- jepg = IOUtil.array_to_bytes(self._frame[...,::-1])
- return IOUtil.bytes_to_base64(jepg)
- def getBase64FrameByte(self):
- #返回base64格式图片的bytes
- return bytes(self.getBase64Frame())
- def getJson(self):
- #获得json数据格式
- dicdata = {"frame":self.getBase64Frame().decode(),"message":self.getMessage()}
- return json.dumps(dicdata)
- def getBinaryFrame(self):
- return IOUtil.array_to_bytes(self._frame[...,::-1])
运行之后在第一帧图像上选择要追踪的部分,这里测试了一下使用KCF算法的追踪器
更新:忘记放utils,给大家造成的困扰深表歉意
- #encoding=utf-8
- import time
- import numpy
- import base64
- import os
- import logging
- import sys
- from settings import *
- from PIL import Image
- from io import BytesIO
-
- #工具类
- class IOUtil(object):
- #流操作工具类
- @staticmethod
- def array_to_bytes(pic,formatter="jpeg",quality=70):
- '''
- 静态方法,将numpy数组转化二进制流
- :param pic: numpy数组
- :param format: 图片格式
- :param quality:压缩比,压缩比越高,产生的二进制数据越短
- :return:
- '''
- stream = BytesIO()
- picture = Image.fromarray(pic)
- picture.save(stream,format=formatter,quality=quality)
- jepg = stream.getvalue()
- stream.close()
- return jepg
- @staticmethod
- def bytes_to_base64(byte):
- '''
- 静态方法,bytes转base64编码
- :param byte:
- :return:
- '''
- return base64.b64encode(byte)
- @staticmethod
- def transport_rgb(frame):
- '''
- 将bgr图像转化为rgb图像,或者将rgb图像转化为bgr图像
- '''
- return frame[...,::-1]
- @staticmethod
- def byte_to_package(bytes,cmd,var=1):
- '''
- 将每一帧的图片流的二进制数据进行分包
- :param byte: 二进制文件
- :param cmd:命令
- :return:
- '''
- head = [ver,len(byte),cmd]
- headPack = struct.pack("!3I", *head)
- senddata = headPack+byte
- return senddata
- @staticmethod
- def mkdir(filePath):
- '''
- 创建文件夹
- '''
- if not os.path.exists(filePath):
- os.mkdir(filePath)
- @staticmethod
- def countCenter(box):
- '''
- 计算一个矩形的中心
- '''
- return (int(abs(box[0][0] - box[1][0])*0.5) + box[0][0],int(abs(box[0][1] - box[1][1])*0.5) +box[0][1])
- @staticmethod
- def countBox(center):
- '''
- 根据两个点计算出,x,y,c,r
- '''
- return (center[0][0],center[0][1],center[1][0]-center[0][0],center[1][1]-center[0][1])
- @staticmethod
- def getImageFileName():
- return time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())+'.png'
-
- #构造日志
- logger = logging.getLogger(LOG_NAME)
- formatter = logging.Formatter(LOG_FORMATTER)
- IOUtil.mkdir(LOG_DIR);
- file_handler = logging.FileHandler(LOG_DIR + LOG_FILE,encoding='utf-8')
- file_handler.setFormatter(formatter)
- console_handler = logging.StreamHandler(sys.stdout)
- console_handler.setFormatter(formatter)
- logger.addHandler(file_handler)
- logger.addHandler(console_handler)
- logger.setLevel(logging.INFO)
以上为个人经验,希望能给大家一个参考,也希望大家多多支持城东书院。