package cupdnn.layer; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.util.Vector; import cupdnn.Network; import cupdnn.active.ReluActivationFunc; import cupdnn.active.SigmodActivationFunc; import cupdnn.active.TanhActivationFunc; import cupdnn.data.Blob; import cupdnn.util.MathFunctions; import cupdnn.util.Task; import cupdnn.util.ThreadPoolManager; /* * ���� */ public class Conv2dLayer extends Layer{ public static final String TYPE = "Conv2dLayer"; private Blob kernel; private Blob bias; private Blob kernelGradient; private Blob biasGradient; private Blob z; private int width; private int height; private int inChannel; private int outChannel; private int kernelSize; private int stride; public Conv2dLayer(Network network){ super(network); } public Conv2dLayer(Network network,int width,int height,int inChannel,int outChannel,int kernelSize,int stride) { // TODO Auto-generated constructor stub super(network); this.width = width; this.height = height; this.inChannel = inChannel; this.outChannel = outChannel; this.kernelSize = kernelSize; this.stride = stride; } @Override public String getType() { // TODO Auto-generated method stub return TYPE; } @Override public void prepare() { // TODO Auto-generated method stub //layerParams.getHeight()��ʾ�ò���Ҫ��ȡ�������� if(kernel ==null && bias == null){ kernel = new Blob(inChannel*outChannel,kernelSize,kernelSize); bias = new Blob(outChannel); //init params MathFunctions.gaussianInitData(kernel.getData()); MathFunctions.constantInitData(bias.getData(), 0.001f); } z = new Blob(mNetwork.getBatch(),outChannel,height,width); kernelGradient = new Blob(inChannel*outChannel,kernelSize,kernelSize); biasGradient = new Blob(outChannel); } @Override public void forward() { // TODO Auto-generated method stub Blob input = mNetwork.getDatas().get(id-1); Blob output = mNetwork.getDatas().get(id); float [] outputData = output.getData(); float [] zData = z.getData(); //���� if(activationFunc!=null){ //����Ľ������z�� z.fillValue(0); MathFunctions.conv2dBlobSame(mNetwork,input, kernel, bias, z); Vector<Task<Object>> workers = new Vector<Task<Object>>(); for(int n=0;n<output.getNumbers();n++){ workers.add(new Task<Object>(n) { @Override public Object call() throws Exception { for(int c=0;c<output.getChannels();c++){ for(int h=0;h<output.getHeight();h++){ for(int w=0;w<output.getWidth();w++){ outputData[output.getIndexByParams(n, c, h, w)] = activationFunc.active(zData[z.getIndexByParams(n, c, h, w)]); } } } return null; } }); } ThreadPoolManager.getInstance(mNetwork).dispatchTask(workers); }else { //����Ľ������output�� output.fillValue(0); MathFunctions.conv2dBlobSame(mNetwork,input, kernel, bias, output); } } @Override public void backward() { // TODO Auto-generated method stub Blob input = mNetwork.getDatas().get(id-1); Blob inputDiff = mNetwork.getDiffs().get(id); Blob outputDiff = mNetwork.getDiffs().get(id-1); float[] inputDiffData = inputDiff.getData(); float[] zData = z.getData(); float[] kernelGradientData = kernelGradient.getData(); float[] inputData = input.getData(); float[] biasGradientData = biasGradient.getData(); //�ȳ˼����ĵ��,�õ��ò����� Vector<Task<Object>> workers = new Vector<Task<Object>>(); if(activationFunc!=null){ for(int n=0;n<inputDiff.getNumbers();n++){ workers.add(new Task<Object>(n) { @Override public Object call() throws Exception { for(int c=0;c<inputDiff.getChannels();c++){ for(int h=0;h<inputDiff.getHeight();h++){ for(int w=0;w<inputDiff.getWidth();w++){ inputDiffData[inputDiff.getIndexByParams(n, c, h, w)] *= activationFunc.diffActive(zData[z.getIndexByParams(n, c, h, w)]); } } } return null; } }); } ThreadPoolManager.getInstance(mNetwork).dispatchTask(workers); } //Ȼ���²�� //����kernelGradient,��ﲢ������kernel,kernel���Ż���и��� kernelGradient.fillValue(0); workers.clear(); for(int n=0;n<inputDiff.getNumbers();n++){ workers.add(new Task<Object>(n) { @Override public Object call() throws Exception { for(int ci=0;ci<inputDiff.getChannels();ci++){ for(int co=0;co<outputDiff.getChannels();co++) { for(int h=0;h<inputDiff.getHeight();h++){ for(int w=0;w<inputDiff.getWidth();w++){ //�ȶ�λ������λ�� //Ȼ���kernel,ͨ��kernel��λ�����λ�� //Ȼ���������diff int inStartX = w - kernelGradient.getWidth()/2; int inStartY = h - kernelGradient.getHeight()/2; //�;��˳˼� for(int kh=0;kh<kernelGradient.getHeight();kh++){ for(int kw=0;kw<kernelGradient.getWidth();kw++){ int inY = inStartY + kh; int inX = inStartX + kw; if (inY >= 0 && inY < input.getHeight() && inX >= 0 && inX < input.getWidth()){ kernelGradientData[kernelGradient.getIndexByParams(0, ci*outputDiff.getChannels()+co, kh, kw)] += inputData[input.getIndexByParams(n,co , inY, inX)] *inputDiffData[inputDiff.getIndexByParams(n, ci, h, w)]; } } } } } } } return null; } }); } ThreadPoolManager.getInstance(mNetwork).dispatchTask(workers); //ƽ�� MathFunctions.dataDivConstant(kernelGradientData, inputDiff.getNumbers()); //����bias biasGradient.fillValue(0); for(int n=0;n<inputDiff.getNumbers();n++){ for(int c=0;c<inputDiff.getChannels();c++){ for(int h=0;h<inputDiff.getHeight();h++){ for(int w=0;w<inputDiff.getWidth();w++){ biasGradientData[bias.getIndexByParams(0, 0, 0, c)] += inputDiffData[inputDiff.getIndexByParams(n, c, h, w)]; } } } } //ƽ�� MathFunctions.dataDivConstant(biasGradientData, inputDiff.getNumbers()); if(id<=1)return; //�Ȱ�kernel�ת180�� //Blob kernelRoate180 = MathFunctions.rotate180Blob(kernel); //Ȼ������� outputDiff.fillValue(0); MathFunctions.conv2dBlobSame(mNetwork,inputDiff, kernel, outputDiff); mNetwork.updateW(kernel, kernelGradient); mNetwork.updateW(bias, biasGradient); } @Override public void saveModel(ObjectOutputStream out) { // TODO Auto-generated method stub try { out.writeUTF(getType()); out.writeInt(width); out.writeInt(height); out.writeInt(inChannel); out.writeInt(outChannel); out.writeInt(kernelSize); out.writeInt(stride); out.writeObject(kernel); out.writeObject(bias); if(activationFunc != null){ out.writeUTF(activationFunc.getType()); } } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } @Override public void loadModel(ObjectInputStream in) { // TODO Auto-generated method stub try { width = in.readInt(); height = in.readInt(); inChannel = in.readInt(); outChannel = in.readInt(); kernelSize = in.readInt(); stride = in.readInt(); kernel = (Blob) in.readObject(); bias = (Blob) in.readObject(); String activationType = in.readUTF(); if(activationType.equals(ReluActivationFunc.TYPE)){ setActivationFunc(new ReluActivationFunc()); }else if(activationType.equals(SigmodActivationFunc.TYPE)){ setActivationFunc(new SigmodActivationFunc()); }else if(activationType.equals(TanhActivationFunc.TYPE)){ setActivationFunc(new TanhActivationFunc()); } } catch (ClassNotFoundException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } @Override public Blob createOutBlob() { // TODO Auto-generated method stub return new Blob(mNetwork.getBatch(),outChannel,height,width); } @Override public Blob createDiffBlob() { // TODO Auto-generated method stub return new Blob(mNetwork.getBatch(),outChannel,height,width); } }