Debug in TensorFlow
官方介绍Guide – Debugging – TensorFlow Debugger中已经介绍的很详细了, 但是有几个疏漏的点需要补充一下.
- 在 使用 tfdbg 调试模型训练 章节中提到了可以通过自定义filter进行相关张量的筛选.
def my_filter_callable(datum, tensor):
# A filter that detects zero-valued scalars.
return len(tensor.shape) == 0 and tensor == 0.0
sess.add_tensor_filter('my_filter', my_filter_callable)
该filter是通过添加到Session中发挥作用的. 而Estimator则是在内部管理Session, 无法通过这种方式显式的增加filter.
# First, let your BUILD target depend on "//tensorflow/python/debug:debug_py"
# (You don't need to worry about the BUILD dependency if you are using a pip
# install of open-source TensorFlow.)
from tensorflow.python import debug as tf_debug
# Create a LocalCLIDebugHook and use it as a monitor when calling fit().
hooks = [tf_debug.LocalCLIDebugHook()]
# To debug `train`:
classifier.train(input_fn, steps=1000, hooks=hooks)
目前看起来正确的方式应该是通过LocalCLIDebugHook实例的add_tensor_filter方法进行添加.
sess通过tf_debug.LocalCLIDebugWrapperSession包裹之后, 其新添了add_tensor_filter方法, 源码是
def add_tensor_filter(self, filter_name, tensor_filter):
"""Add a tensor filter.
Args:
filter_name: (`str`) name of the filter.
tensor_filter: (`callable`) the filter callable. See the doc string of
`DebugDumpDir.find()` for more details about its signature.
"""
self._tensor_filters[filter_name] = tensor_filter
而LocalCLIDebugHook中的add_tensor_filter方法和上一种是相同的.
def add_tensor_filter(self, filter_name, tensor_filter):
"""Add a tensor filter.
See doc of `LocalCLIDebugWrapperSession.add_tensor_filter()` for details.
Override default behavior to accommodate the possibility of this method being
called prior to the initialization of the underlying
`LocalCLIDebugWrapperSession` object.
Args:
filter_name: See doc of `LocalCLIDebugWrapperSession.add_tensor_filter()`
for details.
tensor_filter: See doc of
`LocalCLIDebugWrapperSession.add_tensor_filter()` for details.
"""
if self._session_wrapper:
self._session_wrapper.add_tensor_filter(filter_name, tensor_filter)
else:
self._pending_tensor_filters[filter_name] = tensor_filter