package info.semanticanalyzer.classifiers.weka; import info.semanticanalyzer.classifiers.weka.threeway.ThreeWayMNBTrainer; import junit.framework.Assert; import org.apache.commons.io.IOUtils; import org.junit.Test; import java.io.FileInputStream; /** * Created by Dmitry Kan on 27.04.2014. */ public class ThreeWayMNBTrainerTest { ThreeWayMNBTrainer threeWayMnbTrainer; String modelFile = "models/three-way-sentiment-mnb.model"; private static final String PERFOMRANCE_TEST_CONTENT_FILE = "src/test/resources/en_imdb_sentences.txt"; @org.junit.Before public void setUp() throws Exception { threeWayMnbTrainer = new ThreeWayMNBTrainer(modelFile); } @org.junit.Test public void testAddTrainingInstance() throws Exception { threeWayMnbTrainer.addTrainingInstance(SentimentClass.ThreeWayClazz.NEGATIVE, new String[] {"dislike"}); threeWayMnbTrainer.addTrainingInstance(SentimentClass.ThreeWayClazz.POSITIVE, new String[] {"like"}); threeWayMnbTrainer.showInstances(); } @org.junit.Test public void testTrainModel() throws Exception { threeWayMnbTrainer.addTrainingInstance(SentimentClass.ThreeWayClazz.NEGATIVE, new String[] {"dislike"}); threeWayMnbTrainer.addTrainingInstance(SentimentClass.ThreeWayClazz.POSITIVE, new String[] {"like"}); threeWayMnbTrainer.trainModel(); threeWayMnbTrainer.testModel(); } @org.junit.Test public void testSaveModel() throws Exception { threeWayMnbTrainer.addTrainingInstance(SentimentClass.ThreeWayClazz.NEGATIVE, new String[] {"dislike"}); threeWayMnbTrainer.addTrainingInstance(SentimentClass.ThreeWayClazz.POSITIVE, new String[] {"like"}); threeWayMnbTrainer.trainModel(); threeWayMnbTrainer.testModel(); threeWayMnbTrainer.saveModel(); System.out.println("===== Loading and testing model ===="); threeWayMnbTrainer.loadModel(modelFile); threeWayMnbTrainer.testModel(); } @org.junit.Test public void testExistingModel() throws Exception { threeWayMnbTrainer.addTrainingInstance(SentimentClass.ThreeWayClazz.NEGATIVE, new String[] {"dislike"}); threeWayMnbTrainer.addTrainingInstance(SentimentClass.ThreeWayClazz.POSITIVE, new String[] {"like"}); threeWayMnbTrainer.loadModel(modelFile); threeWayMnbTrainer.testModel(); } @org.junit.Test public void testArbitraryTextPositive() throws Exception { threeWayMnbTrainer.loadModel(modelFile); Assert.assertEquals(SentimentClass.ThreeWayClazz.POSITIVE, threeWayMnbTrainer.classify("I like this weather")); } @org.junit.Test public void testArbitraryTextNegative() throws Exception { threeWayMnbTrainer.loadModel(modelFile); Assert.assertEquals(SentimentClass.ThreeWayClazz.NEGATIVE, threeWayMnbTrainer.classify("I dislike this weather")); } @org.junit.Test public void testArbitraryTextMixed() throws Exception { threeWayMnbTrainer.loadModel(modelFile); Assert.assertEquals(SentimentClass.ThreeWayClazz.NEUTRAL, threeWayMnbTrainer.classify("I really don't know whether I like or dislike this weather")); } @Test public void testPerformance() throws Exception { String content = IOUtils.toString(new FileInputStream(PERFOMRANCE_TEST_CONTENT_FILE), "UTF-8"); String[] lines = content.split("\n"); int wordsCount = getWordsCount(lines); threeWayMnbTrainer.loadModel(modelFile); test(lines, wordsCount, content.length()); // warm up test(lines, wordsCount, content.length()); // test } private int getWordsCount(String[] texts) { int count = 0; for (String str : texts) { count += str.split("\\s+").length; } return count; } private void test(String[] texts, int wordsCount, int totalLength) throws Exception { System.out.println("Testing on " + texts.length + " samples, " + wordsCount + " words, " + totalLength + " characters..."); long startTime = System.currentTimeMillis(); for (String str : texts) { // to print out the predicted labels, uncomment the line: //System.out.println(threeWayMnbTrainer.classify(str).name()); threeWayMnbTrainer.classify(str).name(); } long elapsedTime = System.currentTimeMillis() - startTime; System.out.println("Time " + elapsedTime + " ms."); System.out.println("Speed " + ((double) totalLength / elapsedTime) + " chars/ms"); System.out.println("Speed " + ((double) wordsCount / elapsedTime) + " words/ms"); System.out.println("+++++++++="); } }