在训练wide&deep这个模型时报了错误,下面是错误详情:
ValueError: Tensor("num_parallel_calls:0", shape=(), dtype=int32, device=/device:CPU:0)
must be from the same graph as Tensor("FlatMapDataset_1:0", shape=(), dtype=variant).
查了很多方法,大部分都是tf.Graph()类型的解决方案,并不适合我代码,因为代码中并没有引入session部分。
直到我看到代码中有:
def input_fn(data_file,num_epochs,shuffle,batch_size):
"""
input function for the Estimator
:param dataset:
:param num_epochs:
:param shuffle:
:param batch_size:
:return:
"""
def parse_csv(value):
columns=tf.decode_csv(value,record_defaults=_CSV_COLUMN_DEFAULTS)
features=dict(zip(_CSV_COLUMNS_NAME,columns))
labels=features.pop('target')
return features,labels
dataset=tf.data.TextLineDataset(data_file)
if shuffle:
dataset=dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])
def tf_read_file(file):
assert tf.gfile.Exists(file),print('{} is not found'.format(file))
dataset=tf.data.TextLineDataset(file) #每一个元素对应一行
return dataset
train=tf_read_file(args.train_data)
model.train(input_fn=lambda :input_fn(train,args.epoch_per_eval,False,args.batch_size))
我才明白可能是train数据放入input_fn()方法时有问题,尝试了一下果然如此。
解决方法:
- 在input_fn()参数中将训练数据路径填进去,在该方法内读取数据,不要直接传数据过来