728x90

문제 : def load_model의 model.load_state_dict에서 에러 발생

 

 

학습된 모델을 가지고 터미널에서 inference.py를 실행했더니 아래와 같은 에러가 발생했다.

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for BaseModel:
Missing key(s) in state_dict: "conv1.weight", "conv1.bias", "conv2.weight", "conv2.bias", "conv3.weight", "conv3.bias", "fc.weight", "fc.bias".
Unexpected key(s) in state_dict: "model.features.0.0.weight", "model.features.0.1.weight",
...
...
...

 

요지는 다음과 같은 부분이었다.

 

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for BaseModel: Missing key(s) in state_dict:

 

다음 링크를 참고하여 문제를 해결할 수 있었다.

 

inference.py 에서 load_state_dict 안에 stric=False를 추가해주었더니 해결되었다. 이 경우 파이토치의 버전에 따른 문제이며 출처에서는 torch 0.4.0를 사용하기 때문에 아래와 같이 strict=False를 추가해주어야했다고 한다. 따라서 torch 0.4.1 미만(=0.4.0 이하)을 사용할 경우 나타나는 문제로 판단된다.

model.load_state_dict(checkpoint['state_dict'], strict=False)

 

 

 

728x90