本文共 678 字,大约阅读时间需要 2 分钟。
tensorflow mini batch 训练中线程和队列数据输入的问题
实际学习和使用tensorflow的时候,面临大数据量训练的场景,几乎很少使用Session.run中的feed_dict来批量导入数据。tensorflow利用多线程和队列方法异步实现大批量数据的输入,大大节省了数据输入引起的资源浪费。大致的流程如下:
- 先创建一个“先入先出”的队列(FIFOQueue)
- 创建enqueue方法和dequeue对象: enqueue_operation = queue.enqueue inputs = queue.dequeue
- 利用inputs对象作为神经网络图的输入构建神经网络
- 创建cord=tf.train.Coordinator()管理多线程同步
- 启动队列工作tf.train.start_queue_runners(sess=sess, coord=cord)
- 启动线程写入队列start_threads(sess=sess, coord=coord, n_threads=8)
- session.run(),启动训练,这时候是不用通过feed_dict来喂数据的,最关键的理解是第3步,我曾经在运行代码的时候一直不理解,因为enqueue是可以很容易设置断点监控到了,但dequeue方法却没看到,所以痛苦了一段时间。最后单步跟,看到input的提示是这样一个tensor(Tensor(“fifo_queue_Dequeue:0”, dtype=float32, device=/device:GPU:0)),算是想通了这一点。
转载地址:http://cuyrf.baihongyu.com/