天天看點

keras_yolo3閱讀kmens.pytrain.py

源碼位址 https://github.com/qqwweee/keras-yolo3

春節期間仔細看了看yolov3的kears源碼,這個源碼畢竟不是作者寫的,有點寒酸,可能大道至簡也是這麼個理。我在看源碼的時候,參照了一些部落格進行補充,主要是,作者公布的代碼有點淩亂和我熟悉的代碼風格不同的緣故吧。。。。。

看到大神的優秀部落格,感覺自己的筆記有點炒冷飯的味道。。。?

1.目錄結構:

keras_yolo3閱讀kmens.pytrain.py

如下:這個就是直接從github上down下來的

.
├── coco_annotation.py
├── convert.py ├── darknet53.cfg ├── font │ ├── FiraMono-Medium.otf │ └── SIL Open Font License.txt ├── .gitignore ├── kmeans.py ├── LICENSE ├── model_data │ ├── coco_classes.txt │ ├── tiny_yolo_anchors.txt │ ├── voc_classes.txt │ └── yolo_anchors.txt ├── README.md ├── train_bottleneck.py ├── train.py ├── voc_annotation.py ├── yolo3 │ ├── __init__.py │ ├── model.py │ └── utils.py ├── yolo.py ├── yolov3.cfg ├── yolov3-tiny.cfg └── yolo_video.py                 
  1. font是字型目錄
  2. model_data:

    是各個資料庫對應的模型的檔案:

  • coco_classes檔案: 就是coco檔案的類别檔案

    如下:

    keras_yolo3閱讀kmens.pytrain.py
  • yolo_anchors檔案:就是yolo3所需要的anchors大小

    如下

    keras_yolo3閱讀kmens.pytrain.py
    這裡的兩檔案可以根據資料不同改變,改成你所需要的類别。而anchors可以通過k-means進行聚類直接獲得。
  1. yolo3:

    這裡有model.py和utils.py檔案。

  • model.py 就是建構yolo3的主要子產品檔案,這裡一共有14個函數/

    如下:

keras_yolo3閱讀kmens.pytrain.py
  • utils.py 是在模型訓練時進行資料處理的工具檔案,一共有3個函數:
keras_yolo3閱讀kmens.pytrain.py
  1. *_annoataion.py 對資料進行轉換的檔案,把原始的檔案轉換為txt檔案。
  • coco_annoataion.py 把json檔案轉換為txt檔案
  • voc_annoataion.py 把xml檔案轉換為txt
  1. convert.py 把原始權重轉換為kares的能讀取的原始h5檔案
  2. kmeans.py 輸入上面得到的txt檔案,通過聚類得到資料最佳anchors。
  3. train.py 進行yolov3訓練的檔案
  4. yolo.py 建構以yolov3為底層構件的yolo檢測模型,因為上面的yolov3還是分開的單個函數,功能并沒有融合在一起,即使在訓練的時候所有的yolov3元件還是分開的功能,并沒有統一接口,供在模型訓練完成之後,直接使用。通過yolo.py融合所有的元件。
  5. yolo_video.py 使用yolo.py檔案中的yolo檢測模型,并且對視訊中的物體進行檢測。
  6. yolov3.cfg 建構yolov3檢測模型的整個超參檔案。

在閱讀源碼的時候主要參考:

https://github.com/SpikeKing/keras-yolo3-detection的幾篇博文,但是為了更好了解keras-yolo3的代碼,這幾篇博文的對應檔案如下:

  • 探索 YOLO v3 源碼 - 第1篇 訓練---在train.py中
  • 探索 YOLO v3 源碼 - 第2篇 模型---在train.py中
  • 探索 YOLO v3 源碼 - 第3篇 網絡---在yolo3/model.py中
  • 探索 YOLO v3 源碼 - 第4篇 真值---在yolo3/utils.py和yolo3/model.py中
  • 探索 YOLO v3 源碼 - 第5篇 Loss---在yolo3/model.py中
  • 探索 YOLO v3 源碼 - 完結篇 預測---在yolo.py中

kmens.py

import numpy as np


class YOLO_Kmeans:

    def __init__(self, cluster_number, filename): # 讀取kmeans的中心數 self.cluster_number = cluster_number # 标簽檔案的檔案名 self.filename = "2012_train.txt" def iou(self, boxes, clusters): # 1 box -> k clusters # boxes : 所有的[width, height] # clusters : 9個随機的中心點[width, height] n = boxes.shape[                

k-means拿到資料裡所有的目标框,得到所有的寬和高,在這裡面随機取得9個随即中心,之後以9個點為中心得到9個族,不斷計算其他點到中點的距離調整每個點所歸屬的族和中心,直到9個中心不再變即可。這9個中心的x,y就是整個資料的9個合适的anchors==框的寬和高。

train.py

#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2018. All rights reserved.
Created by C. L. Wang on 2018/7/4
"""
import os import numpy as np import tensorflow as tf import keras.backend as K from keras.backend import mean from keras.layers import Input, Lambda from keras.models import Model from keras.optimizers import Adam from keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau, EarlyStopping from keras.utils import plot_model from yolo3.model import preprocess_true_boxes, yolo_body, tiny_yolo_body, yolo_loss from yolo3.utils import get_random_data def _main(): import os os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" from keras import backend as K config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) K.set_session(sess) annotation_path = 'dataset/WIDER_train.txt' # 資料 classes_path = 'configs/wider_classes.txt' # 類别 log_dir = 'logs/004/' # 日志檔案夾 # pretrained_path = 'model_data/yolo_weights.h5' # 預訓練模型 pretrained_path = 'logs/003/ep074-loss26.535-val_loss27.370.h5' # 預訓練模型 anchors_path = 'configs/yolo_anchors.txt' # anchors class_names = get_classes(classes_path) # 類别清單 num_classes = len(class_names) # 類别數 anchors = get_anchors(anchors_path) # anchors清單 input_shape = (                

轉載于:https://www.cnblogs.com/shuimuqingyang/p/10680389.html