tensorflow+openCV进行目标检测
发布时间
阅读量:
阅读量
一、准备
数据集:coco
模型:目标检测常用的三个模型有:SSD、Faster R-CNN、YOLO
省去训练过程,模型成品已备下载:[github地址]
环境:TensorFlow 1.14.0、openCV 4.1.1
二、检测
1、罗列类别名称
person
bicycle
car
motorbike
aeroplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
sofa
potted plant
bed
dining table
toilet
tv monitor
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush
2、下载模型
比如下载 faster_rcnn_inception_v2_coco_2018_01_28 解压到当前目录下
准备好待检测图片
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
# 加载coco数据集模型
model_path = "faster_rcnn_inception_v2_coco_2018_01_28"
frozen_pb_file = os.path.join(model_path, 'frozen_inference_graph.pb')
# 加载coco数据集分类
f = open("coco/classes.txt", "r")
class_names = f.readlines()
# model_path = ""
# frozen_pb_file = os.path.join(model_path, 'model.pb')
score_threshold = 0.3
img_file = 'pic/class.jpg'
# Read the graph.
with tf.gfile.FastGFile(frozen_pb_file, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Session() as sess:
# Restore session
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
# for op in sess.graph.get_operations():
# print(op)
# Read and preprocess an image.
img_cv2 = cv2.imread(img_file)
img_height, img_width, _ = img_cv2.shape
img_in = cv2.resize(img_cv2, (300, 300))
img_in = img_in[:, :, [2, 1, 0]] # BGR2RGB
# Run the model
outputs = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),
sess.graph.get_tensor_by_name('detection_scores:0'),
sess.graph.get_tensor_by_name('detection_boxes:0'),
sess.graph.get_tensor_by_name('detection_classes:0')],
feed_dict={
'image_tensor:0': img_in.reshape(1,
img_in.shape[0],
img_in.shape[1],
3)})
# Visualize detected bounding boxes.
num_detections = int(outputs[0][0])
for i in range(num_detections):
classId = int(outputs[3][0][i])
score = float(outputs[1][0][i])
bbox = [float(v) for v in outputs[2][0][i]]
if score > score_threshold:
x = bbox[1] * img_width
y = bbox[0] * img_height
right = bbox[3] * img_width
bottom = bbox[2] * img_height
# 标框
cv2.rectangle(img_cv2,
(int(x), int(y)),
(int(right), int(bottom)),
(125, 255, 51),
thickness=3)
# 文字"class_name, score"
cv2.putText(img_cv2,
class_names[classId - 1][:-1] + "," + str("%.2f" % score),
(int(x), int(y)),
cv2.FONT_HERSHEY_DUPLEX, 3, (0, 0, 255), 3)
print(str(classId) + ",class:" + class_names[classId - 1][:-1] + ",score:%.2f" % score)
plt.figure(figsize=(10, 8))
plt.imshow(img_cv2[:, :, ::-1])
plt.title("TensorFlow MobileNetV2-SSD")
plt.axis("off")
plt.show()

三、后期展望
采用新增的数据集对模型进行微调训练,并实施跨任务预训练过程(针对SSD或Faster R-CNN模型),应用于特定领域。
四、参考文献
1
1
1
【2】SSD模型的原理
3
3
3
3
全部评论 (0)
还没有任何评论哟~
