Skip to content

Commit e374074

Browse files
authored
Windows test (huggingface#896)
* add generate_sharktank for stable_diffusion model defaults * add windows test for sd --------- Co-authored-by: dan <[email protected]>
1 parent 81e3d1c commit e374074

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

.github/workflows/test-models.yml

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
strategy:
3030
fail-fast: true
3131
matrix:
32-
os: [icelake, a100, MacStudio, ubuntu-latest]
32+
os: [7950x, icelake, a100, MacStudio, ubuntu-latest]
3333
suite: [cpu,cuda,vulkan]
3434
python-version: ["3.10"]
3535
include:
@@ -52,13 +52,19 @@ jobs:
5252
suite: cuda
5353
- os: a100
5454
suite: cpu
55+
- os: 7950x
56+
suite: cpu
57+
- os: 7950x
58+
suite: cuda
5559

5660
runs-on: ${{ matrix.os }}
5761

5862
steps:
5963
- uses: actions/checkout@v3
64+
if: matrix.os != '7950x'
6065

6166
- name: Set Environment Variables
67+
if: matrix.os != '7950x'
6268
run: |
6369
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
6470
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
@@ -78,6 +84,9 @@ jobs:
7884
#cache-dependency-path: |
7985
# **/requirements-importer.txt
8086
# **/requirements.txt
87+
88+
- uses: actions/checkout@v2
89+
if: matrix.os == '7950x'
8190

8291
- name: Install dependencies
8392
if: matrix.suite == 'lint'
@@ -130,10 +139,17 @@ jobs:
130139
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" -k vulkan --update_tank
131140
132141
- name: Validate Vulkan Models (a100)
133-
if: matrix.suite == 'vulkan' && matrix.os != 'MacStudio'
142+
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
134143
run: |
135144
cd $GITHUB_WORKSPACE
136145
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
137146
source shark.venv/bin/activate
138147
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="${GITHUB_WORKSPACE}/shark_tmp/shark_cache" -k vulkan
139148
python build_tools/stable_diffusion_testing.py --device=vulkan
149+
150+
151+
- name: Validate Stable Diffusion Models (Windows)
152+
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
153+
run: |
154+
./setup_venv.ps1
155+
python build_tools/stable_diffusion_testing.py --device=vulkan

shark/examples/shark_inference/stable_diffusion/utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,7 @@ def compile_through_fx(
115115

116116
save_dir = os.path.join(args.local_tank_cache, model_name)
117117

118-
(
119-
mlir_module,
120-
func_name,
121-
) = import_with_fx(
118+
mlir_module, func_name, = import_with_fx(
122119
model=model,
123120
inputs=inputs,
124121
is_f16=is_f16,

0 commit comments

Comments
 (0)