import torch import yaml import argparse import sys from packaging import version sys.path.append('.') import models import os import pdb model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) parser = argparse.ArgumentParser(description='PyTorch Kinematics') parser.add_argument('--config', required=True) parser.add_argument('--iter', type=int, required=True) args = parser.parse_args() def main(): exp_dir = os.path.dirname(args.config) with open(args.config) as f: if version.parse(yaml.version >= "5.1"): config = yaml.load(f, Loader=yaml.FullLoader) else: config = yaml.load(f) for k, v in config.items(): setattr(args, k, v) model = models.modules.__dict__[args.model['module']['arch']](args.model['module']) model = torch.nn.DataParallel(model) ckpt_path = exp_dir + '/checkpoints/ckpt_iter_{}.pth.tar'.format(args.iter) save_path = exp_dir + '/checkpoints/convert_iter_{}.pth.tar'.format(args.iter) ckpt = torch.load(ckpt_path) weight = ckpt['state_dict'] model.load_state_dict(weight, strict=True) model = model.module.image_encoder torch.save(model.state_dict(), save_path) if __name__ == "__main__": main()