This script is designed to load a pretrained PyTorch model for MNIST digit classification from a tar.gz file, extract it, and use the model to perform inference on a given input image. Ensure you have all required dependencies installed:
pipinstallPillowtorchtorchvision
# content of the inference.py fileimport torchimport torchvision.transforms as transformsfrom PIL import Imagefrom torch.autograd import Variableimport argparseimport tarfileclassCustomModel(torch.nn.Module):def__init__(self):super(CustomModel, self).__init__() self.conv1 = torch.nn.Conv2d(1, 10, 5) self.conv2 = torch.nn.Conv2d(10, 20, 5) self.fc1 = torch.nn.Linear(320, 50) self.fc2 = torch.nn.Linear(50, 10)defforward(self,x): x = torch.relu(self.conv1(x)) x = torch.max_pool2d(x, 2) x = torch.relu(self.conv2(x)) x = torch.max_pool2d(x, 2) x = torch.flatten(x, 1) x = torch.relu(self.fc1(x)) x = self.fc2(x) output = torch.log_softmax(x, dim=1)return outputdefextract_tar_gz(file_path,output_dir):with tarfile.open(file_path, 'r:gz')as tar: tar.extractall(path=output_dir)# Parse command-line argumentsparser = argparse.ArgumentParser()parser.add_argument('--tar_gz_file_path', type=str, required=True, help='Path to the tar.gz file')parser.add_argument('--output_directory', type=str, required=True, help='Output directory to extract the tar.gz file')parser.add_argument('--image_path', type=str, required=True, help='Path to the input image file')args = parser.parse_args()# Extract the tar.gz filetar_gz_file_path = args.tar_gz_file_pathoutput_directory = args.output_directoryextract_tar_gz(tar_gz_file_path, output_directory)# Load the modelmodel_path =f"{output_directory}/model.pth"model =CustomModel()model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))model.eval()# Transformations for the MNIST datasettransform = transforms.Compose([ transforms.Resize((28, 28)), transforms.Grayscale(num_output_channels=1), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)),])# Function to run inference on an imagedefrun_inference(image,model): image_tensor =transform(image).unsqueeze(0)# Apply transformations and add batch dimensioninput=Variable(image_tensor)# Perform inference output =model(input) _, predicted = torch.max(output.data, 1)return predicted.item()# Example usageimage_path = args.image_pathimage = Image.open(image_path)predicted_class =run_inference(image, model)print(f"Predicted class: {predicted_class}")
To use this script, you need to provide the paths to the tar.gz file containing the pre-trained model, the output directory where the model will be extracted, and the input image file for which you want to perform inference. The script will output the predicted digit (class) for the given input image.
export JOB_ID=$( ... ): Export results of a command execution as environment variable
-w /inputs Set the current working directory at /inputs in the container
-i src=s3://sagemaker-sample-files/datasets/image/MNIST/model/pytorch-training-2020-11-21-22-02-56-203/model.tar.gz,dst=/model/,opt=region=us-east-1: Mount the s3 bucket at the destination path provided - /model/ and specifying the region where the bucket is located opt=region=us-east-1
-i git://github.com/js-ts/mnist-test.git: Flag to mount the source code repo from GitHub. It would mount the repo at /inputs/js-ts/mnist-test in this case it also contains the test image
pytorch/pytorch: The name of the Docker image
-- python3 /inputs/js-ts/mnist-test/inference.py --tar_gz_file_path /model/model.tar.gz --output_directory /model-pth --image_path /inputs/js-ts/mnist-test/image.png: The command to run inference on the model. It consists of:
/model/model.tar.gz is the path to the model file
/model-pth is the output directory for the model
/inputs/js-ts/mnist-test/image.png is the path to the input image