Trikang
[Camp-ZipNeRF/Troubleshooting] ZipNeRF에 새 데이터 셋 학습 시 체크포인트 저장에서 unexpected keyword argument 에러 발생 본문
공부/ML
[Camp-ZipNeRF/Troubleshooting] ZipNeRF에 새 데이터 셋 학습 시 체크포인트 저장에서 unexpected keyword argument 에러 발생
Trikang 2024. 6. 2. 17:45문제
데이터 셋을 준비하고 CamP를 돌리기 이전에 ZipNeRF를 먼저 돌리려고 시도하면 10,000번 마다 checkpoint를 저장하는 코드에서 아래와 같은 이슈 발생
"TypeError: PyTreeCheckpointHandler.__init__() got an unexpected keyword argument 'restore_with_serialized_types'"
.
.
.
I0602 16:26:44.424313 140658210706048 train.py:360] 9800/200000: loss=0.00852, psnr=32.213, lr=7.77e-04 | data=0.00726,dist=1.5e-06, inte=1.7e-05, inte=1.5e-05, regu=6.7e-06, regu=3.7e-05, regu=0.00118, 97948 r/s
I0602 16:26:52.930353 140658210706048 train.py:360] 9900/200000: loss=0.00850, psnr=32.282, lr=7.82e-04 | data=0.00725,dist=1.5e-06, inte=1.7e-05, inte=1.7e-05, regu=6.9e-06, regu=3.6e-05, regu=0.00118, 98327 r/s
I0602 16:27:01.891378 140658210706048 train.py:360] 10000/200000: loss=0.00848, psnr=32.268, lr=7.88e-04 | data=0.00722,dist=1.4e-06, inte=1.6e-05, inte=2.1e-05, regu=6.6e-06, regu=3.8e-05, regu=0.00118, 92510 r/s
I0602 16:27:08.685611 140658210706048 train.py:428] Model visualized in 6.794s
I0602 16:27:09.068377 140658210706048 checkpoints.py:567] Saving checkpoint at step: 10000
I0602 16:27:09.068469 140658210706048 checkpoints.py:790] Using Orbax as backend to save Flax checkpoints. For potential troubleshooting see: https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#orbax-as-backend-troubleshooting
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/home/user/3D_survey/models/camp_zipnerf/camp_zipnerf/train.py", line 555, in <module>
app.run(main)
File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/home/user/3D_survey/models/camp_zipnerf/camp_zipnerf/train.py", line 436, in main
checkpoints.save_checkpoint_multiprocess(
File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/flax/training/checkpoints.py", line 804, in save_checkpoint_multiprocess
ocp.PyTreeCheckpointHandler(restore_with_serialized_types=False)
TypeError: PyTreeCheckpointHandler.__init__() got an unexpected keyword argument 'restore_with_serialized_types'
해결 시도
1. 해당 모델 GitHub issue 확인
같은 이슈 리포트 발견. 그러나 CUDA 11.6 환경에서 발생했던 문제라 flax 버전을 낮추는 게 제시된 해결책이었는데, 나는 11.8 버전에서 잘 작동시키고 있었으며 공식 repo에서도 나랑 같은 버전인 0.7.5를 쓰고 있음을 확인해서 이 방법은 패스
https://github.com/jonbarron/camp_zipnerf/issues/2
2. flax GitHub issule 확인
orbax-checkpoint의 모듈 버전을 올려보라고 해서 업데이트 후 학습 재시도
https://github.com/google/flax/issues/3417
새로운 문제 당면
I0602 17:18:17.219559 140427076002432 train.py:360] 10000/200000: loss=0.00948, psnr=32.126, lr=7.88e-04 | data=0.00779,dist=1.2e-06, inte=3.6e-05, inte=3.7e-05, regu=1.2e-05, regu=6.8e-05, regu=0.00153, 92750 r/s
I0602 17:18:23.966293 140427076002432 train.py:428] Model visualized in 6.747s
I0602 17:18:24.349967 140427076002432 checkpoints.py:567] Saving checkpoint at step: 10000
I0602 17:18:24.350065 140427076002432 checkpoints.py:790] Using Orbax as backend to save Flax checkpoints. For potential troubleshooting see: https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html#orbax-as-backend-troubleshooting
W0602 17:18:24.350355 140427076002432 type_handlers.py:302] SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before May 1st, 2024. If your Pytree has empty ([], {}, None) values then use PyTreeCheckpointHandler(..., write_tree_metadata=True, ...)or use StandardCheckpointHandler to avoid TypeHandler Registry error. Please note that PyTreeCheckpointHandler.write_tree_metadata default value is already set to True.
I0602 17:18:24.350629 140427076002432 checkpointer.py:137] Saving item to /home/user/3D_survey/models/camp_zipnerf/output/nerf_synthetic/lego/checkpoint_10000.
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/home/user/3D_survey/models/camp_zipnerf/camp_zipnerf/train.py", line 555, in <module>
app.run(main)
File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/home/user/3D_survey/models/camp_zipnerf/camp_zipnerf/train.py", line 436, in main
checkpoints.save_checkpoint_multiprocess(
File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/flax/training/checkpoints.py", line 821, in save_checkpoint_multiprocess
orbax_checkpointer.save(
File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/orbax/checkpoint/checkpointer.py", line 145, in save
tmpdir = utils.create_tmp_directory(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/orbax/checkpoint/utils.py", line 517, in create_tmp_directory
tmp_dir = get_tmp_directory(
^^^^^^^^^^^^^^^^^^
File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/orbax/checkpoint/utils.py", line 450, in get_tmp_directory
timestamp = multihost.broadcast_one_to_some(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/orbax/checkpoint/multihost/utils.py", line 90,in broadcast_one_to_some
in_tree = jax.tree.map(pre_jit, in_tree)
^^^^^^^^
File "/home/user/anaconda3/envs/camp_zipnerf/lib/python3.11/site-packages/jax/_src/deprecations.py", line 53, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'tree'
3. orbax-checkpoint의 버전 재수정
코드를 찬찬히 분석해보니 orbax-checkpoint의 버전이 0.4에서 0.5로 올라갈 때 변화가 많은 것 같아, camp_zipnerf에서 쓰는 0.4버전을 유지하되 0.4 버전 중에서 제일 높은 버전인 0.4.8 버전으로 변경
pip install orbax-checkpoint==0.4.8
문제 해결!
'공부 > ML' 카테고리의 다른 글
Comments