Tensorflow教程:输入管道中的重复混洗

时间:2015-12-19 14:36:25

标签: python tensorflow

Tensorflow reading data tutorial中给出了一个示例输入管道。在该管道中,数据在string_input_producer内以及shuffle batch generator内混洗两次。这是代码:

def input_pipeline(filenames, batch_size, num_epochs=None):
  # Fist shuffle in the input pipeline
  filename_queue = tf.train.string_input_producer(
      filenames, num_epochs=num_epochs, shuffle=True)

  example, label = read_my_file_format(filename_queue)
  min_after_dequeue = 10000
  capacity = min_after_dequeue + 3 * batch_size
  # Second shuffle as part of the batching. 
  # Requiring min_after_dequeue preloaded images
  example_batch, label_batch = tf.train.shuffle_batch(
      [example, label], batch_size=batch_size, capacity=capacity,
      min_after_dequeue=min_after_dequeue)

  return example_batch, label_batch

第二次洗牌是否有用?混洗批处理生成器的缺点是min_after_dequeue示例总是预先存储在存储器中以允许有用的混洗。我的图像数据确实很大,内存消耗很大。这就是我考虑使用normal batch generator的原因。将数据混洗两次有什么好处吗?

编辑:附加问题,为什么string_input_producer仅在默认容量为32时初始化?将batch_size的倍数作为容量不是有利的吗?

3 个答案:

答案 0 :(得分:6)

是的 - 这是一种常见的模式,它以最一般的方式显示。 string_input_producer随机播放数据文件的读取顺序。为了提高效率,每个数据文件通常包含许多示例。 (读取一百万个小文件非常慢;最好每个1000个例子读取1000个大文件。)

因此,文件中的示例被读入一个混洗队列,在那里它们以更精细的粒度进行混洗,因此来自同一文件的示例并不总是以相同的顺序进行训练,并且在整个过程中进行混合。输入文件。

有关详细信息,请参阅Getting good mixing with many input datafiles in tensorflow

如果你的文件每个只包含一个输入示例,那么你不需要多次洗牌,只能使用string_input_producer,但请注意,你仍然可以从拥有一个阅读后几张图片,以便您可以重叠网络的输入和培训。 queue_runnerbatch的{​​{1}}将在单独的主题中运行,确保I / O在后台运行,并且图像始终可用于培训。而且,当然,创建微型飞机进行训练的速度通常很不错。

答案 1 :(得分:0)

两种洗牌都有不同的用途,可以改变不同的东西:

  • tf.train.string_input_producer shuffle:Boolean。如果为true,则在每个纪元内随机改组字符串。。因此,如果您有一些文件['file1', 'file2', ..., 'filen'],则会从此列表中随机选择一个文件。如果是false,则文件会一个接一个地跟着。
  • tf.train.shuffle_batch 通过随机调整张量来创建批量。因此,您的队列batch_size需要read_my_file_format个张量并对其进行随机播放。

因为两个shuffle做不同的事情,所以将数据洗牌两次是有利的。即使您使用一批256个图像,并且每个图像是256x256像素,您将消耗少于100 Mb的内存。如果在某些时候您会看到内存问题,可以尝试减小批量大小。

关于默认容量 - 它是model specific。让它比batch_size更大是有意义的,并确保它在训练期间永远不会是空的。

答案 2 :(得分:0)

要回答其他问题, RecyclerView lstMovsWallet = (RecyclerView) findViewById(R.id.lstMovsWallet); lstMovsWallet.setLayoutManager(new LinearLayoutManager(MovsMobileWallet.this)); AdapterCobrosPendientesListado adapter = new AdapterCobrosPendientesListado(MovsMobileWallet.this, items); lstMovsWallet.setAdapter(adapter); Adapter for de RecyclerView : public class AdapterCobrosPendientesListado extends RecyclerView.Adapter<AdapterCobrosPendientesListado.ViewHolder> { private LayoutInflater mInflater; protected List<MovimientoCuenta> items; public AdapterCobrosPendientesListado(Context context, List<MovimientoCuenta> data) { this.mInflater = LayoutInflater.from(context); this.items = data; } @Override public AdapterCobrosPendientesListado.ViewHolder onCreateViewHolder(ViewGroup parent, int viewType) { View view = mInflater.inflate(R.layout.activity_adapter_billings_listhistory, parent, false); ViewHolder viewHolder = new ViewHolder(view); return viewHolder; } @Override public void onBindViewHolder(AdapterCobrosPendientesListado.ViewHolder holder, int position) { DecimalFormat formater = new DecimalFormat("###.00"); String numero = items.get(position).getNumber(); String cantidad = items.get(position).getMonto(); String fecha = items.get(position).getFecha(); String referencia = items.get(position).getReferencia(); String debitoCredito = items.get(position).getDebitoCredito(); holder.number.setText(numero); holder.mount.setText(cantidad); holder.date.setText(fecha); holder.ref.setText(referencia); if(debitoCredito.compareTo("DBT")==0){ holder.title.setText("Pago"); holder.auxBilling.setImageResource(R.mipmap.signonegativo); } else { holder.title.setText("Cobro"); holder.auxBilling.setImageResource(R.mipmap.signomas); } } @Override public int getItemCount() { return items.size(); } public class ViewHolder extends RecyclerView.ViewHolder implements View.OnClickListener { public TextView number; public TextView mount; public TextView date; public ImageView auxBilling; public TextView ref; public TextView title ; public ViewHolder(View itemView) { super(itemView); number = itemView.findViewById(R.id.txtNumberPhoneBilling); mount = itemView.findViewById(R.id.txtMountBillingNotifications); date = itemView.findViewById(R.id.txtDateBillingNotifications); auxBilling = itemView.findViewById(R.id.btnCancelBillingNotifications); ref = itemView.findViewById(R.id.txtDateBillingRef); title = itemView.findViewById(R.id.TitleMovs); itemView.setOnClickListener(this); } @Override public void onClick(View view) { // if (mClickListener != null) mClickListener.onItemClick(view, getAdapterPosition()); } } /* // convenience method for getting data at click position public String getItem(int id) { return mData.get(id); } // allows clicks events to be caught public void setClickListener(ItemClickListener itemClickListener) { this.mClickListener = itemClickListener; } // parent activity will implement this method to respond to click events public interface ItemClickListener { void onItemClick(View view, int position); }*/ } 会返回一个包含文件名的队列,其中包含样本,而不是样本本身。然后public class MovimientoCuenta { private String number; private String monto; private String moneda; private String fecha; private String ID; private String referencia ; private String filtro ; private String debitoCredito ; private String nombreMov; public MovimientoCuenta(String number, String monto, String moneda, String fecha, String ID, String referencia, String filtro, String debitoCredito,String nombreMov) { this.number = number; this.monto = monto; this.moneda = moneda; this.fecha = fecha; this.ID = ID ; this.filtro =filtro; this.referencia=referencia; this.debitoCredito =debitoCredito; this.nombreMov =nombreMov; } 使用此文件名加载数据。因此,加载的样本数与 string_input_producer 函数的shuffle_batch参数相关,而不是capacity