天天看點

pytorch學習系列(3):兩種tensorboard可視化1.tensorboardX可視化2.利用腳本實作tensorboard

1.tensorboardX可視化

這種方式直接使用tensorboardX中的SummaryWriter來儲存變量。

from tensorboardX import SummaryWriter
writer = SummaryWriter(log_dir=log_dir)#log_dir為日志檔案的儲存目錄
train()
dummy_input = torch.rand(opt.batch_size,channels,opt.image_size,opt.image_size)#設定一個假的輸入量
model = net.NET()
writer.add_graph(net,(dummy_input,))#儲存計算圖,可視化模型結構
writer.add_scalar('train/loss', your_loss, epoch)#儲存标量,一般為損失或者精度
x = vutils.make_grid(image,normalize=True,scale_each=True)
writer.add_image('image',x,epoch)#可視化中間過程的圖像
for name,param in model.named_parameters():
            writer.add_histogram(name,param.clone().cpu().data.numpy(),epoch)#儲存參數的直方圖
           

2.利用腳本實作tensorboard

複制如下代碼到新的檔案logger.py中,将其放入項目目錄中。

import tensorflow as tf
import numpy as np
import scipy.misc
try:
    from StringIO import StringIO  # Python 2.7
except ImportError:
    from io import BytesIO         # Python 3.x


class Logger(object):
    
    def __init__(self, log_dir):
        """Create a summary writer logging to log_dir."""
        self.writer = tf.summary.FileWriter(log_dir)

    def scalar_summary(self, tag, value, step):
        """Log a scalar variable."""
        summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
        self.writer.add_summary(summary, step)

    def image_summary(self, tag, images, step):
        """Log a list of images."""

        img_summaries = []
        for i, img in enumerate(images):
            # Write the image to a string
            try:
                s = StringIO()
            except:
                s = BytesIO()
            scipy.misc.toimage(img).save(s, format="png")

            # Create an Image object
            img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
                                       height=img.shape[0],
                                       width=img.shape[1])
            # Create a Summary value
            img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))

        # Create and write Summary
        summary = tf.Summary(value=img_summaries)
        self.writer.add_summary(summary, step)
        
    def histo_summary(self, tag, values, step, bins=1000):
        """Log a histogram of the tensor of values."""

        # Create a histogram using numpy
        counts, bin_edges = np.histogram(values, bins=bins)

        # Fill the fields of the histogram proto
        hist = tf.HistogramProto()
        hist.min = float(np.min(values))
        hist.max = float(np.max(values))
        hist.num = int(np.prod(values.shape))
        hist.sum = float(np.sum(values))
        hist.sum_squares = float(np.sum(values**2))

        # Drop the start of the first bin
        bin_edges = bin_edges[1:]

        # Add bin edges and counts
        for edge in bin_edges:
            hist.bucket_limit.append(edge)
        for c in counts:
            hist.bucket.append(c)

        # Create and write Summary
        summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
        self.writer.add_summary(summary, step)
        self.writer.flush()
           

在自己的train程式中導入logger.py

from logger import Logger
logger = Logger(log_dir)#log_dir為日志檔案的儲存目錄
           

在訓練過程中以如下方式記錄

if idx % 20== 0:
    # 1. Log scalar values (scalar summary)
    # 日志輸出标量資訊(scalar summary)
    info = { 'loss': mse.item()}#損失函數

    for tag, value in info.items():
        logger.scalar_summary(tag, value, idx)

    # 2. Log values and gradients of the parameters (histogram summary)
    # 日志輸出參數值和梯度(histogram summary)
    for tag, value in net.named_parameters():
        tag = tag.replace('.', '/')
        logger.histo_summary(tag, value.data.cpu().numpy(), idx)
        logger.histo_summary(tag+'/grad', value.grad.data.cpu().numpy(),idx)

    # 3. Log training images (image summary)
    # 日志輸出圖像(image summary)
    info = { 'recon': output.view(-1, opt.image_size, opt.image_size)[:5].cpu().detach().numpy(),
            'orign': input.view(-1, opt.image_size, opt.image_size)[:5].cpu().detach().numpy() }
            #由于我做的是圖像壓縮,是以這裡有recon和orign兩個,[:5]表示顯示五個圖像

    for tag, images in info.items():
        logger.image_summary(tag, images, idx)
           

我認為還是第一種方式更簡單明了一些。

繼續閱讀