策略分享

一份用交易数据预测股票代码分享

由ypyu创建,最终由ypyu 被浏览 272 用户

ps:非本人代码,来自网络开源分享 分享一篇,科赛网《〈 公开新闻预测A股行业板块动向〉〉比赛第三名的开源方案: 本次比赛使用的tushare免费数据,个人可以复现。

import datetime
import os
import sys
from multiprocessing.pool import Pool

import numpy as np
import pandas as pd
import talib
from loguru import logger
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold
from sklearn.svm import SVC
In [6]:
def channel_index_cross(close_list, lowerband_list, upperband_list):
    """
    轨道类指标的交叉序列
    :param close_list:日线收盘价
    :param lowerband_list:下行轨道
    :param upperband_list:上行轨道
    :return:
    """
    result = []
    result.append(0)
    for i in range(1, len(close_list)):
        if close_list[i - 1] < lowerband_list[i - 1] and close_list[i] > lowerband_list[i]:
            result.append(1)
        elif close_list[i - 1] > upperband_list[i - 1] and close_list[i] < upperband_list[i]:
            result.append(-1)
        else:
            result.append(0)
    return result

def index_cross(short_index_values, long_index_values):
    """
    计算指标金叉死叉的序列
    :param short_index_values: 短周期指标的值
    :param long_index_values: 长周期指标的值
    :return:
    """
    result = []
    result.append(0)  # 补齐第一天
    for i in range(1, len(short_index_values)):
        if short_index_values[i - 1] < long_index_values[i] and short_index_values[i] > long_index_values[i]:
            result.append(1)
        elif short_index_values[i - 1] > long_index_values[i] and short_index_values[i] < long_index_values[i]:
            result.append(-1)
        else:
            result.append(0)
    return result
    
def index_oversold_overbought(day_index_values, min_index_value=20, max_index_value=80):
    """
    计算指标值满足超买超卖情况的序列
    :param day_index_values: 日线指标的值
    :param min_index_value:超卖状态的指标的值,如果指标小于该值,则记录为1
    :param max_index_value:超买状态的指标的值,如果指标大于该值,则记录为-1
    :return:
    """
    res = []
    condition_min_list = (day_index_values < min_index_value).tolist()
    condition_max_list = (day_index_values > max_index_value).tolist()
    for i in range(len(condition_max_list)):
        if condition_min_list[i]:
            res.append(1)
        elif condition_max_list[i]:
            res.append(-1)
        else:
            res.append(0)
    return res
    
def calc_oversold_overbought_value(n, bins):
    """
    计算超卖值,超买值
    :param n: hist统计出来的n
    :param bins:hist统计出来的bins
    :return:(超卖值,超买值)
    """
    oversold = 0
    overbought = 0
    for i in range(len(n)):
        if n[i] > 0:
            oversold = bins[i].right
            break
    for i in range(len(n) - 1, -1, -1):
        if n[i] > 0:
            overbought = bins[i].left
            break
    return oversold, overbought
    
def calculate_oversold_overbought(index_values, min_value=0, max_value=100, step=10):
    """
    计算超买超卖区域
    :param index_values: 指标值
    :param min_value: 指标最小值
    :param max_value: 指标最大值
    :param step: 统计超买超卖区域的步长
    :return: (超卖值,超买值)
    """
    index_counter_series = pd.Series(index_values).value_counts(normalize=False, sort=False, ascending=True,
                                                                bins=[t for t in
                                                                      range(min_value, max_value + step, step)])
    n = index_counter_series.values.tolist()
    bins = index_counter_series.index.values
    return calc_oversold_overbought_value(n, bins)
In [7]:
pd.set_option('mode.chained_assignment', None)
logger.add("data_hive.log", format="{time} {level} {line} {message}", level='DEBUG', rotation="5 MB")

algorithms = [
    [RandomForestClassifier(random_state=1, n_estimators=100, min_samples_split=4, min_samples_leaf=2),
     ["boll_upper", "boll_middle", "boll_lower", "k", "d", "j", "macd", "macd_signal", "macd_hist",
      "cci6", "rsi6", "willr6", "cci10", "rsi10", "willr10",
      'boll_cross', 'kd_cross', 'kj_cross', 'cci_cross', 'rsi_cross', 'willr_cross',
      'over_cci6', 'over_cci10', 'over_willr6', 'over_willr10', 'over_j']],
    [LogisticRegression(random_state=1, solver='liblinear'),
     ["boll_upper", "boll_middle", "boll_lower", "k", "d", "j", "macd", "macd_signal", "macd_hist",
      "cci6", "rsi6", "willr6", "cci10", "rsi10", "willr10",
      'boll_cross', 'kd_cross', 'kj_cross', 'cci_cross', 'rsi_cross', 'willr_cross',
      'over_cci6', 'over_cci10', 'over_willr6', 'over_willr10', 'over_j']],
    [SVC(C=1.0, kernel='linear', probability=True),
     ["boll_upper", "boll_middle", "boll_lower", "k", "d", "j", "macd", "macd_signal", "macd_hist",
      "cci6", "rsi6", "willr6", "cci10", "rsi10", "willr10",
      'boll_cross', 'kd_cross', 'kj_cross', 'cci_cross', 'rsi_cross', 'willr_cross',
      'over_cci6', 'over_cci10', 'over_willr6', 'over_willr10', 'over_j']]
]


def __get_n_day_str():
    """
    获取未来三天日期
    :return:
    """
    today = datetime.date.today()
    today_str = datetime.datetime.strftime(today, "%Y%m%d")
    dt = datetime.datetime.strptime(today_str, "%Y%m%d")
    delta_day = datetime.timedelta(days=1)
    date1 = dt + delta_day
    date2 = date1 + delta_day
    date3 = date2 + delta_day
    return [int(datetime.datetime.strftime(date1, "%Y%m%d")), int(datetime.datetime.strftime(date2, "%Y%m%d")), int(datetime.datetime.strftime(date3, "%Y%m%d"))]


date_list = __get_n_day_str()


def load_data(csv_data_fullname):
    day_dataframe = pd.read_csv(csv_data_fullname)
    result_list = method_name2(day_dataframe)
    if os.path.exists('result.csv'):
        os.remove('result.csv')
    with open('result.csv', 'w') as f:
        f.write('{0},{1},{2}\n'.format('ts_code', 'trade_date', 'p'))
        for p_result in result_list:
            f.write('{0},{1},{2}\n'.format(p_result[0][0:6], p_result[1], p_result[2]))


def method_name2(dataframe):
    result_list = []
    pool = Pool()
    tuple_list = list()
    for stock_id, day_dataframe in dataframe.groupby('ts_code'):
        day_dataframe = day_dataframe.sort_values(by=['trade_date'])
        tuple_list.append((stock_id, day_dataframe))
    temp_result_list_2d = pool.map(task, tuple_list)

    for temp_list in temp_result_list_2d:
        result_list.extend(temp_list)
    pool.close()
    pool.join()
    return result_list


def task(tuple_stockid_daydataframe):
    stock_id = tuple_stockid_daydataframe[0]
    day_dataframe = tuple_stockid_daydataframe[1]
    logger.debug('板块ID:{0}'.format(stock_id))
    result_list = list()
    stock_name = day_dataframe.iloc[0]['name']

    day_open_series = day_dataframe['open']
    day_high_series = day_dataframe['high']
    day_low_series = day_dataframe['low']
    day_close_series = day_dataframe['close']

    logger.debug('初始化结果值开始:{0}'.format(stock_id))
    y1, y2, y3 = init_target_values(day_close_series, day_open_series)
    day_dataframe['y1'] = y1
    day_dataframe['y2'] = y2
    day_dataframe['y3'] = y3
    logger.debug('初始化结果值完成:{0}'.format(stock_id))

    logger.debug('初始化指标值开始:{0}'.format(stock_id))
    init_index_values(day_high_series, day_low_series, day_close_series, day_dataframe)
    logger.debug('初始化指标值完成:{0}'.format(stock_id))

    day_data_dataframe = pd.DataFrame(day_dataframe,
                                      columns=['open', 'low', 'low', 'close', 'change', 'vol', 'amount',
                                               'pe', 'pb',
                                               'y1', 'y2', 'y3',
                                               'boll_upper', 'boll_middle', 'boll_lower',
                                               'k', 'd', 'j',
                                               'macd', 'macd_signal', 'macd_hist',
                                               'cci6', 'rsi6', 'willr6',
                                               'cci10', 'rsi10', 'willr10',
                                               'boll_cross', 'kd_cross', 'kj_cross', 'cci_cross', 'rsi_cross', 'willr_cross',
                                               'over_cci6', 'over_cci10', 'over_willr6', 'over_willr10', 'over_j'
                                               ])
    logger.debug('开始预测:{0}'.format(stock_id))
    pre_value1 = get_pre_value(algorithms=algorithms, data_dataframe=day_data_dataframe, target_label='y1')
    pre_value2 = get_pre_value(algorithms=algorithms, data_dataframe=day_data_dataframe, target_label='y2')
    pre_value3 = get_pre_value(algorithms=algorithms, data_dataframe=day_data_dataframe, target_label='y3')
    logger.debug('预测完成:{0}'.format(stock_id))

    result_list.append((stock_id, date_list[0], pre_value1))
    result_list.append((stock_id, date_list[1], pre_value2))
    result_list.append((stock_id, date_list[2], pre_value3))
    return result_list


def get_pre_value(algorithms, data_dataframe, target_label='y1'):
    # pre_value = 0.5
    pre_value1 = get_pre_value_by_alg(algorithms[0][0], algorithms[0][1], target_label, data_dataframe)
    pre_value2 = get_pre_value_by_alg(algorithms[1][0], algorithms[1][1], target_label, data_dataframe)
    pre_value3 = get_pre_value_by_alg(algorithms[2][0], algorithms[2][1], target_label, data_dataframe)
    logger.debug('alg1:{0}'.format(pre_value1))
    logger.debug('alg2:{0}'.format(pre_value2))
    logger.debug('alg3:{0}'.format(pre_value3))
    if pre_value1 + pre_value2 + pre_value3 > 1.0:  # 三种算法举手表决,如果两个算法预测上涨,则为上涨,否则为下跌
        pre_value = 1
    else:
        pre_value = 0
    return pre_value


def get_pre_value_by_alg(alg, predictors, target_label, data_dataframe):
    train_target1 = data_dataframe[target_label].iloc[:]
    alg.fit(data_dataframe[predictors].iloc[:], train_target1)
    alg_result1 = alg.predict([data_dataframe[predictors].iloc[-1, :].values.tolist()])
    return alg_result1[0]


def init_index_values(high_series, low_series, close_series, group):
    lowerband, middleband, upperband = calc_boll_values(close_series)
    k, d, j = calc_kdj_values(close_series, high_series, low_series)
    macd, macd_hist, macd_signal = calc_macd_values(close_series)
    cci_6 = calc_cci_values(close_series, high_series, low_series, timeperiod=6)
    rsi_6 = calc_rsi_values(close_series, timeperiod=6)
    willr_6 = calc_willr_values(close_series, high_series, low_series, timeperiod=6)
    cci_10 = calc_cci_values(close_series, high_series, low_series, timeperiod=10)
    rsi_10 = calc_rsi_values(close_series, timeperiod=10)
    willr_10 = calc_willr_values(close_series, high_series, low_series, timeperiod=10)
    group['boll_upper'] = upperband
    group['boll_middle'] = middleband
    group['boll_lower'] = lowerband
    group['k'] = k
    group['d'] = d
    group['j'] = j
    group['macd'] = macd
    group['macd_signal'] = macd_signal
    group['macd_hist'] = macd_hist
    group['cci6'] = cci_6
    group['rsi6'] = rsi_6
    group['willr6'] = willr_6
    group['cci10'] = cci_10
    group['rsi10'] = rsi_10
    group['willr10'] = willr_10
    boll_cross_list = channel_index_cross(close_series.values.tolist(), lowerband.tolist(), upperband.tolist())
    kd_cross = index_cross(k.tolist(), d.tolist())
    kj_cross = index_cross(k.tolist(), j.tolist())
    cci_cross = index_cross(cci_6.tolist(), cci_10.tolist())
    rsi_cross = index_cross(rsi_6.tolist(), rsi_10.tolist())
    willr_cross = index_cross(willr_6.tolist(), willr_10.tolist())
    group['boll_cross'] = boll_cross_list
    group['kd_cross'] = kd_cross
    group['kj_cross'] = kj_cross
    group['cci_cross'] = cci_cross
    group['rsi_cross'] = rsi_cross
    group['willr_cross'] = willr_cross

    oversold_cci6, overbought_cci6 = calculate_oversold_overbought(cci_6, min_value=-100, max_value=100)
    over_cci6 = index_oversold_overbought(cci_6, oversold_cci6, overbought_cci6)
    oversold_cci10, overbought_cci10 = calculate_oversold_overbought(cci_10, min_value=-100, max_value=100)
    over_cci10 = index_oversold_overbought(cci_10, oversold_cci10, overbought_cci10)
    oversold_willr6, overbought_willr6 = calculate_oversold_overbought(willr_6, min_value=-100, max_value=100)
    over_willr6 = index_oversold_overbought(willr_6, oversold_willr6, overbought_willr6)
    oversold_willr10, overbought_willr10 = calculate_oversold_overbought(willr_10, min_value=-100, max_value=100)
    over_willr10 = index_oversold_overbought(willr_10, oversold_willr10, overbought_willr10)
    oversold_j, overbought_j = calculate_oversold_overbought(j, min_value=-100, max_value=100)
    over_j = index_oversold_overbought(j, oversold_j, overbought_j)

    group['over_cci6'] = over_cci6
    group['over_cci10'] = over_cci10
    group['over_willr6'] = over_willr6
    group['over_willr10'] = over_willr10
    group['over_j'] = over_j
    # deviate_cci_6 = bottom_deviate(close_series.values, cci_6, timeperiod=5)
    # deviate_cci_10 = bottom_deviate(close_series.values, cci_10, timeperiod=5)
    # deviate_willr_6 = bottom_deviate(close_series.values, willr_6, timeperiod=5)
    # deviate_willr_10 = bottom_deviate(close_series.values, willr_10, timeperiod=5)
    # deviate_macd = bottom_deviate(close_series.values, macd, timeperiod=5)
    # deviate_macd_hist = bottom_deviate(close_series.values, macd_hist, timeperiod=5)
    # deviate_macd_signal = bottom_deviate(close_series.values, macd_signal, timeperiod=5)
    # group['deviate_cci6'] = deviate_cci_6
    # group['deviate_cci10'] = deviate_cci_10
    # group['deviate_willr6'] = deviate_willr_6
    # group['deviate_willr10'] = deviate_willr_10
    # group['deviate_macd'] = deviate_macd
    # group['deviate_macd_hist'] = deviate_macd_hist
    # group['deviate_macd_signal'] = deviate_macd_signal


def calc_willr_values(close_series, high_series, low_series, timeperiod=6):
    willr_series = talib.WILLR(high_series, low_series, close_series, timeperiod=timeperiod)
    willr = np.where(np.isnan(willr_series), -50, willr_series)
    return willr


def calc_rsi_values(close_series, timeperiod=6):
    rsi_series = talib.RSI(close_series, timeperiod=timeperiod)
    rsi = np.where(np.isnan(rsi_series), 50, rsi_series)
    return rsi


def calc_macd_values(close_series):
    macd_series, macd_signal_series, macd_hist_series = talib.MACD(close_series, fastperiod=12, slowperiod=26,
                                                                   signalperiod=9)
    macd = np.where(np.isnan(macd_series), 0, macd_series)
    macd_signal = np.where(np.isnan(macd_signal_series), 0, macd_signal_series)
    macd_hist = np.where(np.isnan(macd_hist_series), 0, macd_hist_series)
    return macd, macd_hist, macd_signal


def calc_kdj_values(close_series, high_series, low_series):
    k_series, d_series = talib.STOCH(high_series, low_series, close_series, fastk_period=9, slowk_period=3,
                                     slowk_matype=0, slowd_period=3, slowd_matype=0)
    k = np.where(np.isnan(k_series), 50, k_series)
    d = np.where(np.isnan(d_series), 50, d_series)
    j = 3 * k - 2 * d
    return k, d, j


def calc_cci_values(close_series, high_series, low_series, timeperiod=6):
    cci_series = talib.CCI(high_series, low_series, close_series, timeperiod=timeperiod)
    cci = np.where(np.isnan(cci_series), 0, cci_series)
    return cci


def calc_boll_values(close_series):
    upperband_series, middleband_series, lowerband_series = talib.BBANDS(close_series, timeperiod=20, nbdevup=2,
                                                                         nbdevdn=2, matype=0)
    upperband = np.where(np.isnan(upperband_series), 0, upperband_series)
    middleband = np.where(np.isnan(middleband_series), 0, middleband_series)
    lowerband = np.where(np.isnan(lowerband_series), 0, lowerband_series)
    return lowerband, middleband, upperband


def init_target_values(close_series, open_series):
    y1 = []
    y2 = []
    y3 = []
    for i in range(close_series.shape[0] - 1):
        if i + 1 < close_series.shape[0]:
            if close_series.iloc[i + 1] > open_series.iloc[i]:
                y1.append(1)
            else:
                y1.append(0)
        if i + 2 < close_series.shape[0]:
            if close_series.iloc[i + 2] > open_series.iloc[i]:
                y2.append(1)
            else:
                y2.append(0)
        if i + 3 < close_series.shape[0]:
            if close_series.iloc[i + 3] > open_series.iloc[i]:
                y3.append(1)
            else:
                y3.append(0)
    y1.append(0)
    y2.append(0)
    y2.append(0)
    y3.append(0)
    y3.append(0)
    y3.append(0)

    return y1, y2, y3
In [8]:
load_data('/home/kesci/input/input9882/TRAINSET_STOCK.csv')

详情请看连接https://www.kesci.com/home/project/5d353597cf76a60036f579ad 非常欢迎讨论与分享,希望可以在bigquant平台做成策略 QQ2260068285希望可以认识更多大佬

\

标签

交易数据