Advertisement

tensorflow+openCV进行目标检测

阅读量:

一、准备

数据集:coco

模型:目标检测常用的三个模型有:SSD、Faster R-CNN、YOLO

省去训练过程,模型成品已备下载:[github地址]

环境:TensorFlow 1.14.0、openCV 4.1.1

二、检测

1、罗列类别名称

复制代码
 person

    
 bicycle
    
 car
    
 motorbike
    
 aeroplane
    
 bus
    
 train
    
 truck
    
 boat
    
 traffic light
    
 fire hydrant
    
  
    
 stop sign
    
 parking meter
    
 bench
    
 bird
    
 cat
    
 dog
    
 horse
    
 sheep
    
 cow
    
 elephant
    
 bear
    
 zebra
    
 giraffe
    
  
    
 backpack
    
 umbrella
    
  
    
  
    
 handbag
    
 tie
    
 suitcase
    
 frisbee
    
 skis
    
 snowboard
    
 sports ball
    
 kite
    
 baseball bat
    
 baseball glove
    
 skateboard
    
 surfboard
    
 tennis racket
    
 bottle
    
  
    
 wine glass
    
 cup
    
 fork
    
 knife
    
 spoon
    
 bowl
    
 banana
    
 apple
    
 sandwich
    
 orange
    
 broccoli
    
 carrot
    
 hot dog
    
 pizza
    
 donut
    
 cake
    
 chair
    
 sofa
    
 potted plant
    
 bed
    
  
    
 dining table
    
  
    
  
    
 toilet
    
  
    
 tv monitor
    
 laptop
    
 mouse
    
 remote
    
 keyboard
    
 cell phone
    
 microwave
    
 oven
    
 toaster
    
 sink
    
 refrigerator
    
  
    
 book
    
 clock
    
 vase
    
 scissors
    
 teddy bear
    
 hair drier
    
 toothbrush

2、下载模型

比如下载 faster_rcnn_inception_v2_coco_2018_01_28 解压到当前目录下

准备好待检测图片

复制代码
 import os

    
 import numpy as np
    
 import cv2
    
 import matplotlib.pyplot as plt
    
 import tensorflow as tf
    
  
    
 # 加载coco数据集模型
    
 model_path = "faster_rcnn_inception_v2_coco_2018_01_28"
    
 frozen_pb_file = os.path.join(model_path, 'frozen_inference_graph.pb')
    
  
    
 # 加载coco数据集分类
    
 f = open("coco/classes.txt", "r")
    
 class_names = f.readlines()
    
  
    
 # model_path = ""
    
 # frozen_pb_file = os.path.join(model_path, 'model.pb')
    
  
    
  
    
 score_threshold = 0.3
    
  
    
 img_file = 'pic/class.jpg'
    
  
    
 # Read the graph.
    
 with tf.gfile.FastGFile(frozen_pb_file, 'rb') as f:
    
     graph_def = tf.GraphDef()
    
     graph_def.ParseFromString(f.read())
    
  
    
  
    
 with tf.Session() as sess:
    
     # Restore session
    
     sess.graph.as_default()
    
     tf.import_graph_def(graph_def, name='')
    
  
    
     # for op in sess.graph.get_operations():
    
     #     print(op)
    
  
    
     # Read and preprocess an image.
    
     img_cv2 = cv2.imread(img_file)
    
     img_height, img_width, _ = img_cv2.shape
    
  
    
     img_in = cv2.resize(img_cv2, (300, 300))
    
     img_in = img_in[:, :, [2, 1, 0]]  # BGR2RGB
    
  
    
     # Run the model
    
     outputs = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),
    
                 sess.graph.get_tensor_by_name('detection_scores:0'),
    
                 sess.graph.get_tensor_by_name('detection_boxes:0'),
    
                 sess.graph.get_tensor_by_name('detection_classes:0')],
    
                feed_dict={
    
                    'image_tensor:0': img_in.reshape(1,
    
                                                     img_in.shape[0],
    
                                                     img_in.shape[1],
    
                                                     3)})
    
  
    
     # Visualize detected bounding boxes.
    
     num_detections = int(outputs[0][0])
    
     for i in range(num_detections):
    
     classId = int(outputs[3][0][i])
    
     score = float(outputs[1][0][i])
    
     bbox = [float(v) for v in outputs[2][0][i]]
    
     if score > score_threshold:
    
         x = bbox[1] * img_width
    
         y = bbox[0] * img_height
    
         right = bbox[3] * img_width
    
         bottom = bbox[2] * img_height
    
         # 标框
    
         cv2.rectangle(img_cv2,
    
                       (int(x), int(y)),
    
                       (int(right), int(bottom)),
    
                       (125, 255, 51),
    
                       thickness=3)
    
         # 文字"class_name, score"
    
         cv2.putText(img_cv2,
    
                     class_names[classId - 1][:-1] + "," + str("%.2f" % score),
    
                     (int(x), int(y)),
    
                     cv2.FONT_HERSHEY_DUPLEX, 3, (0, 0, 255), 3)
    
         print(str(classId) + ",class:" + class_names[classId - 1][:-1] + ",score:%.2f" % score)
    
  
    
 plt.figure(figsize=(10, 8))
    
 plt.imshow(img_cv2[:, :, ::-1])
    
 plt.title("TensorFlow MobileNetV2-SSD")
    
 plt.axis("off")
    
 plt.show()

三、后期展望

采用新增的数据集对模型进行微调训练,并实施跨任务预训练过程(针对SSD或Faster R-CNN模型),应用于特定领域。

四、参考文献

1

1

1

【2】SSD模型的原理

3

3

3

3

全部评论 (0)

还没有任何评论哟~