问答交流

如何在模拟交易时实现训练数据和预测数据跟随交易日期滚动?

由bqvpogtl创建,最终由bqvpogtl 被浏览 2 用户

from bigmodule import M

# <aistudiograph>

# @param(id="m7", name="initialize")
# 交易引擎:初始化函数,只执行一次
def m7_initialize_bigquant_run(context):
    import math
    import numpy as np

    from bigtrader.finance.commission import PerOrder

    # 系统已经设置了默认的交易手续费和滑点, 要修改手续费可使用如下函数
    context.set_commission(PerOrder(buy_cost=0.0003, sell_cost=0.0013, min_cost=5))
    # 预测数据, 通过 options 传入进来, 使用 read_df 函数, 加载到内存 (DataFrame)
    context.data2 = context.options['data'].read()
    # 设置买入的股票数量, 这里买入预测股票列表排名靠前的10只
    stock_count = 10
    # 每只的股票的权重, 如下的权重分配会使得靠前的股票分配多一点的资金, [0.339160, 0.213986, 0.169580, ..]
    context.options["hold_days"] = 5
# @param(id="m7", name="before_trading_start")
# 交易引擎:每个单位时间开盘前调用一次。
def m7_before_trading_start_bigquant_run(context, data):
    # 盘前处理,订阅行情等
    pass

# @param(id="m7", name="handle_tick")
# 交易引擎:tick数据处理函数,每个tick执行一次
def m7_handle_tick_bigquant_run(context, tick):
    pass

# @param(id="m7", name="handle_data")
def m7_handle_data_bigquant_run(context, data):
    import pandas as pd

    max_hold_days = 20
    max_per_industry = 10

    # 初始化或读取卖出限制和持仓追踪表
    if not hasattr(context, 'sell_restrictions'):
        context.sell_restrictions = {}  # {instrument: 卖出日期}
    if not hasattr(context, 'hold_records'):
        context.hold_records = {}  # {instrument: {'buy_price', 'buy_atr', 'max_price', 'hold_days', 'max_profit'}}

    current_date = data.current_dt.strftime("%Y-%m-%d")
    today_data = context.data[context.data["date"] == current_date]
    current_day_data2 = context.data2[context.data2["date"] == current_date]

    # 风控模块
    if not current_day_data2.empty and 'market_risk_indicator' in current_day_data2.columns:
        current_day_market_risk_indicator = current_day_data2['market_risk_indicator'].iloc[0]
    else:
        current_day_market_risk_indicator = 0

    current_hold_instruments = set(context.get_account_positions().keys())

    # 大盘风控,遇到风控信号则清仓
    if current_day_market_risk_indicator != 0:
        for ins in current_hold_instruments:
            context.order_target_percent(ins, 0)
        return

    # -- 卖出策略部分 --
    remove_ins_list = []
    for ins in current_hold_instruments:
        position = context.get_position(ins)
        if position is None:
            continue
        price = position.last_price
        cost = position.cost_price

        # 获取 today's 行业和 ATR(14)
        stock_info = today_data[today_data['instrument'] == ins]
        industry = stock_info['sw2021_level1'].iloc[0] if not stock_info.empty and "sw2021_level1" in stock_info.columns else None
        atr_14 = stock_info['atr_14'].iloc[0] if not stock_info.empty and "atr_14" in stock_info.columns else None
        if pd.isnull(atr_14) or atr_14 is None:
            # 若当天没有atr_14,回溯上一天(可选),否则用0跳过此止损
            atr_14 = 0

        # --- 追踪持仓最高价、持仓天数、ATR ---
        rec = context.hold_records.get(ins)
        if rec is None:
            # 首日初始化
            rec = {
                'buy_price': cost,
                'buy_atr': atr_14,
                'max_price': price,
                'hold_days': 1,
                'max_profit': (price - cost) / cost if cost > 0 else 0
            }
        else:
            rec['hold_days'] += 1
            if price > rec['max_price']:
                rec['max_price'] = price
            cur_profit = (price - cost) / cost if cost > 0 else 0
            if cur_profit > rec.get('max_profit', 0):
                rec['max_profit'] = cur_profit

        # 止盈1:历史最高点回撤5%
        if rec['max_price'] > 0 and (rec['max_price'] - price) / rec['max_price'] > 0.05:
            context.order_target_percent(ins, 0)
            context.sell_restrictions[ins] = data.current_dt.date()
            remove_ins_list.append(ins)
            continue
        # 止盈2:历史最大盈利超33%
        if rec['max_profit'] >= 0.33:
            context.order_target_percent(ins, 0)
            context.sell_restrictions[ins] = data.current_dt.date()
            remove_ins_list.append(ins)
            continue
        # 止损:价格<买入价-2*ATR(14)
        if atr_14 > 0 and price < rec['buy_price'] - 2 * rec['buy_atr']:
            context.order_target_percent(ins, 0)
            context.sell_restrictions[ins] = data.current_dt.date()
            remove_ins_list.append(ins)
            continue
        # 最大20日持仓
        if rec['hold_days'] >= max_hold_days:
            context.order_target_percent(ins, 0)
            context.sell_restrictions[ins] = data.current_dt.date()
            remove_ins_list.append(ins)
            continue

        context.hold_records[ins] = rec  # 更新/写回

    for ins in remove_ins_list:
        if ins in context.hold_records:
            del context.hold_records[ins]

    # 卖出记录只保留30天冷却期
    context.sell_restrictions = {ins: date for ins, date in context.sell_restrictions.items()
                                 if (data.current_dt.date() - date).days <= 30}

    # -- 非调仓日不买入 --
    if not context.rebalance_period.is_signal_date(data.current_dt.date()):
        return

    # 卖出不在今日目标的股票
    for instrument in sorted(current_hold_instruments - set(today_data["instrument"])):
        context.order_target_percent(instrument, 0)
        # 清理记录
        if instrument in context.hold_records:
            del context.hold_records[instrument]

    # 买入策略(排除卖出限制/冷却期并确保行业分散)
    # 统计现持行业分布
    industry_count = {}
    for ins in context.get_account_positions():
        stock_info = today_data[today_data['instrument'] == ins]
        industry = stock_info['sw2021_level1'].iloc[0] if not stock_info.empty and "sw2021_level1" in stock_info.columns else None
        industry_count[industry] = industry_count.get(industry, 0) + 1

    today_data = today_data.sort_values(['score'], ascending=False)  # 按score高优先
    # 买入
    max_buy_num = 30  # 按实际业务设定
    buy_num = 0
    for _, row in today_data.iterrows():
        ins = row['instrument']
        ind = row.get('sw2021_level1', None)
        if ins in current_hold_instruments:
            continue
        if ins in context.sell_restrictions:
            continue
        # 行业控制
        if industry_count.get(ind, 0) >= max_per_industry:
            continue
        # 下单
        pos = row.get('position', 1.0 / max_buy_num)
        context.order_target_percent(ins, pos)
        industry_count[ind] = industry_count.get(ind, 0) + 1
        # 初始化持仓追踪
        price = data.current(ins, "close")
        atr_14 = row.get('atr_14', None)
        if pd.isnull(atr_14) or atr_14 is None:
            atr_14 = data.current(ins, "m_ta_atr(high,low,close,14)")
            if pd.isnull(atr_14) or atr_14 is None:
                atr_14 = 0
        context.hold_records[ins] = {
            'buy_price': price,
            'buy_atr': atr_14,
            'max_price': price,
            'hold_days': 1,
            'max_profit': 0
        }
        buy_num += 1
        if buy_num >= max_buy_num:
            break
# @param(id="m7", name="handle_trade")
# 交易引擎:成交回报处理函数,每个成交发生时执行一次
def m7_handle_trade_bigquant_run(context, trade):
    pass

# @param(id="m7", name="handle_order")
# 交易引擎:委托回报处理函数,每个委托变化时执行一次
def m7_handle_order_bigquant_run(context, order):
    pass

# @param(id="m7", name="after_trading")
# 交易引擎:盘后处理函数,每日盘后执行一次
def m7_after_trading_bigquant_run(context, data):
    pass

# @module(position="-96,-281", comment="""""", comment_collapsed=True)
m8 = M.cn_stock_basic_selector.v8(
    exchanges=["""上交所""", """深交所"""],
    list_sectors=["""主板""", """创业板""", """科创板"""],
    indexes=["""中证500""", """中证A500""", """上证指数""", """创业板指""", """深证成指""", """上证50""", """沪深300""", """中证1000""", """中证100""", """深证100"""],
    st_statuses=["""正常"""],
    margin_tradings=["""两融标的""", """非两融标的"""],
    sw2021_industries=["""农林牧渔""", """采掘""", """基础化工""", """钢铁""", """有色金属""", """建筑建材""", """机械设备""", """电子""", """汽车""", """交运设备""", """信息设备""", """家用电器""", """食品饮料""", """纺织服饰""", """轻工制造""", """医药生物""", """公用事业""", """交通运输""", """房地产""", """金融服务""", """商贸零售""", """社会服务""", """信息服务""", """银行""", """非银金融""", """综合""", """建筑材料""", """建筑装饰""", """电力设备""", """国防军工""", """计算机""", """传媒""", """通信""", """煤炭""", """石油石化""", """环保""", """美容护理"""],
    drop_suspended=True,
    m_name="""m8"""
)

# @module(position="-103,-178", comment="""因子特征,用表达式构建因子""")
m1 = M.input_features_dai.v30(
    input_1=m8.data,
    mode="""表达式""",
    expr="""c_zscore(c_neutralize((3.0-(ROW_NUMBER() OVER (PARTITION BY date ORDER BY ps_ttm ASC)) / 1000.0), sw2021_level1, float_market_cap)) AS f10
-c_zscore(c_group_avg(sw2021_level1, (close / m_lag(close, 20) - 1))) AS f28
-c_zscore((m_sum(net_active_buy_amount_main, 20) / NULLIF(m_sum(abs(net_active_buy_amount_main), 20), 0)) * 6) AS f20
-c_zscore((net_profit_to_parent_shareholders_ttm / m_lag(net_profit_to_parent_shareholders_ttm, 60) - 1) * 5) AS f9
-c_zscore(m_nanstd(daily_return, 20) / m_nanstd(daily_return, 60)) AS f29
c_zscore(debt_to_asset_lf) AS f4
c_zscore(c_neutralize(dividend_yield_ratio, sw2021_level1, float_market_cap)) AS f15""",
    expr_filters="""-- DAI SQL 算子/函数: https://bigquant.com/wiki/doc/dai-PLSbc1SbZX#h-%E5%87%BD%E6%95%B0
-- 数据&字段: 数据文档 https://bigquant.com/data/home
-- 表达式模式的过滤都是放在 QUALIFY 里, 即数据查询、计算, 最后才到过滤条件
st_status = 0
list_days > 260
float_market_cap > 500000000 -- 流通市值大于5亿
net_profit_yoy_lf > 0 -- 净利润同比增长率(最新一期)
operating_revenue_yoy_lf > 0 -- 营业收入同比增长率(最新一期)
pe_ttm > 0 
pe_ttm < 50
pb < 10 -- 市净率
pe_ttm < 1.5*c_avg(pe_ttm) -- 股票的市盈率不高于市场平均值的1.5 倍
pb < 1.5*c_avg(pb) -- 股票的市净率不高于市场平均值的1.5 倍
-- c_pct_rank(-return_90) <= 0.3
-- c_pct_rank(return_30) <= 0.3
-- cn_stock_bar1d.turn > 0.02
""",
    expr_tables="""cn_stock_prefactors;cn_stock_money_flow;cn_stock_factors_alpha_101""",
    extra_fields="""date, instrument""",
    order_by="""date, instrument""",
    expr_drop_na=True,
    sql="""-- 使用DAI SQL获取数据, 构建因子等, 如下是一个例子作为参考
-- DAI SQL 语法: https://bigquant.com/wiki/doc/dai-PLSbc1SbZX#h-sql%E5%85%A5%E9%97%A8%E6%95%99%E7%A8%8B
-- 使用数据输入1/2/3里的字段: e.g. input_1.close, input_1.* EXCLUDE(date, instrument)

SELECT
    -- 在这里输入因子表达式
    -- DAI SQL 算子/函数: https://bigquant.com/wiki/doc/dai-PLSbc1SbZX#h-%E5%87%BD%E6%95%B0
    -- 数据&字段: 数据文档 https://bigquant.com/data/home
    sw2021_level1,                  -- 申万一级行业代码,作为行业分散依据
    m_ta_atr(high, low, close, 14) AS atr_14,
    m_lag(close, 90) / close AS return_90,
    m_lag(close, 30) / close AS return_30,
    -- 下划线开始命名的列是中间变量, 不会在最终结果输出 (e.g. _rank_return_90)
    c_pct_rank(-return_90) AS _rank_return_90,
    c_pct_rank(return_30) AS _rank_return_30,

    c_rank(volume) AS rank_volume,
    close / m_lag(close, 1) as return_0,

    -- 日期和股票代码
    date, instrument
FROM
    -- 预计算因子 cn_stock_bar1d https://bigquant.com/data/datasources/cn_stock_bar1d
    cn_stock_prefactors
    -- SQL 模式不会自动join输入数据源, 可以根据需要自由灵活的使用
    -- JOIN input_1 USING(date, instrument)
WHERE
    -- WHERE 过滤, 在窗口等计算算子之前执行
    -- 剔除ST股票
    st_status = 0
QUALIFY
    -- QUALIFY 过滤, 在窗口等计算算子之后执行, 比如 m_lag(close, 3) AS close_3, 对于 close_3 的过滤需要放到这里
    -- 去掉有空值的行
    COLUMNS(*) IS NOT NULL
    -- _rank_return_90 是窗口函数结果,需要放在 QUALIFY 里
    AND _rank_return_90 > 0.1
    AND _rank_return_30 < 0.1
-- 按日期和股票代码排序, 从小到大
ORDER BY date, instrument
""",
    extract_data=False,
    m_name="""m1"""
)

# @module(position="-253,-35", comment="""加数据标注""")
m2 = M.input_features_dai.v30(
    input_1=m1.data,
    mode="""表达式""",
    expr="""-- DAI SQL 算子/函数: https://bigquant.com/wiki/doc/dai-PLSbc1SbZX#h-%E5%87%BD%E6%95%B0
-- 数据&字段: 数据文档 https://bigquant.com/data/home
-- 数据使用: 表名.字段名, 对于没有指定表名的列, 会从 expr_tables 推断, 如果同名字段在多个表中出现, 需要显式的给出表名

input_1.* EXCLUDE(date, instrument)
m_lead(close, 5) / m_lead(open, 1) AS _future_return
c_quantile_cont(_future_return, 0.01) AS _future_return_1pct
c_quantile_cont(_future_return, 0.99) AS _future_return_99pct
clip(_future_return, _future_return_1pct, _future_return_99pct) AS _clipped_return
c_cbins(_clipped_return, 20) AS label

-- cn_stock_bar1d.close / cn_stock_bar1d.open
-- cn_stock_prefactors https://bigquant.com/data/datasources/cn_stock_prefactors 是常用因子表(VIEW), JOIN了很多数据表, 性能会比直接用相关表慢一点, 但使用简单
-- cn_stock_prefactors.pe_ttm

-- 表达式模式下, 会自动join输入数据1/2/3, 可以在表达式里直接使用其字段。包括 input_1 的所有列但去掉 date, instrument。注意字段不能有重复的, 否则会报错
-- input_1.* EXCLUDE(date, instrument)
-- input_1.close
-- input_2.close / input_1.close
""",
    expr_filters="""-- DAI SQL 算子/函数: https://bigquant.com/wiki/doc/dai-PLSbc1SbZX#h-%E5%87%BD%E6%95%B0
-- 数据&字段: 数据文档 https://bigquant.com/data/home
-- 表达式模式的过滤都是放在 QUALIFY 里, 即数据查询、计算, 最后才到过滤条件

-- c_pct_rank(-return_90) <= 0.3
-- c_pct_rank(return_30) <= 0.3
-- cn_stock_bar1d.turn > 0.02
""",
    expr_tables="""cn_stock_prefactors""",
    extra_fields="""date, instrument""",
    order_by="""date, instrument""",
    expr_drop_na=True,
    sql="""-- 使用DAI SQL获取数据, 构建因子等, 如下是一个例子作为参考
-- DAI SQL 语法: https://bigquant.com/wiki/doc/dai-PLSbc1SbZX#h-sql%E5%85%A5%E9%97%A8%E6%95%99%E7%A8%8B
-- 使用数据输入1/2/3里的字段: e.g. input_1.close, input_1.* EXCLUDE(date, instrument)

SELECT
    -- 在这里输入因子表达式
    -- DAI SQL 算子/函数: https://bigquant.com/wiki/doc/dai-PLSbc1SbZX#h-%E5%87%BD%E6%95%B0
    -- 数据&字段: 数据文档 https://bigquant.com/data/home

    m_lag(close, 90) / close AS return_90,
    m_lag(close, 30) / close AS return_30,
    -- 下划线开始命名的列是中间变量, 不会在最终结果输出 (e.g. _rank_return_90)
    c_pct_rank(-return_90) AS _rank_return_90,
    c_pct_rank(return_30) AS _rank_return_30,

    c_rank(volume) AS rank_volume,
    close / m_lag(close, 1) as return_0,

    -- 日期和股票代码
    date, instrument
FROM
    -- 预计算因子 cn_stock_bar1d https://bigquant.com/data/datasources/cn_stock_bar1d
    cn_stock_prefactors
    -- SQL 模式不会自动join输入数据源, 可以根据需要自由灵活的使用
    -- JOIN input_1 USING(date, instrument)
WHERE
    -- WHERE 过滤, 在窗口等计算算子之前执行
    -- 剔除ST股票
    st_status = 0
QUALIFY
    -- QUALIFY 过滤, 在窗口等计算算子之后执行, 比如 m_lag(close, 3) AS close_3, 对于 close_3 的过滤需要放到这里
    -- 去掉有空值的行
    COLUMNS(*) IS NOT NULL
    -- _rank_return_90 是窗口函数结果,需要放在 QUALIFY 里
    AND _rank_return_90 > 0.1
    AND _rank_return_30 < 0.1
-- 按日期和股票代码排序, 从小到大
ORDER BY date, instrument
""",
    extract_data=False,
    m_name="""m2"""
)

# @module(position="-254,76", comment="""抽取训练数据""")
m3 = M.extract_data_dai.v19(
    sql=m2.data,
    start_date="""2020-1-1""",
    start_date_bound_to_trading_date=True,
    end_date="""2023-1-1""",
    end_date_bound_to_trading_date=True,
    before_start_days=360,
    keep_before=False,
    debug=False,
    m_name="""m3"""
)

# @module(position="106,-20", comment="""抽取预测数据""")
m4 = M.extract_data_dai.v19(
    sql=m1.data,
    start_date="""2023-01-01""",
    start_date_bound_to_trading_date=True,
    end_date="""2025-5-30""",
    end_date_bound_to_trading_date=True,
    before_start_days=360,
    keep_before=False,
    debug=False,
    m_name="""m4"""
)

# @module(position="-124,209", comment="""模型训练""")
m5 = M.stockranker.v9(
    train_data=m3.data,
    predict_data=m4.data,
    learning_algorithm="""排序""",
    number_of_leaves=30,
    min_docs_per_leaf=1000,
    number_of_trees=20,
    learning_rate=0.2,
    max_bins=1023,
    feature_fraction=1,
    data_row_fraction=1,
    sort_by="""date,instrument""",
    plot_charts=True,
    ndcg_discount_base=1,
    m_name="""m5"""
)

# @module(position="-80,331", comment="""等权分配""")
m6 = M.score_to_position.v4(
    input_1=m5.predictions,
    score_field="""score DESC""",
    hold_count=6,
    position_expr="""-- DAI SQL 算子/函数: https://bigquant.com/wiki/doc/dai-PLSbc1SbZX#h-%E5%87%BD%E6%95%B0
-- 在这里输入表达式, 每行一个表达式, 输出仓位字段必须命名为 position, 模块会进一步做归一化
-- 排序倒数: 1 / score_rank AS position
-- 对数下降: 1 / log2(score_rank + 1) AS position
-- TODO 拟合、最优化 ..

-- 等权重分配
1 AS position
""",
    total_position=1,
    extract_data=True,
    m_name="""m6"""
)

# @module(position="353.44732666015625,156.6826171875", comment="""准备大盘风控数据""", comment_collapsed=True)
m9 = M.input_features_dai.v30(
    mode="""表达式""",
    expr="""-- 设置一个风控指标,如果大盘连续5天跌超0.03,风控指标则为1,否则为0
IF(close / m_lag(close, 5) - 1 < -0.03, 1, 0) AS market_risk_indicator

 """,
    expr_filters="""-- 指数选择上证指数
instrument = '000001.SH'""",
    expr_tables="""cn_stock_index_bar1d""",
    extra_fields="""date""",
    order_by="""date""",
    expr_drop_na=True,
    sql="""-- 使用DAI SQL获取数据, 构建因子等, 如下是一个例子作为参考
-- DAI SQL 语法: https://bigquant.com/wiki/doc/dai-PLSbc1SbZX#h-sql%E5%85%A5%E9%97%A8%E6%95%99%E7%A8%8B
-- 使用数据输入1/2/3里的字段: e.g. input_1.close, input_1.* EXCLUDE(date, instrument)

SELECT
    -- 在这里输入因子表达式
    -- DAI SQL 算子/函数: https://bigquant.com/wiki/doc/dai-PLSbc1SbZX#h-%E5%87%BD%E6%95%B0
    -- 数据&字段: 数据文档 https://bigquant.com/data/home

    m_lag(close, 90) / close AS return_90,
    m_lag(close, 30) / close AS return_30,
    -- 下划线开始命名的列是中间变量, 不会在最终结果输出 (e.g. _rank_return_90)
    c_pct_rank(-return_90) AS _rank_return_90,
    c_pct_rank(return_30) AS _rank_return_30,

    c_rank(volume) AS rank_volume,
    close / m_lag(close, 1) as return_0,

    -- 日期和股票代码
    date, instrument
FROM
    -- 预计算因子 cn_stock_bar1d https://bigquant.com/data/datasources/cn_stock_bar1d
    cn_stock_factors
    -- SQL 模式不会自动join输入数据源, 可以根据需要自由灵活的使用
    -- JOIN input_1 USING(date, instrument)
WHERE
    -- WHERE 过滤, 在窗口等计算算子之前执行
    -- 剔除ST股票
    st_status = 0
QUALIFY
    -- QUALIFY 过滤, 在窗口等计算算子之后执行, 比如 m_lag(close, 3) AS close_3, 对于 close_3 的过滤需要放到这里
    -- 去掉有空值的行
    COLUMNS(*) IS NOT NULL
    -- _rank_return_90 是窗口函数结果,需要放在 QUALIFY 里
    AND _rank_return_90 > 0.1
    AND _rank_return_30 < 0.1
-- 按日期和股票代码排序, 从小到大
ORDER BY date, instrument
""",
    extract_data=False,
    m_name="""m9"""
)

# @module(position="309,264", comment="""抽取预测数据""", comment_collapsed=True)
m10 = M.extract_data_dai.v19(
    sql=m9.data,
    start_date="""2022-01-01""",
    start_date_bound_to_trading_date=True,
    end_date="""2025-5-30""",
    end_date_bound_to_trading_date=True,
    before_start_days=360,
    keep_before=False,
    debug=False,
    m_name="""m10"""
)

# @module(position="-101,449", comment="""交易,日线,设置初始化函数和K线处理函数,以及初始化资金、基准等""")
m7 = M.bigtrader.v38(
    data=m6.data,
    options_data=m10.data,
    start_date="""2025-1-1""",
    end_date="""2025-5-30""",
    initialize=m7_initialize_bigquant_run,
    before_trading_start=m7_before_trading_start_bigquant_run,
    handle_tick=m7_handle_tick_bigquant_run,
    handle_data=m7_handle_data_bigquant_run,
    handle_trade=m7_handle_trade_bigquant_run,
    handle_order=m7_handle_order_bigquant_run,
    after_trading=m7_after_trading_bigquant_run,
    capital_base=200000,
    frequency="""daily""",
    product_type="""股票""",
    rebalance_period_type="""交易日""",
    rebalance_period_days="""5""",
    rebalance_period_roll_forward=True,
    backtest_engine_mode="""标准模式""",
    before_start_days=0,
    volume_limit=0.25,
    order_price_field_buy="""open""",
    order_price_field_sell="""open""",
    benchmark="""沪深300指数""",
    plot_charts=True,
    debug=False,
    backtest_only=False,
    m_name="""m7"""
)
# </aistudiograph>

训练时间段和预测时间段是需要分开的,在策略模板中训练时间和预测时间段需要直接选定,这样回测是没问题的。但是在模拟交易中随着交易时间的发展,训练集并不会变化,如何才能让训练集和预测集随交易时间变化而继续向前滚动?

{link}