答:没有最好,只有更好。 这个问题的答案取决于许多因素,例如股票市场的条件,数据集的质量和特征工程的有效性等。 接下来 我们来看看这些算法的优势和劣势。
神经网络:适用于复杂的非线性问题,可以有效地捕捉市场的非线性特征和复杂关系。
决策树:适用于数据量较小、特征维度较少的情况,可以很好地解释模型的决策过程。
随机森林:适用于处理高维度、复杂数据集,具有很好的鲁棒性和准确性。
支持向量机:适用于数据量较小、特征维度较高的情况,可以有效地处理非线性和线性可分问题。
然而,一般来说,深度学习算法比机器学习算法可能会获得更好的收益和效果。 原因如下:
深度学习算法对于非线性模型的拟合效果更好。在选股策略中,非线性模型更符合实际情况。
深度学习算法可以处理更复杂的数据结构。 对于股票数据,深度学习算法可以更好地挖掘和处理时间序列数据,自然语言处理和图像数据等多种数据结构。
深度学习算法可以进行端到端学习。深度学习算法可以直接从原始数据中进行学习,而不需要进行手动特征工程,这样可以更好地挖掘数据的潜在信息。
训练集:14-2018年-01-14 测试集: 18-2019-01-10 日频调仓 每天1只股票半仓轮动 各种算法收益如图:
# 本代码由可视化策略环境自动生成 2023年4月7日 13:29
# 本代码单元只能在可视化模式下编辑。您也可以拷贝代码,粘贴到新建的代码单元或者策略,然后修改。
# 回测引擎:初始化函数,只执行一次
def m5_initialize_bigquant_run(context):
# 加载预测数据
context.ranker_prediction = context.options['data'].read_df()
# 系统已经设置了默认的交易手续费和滑点,要修改手续费可使用如下函数
context.set_commission(PerOrder(buy_cost=0.0003, sell_cost=0.0013, min_cost=5))
# 预测数据,通过options传入进来,使用 read_df 函数,加载到内存 (DataFrame)
# 设置买入的股票数量,这里买入预测股票列表排名靠前的5只
stock_count = 1
# 每只的股票的权重,如下的权重分配会使得靠前的股票分配多一点的资金,[0.339160, 0.213986, 0.169580, ..]
context.stock_weights = [1]
# 设置每只股票占用的最大资金比例
context.max_cash_per_instrument = 1
context.options['hold_days'] = 1
# 回测引擎:每日数据处理函数,每天执行一次
def m5_handle_data_bigquant_run(context, data):
# 按日期过滤得到今日的预测数据
ranker_prediction = context.ranker_prediction[
context.ranker_prediction.date == data.current_dt.strftime('%Y-%m-%d')]
cash_for_buy = min(context.portfolio.portfolio_value/1,context.portfolio.cash)
#cash_for_buy = context.portfolio.portfolio_value
#print(ranker_prediction)
#cash_for_buy = context.portfolio.portfolio_value
#cash_for_buy = context.portfolio.cash
buy_instruments = list(ranker_prediction.instrument)
sell_instruments = [instrument.symbol for instrument in context.portfolio.positions.keys()]
to_buy = set(buy_instruments[:1]) - set(sell_instruments)
to_sell = set(sell_instruments) - set(buy_instruments[:1])
for instrument in to_sell:
context.order_target(context.symbol(instrument), 0)
for instrument in to_buy:
context.order_value(context.symbol(instrument), cash_for_buy)
def m5_prepare_bigquant_run(context):
# 获取st状态和涨跌停状态
context.status_df = D.features(instruments =context.instruments,start_date = context.start_date, end_date = context.end_date,
fields=['st_status_0','price_limit_status_0','price_limit_status_1'])
def m5_before_trading_start_bigquant_run(context, data):
pass
# 获取涨跌停状态数据
# df_price_limit_status=context.status_df.set_index('date')
# today=data.current_dt.strftime('%Y-%m-%d')
# # 得到当前未完成订单
# for orders in get_open_orders().values():
# # 循环,撤销订单
# for _order in orders:
# ins=str(_order.sid.symbol)
# try:
# #判断一下如果当日涨停,则取消卖单
# if df_price_limit_status[df_price_limit_status.instrument==ins].price_limit_status_0.loc[today]>2 and _order.amount<0:
# cancel_order(_order)
# print(today,'尾盘涨停取消卖单',ins)
# except:
# continue
m1 = M.instruments.v2(
start_date='2014-01-01',
end_date='2018-01-14',
market='CN_STOCK_A',
instrument_list='',
max_count=0
)
m2 = M.advanced_auto_labeler.v2(
instruments=m1.data,
label_expr="""# #号开始的表示注释
# 0. 每行一个,顺序执行,从第二个开始,可以使用label字段
# 1. 可用数据字段见 https://bigquant.com/docs/develop/datasource/deprecated/history_data.html
# 添加benchmark_前缀,可使用对应的benchmark数据
# 2. 可用操作符和函数见 `表达式引擎 <https://bigquant.com/docs/develop/bigexpr/usage.html>`_
# 计算收益:5日收盘价(作为卖出价格)除以明日开盘价(作为买入价格)
shift(close, -2) / shift(open, -1)
# 极值处理:用1%和99%分位的值做clip
clip(label, all_quantile(label, 0.01), all_quantile(label, 0.99))
# 将分数映射到分类,这里使用20个分类
all_wbins(label, 20)
# 过滤掉一字涨停的情况 (设置label为NaN,在后续处理和训练中会忽略NaN的label)
where(shift(high, -1) == shift(low, -1), NaN, label)
""",
start_date='',
end_date='',
benchmark='000300.HIX',
drop_na_label=True,
cast_label_int=True
)
m3 = M.input_features.v1(
features="""# #号开始的表示注释
# 多个特征,每行一个,可以包含基础特征和衍生特征
return_5
return_10
return_20
avg_amount_0/avg_amount_5
avg_amount_5/avg_amount_20
rank_avg_amount_0/rank_avg_amount_5
rank_avg_amount_5/rank_avg_amount_10
rank_return_0
rank_return_5
rank_return_10
rank_return_0/rank_return_5
rank_return_5/rank_return_10
pe_ttm_0
#主力净流入净额
#mf_net_amount_main_0
"""
)
m15 = M.general_feature_extractor.v7(
instruments=m1.data,
features=m3.data,
start_date='',
end_date='',
before_start_days=90
)
m16 = M.derived_feature_extractor.v3(
input_data=m15.data,
features=m3.data,
date_col='date',
instrument_col='instrument',
drop_na=False,
remove_extra_columns=False
)
m7 = M.join.v3(
data1=m2.data,
data2=m16.data,
on='date,instrument',
how='inner',
sort=False
)
m13 = M.dropnan.v1(
input_data=m7.data
)
m6 = M.stock_ranker_train.v6(
training_ds=m13.data,
features=m3.data,
learning_algorithm='排序',
number_of_leaves=30,
minimum_docs_per_leaf=1000,
number_of_trees=20,
learning_rate=0.1,
max_bins=1023,
feature_fraction=1,
data_row_fraction=1,
plot_charts=True,
ndcg_discount_base=1,
m_lazy_run=False
)
m9 = M.instruments.v2(
start_date=T.live_run_param('trading_date', '2018-01-15'),
end_date=T.live_run_param('trading_date', '2019-01-10'),
market='CN_STOCK_A',
instrument_list='',
max_count=0
)
m17 = M.general_feature_extractor.v7(
instruments=m9.data,
features=m3.data,
start_date='',
end_date='',
before_start_days=90
)
m18 = M.derived_feature_extractor.v3(
input_data=m17.data,
features=m3.data,
date_col='date',
instrument_col='instrument',
drop_na=False,
remove_extra_columns=False
)
m14 = M.dropnan.v1(
input_data=m18.data
)
m8 = M.stock_ranker_predict.v5(
model=m6.model,
data=m14.data,
m_lazy_run=False
)
m5 = M.trade.v4(
instruments=m9.data,
options_data=m8.predictions,
start_date='',
end_date='',
initialize=m5_initialize_bigquant_run,
handle_data=m5_handle_data_bigquant_run,
prepare=m5_prepare_bigquant_run,
before_trading_start=m5_before_trading_start_bigquant_run,
volume_limit=0,
order_price_field_buy='open',
order_price_field_sell='close',
capital_base=100000,
auto_cancel_non_tradable_orders=True,
data_frequency='daily',
price_type='真实价格',
product_type='股票',
plot_charts=True,
backtest_only=False,
benchmark='000300.SHA'
)
m4 = M.random_forest_regressor.v1(
iterations=10,
feature_fraction=1,
max_depth=30,
min_samples_per_leaf=200,
key_cols='date,instrument',
workers=1,
random_state=0,
other_train_parameters={}
)
[2023-04-07 11:46:44.221232] INFO: moduleinvoker: instruments.v2 开始运行..
[2023-04-07 11:46:44.321154] INFO: moduleinvoker: instruments.v2 运行完成[0.099928s].
[2023-04-07 11:46:44.334505] INFO: moduleinvoker: advanced_auto_labeler.v2 开始运行..
[2023-04-07 11:46:46.917145] INFO: 自动标注(股票): 加载历史数据: 2553771 行
[2023-04-07 11:46:46.918601] INFO: 自动标注(股票): 开始标注 ..
[2023-04-07 11:46:49.929262] INFO: moduleinvoker: advanced_auto_labeler.v2 运行完成[5.594751s].
[2023-04-07 11:46:49.937408] INFO: moduleinvoker: input_features.v1 开始运行..
[2023-04-07 11:46:49.944666] INFO: moduleinvoker: 命中缓存
[2023-04-07 11:46:49.945780] INFO: moduleinvoker: input_features.v1 运行完成[0.00838s].
[2023-04-07 11:46:49.961670] INFO: moduleinvoker: general_feature_extractor.v7 开始运行..
[2023-04-07 11:46:50.484447] INFO: 基础特征抽取: 年份 2013, 特征行数=143272
[2023-04-07 11:46:51.722472] INFO: 基础特征抽取: 年份 2014, 特征行数=569948
[2023-04-07 11:46:53.024806] INFO: 基础特征抽取: 年份 2015, 特征行数=569698
[2023-04-07 11:46:54.368705] INFO: 基础特征抽取: 年份 2016, 特征行数=641546
[2023-04-07 11:46:55.923325] INFO: 基础特征抽取: 年份 2017, 特征行数=743233
[2023-04-07 11:46:56.308862] INFO: 基础特征抽取: 年份 2018, 特征行数=29346
[2023-04-07 11:46:56.357962] INFO: 基础特征抽取: 总行数: 2697043
[2023-04-07 11:46:56.362529] INFO: moduleinvoker: general_feature_extractor.v7 运行完成[6.400856s].
[2023-04-07 11:46:56.376685] INFO: moduleinvoker: derived_feature_extractor.v3 开始运行..
[2023-04-07 11:47:00.351664] INFO: derived_feature_extractor: 提取完成 avg_amount_0/avg_amount_5, 0.008s
[2023-04-07 11:47:00.360965] INFO: derived_feature_extractor: 提取完成 avg_amount_5/avg_amount_20, 0.008s
[2023-04-07 11:47:00.366306] INFO: derived_feature_extractor: 提取完成 rank_avg_amount_0/rank_avg_amount_5, 0.004s
[2023-04-07 11:47:00.371506] INFO: derived_feature_extractor: 提取完成 rank_avg_amount_5/rank_avg_amount_10, 0.004s
[2023-04-07 11:47:00.376654] INFO: derived_feature_extractor: 提取完成 rank_return_0/rank_return_5, 0.004s
[2023-04-07 11:47:00.381576] INFO: derived_feature_extractor: 提取完成 rank_return_5/rank_return_10, 0.004s
[2023-04-07 11:47:01.425580] INFO: derived_feature_extractor: /y_2013, 143272
[2023-04-07 11:47:02.157687] INFO: derived_feature_extractor: /y_2014, 569948
[2023-04-07 11:47:03.099019] INFO: derived_feature_extractor: /y_2015, 569698
[2023-04-07 11:47:04.265013] INFO: derived_feature_extractor: /y_2016, 641546
[2023-04-07 11:47:05.614496] INFO: derived_feature_extractor: /y_2017, 743233
[2023-04-07 11:47:06.116316] INFO: derived_feature_extractor: /y_2018, 29346
[2023-04-07 11:47:06.322695] INFO: moduleinvoker: derived_feature_extractor.v3 运行完成[9.946011s].
[2023-04-07 11:47:06.335135] INFO: moduleinvoker: join.v3 开始运行..
[2023-04-07 11:47:10.838386] INFO: join: /y_2013, 行数=0/143272, 耗时=0.83365s
[2023-04-07 11:47:12.563680] INFO: join: /y_2014, 行数=567883/569948, 耗时=1.723598s
[2023-04-07 11:47:14.367103] INFO: join: /y_2015, 行数=560441/569698, 耗时=1.799924s
[2023-04-07 11:47:16.439750] INFO: join: /y_2016, 行数=637478/641546, 耗时=2.069026s
[2023-04-07 11:47:18.777982] INFO: join: /y_2017, 行数=738013/743233, 耗时=2.334258s
[2023-04-07 11:47:19.362681] INFO: join: /y_2018, 行数=22655/29346, 耗时=0.578735s
[2023-04-07 11:47:19.440970] INFO: join: 最终行数: 2526470
[2023-04-07 11:47:19.476974] INFO: moduleinvoker: join.v3 运行完成[13.141833s].
[2023-04-07 11:47:19.489690] INFO: moduleinvoker: dropnan.v1 开始运行..
[2023-04-07 11:47:20.175099] INFO: dropnan: /y_2013, 0/0
[2023-04-07 11:47:21.346008] INFO: dropnan: /y_2014, 566044/567883
[2023-04-07 11:47:22.540603] INFO: dropnan: /y_2015, 558165/560441
[2023-04-07 11:47:23.965755] INFO: dropnan: /y_2016, 635618/637478
[2023-04-07 11:47:25.514994] INFO: dropnan: /y_2017, 732356/738013
[2023-04-07 11:47:25.685114] INFO: dropnan: /y_2018, 22566/22655
[2023-04-07 11:47:25.760240] INFO: dropnan: 行数: 2514749/2526470
[2023-04-07 11:47:25.764485] INFO: moduleinvoker: dropnan.v1 运行完成[6.274802s].
[2023-04-07 11:47:25.781307] INFO: moduleinvoker: stock_ranker_train.v6 开始运行..
[2023-04-07 11:47:31.141840] INFO: StockRanker: 特征预处理 ..
[2023-04-07 11:47:34.344162] INFO: StockRanker: prepare data: training ..
[2023-04-07 11:47:37.548028] INFO: StockRanker: sort ..
[2023-04-07 11:48:03.742389] INFO: StockRanker训练: e9348934 准备训练: 2514749 行数
[2023-04-07 11:48:03.743742] INFO: StockRanker训练: AI模型训练,将在2514749*13=3269.17万数据上对模型训练进行20轮迭代训练。预计将需要10~21分钟。请耐心等待。
[2023-04-07 11:48:03.982205] INFO: StockRanker训练: 正在训练 ..
[2023-04-07 11:48:04.045304] INFO: StockRanker训练: 任务状态: Pending
[2023-04-07 11:48:14.089955] INFO: StockRanker训练: 任务状态: Running
[2023-04-07 11:49:34.432851] INFO: StockRanker训练: 00:01:17.7669755, finished iteration 1
[2023-04-07 11:49:44.473526] INFO: StockRanker训练: 00:01:36.3993067, finished iteration 2
[2023-04-07 11:50:05.035488] INFO: StockRanker训练: 00:01:55.5013713, finished iteration 3
[2023-04-07 11:50:25.121920] INFO: StockRanker训练: 00:02:13.1006714, finished iteration 4
[2023-04-07 11:50:45.200042] INFO: StockRanker训练: 00:02:29.7968547, finished iteration 5
[2023-04-07 11:50:55.241537] INFO: StockRanker训练: 00:02:45.4767774, finished iteration 6
[2023-04-07 11:51:15.326406] INFO: StockRanker训练: 00:03:01.6437048, finished iteration 7
[2023-04-07 11:51:35.420521] INFO: StockRanker训练: 00:03:18.4522216, finished iteration 8
[2023-04-07 11:51:45.471146] INFO: StockRanker训练: 00:03:35.9467368, finished iteration 9
[2023-04-07 11:52:05.615544] INFO: StockRanker训练: 00:03:57.3881220, finished iteration 10
[2023-04-07 11:52:35.738796] INFO: StockRanker训练: 00:04:20.3825184, finished iteration 11
[2023-04-07 11:52:55.827003] INFO: StockRanker训练: 00:04:39.2380031, finished iteration 12
[2023-04-07 11:53:05.868087] INFO: StockRanker训练: 00:04:57.8872321, finished iteration 13
[2023-04-07 11:53:25.955220] INFO: StockRanker训练: 00:05:13.5194859, finished iteration 14
[2023-04-07 11:53:46.055354] INFO: StockRanker训练: 00:05:29.6420988, finished iteration 15
[2023-04-07 11:53:56.098350] INFO: StockRanker训练: 00:05:46.5631741, finished iteration 16
[2023-04-07 11:54:16.185765] INFO: StockRanker训练: 00:06:04.2975979, finished iteration 17
[2023-04-07 11:54:36.269577] INFO: StockRanker训练: 00:06:20.9218187, finished iteration 18
[2023-04-07 11:54:46.317239] INFO: StockRanker训练: 00:06:37.5677569, finished iteration 19
[2023-04-07 11:55:06.513043] INFO: StockRanker训练: 00:06:54.2133379, finished iteration 20
[2023-04-07 11:55:06.514494] INFO: StockRanker训练: 任务状态: Succeeded
[2023-04-07 11:55:06.745975] INFO: moduleinvoker: stock_ranker_train.v6 运行完成[460.964661s].
[2023-04-07 11:55:06.750634] INFO: moduleinvoker: instruments.v2 开始运行..
[2023-04-07 11:55:06.803806] INFO: moduleinvoker: instruments.v2 运行完成[0.053178s].
[2023-04-07 11:55:06.816072] INFO: moduleinvoker: general_feature_extractor.v7 开始运行..
[2023-04-07 11:55:07.378186] INFO: 基础特征抽取: 年份 2017, 特征行数=174303
[2023-04-07 11:55:09.108555] INFO: 基础特征抽取: 年份 2018, 特征行数=816987
[2023-04-07 11:55:09.482863] INFO: 基础特征抽取: 年份 2019, 特征行数=24884
[2023-04-07 11:55:09.551917] INFO: 基础特征抽取: 总行数: 1016174
[2023-04-07 11:55:09.556387] INFO: moduleinvoker: general_feature_extractor.v7 运行完成[2.740321s].
[2023-04-07 11:55:09.562904] INFO: moduleinvoker: derived_feature_extractor.v3 开始运行..
[2023-04-07 11:55:11.034551] INFO: derived_feature_extractor: 提取完成 avg_amount_0/avg_amount_5, 0.004s
[2023-04-07 11:55:11.039135] INFO: derived_feature_extractor: 提取完成 avg_amount_5/avg_amount_20, 0.003s
[2023-04-07 11:55:11.043053] INFO: derived_feature_extractor: 提取完成 rank_avg_amount_0/rank_avg_amount_5, 0.002s
[2023-04-07 11:55:11.046418] INFO: derived_feature_extractor: 提取完成 rank_avg_amount_5/rank_avg_amount_10, 0.002s
[2023-04-07 11:55:11.050044] INFO: derived_feature_extractor: 提取完成 rank_return_0/rank_return_5, 0.002s
[2023-04-07 11:55:11.053957] INFO: derived_feature_extractor: 提取完成 rank_return_5/rank_return_10, 0.002s
[2023-04-07 11:55:11.533263] INFO: derived_feature_extractor: /y_2017, 174303
[2023-04-07 11:55:12.620865] INFO: derived_feature_extractor: /y_2018, 816987
[2023-04-07 11:55:13.185824] INFO: derived_feature_extractor: /y_2019, 24884
[2023-04-07 11:55:13.304987] INFO: moduleinvoker: derived_feature_extractor.v3 运行完成[3.742064s].
[2023-04-07 11:55:13.315481] INFO: moduleinvoker: dropnan.v1 开始运行..
[2023-04-07 11:55:13.937766] INFO: dropnan: /y_2017, 172549/174303
[2023-04-07 11:55:15.789686] INFO: dropnan: /y_2018, 814562/816987
[2023-04-07 11:55:15.864906] INFO: dropnan: /y_2019, 24834/24884
[2023-04-07 11:55:16.022767] INFO: dropnan: 行数: 1011945/1016174
[2023-04-07 11:55:16.027815] INFO: moduleinvoker: dropnan.v1 运行完成[2.712338s].
[2023-04-07 11:55:16.044394] INFO: moduleinvoker: stock_ranker_predict.v5 开始运行..
[2023-04-07 11:55:16.710921] INFO: StockRanker预测: /y_2017 ..
[2023-04-07 11:55:18.081123] INFO: StockRanker预测: /y_2018 ..
[2023-04-07 11:55:19.604900] INFO: StockRanker预测: /y_2019 ..
[2023-04-07 11:55:20.616865] INFO: moduleinvoker: stock_ranker_predict.v5 运行完成[4.572465s].
[2023-04-07 11:55:23.539914] INFO: moduleinvoker: backtest.v8 开始运行..
[2023-04-07 11:55:23.547757] INFO: backtest: biglearning backtest:V8.6.3
[2023-04-07 11:55:24.182559] INFO: backtest: product_type:stock by specified
[2023-04-07 11:55:24.235704] INFO: moduleinvoker: cached.v2 开始运行..
[2023-04-07 11:55:30.125344] INFO: backtest: 读取股票行情完成:1891518
[2023-04-07 11:55:31.479080] INFO: moduleinvoker: cached.v2 运行完成[7.243369s].
[2023-04-07 11:55:38.926676] INFO: backtest: algo history_data=DataSource(366ebd136cd247ddb4318d2b2ea185ceT)
[2023-04-07 11:55:38.928073] INFO: algo: TradingAlgorithm V1.8.9
[2023-04-07 11:55:40.564024] INFO: algo: trading transform...
[2023-04-07 11:55:42.824180] INFO: Performance: Simulated 241 trading days out of 241.
[2023-04-07 11:55:42.825521] INFO: Performance: first open: 2018-01-15 09:30:00+00:00
[2023-04-07 11:55:42.826564] INFO: Performance: last close: 2019-01-10 15:00:00+00:00
[2023-04-07 11:55:45.203081] INFO: moduleinvoker: backtest.v8 运行完成[21.663158s].
[2023-04-07 11:55:45.204731] INFO: moduleinvoker: trade.v4 运行完成[24.571791s].
# Python 代码入口函数,input_1/2/3 对应三个输入端,data_1/2/3 对应三个输出端
def m10_run_bigquant_run(input_1, input_2, input_3):
# 示例代码如下。在这里编写您的代码
from sklearn.model_selection import train_test_split
data = input_1.read()
x_train, x_val, y_train, y_val = train_test_split(data["x"], data['y'], random_state=5)
data_1 = DataSource.write_pickle({'x': x_train, 'y': y_train})
data_2 = DataSource.write_pickle({'x': x_val, 'y': y_val})
return Outputs(data_1=data_1, data_2=data_2, data_3=None)
# 后处理函数,可选。输入是主函数的输出,可以在这里对数据做处理,或者返回更友好的outputs数据格式。此函数输出不会被缓存。
def m10_post_run_bigquant_run(outputs):
return outputs
from tensorflow.keras.callbacks import EarlyStopping
m5_earlystop_bigquant_run=EarlyStopping(monitor='val_mse', min_delta=0.0001, patience=5)
# 用户的自定义层需要写到字典中,比如
# {
# "MyLayer": MyLayer
# }
m5_custom_objects_bigquant_run = {
}
# Python 代码入口函数,input_1/2/3 对应三个输入端,data_1/2/3 对应三个输出端
def m24_run_bigquant_run(input_1, input_2, input_3):
# 示例代码如下。在这里编写您的代码
pred_label = input_1.read_pickle()
df = input_2.read_df()
df = pd.DataFrame({'pred_label':pred_label[:,0], 'instrument':df.instrument, 'date':df.date})
df.sort_values(['date','pred_label'],inplace=True, ascending=[True,False])
return Outputs(data_1=DataSource.write_df(df), data_2=None, data_3=None)
# 后处理函数,可选。输入是主函数的输出,可以在这里对数据做处理,或者返回更友好的outputs数据格式。此函数输出不会被缓存。
def m24_post_run_bigquant_run(outputs):
return outputs
# 回测引擎:初始化函数,只执行一次
def m19_initialize_bigquant_run(context):
# 加载预测数据
context.ranker_prediction = context.options['data'].read_df()
# 系统已经设置了默认的交易手续费和滑点,要修改手续费可使用如下函数
context.set_commission(PerOrder(buy_cost=0.0003, sell_cost=0.0013, min_cost=5))
# 预测数据,通过options传入进来,使用 read_df 函数,加载到内存 (DataFrame)
# 设置买入的股票数量,这里买入预测股票列表排名靠前的5只
stock_count = 20
# 每只的股票的权重,如下的权重分配会使得靠前的股票分配多一点的资金,[0.339160, 0.213986, 0.169580, ..]
context.stock_weights = T.norm([1 / math.log(i + 2) for i in range(0, stock_count)])
# 设置每只股票占用的最大资金比例
context.max_cash_per_instrument = 0.2
context.options['hold_days'] = 5
# 回测引擎:每日数据处理函数,每天执行一次
def m19_handle_data_bigquant_run(context, data):
# 按日期过滤得到今日的预测数据
ranker_prediction = context.ranker_prediction[
context.ranker_prediction.date == data.current_dt.strftime('%Y-%m-%d')]
# 1. 资金分配
# 平均持仓时间是hold_days,每日都将买入股票,每日预期使用 1/hold_days 的资金
# 实际操作中,会存在一定的买入误差,所以在前hold_days天,等量使用资金;之后,尽量使用剩余资金(这里设置最多用等量的1.5倍)
is_staging = context.trading_day_index < context.options['hold_days'] # 是否在建仓期间(前 hold_days 天)
cash_avg = context.portfolio.portfolio_value / context.options['hold_days']
cash_for_buy = min(context.portfolio.cash, (1 if is_staging else 1.5) * cash_avg)
cash_for_sell = cash_avg - (context.portfolio.cash - cash_for_buy)
positions = {e.symbol: p.amount * p.last_sale_price
for e, p in context.perf_tracker.position_tracker.positions.items()}
# 2. 生成卖出订单:hold_days天之后才开始卖出;对持仓的股票,按机器学习算法预测的排序末位淘汰
if not is_staging and cash_for_sell > 0:
equities = {e.symbol: e for e, p in context.perf_tracker.position_tracker.positions.items()}
instruments = list(reversed(list(ranker_prediction.instrument[ranker_prediction.instrument.apply(
lambda x: x in equities and not context.has_unfinished_sell_order(equities[x]))])))
# print('rank order for sell %s' % instruments)
for instrument in instruments:
context.order_target(context.symbol(instrument), 0)
cash_for_sell -= positions[instrument]
if cash_for_sell <= 0:
break
# 3. 生成买入订单:按机器学习算法预测的排序,买入前面的stock_count只股票
buy_cash_weights = context.stock_weights
buy_instruments = list(ranker_prediction.instrument[:len(buy_cash_weights)])
max_cash_per_instrument = context.portfolio.portfolio_value * context.max_cash_per_instrument
for i, instrument in enumerate(buy_instruments):
cash = cash_for_buy * buy_cash_weights[i]
if cash > max_cash_per_instrument - positions.get(instrument, 0):
# 确保股票持仓量不会超过每次股票最大的占用资金量
cash = max_cash_per_instrument - positions.get(instrument, 0)
if cash > 0:
context.order_value(context.symbol(instrument), cash)
# 回测引擎:准备数据,只执行一次
def m19_prepare_bigquant_run(context):
pass
# 回测引擎:初始化函数,只执行一次
def m12_initialize_bigquant_run(context):
# 加载预测数据
context.ranker_prediction = context.options['data'].read_df()
# 系统已经设置了默认的交易手续费和滑点,要修改手续费可使用如下函数
context.set_commission(PerOrder(buy_cost=0.0003, sell_cost=0.0013, min_cost=5))
# 预测数据,通过options传入进来,使用 read_df 函数,加载到内存 (DataFrame)
# 设置买入的股票数量,这里买入预测股票列表排名靠前的5只
stock_count = 1
# 每只的股票的权重,如下的权重分配会使得靠前的股票分配多一点的资金,[0.339160, 0.213986, 0.169580, ..]
context.stock_weights = [1]
# 设置每只股票占用的最大资金比例
context.max_cash_per_instrument = 1
context.options['hold_days'] = 1
# 回测引擎:每日数据处理函数,每天执行一次
def m12_handle_data_bigquant_run(context, data):
# 按日期过滤得到今日的预测数据
ranker_prediction = context.ranker_prediction[
context.ranker_prediction.date == data.current_dt.strftime('%Y-%m-%d')]
cash_for_buy = min(context.portfolio.portfolio_value/1,context.portfolio.cash)
#cash_for_buy = context.portfolio.portfolio_value
#print(ranker_prediction)
#cash_for_buy = context.portfolio.portfolio_value
#cash_for_buy = context.portfolio.cash
buy_instruments = list(ranker_prediction.instrument)
sell_instruments = [instrument.symbol for instrument in context.portfolio.positions.keys()]
to_buy = set(buy_instruments[:1]) - set(sell_instruments)
to_sell = set(sell_instruments) - set(buy_instruments[:1])
for instrument in to_sell:
context.order_target(context.symbol(instrument), 0)
for instrument in to_buy:
context.order_value(context.symbol(instrument), cash_for_buy)
def m12_prepare_bigquant_run(context):
# 获取st状态和涨跌停状态
context.status_df = D.features(instruments =context.instruments,start_date = context.start_date, end_date = context.end_date,
fields=['st_status_0','price_limit_status_0','price_limit_status_1'])
def m12_before_trading_start_bigquant_run(context, data):
pass
# 获取涨跌停状态数据
# df_price_limit_status=context.status_df.set_index('date')
# today=data.current_dt.strftime('%Y-%m-%d')
# # 得到当前未完成订单
# for orders in get_open_orders().values():
# # 循环,撤销订单
# for _order in orders:
# ins=str(_order.sid.symbol)
# try:
# #判断一下如果当日涨停,则取消卖单
# if df_price_limit_status[df_price_limit_status.instrument==ins].price_limit_status_0.loc[today]>2 and _order.amount<0:
# cancel_order(_order)
# print(today,'尾盘涨停取消卖单',ins)
# except:
# continue
m1 = M.instruments.v2(
start_date='2014-01-01',
end_date='2018-01-14',
market='CN_STOCK_A',
instrument_list='',
max_count=0
)
m2 = M.advanced_auto_labeler.v2(
instruments=m1.data,
label_expr="""# #号开始的表示注释
# 0. 每行一个,顺序执行,从第二个开始,可以使用label字段
# 1. 可用数据字段见 https://bigquant.com/docs/data_history_data.html
# 添加benchmark_前缀,可使用对应的benchmark数据
# 2. 可用操作符和函数见 `表达式引擎 <https://bigquant.com/docs/big_expr.html>`_
# 计算收益:5日收盘价(作为卖出价格)除以明日开盘价(作为买入价格)
shift(close, -2) / shift(open, -1)-1
# 极值处理:用1%和99%分位的值做clip
clip(label, all_quantile(label, 0.01), all_quantile(label, 0.99))
# 过滤掉一字涨停的情况 (设置label为NaN,在后续处理和训练中会忽略NaN的label)
where(shift(high, -1) == shift(low, -1), NaN, label)
""",
start_date='',
end_date='',
benchmark='000300.SHA',
drop_na_label=True,
cast_label_int=False
)
m29 = M.standardlize.v8(
input_1=m2.data,
columns_input='label'
)
m3 = M.input_features.v1(
features="""return_5
return_10
return_20
avg_amount_0/avg_amount_5
avg_amount_5/avg_amount_20
rank_avg_amount_0/rank_avg_amount_5
rank_avg_amount_5/rank_avg_amount_10
rank_return_0
rank_return_5
rank_return_10
rank_return_0/rank_return_5
rank_return_5/rank_return_10
pe_ttm_0"""
)
m15 = M.general_feature_extractor.v7(
instruments=m1.data,
features=m3.data,
start_date='',
end_date='',
before_start_days=0
)
m16 = M.derived_feature_extractor.v3(
input_data=m15.data,
features=m3.data,
date_col='date',
instrument_col='instrument',
drop_na=True,
remove_extra_columns=False
)
m28 = M.standardlize.v8(
input_1=m16.data,
input_2=m3.data,
columns_input='[]'
)
m13 = M.fillnan.v1(
input_data=m28.data,
features=m3.data,
fill_value='0.0'
)
m7 = M.join.v3(
data1=m29.data,
data2=m13.data,
on='date,instrument',
how='inner',
sort=False
)
m26 = M.dl_convert_to_bin.v2(
input_data=m7.data,
features=m3.data,
window_size=2,
feature_clip=3,
flatten=True,
window_along_col='instrument'
)
m10 = M.cached.v3(
input_1=m26.data,
run=m10_run_bigquant_run,
post_run=m10_post_run_bigquant_run,
input_ports='',
params='{}',
output_ports=''
)
m9 = M.instruments.v2(
start_date=T.live_run_param('trading_date', '2018-01-15'),
end_date=T.live_run_param('trading_date', '2019-01-10'),
market='CN_STOCK_A',
instrument_list='',
max_count=0
)
m17 = M.general_feature_extractor.v7(
instruments=m9.data,
features=m3.data,
start_date='',
end_date='',
before_start_days=0
)
m18 = M.derived_feature_extractor.v3(
input_data=m17.data,
features=m3.data,
date_col='date',
instrument_col='instrument',
drop_na=True,
remove_extra_columns=False
)
m25 = M.standardlize.v8(
input_1=m18.data,
input_2=m3.data,
columns_input='[]'
)
m14 = M.fillnan.v1(
input_data=m25.data,
features=m3.data,
fill_value='0.0'
)
m27 = M.dl_convert_to_bin.v2(
input_data=m14.data,
features=m3.data,
window_size=2,
feature_clip=3,
flatten=True,
window_along_col='instrument'
)
m6 = M.dl_layer_input.v1(
shape='26',
batch_shape='',
dtype='float32',
sparse=False,
name=''
)
m8 = M.dl_layer_dense.v1(
inputs=m6.data,
units=256,
activation='relu',
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='Zeros',
kernel_regularizer='None',
kernel_regularizer_l1=0,
kernel_regularizer_l2=0,
bias_regularizer='None',
bias_regularizer_l1=0,
bias_regularizer_l2=0,
activity_regularizer='None',
activity_regularizer_l1=0,
activity_regularizer_l2=0,
kernel_constraint='None',
bias_constraint='None',
name=''
)
m21 = M.dl_layer_dropout.v1(
inputs=m8.data,
rate=0.1,
noise_shape='',
name=''
)
m20 = M.dl_layer_dense.v1(
inputs=m21.data,
units=128,
activation='relu',
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='Zeros',
kernel_regularizer='None',
kernel_regularizer_l1=0,
kernel_regularizer_l2=0,
bias_regularizer='None',
bias_regularizer_l1=0,
bias_regularizer_l2=0,
activity_regularizer='None',
activity_regularizer_l1=0,
activity_regularizer_l2=0,
kernel_constraint='None',
bias_constraint='None',
name=''
)
m22 = M.dl_layer_dropout.v1(
inputs=m20.data,
rate=0.1,
noise_shape='',
name=''
)
m23 = M.dl_layer_dense.v1(
inputs=m22.data,
units=1,
activation='linear',
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='Zeros',
kernel_regularizer='None',
kernel_regularizer_l1=0,
kernel_regularizer_l2=0,
bias_regularizer='None',
bias_regularizer_l1=0,
bias_regularizer_l2=0,
activity_regularizer='None',
activity_regularizer_l1=0,
activity_regularizer_l2=0,
kernel_constraint='None',
bias_constraint='None',
name=''
)
m4 = M.dl_model_init.v1(
inputs=m6.data,
outputs=m23.data
)
m5 = M.dl_model_train.v1(
input_model=m4.data,
training_data=m10.data_1,
validation_data=m10.data_2,
optimizer='Adam',
loss='mean_squared_error',
metrics='mse',
batch_size=1024,
epochs=30,
earlystop=m5_earlystop_bigquant_run,
custom_objects=m5_custom_objects_bigquant_run,
n_gpus=0,
verbose='2:每个epoch输出一行记录',
m_cached=False
)
m11 = M.dl_model_predict.v1(
trained_model=m5.data,
input_data=m27.data,
batch_size=1024,
n_gpus=0,
verbose='2:每个epoch输出一行记录'
)
m24 = M.cached.v3(
input_1=m11.data,
input_2=m18.data,
run=m24_run_bigquant_run,
post_run=m24_post_run_bigquant_run,
input_ports='',
params='{}',
output_ports=''
)
m19 = M.trade.v4(
instruments=m9.data,
options_data=m24.data_1,
start_date='',
end_date='',
initialize=m19_initialize_bigquant_run,
handle_data=m19_handle_data_bigquant_run,
prepare=m19_prepare_bigquant_run,
volume_limit=0.025,
order_price_field_buy='open',
order_price_field_sell='close',
capital_base=1000000,
auto_cancel_non_tradable_orders=True,
data_frequency='daily',
price_type='后复权',
product_type='股票',
plot_charts=True,
backtest_only=False,
benchmark='000300.SHA'
)
m12 = M.trade.v4(
instruments=m9.data,
options_data=m24.data_1,
start_date='',
end_date='',
initialize=m12_initialize_bigquant_run,
handle_data=m12_handle_data_bigquant_run,
prepare=m12_prepare_bigquant_run,
before_trading_start=m12_before_trading_start_bigquant_run,
volume_limit=0,
order_price_field_buy='open',
order_price_field_sell='close',
capital_base=100000,
auto_cancel_non_tradable_orders=True,
data_frequency='daily',
price_type='真实价格',
product_type='股票',
plot_charts=True,
backtest_only=False,
benchmark='000300.SHA'
)
[2023-04-07 13:19:32.348296] INFO: moduleinvoker: instruments.v2 开始运行..
[2023-04-07 13:19:32.518343] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:19:32.519889] INFO: moduleinvoker: instruments.v2 运行完成[0.171616s].
[2023-04-07 13:19:32.529215] INFO: moduleinvoker: advanced_auto_labeler.v2 开始运行..
[2023-04-07 13:19:32.537842] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:19:32.539256] INFO: moduleinvoker: advanced_auto_labeler.v2 运行完成[0.010031s].
[2023-04-07 13:19:32.546975] INFO: moduleinvoker: standardlize.v8 开始运行..
[2023-04-07 13:19:32.553421] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:19:32.554698] INFO: moduleinvoker: standardlize.v8 运行完成[0.00772s].
[2023-04-07 13:19:32.560117] INFO: moduleinvoker: input_features.v1 开始运行..
[2023-04-07 13:19:32.565387] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:19:32.566504] INFO: moduleinvoker: input_features.v1 运行完成[0.006393s].
[2023-04-07 13:19:32.579978] INFO: moduleinvoker: general_feature_extractor.v7 开始运行..
[2023-04-07 13:19:32.587171] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:19:32.588297] INFO: moduleinvoker: general_feature_extractor.v7 运行完成[0.008325s].
[2023-04-07 13:19:32.595659] INFO: moduleinvoker: derived_feature_extractor.v3 开始运行..
[2023-04-07 13:19:32.603502] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:19:32.604673] INFO: moduleinvoker: derived_feature_extractor.v3 运行完成[0.009013s].
[2023-04-07 13:19:32.610506] INFO: moduleinvoker: standardlize.v8 开始运行..
[2023-04-07 13:19:32.619009] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:19:32.620288] INFO: moduleinvoker: standardlize.v8 运行完成[0.009783s].
[2023-04-07 13:19:32.632965] INFO: moduleinvoker: fillnan.v1 开始运行..
[2023-04-07 13:19:32.641427] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:19:32.643078] INFO: moduleinvoker: fillnan.v1 运行完成[0.010108s].
[2023-04-07 13:19:32.652028] INFO: moduleinvoker: join.v3 开始运行..
[2023-04-07 13:19:32.660273] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:19:32.661639] INFO: moduleinvoker: join.v3 运行完成[0.009592s].
[2023-04-07 13:19:32.683137] INFO: moduleinvoker: dl_convert_to_bin.v2 开始运行..
[2023-04-07 13:19:32.697400] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:19:32.698829] INFO: moduleinvoker: dl_convert_to_bin.v2 运行完成[0.015707s].
[2023-04-07 13:19:32.710964] INFO: moduleinvoker: cached.v3 开始运行..
[2023-04-07 13:19:32.718922] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:19:32.720067] INFO: moduleinvoker: cached.v3 运行完成[0.009109s].
[2023-04-07 13:19:32.724179] INFO: moduleinvoker: instruments.v2 开始运行..
[2023-04-07 13:19:32.731193] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:19:32.732261] INFO: moduleinvoker: instruments.v2 运行完成[0.008082s].
[2023-04-07 13:19:32.748686] INFO: moduleinvoker: general_feature_extractor.v7 开始运行..
[2023-04-07 13:19:32.755146] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:19:32.756574] INFO: moduleinvoker: general_feature_extractor.v7 运行完成[0.007893s].
[2023-04-07 13:19:32.763397] INFO: moduleinvoker: derived_feature_extractor.v3 开始运行..
[2023-04-07 13:19:32.771675] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:19:32.772795] INFO: moduleinvoker: derived_feature_extractor.v3 运行完成[0.009397s].
[2023-04-07 13:19:32.776833] INFO: moduleinvoker: standardlize.v8 开始运行..
[2023-04-07 13:19:32.782973] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:19:32.784040] INFO: moduleinvoker: standardlize.v8 运行完成[0.007206s].
[2023-04-07 13:19:32.790966] INFO: moduleinvoker: fillnan.v1 开始运行..
[2023-04-07 13:19:32.797062] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:19:32.798137] INFO: moduleinvoker: fillnan.v1 运行完成[0.007171s].
[2023-04-07 13:19:32.809512] INFO: moduleinvoker: dl_convert_to_bin.v2 开始运行..
[2023-04-07 13:19:32.816433] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:19:32.817662] INFO: moduleinvoker: dl_convert_to_bin.v2 运行完成[0.008155s].
[2023-04-07 13:19:32.832780] INFO: moduleinvoker: dl_layer_input.v1 运行完成[0.005652s].
[2023-04-07 13:19:32.857753] INFO: moduleinvoker: dl_layer_dense.v1 运行完成[0.017815s].
[2023-04-07 13:19:32.867708] INFO: moduleinvoker: dl_layer_dropout.v1 运行完成[0.003602s].
[2023-04-07 13:19:32.883309] INFO: moduleinvoker: dl_layer_dense.v1 运行完成[0.00983s].
[2023-04-07 13:19:32.891117] INFO: moduleinvoker: dl_layer_dropout.v1 运行完成[0.002227s].
[2023-04-07 13:19:32.902611] INFO: moduleinvoker: dl_layer_dense.v1 运行完成[0.006627s].
[2023-04-07 13:19:32.928175] INFO: moduleinvoker: cached.v3 开始运行..
[2023-04-07 13:19:32.949369] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:19:32.951192] INFO: moduleinvoker: cached.v3 运行完成[0.023018s].
[2023-04-07 13:19:32.953882] INFO: moduleinvoker: dl_model_init.v1 运行完成[0.046052s].
[2023-04-07 13:19:32.962172] INFO: moduleinvoker: dl_model_train.v1 开始运行..
[2023-04-07 13:19:33.672260] INFO: dl_model_train: 准备训练,训练样本个数:1886061,迭代次数:30
[2023-04-07 13:21:52.349120] INFO: dl_model_train: 训练结束,耗时:138.68s
[2023-04-07 13:21:52.387717] INFO: moduleinvoker: dl_model_train.v1 运行完成[139.425537s].
[2023-04-07 13:21:52.395568] INFO: moduleinvoker: dl_model_predict.v1 开始运行..
[2023-04-07 13:21:53.481779] INFO: moduleinvoker: dl_model_predict.v1 运行完成[1.086215s].
[2023-04-07 13:21:53.497152] INFO: moduleinvoker: cached.v3 开始运行..
[2023-04-07 13:21:56.001320] INFO: moduleinvoker: cached.v3 运行完成[2.504172s].
[2023-04-07 13:21:56.069463] INFO: moduleinvoker: backtest.v8 开始运行..
[2023-04-07 13:21:56.075917] INFO: backtest: biglearning backtest:V8.6.3
[2023-04-07 13:21:56.077351] INFO: backtest: product_type:stock by specified
[2023-04-07 13:21:56.136741] INFO: moduleinvoker: cached.v2 开始运行..
[2023-04-07 13:22:02.484713] INFO: backtest: 读取股票行情完成:1891518
[2023-04-07 13:22:03.672283] INFO: moduleinvoker: cached.v2 运行完成[7.535547s].
[2023-04-07 13:22:10.562575] INFO: backtest: algo history_data=DataSource(d740f1d22e9549c28c0154d82cd63343T)
[2023-04-07 13:22:10.564382] INFO: algo: TradingAlgorithm V1.8.9
[2023-04-07 13:22:11.882305] INFO: algo: trading transform...
[2023-04-07 13:22:26.326845] WARNING: Performance: maybe_close_position no price for asset:Equity(1011 [000979.SZA]), field:price, dt:2018-12-28 15:00:00+00:00
[2023-04-07 13:22:26.702788] INFO: Performance: Simulated 241 trading days out of 241.
[2023-04-07 13:22:26.704178] INFO: Performance: first open: 2018-01-15 09:30:00+00:00
[2023-04-07 13:22:26.705297] INFO: Performance: last close: 2019-01-10 15:00:00+00:00
[2023-04-07 13:22:30.958653] INFO: moduleinvoker: backtest.v8 运行完成[34.889193s].
[2023-04-07 13:22:30.959975] INFO: moduleinvoker: trade.v4 运行完成[34.951239s].
[2023-04-07 13:22:31.003324] INFO: moduleinvoker: backtest.v8 开始运行..
[2023-04-07 13:22:31.008378] INFO: backtest: biglearning backtest:V8.6.3
[2023-04-07 13:22:31.715410] INFO: backtest: product_type:stock by specified
[2023-04-07 13:22:31.796545] INFO: moduleinvoker: cached.v2 开始运行..
[2023-04-07 13:22:31.808365] INFO: moduleinvoker: 命中缓存
[2023-04-07 13:22:31.809719] INFO: moduleinvoker: cached.v2 运行完成[0.013191s].
[2023-04-07 13:22:39.343467] INFO: backtest: algo history_data=DataSource(366ebd136cd247ddb4318d2b2ea185ceT)
[2023-04-07 13:22:39.345009] INFO: algo: TradingAlgorithm V1.8.9
[2023-04-07 13:22:40.759510] INFO: algo: trading transform...
[2023-04-07 13:22:43.150118] INFO: Performance: Simulated 241 trading days out of 241.
[2023-04-07 13:22:43.151534] INFO: Performance: first open: 2018-01-15 09:30:00+00:00
[2023-04-07 13:22:43.152823] INFO: Performance: last close: 2019-01-10 15:00:00+00:00
[2023-04-07 13:22:45.388723] INFO: moduleinvoker: backtest.v8 运行完成[14.38539s].
[2023-04-07 13:22:45.390098] INFO: moduleinvoker: trade.v4 运行完成[14.422541s].
#多因子模型分回归和排序两类,其中回归重在解释,而排序旨在选股收益。
startdate = '20140101'
enddate = '20190123'
data=m13.data.read()
data
# import pandas as pd
# from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
# from sklearn.linear_model import LogisticRegression
# from sklearn.tree import DecisionTreeClassifier
# from sklearn.neighbors import KNeighborsClassifier
# # 读取数据
# data = m13.data.read()
# # 独热编码,将“instrument”列转换为数值类型的特征
# instrument_col = data['instrument']
# instrument_df = pd.get_dummies(instrument_col, drop_first=True)
# # 合并独热编码后的 DataFrame 和原始 DataFrame
# data = pd.concat([data.drop('instrument', axis=1), instrument_df], axis=1)
# # 分割数据集
# X = data.drop('label', axis=1)
# y = data['label']
# train_X, test_X, train_y, test_y = train_test_split(X, y, test_size=0.3, random_state=1)
# # 定义模型和模型参数
# models = [('LR', LogisticRegression(solver='liblinear', class_weight='balanced', random_state=1)),
# ('KNN', KNeighborsClassifier()),
# ('CART', DecisionTreeClassifier())]
# from sklearn.preprocessing import LabelEncoder
# le = LabelEncoder()
# train_y = le.fit_transform(train_y)
# from sklearn.preprocessing import OneHotEncoder
# # 对离散特征进行one-hot编码
# encoder = OneHotEncoder(categories='auto', sparse=False)
# train_X_encoded = encoder.fit_transform(train_X[discrete_features])
# # 对数值特征进行标准化
# scaler = StandardScaler()
# train_X_scaled = scaler.fit_transform(train_X[numerical_features])
# # 将编码后的特征和标准化后的特征合并起来
# train_X = np.hstack((train_X_scaled, train_X_encoded))
# # 交叉验证
# for name, model in models:
# # 模型训练
# print(name, model)
# model.fit(train_X, train_y)
# # 评估指标
# kfold = StratifiedKFold(n_splits=10, random_state=1, shuffle=True)
# scores = cross_val_score(model, train_X, train_y, scoring='accuracy', cv=kfold)
# print(f'{name}: {scores.mean()}, {scores.std()}')