import sys
#sys.path.insert(0,'..')
import pickle
from functools import partial
#from twod.hg_3d import *
from model.Pose3D import *
from model.DepthRegressor3D import *
from model.HourGlassNet3D import *
from model.HourGlass3D import *
from model.Layers3D import *
import inflation.Inflate as Inflate
import torch
import ref



def inflate(opt = None):
	if opt is not None:
		model3d = HourglassNet3D(opt.nChannels, opt.nStack, opt.nModules, opt.numReductions, ref.nJoints)
		Inflate.nChannels = opt.nChannels
		Inflate.nStack = opt.nStack
		Inflate.nModules = opt.nModules
		Inflate.nRegFrames = opt.nRegFrames
		Inflate.nJoints = ref.nJoints
	else :
		model3d = HourglassNet3D()
	pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
	pickle.load = partial(pickle.load, encoding="latin1")
	if opt is not None:
		model = torch.load(opt.Model2D)
	else:
		model = torch.load('models/hgreg-3d.pth') #, map_location=lambda storage, loc: storage)

	Inflate.inflateHourglassNet(model3d, model)

	torch.save(model3d,open('inflatedModel.pth','wb'))

	return model3d