Skip to content

Reproduce fine tuning but score poorly on the evaluation dataset #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
BenjaminChuang00 opened this issue Apr 1, 2025 · 5 comments

Comments

@BenjaminChuang00
Copy link

Hi,
Thanks to the author for the contribution, but I had some problems reproducing it.
Why do I get bad scores on the evaluation dataset when reproducing fine-tuning results on my RTX3090?
The scores are as follows:

Dataset AbsRel ↓ Delta_1 ↑
NYUv2 0.056 0.963
KITTI 0.092 0.928
ETH3D 0.064 0.961
ScanNet 0.062 0.956
DIODE 0.299 0.780

The following is the train script configuration I use, refer to train_marigold_e2e_ft_depth.sh:

Note.
The following are the modified parts:
--checkpointing_steps 500 => to store the best checkpoint
--dataloader_num_workers 4 => speed up training time
--mixed_precision "bf16" => reduce memory usage
--seed 1234 => fixed seed

The complete script is as follows:

#!/bin/bash

accelerate launch training/train.py \
--pretrained_model_name_or_path "prs-eth/marigold-v1-0" \
--modality "depth" \
--noise_type "zeros" \
--max_train_steps 20000 \
--checkpointing_steps 500 \
--train_batch_size 2 \
--gradient_accumulation_steps 16 \
--gradient_checkpointing \
--learning_rate 3e-05 \
--lr_total_iter_length 20000 \
--lr_exp_warmup_steps 100 \
--dataloader_num_workers 4 \
--mixed_precision "bf16" \
--output_dir "model-finetuned/marigold_e2e_ft_depth_bf16" \
--enable_xformers_memory_efficient_attention \
--seed 1234 \
"$@"
@GonzaloMartinGarcia
Copy link
Collaborator

Hi,

It looks like you’re following the training instructions correctly. However, please note that all of our fine-tuning experiments were conducted using full precision. We haven't tested or validated the E2E fine-tuning setup using mixed precision (FP16 or BF16), which could explain the discrepancies you're seeing in the evaluation scores.
To reproduce the reported performance, we recommend training using full precision.

Let us know if switching to full precision resolves the issue or if you continue to encounter problems!

@BenjaminChuang00
Copy link
Author

Thanks for your reply!

I want to confirm how much GPU VRAM is required for training with the provided settings?

Because the hardware I use is RTX 3090 VRAM 24GB, using mixed precision "no" will cause GPU memory out problem.

+++++++++++

Another problem was found.
In train.py, decode the predicted depth from latent to pixel-wise.
as follows:

current_estimate = decode_image(vae, current_latent_estimate)

This will cause the memory to continue to accumulate, but according to the original paper, VAE is frozen.
Can I add @torch.no_grad() to this function to suppress the increase in memory? Will it affect the training results?

Rewrite as follows:

########################
# VAE Helper Functions
########################

# Apply VAE Encoder to image
@torch.no_grad()
def encode_image(vae, image):
h = vae.encoder(image)
moments = vae.quant_conv(h)
latent, _ = torch.chunk(moments, 2, dim=1)
return latent

# Apply VAE Decoder to latent
@torch.no_grad()
def decode_image(vae, latent):
z = vae.post_quant_conv(latent)
image = vae.decoder(z)
return image

@GonzaloMartinGarcia
Copy link
Collaborator

GonzaloMartinGarcia commented Apr 7, 2025

Hi,

Using torch.no_grad() on the VAE encoder and decoder is not a good approach. Gradients need to flow from the decoded depth maps/normals through the (frozen) decoder and back into the UNet. torch.no_grad() would block any gradients from reaching the UNet.
For the VAE encoder, gradients are never tracked since it is frozen and only provides the RGB latent to the UNet.

A possible option is to try a batch size of 1 with gradient accumulation set to 32.

@Z-Pengcheng
Copy link

Hi, Thanks to the author for the contribution, but I had some problems reproducing it. Why do I get bad scores on the evaluation dataset when reproducing fine-tuning results on my RTX3090? The scores are as follows:

Dataset AbsRel ↓ Delta_1 ↑
NYUv2 0.056 0.963
KITTI 0.092 0.928
ETH3D 0.064 0.961
ScanNet 0.062 0.956
DIODE 0.299 0.780
The following is the train script configuration I use, refer to train_marigold_e2e_ft_depth.sh:

Note.
The following are the modified parts:
--checkpointing_steps 500 => to store the best checkpoint
--dataloader_num_workers 4 => speed up training time
--mixed_precision "bf16" => reduce memory usage
--seed 1234 => fixed seed

The complete script is as follows:

#!/bin/bash

accelerate launch training/train.py \
--pretrained_model_name_or_path "prs-eth/marigold-v1-0" \
--modality "depth" \
--noise_type "zeros" \
--max_train_steps 20000 \
--checkpointing_steps 500 \
--train_batch_size 2 \
--gradient_accumulation_steps 16 \
--gradient_checkpointing \
--learning_rate 3e-05 \
--lr_total_iter_length 20000 \
--lr_exp_warmup_steps 100 \
--dataloader_num_workers 4 \
--mixed_precision "bf16" \
--output_dir "model-finetuned/marigold_e2e_ft_depth_bf16" \
--enable_xformers_memory_efficient_attention \
--seed 1234 \
"$@"

Hi, how do you achieve bf16 or fp16 training on the code? I fixed --mixed_precision="bf16", but it does not work:

Traceback (most recent call last):
  File "/home/zpc/Workspaces/diffusion-e2e-ft/training/train.py", line 637, in <module>
    main()
  File "/home/zpc/Workspaces/diffusion-e2e-ft/training/train.py", line 529, in main
    current_estimate = decode_image(vae, current_latent_estimate)
  File "/home/zpc/Workspaces/diffusion-e2e-ft/training/train.py", line 241, in decode_image
    z = vae.post_quant_conv(latent)
  File "/home/zpc/.miniconda3/envs/diffusion-e2e-ft/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zpc/.miniconda3/envs/diffusion-e2e-ft/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zpc/.miniconda3/envs/diffusion-e2e-ft/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/zpc/.miniconda3/envs/diffusion-e2e-ft/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same

Can you teach me? Thank you!

@BenjaminChuang00
Copy link
Author

Hi, Thanks to the author for the contribution, but I had some problems reproducing it. Why do I get bad scores on the evaluation dataset when reproducing fine-tuning results on my RTX3090? The scores are as follows:
Dataset AbsRel ↓ Delta_1 ↑
NYUv2 0.056 0.963
KITTI 0.092 0.928
ETH3D 0.064 0.961
ScanNet 0.062 0.956
DIODE 0.299 0.780
The following is the train script configuration I use, refer to train_marigold_e2e_ft_depth.sh:

Note.
The following are the modified parts:
--checkpointing_steps 500 => to store the best checkpoint
--dataloader_num_workers 4 => speed up training time
--mixed_precision "bf16" => reduce memory usage
--seed 1234 => fixed seed

The complete script is as follows:

#!/bin/bash

accelerate launch training/train.py \
--pretrained_model_name_or_path "prs-eth/marigold-v1-0" \
--modality "depth" \
--noise_type "zeros" \
--max_train_steps 20000 \
--checkpointing_steps 500 \
--train_batch_size 2 \
--gradient_accumulation_steps 16 \
--gradient_checkpointing \
--learning_rate 3e-05 \
--lr_total_iter_length 20000 \
--lr_exp_warmup_steps 100 \
--dataloader_num_workers 4 \
--mixed_precision "bf16" \
--output_dir "model-finetuned/marigold_e2e_ft_depth_bf16" \
--enable_xformers_memory_efficient_attention \
--seed 1234 \
"$@"

Hi, how do you achieve bf16 or fp16 training on the code? I fixed --mixed_precision="bf16", but it does not work:

Traceback (most recent call last):
  File "/home/zpc/Workspaces/diffusion-e2e-ft/training/train.py", line 637, in <module>
    main()
  File "/home/zpc/Workspaces/diffusion-e2e-ft/training/train.py", line 529, in main
    current_estimate = decode_image(vae, current_latent_estimate)
  File "/home/zpc/Workspaces/diffusion-e2e-ft/training/train.py", line 241, in decode_image
    z = vae.post_quant_conv(latent)
  File "/home/zpc/.miniconda3/envs/diffusion-e2e-ft/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zpc/.miniconda3/envs/diffusion-e2e-ft/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zpc/.miniconda3/envs/diffusion-e2e-ft/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/zpc/.miniconda3/envs/diffusion-e2e-ft/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same

Can you teach me? Thank you!

Hi,
The way I do it is, in train.py, add this section to convert current_latent_estimate to support bf16 or f16, as shown below, refer to the new section:

# clip and threshold prediction (only here for completeness, not used by SD2 or our models with v_prediction)
if noise_scheduler.config.thresholding:
    pred_original_sample = noise_scheduler._threshold_sample(pred_original_sample)
elif noise_scheduler.config.clip_sample:
    pred_original_sample = pred_original_sample.clamp(
        -noise_scheduler.config.clip_sample_range, noise_scheduler.config.clip_sample_range
    )
### new ###
if args.mixed_precision != "no":
    current_latent_estimate = current_latent_estimate.to(dtype=weight_dtype)
### new ###

Good luck and enjoy!

If you have good news, I look forward to your sharing and discussion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants