TabNet: Attentive Interpretable Tabular Learning
基于Tabnet模型的量化选股方案。抽取了98个量价因子,2010到2018年为数据训练TabNet模型,并将模型的预测结果应用在2018到2021年9月的数据上进行了回测。
TabNet核心参数
# 本代码由可视化策略环境自动生成 2022年12月1日 12:36
# 本代码单元只能在可视化模式下编辑。您也可以拷贝代码,粘贴到新建的代码单元或者策略,然后修改。
# Python 代码入口函数,input_1/2/3 对应三个输入端,data_1/2/3 对应三个输出端
def m12_run_bigquant_run(input_1, input_2, input_3):
# 示例代码如下。在这里编写您的代码
from sklearn.model_selection import train_test_split
data = input_1.read()
x_train, x_val, y_train, y_val = train_test_split(data["x"], data['y'], random_state=2021)
data_1 = DataSource.write_pickle({'x': x_train, 'y': y_train.reshape(-1, 1)})
data_2 = DataSource.write_pickle({'x': x_val, 'y': y_val.reshape(-1, 1)})
return Outputs(data_1=data_1, data_2=data_2, data_3=None)
# 后处理函数,可选。输入是主函数的输出,可以在这里对数据做处理,或者返回更友好的outputs数据格式。此函数输出不会被缓存。
def m12_post_run_bigquant_run(outputs):
return outputs
# Python 代码入口函数,input_1/2/3 对应三个输入端,data_1/2/3 对应三个输出端
def m20_run_bigquant_run(input_1, input_2, input_3):
# 示例代码如下。在这里编写您的代码
pred_label = input_1.read_pickle()
df = input_2.read_df()
df = pd.DataFrame({'pred_label':pred_label[:,0], 'instrument':df.instrument, 'date':df.date})
df.sort_values(['date','pred_label'],inplace=True, ascending=[True,False])
return Outputs(data_1=DataSource.write_df(df), data_2=None, data_3=None)
# 后处理函数,可选。输入是主函数的输出,可以在这里对数据做处理,或者返回更友好的outputs数据格式。此函数输出不会被缓存。
def m20_post_run_bigquant_run(outputs):
return outputs
# 回测引擎:初始化函数,只执行一次
def m29_initialize_bigquant_run(context):
# 加载预测数据
context.ranker_prediction = context.options['data'].read_df()
# 系统已经设置了默认的交易手续费和滑点,要修改手续费可使用如下函数
context.set_commission(PerOrder(buy_cost=0.0003, sell_cost=0.0013, min_cost=5))
# 预测数据,通过options传入进来,使用 read_df 函数,加载到内存 (DataFrame)
# 设置买入的股票数量,这里买入预测股票列表排名靠前的5只
stock_count = 5
# 每只的股票的权重,如下的权重分配会使得靠前的股票分配多一点的资金,[0.339160, 0.213986, 0.169580, ..]
context.stock_weights = T.norm([1 / math.log(i + 2) for i in range(0, stock_count)])
# 设置每只股票占用的最大资金比例
context.max_cash_per_instrument =0.1
context.options['hold_days'] = 5
# 这一段感觉是在为盘前准备函数写的,当order.amount>0时,认为是买入,<0卖出。
from zipline.finance.slippage import SlippageModel
class FixedPriceSlippage(SlippageModel):
def process_order(self, data, order, bar_volume=0, trigger_check_price=0):
if order.limit is None:
price_field = self._price_field_buy if order.amount > 0 else self._price_field_sell
price = data.current(order.asset, price_field)
else:
price = data.current(order.asset, self._price_field_buy)
# 返回希望成交的价格和数量
return (price, order.amount)
# 设置price_field,默认是开盘买入,收盘卖出
context.fix_slippage = FixedPriceSlippage(price_field_buy='open', price_field_sell='close')
context.set_slippage(us_equities=context.fix_slippage)
# 回测引擎:每日数据处理函数,每天执行一次
def m29_handle_data_bigquant_run(context, data):
# 获取当前持仓
positions = {e.symbol: p.amount * p.last_sale_price
for e, p in context.portfolio.positions.items()}
today = data.current_dt.strftime('%Y-%m-%d')
# 按日期过滤得到今日的预测数据
ranker_prediction = context.ranker_prediction[
context.ranker_prediction.date == today]
today_date = data.current_dt.strftime('%Y-%m-%d')
positions_all = [equity.symbol for equity in context.portfolio.positions]
dataprediction=context.dataprediction
today_prediction=dataprediction[dataprediction.date==today_date].direction.values[0]
# 满足空仓条件
if today_prediction<0:
if len(positions_all)>0:
# 全部卖出后返回
for i in positions_all:
if data.can_trade(context.symbol(i)):
context.order_target_percent(context.symbol(i), 0)
print('风控执行',today_date)
return
#运行风控后当日结束,不再执行后续的买卖订单
# 1. 资金分配
# 平均持仓时间是hold_days,每日都将买入股票,每日预期使用 1/hold_days 的资金
# 实际操作中,会存在一定的买入误差,所以在前hold_days天,等量使用资金;之后,尽量使用剩余资金(这里设置最多用等量的1.5倍)
is_staging = context.trading_day_index < context.options['hold_days'] # 是否在建仓期间(前 hold_days 天)
cash_avg = context.portfolio.portfolio_value / context.options['hold_days']
cash_for_buy = min(context.portfolio.cash, (1 if is_staging else 1.5) * cash_avg)
cash_for_sell = cash_avg - (context.portfolio.cash - cash_for_buy)
# 2. 根据需要加入移动止赢止损模块、固定天数卖出模块、ST或退市股卖出模块
stock_sold = [] # 记录卖出的股票,防止多次卖出出现空单
#------------------------START:止赢止损模块(含建仓期)---------------
current_stopwin_stock=[]
current_stoploss_stock = []
positions_cost={e.symbol:p.cost_basis for e,p in context.portfolio.positions.items()}
if len(positions)>0:
for instrument in positions.keys():
stock_cost=positions_cost[instrument]
stock_market_price=data.current(context.symbol(instrument),'price')
volume_since_buy = data.history(context.symbol(instrument), 'volume', 6, '1d')
# 赚60%且为可交易状态就止盈
if stock_market_price/stock_cost-1>=0.5 and data.can_trade(context.symbol(instrument)):
context.order_target_percent(context.symbol(instrument),0)
cash_for_sell -= positions[instrument]
current_stopwin_stock.append(instrument)
# 亏5%并且为可交易状态就止损
if stock_market_price/stock_cost-1 <= -0.05 and data.can_trade(context.symbol(instrument)):
context.order_target_percent(context.symbol(instrument),0)
cash_for_sell -= positions[instrument]
current_stoploss_stock.append(instrument)
# 放天量 止损:
# if (volume_since_buy[0]>1.5*volume_since_buy[1]) |(volume_since_buy[0]>1.5*(volume_since_buy[1]+volume_since_buy[2]+volume_since_buy[3]+volume_since_buy[4]+volume_since_buy[5])/5):
# context.order_target_percent(context.symbol(instrument),0)
# cash_for_sell -= positions[instrument]
# current_stoploss_stock.append(instrument)
if len(current_stopwin_stock)>0:
print(today,'止盈股票列表',current_stopwin_stock)
stock_sold += current_stopwin_stock
if len(current_stoploss_stock)>0:
print(today,'止损股票列表',current_stoploss_stock)
stock_sold += current_stoploss_stock
#--------------------------END: 止赢止损模块--------------------------
#--------------------------START:持有固定天数卖出(不含建仓期)-----------
current_stopdays_stock = []
positions_lastdate = {e.symbol:p.last_sale_date for e,p in context.portfolio.positions.items()}
# 不是建仓期(在前hold_days属于建仓期)
if not is_staging:
for instrument in positions.keys():
#如果上面的止盈止损已经卖出过了,就不要重复卖出以防止产生空单
if instrument in stock_sold:
continue
# 今天和上次交易的时间相隔hold_days就全部卖出 datetime.timedelta(context.options['hold_days'])也可以换成自己需要的天数,比如datetime.timedelta(5)
if data.current_dt - positions_lastdate[instrument]>=datetime.timedelta(22) and data.can_trade(context.symbol(instrument)):
context.order_target_percent(context.symbol(instrument), 0)
current_stopdays_stock.append(instrument)
cash_for_sell -= positions[instrument]
if len(current_stopdays_stock)>0:
print(today,'固定天数卖出列表',current_stopdays_stock)
stock_sold += current_stopdays_stock
#------------------------- END:持有固定天数卖出-----------------------
#-------------------------- START: ST和退市股卖出 ---------------------
st_stock_list = []
for instrument in positions.keys():
try:
instrument_name = ranker_prediction[ranker_prediction.instrument==instrument].name.values[0]
# 如果股票状态变为了st或者退市 则卖出
if 'ST' in instrument_name or '退' in instrument_name:
if instrument in stock_sold:
continue
if data.can_trade(context.symbol(instrument)):
context.order_target(context.symbol(instrument), 0)
st_stock_list.append(instrument)
cash_for_sell -= positions[instrument]
except:
continue
if st_stock_list!=[]:
print(today,'持仓出现st股/退市股',st_stock_list,'进行卖出处理')
stock_sold += st_stock_list
#-------------------------- END: ST和退市股卖出 ---------------------
# 3. 生成轮仓卖出订单:hold_days天之后才开始卖出;对持仓的股票,按机器学习算法预测的排序末位淘汰
if not is_staging and cash_for_sell > 0:
instruments = list(reversed(list(ranker_prediction.instrument[ranker_prediction.instrument.apply(
lambda x: x in positions)])))
for instrument in instruments:
# 如果资金够了就不卖出了
if cash_for_sell <= 0:
break
#防止多个止损条件同时满足,出现多次卖出产生空单
if instrument in stock_sold:
continue
context.order_target(context.symbol(instrument), 0)
cash_for_sell -= positions[instrument]
stock_sold.append(instrument)
# 4. 生成轮仓买入订单:按机器学习算法预测的排序,买入前面的stock_count只股票
# 计算今日跌停的股票
#dt_list = list(ranker_prediction[ranker_prediction.price_limit_status_0==1].instrument)
# 计算今日ST/退市的股票
st_list = list(ranker_prediction[ranker_prediction.name.str.contains('ST')|ranker_prediction.name.str.contains('退')].instrument)
# 计算所有禁止买入的股票池
banned_list = stock_sold+st_list
buy_cash_weights = context.stock_weights
buy_instruments=[k for k in list(ranker_prediction.instrument) if k not in banned_list][:len(buy_cash_weights)]
max_cash_per_instrument = context.portfolio.portfolio_value * context.max_cash_per_instrument
for i, instrument in enumerate(buy_instruments):
cash = cash_for_buy * buy_cash_weights[i]
if cash > max_cash_per_instrument - positions.get(instrument, 0):
# 确保股票持仓量不会超过每次股票最大的占用资金量
cash = max_cash_per_instrument - positions.get(instrument, 0)
if cash > 0:
current_price = data.current(context.symbol(instrument), 'price')
amount = math.floor(cash / current_price - cash / current_price % 100)
context.order(context.symbol(instrument), amount)
# 回测引擎:准备数据,只执行一次
def m29_prepare_bigquant_run(context):
context.status_df = D.features(instruments =context.instruments,start_date = context.start_date, end_date = context.end_date,
fields=['st_status_0','price_limit_status_0','price_limit_status_1'])
seq_len=5 #每个input的长度
# 导入包
from tensorflow.keras.layers import Activation, Dropout
from tensorflow.keras.models import Sequential
from tensorflow.keras import optimizers
import tensorflow.keras as tf
from sklearn.preprocessing import scale
from tensorflow.keras.layers import Input, Dense, LSTM
from tensorflow.keras.models import Model
# 基础参数配置
instrument = '000300.SHA' #股票代码
#设置用于训练和回测的开始/结束日期
train_length=seq_len*10
# 多取几天的数据,这里前100天
start_date_temp= (pd.to_datetime(context.start_date) - datetime.timedelta(days=2*train_length)).strftime('%Y-%m-%d')
print(start_date_temp)
len1=len(D.trading_days(start_date=start_date_temp, end_date=context.end_date))
print(len1)
len2=len(D.trading_days(start_date=context.start_date, end_date=context.end_date))
print(len2)
distance=len1-len2
print(distance)
trade_day=D.trading_days(start_date=start_date_temp, end_date=context.end_date)
print(trade_day)
start_date = trade_day.iloc[distance-train_length][0].strftime('%Y-%m-%d')
print(start_date)
split_date = trade_day.iloc[distance-1][0].strftime('%Y-%m-%d')
print(split_date)
# features因子
fields = ['close', 'open', 'high', 'low', 'amount', 'volume']
#整数,指定进行梯度下降时每个batch包含的样本数,训练时一个batch的样本会被计算一次梯度下降,使目标函数优化一步
batch = 100
# 数据导入以及初步处理
data1 = D.history_data(instrument, start_date, context.end_date, fields)
data1['return'] = data1['close'].shift(-5) / data1['open'].shift(-1) - 1 #计算未来5日收益率(未来第五日的收盘价/明日的开盘价)
data1=data1[data1.amount>0]
datatime = data1['date'][data1.date>split_date] #记录predictions的时间,回测要用
data1['return'] = data1['return']
data1['return'] = data1['return']*10 # 适当增大return(收益)范围,利于LSTM模型训练
data1.reset_index(drop=True, inplace=True)
scaledata = data1[fields]
traindata = data1[data1.date<=split_date]
# 数据处理:设定每个input(series×6features)以及数据标准化
train_input = []
train_output = []
test_input = []
for i in range(seq_len-1, len(traindata)):
a = scale(scaledata[i+1-seq_len:i+1])
train_input.append(a)
c = data1['return'][i]
train_output.append(c)
for j in range(len(traindata), len(data1)):
b = scale(scaledata[j+1-seq_len:j+1])
test_input.append(b)
# LSTM接受数组类型的输入
train_x = np.array(train_input)
train_y = np.array(train_output)
test_x = np.array(test_input)
# 自定义激活函数
import tensorflow.keras as tf
def atan(x):
return tf.atan(x)
# 构建神经网络层 1层LSTM层+3层Dense层
# 用于1个输入情况
lstm_input = Input(shape=(seq_len,len(fields)), name='lstm_input')
lstm_output = LSTM(32,input_shape=(seq_len,len(fields)))(lstm_input)
Dense_output_1 = Dense(16, activation='linear')(lstm_output)
Dense_output_2 = Dense(4, activation='linear')(Dense_output_1)
predictions = Dense(1)(Dense_output_2)
model = Model(inputs=lstm_input, outputs=predictions)
model.compile(optimizer='adam', loss='mse', metrics=['mse'])
model.fit(train_x, train_y, batch_size=batch, epochs=5, verbose=0)
# 预测
predictions = model.predict(test_x)
# 如果预测值>0,取为1;如果预测值<=0,取为-1.为回测做准备
for i in range(len(predictions)):
if predictions[i]>0:
predictions[i]=1
elif predictions[i]<=0:
predictions[i]=-1
# 将预测值与时间整合作为回测数据
cc = np.reshape(predictions,len(predictions), )
dataprediction = pd.DataFrame()
dataprediction['date'] = datatime
dataprediction['direction']=np.round(cc)
context.dataprediction=dataprediction
def m29_before_trading_start_bigquant_run(context, data):
# 获取涨跌停状态数据
df_price_limit_status = context.ranker_prediction.set_index('date')
today=data.current_dt.strftime('%Y-%m-%d')
# 得到当前未完成订单
for orders in get_open_orders().values():
# 循环,撤销订单
for _order in orders:
ins=str(_order.sid.symbol)
try:
#判断一下如果当日涨停(3),则取消卖单
if df_price_limit_status[df_price_limit_status.instrument==ins].price_limit_status_0.ix[today]>2 and _order.amount<0:
cancel_order(_order)
print(today,'尾盘涨停取消卖单',ins)
except:
continue
m23 = M.input_features.v1(
features="""
# #号开始的表示注释,注释需单独一行
# 多个特征,每行一个,可以包含基础特征和衍生特征,特征须为本平台特征
#bm_0 = where(close/shift(close,5)-1<-0.05,1,0)
# 如果macd中的dif下穿macd中的dea,则bm_0等于1,否则等于0
bm_0=where(ta_macd_dif(close,2,4,4)-ta_macd_dea(close,2,4,4)<0,1,0)"""
)
m26 = M.input_features.v1(
features="""
# #号开始的表示注释,注释需单独一行
# 多个特征,每行一个,可以包含基础特征和衍生特征,特征须为本平台特征
name"""
)
m32 = M.instruments.v2(
start_date='2010-01-01',
end_date='2017-12-31',
market='CN_STOCK_A',
instrument_list='',
max_count=0
)
m3 = M.advanced_auto_labeler.v2(
instruments=m32.data,
label_expr="""# #号开始的表示注释
# 0. 每行一个,顺序执行,从第二个开始,可以使用label字段
# 1. 可用数据字段见 https://bigquant.com/docs/develop/datasource/deprecated/history_data.html
# 添加benchmark_前缀,可使用对应的benchmark数据
# 2. 可用操作符和函数见 `表达式引擎 <https://bigquant.com/docs/develop/bigexpr/usage.html>`_
# 计算收益:5日收盘价(作为卖出价格)除以明日开盘价(作为买入价格)
shift(close, -5) / shift(open, -1)
# 极值处理:用1%和99%分位的值做clip
clip(label, all_quantile(label, 0.01), all_quantile(label, 0.99))
# 将分数映射到分类,这里使用20个分类
all_wbins(label, 20)
# 过滤掉一字涨停的情况 (设置label为NaN,在后续处理和训练中会忽略NaN的label)
where(shift(high, -1) == shift(low, -1), NaN, label)
""",
start_date='',
end_date='',
benchmark='000300.SHA',
drop_na_label=True,
cast_label_int=True,
user_functions={}
)
m17 = M.standardlize.v8(
input_1=m3.data,
columns_input='label'
)
m1 = M.instruments.v2(
start_date='2018-01-01',
end_date='2021-12-31',
market='CN_STOCK_A',
instrument_list='',
max_count=0
)
m27 = M.use_datasource.v1(
instruments=m1.data,
features=m26.data,
datasource_id='instruments_CN_STOCK_A',
start_date='',
end_date=''
)
m24 = M.index_feature_extract.v3(
input_1=m1.data,
input_2=m23.data,
before_days=100,
index='000300.HIX'
)
m25 = M.select_columns.v3(
input_ds=m24.data_1,
columns='date,bm_0',
reverse_select=False
)
m28 = M.join.v3(
data1=m27.data,
data2=m25.data,
on='date',
how='left',
sort=True
)
m5 = M.input_features.v1(
features="""high_0
high_1
high_2
high_3
high_4
low_0
low_1
low_2
low_3
low_4
# 5日平均换手率
avg_turn_5
# 5日平均振幅
(high_0-low_0+high_1-low_1+high_2-low_2+high_3-low_3+high_4-low_4)/5
# 市盈率LYR
pe_lyr_0
# 5日净主动买入额
mf_net_amount_5
# 10日净主动买入额
mf_net_amount_10
# 20日净主动买入额
mf_net_amount_20
"""
)
m33 = M.input_features.v1(
features_ds=m5.data,
features="""
# #号开始的表示注释,注释需单独一行
# 多个特征,每行一个,可以包含基础特征和衍生特征,特征须为本平台特征
close_0
high_1
open_0
low_0
st_status_0
"""
)
m6 = M.general_feature_extractor.v7(
instruments=m32.data,
features=m33.data,
start_date='',
end_date='',
before_start_days=10
)
m7 = M.derived_feature_extractor.v3(
input_data=m6.data,
features=m33.data,
date_col='date',
instrument_col='instrument',
drop_na=True,
remove_extra_columns=False
)
m8 = M.general_feature_extractor.v7(
instruments=m1.data,
features=m33.data,
start_date='',
end_date='',
before_start_days=10
)
m9 = M.derived_feature_extractor.v3(
input_data=m8.data,
features=m33.data,
date_col='date',
instrument_col='instrument',
drop_na=True,
remove_extra_columns=False
)
m16 = M.standardlize.v8(
input_1=m9.data,
columns_input='[]'
)
m13 = M.standardlize.v8(
input_1=m7.data,
input_2=m5.data,
columns_input='[]'
)
m14 = M.fillnan.v1(
input_data=m13.data,
features=m5.data,
fill_value='0.0'
)
m4 = M.join.v3(
data1=m17.data,
data2=m14.data,
on='date,instrument',
how='inner',
sort=False
)
m21 = M.filter.v3(
input_data=m4.data,
expr='st_status_0==0 ',
output_left_data=False
)
m15 = M.fillnan.v1(
input_data=m16.data,
features=m5.data,
fill_value='0.0'
)
m22 = M.filter.v3(
input_data=m15.data,
expr='st_status_0==0 ',
output_left_data=False
)
m11 = M.dl_convert_to_bin.v2(
input_data=m22.data,
features=m5.data,
window_size=1,
feature_clip=3,
flatten=True,
window_along_col='instrument'
)
m10 = M.dl_convert_to_bin.v2(
input_data=m21.data,
features=m5.data,
window_size=1,
feature_clip=3,
flatten=True,
window_along_col='instrument'
)
m12 = M.cached.v3(
input_1=m10.data,
run=m12_run_bigquant_run,
post_run=m12_post_run_bigquant_run,
input_ports='',
params='{}',
output_ports=''
)
m18 = M.dl_models_tabnet_train.v1(
training_data=m12.data_1,
validation_data=m12.data_2,
input_dim=16,
n_steps=3,
n_d=16,
n_a=16,
gamma=1.3,
momentum=0.02,
batch_size=5120,
virtual_batch_size=512,
epochs=1,
num_workers=4,
device_name='cpu:使用cpu训练',
verbose='1:输出进度条记录'
)
m19 = M.dl_models_tabnet_predict.v1(
trained_model=m18.data,
input_data=m11.data,
m_cached=False
)
m20 = M.cached.v3(
input_1=m19.data,
input_2=m22.data,
run=m20_run_bigquant_run,
post_run=m20_post_run_bigquant_run,
input_ports='',
params='{}',
output_ports=''
)
m30 = M.join.v3(
data1=m20.data_1,
data2=m28.data,
on='date,instrument',
how='left',
sort=False
)
m29 = M.trade.v4(
instruments=m1.data,
options_data=m30.data,
start_date='',
end_date='',
initialize=m29_initialize_bigquant_run,
handle_data=m29_handle_data_bigquant_run,
prepare=m29_prepare_bigquant_run,
before_trading_start=m29_before_trading_start_bigquant_run,
volume_limit=0.025,
order_price_field_buy='open',
order_price_field_sell='close',
capital_base=1000001,
auto_cancel_non_tradable_orders=True,
data_frequency='daily',
price_type='后复权',
product_type='股票',
plot_charts=True,
backtest_only=False,
benchmark=''
)
# 输出predict
predict_df = m20.data_1.read()
predict_df.head()
predict_df.to_csv("tabnet_predict.csv")