|
16 | 16 | "---\n",
|
17 | 17 | "Welcome to Amazon [SageMaker JumpStart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html)! You can use JumpStart to solve many Machine Learning tasks through one-click in SageMaker Studio, or through [SageMaker JumpStart API](https://sagemaker.readthedocs.io/en/stable/overview.html#use-prebuilt-models-with-sagemaker-jumpstart). \n",
|
18 | 18 | "\n",
|
19 |
| - "In this demo notebook, we demonstrate how to use the JumpStart API for Text-to-Image. Text-to-Image is the task of generating realistic image given any text input. Here, we show how to use state-of-the-art pre-trained Stable Diffusion models for generating image from text.\n", |
| 19 | + "In this demo notebook, we demonstrate how to use the JumpStart API for Text-to-Image. Text-to-Image is the task of generating realistic images given any text input. Here, we show how to use state-of-the-art pre-trained Stable Diffusion models for generating image from text.\n", |
20 | 20 | "\n",
|
21 | 21 | "---"
|
22 | 22 | ]
|
|
27 | 27 | "metadata": {},
|
28 | 28 | "source": [
|
29 | 29 | "1. [Set Up](#1.-Set-Up)\n",
|
30 |
| - "2. [Select a model](#2.-Select-a-model)\n", |
31 |
| - "3. [Retrieve JumpStart Artifacts & Deploy an Endpoint](#3.-Retrieve-JumpStart-Artifacts-&-Deploy-an-Endpoint)\n", |
32 |
| - "4. [Query endpoint and parse response](#4.-Query-endpoint-and-parse-response)\n", |
33 |
| - "5. [Advanced features](#5.-Advanced-features)\n", |
34 |
| - "6. [Clean up the endpoint](#6.-Clean-up-the-endpoint)" |
| 30 | + "3. [Retrieve JumpStart Artifacts & Deploy an Endpoint](#2.-Retrieve-JumpStart-Artifacts-&-Deploy-an-Endpoint)\n", |
| 31 | + "4. [Query endpoint and parse response](#3.-Query-endpoint-and-parse-response)\n", |
| 32 | + "5. [Advanced features](#4.-Advanced-features)\n", |
| 33 | + "6. [Clean up the endpoint](#5.-Clean-up-the-endpoint)" |
35 | 34 | ]
|
36 | 35 | },
|
37 | 36 | {
|
|
65 | 64 | "cell_type": "code",
|
66 | 65 | "execution_count": null,
|
67 | 66 | "id": "25293522",
|
68 |
| - "metadata": {}, |
| 67 | + "metadata": { |
| 68 | + "tags": [] |
| 69 | + }, |
69 | 70 | "outputs": [],
|
70 | 71 | "source": [
|
71 | 72 | "!pip install sagemaker ipywidgets --upgrade --quiet"
|
|
88 | 89 | "cell_type": "code",
|
89 | 90 | "execution_count": null,
|
90 | 91 | "id": "90518e45",
|
91 |
| - "metadata": {}, |
| 92 | + "metadata": { |
| 93 | + "tags": [] |
| 94 | + }, |
92 | 95 | "outputs": [],
|
93 | 96 | "source": [
|
94 | 97 | "import sagemaker, boto3, json\n",
|
|
99 | 102 | "sess = sagemaker.Session()"
|
100 | 103 | ]
|
101 | 104 | },
|
102 |
| - { |
103 |
| - "cell_type": "markdown", |
104 |
| - "id": "d2c1a623", |
105 |
| - "metadata": {}, |
106 |
| - "source": [ |
107 |
| - "### 2. Select a model\n", |
108 |
| - "\n", |
109 |
| - "***\n", |
110 |
| - "Here, we download jumpstart model_manifest file from the jumpstart s3 bucket, filter-out all the Text Generation models and select a model for inference. \n", |
111 |
| - "***" |
112 |
| - ] |
113 |
| - }, |
114 |
| - { |
115 |
| - "cell_type": "code", |
116 |
| - "execution_count": null, |
117 |
| - "id": "deecb929", |
118 |
| - "metadata": {}, |
119 |
| - "outputs": [], |
120 |
| - "source": [ |
121 |
| - "from ipywidgets import Dropdown\n", |
122 |
| - "\n", |
123 |
| - "# download JumpStart model_manifest file.\n", |
124 |
| - "boto3.client(\"s3\").download_file(\n", |
125 |
| - " f\"jumpstart-cache-prod-{aws_region}\", \"models_manifest.json\", \"models_manifest.json\"\n", |
126 |
| - ")\n", |
127 |
| - "with open(\"models_manifest.json\", \"rb\") as json_file:\n", |
128 |
| - " model_list = json.load(json_file)\n", |
129 |
| - "\n", |
130 |
| - "# filter-out all the Text Generation models from the manifest list.\n", |
131 |
| - "txt2img_models = []\n", |
132 |
| - "for model in model_list:\n", |
133 |
| - " model_id = model[\"model_id\"]\n", |
134 |
| - " if \"-txt2img-\" in model_id and model_id not in txt2img_models:\n", |
135 |
| - " txt2img_models.append(model_id)\n", |
136 |
| - "\n", |
137 |
| - "# display the model-ids in a dropdown to select a model for inference.\n", |
138 |
| - "model_dropdown = Dropdown(\n", |
139 |
| - " options=txt2img_models,\n", |
140 |
| - " value=\"huggingface-txt2img-stable-diffusion-v1-4\",\n", |
141 |
| - " description=\"Select a model\",\n", |
142 |
| - " style={\"description_width\": \"initial\"},\n", |
143 |
| - " layout={\"width\": \"max-content\"},\n", |
144 |
| - ")" |
145 |
| - ] |
146 |
| - }, |
147 |
| - { |
148 |
| - "cell_type": "markdown", |
149 |
| - "id": "a821a4cf", |
150 |
| - "metadata": {}, |
151 |
| - "source": [ |
152 |
| - "#### Chose a model for Inference" |
153 |
| - ] |
154 |
| - }, |
155 |
| - { |
156 |
| - "cell_type": "code", |
157 |
| - "execution_count": null, |
158 |
| - "id": "01cc6c00", |
159 |
| - "metadata": {}, |
160 |
| - "outputs": [], |
161 |
| - "source": [ |
162 |
| - "display(model_dropdown)" |
163 |
| - ] |
164 |
| - }, |
165 |
| - { |
166 |
| - "cell_type": "code", |
167 |
| - "execution_count": null, |
168 |
| - "id": "2ff82d42", |
169 |
| - "metadata": {}, |
170 |
| - "outputs": [], |
171 |
| - "source": [ |
172 |
| - "# model_version=\"*\" fetches the latest version of the model\n", |
173 |
| - "model_id, model_version = model_dropdown.value, \"*\"" |
174 |
| - ] |
175 |
| - }, |
176 | 105 | {
|
177 | 106 | "cell_type": "markdown",
|
178 | 107 | "id": "8f3ab601",
|
179 | 108 | "metadata": {},
|
180 | 109 | "source": [
|
181 |
| - "### 3. Retrieve JumpStart Artifacts & Deploy an Endpoint\n", |
| 110 | + "### 2. Retrieve JumpStart Artifacts & Deploy an Endpoint\n", |
182 | 111 | "\n",
|
183 | 112 | "***\n",
|
184 | 113 | "\n",
|
|
191 | 120 | "cell_type": "code",
|
192 | 121 | "execution_count": null,
|
193 | 122 | "id": "a8a79ec9",
|
194 |
| - "metadata": {}, |
| 123 | + "metadata": { |
| 124 | + "tags": [] |
| 125 | + }, |
195 | 126 | "outputs": [],
|
196 | 127 | "source": [
|
197 | 128 | "from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n",
|
198 | 129 | "from sagemaker.model import Model\n",
|
199 | 130 | "from sagemaker.predictor import Predictor\n",
|
200 | 131 | "from sagemaker.utils import name_from_base\n",
|
201 | 132 | "\n",
|
| 133 | + "# model_version=\"*\" fetches the latest version of the model\n", |
| 134 | + "model_id, model_version = \"model-txt2img-stabilityai-stable-diffusion-v1-4\", \"*\"\n", |
202 | 135 | "\n",
|
203 | 136 | "endpoint_name = name_from_base(f\"jumpstart-example-infer-{model_id}\")\n",
|
204 | 137 | "\n",
|
|
252 | 185 | "id": "b2e0fd36",
|
253 | 186 | "metadata": {},
|
254 | 187 | "source": [
|
255 |
| - "### 4. Query endpoint and parse response\n", |
| 188 | + "### 3. Query endpoint and parse response\n", |
256 | 189 | "\n",
|
257 | 190 | "---\n",
|
258 | 191 | "Input to the endpoint is any string of text dumped in json and encoded in `utf-8` format. Output of the endpoint is a `json` with generated text.\n",
|
|
264 | 197 | "cell_type": "code",
|
265 | 198 | "execution_count": null,
|
266 | 199 | "id": "84fb30d0",
|
267 |
| - "metadata": {}, |
| 200 | + "metadata": { |
| 201 | + "tags": [] |
| 202 | + }, |
268 | 203 | "outputs": [],
|
269 | 204 | "source": [
|
270 | 205 | "import matplotlib.pyplot as plt\n",
|
|
320 | 255 | "metadata": {
|
321 | 256 | "pycharm": {
|
322 | 257 | "is_executing": true
|
323 |
| - } |
| 258 | + }, |
| 259 | + "tags": [] |
324 | 260 | },
|
325 | 261 | "outputs": [],
|
326 | 262 | "source": [
|
|
339 | 275 | }
|
340 | 276 | },
|
341 | 277 | "source": [
|
342 |
| - "### 5. Advanced features\n", |
| 278 | + "### 4. Advanced features\n", |
343 | 279 | "\n",
|
344 | 280 | "***\n",
|
345 | 281 | "This model also supports many advanced parameters while performing inference. They include:\n",
|
|
362 | 298 | "metadata": {
|
363 | 299 | "pycharm": {
|
364 | 300 | "is_executing": true
|
365 |
| - } |
| 301 | + }, |
| 302 | + "tags": [] |
366 | 303 | },
|
367 | 304 | "outputs": [],
|
368 | 305 | "source": [
|
|
412 | 349 | "id": "870d1173",
|
413 | 350 | "metadata": {},
|
414 | 351 | "source": [
|
415 |
| - "### 6. Clean up the endpoint" |
| 352 | + "### 5. Clean up the endpoint" |
416 | 353 | ]
|
417 | 354 | },
|
418 | 355 | {
|
|
0 commit comments