|
29 | 29 | )
|
30 | 30 | from monarch.debugger import init_debugging
|
31 | 31 |
|
32 |
| -from monarch.mesh_controller import spawn_tensor_engine |
| 32 | +from monarch._rust_bindings import has_tensor_engine |
| 33 | + |
| 34 | +if has_tensor_engine(): |
| 35 | + from monarch.mesh_controller import spawn_tensor_engine |
| 36 | +else: |
| 37 | + spawn_tensor_engine = None |
33 | 38 |
|
34 | 39 | from monarch.proc_mesh import local_proc_mesh, proc_mesh
|
35 | 40 | from monarch.rdma import RDMABuffer
|
@@ -114,6 +119,10 @@ async def get_buffer(self):
|
114 | 119 | return self.buffer
|
115 | 120 |
|
116 | 121 |
|
| 122 | +@pytest.mark.skipif( |
| 123 | + not torch.cuda.is_available(), |
| 124 | + reason="CUDA not available", |
| 125 | +) |
117 | 126 | async def test_proc_mesh_rdma():
|
118 | 127 | proc = await proc_mesh(gpus=1)
|
119 | 128 | server = await proc.spawn("server", ParameterServer)
|
@@ -282,6 +291,10 @@ async def update_weights(self):
|
282 | 291 | ), f"{torch.sum(self.generator.weight.data)=}, {self.step=}"
|
283 | 292 |
|
284 | 293 |
|
| 294 | +@pytest.mark.skipif( |
| 295 | + not torch.cuda.is_available(), |
| 296 | + reason="CUDA not available", |
| 297 | +) |
285 | 298 | async def test_gpu_trainer_generator():
|
286 | 299 | trainer_proc = await proc_mesh(gpus=1)
|
287 | 300 | gen_proc = await proc_mesh(gpus=1)
|
@@ -311,6 +324,10 @@ async def test_sync_actor():
|
311 | 324 | assert r == 5
|
312 | 325 |
|
313 | 326 |
|
| 327 | +@pytest.mark.skipif( |
| 328 | + not torch.cuda.is_available(), |
| 329 | + reason="CUDA not available", |
| 330 | +) |
314 | 331 | def test_gpu_trainer_generator_sync() -> None:
|
315 | 332 | trainer_proc = proc_mesh(gpus=1).get()
|
316 | 333 | gen_proc = proc_mesh(gpus=1).get()
|
@@ -391,6 +408,10 @@ def check(module, path):
|
391 | 408 | check(bindings, "monarch._rust_bindings")
|
392 | 409 |
|
393 | 410 |
|
| 411 | +@pytest.mark.skipif( |
| 412 | + not has_tensor_engine(), |
| 413 | + reason="Tensor engine not available", |
| 414 | +) |
394 | 415 | @pytest.mark.skipif(
|
395 | 416 | torch.cuda.device_count() < 2,
|
396 | 417 | reason="Not enough GPUs, this test requires at least 2 GPUs",
|
|
0 commit comments