TensorFlow/Command Line Args
From charlesreid1
Passing Command Line Arguments
Why
First, why would you want to pass arguments into a TensorFlow model from the command line?
- While it's possible to use a notebook to train models, eventually the models will probably go into production
- Training models in Cloud ML Engine requires passing arguments via command line
- Parameterizing models using command line arguments make them easier to run from other scripts and do hyperparameter optimization
How
Couple of methods to pass arguments into a TensorFlow model:
- argparse - this is the most straightforward way; can parse arguments using argparse.ArgumentParser() and then pass all of the arguments into a dictionary. That can then be passed into the model run/train function call using the double splat operator
- gflags - this is the Google command line argument parsing library; more full-featured than argparse
- tf.app.flags - built-in and convenient, but mainly used for keeping demos pithy; mimics gflats behavior
- tf.app.run - combines argument parser with app runner, passes arguments on to main()
- tf.app.run and tf.app.flags in combination - parses arguments with tf.app.flags, and passes un-parsed arguments to app.run (and thus on to main())
- sys.argv - get raw command line arguments, split at spaces, manually parse
Cloud ML Engine
Using command line arguments is particularly important when bundling a TensorFlow model for Cloud ML Engine.
For an example of how to do this, see the GCP training-data-analyst repository, the taxicab fare prediction model:
- https://github.com/GoogleCloudPlatform/training-data-analyst/blob/master/courses/machine_learning/cloudmle/taxifare/trainer/task.py - contains the functions that run the "task" (training or prediction)
- task.py parses the input arguments using argparse.ArgumentParser() object
- each argument parsed is added to a dictionary, and the arguments in the dictionary are passed into the training call using the double-splat operator, which unpacks them
Here's the relevant code in the task.py file:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Input Arguments
parser.add_argument(
'--train_data_paths',
help='GCS or local path to training data',
required=True
)
parser.add_argument(
'--num_epochs',
help="""\
Maximum number of training data epochs on which to train.
If both --max-steps and --num-epochs are specified,
the training job will run for --max-steps or --num-epochs,
whichever occurs first. If unspecified will run for --max-steps.\
""",
type=int,
)
etc... Several other arguments are read from the command line. Next, these are added to a dict:
args = parser.parse_args()
arguments = args.__dict__
Finally, these are passed to the model's training run using the double-splat:
# Run the training job
learn_runner.run(generate_experiment_fn(**arguments), output_dir)
These arguments are passed to the generate_experiment_fn, which in turn passes them to the model object that is created:
def generate_experiment_fn(train_data_paths,
eval_data_paths,
format,
num_epochs=None,
train_batch_size=512,
eval_batch_size=512,
hidden_units=None,
**experiment_args):
def _experiment_fn(output_dir):
input_fn = model.generate_csv_input_fn
train_input = input_fn(
train_data_paths, num_epochs=num_epochs, batch_size=train_batch_size)
eval_input = input_fn(
eval_data_paths, batch_size=eval_batch_size, mode=tf.contrib.learn.ModeKeys.EVAL)
return Experiment(
model.build_estimator(
output_dir,
hidden_units=hidden_units
),
train_input_fn=train_input,
eval_input_fn=eval_input,
export_strategies=[saved_model_export_utils.make_export_strategy(
model.serving_input_fn,
default_output_alternative_key=None,
exports_to_keep=1
)],
eval_metrics=model.get_eval_metrics(),
#min_eval_frequency = 1000, # change this to speed up training on large datasets
**experiment_args
)
return _experiment_fn
Note that this example is pretty complicated - as are many of the TensorFlow examples - because there are so many nested function calls and so many parameters being passed around, from one function call to another.