TensorFlow/Command Line Args
From charlesreid1
Passing Command Line Arguments
Why
First, why would you want to pass arguments into a TensorFlow model from the command line?
- While it's possible to use a notebook to train models, eventually the models will probably go into production
- Training models in Cloud ML Engine requires passing arguments via command line
- Parameterizing models using command line arguments make them easier to run from other scripts and do hyperparameter optimization
How
Couple of methods to pass arguments into a TensorFlow model:
Recommended:
- argparse - this is the most straightforward way; can parse arguments using argparse.ArgumentParser() and then pass all of the arguments into a dictionary. That can then be passed into the model run/train function call using the double splat operator
Other methods:
- gflags - this is the Google command line argument parsing library; more full-featured than argparse
- tf.app.flags - built-in and convenient, but mainly used for keeping demos pithy; mimics gflats behavior
- tf.app.run - combines argument parser with app runner, passes arguments on to main()
- tf.app.run and tf.app.flags in combination - parses arguments with tf.app.flags, and passes un-parsed arguments to app.run (and thus on to main())
- sys.argv - get raw command line arguments, split at spaces, manually parse
If you take a look at the Cloud ML Engine examples repo on github (googlecloudplatform/cloudml-samples) you'll see that they primarily use the argparse library to parse input arguments.
Cloud ML Engine
Using command line arguments is particularly important when bundling a TensorFlow model for Cloud ML Engine.
For an example of how to do this, see the GCP training-data-analyst repository, the taxicab fare prediction model:
- https://github.com/GoogleCloudPlatform/training-data-analyst/blob/master/courses/machine_learning/cloudmle/taxifare/trainer/task.py - contains the functions that run the "task" (training or prediction)
- task.py parses the input arguments using argparse.ArgumentParser() object
- each argument parsed is added to a dictionary, and the arguments in the dictionary are passed into the training call using the double-splat operator, which unpacks them
Here's the relevant code in the task.py file:
if __name__ == '__main__': parser = argparse.ArgumentParser() # Input Arguments parser.add_argument( '--train_data_paths', help='GCS or local path to training data', required=True ) parser.add_argument( '--num_epochs', help="""\ Maximum number of training data epochs on which to train. If both --max-steps and --num-epochs are specified, the training job will run for --max-steps or --num-epochs, whichever occurs first. If unspecified will run for --max-steps.\ """, type=int, )
etc... Several other arguments are read from the command line. Next, these are added to a dict:
args = parser.parse_args() arguments = args.__dict__
Finally, these are passed to the model's training run using the double-splat:
# Run the training job learn_runner.run(generate_experiment_fn(**arguments), output_dir)
These arguments are passed to the generate_experiment_fn, which in turn passes them to the model object that is created.
here is the generate_experiment_fn()
header with parameters:
def generate_experiment_fn(train_data_paths, eval_data_paths, format, num_epochs=None, train_batch_size=512, eval_batch_size=512, hidden_units=None, **experiment_args):
Within the body of generate_experiment_fn()
, these parameters are passed on to the model's input function, input_fn()
:
def _experiment_fn(output_dir): input_fn = model.generate_csv_input_fn train_input = input_fn( train_data_paths, num_epochs=num_epochs, batch_size=train_batch_size) eval_input = input_fn( eval_data_paths, batch_size=eval_batch_size, mode=tf.contrib.learn.ModeKeys.EVAL) return Experiment( model.build_estimator( output_dir, hidden_units=hidden_units ), train_input_fn=train_input, eval_input_fn=eval_input, export_strategies=[saved_model_export_utils.make_export_strategy( model.serving_input_fn, default_output_alternative_key=None, exports_to_keep=1 )], eval_metrics=model.get_eval_metrics(), #min_eval_frequency = 1000, # change this to speed up training on large datasets **experiment_args ) return _experiment_fn
Note that this example is pretty complicated - as are many of the TensorFlow examples - because there are so many nested function calls and so many parameters being passed around, from one function call to another. However, all of the input parameters are fed to the program from the command line, and all of the input parameters that are passed from function call to function call make their way into the call to create the model, thereby ending up in model.py, where the actual TensorFlow model is defined.