Python matplotlib.pyplot.plot_date() Examples

The following are 30 code examples of matplotlib.pyplot.plot_date(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module matplotlib.pyplot , or try the search function .
Example #1
Source File: visualizer.py    From Load-Forecasting with MIT License 6 votes vote down vote up
def comparisonPlot(year,month,day,seriesList,nameList,plotName="Comparison of Values over Time", yAxisName="Predicted"):
	date = datetime.date(year,month,day)
	dateList = []
	for x in range(len(seriesList[0])):
		dateList.append(date+datetime.timedelta(days=x))
	colors = ["b","g","r","c","m","y","k","w"]
	currColor = 0
	legendVars = []
	for i in range(len(seriesList)):
		x, = plt.plot_date(x=dateList,y=seriesList[i],color=colors[currColor],linestyle="-",marker=".")
		legendVars.append(x)
		currColor += 1
		if (currColor >= len(colors)):
			currColor = 0
	plt.legend(legendVars, nameList)
	plt.title(plotName)
	plt.ylabel(yAxisName)
	plt.xlabel("Date")
	plt.show() 
Example #2
Source File: QCreport.py    From geoist with MIT License 6 votes vote down vote up
def graph_event_types(catalog, prefix):
    """Graph number of cumulative events by type of event."""
    typedict = {}

    for evtype in catalog['type'].unique():
        typedict[evtype] = (catalog['type'] == evtype).cumsum()

    plt.figure(figsize=(12, 6))

    for evtype in typedict:
        plt.plot_date(catalog['convtime'], typedict[evtype], marker=None,
                      linestyle='-', label=evtype)

    plt.yscale('log')
    plt.legend()
    plt.xlim(min(catalog['convtime']), max(catalog['convtime']))

    plt.xlabel('Date', fontsize=14)
    plt.ylabel('Cumulative number of events', fontsize=14)
    plt.title('Cumulative Event Type', fontsize=20)

    plt.savefig('%s_cumuleventtypes.png' % prefix, dpi=300)
    plt.close() 
Example #3
Source File: DailyDifferenceAverageSpark.py    From incubator-sdap-nexus with Apache License 2.0 6 votes vote down vote up
def toImage(self):
        from StringIO import StringIO
        import matplotlib.pyplot as plt
        from matplotlib.dates import date2num

        times = [date2num(datetime.fromtimestamp(dayavglistdict[0]['time'], pytz.utc).date()) for dayavglistdict in
                 self.results()]
        means = [dayavglistdict[0]['mean'] for dayavglistdict in self.results()]
        plt.plot_date(times, means, '|g-')

        plt.xlabel('Date')
        plt.xticks(rotation=70)
        plt.ylabel(u'Difference from 5-Day mean (\u00B0C)')
        plt.title('Sea Surface Temperature (SST) Anomalies')
        plt.grid(True)
        plt.tight_layout()

        sio = StringIO()
        plt.savefig(sio, format='png')
        return sio.getvalue() 
Example #4
Source File: QCreport.py    From geoist with MIT License 6 votes vote down vote up
def graph_mag_time(catalog, prefix):
    """Plot magnitudes vs. origin time."""
    catalog = catalog[pd.notnull(catalog['mag'])]

    times = catalog['convtime'].copy()
    mags = catalog['mag'].copy()

    plt.figure(figsize=(10, 6))
    plt.xlabel('Date', fontsize=14)
    plt.ylabel('Magnitude', fontsize=14)
    plt.plot_date(times, mags, alpha=0.7, markersize=2, c='b')
    plt.xlim(min(times), max(times))
    plt.title('Magnitude vs. Time', fontsize=20)

    plt.savefig('%s_magvtime.png' % prefix, dpi=300)
    plt.close() 
Example #5
Source File: test_axes.py    From twitter-stock-recommendation with MIT License 5 votes vote down vote up
def test_date_timezone_x_and_y():
    # Tests issue 5575
    time_index = [pytz.timezone('UTC').localize(datetime.datetime(
        year=2016, month=2, day=22, hour=x)) for x in range(3)]

    # Same Timezone
    fig = plt.figure(figsize=(20, 12))
    plt.subplot(2, 1, 1)
    plt.plot_date(time_index, time_index, tz='UTC', ydate=True)

    # Different Timezone
    plt.subplot(2, 1, 2)
    plt.plot_date(time_index, time_index, tz='US/Eastern', ydate=True) 
Example #6
Source File: ch_591_water_balance.py    From hydrology with GNU General Public License v3.0 5 votes vote down vote up
def plot_date(dataframe, column_name):
    """

    :param dataframe:
    :param column_name:
    :type column_name:str
    :return:
    """
    fig = plt.figure(figsize=(11.69, 8.27))
    p = plt.plot(dataframe.index, dataframe[column_name], 'b-', label=r"%s" % column_name)
    plt.hlines(0, min(dataframe.index), max(dataframe.index), 'r')
    plt.legend(loc='best')
    fig.autofmt_xdate(rotation=90)
    return p 
Example #7
Source File: test_axes.py    From twitter-stock-recommendation with MIT License 5 votes vote down vote up
def test_date_timezone_y():
    # Tests issue 5575
    time_index = [pytz.timezone('Canada/Eastern').localize(datetime.datetime(
        year=2016, month=2, day=22, hour=x)) for x in range(3)]

    # Same Timezone
    fig = plt.figure(figsize=(20, 12))
    plt.subplot(2, 1, 1)
    plt.plot_date([3] * 3,
                  time_index, tz='Canada/Eastern', xdate=False, ydate=True)

    # Different Timezone
    plt.subplot(2, 1, 2)
    plt.plot_date([3] * 3, time_index, tz='UTC', xdate=False, ydate=True) 
Example #8
Source File: test_axes.py    From twitter-stock-recommendation with MIT License 5 votes vote down vote up
def test_date_timezone_x():
    # Tests issue 5575
    time_index = [pytz.timezone('Canada/Eastern').localize(datetime.datetime(
        year=2016, month=2, day=22, hour=x)) for x in range(3)]

    # Same Timezone
    fig = plt.figure(figsize=(20, 12))
    plt.subplot(2, 1, 1)
    plt.plot_date(time_index, [3] * 3, tz='Canada/Eastern')

    # Different Timezone
    plt.subplot(2, 1, 2)
    plt.plot_date(time_index, [3] * 3, tz='UTC') 
Example #9
Source File: test_axes.py    From twitter-stock-recommendation with MIT License 5 votes vote down vote up
def test_single_date():
    time1 = [721964.0]
    data1 = [-65.54]

    fig = plt.figure()
    plt.subplot(211)
    plt.plot_date(time1, data1, 'o', color='r')

    plt.subplot(212)
    plt.plot(time1, data1, 'o', color='r') 
Example #10
Source File: log_analyzer.py    From satellite with GNU General Public License v3.0 5 votes vote down vote up
def _plot_dvb(ds, ds_name):
    keys = list(ds[0].keys())
    keys.remove("time")
    keys.remove("date")

    formatter = DateFormatter('%H:%M')

    path = os.path.join("figs", ds_name)
    if not os.path.isdir(path):
        os.makedirs(path)

    for key in keys:
        print("Plotting {}".format(key))
        x = [r[key] for r in ds if key in r]
        t = [r["date"] for r in ds if key in r]
        n = os.path.join(path, "dvb-" + key.replace("/","_") + ".png")
        if "_" in key:
            kelems = key.split("_")
            ylabel = "{} ({})".format(kelems[0], kelems[1])
        elif key[-1] == "%":
            ylabel = "{} %".format(key[:-1])
        else:
            ylabel = key
        fig, ax = plt.subplots()
        plt.plot_date(t, x, ms=1)
        plt.ylabel(ylabel)
        plt.xlabel("Time")
        if (key == "postBER"):
            ax.set_yscale('log')
        plt.grid()
        ax.xaxis.set_major_formatter(formatter)
        ax.xaxis.set_tick_params(rotation=30, labelsize=10)
        plt.tight_layout()
        plt.savefig(n, dpi=300)
        plt.close() 
Example #11
Source File: graph.py    From Quant_stock with MIT License 5 votes vote down vote up
def graph(models):
    for model in models:
        print("Loading pre-trained model...")
        sess = tf.Session()
        saver = tf.train.import_meta_graph("data/model/"+str(model)+'/'+str(model)+'.ckpt.meta')
        saver.restore(sess, tf.train.latest_checkpoint('data/model/'+str(model)))
        print("Model loaded...")

        graph = tf.get_default_graph()
        if model == 'feedforward':
            x = graph.get_tensor_by_name('input:0')
            prediction = graph.get_tensor_by_name('output:0')
        elif model == 'recurrent':
            x = graph.get_tensor_by_name('input_recurrent:0')
            prediction = graph.get_tensor_by_name('output_recurrent:0')
        _, _, _, _, oil_price, stock_price = dp.create_data()

        predictions = []
        if model == 'feedforward':
            date_labels = oil_price.index
            date_labels = matplotlib.dates.date2num(date_labels.to_pydatetime())
            for i in oil_price:
                predictions.append(sess.run(prediction, feed_dict={x: [[i]]})[0][0])
        elif model == 'recurrent':
            predictions = []
            for index in range(int(len(oil_price.values) / total_chunk_size)):
                x_in = oil_price.values[index * total_chunk_size:index * total_chunk_size + total_chunk_size].reshape(
                    (1, n_chunks, chunk_size))
                predictions += sess.run(prediction, feed_dict={x: x_in})[0].reshape(total_chunk_size).tolist()

        plt.plot_date(date_labels, predictions, 'b-', label="Feedforward Predictions")
        plt.plot_date(date_labels, stock_price.values, 'r-', label='Stock Prices')
        plt.legend()
    plt.ylabel('Price')
    plt.xlabel('Year')
    plt.show() 
Example #12
Source File: test_axes.py    From coffeegrindsize with MIT License 5 votes vote down vote up
def test_date_timezone_x_and_y():
    # Tests issue 5575
    UTC = datetime.timezone.utc
    time_index = [datetime.datetime(2016, 2, 22, hour=x, tzinfo=UTC)
                  for x in range(3)]

    # Same Timezone
    fig = plt.figure(figsize=(20, 12))
    plt.subplot(2, 1, 1)
    plt.plot_date(time_index, time_index, tz='UTC', ydate=True)

    # Different Timezone
    plt.subplot(2, 1, 2)
    plt.plot_date(time_index, time_index, tz='US/Eastern', ydate=True) 
Example #13
Source File: test_axes.py    From coffeegrindsize with MIT License 5 votes vote down vote up
def test_date_timezone_y():
    # Tests issue 5575
    time_index = [datetime.datetime(2016, 2, 22, hour=x,
                                    tzinfo=dutz.gettz('Canada/Eastern'))
                  for x in range(3)]

    # Same Timezone
    fig = plt.figure(figsize=(20, 12))
    plt.subplot(2, 1, 1)
    plt.plot_date([3] * 3,
                  time_index, tz='Canada/Eastern', xdate=False, ydate=True)

    # Different Timezone
    plt.subplot(2, 1, 2)
    plt.plot_date([3] * 3, time_index, tz='UTC', xdate=False, ydate=True) 
Example #14
Source File: test_axes.py    From coffeegrindsize with MIT License 5 votes vote down vote up
def test_date_timezone_x():
    # Tests issue 5575
    time_index = [datetime.datetime(2016, 2, 22, hour=x,
                                    tzinfo=dutz.gettz('Canada/Eastern'))
                  for x in range(3)]

    # Same Timezone
    fig = plt.figure(figsize=(20, 12))
    plt.subplot(2, 1, 1)
    plt.plot_date(time_index, [3] * 3, tz='Canada/Eastern')

    # Different Timezone
    plt.subplot(2, 1, 2)
    plt.plot_date(time_index, [3] * 3, tz='UTC') 
Example #15
Source File: test_axes.py    From coffeegrindsize with MIT License 5 votes vote down vote up
def test_single_date():
    time1 = [721964.0]
    data1 = [-65.54]

    fig = plt.figure()
    plt.subplot(211)
    plt.plot_date(time1, data1, 'o', color='r')

    plt.subplot(212)
    plt.plot(time1, data1, 'o', color='r') 
Example #16
Source File: test_axes.py    From ImageFusion with MIT License 5 votes vote down vote up
def test_single_date():
    time1 = [721964.0]
    data1 = [-65.54]

    fig = plt.figure()
    plt.subplot(211)
    plt.plot_date(time1, data1, 'o', color='r')

    plt.subplot(212)
    plt.plot(time1, data1, 'o', color='r') 
Example #17
Source File: visualizer.py    From Load-Forecasting with MIT License 5 votes vote down vote up
def yearlyPlot(ySeries,year,month,day,plotName ="Plot",yAxisName="yData"):

	date = datetime.date(year,month,day)
	dateList = []
	for x in range(len(ySeries)):
		dateList.append(date+datetime.timedelta(days=x))

	plt.plot_date(x=dateList,y=ySeries,fmt="r-")
	plt.title(plotName)
	plt.ylabel(yAxisName)
	plt.xlabel("Date")
	plt.grid(True)
	plt.show()

# Plots autocorrelation factors against varying time lags for ySeries 
Example #18
Source File: temperature_data.py    From python_primer with MIT License 5 votes vote down vote up
def plot_city_data(*city_data_dicts):
    cities = []
    for city_data in city_data_dicts:
        cities.append(city_data['name'])
        plt.plot_date(city_data['date'], city_data['temperature'], '.')
    plt.ylabel('Temperature (C)')
    plt.xlabel('Date')
    plt.ylim([0, 110])
    plt.legend(cities)
    plt.show() 
Example #19
Source File: display.py    From diogenes with MIT License 5 votes vote down vote up
def plot_on_timeline(col, verbose=True):
    """Plots points on a timeline
    
    Parameters
    ----------
    col : np.array
    verbose : boolean
        iff True, display the graph

    Returns
    -------
    matplotlib.figure.Figure
        Figure containing plot

    
    Returns
    -------
    matplotlib.figure.Figure
    """
    col = utils.check_col(col)
    # http://stackoverflow.com/questions/1574088/plotting-time-in-python-with-matplotlib
    if is_nd(col):
        col = col.astype(datetime)
    dates = matplotlib.dates.date2num(col)
    fig = plt.figure()
    plt.plot_date(dates, [0] * len(dates))
    if verbose:
        plt.show()
    return fig 
Example #20
Source File: test_axes.py    From python3_ios with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_date_timezone_x_and_y():
    # Tests issue 5575
    UTC = datetime.timezone.utc
    time_index = [datetime.datetime(2016, 2, 22, hour=x, tzinfo=UTC)
                  for x in range(3)]

    # Same Timezone
    fig = plt.figure(figsize=(20, 12))
    plt.subplot(2, 1, 1)
    plt.plot_date(time_index, time_index, tz='UTC', ydate=True)

    # Different Timezone
    plt.subplot(2, 1, 2)
    plt.plot_date(time_index, time_index, tz='US/Eastern', ydate=True) 
Example #21
Source File: test_axes.py    From python3_ios with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_date_timezone_y():
    # Tests issue 5575
    time_index = [datetime.datetime(2016, 2, 22, hour=x,
                                    tzinfo=dutz.gettz('Canada/Eastern'))
                  for x in range(3)]

    # Same Timezone
    fig = plt.figure(figsize=(20, 12))
    plt.subplot(2, 1, 1)
    plt.plot_date([3] * 3,
                  time_index, tz='Canada/Eastern', xdate=False, ydate=True)

    # Different Timezone
    plt.subplot(2, 1, 2)
    plt.plot_date([3] * 3, time_index, tz='UTC', xdate=False, ydate=True) 
Example #22
Source File: test_axes.py    From python3_ios with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_date_timezone_x():
    # Tests issue 5575
    time_index = [datetime.datetime(2016, 2, 22, hour=x,
                                    tzinfo=dutz.gettz('Canada/Eastern'))
                  for x in range(3)]

    # Same Timezone
    fig = plt.figure(figsize=(20, 12))
    plt.subplot(2, 1, 1)
    plt.plot_date(time_index, [3] * 3, tz='Canada/Eastern')

    # Different Timezone
    plt.subplot(2, 1, 2)
    plt.plot_date(time_index, [3] * 3, tz='UTC') 
Example #23
Source File: test_axes.py    From python3_ios with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def test_single_date():
    time1 = [721964.0]
    data1 = [-65.54]

    fig = plt.figure()
    plt.subplot(211)
    plt.plot_date(time1, data1, 'o', color='r')

    plt.subplot(212)
    plt.plot(time1, data1, 'o', color='r') 
Example #24
Source File: process_debug.py    From scoop with GNU Lesser General Public License v3.0 5 votes vote down vote up
def plotBrokerQueue(dataTask, filename):
    """Generates the broker queue length graphic."""
    print("Plotting broker queue length for {0}.".format(filename))
    plt.figure()

    # Queue length
    plt.subplot(211)
    for fichier, vals in dataTask.items():
        if type(vals) == list:
            timestamps = list(map(datetime.fromtimestamp, map(int, list(zip(*vals))[0])))
            # Data is from broker
            plt.plot_date(timestamps, list(zip(*vals))[2],
                          linewidth=1.0,
                          marker='o',
                          markersize=2,
                          label=fichier)
    plt.title('Broker queue length')
    plt.ylabel('Tasks')

    # Requests received
    plt.subplot(212)
    for fichier, vals in dataTask.items():
        if type(vals) == list:
            timestamps = list(map(datetime.fromtimestamp, map(int, list(zip(*vals))[0])))
            # Data is from broker
            plt.plot_date(timestamps, list(zip(*vals))[3],
                          linewidth=1.0,
                          marker='o',
                          markersize=2,
                          label=fichier)
    plt.title('Broker pending requests')
    plt.xlabel('time (s)')
    plt.ylabel('Requests')

    plt.savefig(filename) 
Example #25
Source File: plot.py    From quantified-self with MIT License 5 votes vote down vote up
def make_efficiency_date(
        total_data,
        avg_data,
        f_name,
        title=None,
        x_label=None,
        y_label=None,
        x_ticks=None,
        y_ticks=None,
    ):

        fig = plt.figure()

        if title is not None:
            plt.title(title, fontsize=16)
        if x_label is not None:
            plt.ylabel(x_label)
        if y_label is not None:
            plt.xlabel(y_label)

        v_date = []
        v_val = []

        for data in total_data:
            dates = dt.date2num(datetime.datetime.strptime(data[0], "%H:%M"))
            to_int = round(float(data[1]))
            plt.plot_date(dates, data[1], color=plt.cm.brg(to_int))
        for data in avg_data:
            dates = dt.date2num(datetime.datetime.strptime(data[0], "%H:%M"))
            v_date.append(dates)
            v_val.append(data[1])

        plt.plot_date(v_date, v_val, "^y-", label="Average")
        plt.legend()
        plt.savefig(f_name)
        plt.close(fig) 
Example #26
Source File: test_axes.py    From neural-network-animation with MIT License 5 votes vote down vote up
def test_single_date():
    time1 = [721964.0]
    data1 = [-65.54]

    fig = plt.figure()
    plt.subplot(211)
    plt.plot_date(time1, data1, 'o', color='r')

    plt.subplot(212)
    plt.plot(time1, data1, 'o', color='r') 
Example #27
Source File: DailyDifferenceAverage.py    From incubator-sdap-nexus with Apache License 2.0 4 votes vote down vote up
def calc(self, request, **args):
        min_lat, max_lat, min_lon, max_lon = request.get_min_lat(), request.get_max_lat(), request.get_min_lon(), request.get_max_lon()
        dataset1 = request.get_argument("ds1", None)
        dataset2 = request.get_argument("ds2", None)
        start_time = request.get_start_time()
        end_time = request.get_end_time()

        simple = request.get_argument("simple", None) is not None

        averagebyday = self.get_daily_difference_average_for_box(min_lat, max_lat, min_lon, max_lon, dataset1, dataset2,
                                                                 start_time, end_time)

        averagebyday = sorted(averagebyday, key=lambda dayavg: dayavg[0])

        if simple:

            import matplotlib.pyplot as plt
            from matplotlib.dates import date2num

            times = [date2num(self.date_from_ms(dayavg[0])) for dayavg in averagebyday]
            means = [dayavg[1] for dayavg in averagebyday]
            plt.plot_date(times, means, ls='solid')

            plt.xlabel('Date')
            plt.xticks(rotation=70)
            plt.ylabel(u'Difference from 5-Day mean (\u00B0C)')
            plt.title('Sea Surface Temperature (SST) Anomalies')
            plt.grid(True)
            plt.tight_layout()
            plt.savefig("test.png")

            return averagebyday, None, None
        else:

            result = NexusResults(
                results=[[{'time': dayms, 'mean': avg, 'ds': 0}] for dayms, avg in averagebyday],
                stats={},
                meta=self.get_meta())

            result.extendMeta(min_lat, max_lat, min_lon, max_lon, "", start_time, end_time)
            result.meta()['label'] = u'Difference from 5-Day mean (\u00B0C)'

            return result 
Example #28
Source File: todaychart.py    From raspi-sump with MIT License 4 votes vote down vote up
def graph(csv_file, filename, bytes2str):
    """Create a line graph from a two column csv file."""

    unit = configs["unit"]
    date, value = np.loadtxt(
        csv_file, delimiter=",", unpack=True, converters={0: bytes2str}
    )
    fig = plt.figure(figsize=(10, 3.5))

    # axisbg is deprecated in matplotlib 2.x. Maintain 1.x compatibility
    if MPL_VERSION > 1:
        fig.add_subplot(111, facecolor="white", frameon=False)
    else:
        fig.add_subplot(111, axisbg="white", frameon=False)

    rcParams.update({"font.size": 9})
    plt.plot_date(
        x=date,
        y=value,
        ls="solid",
        linewidth=2,
        color="#" + configs["line_color"],
        fmt=":",
    )
    title = "Sump Pit Water Level {}".format(time.strftime("%Y-%m-%d %H:%M"))
    title_set = plt.title(title)
    title_set.set_y(1.09)
    plt.subplots_adjust(top=0.86)

    if unit == "imperial":
        plt.ylabel("inches")
    if unit == "metric":
        plt.ylabel("centimeters")

    plt.xlabel("Time of Day")
    plt.xticks(rotation=30)
    plt.grid(True, color="#ECE5DE", linestyle="solid")
    if MPL_VERSION < 3:
        plt.tick_params(axis="x", bottom="off", top="off")
        plt.tick_params(axis="y", left="off", right="off")
    else:
        plt.tick_params(axis="x", bottom=False, top=False)
        plt.tick_params(axis="y", left=False, right=False)
    plt.savefig(filename, dpi=72) 
Example #29
Source File: feedforward_nn.py    From Quant_stock with MIT License 4 votes vote down vote up
def feedforward_neural_network(inputs):
    x = tf.placeholder('float', name='input')
    oil_train, stock_train, oil_test, stock_test, oil_price, stock_price = inputs
    prediction = neural_network_model(x)
    cost = tf.reduce_mean(tf.square(tf.transpose(prediction)-y))
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
    #oil_train, stock_train, oil_test, stock_test = inputs

    oil_train, stock_train, oil_test, stock_test = refine_input_with_lag(oil_train, stock_train, oil_test, stock_test)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        #Running neural net
        for epoch in range(hm_epoch):
            epoch_loss = 0
            for (X, Y) in zip(oil_train.values, stock_train.values):
                _, c = sess.run([optimizer, cost], feed_dict={x: [[X]], y: [[Y]]})
                epoch_loss += c
            print('Epoch', epoch, 'completed out of', hm_epoch, 'loss:', epoch_loss)
        correct = tf.subtract(prediction, y)
        total = 0
        cor = 0
        for (X,Y) in zip(oil_test.values, stock_test.values):
            total += 1
            if abs(correct.eval({x: [[X]], y: [[Y]]})) < 5:
                cor += 1
        print('Accuracy:', cor/total)
        save_path = saver.save(sess, "data/model/feedforward/feedforward.ckpt")
        print("Model saved in file: %s" % save_path)

        date_labels = oil_price.index
        date_labels = matplotlib.dates.date2num(date_labels.to_pydatetime())

        predictions = []
        for i in oil_price:
            predictions.append(sess.run(prediction, feed_dict={x: [[i]]})[0][0])
        plt.plot_date(date_labels, predictions, 'b-', label="Feedforward Predictions")
        plt.plot_date(date_labels, stock_price.values, 'r-', label='Stock Prices')
        plt.legend()
        plt.ylabel('Price')
        plt.xlabel('Year')
        plt.show() 
Example #30
Source File: recurrent_lstm.py    From Quant_stock with MIT License 4 votes vote down vote up
def recurrent_neural_network(inputs):
    oil_train, stock_train, oil_test, stock_test, oil_price, stock_price = inputs
    cost = tf.reduce_mean(tf.square(prediction-y))
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
    #oil_train, stock_train, oil_test, stock_test = inputs

    oil_train, stock_train, oil_test, stock_test = refine_input_with_lag(oil_train, stock_train, oil_test, stock_test)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        #Running neural net
        for epoch in range(hm_epoch):
            epoch_loss = 0
            for index in range(int(len(oil_train.values) / total_chunk_size)):
                x_in = oil_train.values[index * total_chunk_size:index * total_chunk_size + total_chunk_size].reshape((1, n_chunks, chunk_size))
                y_in = stock_train.values[index * total_chunk_size:index * total_chunk_size + total_chunk_size].reshape((1, n_chunks, chunk_size))
                _, c = sess.run([optimizer, cost], feed_dict={x: x_in, y: y_in})
                epoch_loss += c
            print('Epoch', epoch, 'completed out of', hm_epoch, 'loss:', epoch_loss)
        correct = tf.reduce_mean(tf.square(tf.subtract(prediction, y)))
        total = 0
        cor = 0
        for index in range(int(len(oil_test.values) / total_chunk_size)):
            x_in = oil_test.values[index * total_chunk_size:index * total_chunk_size + total_chunk_size].reshape((1, n_chunks, chunk_size))
            y_in = stock_test.values[index * total_chunk_size:index * total_chunk_size + total_chunk_size].reshape((1, n_chunks, chunk_size))
            total += total_chunk_size
            if abs(correct.eval(feed_dict={x: x_in, y: y_in})) < 5:
                cor += total_chunk_size

        saver = tf.train.Saver()
        print('Accuracy:', cor/total)
        save_path = saver.save(sess, "data/model/recurrent/recurrent.ckpt")
        print("Model saved in file: %s" % save_path)

        date_labels = oil_price.index
        date_labels = matplotlib.dates.date2num(date_labels.to_pydatetime())[:-4]

        predictions = []
        for index in range(int(len(oil_price.values) / total_chunk_size)):
            x_in = oil_price.values[index * total_chunk_size:index * total_chunk_size + total_chunk_size].reshape((1, n_chunks, chunk_size))
            predictions += sess.run(prediction, feed_dict={x: x_in})[0].reshape(total_chunk_size).tolist()
        print(len(predictions), len(date_labels))
        plt.plot_date(date_labels, predictions, 'b-', label="RNN Predictions")
        plt.plot_date(date_labels, stock_price.values[:-4], 'r-', label='Stock Prices')
        plt.legend()
        plt.ylabel('Price')
        plt.xlabel('Year')
        plt.show()