728x90
문제 : def load_model의 model.load_state_dict에서 에러 발생
학습된 모델을 가지고 터미널에서 inference.py를 실행했더니 아래와 같은 에러가 발생했다.
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for BaseModel:
Missing key(s) in state_dict: "conv1.weight", "conv1.bias", "conv2.weight", "conv2.bias", "conv3.weight", "conv3.bias", "fc.weight", "fc.bias".
Unexpected key(s) in state_dict: "model.features.0.0.weight", "model.features.0.1.weight",
...
...
...
요지는 다음과 같은 부분이었다.
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for BaseModel: Missing key(s) in state_dict:
다음 링크를 참고하여 문제를 해결할 수 있었다.
inference.py 에서 load_state_dict 안에 stric=False를 추가해주었더니 해결되었다. 이 경우 파이토치의 버전에 따른 문제이며 출처에서는 torch 0.4.0를 사용하기 때문에 아래와 같이 strict=False를 추가해주어야했다고 한다. 따라서 torch 0.4.1 미만(=0.4.0 이하)을 사용할 경우 나타나는 문제로 판단된다.
model.load_state_dict(checkpoint['state_dict'], strict=False)
728x90