Git integration now available for the Amazon SageMaker Python SDK
Git integration is now available in the Amazon SageMaker Python SDK. You no longer have to download scripts from a Git repository for training jobs and hosting models. With this new feature, you can use training scripts stored in Git repos directly when training a model in the Python SDK. You can also use hosting scripts stored in Git repos when hosting a model. The scripts are hosted in GitHub, another Git-based repo, or an AWS CodeCommit repo.
This post describes in detail how to use Git integration with the Amazon SageMaker Python SDK.
Overview
When you train a model with the Amazon SageMaker Python SDK, you need a training script that does the following:
- Loads data from the input channels
- Configures training with hyperparameters
- Trains a model
- Saves the model
You specify the script as the value of the entry_point
argument when you create an estimator object.
Previously, when users constructed an Estimator
or Model
object, in the Python SDK, the training script had to be a path in the local file system when you provided it as the entry_point
value. This location was inconvenient when you had training scripts in Git repos because you had to download them locally.
If multiple developers were contributing to the Git repo, you would have to keep track of any updates to the repo. Also, if your local version was out of date, you’d need to pull the latest version prior to every training job. This also makes scheduling periodic training jobs even more challenging.
With the launch of Git integration, these issues are solved, which results in a notable improvement in convenience and productivity.
Walkthrough
Enable the Git integration feature by passing a dict
parameter named git_config
when you create the Estimator
or Model
object. The git_config
parameter provides information about the location of the Git repo that contains the scripts and the authentication for accessing that repo.
Locate the Git repo
To locate the repo that contains the scripts, use the repo
, branch
, and commit
fields in git_config
. The repo
field is required; the other two fields are optional. If you only provide the repo
field, the latest commit in master
branch is used by default:
git_config = {'repo': 'your-git-repo-url'}
To specify a branch, use both the repo
and branch
fields. The latest commit in that branch is used by default:
git_config = {'repo': 'your-git-repo-url', 'branch': 'your-branch'}
To specify a commit of a specific branch in a repo, use all three fields in git_config
:
git_config = {'repo': 'your-git-repo-url',
'branch': 'your-branch',
'commit': 'a-commit-sha-under-this-branch'}
If only the repo
and commit
fields are provided, this works when the commit is under the master
branch and the commit is used. However, if the commit is not under the master
branch, the repo is not found:
git_config = {'repo': 'your-git-repo-url', 'commit': 'a-commit-sha-under-master'}
Get access to the Git repo
If the Git repo is private (all CodeCommit repos are private), you need authentication information to access it.
For CodeCommit repos, first make sure that you set up your authentication method. For more information, see Setting Up for AWS CodeCommit. The topic lists the following ways by which you can authenticate:
- SSH connections
- Git credentials
- AWS CLI Credential Helper
Authentication for SSH URLs
For SSH URLs, you must configure the SSH key pair. This applies to GitHub, CodeCommit, and other Git-based repos.
- For CodeCommit SSH key configuration, see:
- Setup Steps for SSH Connections to AWS CodeCommit Repositories on Linux, macOS, or Unix
- Setup Steps for SSH Connections to AWS CodeCommit Repositories on Windows
- For GitHub SSH key configuration, see Connecting to GitHub with SSH. The SSH key configuration for other Git-based VCSs is similar to that of GitHub.
Do not set an SSH key passphrase for the SSH key pairs. If you do, access to the repo fails.
After the SSH key pair is configured, Git integration works with SSH URLs without further authentication information:
# for GitHub repos
git_config = {'repo': 'git@github.com:your-git-account/your-git-repo.git'}
# for CodeCommit repos
git_config = {'repo': 'ssh://git-codecommit.us-west-2.amazonaws.com/v1/repos/your-repo/'}
Authentication for HTTPS URLs
For HTTPS URLs, there are two ways to deal with authentication:
- Have it configured locally.
- Configure it by providing extra fields in
git_config
, namely2FA_enabled
,username
,password
, andtoken
. Things can be slightly different here between CodeCommit, GitHub, and other Git-based repos.
Authenticating using Git credentials
If you authenticate with Git credentials, you can do one of the following:
- Provide the credentials in
git_config
:git_config = {'repo': 'https://git-codecommit.us-west-2.amazonaws.com/v1/repos/your-repo/', 'username': 'your-username', 'password': 'your-password'}
- Have the credentials stored in local credential storage. Typically, the credentials are stored automatically after you provide them with the AWS CLI. For example, macOS stores credentials in Keychain Access.
With the Git credentials stored locally, you can specify the git_config
parameter without providing the credentials, to avoid showing them in scripts:
git_config = {'repo': 'https://git-codecommit.us-west-2.amazonaws.com/v1/repos/your-repo/'}
Authenticating using AWS CLI Credential Helper
If you follow the setup documentation mentioned earlier to configure AWS CLI Credential Helper, you don’t have to provide any authentication information.
For GitHub and other Git-based repos, check whether two-factor authentication (2FA) is enabled for your account. (Authentication is disabled by default and must be enabled manually.) For more information, see Securing your account with two-factor authentication (2FA).
If 2FA is enabled for your account, provide 2FA_enabled
when specifying git_config
and set it to True
. Otherwise, set it to False
. If 2FA_enabled
is not provided, it is set to False
by default. Usually, you can use either username+password or a personal access token to authenticate for GitHub and other Git-based repos. However, when 2FA is enabled, you can only use a personal access token.
To use username+password for authentication:
git_config = {'repo': 'https://github.com/your-account/your-private-repo.git',
'username': 'your-username',
'password': 'your-password'}
Again, you can store the credentials in local credential storage to avoid showing them in the script.
To use a personal access token for authentication:
git_config = {'repo': 'https://github.com/your-account/your-private-repo.git',
'token': 'your-token'}
Create the estimator or model with Git integration
After you correctly specify git_config
, pass it as a parameter when you create the estimator or model object to enable Git integration. Then, make sure that the entry_point
, source_dir
, and dependencies
are all be relative paths under the Git repo.
You know that if source_dir
is provided, entry_point
should be a relative path from the source directory. The same is true with Git integration.
For example, with the following structure of the Git repo ‘amazon-sagemaker-examples’ under branch ‘training-scripts’:
amazon-sagemaker-examples
|
|-------------char-rnn-tensorflow
| |----------train.py
| |----------utils.py
| |----------other files
|
|-------------pytorch-rnn-scripts
|-------------.gitignore
|-------------README.md
You can create the estimator object as follows:
git_config = {'repo': 'https://github.com/awslabs/amazon-sagemaker-examples.git', 'branch': 'training-scripts'}
estimator = TensorFlow(entry_point='train.py',
source_dir='char-rnn-tensorflow',
git_config=git_config,
train_instance_type=train_instance_type,
train_instance_count=1,
role=sagemaker.get_execution_role(), # Passes to the container the AWS role that you are using on this notebook
framework_version='1.13',
py_version='py3',
script_mode=True)
In this example, source_dir 'char-rnn-tensorflow'
is a relative path inside the Git repo, while entry_point 'train.py'
is a relative path under ‘char-rnn-tensorflow’.
Git integration example
Now let’s look at a complete example of using Git integration. This example trains a multi-layer LSTM RNN model on a language modeling task based on PyTorch example. By default, the training script uses the Wikitext-2 dataset. We train a model on SageMaker, deploy it, and then use deployed model to generate new text.
Run the commands in a Python script, except for those that start with a ‘!’, which are bash commands.
First let’s do the setup:
import sagemaker
sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()
prefix = 'sagemaker/DEMO-pytorch-rnn-lstm'
role = sagemaker.get_execution_role()
Next get the dataset. This data is from Wikipedia and is licensed CC-BY-SA-3.0. Before you use this data for any other purpose than this example, you should understand the data license, described at https://creativecommons.org/licenses/by-sa/3.0/:
!wget http://research.metamind.io.s3.amazonaws.com/wikitext/wikitext-2-raw-v1.zip
!unzip -n wikitext-2-raw-v1.zip
!cd wikitext-2-raw
!mv wiki.test.raw test && mv wiki.train.raw train && mv wiki.valid.raw valid
Upload the data to S3:
inputs = sagemaker_session.upload_data(path='wikitext-2-raw', bucket=bucket, key_prefix=prefix)
Specify git_config
and create the estimator with it:
from sagemaker.pytorch import PyTorch
git_config = {'repo': 'https://github.com/awslabs/amazon-sagemaker-examples.git', 'branch': 'training-scripts'}
estimator = PyTorch(entry_point='train.py',
role=role,
framework_version='1.1.0',
train_instance_count=1,
train_instance_type='ml.c4.xlarge',
source_dir='pytorch-rnn-scripts',
git_config=git_config,
hyperparameters={
'epochs': 6,
'tied': True
})
Train the mode:
estimator.fit({'training': inputs})
Next let’s host the model. We are going to provide custom implementation of model_fn
, input_fn
, output_fn
, and predict_fn
hosting functions in a separate file ‘generate.py’, which is in the same Git repo. The PyTorch model uses a npy serializer and deserializer by default. For this example, since we have a custom implementation of all the hosting functions and plan on using JSON instead, we need a predictor that can serialize and deserialize JSON:
from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer
class JSONPredictor(RealTimePredictor):
def __init__(self, endpoint_name, sagemaker_session):
super(JSONPredictor, self).__init__(endpoint_name, sagemaker_session, json_serializer, json_deserializer)
Create the model object:
from sagemaker.pytorch import PyTorchModel
training_job_name = estimator.latest_training_job.name
desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
trained_model_location = desc['ModelArtifacts']['S3ModelArtifacts']
model = PyTorchModel(model_data=trained_model_location,
role=role,
framework_version='1.0.0',
entry_point='generate.py',
source_dir='pytorch-rnn-scripts',
git_config=git_config,
predictor_cls=JSONPredictor)
Create the hosting endpoint:
predictor = model.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')
Now we are going to use our deployed model to generate text by providing random seed, temperature (higher will increase diversity), and number of words we would like to get:
input = {
'seed': 111,
'temperature': 2.0,
'words': 100
}
response = predictor.predict(input)
print(response)
You get the following results:
acids west 'igan 1232 keratinous Andrews argue cancel mauling even incorporating Jewish
centimetres Fang Andres cyclic logjams filth nullity Homarinus pilaris Emperors whoops punts
followed Reichsgau envisaged Invisible alcohols are osteoarthritis twilight Alexandre Odes Bucanero Genesis
crimson Hutchison genus Brighton 1532 0226284301 Harikatha p Assault Vaisnava plantie 1829
Totals established outcast hurricane herbs revel Lebens Metoposaurids Pajaka initialize frond discarding
walking Unusually Ľubomír Springboks reviewing leucocythemia blistered kinder Nowels arriving 1350 Weymouth
Saigon cantonments genealogy alleging Upright typists termini doodle conducts parallelisms cypresses consults
others estate cover passioned recognition channelled breathed straighter Visibly dug blanche motels
Barremian quickness constrictor reservist
Finally delete the endpoint after you are done using it:
sagemaker_session.delete_endpoint(predictor.endpoint)
Conclusion
In this post, I walked through how to use Git integration with the Amazon SageMaker Python SDK. With Git integration, you no longer have to download scripts from Git repos for training jobs and hosting models. Now you can use scripts in Git repos directly, simply by passing an additional parameter git_config
when creating the Estimator
or Model
object.
If you have questions or suggestions, please leave them in the comments.
About the Authors
Yue Tu is a summer intern on the AWS SageMaker ML Frameworks team. He works on Git integration for the SageMaker Python SDK during his internship. Outside of work he likes playing basketball, his favorite basketball teams are the Golden State Warriors and Duke basketball team. He also likes paying attention to nothing for some time.
Chuyang Deng is a software development engineer on the AWS SageMaker ML Frameworks team. She enjoys playing LEGO alone.