Trikang
3D Gaussian Splatting - FPS 측정하기 본문
3D-GS 계열 전체에 걸쳐 쓸 수 있는 FPS 측정 스크립트를 야매로 구성해봤다. EAGLES의 코드를 많이 참고함.
사용법은 3dgs 코드의 render.py와 거의 유사하다
아래 소스코드를 복사한 후, 3d-gs 디렉토리에 새로운 python 파일 하나 붙여넣은 다음에 사용하면 된다. 나는 measure_fps.py로 이름 짓고 사용 중.
import os
import json
import torch
import numpy as np
import subprocess as sp
from gaussian_renderer import render
import torch.utils.benchmark as benchmark
from gaussian_renderer import GaussianModel
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, get_combined_args
from scene import Scene
def render_fn(views, gaussians, pipeline, background):
with torch.autocast(device_type='cuda', dtype=torch.float16):
for view in views:
render(view, gaussians, pipeline, background)
def measure_fps(scene, gaussians, pipeline, background):
with torch.no_grad():
views = scene.getTrainCameras() + scene.getTestCameras()
t0 = benchmark.Timer(stmt='render_fn(views, gaussians, pipeline, background)',
setup='from __main__ import render_fn',
globals={'views': views, 'gaussians': gaussians, 'pipeline': pipeline,
'background': background},
)
time = t0.timeit(50)
fps = len(views)/time.median
print("Rendering FPS: ", fps)
return fps
def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams):
with torch.no_grad():
gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
fps = measure_fps(scene, gaussians, pipeline, background)
if __name__ == "__main__":
# Set up command line argument parser
parser = ArgumentParser(description="Testing script parameters")
model = ModelParams(parser, sentinel=True)
pipeline = PipelineParams(parser)
parser.add_argument("--iteration", default=-1, type=int)
args = get_combined_args(parser)
print("Rendering " + args.model_path)
# model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test
print(f"model_extract: {model.extract(args)}")
print(f"args_iteration: {args.iteration}")
print(f"pipeline_extract: {pipeline.extract(args)}")
# print(f"args_skip_train: {args.skip_train}")
# print(f"args_skip_test: {args.skip_test}")
render_sets(model.extract(args), args.iteration, pipeline.extract(args))
실행 예시는 아래와 같다.
python measure_fps.py -m output/fern -s data/fern
'공부 > ML' 카테고리의 다른 글
Comments