复制链接
克隆策略

    {"description":"实验创建于2017/8/26","graph":{"edges":[{"to_node_id":"-635:input_ds","from_node_id":"-179:data_1"},{"to_node_id":"-670:predictions","from_node_id":"-635:sorted_data"}],"nodes":[{"node_id":"-670","module_id":"BigQuantSpace.metrics_classification.metrics_classification-v1","parameters":[],"input_ports":[{"name":"predictions","node_id":"-670"}],"output_ports":[{"name":"data","node_id":"-670"}],"cacheable":true,"seq_num":12,"comment":"","comment_collapsed":true},{"node_id":"-179","module_id":"BigQuantSpace.cached.cached-v3","parameters":[{"name":"run","value":"import random\nimport datetime\nimport math\nimport pandas as pd\n\nclass XDebug:\n @staticmethod\n def Check(bFlag, strOutput = None):\n if bFlag:\n return\n if strOutput is not None:\n raise Exception(strOutput)\n else:\n raise\n\nclass XNum:\n @staticmethod\n def IsNA(x):\n if x is None:\n return(True)\n if pd.isna(x):\n return(True)\n return(False)\n\n @staticmethod\n def IsNum(x):\n if XNum.IsNA(x):\n return False\n\n if type(x) == type(1.0) or type(x) == type(1):\n return(True)\n\n\n return(False)\n\n @staticmethod\n def Test():\n XDebug.Check( XNum.IsNum(100))\n\n\n\n\nclass XStable:\n @staticmethod\n def WOE(df, col, y='xxx'):\n \"\"\"\n df:数据集\n col:特征名\n y:样本定义根据的列名(1:黑样本,0:白样本)\n \"\"\"\n # 黑样本\n black_cnt = df.groupby(col)[y].sum()\n # 白样本\n white_cnt = df.groupby(col)[y].sum()\n\n # 所有黑样本\n black_cnt_total = df[y].sum()\n # 所有白样本\n white_cnt_total = df.shape(0) - df[y].sum()\n\n # pyi\n pyi = black_cnt / black_cnt_total\n # pni\n pni = white_cnt / white_cnt_total\n\n # woe\n woe = (pyi / pni).map(lambda x: math.log(x))\n\n return woe, pyi, pni\n\n @staticmethod\n # 逻辑代码\n def IV(df, col, y='xxx'):\n \"\"\"\n df:数据集\n col:特征名\n y:样本定义根据的列名(1:黑样本,0:白样本)\n \"\"\"\n # 获取woe、pyi、pni\n woe, pyi, pni = WOE(df, col, y)\n # 计算特征每个分箱的iv值\n iv_i = (pyi - pni) * woe\n # 返回该特征的iv值\n return iv_i.sum()\n\n @staticmethod\n def PSI(arrReal, arrExpect)->float:\n ##psi = sum((实际占比-预期占比)* ln(实际占比/预期占比))\n arrSplit = XList.CalcBinsByFreq(arrExpect, nBins = 10)\n arrBins = XList.Value2Bins(arrExpect,arrSplit )\n arrBinsReal = XList.Value2Bins(arrReal, arrSplit)\n\n mFreq = XList.BinsFreq(arrBins)\n mFreqReal = XList.BinsFreq(arrBinsReal)\n\n s = 0\n for x in mFreq.keys():\n freqReal = mFreqReal.get(x, 0)\n freq = mFreq.get(x, 0 )\n if freq > 0.0001:\n part0 = freqReal / freq\n part1 = (freqReal - freq)\n if part0 > 0:\n part2= math.log( part0 , math.e)\n v = part1 * part2\n assert(v >= 0)\n s+= v\n else:\n s+= 0.2\n\n return(s)\n\n @staticmethod\n def Test(n=1000):\n arr = [None] * n\n arr = list(map(lambda x:random.random(), arr))\n\n arr2 = [None] * n\n arr2 = list(map(lambda x: random.random(), arr2))\n\n print(XStable.PSI(arr, arr2))\n\n\nclass XList:\n @staticmethod\n def IsNumber(series):\n arr = list(map(lambda x: XNum.IsNum(x) or XNum.IsNA(x), series))\n arrTrue = XLambda.FilterTrueIndex(arr)\n if len(arrTrue) == len(series):\n return (True)\n return False\n\n @staticmethod\n def CalcBinsByFreq(arr, nBins=10, bNeedSort=True, ascending=True):\n \"\"\"\n 按频次切 bin\n :param arr:\n :param nBins:\n :param bNeedSort:\n :param ascending:\n :return:\n \"\"\"\n\n nBins += 1\n\n arr2 = list(arr)\n if bNeedSort:\n arr2 = sorted(arr2, reverse=not ascending)\n\n N = len(arr2)\n step = math.floor(N / nBins)\n\n currPos = 0\n\n ret = []\n\n for i in range(1, nBins):\n currPos = i * step\n\n value = arr2[currPos]\n\n if ascending:\n if len(ret) == 0 or value > ret[-1]:\n ret.append(value)\n else:\n if len(ret) == 0 or value < ret[-1]:\n ret.append(value)\n return (ret)\n\n @staticmethod\n def BinsFreq(arr):\n m = {}\n for x in arr:\n m[x] = m.get(x, 0) + 1\n\n for x in m.keys():\n m[x] = m[x] / len(arr)\n return m\n\n @staticmethod\n def Value2Bins(arr, arrSplit , bNeedSort=True, ascending=True):\n \"\"\"\n 按频次切 bin, 并把值映射到bin上\n :param arr:\n :param nBins:\n :param bNeedSort:\n :param ascending:\n :return:\n \"\"\"\n\n ##arrSplit = XList.CalcBinsByFreq(arr, nBins, bNeedSort=bNeedSort, ascending=ascending)\n\n ret = []\n N = len(arrSplit)\n for x in arr:\n i = 0\n if ascending:\n while i < N:\n if x > arrSplit[i]:\n i += 1\n else:\n break\n else:\n while i < N:\n if x < arrSplit[i]:\n i += 1\n else:\n break\n\n ret.append(i)\n return ret\n\n @staticmethod\n def Test():\n a = [None] * 1000\n a = list(map(lambda x: random.randint(100,1000), a))\n arrSplit = XList.CalcBinsByFreq(a)\n print(arrSplit)\n c = XList.Value2Bins(a, arrSplit)\n print(c)\n\nclass XPickle:\n @staticmethod\n def WriteCsv(df, fileName):\n path = \"/home/bigquant/work/userlib/\"\n ##pd.DataFrame([input_1.read()]).to_pickle(path + fileName)\n df.to_pickle(path + fileName)\n\n @staticmethod\n def ReadCsv(fileName):\n path = \"/home/bigquant/work/userlib/\"\n df = pd.read_pickle(path + fileName) ##(\"/home/bigquant/work/userlib/data_more_feature/2018-01-01.csv\")\n return df\n\n\nclass XRandom:\n @staticmethod\n def SampleRows(df, nRowSample):\n nRow = df.shape[0]\n arrSample = random.sample(range(0, nRow), nRowSample)\n arrRemain = list(set(range(0, nRow)) - set(arrSample))\n return df.iloc[arrSample, :], df.iloc[arrRemain, :]\n\n\nclass XTime:\n @staticmethod\n def AddMonths(currDay, nMonth):\n date = datetime.datetime.strptime(currDay, \"%Y-%m-%d\")\n print(date)\n date2 = date + datetime.timedelta(days=nMonth * 31)\n ret = datetime.datetime.strftime(date2, \"%Y-%m-01\")\n return ret\n\n @staticmethod\n def ToDate(strDate1):\n date1 = datetime.datetime.strptime(strDate1, \"%Y-%m-%d\")\n return(date1)\n\n @staticmethod\n def PDTime2Str(date1, formatStr = \"%Y-%m-%d\"):\n return date1.strftime(formatStr)\n\n @staticmethod\n def DateSmallerThan(strDate1, strDate2):\n \"\"\"\n return (date1 < date2)\n\n :param strDate1:\n :param strDate2:\n :return:\n \"\"\"\n date1 = XTime.ToDate(strDate1)\n date2 = XTime.ToDate(strDate2)\n return (date1 < date2)\n\n @staticmethod\n def Test():\n XDebug.Check(XTime.DateSmallerThan(\"2018-01-01\", \"2018-01-02\"))\n\n\n\n\nclass XPandas:\n @staticmethod\n def FilterColumns(df, arrColNames):\n colNames = XPandas.GetColumnNames(df)\n arr = []\n nFound = 0\n for i in range(0, len(colNames)):\n if colNames[i] in arrColNames:\n nFound += 1\n arr.append(i)\n XDebug.Check(nFound == len(arrColNames))\n return df.iloc[:, arr].copy()\n\n @staticmethod\n def SortByColumn(df, byColNames = [], inplace = True, ascending = True):\n if inplace:\n df = df.sort_values(by = byColNames, inplace = inplace, ascending = ascending)\n return(df)\n else:\n df2 = df.sort_values(by=byColNames, inplace=inplace, ascending=ascending)\n return(df2)\n\n @staticmethod\n def SampleRows(df, nRow):\n N = df.shape[0]\n sampleArr = random.sample(list(range(0,N)) , nRow)\n remainArr = list(set(list(range(0,N))) - set(sampleArr))\n return df.iloc[sampleArr, :], df.iloc[remainArr, :]\n\n @staticmethod\n def GetColumnNames(df):\n return df.columns\n\n @staticmethod\n def MyJoin(data1, data2, onKey=[\"instrument\", \"date\"], how=\"left\"):\n name1 = list(data1.columns)\n name2 = list(data2.columns)\n\n nameDiff = list(set(name2) - set(name1))\n\n ##print(nameDiff)\n\n nameDiff.append(\"instrument\")\n nameDiff.append(\"date\")\n\n data3 = data2[nameDiff]\n\n ##data4 = data1.join(data3, on =[\"instrument\", \"date\"] ,how = \"left\") ## join\n data4 = pd.merge(data1, data3, on=onKey, how=how)\n\n return (data4)\n\n\nclass XHash:\n @staticmethod\n def MyHash(s: str) -> int:\n return int(hash(s))\n\n\nclass XLambda:\n @staticmethod\n def FilterTrueIndex(arr: list) -> list:\n\n ret = []\n for i in range(0, len(arr)):\n if arr[i] == True:\n ret.append(i)\n return (ret)\n\n\nclass XSample:\n @staticmethod\n def Sample10Pct(data1):\n arr1 = zip(list(data1[\"instrument\"]), list(data1[\"date\"]))\n arr2 = list(map(lambda x: True if XUtil.MyHash(str(x)) % 10 == 1 else False, arr1))\n arr3 = XLambda.FilterTrueIndex(arr2)\n data2 = data1.iloc[arr3, :]\n return (data2)\n\n\n\n# Python 代码入口函数,input_1/2/3 对应三个输入端,data_1/2/3 对应三个输出端\ndef bigquant_run(input_1, input_2, input_3):\n \n \n pd3 = XPickle.ReadCsv(\"error_check2.csv\")\n\n ds3 = DataSource.write_pickle(pd3)\n print(\"end222..\")\n return Outputs(data_1 = ds3 , data_2 = None, data_3 = None)\n\n","type":"Literal","bound_global_parameter":null},{"name":"post_run","value":"# 后处理函数,可选。输入是主函数的输出,可以在这里对数据做处理,或者返回更友好的outputs数据格式。此函数输出不会被缓存。\ndef bigquant_run(outputs):\n return outputs\n","type":"Literal","bound_global_parameter":null},{"name":"input_ports","value":"","type":"Literal","bound_global_parameter":null},{"name":"params","value":"{}","type":"Literal","bound_global_parameter":null},{"name":"output_ports","value":"","type":"Literal","bound_global_parameter":null}],"input_ports":[{"name":"input_1","node_id":"-179"},{"name":"input_2","node_id":"-179"},{"name":"input_3","node_id":"-179"}],"output_ports":[{"name":"data_1","node_id":"-179"},{"name":"data_2","node_id":"-179"},{"name":"data_3","node_id":"-179"}],"cacheable":true,"seq_num":21,"comment":"","comment_collapsed":true},{"node_id":"-635","module_id":"BigQuantSpace.sort.sort-v4","parameters":[{"name":"sort_by","value":"pred_label","type":"Literal","bound_global_parameter":null},{"name":"group_by","value":"date","type":"Literal","bound_global_parameter":null},{"name":"keep_columns","value":"--","type":"Literal","bound_global_parameter":null},{"name":"ascending","value":"False","type":"Literal","bound_global_parameter":null}],"input_ports":[{"name":"input_ds","node_id":"-635"},{"name":"sort_by_ds","node_id":"-635"}],"output_ports":[{"name":"sorted_data","node_id":"-635"}],"cacheable":true,"seq_num":22,"comment":"","comment_collapsed":true}],"node_layout":"<node_postions><node_position Node='-670' Position='1401.8892822265625,1466.3917236328125,200,200'/><node_position Node='-179' Position='1105.8525390625,1112.7852783203125,200,200'/><node_position Node='-635' Position='1200.6519165039062,1308.4397583007812,200,200'/></node_postions>"},"nodes_readonly":false,"studio_version":"v2"}
    In [53]:
    # 本代码由可视化策略环境自动生成 2023年1月13日 19:58
    # 本代码单元只能在可视化模式下编辑。您也可以拷贝代码,粘贴到新建的代码单元或者策略,然后修改。
    
    
    import random
    import datetime
    import math
    import pandas as pd
    
    class XDebug:
        @staticmethod
        def Check(bFlag, strOutput = None):
            if bFlag:
                return
            if strOutput is not None:
                raise Exception(strOutput)
            else:
                raise
    
    class XNum:
        @staticmethod
        def IsNA(x):
            if x  is None:
                return(True)
            if pd.isna(x):
                return(True)
            return(False)
    
        @staticmethod
        def IsNum(x):
            if XNum.IsNA(x):
                return False
    
            if type(x) == type(1.0) or type(x) == type(1):
                return(True)
    
    
            return(False)
    
        @staticmethod
        def Test():
            XDebug.Check( XNum.IsNum(100))
    
    
    
    
    class XStable:
        @staticmethod
        def WOE(df, col, y='xxx'):
            """
            df:数据集
            col:特征名
            y:样本定义根据的列名(1:黑样本,0:白样本)
            """
            # 黑样本
            black_cnt = df.groupby(col)[y].sum()
            # 白样本
            white_cnt = df.groupby(col)[y].sum()
    
            # 所有黑样本
            black_cnt_total = df[y].sum()
            # 所有白样本
            white_cnt_total = df.shape(0) - df[y].sum()
    
            # pyi
            pyi = black_cnt / black_cnt_total
            # pni
            pni = white_cnt / white_cnt_total
    
            # woe
            woe = (pyi / pni).map(lambda x: math.log(x))
    
            return woe, pyi, pni
    
        @staticmethod
        # 逻辑代码
        def IV(df, col, y='xxx'):
            """
            df:数据集
            col:特征名
            y:样本定义根据的列名(1:黑样本,0:白样本)
            """
            # 获取woe、pyi、pni
            woe, pyi, pni = WOE(df, col, y)
            # 计算特征每个分箱的iv值
            iv_i = (pyi - pni) * woe
            # 返回该特征的iv值
            return iv_i.sum()
    
        @staticmethod
        def PSI(arrReal, arrExpect)->float:
            ##psi = sum((实际占比-预期占比)* ln(实际占比/预期占比))
            arrSplit = XList.CalcBinsByFreq(arrExpect, nBins = 10)
            arrBins = XList.Value2Bins(arrExpect,arrSplit )
            arrBinsReal = XList.Value2Bins(arrReal, arrSplit)
    
            mFreq = XList.BinsFreq(arrBins)
            mFreqReal = XList.BinsFreq(arrBinsReal)
    
            s = 0
            for x in mFreq.keys():
                freqReal = mFreqReal.get(x, 0)
                freq = mFreq.get(x, 0 )
                if freq > 0.0001:
                    part0 = freqReal / freq
                    part1 = (freqReal - freq)
                    if part0 > 0:
                        part2= math.log( part0 , math.e)
                        v = part1 * part2
                        assert(v >= 0)
                        s+= v
                else:
                    s+= 0.2
    
            return(s)
    
        @staticmethod
        def Test(n=1000):
            arr = [None] * n
            arr = list(map(lambda x:random.random(), arr))
    
            arr2 = [None] * n
            arr2 = list(map(lambda x: random.random(), arr2))
    
            print(XStable.PSI(arr, arr2))
    
    
    class XList:
        @staticmethod
        def IsNumber(series):
            arr = list(map(lambda x: XNum.IsNum(x) or XNum.IsNA(x), series))
            arrTrue = XLambda.FilterTrueIndex(arr)
            if len(arrTrue) == len(series):
                return (True)
            return False
    
        @staticmethod
        def CalcBinsByFreq(arr, nBins=10, bNeedSort=True, ascending=True):
            """
            按频次切 bin
            :param arr:
            :param nBins:
            :param bNeedSort:
            :param ascending:
            :return:
            """
    
            nBins += 1
    
            arr2 = list(arr)
            if bNeedSort:
                arr2 = sorted(arr2, reverse=not ascending)
    
            N = len(arr2)
            step = math.floor(N / nBins)
    
            currPos = 0
    
            ret = []
    
            for i in range(1, nBins):
                currPos = i * step
    
                value = arr2[currPos]
    
                if ascending:
                    if len(ret) == 0 or value > ret[-1]:
                        ret.append(value)
                else:
                    if len(ret) == 0 or value < ret[-1]:
                        ret.append(value)
            return (ret)
    
        @staticmethod
        def BinsFreq(arr):
            m = {}
            for x in arr:
                m[x] = m.get(x, 0) + 1
    
            for x in m.keys():
                m[x] = m[x] / len(arr)
            return m
    
        @staticmethod
        def Value2Bins(arr, arrSplit ,  bNeedSort=True, ascending=True):
            """
            按频次切 bin, 并把值映射到bin上
            :param arr:
            :param nBins:
            :param bNeedSort:
            :param ascending:
            :return:
            """
    
            ##arrSplit = XList.CalcBinsByFreq(arr, nBins, bNeedSort=bNeedSort, ascending=ascending)
    
            ret = []
            N = len(arrSplit)
            for x in arr:
                i = 0
                if ascending:
                    while i < N:
                        if x > arrSplit[i]:
                            i += 1
                        else:
                            break
                else:
                    while i < N:
                        if x < arrSplit[i]:
                            i += 1
                        else:
                            break
    
                ret.append(i)
            return ret
    
        @staticmethod
        def Test():
            a = [None] * 1000
            a = list(map(lambda x: random.randint(100,1000), a))
            arrSplit = XList.CalcBinsByFreq(a)
            print(arrSplit)
            c = XList.Value2Bins(a, arrSplit)
            print(c)
    
    class XPickle:
        @staticmethod
        def WriteCsv(df, fileName):
            path = "/home/bigquant/work/userlib/"
            ##pd.DataFrame([input_1.read()]).to_pickle(path + fileName)
            df.to_pickle(path + fileName)
    
        @staticmethod
        def ReadCsv(fileName):
            path = "/home/bigquant/work/userlib/"
            df = pd.read_pickle(path + fileName)  ##("/home/bigquant/work/userlib/data_more_feature/2018-01-01.csv")
            return df
    
    
    class XRandom:
        @staticmethod
        def SampleRows(df, nRowSample):
            nRow = df.shape[0]
            arrSample = random.sample(range(0, nRow), nRowSample)
            arrRemain = list(set(range(0, nRow)) - set(arrSample))
            return df.iloc[arrSample, :], df.iloc[arrRemain, :]
    
    
    class XTime:
        @staticmethod
        def AddMonths(currDay, nMonth):
            date = datetime.datetime.strptime(currDay, "%Y-%m-%d")
            print(date)
            date2 = date + datetime.timedelta(days=nMonth * 31)
            ret = datetime.datetime.strftime(date2, "%Y-%m-01")
            return ret
    
        @staticmethod
        def ToDate(strDate1):
            date1 = datetime.datetime.strptime(strDate1, "%Y-%m-%d")
            return(date1)
    
        @staticmethod
        def PDTime2Str(date1, formatStr = "%Y-%m-%d"):
            return date1.strftime(formatStr)
    
        @staticmethod
        def DateSmallerThan(strDate1, strDate2):
            """
            return (date1 < date2)
    
            :param strDate1:
            :param strDate2:
            :return:
            """
            date1 = XTime.ToDate(strDate1)
            date2 = XTime.ToDate(strDate2)
            return (date1 < date2)
    
        @staticmethod
        def Test():
            XDebug.Check(XTime.DateSmallerThan("2018-01-01", "2018-01-02"))
    
    
    
    
    class XPandas:
        @staticmethod
        def FilterColumns(df, arrColNames):
            colNames = XPandas.GetColumnNames(df)
            arr = []
            nFound = 0
            for i in range(0, len(colNames)):
                if colNames[i] in arrColNames:
                    nFound += 1
                    arr.append(i)
            XDebug.Check(nFound == len(arrColNames))
            return df.iloc[:, arr].copy()
    
        @staticmethod
        def SortByColumn(df, byColNames = [], inplace = True, ascending = True):
            if inplace:
                df = df.sort_values(by = byColNames, inplace = inplace, ascending = ascending)
                return(df)
            else:
                df2 = df.sort_values(by=byColNames, inplace=inplace, ascending=ascending)
                return(df2)
    
        @staticmethod
        def SampleRows(df, nRow):
            N = df.shape[0]
            sampleArr = random.sample(list(range(0,N)) , nRow)
            remainArr =  list(set(list(range(0,N))) - set(sampleArr))
            return df.iloc[sampleArr, :], df.iloc[remainArr, :]
    
        @staticmethod
        def GetColumnNames(df):
            return df.columns
    
        @staticmethod
        def MyJoin(data1, data2, onKey=["instrument", "date"], how="left"):
            name1 = list(data1.columns)
            name2 = list(data2.columns)
    
            nameDiff = list(set(name2) - set(name1))
    
            ##print(nameDiff)
    
            nameDiff.append("instrument")
            nameDiff.append("date")
    
            data3 = data2[nameDiff]
    
            ##data4 = data1.join(data3, on =["instrument", "date"] ,how = "left")  ## join
            data4 = pd.merge(data1, data3, on=onKey, how=how)
    
            return (data4)
    
    
    class XHash:
        @staticmethod
        def MyHash(s: str) -> int:
            return int(hash(s))
    
    
    class XLambda:
        @staticmethod
        def FilterTrueIndex(arr: list) -> list:
    
            ret = []
            for i in range(0, len(arr)):
                if arr[i] == True:
                    ret.append(i)
            return (ret)
    
    
    class XSample:
        @staticmethod
        def Sample10Pct(data1):
            arr1 = zip(list(data1["instrument"]), list(data1["date"]))
            arr2 = list(map(lambda x: True if XUtil.MyHash(str(x)) % 10 == 1 else False, arr1))
            arr3 = XLambda.FilterTrueIndex(arr2)
            data2 = data1.iloc[arr3, :]
            return (data2)
    
    
    
    # Python 代码入口函数,input_1/2/3 对应三个输入端,data_1/2/3 对应三个输出端
    def m21_run_bigquant_run(input_1, input_2, input_3):
        
        
        pd3 = XPickle.ReadCsv("error_check2.csv")
    
        ds3 = DataSource.write_pickle(pd3)
        print("end222..")
        return Outputs(data_1 = ds3 , data_2 = None, data_3 = None)
    
    
    # 后处理函数,可选。输入是主函数的输出,可以在这里对数据做处理,或者返回更友好的outputs数据格式。此函数输出不会被缓存。
    def m21_post_run_bigquant_run(outputs):
        return outputs
    
    
    m21 = M.cached.v3(
        run=m21_run_bigquant_run,
        post_run=m21_post_run_bigquant_run,
        input_ports='',
        params='{}',
        output_ports=''
    )
    
    m22 = M.sort.v4(
        input_ds=m21.data_1,
        sort_by='pred_label',
        group_by='date',
        keep_columns='--',
        ascending=False
    )
    
    m12 = M.metrics_classification.v1(
        predictions=m22.sorted_data
    )
    
    ---------------------------------------------------------------------------
    HDF5ExtError                              Traceback (most recent call last)
    HDF5ExtError: HDF5 error back trace
    
      File "H5F.c", line 509, in H5Fopen
        unable to open file
      File "H5Fint.c", line 1400, in H5F__open
        unable to open file
      File "H5Fint.c", line 1700, in H5F_open
        unable to read superblock
      File "H5Fsuper.c", line 411, in H5F__super_read
        file signature not found
    
    End of HDF5 error back trace
    
    Unable to open/create file '/tmp/data.h5'
    
    During handling of the above exception, another exception occurred:
    
    IndexError                                Traceback (most recent call last)
    <ipython-input-53-0d8a821327f7> in <module>
        390 )
        391 
    --> 392 m22 = M.sort.v4(
        393     input_ds=m21.data_1,
        394     sort_by='pred_label',
    
    IndexError: tuple index out of range