天天看点

Tensorflow训练模型报错:must be from the same graph as Tensor

在训练wide&deep这个模型时报了错误,下面是错误详情:

ValueError: Tensor("num_parallel_calls:0", shape=(), dtype=int32, device=/device:CPU:0) 
must be from the same graph as Tensor("FlatMapDataset_1:0", shape=(), dtype=variant).
           

查了很多方法,大部分都是tf.Graph()类型的解决方案,并不适合我代码,因为代码中并没有引入session部分。

直到我看到代码中有:

def input_fn(data_file,num_epochs,shuffle,batch_size):
    """
    input function for the Estimator
    :param dataset:
    :param num_epochs:
    :param shuffle:
    :param batch_size:
    :return:
    """
    def parse_csv(value):

        columns=tf.decode_csv(value,record_defaults=_CSV_COLUMN_DEFAULTS)
        features=dict(zip(_CSV_COLUMNS_NAME,columns))
        labels=features.pop('target')

        return features,labels

    dataset=tf.data.TextLineDataset(data_file)
    if shuffle:
        dataset=dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])

def tf_read_file(file):
    assert tf.gfile.Exists(file),print('{} is not found'.format(file))
    dataset=tf.data.TextLineDataset(file)   #每一个元素对应一行
    return dataset

train=tf_read_file(args.train_data)

model.train(input_fn=lambda :input_fn(train,args.epoch_per_eval,False,args.batch_size))
           

我才明白可能是train数据放入input_fn()方法时有问题,尝试了一下果然如此。

解决方法:

  • 在input_fn()参数中将训练数据路径填进去,在该方法内读取数据,不要直接传数据过来

继续阅读