Notice
Recent Posts
Recent Comments
Link
«   2024/11   »
1 2
3 4 5 6 7 8 9
10 11 12 13 14 15 16
17 18 19 20 21 22 23
24 25 26 27 28 29 30
Tags
more
Archives
Today
Total
관리 메뉴

Trikang

[Camp-ZipNeRF/Troubleshooting] ZipNeRF에 새 데이터 셋 학습 시 체크포인트 저장에서 unexpected keyword argument 에러 발생 본문

공부/ML

[Camp-ZipNeRF/Troubleshooting] ZipNeRF에 새 데이터 셋 학습 시 체크포인트 저장에서 unexpected keyword argument 에러 발생

Trikang 2024. 6. 2. 17:45

문제

데이터 셋을 준비하고 CamP를 돌리기 이전에 ZipNeRF를 먼저 돌리려고 시도하면 10,000번 마다 checkpoint를 저장하는 코드에서 아래와 같은 이슈 발생

 

"TypeError: PyTreeCheckpointHandler.__init__() got an unexpected keyword argument 'restore_with_serialized_types'"

.
.
.
I0602 16:26:44.424313 140658210706048 train.py:360]    9800/200000: loss=0.00852, psnr=32.213, lr=7.77e-04 | data=0.00726,dist=1.5e-06, inte=1.7e-05, inte=1.5e-05, regu=6.7e-06, regu=3.7e-05, regu=0.00118, 97948 r/s
I0602 16:26:52.930353 140658210706048 train.py:360]    9900/200000: loss=0.00850, psnr=32.282, lr=7.82e-04 | data=0.00725,dist=1.5e-06, inte=1.7e-05, inte=1.7e-05, regu=6.9e-06, regu=3.6e-05, regu=0.00118, 98327 r/s
I0602 16:27:01.891378 140658210706048 train.py:360]   10000/200000: loss=0.00848, psnr=32.268, lr=7.88e-04 | data=0.00722,dist=1.4e-06, inte=1.6e-05, inte=2.1e-05, regu=6.6e-06, regu=3.8e-05, regu=0.00118, 92510 r/s
I0602 16:27:08.685611 140658210706048 train.py:428] Model visualized in 6.794s
I0602 16:27:09.068377 140658210706048 checkpoints.py:567] Saving checkpoint at step: 10000
I0602 16:27:09.068469 140658210706048 checkpoints.py:790] Using Orbax as backend to save Flax checkpoints. For potential troubleshooting see: https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#orbax-as-backend-troubleshooting
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/user/3D_survey/models/camp_zipnerf/camp_zipnerf/train.py", line 555, in <module>
    app.run(main)
  File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/home/user/3D_survey/models/camp_zipnerf/camp_zipnerf/train.py", line 436, in main
    checkpoints.save_checkpoint_multiprocess(
  File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/flax/training/checkpoints.py", line 804, in save_checkpoint_multiprocess
    ocp.PyTreeCheckpointHandler(restore_with_serialized_types=False)
TypeError: PyTreeCheckpointHandler.__init__() got an unexpected keyword argument 'restore_with_serialized_types'

 

해결 시도

1. 해당 모델 GitHub issue 확인

같은 이슈 리포트 발견. 그러나 CUDA 11.6 환경에서 발생했던 문제라 flax 버전을 낮추는 게 제시된 해결책이었는데, 나는 11.8 버전에서 잘 작동시키고 있었으며 공식 repo에서도 나랑 같은 버전인 0.7.5를 쓰고 있음을 확인해서 이 방법은 패스

https://github.com/jonbarron/camp_zipnerf/issues/2

 

Jax Environment for CUDA 11.6 · Issue #2 · jonbarron/camp_zipnerf

Thanks for the authors' outstanding work. However, due to JAX's high dependency on CUDA, the default environment is only compatible with CUDA 11.8 and CUDA 12.2. Here, I provide an environment that...

github.com

 

2. flax GitHub issule 확인

orbax-checkpoint의 모듈 버전을 올려보라고 해서 업데이트 후 학습 재시도

https://github.com/google/flax/issues/3417

 

checkpoints.save_checkpoint error · Issue #3417 · google/flax

Faced this issue in a cluster environment, while it works in colab. Original notebook is the below. (https://colab.research.google.com/github/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_note...

github.com

 

새로운 문제 당면

I0602 17:18:17.219559 140427076002432 train.py:360]   10000/200000: loss=0.00948, psnr=32.126, lr=7.88e-04 | data=0.00779,dist=1.2e-06, inte=3.6e-05, inte=3.7e-05, regu=1.2e-05, regu=6.8e-05, regu=0.00153, 92750 r/s
I0602 17:18:23.966293 140427076002432 train.py:428] Model visualized in 6.747s
I0602 17:18:24.349967 140427076002432 checkpoints.py:567] Saving checkpoint at step: 10000
I0602 17:18:24.350065 140427076002432 checkpoints.py:790] Using Orbax as backend to save Flax checkpoints. For potential troubleshooting see: https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#orbax-as-backend-troubleshooting
W0602 17:18:24.350355 140427076002432 type_handlers.py:302] SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before May 1st, 2024. If your Pytree has empty ([], {}, None) values then use PyTreeCheckpointHandler(..., write_tree_metadata=True, ...)or use StandardCheckpointHandler to avoid TypeHandler Registry error. Please note that PyTreeCheckpointHandler.write_tree_metadata default value is already set to True.
I0602 17:18:24.350629 140427076002432 checkpointer.py:137] Saving item to /home/user/3D_survey/models/camp_zipnerf/output/nerf_synthetic/lego/checkpoint_10000.
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/user/3D_survey/models/camp_zipnerf/camp_zipnerf/train.py", line 555, in <module>
    app.run(main)
  File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/home/user/3D_survey/models/camp_zipnerf/camp_zipnerf/train.py", line 436, in main
    checkpoints.save_checkpoint_multiprocess(
  File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/flax/training/checkpoints.py", line 821, in save_checkpoint_multiprocess
    orbax_checkpointer.save(
  File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/orbax/checkpoint/checkpointer.py", line 145, in save
    tmpdir = utils.create_tmp_directory(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/orbax/checkpoint/utils.py", line 517, in create_tmp_directory
    tmp_dir = get_tmp_directory(
              ^^^^^^^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/orbax/checkpoint/utils.py", line 450, in get_tmp_directory
    timestamp = multihost.broadcast_one_to_some(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/orbax/checkpoint/multihost/utils.py", line 90,in broadcast_one_to_some
    in_tree = jax.tree.map(pre_jit, in_tree)
              ^^^^^^^^
  File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/jax/_src/deprecations.py", line 53, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'tree'

 

3. orbax-checkpoint의 버전 재수정

코드를 찬찬히 분석해보니 orbax-checkpoint의 버전이 0.4에서 0.5로 올라갈 때 변화가 많은 것 같아, camp_zipnerf에서 쓰는 0.4버전을 유지하되 0.4 버전 중에서 제일 높은 버전인 0.4.8 버전으로 변경

pip install orbax-checkpoint==0.4.8

 

문제 해결!

Comments