TensorFlow Tutorial : Debug

Debug in TensorFlow

官方介绍Guide – Debugging – TensorFlow Debugger中已经介绍的很详细了, 但是有几个疏漏的点需要补充一下.

  1. 使用 tfdbg 调试模型训练 章节中提到了可以通过自定义filter进行相关张量的筛选.
def my_filter_callable(datum, tensor):
  # A filter that detects zero-valued scalars.
  return len(tensor.shape) == 0 and tensor == 0.0

sess.add_tensor_filter('my_filter', my_filter_callable)


该filter是通过添加到Session中发挥作用的. 而Estimator则是在内部管理Session, 无法通过这种方式显式的增加filter.

# First, let your BUILD target depend on "//tensorflow/python/debug:debug_py"
# (You don't need to worry about the BUILD dependency if you are using a pip
#  install of open-source TensorFlow.)
from tensorflow.python import debug as tf_debug

# Create a LocalCLIDebugHook and use it as a monitor when calling fit().
hooks = [tf_debug.LocalCLIDebugHook()]

# To debug `train`:
classifier.train(input_fn, steps=1000,  hooks=hooks)

目前看起来正确的方式应该是通过LocalCLIDebugHook实例的add_tensor_filter方法进行添加.

sess通过tf_debug.LocalCLIDebugWrapperSession包裹之后, 其新添了add_tensor_filter方法, 源码是

  def add_tensor_filter(self, filter_name, tensor_filter):
    """Add a tensor filter.

    Args:
      filter_name: (`str`) name of the filter.
      tensor_filter: (`callable`) the filter callable. See the doc string of
        `DebugDumpDir.find()` for more details about its signature.
    """

    self._tensor_filters[filter_name] = tensor_filter

而LocalCLIDebugHook中的add_tensor_filter方法和上一种是相同的.

  def add_tensor_filter(self, filter_name, tensor_filter):
    """Add a tensor filter.

    See doc of `LocalCLIDebugWrapperSession.add_tensor_filter()` for details.
    Override default behavior to accommodate the possibility of this method being
    called prior to the initialization of the underlying
    `LocalCLIDebugWrapperSession` object.

    Args:
      filter_name: See doc of `LocalCLIDebugWrapperSession.add_tensor_filter()`
        for details.
      tensor_filter: See doc of
        `LocalCLIDebugWrapperSession.add_tensor_filter()` for details.
    """

    if self._session_wrapper:
      self._session_wrapper.add_tensor_filter(filter_name, tensor_filter)
    else:
      self._pending_tensor_filters[filter_name] = tensor_filter

发表评论

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据