package cupdnn.layer; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import cupdnn.Network; import cupdnn.data.Blob; public class RecurrentLayer extends Layer { Network mNetwork; Cell mCell; int seqLen; int inSize; int hidenSize; public static final String TYPE = "RecurrentLayer"; public static enum RecurrentType{ RNN, LSTM, GRU } RecurrentType type; public RecurrentLayer(Network network) { super(network); // TODO Auto-generated constructor stub this.mNetwork = network; } public RecurrentLayer(Network network,RecurrentType type,int seqLen,int inSize,int hidenSize) { this(network); this.seqLen = seqLen; this.inSize = inSize; this.hidenSize = hidenSize; this.type = type; } @Override public Blob createOutBlob() { // TODO Auto-generated method stub return new Blob(seqLen,mNetwork.getBatch()/seqLen,hidenSize); } @Override public Blob createDiffBlob() { // TODO Auto-generated method stub return new Blob(seqLen,mNetwork.getBatch()/seqLen,hidenSize); } @Override public String getType() { // TODO Auto-generated method stub return TYPE; } @Override public void prepare() { // TODO Auto-generated method stub switch(type) { case RNN: if(mCell==null) { mCell = new RnnCell(mNetwork,this,inSize,hidenSize); } mCell.prepare(); break; case LSTM: break; case GRU: break; } } @Override public void forward() { // TODO Auto-generated method stub Blob input = mNetwork.getDatas().get(id-1); Blob output = mNetwork.getDatas().get(id); float[] inputData = input.getData(); float[] outputData = output.getData(); Blob tmpIn = new Blob(mNetwork.getBatch()/seqLen,inSize); Blob tmpOut = new Blob(mNetwork.getBatch()/seqLen,hidenSize); for(int i=0;i<seqLen;i++) { if(i==0) { mCell.resetState(); } float[] tmpInData = tmpIn.getData(); int tmpInSize = tmpIn.getSize(); //ÿ��ȥȡһ����� for(int j=0;j<tmpInSize;j++) { tmpInData[j] = inputData[i*tmpInSize+j]; } mCell.forward(tmpIn,tmpOut); float[] tmpOutData = tmpOut.getData(); int tmpOutSize = tmpOut.getSize(); for(int j=0;j<tmpOutSize;j++) { outputData[i*tmpOutSize+j] = tmpOutData[j]; } } } @Override public void backward() { // TODO Auto-generated method stub Blob inputDiff = mNetwork.getDiffs().get(id); Blob outputDiff = mNetwork.getDiffs().get(id-1); Blob input = mNetwork.getDatas().get(id-1); Blob output = mNetwork.getDatas().get(id); Blob tmpIn = new Blob(mNetwork.getBatch()/seqLen,inSize); Blob tmpOut = new Blob(mNetwork.getBatch()/seqLen,hidenSize); Blob tmpInDiff = new Blob(mNetwork.getBatch()/seqLen,hidenSize); Blob tmpOutDiff = new Blob(mNetwork.getBatch()/seqLen,inSize); float[] inputData = input.getData(); float[] outputData = output.getData(); float[] inputDiffData = inputDiff.getData(); float[] outputDiffData = outputDiff.getData(); for(int i=0;i<seqLen;i++) { //һ��ȡ����е�һ�� float[] tmpInData = tmpIn.getData(); int tmpInSize = tmpIn.getSize(); for(int j=0;j<tmpInSize;j++) { tmpInData[j] = inputData[i*tmpInSize+j]; } float[] tmpOutData = tmpOut.getData(); int tmpOutSize = tmpOut.getSize(); for(int j=0;j<tmpOutSize;j++) { tmpOutData[j] = outputData[i*tmpOutSize+j]; } float[] tmpInDiffData = tmpInDiff.getData(); int tmpInDiffSize = tmpInDiff.getSize(); for(int j=0;j<tmpInDiffSize;j++) { tmpInDiffData[j] = inputDiffData[i*tmpInDiffSize+j]; } mCell.backward(tmpIn,tmpOut,tmpInDiff,tmpOutDiff); //������Ľ���˳����outputDiffBlob float[] tmpOutDiffData = tmpOutDiff.getData(); int tmpOutDifftSize = tmpOutDiff.getSize(); for(int j=0;j<tmpOutDifftSize;j++) { outputDiffData[i*tmpOutDifftSize+j] = tmpOutDiffData[j]; } } } @Override public void saveModel(ObjectOutputStream out) { // TODO Auto-generated method stub try { out.writeUTF(getType()); out.writeInt(seqLen); out.writeInt(inSize); out.writeInt(hidenSize); mCell.saveModel(out); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } @Override public void loadModel(ObjectInputStream in) { // TODO Auto-generated method stub try { seqLen = in.readInt(); inSize = in.readInt(); hidenSize = in.readInt(); String type = in.readUTF(); if(type.equals("RNN")) { this.type = RecurrentType.RNN; mCell = new RnnCell(mNetwork,RecurrentLayer.this); mCell.loadModel(in); } } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } public int getSeqLen() { return seqLen; } }