Java Code Examples for org.nd4j.autodiff.samediff.SameDiff#variables()

The following examples show how to use org.nd4j.autodiff.samediff.SameDiff#variables() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: IntegrationTestRunner.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static void assertSameDiffEquals(SameDiff sd1, SameDiff sd2){
    assertEquals(sd1.variableMap().keySet(), sd2.variableMap().keySet());
    assertEquals(sd1.getOps().keySet(), sd2.getOps().keySet());
    assertEquals(sd1.inputs(), sd2.inputs());

    //Check constant and variable arrays:
    for(SDVariable v : sd1.variables()){
        String n = v.name();
        assertEquals(n, v.getVariableType(), sd2.getVariable(n).getVariableType());
        if(v.isConstant() || v.getVariableType() == VariableType.VARIABLE){
            INDArray a1 = v.getArr();
            INDArray a2 = sd2.getVariable(n).getArr();
            assertEquals(n, a1, a2);
        }
    }

    //Check ops:
    for(SameDiffOp o : sd1.getOps().values()){
        SameDiffOp o2 = sd2.getOps().get(o.getName());
        assertEquals(o.getOp().getClass(), o2.getOp().getClass());
    }
}
 
Example 2
Source File: IntegrationTestRunner.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
private static Map<String,INDArray> getConstantCopies(SameDiff sd){
    Map<String,INDArray> out = new HashMap<>();
    for(SDVariable v : sd.variables()){
        if(v.isConstant()){
            out.put(v.name(), v.getArr());
        }
    }
    return out;
}
 
Example 3
Source File: UIListener.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
protected void checkStructureForRestore(SameDiff sd){
    LogFileWriter.StaticInfo si;
    try {
        si = writer.readStatic();
    } catch (IOException e){
        throw new RuntimeException("Error restoring existing log file, static info at path: " + logFile.getAbsolutePath(), e);
    }

    List<Pair<UIStaticInfoRecord, Table>> staticList = si.getData();
    if(si != null) {
        UIGraphStructure structure = null;
        for (int i = 0; i < staticList.size(); i++) {
            UIStaticInfoRecord r = staticList.get(i).getFirst();
            if (r.infoType() == UIInfoType.GRAPH_STRUCTURE){
                structure = (UIGraphStructure) staticList.get(i).getSecond();
                break;
            }
        }

        if(structure != null){
            int nInFile = structure.inputsLength();
            List<String> phs = new ArrayList<>(nInFile);
            for( int i=0; i<nInFile; i++ ){
                phs.add(structure.inputs(i));
            }

            List<String> actPhs = sd.inputs();
            if(actPhs.size() != phs.size() || !actPhs.containsAll(phs)){
                throw new IllegalStateException("Error continuing collection of UI stats in existing model file " + logFile.getAbsolutePath() +
                        ": Model structure differs. Existing (file) model placeholders: " + phs + " vs. current model placeholders: " + actPhs +
                        ". To disable this check, use FileMode.CREATE_APPEND_NOCHECK though this may result issues when rendering data via UI");
            }

            //Check variables:
            int nVarsFile = structure.variablesLength();
            List<String> vars = new ArrayList<>(nVarsFile);
            for( int i=0; i<nVarsFile; i++ ){
                vars.add(structure.variables(i).name());
            }
            List<SDVariable> sdVars = sd.variables();
            List<String> varNames = new ArrayList<>(sdVars.size());
            for(SDVariable v : sdVars){
                varNames.add(v.name());
            }

            if(varNames.size() != vars.size() || !varNames.containsAll(vars)){
                int countDifferent = 0;
                List<String> different = new ArrayList<>();
                for(String s : varNames){
                    if(!vars.contains(s)){
                        countDifferent++;
                        if(different.size() < 10){
                            different.add(s);
                        }
                    }
                }
                StringBuilder msg = new StringBuilder();
                msg.append("Error continuing collection of UI stats in existing model file ")
                        .append(logFile.getAbsolutePath())
                        .append(": Current model structure differs vs. model structure in file - ").append(countDifferent).append(" variable names differ.");
                if(different.size() == countDifferent){
                    msg.append("\nVariables in new model not present in existing (file) model: ").append(different);
                } else {
                    msg.append("\nFirst 10 variables in new model not present in existing (file) model: ").append(different);
                }
                msg.append("\nTo disable this check, use FileMode.CREATE_APPEND_NOCHECK though this may result issues when rendering data via UI");

                throw new IllegalStateException(msg.toString());
            }
        }
    }

    checkStructureForRestore = false;
}