import pandas as pd
import numpy as np
df = pd.read_csv("datas/zgpa_train.csv")
df.head()
price = df["close"]
# 归一化
price_norm = price/max(price)
import matplotlib.pyplot as plt
%matplotlib inline
plt.figure(figsize=(5,3))
plt.plot(price)
plt.xlabel("time")
plt.ylabel("price")
plt.show()
# 提取x和y
def extract_data(data, time_step):
X = []
y = []
for i in range(len(data) - time_step):
X.append([a for a in data[i: i+time_step]])
y.append(data[i+time_step])
X = np.array(X)
X = X.reshape(X.shape[0], X.shape[1], 1)
return X, y
#样本大小
time_step=8
# 定义x和y 用前八位预测第九位
X, y = extract_data(price_norm, time_step)
from keras.models import Sequential
from keras.layers import Dense, SimpleRNN
#建立模型
model = Sequential()
#添加rnn层
model.add(SimpleRNN(units=5, input_shape=(time_step, 1), activation="relu"))
#输出层
model.add(Dense(units=1, activation="linear"))
#模型配置
model.compile(optimizer="adam", loss="mean_squared_error")
#模型训练 损失不变可以重新载入模型
model.fit(X, np.array(y), batch_size=30, epochs=200)
#预测训练数据
y_train_predict = model.predict(X) * max(price)
y_train = [i * max(price) for i in y]
plt.figure(figsize=(5,3))
plt.plot(y_train_predict, label="predict price")
plt.plot(y_train, label="true price")
plt.xlabel("time")
plt.ylabel("price")
plt.legend()
plt.show()
#预测测试数据
test_data = pd.read_csv("datas/zgpa_test.csv")
test_data.head()
price_test = test_data["close"]
#归一化 统一分母
price_test_norm = price_test/max(price)
x_test_norm, y_test_norm = extract_data(price_test_norm, time_step)
# 预测测试数据
y_test_predict = model.predict(x_test_norm) * max(price)
y_test = [i*max(price) for i in y_test_norm]
plt.figure(figsize=(5,3))
plt.plot(y_test_predict, label="test predict price")
plt.plot(y_test, label="test true price")
plt.xlabel("time")
plt.ylabel("price")
plt.legend()
plt.show()
#存储数据
result_y_test = np.array(y_test).reshape(-1, 1)
result_y_test_predict = y_test_predict
print(result_y_test.shape, result_y_test_predict.shape)
#合并数组
result = np.concatenate((result_y_test, result_y_test_predict), axis=1)
result = pd.DataFrame(result, columns=["real_price_test", "predict_price_test"])
result.to_csv("zgpa_predict_test.csv")
#预测结果会慢一步