Python sklearn.tree.export_graphviz() Examples

The following are 24 code examples of sklearn.tree.export_graphviz(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module sklearn.tree , or try the search function .
Example #1
Source File: 23_DecisionTree.py    From Python with MIT License 7 votes vote down vote up
def visualize_tree(data,clf,clf_name):
	features = data.columns
	features = features[:-1]
	class_names = list(set(data.iloc[:,-1]))
	dot_data = tree.export_graphviz(clf, out_file=None,  \
		feature_names=features,class_names=class_names,  \
		filled=True, rounded=True, special_characters=True)
	graph = graphviz.Source(dot_data)
	graph.render('dtree_render_'+clf_name,view=True)

# Function to perform training with giniIndex. 
Example #2
Source File: predict_enriched_decision_tree.py    From PIDGINv2 with MIT License 7 votes vote down vote up
def createTree(matrix,label):
	kmeans = KMeans(n_clusters=moa_clusters, random_state=0).fit(matrix)
	vector = map(int,kmeans.labels_)
	pc_10 = int(len(querymatrix1)*0.01)
	clf = tree.DecisionTreeClassifier(min_samples_split=min_sampsplit,min_samples_leaf=min_leafsplit,max_depth=max_d)
	clf.fit(matrix,vector)
	dot_data = StringIO()
	tree.export_graphviz(clf, out_file=dot_data,
							feature_names=label,
							class_names=map(str,list(set(sorted(kmeans.labels_)))),
							filled=True, rounded=True,
							special_characters=True,
							proportion=False,
							impurity=True)
	out_tree = dot_data.getvalue()
	out_tree = out_tree.replace('True','Inactive').replace('False','Active').replace(' ≤ 0.5', '').replace('class', 'Predicted MoA')
	graph = pydot.graph_from_dot_data(str(out_tree))
	try:
		graph.write_jpg(output_name_tree)
	except AttributeError:
		graph = pydot.graph_from_dot_data(str(out_tree))[0]
		graph.write_jpg(output_name_tree)
	return

#initializer for the pool 
Example #3
Source File: DTSklearn.py    From AiLearning with GNU General Public License v3.0 6 votes vote down vote up
def show_pdf(clf):
    '''
    可视化输出
    把决策树结构写入文件: http://sklearn.lzjqsdd.com/modules/tree.html

    Mac报错: pydotplus.graphviz.InvocationException: GraphViz's executables not found
    解决方案: sudo brew install graphviz
    参考写入:  http://www.jianshu.com/p/59b510bafb4d
    '''
    # with open("testResult/tree.dot", 'w') as f:
    #     from sklearn.externals.six import StringIO
    #     tree.export_graphviz(clf, out_file=f)

    import pydotplus
    from sklearn.externals.six import StringIO
    dot_data = StringIO()
    tree.export_graphviz(clf, out_file=dot_data)
    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("../../../output/3.DecisionTree/tree.pdf")

    # from IPython.display import Image
    # Image(graph.create_png()) 
Example #4
Source File: visualize_tree.py    From kaggle-tools with MIT License 6 votes vote down vote up
def visualize_tree(clf, feature_names, class_names, output_file,
                   method='pdf'):
    dot_data = StringIO()
    tree.export_graphviz(clf, out_file=dot_data,
                         feature_names=iris.feature_names,
                         class_names=iris.target_names,
                         filled=True, rounded=True,
                         special_characters=True,
                         impurity=False)
    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
    if method == 'pdf':
        graph.write_pdf(output_file + ".pdf")
    elif method == 'inline':
        Image(graph.create_png())

    return graph

# An example using the iris dataset 
Example #5
Source File: Decision_Tree.py    From ml_code with Apache License 2.0 6 votes vote down vote up
def dt_classification():
    iris = datasets.load_iris()
    X = iris.data[:, 0:2]
    y = iris.target
    
    clf = tree.DecisionTreeClassifier()
    clf.fit(X, y)
    
    dot_data = tree.export_graphviz(clf, out_file=None,
                                    feature_names=iris.feature_names,  
                                    class_names=iris.target_names,  
                                    filled=True, rounded=True,  
                                    special_characters=True
                                    )
    graph = pydotplus.graph_from_dot_data(dot_data)
    graph.write_png("./tree_iris.png")
    
    # plot result
    xmin, xmax = X[:, 0].min() - 1, X[:, 0].max() + 1
    ymin, ymax = X[:, 1].min() - 1, X[:, 1].max() + 1
    plot_step = 0.02
    xx, yy = np.meshgrid(np.arange(xmin, xmax, plot_step),
                         np.arange(ymin, ymax, plot_step))
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
    
    # Plot the training points
    n_classes = 3
    plot_colors = "bry"
    for i, color in zip(range(n_classes), plot_colors):
        idx = np.where(y == i)
        plt.scatter(X[idx, 0], X[idx, 1], c=color, 
                    label=iris.target_names[i],
                    cmap=plt.cm.Paired) 
Example #6
Source File: DTSklearn.py    From AiLearning with GNU General Public License v3.0 6 votes vote down vote up
def show_pdf(clf):
    '''
    可视化输出
    把决策树结构写入文件: http://sklearn.lzjqsdd.com/modules/tree.html

    Mac报错: pydotplus.graphviz.InvocationException: GraphViz's executables not found
    解决方案: sudo brew install graphviz
    参考写入:  http://www.jianshu.com/p/59b510bafb4d
    '''
    # with open("testResult/tree.dot", 'w') as f:
    #     from sklearn.externals.six import StringIO
    #     tree.export_graphviz(clf, out_file=f)

    import pydotplus
    from sklearn.externals.six import StringIO
    dot_data = StringIO()
    tree.export_graphviz(clf, out_file=dot_data)
    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("output/3.DecisionTree/tree.pdf")

    # from IPython.display import Image
    # Image(graph.create_png()) 
Example #7
Source File: test_export.py    From Mastering-Elasticsearch-7.0 with MIT License 6 votes vote down vote up
def test_plot_tree(pyplot):
    # mostly smoke tests
    # Check correctness of export_graphviz
    clf = DecisionTreeClassifier(max_depth=3,
                                 min_samples_split=2,
                                 criterion="gini",
                                 random_state=2)
    clf.fit(X, y)

    # Test export code
    feature_names = ['first feat', 'sepal_width']
    nodes = plot_tree(clf, feature_names=feature_names)
    assert len(nodes) == 3
    assert nodes[0].get_text() == ("first feat <= 0.0\nentropy = 0.5\n"
                                   "samples = 6\nvalue = [3, 3]")
    assert nodes[1].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [3, 0]"
    assert nodes[2].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [0, 3]" 
Example #8
Source File: dtree.py    From Python with MIT License 5 votes vote down vote up
def visualize_tree(data,clf,clf_name):
	features = data.columns
	features = features[:-1]
	class_names = list(set(data.iloc[:,-1]))
	dot_data = tree.export_graphviz(clf, out_file=None,  feature_names=features,class_names=class_names,  filled=True, rounded=True, special_characters=True)
	graph = graphviz.Source(dot_data)
	graph.render('dtree_render_'+clf_name,view=True)

# Function to perform training with giniIndex. 
Example #9
Source File: predict_enriched_two_libraries_decision_tree.py    From PIDGINv2 with MIT License 5 votes vote down vote up
def createTree(matrix,label):
	vector = [1] * len(querymatrix1) + [0] * len(querymatrix2)
	ratio = float(len(vector)-sum(vector))/float(sum(vector))
	sw = np.array([ratio if i == 1 else 1 for i in vector])
	pc_10 = int(len(querymatrix1)*0.01)
	clf = tree.DecisionTreeClassifier(min_samples_split=min_sampsplit,min_samples_leaf=min_leafsplit,max_depth=max_d)
	clf.fit(matrix,vector)
	dot_data = StringIO()
	tree.export_graphviz(clf, out_file=dot_data,
							feature_names=label,
							class_names=['File2','File1'],
							filled=True, rounded=True,
							special_characters=True,
							proportion=False,
							impurity=True)
	out_tree = dot_data.getvalue()
	out_tree = out_tree.replace('True','Inactive').replace('False','Active').replace(' &le; 0.5', '')
	graph = pydot.graph_from_dot_data(str(out_tree))
	try:
		graph.write_jpg(output_name_tree)
	except AttributeError:
		graph = pydot.graph_from_dot_data(str(out_tree))[0]
		graph.write_jpg(output_name_tree)
	return

#initializer for the pool 
Example #10
Source File: test_export.py    From twitter-stock-recommendation with MIT License 5 votes vote down vote up
def test_friedman_mse_in_graphviz():
    clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)
    clf.fit(X, y)
    dot_data = StringIO()
    export_graphviz(clf, out_file=dot_data)

    clf = GradientBoostingClassifier(n_estimators=2, random_state=0)
    clf.fit(X, y)
    for estimator in clf.estimators_:
        export_graphviz(estimator[0], out_file=dot_data)

    for finding in finditer("\[.*?samples.*?\]", dot_data.getvalue()):
        assert_in("friedman_mse", finding.group()) 
Example #11
Source File: test_export.py    From twitter-stock-recommendation with MIT License 5 votes vote down vote up
def test_graphviz_errors():
    # Check for errors of export_graphviz
    clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)

    # Check not-fitted decision tree error
    out = StringIO()
    assert_raises(NotFittedError, export_graphviz, clf, out)

    clf.fit(X, y)

    # Check if it errors when length of feature_names
    # mismatches with number of features
    message = ("Length of feature_names, "
               "1 does not match number of features, 2")
    assert_raise_message(ValueError, message, export_graphviz, clf, None,
                         feature_names=["a"])

    message = ("Length of feature_names, "
               "3 does not match number of features, 2")
    assert_raise_message(ValueError, message, export_graphviz, clf, None,
                         feature_names=["a", "b", "c"])

    # Check class_names error
    out = StringIO()
    assert_raises(IndexError, export_graphviz, clf, out, class_names=[])

    # Check precision error
    out = StringIO()
    assert_raises_regex(ValueError, "should be greater or equal",
                        export_graphviz, clf, out, precision=-1)
    assert_raises_regex(ValueError, "should be an integer",
                        export_graphviz, clf, out, precision="1") 
Example #12
Source File: test_export.py    From Mastering-Elasticsearch-7.0 with MIT License 5 votes vote down vote up
def test_graphviz_errors():
    # Check for errors of export_graphviz
    clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)

    # Check not-fitted decision tree error
    out = StringIO()
    assert_raises(NotFittedError, export_graphviz, clf, out)

    clf.fit(X, y)

    # Check if it errors when length of feature_names
    # mismatches with number of features
    message = ("Length of feature_names, "
               "1 does not match number of features, 2")
    assert_raise_message(ValueError, message, export_graphviz, clf, None,
                         feature_names=["a"])

    message = ("Length of feature_names, "
               "3 does not match number of features, 2")
    assert_raise_message(ValueError, message, export_graphviz, clf, None,
                         feature_names=["a", "b", "c"])

    # Check error when argument is not an estimator
    message = "is not an estimator instance"
    assert_raise_message(TypeError, message,
                         export_graphviz, clf.fit(X, y).tree_)

    # Check class_names error
    out = StringIO()
    assert_raises(IndexError, export_graphviz, clf, out, class_names=[])

    # Check precision error
    out = StringIO()
    assert_raises_regex(ValueError, "should be greater or equal",
                        export_graphviz, clf, out, precision=-1)
    assert_raises_regex(ValueError, "should be an integer",
                        export_graphviz, clf, out, precision="1") 
Example #13
Source File: learn.py    From uta with Apache License 2.0 5 votes vote down vote up
def write_pdf(clf, fn):
    dot_data = StringIO()
    tree.export_graphviz(clf, out_file=dot_data)
    graph = pydot.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf(fn) 
Example #14
Source File: models_classification.py    From easyML with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def export_model(self, IDcol):
        #Export the model into the model file as well as create a submission 
        #with model index. This will be used for creating an ensemble.
        self.export_model_base(IDcol,'decision_tree')

    ## UNDER DEVELOPMENT CODE FOR PRINTING TREES
    # def get_tree(self):
    #     return self.alg.tree_
    # Print the tree in visual format
    # Inputs:
    #     export_pdf - if True, a pdf will be exported with the 
    #     filename as specified in pdf_name argument
    #     pdf_name - name of the pdf file if export_pdf is True
    # def printTree(self, export_pdf=True, file_name="Decision_Tree.pdf"):
    #     dot_data = StringIO() 
    #     export_graphviz(
    #             self.alg, out_file=dot_data, feature_names=self.predictors,
    #             filled=True, rounded=True, special_characters=True)

    #     export_graphviz(
    #         self.alg, out_file='data.dot', feature_names=self.predictors,  
    #         filled=True, rounded=True, special_characters=True
    #         ) 
    #     graph = pydot.graph_from_dot_data(dot_data.getvalue())
        
    #     if export_pdf:
    #         graph.write_pdf(file_name)

    #     return graph

#####################################################################
##### RANDOM FOREST
##################################################################### 
Example #15
Source File: c10.py    From abu with GNU General Public License v3.0 5 votes vote down vote up
def sample_1033_1():
    """
    10.3.3 通过决策树分类,绘制出决策图
    这里需要安装dot graphviz,才能通过os.system("dot -T png graphviz.dot -o graphviz.png")生成png
    :return:
    """
    from sklearn.tree import DecisionTreeClassifier
    from sklearn import tree
    import os

    estimator = DecisionTreeClassifier(max_depth=2, random_state=1)

    # noinspection PyShadowingNames
    def graphviz_tree(estimator, features, x, y):
        if not hasattr(estimator, 'tree_'):
            print('only tree can graphviz!')
            return

        estimator.fit(x, y)
        # 将决策模型导出graphviz.dot文件
        tree.export_graphviz(estimator.tree_, out_file='graphviz.dot',
                             feature_names=features)
        # 通过dot将模型绘制决策图,保存png
        os.system("dot -T png graphviz.dot -o graphviz.png")

    global g_with_date_week_noise
    g_with_date_week_noise = True
    train_x, train_y_regress, train_y_classification, pig_three_feature, \
    test_x, test_y_regress, test_y_classification, kl_another_word_feature_test = sample_1031_1()

    # 这里会使用到特征的名称列pig_three_feature.columns[1:]
    graphviz_tree(estimator, pig_three_feature.columns[1:], train_x,
                  train_y_classification)

    import PIL.Image
    PIL.Image.open('graphviz.png').show() 
Example #16
Source File: test_tree.py    From pandas-ml with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_objectmapper(self):
        df = pdml.ModelFrame([])
        self.assertIs(df.tree.DecisionTreeClassifier, tree.DecisionTreeClassifier)
        self.assertIs(df.tree.DecisionTreeRegressor, tree.DecisionTreeRegressor)
        self.assertIs(df.tree.ExtraTreeClassifier, tree.ExtraTreeClassifier)
        self.assertIs(df.tree.ExtraTreeRegressor, tree.ExtraTreeRegressor)
        self.assertIs(df.tree.export_graphviz, tree.export_graphviz) 
Example #17
Source File: test_export.py    From Mastering-Elasticsearch-7.0 with MIT License 5 votes vote down vote up
def test_friedman_mse_in_graphviz():
    clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)
    clf.fit(X, y)
    dot_data = StringIO()
    export_graphviz(clf, out_file=dot_data)

    clf = GradientBoostingClassifier(n_estimators=2, random_state=0)
    clf.fit(X, y)
    for estimator in clf.estimators_:
        export_graphviz(estimator[0], out_file=dot_data)

    for finding in finditer(r"\[.*?samples.*?\]", dot_data.getvalue()):
        assert_in("friedman_mse", finding.group()) 
Example #18
Source File: train.py    From tree-regularization-public with MIT License 5 votes vote down vote up
def visualize(tree, save_path):
    """Generate PDF of a decision tree.

    @param tree: DecisionTreeClassifier instance
    @param save_path: string 
                      where to save tree PDF
    """
    dot_data = export_graphviz(tree, out_file=None,
                               filled=True, rounded=True)
    graph = pydotplus.graph_from_dot_data(dot_data)
    graph = make_graph_minimal(graph)  # remove extra text

    if not save_path is None:
        graph.write_pdf(save_path) 
Example #19
Source File: pigaios_ml.py    From pigaios with GNU General Public License v3.0 5 votes vote down vote up
def graphviz(self):
    if self.clf is None:
      log("Loading model...")
      self.clf = joblib.load("clf.pkl")

    dot_data = tree.export_graphviz(self.clf, out_file="pigaios.dot", \
                                    filled=True, rounded=True, \
                                    special_characters=True)
    os.system("dot -Tx11 pigaios.dot")

#------------------------------------------------------------------------------- 
Example #20
Source File: ABuMLExecute.py    From abu with GNU General Public License v3.0 4 votes vote down vote up
def graphviz_tree(estimator, features, x, y):
    """
    绘制决策树或者core基于树的分类回归算法的决策示意图绘制,查看
    学习器本身hasattr(fiter, 'tree_')是否有tree_属性,内部clone(estimator)学习器
    后再进行训练操作,完成训练后使用sklearn中tree.export_graphvizd导出graphviz.dot文件
    需要使用第三方dot工具将graphviz.dot进行转换graphviz.png,即内部实行使用
    运行命令行:
                os.system("dot -T png graphviz.dot -o graphviz.png")
    最后读取决策示意图显示

    :param estimator: 学习器对象,透传learning_curve
    :param x: 训练集x矩阵,numpy矩阵
    :param y: 训练集y序列,numpy序列
    :param features: 训练集x矩阵列特征所队员的名称,可迭代序列对象
    """
    if not hasattr(estimator, 'tree_'):
        logging.info('only tree can graphviz!')
        return

    # 所有执行fit的操作使用clone一个新的
    estimator = clone(estimator)
    estimator.fit(x, y)
    # TODO out_file path放倒cache中
    tree.export_graphviz(estimator.tree_, out_file='graphviz.dot', feature_names=features)
    os.system("dot -T png graphviz.dot -o graphviz.png")

    '''
        !open $path
        要是方便用notebook直接open其实显示效果好,plt,show的大小不好调整
    '''
    graphviz = os.path.join(os.path.abspath('.'), 'graphviz.png')

    # path = graphviz
    # !open $path
    if not file_exist(graphviz):
        logging.info('{} not exist! please install dot util!'.format(graphviz))
        return

    image_file = cbook.get_sample_data(graphviz)
    image = plt.imread(image_file)
    image_file.close()
    plt.imshow(image)
    plt.axis('off')  # clear x- and y-axes
    plt.show() 
Example #21
Source File: test_export.py    From Mastering-Elasticsearch-7.0 with MIT License 4 votes vote down vote up
def test_precision():

    rng_reg = RandomState(2)
    rng_clf = RandomState(8)
    for X, y, clf in zip(
            (rng_reg.random_sample((5, 2)),
             rng_clf.random_sample((1000, 4))),
            (rng_reg.random_sample((5, )),
             rng_clf.randint(2, size=(1000, ))),
            (DecisionTreeRegressor(criterion="friedman_mse", random_state=0,
                                   max_depth=1),
             DecisionTreeClassifier(max_depth=1, random_state=0))):

        clf.fit(X, y)
        for precision in (4, 3):
            dot_data = export_graphviz(clf, out_file=None, precision=precision,
                                       proportion=True)

            # With the current random state, the impurity and the threshold
            # will have the number of precision set in the export_graphviz
            # function. We will check the number of precision with a strict
            # equality. The value reported will have only 2 precision and
            # therefore, only a less equal comparison will be done.

            # check value
            for finding in finditer(r"value = \d+\.\d+", dot_data):
                assert_less_equal(
                    len(search(r"\.\d+", finding.group()).group()),
                    precision + 1)
            # check impurity
            if is_classifier(clf):
                pattern = r"gini = \d+\.\d+"
            else:
                pattern = r"friedman_mse = \d+\.\d+"

            # check impurity
            for finding in finditer(pattern, dot_data):
                assert_equal(len(search(r"\.\d+", finding.group()).group()),
                             precision + 1)
            # check threshold
            for finding in finditer(r"<= \d+\.\d+", dot_data):
                assert_equal(len(search(r"\.\d+", finding.group()).group()),
                             precision + 1) 
Example #22
Source File: sklearn_tree.py    From android-malware-analysis with GNU General Public License v3.0 4 votes vote down vote up
def train_tree_classifer(features, labels, model_output_path):
    """
    train_tree_classifer will train a DecisionTree and write it out to a pdf file

    features: 2D array of each input feature for each sample
    labels: array of string labels classifying each sample
    model_output_path: path for storing the trained tree model
    """
    # save 20% of data for performance evaluation
    X_train, X_test, y_train, y_test = cross_validation.train_test_split(features, labels, test_size=0.2)

    param = [
        {
            "max_depth": [None, 10, 100, 1000, 10000]
        }
    ]

    dtree = tree.DecisionTreeClassifier(random_state=0)

    # 10-fold cross validation, use 4 thread as each fold and each parameter set can be train in parallel
    clf = grid_search.GridSearchCV(dtree, param,
            cv=10, n_jobs=20, verbose=3)

    clf.fit(X_train, y_train)

    if os.path.exists(model_output_path):
        joblib.dump(clf.best_estimator_, model_output_path)
    else:
        print("Cannot save trained tree model to {0}.".format(model_output_path))

    dot_data = tree.export_graphviz(clf.best_estimator_, out_file=None)
    graph = pydotplus.graph_from_dot_data(dot_data)
    graph.write_pdf('best_tree.pdf')

    print("\nBest parameters set:")
    print(clf.best_params_)

    y_predict=clf.predict(X_test)

    labels=sorted(list(set(labels)))
    print("\nConfusion matrix:")
    print("Labels: {0}\n".format(",".join(labels)))
    print(confusion_matrix(y_test, y_predict, labels=labels))

    print("\nClassification report:")
    print(classification_report(y_test, y_predict)) 
Example #23
Source File: SKLProcessors.py    From Ossian with Apache License 2.0 4 votes vote down vote up
def do_training(self, speech_corpus, text_corpus):
        
        if self.model:  ## if already trained...
            return

        ## 1) get data:
        #### [Added dump_features method to Utterance class, use that: ]
        x_data = []
        y_data = []
        for utterance in speech_corpus:
            
            utt_feats = utterance.dump_features(self.target_nodes, \
                                                self.context_list, return_dict=True)

            for example in utt_feats:
                assert 'response' in example,example
                y_data.append({'response': example['response']})
                del example['response']
                x_data.append(example)
        
        ## Handle categorical features (strings) but to keep numerical ones 
        ## as they are:
        
        x_vectoriser = DictVectorizer()
        x_data = x_vectoriser.fit_transform(x_data).toarray()
        
        y_vectoriser = DictVectorizer()
        y_data = y_vectoriser.fit_transform(y_data).toarray()
      
        if False:
            print x_data
            print y_data
        
        ## 2) train classifier:
        model = tree.DecisionTreeClassifier(min_samples_leaf=self.min_samples_leaf)

        model.fit(x_data, y_data) 
        print '\n Trained classifier: '
        print model
        print '\n Trained x vectoriser:'
        print x_vectoriser
        print 'Feature names:'
        print x_vectoriser.get_feature_names()
        print '\n Trained y vectoriser:'
        print y_vectoriser
        print 'Feature names:'
        print y_vectoriser.get_feature_names()
        
        ## 3) Save classifier by pickling:
        output = open(self.model_file, 'wb')
        pickle.dump([x_vectoriser, y_vectoriser, model], output)
        output.close()        
        
        ## Write ASCII tree representation (which can be plotted):
        tree.export_graphviz(model, out_file=self.model_file + '.dot',  \
                                     feature_names=x_vectoriser.get_feature_names())
        
        self.verify(self.voice_resources) # ## reload -- get self.model etc 
Example #24
Source File: test_export.py    From twitter-stock-recommendation with MIT License 4 votes vote down vote up
def test_precision():

    rng_reg = RandomState(2)
    rng_clf = RandomState(8)
    for X, y, clf in zip(
            (rng_reg.random_sample((5, 2)),
             rng_clf.random_sample((1000, 4))),
            (rng_reg.random_sample((5, )),
             rng_clf.randint(2, size=(1000, ))),
            (DecisionTreeRegressor(criterion="friedman_mse", random_state=0,
                                   max_depth=1),
             DecisionTreeClassifier(max_depth=1, random_state=0))):

        clf.fit(X, y)
        for precision in (4, 3):
            dot_data = export_graphviz(clf, out_file=None, precision=precision,
                                       proportion=True)

            # With the current random state, the impurity and the threshold
            # will have the number of precision set in the export_graphviz
            # function. We will check the number of precision with a strict
            # equality. The value reported will have only 2 precision and
            # therefore, only a less equal comparison will be done.

            # check value
            for finding in finditer("value = \d+\.\d+", dot_data):
                assert_less_equal(
                    len(search("\.\d+", finding.group()).group()),
                    precision + 1)
            # check impurity
            if is_classifier(clf):
                pattern = "gini = \d+\.\d+"
            else:
                pattern = "friedman_mse = \d+\.\d+"

            # check impurity
            for finding in finditer(pattern, dot_data):
                assert_equal(len(search("\.\d+", finding.group()).group()),
                             precision + 1)
            # check threshold
            for finding in finditer("<= \d+\.\d+", dot_data):
                assert_equal(len(search("\.\d+", finding.group()).group()),
                             precision + 1)