天天看點

垃圾分類(加入增強學習和通道機制)

import numpy as np
import matplotlib.pyplot as plt
from keras.preprocessing.image import ImageDataGenerator,load_img,img_to_array
from keras.layers import Conv2D,Flatten,MaxPooling2D,Dense
from keras.models import Sequential,load_model
import glob,os,random
import time
import keras

base_path = "datasets"

def look_dataset_num():

    img_list = glob.glob(os.path.join(base_path, "*/*.jpg"))
    print(len(img_list))  # 2307
    # 随機檢視資料,枚舉
    for i, img_path in enumerate(random.sample(img_list, 6)):
        img = load_img(img_path)
        img = img_to_array(img, dtype=np.uint8)

        # 子圖
        plt.subplot(2, 3, i + 1)
        plt.imshow(img.squeeze())
    plt.show()

def crate_model():

    start = time.time()
    train_datagen = ImageDataGenerator(
        rescale=1. / 225, shear_range=0.1, zoom_range=0.1,
        width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True,
        vertical_flip=True, validation_split=0.1)

    test_datagen = ImageDataGenerator(
        rescale=1. / 255, validation_split=0.1)

    train_generator = train_datagen.flow_from_directory(
        base_path, target_size=(300, 300), batch_size=16,
        class_mode='categorical', subset='training', seed=0)
    # Found 2276 images belonging to 6 classes.
    validation_generator = test_datagen.flow_from_directory(
        base_path, target_size=(300, 300), batch_size=16,
        class_mode='categorical', subset='validation', seed=0)
    # Found 251 images belonging to 6 classes.
    a = (validation_generator.class_indices)
    a = dict((v, k) for k, v in a.items())

    labels = (train_generator.class_indices)
    labels = dict((v, k) for k, v in labels.items())

    print('train_datagen ', a)
    # train_datagen  {0: 'cardboard', 1: 'glass', 2: 'metal', 3: 'paper', 4: 'plastic', 5: 'trash'}
    print('test_datagen', train_datagen)
    # test_datagen <keras.preprocessing.image.ImageDataGenerator object at 0x000002B54BB429B0>
    print('labels', labels)
    # labels {0: 'cardboard', 1: 'glass', 2: 'metal', 3: 'paper', 4: 'plastic', 5: 'trash'}

    # 4.模型的建立和訓練
    model = Sequential([
        Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', input_shape=(300, 300, 3)),
        MaxPooling2D(pool_size=2),

        Conv2D(filters=64, kernel_size=3, padding='same', activation='relu', kernel_regularizer=keras.regularizers.l2(0.001)),
        MaxPooling2D(pool_size=2),

        Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', kernel_regularizer=keras.regularizers.l2(0.001)),
        MaxPooling2D(pool_size=2),

        Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', kernel_regularizer=keras.regularizers.l2(0.001)),
        MaxPooling2D(pool_size=2),

        Flatten(),

        Dense(64, activation='relu'),

        Dense(6, activation='softmax')
    ])

    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])

    model.fit_generator(train_generator, epochs=100, steps_per_epoch=2276 // 32, validation_data=validation_generator,
                        validation_steps=251 // 32)

    model.save('rubbish/rubbish_model.h5')

    #
    # 5.結果展示
    # 下面我們随機抽取validation中的16張圖檔,展示圖檔以及其标簽,并且給予我們的預測。
    # 我們發現預測的準确度還是蠻高的,對于大部分圖檔,都能識别出其類别。

    test_x, test_y = validation_generator.__getitem__(1)

    preds = model.predict(test_x)

    plt.figure(figsize=(16, 16))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.title('pred:%s / truth:%s' % (labels[np.argmax(preds[i])], labels[np.argmax(test_y[i])]))
        plt.imshow(test_x[i])

    plt.show()

    end = time.time()
    t = end - start
    print('運作time', t)

def use_model():

    train_datagen = ImageDataGenerator(
        rescale=1. / 225, shear_range=0.1, zoom_range=0.1,
        width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True,
        vertical_flip=True, validation_split=0.1)

    test_datagen = ImageDataGenerator(
        rescale=1. / 255, validation_split=0.1)

    train_generator = train_datagen.flow_from_directory(
        base_path, target_size=(300, 300), batch_size=36,
        class_mode='categorical', subset='training', seed=0)
    # Found 2276 images belonging to 6 classes.
    validation_generator = test_datagen.flow_from_directory(
        base_path, target_size=(300, 300), batch_size=36,
        class_mode='categorical', subset='validation', seed=0)

    a = (validation_generator.class_indices)

    labels = (train_generator.class_indices)
    labels = dict((v, k) for k, v in labels.items())

    model = load_model('rubbish/rubbish_model.h5')

    test_x, test_y = validation_generator.__getitem__(1)
    print(test_x)
    preds = model.predict(test_x)

    plt.figure(figsize=(36, 36))
    for i in range(36):
        plt.subplot(6, 6, i + 1)
        plt.title('pred:%s / truth:%s' % (labels[np.argmax(preds[i])], labels[np.argmax(test_y[i])]))
        plt.imshow(test_x[i])

    plt.show()

if __name__ == '__main__':
    # look_dataset_num()
    # crate_model()
    use_model()