M.linear_sgd_predict

定义

M.linear_sgd_predict.v2(self, model, data)

使用线性随机梯度下降模型对数据进行预测。如果是回归任务,结果是数值;如果是分类任务,结果是类别。

参数:
返回:

  • .predictions: 预测结果
  • .start_date: 预测数据开始日期
  • .end_date: 预测数据结束日期
  • .instruments: 预测数据里的所有证券代码

返回类型:

Outputs

示例代码

In [2]:
# 基础参数配置
class conf:
    start_date = '2009-01-01'
    end_date='2017-06-21'
    # split_date 之前的数据用于训练,之后的数据用作效果评估
    split_date = '2015-01-01'
    # D.instruments: https://bigquant.com/docs/data_instruments.html
    instruments = D.instruments(start_date, end_date)

    # 机器学习目标标注函数
    # 如下标注函数等价于 min(max((持有期间的收益 * 100), -20), 20) + 20 (后面的M.fast_auto_labeler会做取整操作)
    # 说明:max/min这里将标注分数限定在区间[-20, 20],+20将分数变为非负数 (StockRanker要求标注分数非负整数)
    label_expr = ['return* 30', 'where(label > {0}, {0}, where(label < -{0}, -{0}, label)) + {0}'.format(1)]
    # 持有天数,用于计算label_expr中的return值(收益)
    hold_days = 10
    features = [
        'ta_sma_10_0/ta_sma_20_0',
        'ta_sma_20_0/ta_sma_30_0',
        'ta_sma_30_0/ta_sma_60_0',
        'ta_atr_14_0',
        'ta_atr_28_0',
        'ta_rsi_14_0',
        'ta_rsi_28_0',
    ]


# 给数据做标注:给每一行数据(样本)打分,一般分数越高表示越好
m1 = M.fast_auto_labeler.v6(
    instruments=conf.instruments, start_date=conf.start_date, end_date=conf.end_date,
    label_expr=conf.label_expr, hold_days=conf.hold_days,
    benchmark='000300.SHA', sell_at='close', buy_at='open', is_regression=False)
# 计算特征数据
m2 = M.general_feature_extractor.v5(
    instruments=conf.instruments, start_date=conf.start_date, end_date=conf.end_date,
    features=conf.features)
# 数据预处理:缺失数据处理,数据规范化,T.get_stock_ranker_default_transforms为StockRanker模型做数据预处理
m3=M.add_columns.v1(data=m2.data, eval_list=conf.features)
m4 = M.transform.v2(
    data=m3.data, transforms=None,
    drop_null=True, astype='float32', except_columns=['date', 'instrument'],
    clip_lower=0, clip_upper=200000000)
# 合并标注和特征数据
m5 = M.join.v2(data1=m4.data, data2=m1.data, on=['date', 'instrument'], sort=True)

# 训练数据集
m6_training = M.filter.v2(data=m5.data, expr='date < "%s"' % conf.split_date)
# 评估数据集
m6_evaluation = M.filter.v2(data=m5.data, expr='"%s" <= date' % conf.split_date)

m7 = M.linear_sgd_train.v1(training_ds=m6_training.data, features=conf.features, is_regression=False)
[2017-06-23 20:05:56.535378] INFO: bigquant: fast_auto_labeler.v6 start ..
[2017-06-23 20:05:56.539597] INFO: bigquant: hit cache
bigcharts-data-start/{"title":{"text":"label"},"xAxis":{"title":{"text":"label"}},"chart":{"renderTo":"bigchart-5bc59ceac9e04b55b52f529917e060fa","type":"column","height":400},"stock":false,"legend":{"enabled":true},"series":[{"yAxis":0,"name":"count","data":[[0,2080292],[1,746240],[2,1672796]]}]}/bigcharts-data-end
[2017-06-23 20:05:56.552951] INFO: bigquant: fast_auto_labeler.v6 end [0.017601s].
[2017-06-23 20:05:56.561282] INFO: bigquant: general_feature_extractor.v5 start ..
[2017-06-23 20:05:56.563612] INFO: bigquant: hit cache
[2017-06-23 20:05:56.566138] INFO: bigquant: general_feature_extractor.v5 end [0.004859s].
[2017-06-23 20:05:56.570499] INFO: bigquant: add_columns.v1 start ..
[2017-06-23 20:05:56.573136] INFO: bigquant: hit cache
[2017-06-23 20:05:56.574253] INFO: bigquant: add_columns.v1 end [0.003752s].
[2017-06-23 20:05:56.578815] INFO: bigquant: transform.v2 start ..
[2017-06-23 20:05:56.580711] INFO: bigquant: hit cache
[2017-06-23 20:05:56.581536] INFO: bigquant: transform.v2 end [0.002705s].
[2017-06-23 20:05:56.586636] INFO: bigquant: join.v2 start ..
[2017-06-23 20:05:56.588247] INFO: bigquant: hit cache
[2017-06-23 20:05:56.588948] INFO: bigquant: join.v2 end [0.002311s].
[2017-06-23 20:05:56.592538] INFO: bigquant: filter.v2 start ..
[2017-06-23 20:05:56.594089] INFO: bigquant: hit cache
[2017-06-23 20:05:56.595005] INFO: bigquant: filter.v2 end [0.002464s].
[2017-06-23 20:05:56.597819] INFO: bigquant: filter.v2 start ..
[2017-06-23 20:05:56.600107] INFO: bigquant: hit cache
[2017-06-23 20:05:56.601036] INFO: bigquant: filter.v2 end [0.003219s].
[2017-06-23 20:05:56.890670] INFO: bigquant: linear_sgd_train.v1 start ..
[2017-06-23 20:05:56.893002] INFO: bigquant: hit cache
[2017-06-23 20:05:56.893766] INFO: bigquant: linear_sgd_train.v1 end [0.003125s].
In [3]:
stock_num=40
# 3. 策略主体函数
# 初始化虚拟账户状态,只在第一个交易日运行
def initialize(context):
    # 设置手续费,买入时万3,卖出是千分之1.3,不足5元以5元计
    context.set_commission(PerOrder(buy_cost=0.0003, sell_cost=0.0013, min_cost=5))
    m8 = M.linear_sgd_predict.v1(model=context.options['model'],data=m6_evaluation.data)
    context.pred_df = m8.predictions.read_df()
    context.pred_df = context.pred_df.groupby('date').apply(lambda x:x.sort_values('pred_label',ascending=False))

# 策略交易逻辑,每个交易日运行一次
def handle_data(context,data):
    today = data.current_dt
    today_str=str(today.date())

    equities = {e.symbol: p for e, p in context.portfolio.positions.items() if p.amount>0}

    # 调仓:卖出所有持有股票
    for instrument in equities:
        # 停牌的股票,将不能卖出,将在下一个调仓期处理
        if data.can_trade(context.symbol(instrument)) and today-equities[instrument].last_sale_date>=datetime.timedelta(context.options['rebalance_period']):
            context.order_target_percent(context.symbol(instrument), 0)

    # 调仓:买入新的股票
    if today_str not in context.pred_df.index:
        return
    instruments_to_buy = context.pred_df.ix[today_str].instrument
    if len(instruments_to_buy) == 0:
        return
    # 等量分配资金买入股票
    weight = 1.0 / stock_num
    can_buy_num = stock_num - len(equities)
    for instrument in instruments_to_buy:
        if can_buy_num>0 and data.can_trade(context.symbol(instrument)) and instrument not in equities:
            context.order_target_percent(context.symbol(instrument), weight)
            can_buy_num -= 1

# 4. 策略回测:https://bigquant.com/docs/module_trade.html
m = M.trade.v1(
    instruments=conf.instruments,
    start_date=conf.split_date,
    end_date=conf.end_date,
    initialize=initialize,
    handle_data=handle_data,
    # 买入订单以开盘价成交
    order_price_field_buy='open',
    # 卖出订单以开盘价成交
    order_price_field_sell='close',
    capital_base=1000000,
    benchmark='000300.SHA',
    # 传入数据给回测模块,所有回测函数里用到的数据都要从这里传入,并通过 context.options 使用,否则可能会遇到缓存问题
    options={'rebalance_period': conf.hold_days, 'model':m7.model},
)
[2017-06-23 20:05:56.987479] INFO: bigquant: backtest.v6 start ..
[2017-06-23 20:05:56.989636] INFO: bigquant: hit cache
bigcharts-data-start/{"__id":"bigchart-56db0fc701e645e9a90e7b1aa39318fd","__type":"tabs"}/bigcharts-data-end
[2017-06-23 20:05:58.547311] INFO: bigquant: backtest.v6 end [1.559814s].