M.random_forest_predict

定义

M.random_forest_predict.v2(self, model, data, date_col='date', instrument_col='instrument', sort=True)

使用随机森林模型对数据进行预测。如果是回归任务,结果是数值;如果是分类任务,结果是类别和对应类别的概率。

参数:
  • model (字符串) – 模型,参考 随机森林模型 M.random_forest_train
  • data (DataSource) – 待预测数据集。
  • date_col (str) – 日期列名,如果在表达式中用到切面相关函数时,比如 rank,会用到此列名;默认值是date。
  • instrument_col (str) – 证券代码列名,如果在表达式中用到时间序列相关函数时,比如 shift,会用到此列名;默认值是instrument。
  • sort (bool) – 是否对结果排序;默认值是True。
返回:

  • .predictions: 预测结果

返回类型:

Outputs

示例代码

In [1]:
# 基础参数配置
class conf:
    start_date = '2009-01-01'
    end_date='2017-06-21'
    # split_date 之前的数据用于训练,之后的数据用作效果评估
    split_date = '2015-01-01'
    # D.instruments: https://bigquant.com/docs/data_instruments.html
    instruments = D.instruments(start_date, end_date)

    # 机器学习目标标注函数
    # 如下标注函数等价于 min(max((持有期间的收益 * 100), -20), 20) + 20 (后面的M.fast_auto_labeler会做取整操作)
    # 说明:max/min这里将标注分数限定在区间[-20, 20],+20将分数变为非负数 (StockRanker要求标注分数非负整数)
    label_expr = ['return* 30', 'where(label > {0}, {0}, where(label < -{0}, -{0}, label)) + {0}'.format(3)]
    # 持有天数,用于计算label_expr中的return值(收益)
    hold_days = 10
    features = [
        'ta_sma_10_0/ta_sma_20_0',
        'ta_sma_20_0/ta_sma_30_0',
        'ta_sma_30_0/ta_sma_60_0',
        'ta_atr_14_0',
        'ta_atr_28_0',
        'ta_rsi_14_0',
        'ta_rsi_28_0',
    ]


# 给数据做标注:给每一行数据(样本)打分,一般分数越高表示越好
m1 = M.fast_auto_labeler.v6(
    instruments=conf.instruments, start_date=conf.start_date, end_date=conf.end_date,
    label_expr=conf.label_expr, hold_days=conf.hold_days,
    benchmark='000300.SHA', sell_at='close', buy_at='open', is_regression=False)
# 计算特征数据
m2 = M.general_feature_extractor.v5(
    instruments=conf.instruments, start_date=conf.start_date, end_date=conf.end_date,
    features=conf.features)
# 数据预处理:缺失数据处理,数据规范化,T.get_stock_ranker_default_transforms为StockRanker模型做数据预处理
m3=M.add_columns.v1(data=m2.data, eval_list=conf.features)
m4 = M.transform.v2(
    data=m3.data, transforms=None,
    drop_null=True, astype='float32', except_columns=['date', 'instrument'],
    clip_lower=0, clip_upper=200000000)
# 合并标注和特征数据
m5 = M.join.v2(data1=m4.data, data2=m1.data, on=['date', 'instrument'], sort=True)

# 训练数据集
m6_training = M.filter.v2(data=m5.data, expr='date < "%s"' % conf.split_date)
# 评估数据集
m6_evaluation = M.filter.v2(data=m5.data, expr='"%s" <= date' % conf.split_date)

m7 = M.random_forest_train.v1(training_ds=m6_training.data, features=conf.features, is_regression=False,n_jobs=4)
[2017-07-14 21:10:22.846254] INFO: bigquant: fast_auto_labeler.v6 start ..
[2017-07-14 21:10:22.851121] INFO: bigquant: hit cache
bigcharts-data-start/{"stock":false,"title":{"text":"label"},"legend":{"enabled":true},"chart":{"height":400,"renderTo":"bigchart-52b0be07435e4caba4a20a7277723f18","type":"column"},"xAxis":{"title":{"text":"label"}},"series":[{"yAxis":0,"name":"count","data":[[0,826198],[1,554623],[2,699471],[3,746240],[4,589132],[5,397243],[6,686421]]}]}/bigcharts-data-end
[2017-07-14 21:10:22.866956] INFO: bigquant: fast_auto_labeler.v6 end [0.02073s].
[2017-07-14 21:10:22.876641] INFO: bigquant: general_feature_extractor.v5 start ..
[2017-07-14 21:10:22.937423] INFO: bigquant: hit cache
[2017-07-14 21:10:22.938567] INFO: bigquant: general_feature_extractor.v5 end [0.061939s].
[2017-07-14 21:10:22.944853] INFO: bigquant: add_columns.v1 start ..
[2017-07-14 21:10:22.947686] INFO: bigquant: hit cache
[2017-07-14 21:10:22.948540] INFO: bigquant: add_columns.v1 end [0.003695s].
[2017-07-14 21:10:22.955116] INFO: bigquant: transform.v2 start ..
[2017-07-14 21:10:22.956762] INFO: bigquant: hit cache
[2017-07-14 21:10:22.957499] INFO: bigquant: transform.v2 end [0.002387s].
[2017-07-14 21:10:22.963824] INFO: bigquant: join.v2 start ..
[2017-07-14 21:10:22.965771] INFO: bigquant: hit cache
[2017-07-14 21:10:22.966513] INFO: bigquant: join.v2 end [0.002689s].
[2017-07-14 21:10:22.972433] INFO: bigquant: filter.v2 start ..
[2017-07-14 21:10:23.036872] INFO: bigquant: hit cache
[2017-07-14 21:10:23.038329] INFO: bigquant: filter.v2 end [0.065882s].
[2017-07-14 21:10:23.043643] INFO: bigquant: filter.v2 start ..
[2017-07-14 21:10:23.045657] INFO: bigquant: hit cache
[2017-07-14 21:10:23.046459] INFO: bigquant: filter.v2 end [0.00282s].
[2017-07-14 21:10:23.559363] INFO: bigquant: random_forest_train.v1 start ..
[2017-07-14 21:10:23.637344] INFO: bigquant: hit cache
[2017-07-14 21:10:23.638511] INFO: bigquant: random_forest_train.v1 end [0.079171s].
In [2]:
stock_num=40
# 3. 策略主体函数
# 初始化虚拟账户状态,只在第一个交易日运行
def initialize(context):
    # 设置手续费,买入时万3,卖出是千分之1.3,不足5元以5元计
    context.set_commission(PerOrder(buy_cost=0.0003, sell_cost=0.0013, min_cost=5))
    m8 = M.random_forest_predict.v1(model=context.options['model'],data=m6_evaluation.data)
    context.pred_df = m8.predictions.read_df()
    context.pred_df = context.pred_df.groupby('date').apply(lambda x:x.sort_values('pred_label',ascending=False))

# 策略交易逻辑,每个交易日运行一次
def handle_data(context,data):
    today = data.current_dt
    today_str=str(today.date())

    equities = {e.symbol: p for e, p in context.portfolio.positions.items() if p.amount>0}

    # 调仓:卖出所有持有股票
    for instrument in equities:
        # 停牌的股票,将不能卖出,将在下一个调仓期处理
        if data.can_trade(context.symbol(instrument)) and today-equities[instrument].last_sale_date>=datetime.timedelta(context.options['rebalance_period']):
            context.order_target_percent(context.symbol(instrument), 0)

    # 调仓:买入新的股票
    if today_str not in context.pred_df.index:
        return
    instruments_to_buy = context.pred_df.ix[today_str].instrument
    if len(instruments_to_buy) == 0:
        return
    # 等量分配资金买入股票
    weight = 1.0 / stock_num
    can_buy_num = stock_num - len(equities)
    for instrument in instruments_to_buy:
        if can_buy_num>0 and data.can_trade(context.symbol(instrument)) and instrument not in equities:
            context.order_target_percent(context.symbol(instrument), weight)
            can_buy_num -= 1

# 4. 策略回测:https://bigquant.com/docs/module_trade.html
m = M.trade.v1(
    instruments=conf.instruments,
    start_date=conf.split_date,
    end_date=conf.end_date,
    initialize=initialize,
    handle_data=handle_data,
    # 买入订单以开盘价成交
    order_price_field_buy='open',
    # 卖出订单以开盘价成交
    order_price_field_sell='close',
    capital_base=1000000,
    benchmark='000300.SHA',
    # 传入数据给回测模块,所有回测函数里用到的数据都要从这里传入,并通过 context.options 使用,否则可能会遇到缓存问题
    options={'rebalance_period': conf.hold_days, 'model':m7.model}
)
[2017-07-14 21:10:24.458770] INFO: bigquant: backtest.v6 start ..
[2017-07-14 21:10:52.309326] INFO: bigquant: random_forest_predict.v1 start ..
[2017-07-14 21:11:05.012010] INFO: bigquant: random_forest_predict.v1 end [12.702671s].
[2017-07-14 21:11:20.833604] INFO: Performance: Simulated 600 trading days out of 600.
[2017-07-14 21:11:20.834713] INFO: Performance: first open: 2015-01-05 14:30:00+00:00
[2017-07-14 21:11:20.835517] INFO: Performance: last close: 2017-06-21 19:00:00+00:00
bigcharts-data-start/{"__id":"bigchart-acbc741461194a538aa958218498eaf8","__type":"tabs"}/bigcharts-data-end
[2017-07-14 21:11:23.369218] INFO: bigquant: backtest.v6 end [58.910415s].
In [ ]: