Source code for dltk.io.abstract_reader
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
import tensorflow as tf
import traceback
[docs]class IteratorInitializerHook(tf.train.SessionRunHook):
"""Hook to initialise data iterator after Session is created."""
def __init__(self):
super(IteratorInitializerHook, self).__init__()
self.iterator_initializer_func = None
[docs] def after_create_session(self, session, coord):
"""Initialise the iterator after the session has been created."""
self.iterator_initializer_func(session)
[docs]class Reader(object):
"""Wrapper for dataset generation given a read function"""
def __init__(self, read_fn, dtypes):
"""Constructs a Reader instance
Args:
read_fn: Input function returning features which is a dictionary of
string feature name to `Tensor` or `SparseTensor`. If it
returns a tuple, first item is extracted as features.
Prediction continues until `input_fn` raises an end-of-input
exception (`OutOfRangeError` or `StopIteration`).
dtypes: A nested structure of tf.DType objects corresponding to
each component of an element yielded by generator.
"""
self.dtypes = dtypes
self.read_fn = read_fn
[docs] def get_inputs(self,
file_references,
mode,
example_shapes=None,
shuffle_cache_size=100,
batch_size=4,
params=None):
"""
Function to provide the input_fn for a tf.Estimator.
Args:
file_references: An array like structure that holds the reference
to the file to read. It can also be None if not needed.
mode: A tf.estimator.ModeKeys. It is passed on to `read_fn` to
trigger specific functions there.
example_shapes (optional): A nested structure of lists or tuples
corresponding to the shape of each component of an element
yielded by generator.
shuffle_cache_size (int, optional): An `int` determining the
number of examples that are held in the shuffle queue.
batch_size (int, optional): An `int` specifying the number of
examples returned in a batch.
params (dict, optional): A `dict` passed on to the `read_fn`.
Returns:
function: a handle to the `input_fn` to be passed the relevant
tf estimator functions.
tf.train.SessionRunHook: A hook to initialize the queue within
the dataset.
"""
iterator_initializer_hook = IteratorInitializerHook()
def train_inputs():
def f():
def clean_ex(ex, compare):
# Clean example dictionary by recursively deleting
# non-relevant entries. However, this does not look into
# dictionaries nested into lists
for k in list(ex.keys()):
if k not in list(compare.keys()):
del ex[k]
elif isinstance(ex[k], dict) and isinstance(compare[k], dict):
clean_ex(ex[k], compare[k])
elif (isinstance(ex[k], dict) and not isinstance(compare[k], dict)) or \
(not isinstance(ex[k], dict) and isinstance(compare[k], dict)):
raise ValueError('Entries between example and '
'dtypes incompatible for key {}'
''.format(k))
elif (isinstance(ex[k], list) and not isinstance(compare[k], list)) or \
(not isinstance(ex[k], list) and isinstance(compare[k], list)) or \
(isinstance(ex[k], list) and isinstance(compare[k], list) and not
len(ex[k]) == len(compare[k])):
raise ValueError('Entries between example and '
'dtypes incompatible for key {}'
''.format(k))
for k in list(compare):
if k not in list(ex.keys()):
raise ValueError('Key {} not found in ex but is '
'present in dtypes. Found keys: '
'{}'.format(k, ex.keys()))
return ex
fn = self.read_fn(file_references, mode, params)
# iterate over all entries - this loop is terminated by the
# tf.errors.OutOfRangeError or StopIteration thrown by the
# read_fn
while True:
try:
ex = next(fn)
if ex.get('labels') is None:
ex['labels'] = None
if not isinstance(ex, dict):
raise ValueError('The read_fn has to return '
'dictionaries')
ex = clean_ex(ex, self.dtypes)
yield ex
except (tf.errors.OutOfRangeError, StopIteration):
raise
except Exception as e:
print('got error `{} from `_read_sample`:'.format(e))
print(traceback.format_exc())
raise
dataset = tf.data.Dataset.from_generator(
f, self.dtypes, example_shapes)
dataset = dataset.repeat(None)
dataset = dataset.shuffle(shuffle_cache_size)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(1)
iterator = dataset.make_initializable_iterator()
next_dict = iterator.get_next()
# Set runhook to initialize iterator
iterator_initializer_hook.iterator_initializer_func = \
lambda sess: sess.run(iterator.initializer)
# Return batched (features, labels)
return next_dict['features'], next_dict.get('labels')
# Return function and hook
return train_inputs, iterator_initializer_hook
[docs] def serving_input_receiver_fn(self, placeholder_shapes):
"""Build the serving inputs.
Args:
placeholder_shapes: A nested structure of lists or tuples
corresponding to the shape of each component of the feature
elements yieled by the read_fn.
Returns:
function: A function to be passed to the tf.estimator.Estimator
instance when exporting a saved model with estimator.export_savedmodel.
"""
def f():
inputs = {k: tf.placeholder(
shape=[None] + list(placeholder_shapes['features'][k]),
dtype=self.dtypes['features'][k]) for k in list(self.dtypes['features'].keys())}
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
return f