Advertisement

python+opencv实现目标跟踪

阅读量:

Python 3.0版本的OpenCV新增了几种实用的追踪算法,在官方示例的基础上编写了这个实例。

程序仅限于需安装opencv3.0及以上版本及相应的官方 contrib 包的python解释器

复制代码
 #encoding=utf-8

    
  
    
 import cv2
    
 from items import MessageItem
    
 import time
    
 import numpy as np
    
 '''
    
 监视者模块,负责入侵检测,目标跟踪
    
 '''
    
 class WatchDog(object):
    
   #入侵检测者模块,用于入侵检测
    
     def __init__(self,frame=None):
    
     #运动检测器构造函数
    
     self._background = None
    
     if frame is not None:
    
         self._background = cv2.GaussianBlur(cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY),(21,21),0)
    
     self.es = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
    
     def isWorking(self):
    
     #运动检测器是否工作
    
     return self._background is not None
    
     def startWorking(self,frame):
    
     #运动检测器开始工作
    
     if frame is not None:
    
         self._background = cv2.GaussianBlur(cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY), (21, 21), 0)
    
     def stopWorking(self):
    
     #运动检测器结束工作
    
     self._background = None
    
     def analyze(self,frame):
    
     #运动检测
    
     if frame is None or self._background is None:
    
         return
    
     sample_frame = cv2.GaussianBlur(cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY),(21,21),0)
    
     diff = cv2.absdiff(self._background,sample_frame)
    
     diff = cv2.threshold(diff, 25, 255, cv2.THRESH_BINARY)[1]
    
     diff = cv2.dilate(diff, self.es, iterations=2)
    
     image, cnts, hierarchy = cv2.findContours(diff.copy(),cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
     coordinate = []
    
     bigC = None
    
     bigMulti = 0
    
     for c in cnts:
    
         if cv2.contourArea(c) < 1500:
    
             continue
    
         (x,y,w,h) = cv2.boundingRect(c)
    
         if w * h > bigMulti:
    
             bigMulti = w * h
    
             bigC = ((x,y),(x+w,y+h))
    
     if bigC:
    
         cv2.rectangle(frame, bigC[0],bigC[1], (255,0,0), 2, 1)
    
     coordinate.append(bigC)
    
     message = {"coord":coordinate}
    
     message['msg'] = None
    
     return MessageItem(frame,message)
    
  
    
 class Tracker(object):
    
     '''
    
     追踪者模块,用于追踪指定目标
    
     '''
    
     def __init__(self,tracker_type = "BOOSTING",draw_coord = True):
    
     '''
    
     初始化追踪器种类
    
     '''
    
     #获得opencv版本
    
     (major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.')
    
     self.tracker_types = ['BOOSTING', 'MIL','KCF', 'TLD', 'MEDIANFLOW', 'GOTURN']
    
     self.tracker_type = tracker_type
    
     self.isWorking = False
    
     self.draw_coord = draw_coord
    
     #构造追踪器
    
     if int(minor_ver) < 3:
    
         self.tracker = cv2.Tracker_create(tracker_type)
    
     else:
    
         if tracker_type == 'BOOSTING':
    
             self.tracker = cv2.TrackerBoosting_create()
    
         if tracker_type == 'MIL':
    
             self.tracker = cv2.TrackerMIL_create()
    
         if tracker_type == 'KCF':
    
             self.tracker = cv2.TrackerKCF_create()
    
         if tracker_type == 'TLD':
    
             self.tracker = cv2.TrackerTLD_create()
    
         if tracker_type == 'MEDIANFLOW':
    
             self.tracker = cv2.TrackerMedianFlow_create()
    
         if tracker_type == 'GOTURN':
    
             self.tracker = cv2.TrackerGOTURN_create()
    
     def initWorking(self,frame,box):
    
     '''
    
     追踪器工作初始化
    
     frame:初始化追踪画面
    
     box:追踪的区域
    
     '''
    
     if not self.tracker:
    
         raise Exception("追踪器未初始化")
    
     status = self.tracker.init(frame,box)
    
     if not status:
    
         raise Exception("追踪器工作初始化失败")
    
     self.coord = box
    
     self.isWorking = True
    
  
    
     def track(self,frame):
    
     '''
    
     开启追踪
    
     '''
    
     message = None
    
     if self.isWorking:
    
         status,self.coord = self.tracker.update(frame)
    
         if status:
    
             message = {"coord":[((int(self.coord[0]), int(self.coord[1])),(int(self.coord[0] + self.coord[2]), int(self.coord[1] + self.coord[3])))]}
    
             if self.draw_coord:
    
                 p1 = (int(self.coord[0]), int(self.coord[1]))
    
                 p2 = (int(self.coord[0] + self.coord[2]), int(self.coord[1] + self.coord[3]))
    
                 cv2.rectangle(frame, p1, p2, (255,0,0), 2, 1)
    
                 message['msg'] = "is tracking"
    
     return MessageItem(frame,message)
    
  
    
 class ObjectTracker(object):
    
     def __init__(self,dataSet):
    
     self.cascade = cv2.CascadeClassifier(dataSet)
    
     def track(self,frame):
    
     gray = cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY)
    
     faces = self.cascade.detectMultiScale(gray,1.03,5)
    
     for (x,y,w,h) in faces:
    
         cv2.rectangle(frame,(x,y),(x+w,y+h),(255,0,0),2)
    
     return frame
    
  
    
 if __name__ == '__main__' :
    
     a = ['BOOSTING', 'MIL','KCF', 'TLD', 'MEDIANFLOW', 'GOTURN']
    
     tracker = Tracker(tracker_type="KCF")
    
     video = cv2.VideoCapture(0)
    
     ok, frame = video.read()
    
     bbox = cv2.selectROI(frame, False)
    
     tracker.initWorking(frame,bbox)
    
     while True:
    
     _,frame = video.read();
    
     if(_):
    
         item = tracker.track(frame);
    
         cv2.imshow("track",item.getFrame())
    
         k = cv2.waitKey(1) & 0xff
    
         if k == 27:
    
             break
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-16/sqoEytn1IgHz48pamS2V5LOd3fZN.png)
复制代码
 #encoding=utf-8

    
 import json
    
 from utils import IOUtil
    
 '''
    
 信息封装类
    
 '''
    
 class MessageItem(object):
    
     #用于封装信息的类,包含图片和其他信息
    
     def __init__(self,frame,message):
    
         self._frame = frame
    
         self._message = message
    
     def getFrame(self):
    
         #图片信息
    
         return self._frame
    
     def getMessage(self):
    
         #文字信息,json格式
    
         return self._message
    
     def getBase64Frame(self):
    
         #返回base64格式的图片,将BGR图像转化为RGB图像
    
         jepg = IOUtil.array_to_bytes(self._frame[...,::-1])
    
         return IOUtil.bytes_to_base64(jepg)
    
     def getBase64FrameByte(self):
    
         #返回base64格式图片的bytes
    
         return bytes(self.getBase64Frame())
    
     def getJson(self):
    
         #获得json数据格式
    
         dicdata = {"frame":self.getBase64Frame().decode(),"message":self.getMessage()}
    
         return json.dumps(dicdata)
    
     def getBinaryFrame(self):
    
         return IOUtil.array_to_bytes(self._frame[...,::-1])
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-16/wvFQl8GrBN5XZLsDAC6JWgVecSHa.png)

随后,在初始帧图像中选择了需要追踪的具体区域,并对KCF算法构建的追踪系统进行了初步测试。

更新:忘记放utils,给大家造成的困扰深表歉意

复制代码
 #encoding=utf-8

    
 import time
    
 import numpy
    
 import base64
    
 import os
    
 import logging
    
 import sys
    
 from settings import *
    
 from PIL import Image
    
 from io import BytesIO
    
  
    
 #工具类
    
 class IOUtil(object):
    
     #流操作工具类
    
     @staticmethod
    
     def array_to_bytes(pic,formatter="jpeg",quality=70):
    
     '''
    
     静态方法,将numpy数组转化二进制流
    
     :param pic: numpy数组
    
     :param format: 图片格式
    
     :param quality:压缩比,压缩比越高,产生的二进制数据越短
    
     :return: 
    
     '''
    
     stream = BytesIO()
    
     picture = Image.fromarray(pic)
    
     picture.save(stream,format=formatter,quality=quality)
    
     jepg = stream.getvalue()
    
     stream.close()
    
     return jepg
    
     @staticmethod
    
     def bytes_to_base64(byte):
    
     '''
    
     静态方法,bytes转base64编码
    
     :param byte: 
    
     :return: 
    
     '''
    
     return base64.b64encode(byte)
    
     @staticmethod
    
     def transport_rgb(frame):
    
     '''
    
     将bgr图像转化为rgb图像,或者将rgb图像转化为bgr图像
    
     '''
    
     return frame[...,::-1]
    
     @staticmethod
    
     def byte_to_package(bytes,cmd,var=1):
    
     '''
    
     将每一帧的图片流的二进制数据进行分包
    
     :param byte: 二进制文件
    
     :param cmd:命令
    
     :return: 
    
     '''
    
     head = [ver,len(byte),cmd]
    
     headPack = struct.pack("!3I", *head)
    
     senddata = headPack+byte
    
     return senddata
    
     @staticmethod
    
     def mkdir(filePath):
    
     '''
    
     创建文件夹
    
     '''
    
     if not os.path.exists(filePath):
    
         os.mkdir(filePath)
    
     @staticmethod
    
     def countCenter(box):
    
     '''
    
     计算一个矩形的中心
    
     '''
    
     return (int(abs(box[0][0] - box[1][0])*0.5) + box[0][0],int(abs(box[0][1] - box[1][1])*0.5) +box[0][1])
    
     @staticmethod
    
     def countBox(center):
    
     '''
    
     根据两个点计算出,x,y,c,r
    
     '''
    
     return (center[0][0],center[0][1],center[1][0]-center[0][0],center[1][1]-center[0][1])
    
     @staticmethod
    
     def getImageFileName():
    
     return time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())+'.png'
    
  
    
 #构造日志
    
 logger = logging.getLogger(LOG_NAME)
    
 formatter = logging.Formatter(LOG_FORMATTER)
    
 IOUtil.mkdir(LOG_DIR);
    
 file_handler = logging.FileHandler(LOG_DIR + LOG_FILE,encoding='utf-8')
    
 file_handler.setFormatter(formatter)
    
 console_handler = logging.StreamHandler(sys.stdout)
    
 console_handler.setFormatter(formatter)
    
 logger.addHandler(file_handler)
    
 logger.addHandler(console_handler)
    
 logger.setLevel(logging.INFO)
    
    
    
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-16/jVnEbiPqUdLrNMGgI9S6sHXe07aY.png)

全部评论 (0)

还没有任何评论哟~