{ "cells": [ { "cell_type": "code", "execution_count": 37, "id": "hourly-fever", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "plt.rcParams['figure.figsize'] = (16.0, 9.0)\n", "\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.metrics import mean_squared_error, make_scorer\n", "from sklearn.model_selection import GridSearchCV\n", "\n", "from sklearn.linear_model import LinearRegression, HuberRegressor, SGDRegressor\n", "from sklearn.cross_decomposition import PLSRegression\n", "from sklearn.decomposition import PCA\n", "# from sklearn.tree import DecisionTreeRegressor\n", "from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor\n", "from sklearn.base import BaseEstimator, RegressorMixin, TransformerMixin\n", "from sklearn.pipeline import Pipeline\n", "\n", "# import lightgbm as lgb\n", "\n", "import tensorflow as tf\n", "from tensorflow import keras" ] }, { "cell_type": "markdown", "id": "under-restaurant", "metadata": {}, "source": [ "- [Data](#data)\n", " - [NA 值处理](#na-值处理)\n", " - [Use rank instead of numerical values](#use-rank-instead-of-numerical-values)\n", "- [Train, Validation, Test split](#train-validation-test-split)\n", "- [Evaluation metrics](#evaluation-metrics)\n", "- [Models](#models)\n", " - [Linear regression](#linear-regression)\n", " - [Huber regressor](#huber-regressor)\n", " - [Random Forest](#random-forest)\n", " - [Partial Least Squares](#partial-least-squares)\n", " - [Principal Component Regression](#principal-component-regression)\n", " - [PCA transform](#pca-transform)\n", " - [PCA regression](#pca-regression)\n", " - [Pipeline](#pipeline)\n", " - [Elastic Net](#elastic-net)\n", " - [Gradient Boosted Regression Trees](#gradient-boosted-regression-trees)\n", " - [Neural Nets](#neural-nets)\n", " - [GridSeachCV Neural Nets](#gridseachcv-neural-nets)\n", "- [Transformation pipeline example](#transformation-pipeline-example)" ] }, { "cell_type": "markdown", "id": "intended-belize", "metadata": {}, "source": [ "# Data" ] }, { "cell_type": "code", "execution_count": 38, "id": "broken-matthew", "metadata": {}, "outputs": [], "source": [ "df = pd.read_pickle('../../data/factor_exposure/all_exposure.pkl')" ] }, { "cell_type": "code", "execution_count": 39, "id": "postal-medicaid", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
secIDret_datetradeDateretrfexretymmktcapsizerevbetabmilliqilliq_12mmom_datemomvolivolvol_clipivol_clip
0000001.XSHE2007-072007-06-290.3164970.0024810.3140162007-064.266117e+1024.476555NaN0.46140.123739NaNNaNNaTNaNNaNNaNNaNNaN
1000001.XSHE2007-082007-07-310.0488550.0024040.0464512007-075.616330e+1024.7515290.3140160.64230.0939920.000040NaN2007-06NaN0.042521NaN0.042521NaN
2000001.XSHE2007-092007-08-310.0521050.0026210.0494842007-085.890714e+1024.7992280.0464510.77220.0970850.000020NaN2007-07NaN0.033926NaN0.033926NaN
3000001.XSHE2007-102007-09-280.2018510.0030950.1987562007-096.197651e+1024.8500210.0494840.75960.0922760.000025NaN2007-08NaN0.023872NaN0.023872NaN
4000001.XSHE2007-112007-10-31-0.2491160.003780-0.2528962007-107.448652e+1025.0338840.1987560.79880.0834110.000030NaN2007-09NaN0.035921NaN0.035921NaN
...............................................................
504875900957.XSHG2021-122021-11-300.0358310.0020260.0338052021-111.120560e+0818.534509-0.042588NaNNaN0.0700560.0628842021-100.2167300.0096390.0070460.0096390.007046
504876900957.XSHG2022-012021-12-31-0.0220130.002014-0.0240272021-121.161040e+0818.5699970.033805NaNNaN0.0780370.0596722021-110.2110450.0109610.0086920.0109610.008692
504877900957.XSHG2022-022022-01-28-0.0112540.001921-0.0131752022-011.135280e+0818.547560-0.024027NaNNaN0.0445150.0585022021-12-0.0591720.0105590.0084090.0105590.008409
504878900957.XSHG2022-032022-02-28-0.0341460.001919-0.0360662022-021.122400e+0818.536150-0.013175NaNNaN0.0572180.0602082022-01-0.1571820.0065170.0041950.0065170.004195
504879900957.XSHGNaT2022-03-14NaNNaNNaN2022-031.083760e+0818.501117-0.036066NaNNaNNaN0.0624422022-02-0.117647NaNNaNNaNNaN
\n", "

504880 rows × 20 columns

\n", "
" ], "text/plain": [ " secID ret_date tradeDate ret rf exret \\\n", "0 000001.XSHE 2007-07 2007-06-29 0.316497 0.002481 0.314016 \n", "1 000001.XSHE 2007-08 2007-07-31 0.048855 0.002404 0.046451 \n", "2 000001.XSHE 2007-09 2007-08-31 0.052105 0.002621 0.049484 \n", "3 000001.XSHE 2007-10 2007-09-28 0.201851 0.003095 0.198756 \n", "4 000001.XSHE 2007-11 2007-10-31 -0.249116 0.003780 -0.252896 \n", "... ... ... ... ... ... ... \n", "504875 900957.XSHG 2021-12 2021-11-30 0.035831 0.002026 0.033805 \n", "504876 900957.XSHG 2022-01 2021-12-31 -0.022013 0.002014 -0.024027 \n", "504877 900957.XSHG 2022-02 2022-01-28 -0.011254 0.001921 -0.013175 \n", "504878 900957.XSHG 2022-03 2022-02-28 -0.034146 0.001919 -0.036066 \n", "504879 900957.XSHG NaT 2022-03-14 NaN NaN NaN \n", "\n", " ym mktcap size rev beta bm \\\n", "0 2007-06 4.266117e+10 24.476555 NaN 0.4614 0.123739 \n", "1 2007-07 5.616330e+10 24.751529 0.314016 0.6423 0.093992 \n", "2 2007-08 5.890714e+10 24.799228 0.046451 0.7722 0.097085 \n", "3 2007-09 6.197651e+10 24.850021 0.049484 0.7596 0.092276 \n", "4 2007-10 7.448652e+10 25.033884 0.198756 0.7988 0.083411 \n", "... ... ... ... ... ... ... \n", "504875 2021-11 1.120560e+08 18.534509 -0.042588 NaN NaN \n", "504876 2021-12 1.161040e+08 18.569997 0.033805 NaN NaN \n", "504877 2022-01 1.135280e+08 18.547560 -0.024027 NaN NaN \n", "504878 2022-02 1.122400e+08 18.536150 -0.013175 NaN NaN \n", "504879 2022-03 1.083760e+08 18.501117 -0.036066 NaN NaN \n", "\n", " illiq illiq_12m mom_date mom vol ivol vol_clip \\\n", "0 NaN NaN NaT NaN NaN NaN NaN \n", "1 0.000040 NaN 2007-06 NaN 0.042521 NaN 0.042521 \n", "2 0.000020 NaN 2007-07 NaN 0.033926 NaN 0.033926 \n", "3 0.000025 NaN 2007-08 NaN 0.023872 NaN 0.023872 \n", "4 0.000030 NaN 2007-09 NaN 0.035921 NaN 0.035921 \n", "... ... ... ... ... ... ... ... \n", "504875 0.070056 0.062884 2021-10 0.216730 0.009639 0.007046 0.009639 \n", "504876 0.078037 0.059672 2021-11 0.211045 0.010961 0.008692 0.010961 \n", "504877 0.044515 0.058502 2021-12 -0.059172 0.010559 0.008409 0.010559 \n", "504878 0.057218 0.060208 2022-01 -0.157182 0.006517 0.004195 0.006517 \n", "504879 NaN 0.062442 2022-02 -0.117647 NaN NaN NaN \n", "\n", " ivol_clip \n", "0 NaN \n", "1 NaN \n", "2 NaN \n", "3 NaN \n", "4 NaN \n", "... ... \n", "504875 0.007046 \n", "504876 0.008692 \n", "504877 0.008409 \n", "504878 0.004195 \n", "504879 NaN \n", "\n", "[504880 rows x 20 columns]" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "markdown", "id": "about-wesley", "metadata": {}, "source": [ "## NA 值处理" ] }, { "cell_type": "code", "execution_count": 40, "id": "47a01179-9a08-475c-875c-30c150c8ac49", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "secID 0\n", "ret_date 4853\n", "tradeDate 0\n", "ret 31888\n", "rf 4853\n", "exret 31888\n", "ym 0\n", "mktcap 23011\n", "size 23011\n", "rev 30586\n", "beta 40704\n", "bm 21512\n", "illiq 41579\n", "illiq_12m 91808\n", "mom_date 3547\n", "mom 49225\n", "vol 30782\n", "ivol 54678\n", "vol_clip 30782\n", "ivol_clip 54678\n" ] } ], "source": [ "for col in df.columns:\n", " print(col, df[col].isna().sum())" ] }, { "cell_type": "markdown", "id": "equipped-meaning", "metadata": {}, "source": [ "ret_date 为 NA 的删除,已到最新数据处" ] }, { "cell_type": "code", "execution_count": 41, "id": "periodic-london", "metadata": {}, "outputs": [], "source": [ "df = df[~df['ret_date'].isna()].copy()" ] }, { "cell_type": "code", "execution_count": 42, "id": "cubic-console", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
secIDret_datetradeDateretrfexretymmktcapsizerevbetabmilliqilliq_12mmom_datemomvolivolvol_clipivol_clip
0000001.XSHE2007-072007-06-290.3164970.0024810.3140162007-064.266117e+1024.476555NaN0.46140.123739NaNNaNNaTNaNNaNNaNNaNNaN
1000001.XSHE2007-082007-07-310.0488550.0024040.0464512007-075.616330e+1024.7515290.3140160.64230.0939920.000040NaN2007-06NaN0.042521NaN0.042521NaN
2000001.XSHE2007-092007-08-310.0521050.0026210.0494842007-085.890714e+1024.7992280.0464510.77220.0970850.000020NaN2007-07NaN0.033926NaN0.033926NaN
3000001.XSHE2007-102007-09-280.2018510.0030950.1987562007-096.197651e+1024.8500210.0494840.75960.0922760.000025NaN2007-08NaN0.023872NaN0.023872NaN
4000001.XSHE2007-112007-10-31-0.2491160.003780-0.2528962007-107.448652e+1025.0338840.1987560.79880.0834110.000030NaN2007-09NaN0.035921NaN0.035921NaN
...............................................................
504874900957.XSHG2021-112021-10-29-0.0406250.001963-0.0425882021-101.168400e+0818.576316-0.042478NaNNaN0.0584570.0676462021-090.2851640.0116630.0077000.0116630.007700
504875900957.XSHG2021-122021-11-300.0358310.0020260.0338052021-111.120560e+0818.534509-0.042588NaNNaN0.0700560.0628842021-100.2167300.0096390.0070460.0096390.007046
504876900957.XSHG2022-012021-12-31-0.0220130.002014-0.0240272021-121.161040e+0818.5699970.033805NaNNaN0.0780370.0596722021-110.2110450.0109610.0086920.0109610.008692
504877900957.XSHG2022-022022-01-28-0.0112540.001921-0.0131752022-011.135280e+0818.547560-0.024027NaNNaN0.0445150.0585022021-12-0.0591720.0105590.0084090.0105590.008409
504878900957.XSHG2022-032022-02-28-0.0341460.001919-0.0360662022-021.122400e+0818.536150-0.013175NaNNaN0.0572180.0602082022-01-0.1571820.0065170.0041950.0065170.004195
\n", "

500027 rows × 20 columns

\n", "
" ], "text/plain": [ " secID ret_date tradeDate ret rf exret \\\n", "0 000001.XSHE 2007-07 2007-06-29 0.316497 0.002481 0.314016 \n", "1 000001.XSHE 2007-08 2007-07-31 0.048855 0.002404 0.046451 \n", "2 000001.XSHE 2007-09 2007-08-31 0.052105 0.002621 0.049484 \n", "3 000001.XSHE 2007-10 2007-09-28 0.201851 0.003095 0.198756 \n", "4 000001.XSHE 2007-11 2007-10-31 -0.249116 0.003780 -0.252896 \n", "... ... ... ... ... ... ... \n", "504874 900957.XSHG 2021-11 2021-10-29 -0.040625 0.001963 -0.042588 \n", "504875 900957.XSHG 2021-12 2021-11-30 0.035831 0.002026 0.033805 \n", "504876 900957.XSHG 2022-01 2021-12-31 -0.022013 0.002014 -0.024027 \n", "504877 900957.XSHG 2022-02 2022-01-28 -0.011254 0.001921 -0.013175 \n", "504878 900957.XSHG 2022-03 2022-02-28 -0.034146 0.001919 -0.036066 \n", "\n", " ym mktcap size rev beta bm \\\n", "0 2007-06 4.266117e+10 24.476555 NaN 0.4614 0.123739 \n", "1 2007-07 5.616330e+10 24.751529 0.314016 0.6423 0.093992 \n", "2 2007-08 5.890714e+10 24.799228 0.046451 0.7722 0.097085 \n", "3 2007-09 6.197651e+10 24.850021 0.049484 0.7596 0.092276 \n", "4 2007-10 7.448652e+10 25.033884 0.198756 0.7988 0.083411 \n", "... ... ... ... ... ... ... \n", "504874 2021-10 1.168400e+08 18.576316 -0.042478 NaN NaN \n", "504875 2021-11 1.120560e+08 18.534509 -0.042588 NaN NaN \n", "504876 2021-12 1.161040e+08 18.569997 0.033805 NaN NaN \n", "504877 2022-01 1.135280e+08 18.547560 -0.024027 NaN NaN \n", "504878 2022-02 1.122400e+08 18.536150 -0.013175 NaN NaN \n", "\n", " illiq illiq_12m mom_date mom vol ivol vol_clip \\\n", "0 NaN NaN NaT NaN NaN NaN NaN \n", "1 0.000040 NaN 2007-06 NaN 0.042521 NaN 0.042521 \n", "2 0.000020 NaN 2007-07 NaN 0.033926 NaN 0.033926 \n", "3 0.000025 NaN 2007-08 NaN 0.023872 NaN 0.023872 \n", "4 0.000030 NaN 2007-09 NaN 0.035921 NaN 0.035921 \n", "... ... ... ... ... ... ... ... \n", "504874 0.058457 0.067646 2021-09 0.285164 0.011663 0.007700 0.011663 \n", "504875 0.070056 0.062884 2021-10 0.216730 0.009639 0.007046 0.009639 \n", "504876 0.078037 0.059672 2021-11 0.211045 0.010961 0.008692 0.010961 \n", "504877 0.044515 0.058502 2021-12 -0.059172 0.010559 0.008409 0.010559 \n", "504878 0.057218 0.060208 2022-01 -0.157182 0.006517 0.004195 0.006517 \n", "\n", " ivol_clip \n", "0 NaN \n", "1 NaN \n", "2 NaN \n", "3 NaN \n", "4 NaN \n", "... ... \n", "504874 0.007700 \n", "504875 0.007046 \n", "504876 0.008692 \n", "504877 0.008409 \n", "504878 0.004195 \n", "\n", "[500027 rows x 20 columns]" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "markdown", "id": "biological-wayne", "metadata": {}, "source": [ "momentum 从 2008-01 开始。简单起见,把所有数据调整为从2008-01开始。" ] }, { "cell_type": "code", "execution_count": 43, "id": "crazy-flash", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Period('2008-01', 'M')" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.loc[~df['mom'].isna(),'ret_date'].min()" ] }, { "cell_type": "code", "execution_count": 44, "id": "little-evaluation", "metadata": {}, "outputs": [], "source": [ "df = df[df['ret_date'] >= '2008-01'].copy()" ] }, { "cell_type": "code", "execution_count": 45, "id": "likely-estimate", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "secID 0\n", "ret_date 0\n", "tradeDate 0\n", "ret 26378\n", "rf 0\n", "exret 26378\n", "ym 0\n", "mktcap 22483\n", "size 22483\n", "rev 29845\n", "beta 39159\n", "bm 20363\n", "illiq 36007\n", "illiq_12m 79832\n", "mom_date 3381\n", "mom 36211\n", "vol 25455\n", "ivol 37368\n", "vol_clip 25455\n", "ivol_clip 37368\n" ] } ], "source": [ "for col in df.columns:\n", " print(col, df[col].isna().sum())" ] }, { "cell_type": "markdown", "id": "liable-agreement", "metadata": {}, "source": [ "剩余的NA值有至少三个来源:\n", "- 由于停牌日期填充造成,\n", "- 由于计算时要求最低样本数造成,\n", "- 由优矿直接给出了NA值" ] }, { "cell_type": "markdown", "id": "minor-pressing", "metadata": {}, "source": [ "return 的 NA 值直接删除" ] }, { "cell_type": "code", "execution_count": 46, "id": "elementary-sixth", "metadata": {}, "outputs": [], "source": [ "df = df[~df['ret'].isna()].copy()" ] }, { "cell_type": "code", "execution_count": 47, "id": "agreed-poison", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
secIDret_datetradeDateretrfexretymmktcapsizerevbetabmilliqilliq_12mmom_datemomvolivolvol_clipivol_clip
6000001.XSHE2008-012007-12-28-0.1373060.002949-0.1402552007-126.574629e+1024.9090690.0668340.94680.0944760.000025NaN2007-11NaN0.027254NaN0.027254NaN
7000001.XSHE2008-022008-01-31-0.0045040.002946-0.0074502008-015.850212e+1024.792329-0.1402550.96540.1095130.000039NaN2007-12NaN0.0377220.0132660.0377220.013266
8000001.XSHE2008-032008-02-29-0.1493210.002746-0.1520682008-025.823860e+1024.787814-0.0074501.02920.1100090.000064NaN2008-01NaN0.0414480.0094740.0414480.009474
9000001.XSHE2008-042008-03-310.0503550.0028620.0474932008-034.954234e+1024.626093-0.1520681.02380.2011020.000043NaN2008-02NaN0.0451090.0217460.0451090.021746
10000001.XSHE2008-052008-04-30-0.1482110.002953-0.1511642008-045.203702e+1024.6752210.0474931.02120.2067010.0000510.0000382008-03NaN0.0463230.0144740.0463230.014474
...............................................................
504874900957.XSHG2021-112021-10-29-0.0406250.001963-0.0425882021-101.168400e+0818.576316-0.042478NaNNaN0.0584570.0676462021-090.2851640.0116630.0077000.0116630.007700
504875900957.XSHG2021-122021-11-300.0358310.0020260.0338052021-111.120560e+0818.534509-0.042588NaNNaN0.0700560.0628842021-100.2167300.0096390.0070460.0096390.007046
504876900957.XSHG2022-012021-12-31-0.0220130.002014-0.0240272021-121.161040e+0818.5699970.033805NaNNaN0.0780370.0596722021-110.2110450.0109610.0086920.0109610.008692
504877900957.XSHG2022-022022-01-28-0.0112540.001921-0.0131752022-011.135280e+0818.547560-0.024027NaNNaN0.0445150.0585022021-12-0.0591720.0105590.0084090.0105590.008409
504878900957.XSHG2022-032022-02-28-0.0341460.001919-0.0360662022-021.122400e+0818.536150-0.013175NaNNaN0.0572180.0602082022-01-0.1571820.0065170.0041950.0065170.004195
\n", "

461021 rows × 20 columns

\n", "
" ], "text/plain": [ " secID ret_date tradeDate ret rf exret \\\n", "6 000001.XSHE 2008-01 2007-12-28 -0.137306 0.002949 -0.140255 \n", "7 000001.XSHE 2008-02 2008-01-31 -0.004504 0.002946 -0.007450 \n", "8 000001.XSHE 2008-03 2008-02-29 -0.149321 0.002746 -0.152068 \n", "9 000001.XSHE 2008-04 2008-03-31 0.050355 0.002862 0.047493 \n", "10 000001.XSHE 2008-05 2008-04-30 -0.148211 0.002953 -0.151164 \n", "... ... ... ... ... ... ... \n", "504874 900957.XSHG 2021-11 2021-10-29 -0.040625 0.001963 -0.042588 \n", "504875 900957.XSHG 2021-12 2021-11-30 0.035831 0.002026 0.033805 \n", "504876 900957.XSHG 2022-01 2021-12-31 -0.022013 0.002014 -0.024027 \n", "504877 900957.XSHG 2022-02 2022-01-28 -0.011254 0.001921 -0.013175 \n", "504878 900957.XSHG 2022-03 2022-02-28 -0.034146 0.001919 -0.036066 \n", "\n", " ym mktcap size rev beta bm \\\n", "6 2007-12 6.574629e+10 24.909069 0.066834 0.9468 0.094476 \n", "7 2008-01 5.850212e+10 24.792329 -0.140255 0.9654 0.109513 \n", "8 2008-02 5.823860e+10 24.787814 -0.007450 1.0292 0.110009 \n", "9 2008-03 4.954234e+10 24.626093 -0.152068 1.0238 0.201102 \n", "10 2008-04 5.203702e+10 24.675221 0.047493 1.0212 0.206701 \n", "... ... ... ... ... ... ... \n", "504874 2021-10 1.168400e+08 18.576316 -0.042478 NaN NaN \n", "504875 2021-11 1.120560e+08 18.534509 -0.042588 NaN NaN \n", "504876 2021-12 1.161040e+08 18.569997 0.033805 NaN NaN \n", "504877 2022-01 1.135280e+08 18.547560 -0.024027 NaN NaN \n", "504878 2022-02 1.122400e+08 18.536150 -0.013175 NaN NaN \n", "\n", " illiq illiq_12m mom_date mom vol ivol vol_clip \\\n", "6 0.000025 NaN 2007-11 NaN 0.027254 NaN 0.027254 \n", "7 0.000039 NaN 2007-12 NaN 0.037722 0.013266 0.037722 \n", "8 0.000064 NaN 2008-01 NaN 0.041448 0.009474 0.041448 \n", "9 0.000043 NaN 2008-02 NaN 0.045109 0.021746 0.045109 \n", "10 0.000051 0.000038 2008-03 NaN 0.046323 0.014474 0.046323 \n", "... ... ... ... ... ... ... ... \n", "504874 0.058457 0.067646 2021-09 0.285164 0.011663 0.007700 0.011663 \n", "504875 0.070056 0.062884 2021-10 0.216730 0.009639 0.007046 0.009639 \n", "504876 0.078037 0.059672 2021-11 0.211045 0.010961 0.008692 0.010961 \n", "504877 0.044515 0.058502 2021-12 -0.059172 0.010559 0.008409 0.010559 \n", "504878 0.057218 0.060208 2022-01 -0.157182 0.006517 0.004195 0.006517 \n", "\n", " ivol_clip \n", "6 NaN \n", "7 0.013266 \n", "8 0.009474 \n", "9 0.021746 \n", "10 0.014474 \n", "... ... \n", "504874 0.007700 \n", "504875 0.007046 \n", "504876 0.008692 \n", "504877 0.008409 \n", "504878 0.004195 \n", "\n", "[461021 rows x 20 columns]" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "code", "execution_count": 48, "id": "enhanced-garden", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "secID 0\n", "ret_date 0\n", "tradeDate 0\n", "ret 0\n", "rf 0\n", "exret 0\n", "ym 0\n", "mktcap 0\n", "size 0\n", "rev 7328\n", "beta 25845\n", "bm 16422\n", "illiq 11127\n", "illiq_12m 62624\n", "mom_date 3381\n", "mom 35755\n", "vol 2799\n", "ivol 12482\n", "vol_clip 2799\n", "ivol_clip 12482\n" ] } ], "source": [ "for col in df.columns:\n", " print(col, df[col].isna().sum())" ] }, { "cell_type": "code", "execution_count": 49, "id": "based-advertiser", "metadata": {}, "outputs": [], "source": [ "df.drop(['mom_date','mktcap','vol_clip','ivol_clip'],axis=1,inplace=True)" ] }, { "cell_type": "code", "execution_count": 50, "id": "backed-sharing", "metadata": {}, "outputs": [], "source": [ "df.drop(['ret','rf'],axis=1,inplace=True)" ] }, { "cell_type": "code", "execution_count": 51, "id": "appropriate-minority", "metadata": {}, "outputs": [], "source": [ "df.reset_index(inplace=True,drop=True)" ] }, { "cell_type": "code", "execution_count": 52, "id": "drawn-prompt", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
secIDret_datetradeDateexretymsizerevbetabmilliqilliq_12mmomvolivol
0000001.XSHE2008-012007-12-28-0.1402552007-1224.9090690.0668340.94680.0944760.000025NaNNaN0.027254NaN
1000001.XSHE2008-022008-01-31-0.0074502008-0124.792329-0.1402550.96540.1095130.000039NaNNaN0.0377220.013266
2000001.XSHE2008-032008-02-29-0.1520682008-0224.787814-0.0074501.02920.1100090.000064NaNNaN0.0414480.009474
3000001.XSHE2008-042008-03-310.0474932008-0324.626093-0.1520681.02380.2011020.000043NaNNaN0.0451090.021746
4000001.XSHE2008-052008-04-30-0.1511642008-0424.6752210.0474931.02120.2067010.0000510.000038NaN0.0463230.014474
.............................................
461016900957.XSHG2021-112021-10-29-0.0425882021-1018.576316-0.042478NaNNaN0.0584570.0676460.2851640.0116630.007700
461017900957.XSHG2021-122021-11-300.0338052021-1118.534509-0.042588NaNNaN0.0700560.0628840.2167300.0096390.007046
461018900957.XSHG2022-012021-12-31-0.0240272021-1218.5699970.033805NaNNaN0.0780370.0596720.2110450.0109610.008692
461019900957.XSHG2022-022022-01-28-0.0131752022-0118.547560-0.024027NaNNaN0.0445150.058502-0.0591720.0105590.008409
461020900957.XSHG2022-032022-02-28-0.0360662022-0218.536150-0.013175NaNNaN0.0572180.060208-0.1571820.0065170.004195
\n", "

461021 rows × 14 columns

\n", "
" ], "text/plain": [ " secID ret_date tradeDate exret ym size \\\n", "0 000001.XSHE 2008-01 2007-12-28 -0.140255 2007-12 24.909069 \n", "1 000001.XSHE 2008-02 2008-01-31 -0.007450 2008-01 24.792329 \n", "2 000001.XSHE 2008-03 2008-02-29 -0.152068 2008-02 24.787814 \n", "3 000001.XSHE 2008-04 2008-03-31 0.047493 2008-03 24.626093 \n", "4 000001.XSHE 2008-05 2008-04-30 -0.151164 2008-04 24.675221 \n", "... ... ... ... ... ... ... \n", "461016 900957.XSHG 2021-11 2021-10-29 -0.042588 2021-10 18.576316 \n", "461017 900957.XSHG 2021-12 2021-11-30 0.033805 2021-11 18.534509 \n", "461018 900957.XSHG 2022-01 2021-12-31 -0.024027 2021-12 18.569997 \n", "461019 900957.XSHG 2022-02 2022-01-28 -0.013175 2022-01 18.547560 \n", "461020 900957.XSHG 2022-03 2022-02-28 -0.036066 2022-02 18.536150 \n", "\n", " rev beta bm illiq illiq_12m mom vol \\\n", "0 0.066834 0.9468 0.094476 0.000025 NaN NaN 0.027254 \n", "1 -0.140255 0.9654 0.109513 0.000039 NaN NaN 0.037722 \n", "2 -0.007450 1.0292 0.110009 0.000064 NaN NaN 0.041448 \n", "3 -0.152068 1.0238 0.201102 0.000043 NaN NaN 0.045109 \n", "4 0.047493 1.0212 0.206701 0.000051 0.000038 NaN 0.046323 \n", "... ... ... ... ... ... ... ... \n", "461016 -0.042478 NaN NaN 0.058457 0.067646 0.285164 0.011663 \n", "461017 -0.042588 NaN NaN 0.070056 0.062884 0.216730 0.009639 \n", "461018 0.033805 NaN NaN 0.078037 0.059672 0.211045 0.010961 \n", "461019 -0.024027 NaN NaN 0.044515 0.058502 -0.059172 0.010559 \n", "461020 -0.013175 NaN NaN 0.057218 0.060208 -0.157182 0.006517 \n", "\n", " ivol \n", "0 NaN \n", "1 0.013266 \n", "2 0.009474 \n", "3 0.021746 \n", "4 0.014474 \n", "... ... \n", "461016 0.007700 \n", "461017 0.007046 \n", "461018 0.008692 \n", "461019 0.008409 \n", "461020 0.004195 \n", "\n", "[461021 rows x 14 columns]" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "markdown", "id": "actual-standard", "metadata": {}, "source": [ "- reversal 的 NA 是由于在对应的return date,上个月停牌所以没有上个月的return。\n", "- beta, bm 是优矿的NA。可以用当月的横截面上的中值填充\n", "- illiq, ivol, vol 也可用当月的横截面上的中值填充." ] }, { "cell_type": "code", "execution_count": 53, "id": "nuclear-chassis", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "secID 0\n", "ret_date 0\n", "tradeDate 0\n", "exret 0\n", "ym 0\n", "size 0\n", "rev 7328\n", "beta 25845\n", "bm 16422\n", "illiq 11127\n", "illiq_12m 62624\n", "mom 35755\n", "vol 2799\n", "ivol 12482\n" ] } ], "source": [ "for col in df.columns:\n", " print(col, df[col].isna().sum())" ] }, { "cell_type": "code", "execution_count": 54, "id": "declared-blake", "metadata": {}, "outputs": [], "source": [ "# Reversal 的空值丢掉,其他的用 median 填充\n", "df = df[~df['rev'].isna()].copy()" ] }, { "cell_type": "code", "execution_count": 55, "id": "functional-finland", "metadata": {}, "outputs": [], "source": [ "cols = ['mom','beta','bm','illiq','illiq_12m','vol','ivol']" ] }, { "cell_type": "code", "execution_count": 56, "id": "personal-stylus", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
secIDret_datetradeDateexretymsizerevbetabmilliqilliq_12mmomvolivol
0000001.XSHE2008-012007-12-28-0.1402552007-1224.9090690.0668340.94680.0944760.000025NaNNaN0.027254NaN
1000001.XSHE2008-022008-01-31-0.0074502008-0124.792329-0.1402550.96540.1095130.000039NaNNaN0.0377220.013266
2000001.XSHE2008-032008-02-29-0.1520682008-0224.787814-0.0074501.02920.1100090.000064NaNNaN0.0414480.009474
3000001.XSHE2008-042008-03-310.0474932008-0324.626093-0.1520681.02380.2011020.000043NaNNaN0.0451090.021746
4000001.XSHE2008-052008-04-30-0.1511642008-0424.6752210.0474931.02120.2067010.0000510.000038NaN0.0463230.014474
.............................................
461016900957.XSHG2021-112021-10-29-0.0425882021-1018.576316-0.042478NaNNaN0.0584570.0676460.2851640.0116630.007700
461017900957.XSHG2021-122021-11-300.0338052021-1118.534509-0.042588NaNNaN0.0700560.0628840.2167300.0096390.007046
461018900957.XSHG2022-012021-12-31-0.0240272021-1218.5699970.033805NaNNaN0.0780370.0596720.2110450.0109610.008692
461019900957.XSHG2022-022022-01-28-0.0131752022-0118.547560-0.024027NaNNaN0.0445150.058502-0.0591720.0105590.008409
461020900957.XSHG2022-032022-02-28-0.0360662022-0218.536150-0.013175NaNNaN0.0572180.060208-0.1571820.0065170.004195
\n", "

453693 rows × 14 columns

\n", "
" ], "text/plain": [ " secID ret_date tradeDate exret ym size \\\n", "0 000001.XSHE 2008-01 2007-12-28 -0.140255 2007-12 24.909069 \n", "1 000001.XSHE 2008-02 2008-01-31 -0.007450 2008-01 24.792329 \n", "2 000001.XSHE 2008-03 2008-02-29 -0.152068 2008-02 24.787814 \n", "3 000001.XSHE 2008-04 2008-03-31 0.047493 2008-03 24.626093 \n", "4 000001.XSHE 2008-05 2008-04-30 -0.151164 2008-04 24.675221 \n", "... ... ... ... ... ... ... \n", "461016 900957.XSHG 2021-11 2021-10-29 -0.042588 2021-10 18.576316 \n", "461017 900957.XSHG 2021-12 2021-11-30 0.033805 2021-11 18.534509 \n", "461018 900957.XSHG 2022-01 2021-12-31 -0.024027 2021-12 18.569997 \n", "461019 900957.XSHG 2022-02 2022-01-28 -0.013175 2022-01 18.547560 \n", "461020 900957.XSHG 2022-03 2022-02-28 -0.036066 2022-02 18.536150 \n", "\n", " rev beta bm illiq illiq_12m mom vol \\\n", "0 0.066834 0.9468 0.094476 0.000025 NaN NaN 0.027254 \n", "1 -0.140255 0.9654 0.109513 0.000039 NaN NaN 0.037722 \n", "2 -0.007450 1.0292 0.110009 0.000064 NaN NaN 0.041448 \n", "3 -0.152068 1.0238 0.201102 0.000043 NaN NaN 0.045109 \n", "4 0.047493 1.0212 0.206701 0.000051 0.000038 NaN 0.046323 \n", "... ... ... ... ... ... ... ... \n", "461016 -0.042478 NaN NaN 0.058457 0.067646 0.285164 0.011663 \n", "461017 -0.042588 NaN NaN 0.070056 0.062884 0.216730 0.009639 \n", "461018 0.033805 NaN NaN 0.078037 0.059672 0.211045 0.010961 \n", "461019 -0.024027 NaN NaN 0.044515 0.058502 -0.059172 0.010559 \n", "461020 -0.013175 NaN NaN 0.057218 0.060208 -0.157182 0.006517 \n", "\n", " ivol \n", "0 NaN \n", "1 0.013266 \n", "2 0.009474 \n", "3 0.021746 \n", "4 0.014474 \n", "... ... \n", "461016 0.007700 \n", "461017 0.007046 \n", "461018 0.008692 \n", "461019 0.008409 \n", "461020 0.004195 \n", "\n", "[453693 rows x 14 columns]" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "code", "execution_count": 57, "id": "competitive-cream", "metadata": {}, "outputs": [], "source": [ "temp = df.groupby('ret_date',as_index=False)[cols].transform(lambda x: x.fillna(x.median()))" ] }, { "cell_type": "code", "execution_count": 58, "id": "exceptional-efficiency", "metadata": {}, "outputs": [], "source": [ "temp.fillna(0, inplace=True)" ] }, { "cell_type": "code", "execution_count": 59, "id": "inclusive-singing", "metadata": {}, "outputs": [], "source": [ "df[cols] = temp.copy()" ] }, { "cell_type": "code", "execution_count": 60, "id": "valid-wilson", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "secID 0\n", "ret_date 0\n", "tradeDate 0\n", "exret 0\n", "ym 0\n", "size 0\n", "rev 0\n", "beta 0\n", "bm 0\n", "illiq 0\n", "illiq_12m 0\n", "mom 0\n", "vol 0\n", "ivol 0\n" ] } ], "source": [ "for col in df.columns:\n", " print(col, df[col].isna().sum())" ] }, { "cell_type": "code", "execution_count": 61, "id": "caroline-simon", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
secIDret_datetradeDateexretymsizerevbetabmilliqilliq_12mmomvolivol
0000001.XSHE2008-012007-12-28-0.1402552007-1224.9090690.0668340.946800.0944760.0000250.0005360.7778140.0272540.000000
1000001.XSHE2008-022008-01-31-0.0074502008-0124.792329-0.1402550.965400.1095130.0000390.0005241.1191020.0377220.013266
2000001.XSHE2008-032008-02-29-0.1520682008-0224.787814-0.0074501.029200.1100090.0000640.0005270.6561200.0414480.009474
3000001.XSHE2008-042008-03-310.0474932008-0324.626093-0.1520681.023800.2011020.0000430.0005650.5452600.0451090.021746
4000001.XSHE2008-052008-04-30-0.1511642008-0424.6752210.0474931.021200.2067010.0000510.000038-0.0558890.0463230.014474
.............................................
461016900957.XSHG2021-112021-10-29-0.0425882021-1018.576316-0.0424780.470100.3754320.0584570.0676460.2851640.0116630.007700
461017900957.XSHG2021-122021-11-300.0338052021-1118.534509-0.0425880.469800.3324030.0700560.0628840.2167300.0096390.007046
461018900957.XSHG2022-012021-12-31-0.0240272021-1218.5699970.0338050.469100.3243540.0780370.0596720.2110450.0109610.008692
461019900957.XSHG2022-022022-01-28-0.0131752022-0118.547560-0.0240270.558300.3567160.0445150.058502-0.0591720.0105590.008409
461020900957.XSHG2022-032022-02-28-0.0360662022-0218.536150-0.0131750.635150.3426070.0572180.060208-0.1571820.0065170.004195
\n", "

453693 rows × 14 columns

\n", "
" ], "text/plain": [ " secID ret_date tradeDate exret ym size \\\n", "0 000001.XSHE 2008-01 2007-12-28 -0.140255 2007-12 24.909069 \n", "1 000001.XSHE 2008-02 2008-01-31 -0.007450 2008-01 24.792329 \n", "2 000001.XSHE 2008-03 2008-02-29 -0.152068 2008-02 24.787814 \n", "3 000001.XSHE 2008-04 2008-03-31 0.047493 2008-03 24.626093 \n", "4 000001.XSHE 2008-05 2008-04-30 -0.151164 2008-04 24.675221 \n", "... ... ... ... ... ... ... \n", "461016 900957.XSHG 2021-11 2021-10-29 -0.042588 2021-10 18.576316 \n", "461017 900957.XSHG 2021-12 2021-11-30 0.033805 2021-11 18.534509 \n", "461018 900957.XSHG 2022-01 2021-12-31 -0.024027 2021-12 18.569997 \n", "461019 900957.XSHG 2022-02 2022-01-28 -0.013175 2022-01 18.547560 \n", "461020 900957.XSHG 2022-03 2022-02-28 -0.036066 2022-02 18.536150 \n", "\n", " rev beta bm illiq illiq_12m mom vol \\\n", "0 0.066834 0.94680 0.094476 0.000025 0.000536 0.777814 0.027254 \n", "1 -0.140255 0.96540 0.109513 0.000039 0.000524 1.119102 0.037722 \n", "2 -0.007450 1.02920 0.110009 0.000064 0.000527 0.656120 0.041448 \n", "3 -0.152068 1.02380 0.201102 0.000043 0.000565 0.545260 0.045109 \n", "4 0.047493 1.02120 0.206701 0.000051 0.000038 -0.055889 0.046323 \n", "... ... ... ... ... ... ... ... \n", "461016 -0.042478 0.47010 0.375432 0.058457 0.067646 0.285164 0.011663 \n", "461017 -0.042588 0.46980 0.332403 0.070056 0.062884 0.216730 0.009639 \n", "461018 0.033805 0.46910 0.324354 0.078037 0.059672 0.211045 0.010961 \n", "461019 -0.024027 0.55830 0.356716 0.044515 0.058502 -0.059172 0.010559 \n", "461020 -0.013175 0.63515 0.342607 0.057218 0.060208 -0.157182 0.006517 \n", "\n", " ivol \n", "0 0.000000 \n", "1 0.013266 \n", "2 0.009474 \n", "3 0.021746 \n", "4 0.014474 \n", "... ... \n", "461016 0.007700 \n", "461017 0.007046 \n", "461018 0.008692 \n", "461019 0.008409 \n", "461020 0.004195 \n", "\n", "[453693 rows x 14 columns]" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "markdown", "id": "latin-intranet", "metadata": {}, "source": [ "## Use rank instead of numerical values" ] }, { "cell_type": "markdown", "id": "metropolitan-tumor", "metadata": {}, "source": [ "$$c_{i,t} = \\frac{2}{N+1}CSrank(c^r_{i,t}) - 1$$\n", "\n", "$c^r_{i,t}$ is the original value, $CSrank$ ranks the value with other firms in the same month t" ] }, { "cell_type": "code", "execution_count": 62, "id": "intimate-cooperative", "metadata": {}, "outputs": [], "source": [ "def csrank(df):\n", " return df.rank() * 2 / (len(df) + 1) - 1" ] }, { "cell_type": "code", "execution_count": 63, "id": "precise-vegetation", "metadata": {}, "outputs": [], "source": [ "num_X_cols = df.select_dtypes('number').columns.drop('exret').tolist()" ] }, { "cell_type": "code", "execution_count": 64, "id": "allied-senior", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['size', 'rev', 'beta', 'bm', 'illiq', 'illiq_12m', 'mom', 'vol', 'ivol']" ] }, "execution_count": 64, "metadata": {}, "output_type": "execute_result" } ], "source": [ "num_X_cols" ] }, { "cell_type": "code", "execution_count": 65, "id": "boring-humidity", "metadata": {}, "outputs": [], "source": [ "temp = df[['ret_date']+num_X_cols].groupby('ret_date').apply(csrank)" ] }, { "cell_type": "code", "execution_count": 66, "id": "packed-aquatic", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ret_datesizerevbetabmilliqilliq_12mmomvolivol
00.00.970696-0.8534800.123810-0.696703-0.9794870.0000000.0000000.3216120.000000
10.00.972444-0.412618-0.219724-0.641769-0.9608410.0000000.000000-0.601160-0.718637
20.00.968481-0.6647560.375358-0.588825-0.9584530.0000000.0000000.415473-0.616046
30.00.9699360.483178-0.079456-0.176807-0.9756620.0000000.0000000.2183250.400143
40.00.9655670.522238-0.301291-0.318508-0.971306-0.9770440.000000-0.707317-0.519369
.................................
4610160.0-0.990901-0.011374-0.0002270.0000000.9895360.9899910.556415-0.897179-0.702457
4610170.0-0.990967-0.7831980.0002260.0000000.9887080.9896120.585818-0.900181-0.775068
4610180.0-0.9910330.1217220.0000000.0000000.9932750.9892400.392961-0.862811-0.724277
4610190.0-0.9911250.6560910.0000000.0000000.9884620.989350-0.645440-0.958731-0.672953
4610200.0-0.991184-0.5463960.0000000.0000000.9889800.990302-0.648667-0.990743-0.902138
\n", "

453693 rows × 10 columns

\n", "
" ], "text/plain": [ " ret_date size rev beta bm illiq illiq_12m \\\n", "0 0.0 0.970696 -0.853480 0.123810 -0.696703 -0.979487 0.000000 \n", "1 0.0 0.972444 -0.412618 -0.219724 -0.641769 -0.960841 0.000000 \n", "2 0.0 0.968481 -0.664756 0.375358 -0.588825 -0.958453 0.000000 \n", "3 0.0 0.969936 0.483178 -0.079456 -0.176807 -0.975662 0.000000 \n", "4 0.0 0.965567 0.522238 -0.301291 -0.318508 -0.971306 -0.977044 \n", "... ... ... ... ... ... ... ... \n", "461016 0.0 -0.990901 -0.011374 -0.000227 0.000000 0.989536 0.989991 \n", "461017 0.0 -0.990967 -0.783198 0.000226 0.000000 0.988708 0.989612 \n", "461018 0.0 -0.991033 0.121722 0.000000 0.000000 0.993275 0.989240 \n", "461019 0.0 -0.991125 0.656091 0.000000 0.000000 0.988462 0.989350 \n", "461020 0.0 -0.991184 -0.546396 0.000000 0.000000 0.988980 0.990302 \n", "\n", " mom vol ivol \n", "0 0.000000 0.321612 0.000000 \n", "1 0.000000 -0.601160 -0.718637 \n", "2 0.000000 0.415473 -0.616046 \n", "3 0.000000 0.218325 0.400143 \n", "4 0.000000 -0.707317 -0.519369 \n", "... ... ... ... \n", "461016 0.556415 -0.897179 -0.702457 \n", "461017 0.585818 -0.900181 -0.775068 \n", "461018 0.392961 -0.862811 -0.724277 \n", "461019 -0.645440 -0.958731 -0.672953 \n", "461020 -0.648667 -0.990743 -0.902138 \n", "\n", "[453693 rows x 10 columns]" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "temp" ] }, { "cell_type": "code", "execution_count": 67, "id": "altered-contrary", "metadata": {}, "outputs": [], "source": [ "df_rank = pd.merge(df.drop(num_X_cols, axis=1),\n", " temp.drop('ret_date',axis=1),\n", " left_index=True, right_index=True)" ] }, { "cell_type": "code", "execution_count": 68, "id": "differential-israel", "metadata": {}, "outputs": [], "source": [ "del temp" ] }, { "cell_type": "code", "execution_count": 69, "id": "adaptive-customer", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
secIDret_datetradeDateexretymsizerevbetabmilliqilliq_12mmomvolivol
0000001.XSHE2008-012007-12-28-0.1402552007-120.970696-0.8534800.123810-0.696703-0.9794870.0000000.0000000.3216120.000000
1000001.XSHE2008-022008-01-31-0.0074502008-010.972444-0.412618-0.219724-0.641769-0.9608410.0000000.000000-0.601160-0.718637
2000001.XSHE2008-032008-02-29-0.1520682008-020.968481-0.6647560.375358-0.588825-0.9584530.0000000.0000000.415473-0.616046
3000001.XSHE2008-042008-03-310.0474932008-030.9699360.483178-0.079456-0.176807-0.9756620.0000000.0000000.2183250.400143
4000001.XSHE2008-052008-04-30-0.1511642008-040.9655670.522238-0.301291-0.318508-0.971306-0.9770440.000000-0.707317-0.519369
.............................................
461016900957.XSHG2021-112021-10-29-0.0425882021-10-0.990901-0.011374-0.0002270.0000000.9895360.9899910.556415-0.897179-0.702457
461017900957.XSHG2021-122021-11-300.0338052021-11-0.990967-0.7831980.0002260.0000000.9887080.9896120.585818-0.900181-0.775068
461018900957.XSHG2022-012021-12-31-0.0240272021-12-0.9910330.1217220.0000000.0000000.9932750.9892400.392961-0.862811-0.724277
461019900957.XSHG2022-022022-01-28-0.0131752022-01-0.9911250.6560910.0000000.0000000.9884620.989350-0.645440-0.958731-0.672953
461020900957.XSHG2022-032022-02-28-0.0360662022-02-0.991184-0.5463960.0000000.0000000.9889800.990302-0.648667-0.990743-0.902138
\n", "

453693 rows × 14 columns

\n", "
" ], "text/plain": [ " secID ret_date tradeDate exret ym size \\\n", "0 000001.XSHE 2008-01 2007-12-28 -0.140255 2007-12 0.970696 \n", "1 000001.XSHE 2008-02 2008-01-31 -0.007450 2008-01 0.972444 \n", "2 000001.XSHE 2008-03 2008-02-29 -0.152068 2008-02 0.968481 \n", "3 000001.XSHE 2008-04 2008-03-31 0.047493 2008-03 0.969936 \n", "4 000001.XSHE 2008-05 2008-04-30 -0.151164 2008-04 0.965567 \n", "... ... ... ... ... ... ... \n", "461016 900957.XSHG 2021-11 2021-10-29 -0.042588 2021-10 -0.990901 \n", "461017 900957.XSHG 2021-12 2021-11-30 0.033805 2021-11 -0.990967 \n", "461018 900957.XSHG 2022-01 2021-12-31 -0.024027 2021-12 -0.991033 \n", "461019 900957.XSHG 2022-02 2022-01-28 -0.013175 2022-01 -0.991125 \n", "461020 900957.XSHG 2022-03 2022-02-28 -0.036066 2022-02 -0.991184 \n", "\n", " rev beta bm illiq illiq_12m mom vol \\\n", "0 -0.853480 0.123810 -0.696703 -0.979487 0.000000 0.000000 0.321612 \n", "1 -0.412618 -0.219724 -0.641769 -0.960841 0.000000 0.000000 -0.601160 \n", "2 -0.664756 0.375358 -0.588825 -0.958453 0.000000 0.000000 0.415473 \n", "3 0.483178 -0.079456 -0.176807 -0.975662 0.000000 0.000000 0.218325 \n", "4 0.522238 -0.301291 -0.318508 -0.971306 -0.977044 0.000000 -0.707317 \n", "... ... ... ... ... ... ... ... \n", "461016 -0.011374 -0.000227 0.000000 0.989536 0.989991 0.556415 -0.897179 \n", "461017 -0.783198 0.000226 0.000000 0.988708 0.989612 0.585818 -0.900181 \n", "461018 0.121722 0.000000 0.000000 0.993275 0.989240 0.392961 -0.862811 \n", "461019 0.656091 0.000000 0.000000 0.988462 0.989350 -0.645440 -0.958731 \n", "461020 -0.546396 0.000000 0.000000 0.988980 0.990302 -0.648667 -0.990743 \n", "\n", " ivol \n", "0 0.000000 \n", "1 -0.718637 \n", "2 -0.616046 \n", "3 0.400143 \n", "4 -0.519369 \n", "... ... \n", "461016 -0.702457 \n", "461017 -0.775068 \n", "461018 -0.724277 \n", "461019 -0.672953 \n", "461020 -0.902138 \n", "\n", "[453693 rows x 14 columns]" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_rank" ] }, { "cell_type": "markdown", "id": "irish-plant", "metadata": {}, "source": [ "# Train, Validation, Test split" ] }, { "cell_type": "code", "execution_count": 70, "id": "overhead-bullet", "metadata": {}, "outputs": [], "source": [ "df_rank['year'] = df_rank['ret_date'].dt.year" ] }, { "cell_type": "code", "execution_count": 74, "id": "forbidden-glass", "metadata": {}, "outputs": [], "source": [ "time_idx = [value for (key, value) in sorted(df_rank.groupby('year').groups.items())]" ] }, { "cell_type": "code", "execution_count": 75, "id": "metric-dimension", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[Int64Index([ 0, 1, 2, 3, 4, 5, 6, 7,\n", " 8, 9,\n", " ...\n", " 460852, 460853, 460854, 460855, 460856, 460857, 460858, 460859,\n", " 460860, 460861],\n", " dtype='int64', length=17347),\n", " Int64Index([ 12, 13, 14, 15, 16, 17, 18, 19,\n", " 20, 21,\n", " ...\n", " 460864, 460865, 460866, 460867, 460868, 460869, 460870, 460871,\n", " 460872, 460873],\n", " dtype='int64', length=18345),\n", " Int64Index([ 24, 25, 26, 27, 28, 29, 31, 32,\n", " 192, 193,\n", " ...\n", " 460876, 460877, 460878, 460879, 460880, 460881, 460882, 460883,\n", " 460884, 460885],\n", " dtype='int64', length=20770),\n", " Int64Index([ 33, 34, 35, 36, 37, 38, 39, 40,\n", " 41, 42,\n", " ...\n", " 460888, 460889, 460890, 460891, 460892, 460893, 460894, 460895,\n", " 460896, 460897],\n", " dtype='int64', length=24588),\n", " Int64Index([ 45, 46, 47, 48, 49, 50, 51, 52,\n", " 53, 54,\n", " ...\n", " 460900, 460901, 460902, 460903, 460904, 460905, 460906, 460907,\n", " 460908, 460909],\n", " dtype='int64', length=27649),\n", " Int64Index([ 57, 58, 59, 60, 61, 62, 63, 64,\n", " 65, 66,\n", " ...\n", " 460912, 460913, 460914, 460915, 460916, 460917, 460918, 460919,\n", " 460920, 460921],\n", " dtype='int64', length=28885),\n", " Int64Index([ 69, 70, 71, 72, 73, 74, 75, 76,\n", " 77, 78,\n", " ...\n", " 460924, 460925, 460926, 460927, 460928, 460929, 460930, 460931,\n", " 460932, 460933],\n", " dtype='int64', length=28408),\n", " Int64Index([ 81, 82, 83, 84, 85, 86, 87, 88,\n", " 89, 90,\n", " ...\n", " 460936, 460937, 460938, 460939, 460940, 460941, 460942, 460943,\n", " 460944, 460945],\n", " dtype='int64', length=28331),\n", " Int64Index([ 93, 94, 95, 96, 97, 98, 99, 100,\n", " 101, 102,\n", " ...\n", " 460948, 460949, 460950, 460951, 460952, 460953, 460954, 460955,\n", " 460956, 460957],\n", " dtype='int64', length=31459),\n", " Int64Index([ 105, 106, 107, 108, 109, 110, 111, 112,\n", " 113, 114,\n", " ...\n", " 460960, 460961, 460962, 460963, 460964, 460965, 460966, 460967,\n", " 460968, 460969],\n", " dtype='int64', length=36050),\n", " Int64Index([ 117, 118, 119, 120, 121, 122, 123, 124,\n", " 125, 126,\n", " ...\n", " 460972, 460973, 460974, 460975, 460976, 460977, 460978, 460979,\n", " 460980, 460981],\n", " dtype='int64', length=40026),\n", " Int64Index([ 129, 130, 131, 132, 133, 134, 135, 136,\n", " 137, 138,\n", " ...\n", " 460984, 460985, 460986, 460987, 460988, 460989, 460990, 460991,\n", " 460992, 460993],\n", " dtype='int64', length=43017),\n", " Int64Index([ 141, 142, 143, 144, 145, 146, 147, 148,\n", " 149, 150,\n", " ...\n", " 460996, 460997, 460998, 460999, 461000, 461001, 461002, 461003,\n", " 461004, 461005],\n", " dtype='int64', length=45124),\n", " Int64Index([ 153, 154, 155, 156, 157, 158, 159, 160,\n", " 161, 162,\n", " ...\n", " 461008, 461009, 461010, 461011, 461012, 461013, 461014, 461015,\n", " 461016, 461017],\n", " dtype='int64', length=50192),\n", " Int64Index([ 165, 166, 167, 329, 330, 331, 449, 450,\n", " 451, 751,\n", " ...\n", " 459999, 460408, 460409, 460410, 460564, 460565, 460566, 461018,\n", " 461019, 461020],\n", " dtype='int64', length=13502)]" ] }, "execution_count": 75, "metadata": {}, "output_type": "execute_result" } ], "source": [ "time_idx" ] }, { "cell_type": "code", "execution_count": 76, "id": "posted-tackle", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "year\n", "2008 1559\n", "2009 1627\n", "2010 1934\n", "2011 2231\n", "2012 2477\n", "2013 2530\n", "2014 2649\n", "2015 2863\n", "2016 3028\n", "2017 3471\n", "2018 3605\n", "2019 3739\n", "2020 4045\n", "2021 4509\n", "2022 4538\n", "Name: secID, dtype: int64" ] }, "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_rank.groupby('year')['secID'].nunique()" ] }, { "cell_type": "code", "execution_count": 77, "id": "tribal-humanitarian", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "year\n", "2008 17347\n", "2009 18345\n", "2010 20770\n", "2011 24588\n", "2012 27649\n", "2013 28885\n", "2014 28408\n", "2015 28331\n", "2016 31459\n", "2017 36050\n", "2018 40026\n", "2019 43017\n", "2020 45124\n", "2021 50192\n", "2022 13502\n", "Name: secID, dtype: int64" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_rank.groupby('year')['secID'].count()" ] }, { "cell_type": "code", "execution_count": 79, "id": "ordinary-twins", "metadata": {}, "outputs": [], "source": [ "def list_flat(list_):\n", " return [item for sublist in list_ for item in sublist]\n", "# This is the same as:\n", "# def list_flat2(list_):\n", "# result = []\n", "# for sublist in list_:\n", "# for item in sublist:\n", "# result.append(item)\n", "# return result" ] }, { "cell_type": "code", "execution_count": 80, "id": "micro-department", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[1, 2, 3, 3, 4, 5]" ] }, "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list_flat([[1,2,3],[3,4,5]])" ] }, { "cell_type": "code", "execution_count": 81, "id": "543e2266-a55b-4023-9acf-58303998eff8", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
secIDret_datetradeDateexretymsizerevbetabmilliqilliq_12mmomvolivolyear
0000001.XSHE2008-012007-12-28-0.1402552007-120.970696-0.8534800.123810-0.696703-0.9794870.0000000.0000000.3216120.0000002008
1000001.XSHE2008-022008-01-31-0.0074502008-010.972444-0.412618-0.219724-0.641769-0.9608410.0000000.000000-0.601160-0.7186372008
2000001.XSHE2008-032008-02-29-0.1520682008-020.968481-0.6647560.375358-0.588825-0.9584530.0000000.0000000.415473-0.6160462008
3000001.XSHE2008-042008-03-310.0474932008-030.9699360.483178-0.079456-0.176807-0.9756620.0000000.0000000.2183250.4001432008
4000001.XSHE2008-052008-04-30-0.1511642008-040.9655670.522238-0.301291-0.318508-0.971306-0.9770440.000000-0.707317-0.5193692008
................................................
461016900957.XSHG2021-112021-10-29-0.0425882021-10-0.990901-0.011374-0.0002270.0000000.9895360.9899910.556415-0.897179-0.7024572021
461017900957.XSHG2021-122021-11-300.0338052021-11-0.990967-0.7831980.0002260.0000000.9887080.9896120.585818-0.900181-0.7750682021
461018900957.XSHG2022-012021-12-31-0.0240272021-12-0.9910330.1217220.0000000.0000000.9932750.9892400.392961-0.862811-0.7242772022
461019900957.XSHG2022-022022-01-28-0.0131752022-01-0.9911250.6560910.0000000.0000000.9884620.989350-0.645440-0.958731-0.6729532022
461020900957.XSHG2022-032022-02-28-0.0360662022-02-0.991184-0.5463960.0000000.0000000.9889800.990302-0.648667-0.990743-0.9021382022
\n", "

453693 rows × 15 columns

\n", "
" ], "text/plain": [ " secID ret_date tradeDate exret ym size \\\n", "0 000001.XSHE 2008-01 2007-12-28 -0.140255 2007-12 0.970696 \n", "1 000001.XSHE 2008-02 2008-01-31 -0.007450 2008-01 0.972444 \n", "2 000001.XSHE 2008-03 2008-02-29 -0.152068 2008-02 0.968481 \n", "3 000001.XSHE 2008-04 2008-03-31 0.047493 2008-03 0.969936 \n", "4 000001.XSHE 2008-05 2008-04-30 -0.151164 2008-04 0.965567 \n", "... ... ... ... ... ... ... \n", "461016 900957.XSHG 2021-11 2021-10-29 -0.042588 2021-10 -0.990901 \n", "461017 900957.XSHG 2021-12 2021-11-30 0.033805 2021-11 -0.990967 \n", "461018 900957.XSHG 2022-01 2021-12-31 -0.024027 2021-12 -0.991033 \n", "461019 900957.XSHG 2022-02 2022-01-28 -0.013175 2022-01 -0.991125 \n", "461020 900957.XSHG 2022-03 2022-02-28 -0.036066 2022-02 -0.991184 \n", "\n", " rev beta bm illiq illiq_12m mom vol \\\n", "0 -0.853480 0.123810 -0.696703 -0.979487 0.000000 0.000000 0.321612 \n", "1 -0.412618 -0.219724 -0.641769 -0.960841 0.000000 0.000000 -0.601160 \n", "2 -0.664756 0.375358 -0.588825 -0.958453 0.000000 0.000000 0.415473 \n", "3 0.483178 -0.079456 -0.176807 -0.975662 0.000000 0.000000 0.218325 \n", "4 0.522238 -0.301291 -0.318508 -0.971306 -0.977044 0.000000 -0.707317 \n", "... ... ... ... ... ... ... ... \n", "461016 -0.011374 -0.000227 0.000000 0.989536 0.989991 0.556415 -0.897179 \n", "461017 -0.783198 0.000226 0.000000 0.988708 0.989612 0.585818 -0.900181 \n", "461018 0.121722 0.000000 0.000000 0.993275 0.989240 0.392961 -0.862811 \n", "461019 0.656091 0.000000 0.000000 0.988462 0.989350 -0.645440 -0.958731 \n", "461020 -0.546396 0.000000 0.000000 0.988980 0.990302 -0.648667 -0.990743 \n", "\n", " ivol year \n", "0 0.000000 2008 \n", "1 -0.718637 2008 \n", "2 -0.616046 2008 \n", "3 0.400143 2008 \n", "4 -0.519369 2008 \n", "... ... ... \n", "461016 -0.702457 2021 \n", "461017 -0.775068 2021 \n", "461018 -0.724277 2022 \n", "461019 -0.672953 2022 \n", "461020 -0.902138 2022 \n", "\n", "[453693 rows x 15 columns]" ] }, "execution_count": 81, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_rank" ] }, { "cell_type": "code", "execution_count": 45, "id": "crazy-pleasure", "metadata": {}, "outputs": [], "source": [ "# training, validation, testing scheme:\n", "# 1. [2008-2011], [2012-2015], [2016]\n", "# 2. [2008-2012], [2013-2016], [2017]\n", "# ...\n", "# last. [2008-2016], [2017-2020], [2021]\n", "fulltrain_idx = []\n", "cv_idx = []\n", "test_idx = []\n", "for i in range(4,len(time_idx)-4):\n", " train_idx = list_flat(time_idx[0:i])\n", " val_idx = list_flat(time_idx[i:i+4])\n", " fulltrain_idx.append(train_idx + val_idx)\n", " cv_idx.append((np.where(np.isin(fulltrain_idx[-1], train_idx))[0], \n", " np.where(np.isin(fulltrain_idx[-1], val_idx))[0])) # GridSearchCV 内部用 array 操作,不能带着pandas的index,\n", " # 因此cv_idx需要用fulltrain_idx的编号从0开始\n", " test_idx.append(time_idx[i+4])" ] }, { "cell_type": "code", "execution_count": 46, "id": "amino-endorsement", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0, 4])" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Example\n", "a = [0,1,4,5,3000]\n", "np.where(np.isin(a, [0,3000]))[0]" ] }, { "cell_type": "code", "execution_count": 47, "id": "relative-circuit", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[2016, 2017, 2018, 2019, 2020, 2021]" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_years = list(range(2016, 2022))\n", "test_years" ] }, { "cell_type": "markdown", "id": "innocent-george", "metadata": {}, "source": [ "# Evaluation metrics" ] }, { "cell_type": "code", "execution_count": 48, "id": "alpha-steel", "metadata": {}, "outputs": [], "source": [ "def r2_oos(y_true, y_pred):\n", " return 1 - np.sum((y_true - y_pred)**2) / np.sum(y_true**2)" ] }, { "cell_type": "code", "execution_count": 49, "id": "clinical-allergy", "metadata": {}, "outputs": [], "source": [ "r2_oos_scorer = make_scorer(r2_oos)" ] }, { "cell_type": "markdown", "id": "periodic-harrison", "metadata": {}, "source": [ "# Models" ] }, { "cell_type": "markdown", "id": "preceding-humanitarian", "metadata": {}, "source": [ "## Linear regression" ] }, { "cell_type": "code", "execution_count": 50, "id": "spiritual-university", "metadata": {}, "outputs": [], "source": [ "cols = [col for col in num_X_cols if col != 'illiq_12m' and col!='vol']" ] }, { "cell_type": "code", "execution_count": 51, "id": "uniform-ebony", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['size', 'rev', 'mom', 'beta', 'bm', 'illiq', 'ivol']" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cols" ] }, { "cell_type": "code", "execution_count": 52, "id": "cultural-blocking", "metadata": {}, "outputs": [], "source": [ "model = LinearRegression()" ] }, { "cell_type": "code", "execution_count": 54, "id": "breeding-announcement", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test year 2016 : -0.0117860555614957\n", "Test year 2017 : -0.08186440331616396\n", "Test year 2018 : -0.045722182188346894\n", "Test year 2019 : 0.0053533305241438844\n", "Test year 2020 : -0.0006830141045139904\n", "Test year 2021 : -0.01959896422561891\n" ] } ], "source": [ "for i in range(len(fulltrain_idx)):\n", " X_fulltrain = df_rank.loc[fulltrain_idx[i], cols]\n", " y_fulltrain = df_rank.loc[fulltrain_idx[i], 'exret']\n", " X_test = df_rank.loc[test_idx[i], cols]\n", " y_test = df_rank.loc[test_idx[i], 'exret']\n", " \n", " model.fit(X=X_fulltrain, y=y_fulltrain)\n", " y_pred = model.predict(X=X_test)\n", " \n", " print(\"Test year\", test_years[i],\":\",r2_oos(y_true=y_test, y_pred=y_pred))" ] }, { "cell_type": "code", "execution_count": 55, "id": "superb-standard", "metadata": {}, "outputs": [], "source": [ "cols = ['size','rev','illiq','ivol']" ] }, { "cell_type": "code", "execution_count": 56, "id": "raising-denver", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test year 2016 : -0.011951287753103168\n", "Test year 2017 : -0.08295671126910364\n", "Test year 2018 : -0.045758754149365366\n", "Test year 2019 : 0.006047261470831122\n", "Test year 2020 : -0.0011655147730460502\n", "Test year 2021 : -0.020928820173429674\n" ] } ], "source": [ "for i in range(len(fulltrain_idx)):\n", " X_fulltrain = df_rank.loc[fulltrain_idx[i], cols]\n", " y_fulltrain = df_rank.loc[fulltrain_idx[i], 'exret']\n", " X_test = df_rank.loc[test_idx[i], cols]\n", " y_test = df_rank.loc[test_idx[i], 'exret']\n", " \n", " model.fit(X=X_fulltrain, y=y_fulltrain)\n", " y_pred = model.predict(X=X_test)\n", " \n", " print(\"Test year\", test_years[i],\":\",r2_oos(y_true=y_test, y_pred=y_pred))" ] }, { "cell_type": "markdown", "id": "prescription-campus", "metadata": {}, "source": [ "## Huber regressor" ] }, { "cell_type": "code", "execution_count": 57, "id": "bright-wound", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['size', 'rev', 'mom', 'beta', 'bm', 'illiq', 'ivol']" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cols = [col for col in num_X_cols if col != 'illiq_12m' and col!='vol']\n", "cols" ] }, { "cell_type": "code", "execution_count": 58, "id": "cheap-paste", "metadata": {}, "outputs": [], "source": [ "model = HuberRegressor(alpha=0.01,epsilon=1.05)" ] }, { "cell_type": "code", "execution_count": 59, "id": "printable-farming", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test year 2016 : 0.0014038503898485821\n", "Test year 2017 : -0.04070584873067329\n", "Test year 2018 : 0.0016444781377967788\n", "Test year 2019 : -0.016957932821754618\n", "Test year 2020 : -0.013372340329569798\n", "Test year 2021 : 0.008736696659550569\n" ] } ], "source": [ "for i in range(len(fulltrain_idx)):\n", " X_fulltrain = df_rank.loc[fulltrain_idx[i], cols]\n", " y_fulltrain = df_rank.loc[fulltrain_idx[i], 'exret']\n", " X_test = df_rank.loc[test_idx[i], cols]\n", " y_test = df_rank.loc[test_idx[i], 'exret']\n", " \n", " model.fit(X=X_fulltrain, y=y_fulltrain)\n", " y_pred = model.predict(X=X_test)\n", " \n", " print(\"Test year\", test_years[i],\":\",r2_oos(y_true=y_test, y_pred=y_pred))" ] }, { "cell_type": "markdown", "id": "actual-brunswick", "metadata": {}, "source": [ "## Random Forest" ] }, { "cell_type": "code", "execution_count": 60, "id": "elder-cheese", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['size', 'rev', 'mom', 'beta', 'bm', 'illiq', 'illiq_12m', 'vol', 'ivol']" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cols = num_X_cols\n", "cols" ] }, { "cell_type": "code", "execution_count": 66, "id": "liable-corps", "metadata": {}, "outputs": [], "source": [ "hyperparam_grid = [\n", " {'n_estimators': [100], 'max_depth': [1,3,7], \n", " 'max_features': [3,5,len(cols)]}\n", "]" ] }, { "cell_type": "code", "execution_count": 67, "id": "tight-helena", "metadata": {}, "outputs": [], "source": [ "model = RandomForestRegressor(random_state=42)" ] }, { "cell_type": "code", "execution_count": 68, "id": "historic-chorus", "metadata": {}, "outputs": [], "source": [ "# Cross validation for period 0, i.e.\n", "# train: [2008-2011], val: [2012-2015], test: [2016]\n", "grid_search = GridSearchCV(model, hyperparam_grid, cv=[cv_idx[0]], \n", " scoring=r2_oos_scorer,\n", " return_train_score=True)" ] }, { "cell_type": "code", "execution_count": 69, "id": "chinese-terry", "metadata": {}, "outputs": [], "source": [ "X_fulltrain = df_rank.loc[fulltrain_idx[0], cols]\n", "y_fulltrain = df_rank.loc[fulltrain_idx[0], 'exret']\n", "X_test = df_rank.loc[test_idx[0], cols]\n", "y_test = df_rank.loc[test_idx[0], 'exret']" ] }, { "cell_type": "code", "execution_count": 59, "id": "wanted-aruba", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1min 48s, sys: 385 ms, total: 1min 49s\n", "Wall time: 1min 49s\n" ] }, { "data": { "text/plain": [ "GridSearchCV(cv=[(array([ 0, 1, 2, ..., 86282, 86283, 86284]),\n", " array([ 86285, 86286, 86287, ..., 203045, 203046, 203047]))],\n", " estimator=RandomForestRegressor(random_state=42),\n", " param_grid=[{'max_depth': [1, 3, 7], 'max_features': [3, 5, 9],\n", " 'n_estimators': [100]}],\n", " return_train_score=True, scoring=make_scorer(r2_oos))" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "grid_search.fit(X_fulltrain, y_fulltrain)" ] }, { "cell_type": "code", "execution_count": 60, "id": "special-translation", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'max_depth': 3, 'max_features': 3, 'n_estimators': 100}" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grid_search.best_params_" ] }, { "cell_type": "code", "execution_count": 61, "id": "after-construction", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.05127266281889475 {'max_depth': 1, 'max_features': 3, 'n_estimators': 100}\n", "0.0534381698663029 {'max_depth': 1, 'max_features': 5, 'n_estimators': 100}\n", "0.05949245846715574 {'max_depth': 1, 'max_features': 9, 'n_estimators': 100}\n", "0.06188021686062709 {'max_depth': 3, 'max_features': 3, 'n_estimators': 100}\n", "0.05804954897281175 {'max_depth': 3, 'max_features': 5, 'n_estimators': 100}\n", "0.06081729711883632 {'max_depth': 3, 'max_features': 9, 'n_estimators': 100}\n", "0.05614456342229886 {'max_depth': 7, 'max_features': 3, 'n_estimators': 100}\n", "0.043036296006267676 {'max_depth': 7, 'max_features': 5, 'n_estimators': 100}\n", "0.04356663192832121 {'max_depth': 7, 'max_features': 9, 'n_estimators': 100}\n" ] } ], "source": [ "cv_results = grid_search.cv_results_\n", "for mean_score, params in zip(cv_results['mean_test_score'],\n", " cv_results['params']):\n", " print(np.sqrt(mean_score), params)" ] }, { "cell_type": "code", "execution_count": 62, "id": "tired-abraham", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
featuresfeature_importance
5illiq0.302548
0size0.211411
1rev0.198549
6illiq_12m0.113615
8ivol0.092814
7vol0.052963
2mom0.010744
4bm0.010535
3beta0.006820
\n", "
" ], "text/plain": [ " features feature_importance\n", "5 illiq 0.302548\n", "0 size 0.211411\n", "1 rev 0.198549\n", "6 illiq_12m 0.113615\n", "8 ivol 0.092814\n", "7 vol 0.052963\n", "2 mom 0.010744\n", "4 bm 0.010535\n", "3 beta 0.006820" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame({\"features\":num_X_cols,\"feature_importance\":grid_search.best_estimator_.feature_importances_}).sort_values('feature_importance',\n", " ascending=False)" ] }, { "cell_type": "code", "execution_count": 63, "id": "polished-sequence", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "-0.012832337545194417" ] }, "execution_count": 63, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_pred = grid_search.predict(X_test)\n", "r2_oos(y_true=y_test, y_pred=y_pred)" ] }, { "cell_type": "code", "execution_count": 64, "id": "satisfactory-bonus", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test year 2016 : -0.012832337545194417\n", "Test year 2017 : -0.0820561067657255\n", "Test year 2018 : -0.04409182586886584\n", "Test year 2019 : 0.007830509088117443\n", "Test year 2020 : 0.003591662917594607\n", "Test year 2021 : -0.015391690896670474\n", "CPU times: user 22min 37s, sys: 4.99 s, total: 22min 42s\n", "Wall time: 22min 50s\n" ] } ], "source": [ "%%time\n", "for i in range(len(fulltrain_idx)):\n", " X_fulltrain = df_rank.loc[fulltrain_idx[i], cols]\n", " y_fulltrain = df_rank.loc[fulltrain_idx[i], 'exret']\n", " X_test = df_rank.loc[test_idx[i], cols]\n", " y_test = df_rank.loc[test_idx[i], 'exret']\n", " \n", " grid_search = GridSearchCV(model, hyperparam_grid, cv=[cv_idx[i]], \n", " scoring=r2_oos_scorer,\n", " return_train_score=True)\n", " grid_search.fit(X_fulltrain, y_fulltrain)\n", " y_pred = grid_search.predict(X=X_test)\n", " \n", " print(\"Test year\", test_years[i],\":\",r2_oos(y_true=y_test, y_pred=y_pred))" ] }, { "cell_type": "markdown", "id": "homeless-civilian", "metadata": {}, "source": [ "## Partial Least Squares" ] }, { "cell_type": "code", "execution_count": 61, "id": "trying-assistant", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['size', 'rev', 'mom', 'beta', 'bm', 'illiq', 'illiq_12m', 'vol', 'ivol']" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cols = num_X_cols\n", "cols" ] }, { "cell_type": "code", "execution_count": 62, "id": "cultural-routine", "metadata": {}, "outputs": [], "source": [ "model = PLSRegression(n_components=4)" ] }, { "cell_type": "code", "execution_count": 64, "id": "rising-transcription", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(8378,)" ] }, "execution_count": 64, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_pred.reshape(-1).shape" ] }, { "cell_type": "code", "execution_count": 76, "id": "greatest-falls", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test year 2016 : -0.01123119284077867\n", "Test year 2017 : -0.08973582085167675\n", "Test year 2018 : -0.0454224712689455\n", "Test year 2019 : 0.00484391775863835\n", "Test year 2020 : -0.0007496252971606054\n", "Test year 2021 : -0.02132040728407225\n", "CPU times: user 7.04 s, sys: 224 ms, total: 7.27 s\n", "Wall time: 2.41 s\n" ] } ], "source": [ "%%time\n", "for i in range(len(fulltrain_idx)):\n", " X_fulltrain = df_rank.loc[fulltrain_idx[i], cols]\n", " y_fulltrain = df_rank.loc[fulltrain_idx[i], 'exret']\n", " X_test = df_rank.loc[test_idx[i], cols]\n", " y_test = df_rank.loc[test_idx[i], 'exret']\n", " \n", " model.fit(X_fulltrain, y_fulltrain)\n", " y_pred = model.predict(X=X_test)\n", " y_pred = y_pred.reshape(-1)\n", " print(\"Test year\", test_years[i],\":\",r2_oos(y_true=y_test, y_pred=y_pred))" ] }, { "cell_type": "markdown", "id": "enabling-worker", "metadata": {}, "source": [ "## Principal Component Regression" ] }, { "cell_type": "markdown", "id": "derived-insurance", "metadata": {}, "source": [ "### PCA transform" ] }, { "cell_type": "code", "execution_count": 65, "id": "favorite-display", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['size', 'rev', 'mom', 'beta', 'bm', 'illiq', 'illiq_12m', 'vol', 'ivol']" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cols = num_X_cols\n", "cols" ] }, { "cell_type": "code", "execution_count": 66, "id": "limited-february", "metadata": {}, "outputs": [], "source": [ "X_fulltrain = df_rank.loc[fulltrain_idx[0], cols]\n", "y_fulltrain = df_rank.loc[fulltrain_idx[0], 'exret']\n", "X_test = df_rank.loc[test_idx[0],cols]\n", "y_test = df_rank.loc[test_idx[0],'exret']" ] }, { "cell_type": "code", "execution_count": 67, "id": "reliable-gabriel", "metadata": {}, "outputs": [], "source": [ "pca = PCA(3, random_state=42)" ] }, { "cell_type": "code", "execution_count": 68, "id": "advisory-table", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PCA(n_components=3, random_state=42)" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pca.fit(X_fulltrain)" ] }, { "cell_type": "code", "execution_count": 69, "id": "removed-swiss", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 0.54714595, -0.01385498, 0.04325358, 0.17598213, 0.04193536,\n", " -0.57860615, -0.57126752, 0.06380787, 0.02524075],\n", " [-0.06032312, 0.30002583, 0.26579743, -0.01213907, -0.39538595,\n", " -0.04921141, 0.06139761, 0.56294733, 0.59675692],\n", " [ 0.12342608, -0.20637066, 0.50242834, -0.61520441, -0.44211959,\n", " -0.07749277, -0.02191892, -0.30373145, -0.13060767]])" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pca.components_" ] }, { "cell_type": "code", "execution_count": 70, "id": "committed-costa", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(3, 9)" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pca.components_.shape" ] }, { "cell_type": "code", "execution_count": 71, "id": "equivalent-equipment", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(203048, 9)" ] }, "execution_count": 71, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_fulltrain.shape" ] }, { "cell_type": "code", "execution_count": 84, "id": "annoying-cincinnati", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(9, 3)" ] }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pca.components_.T.shape" ] }, { "cell_type": "code", "execution_count": 72, "id": "underlying-reducing", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 1.5880652 , -0.1593774 , 0.92155463],\n", " [ 1.74711817, 0.03463427, 0.4903537 ],\n", " [ 1.72960462, 0.78025716, 0.22764192],\n", " ...,\n", " [-1.74314491, -0.50257968, -0.34461461],\n", " [-1.66920021, 0.2734256 , -0.03379416],\n", " [-1.65589407, 0.43992714, -0.01309807]])" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.matmul(X_fulltrain.values,pca.components_.T)" ] }, { "cell_type": "code", "execution_count": 73, "id": "aerial-mills", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 1.5880652 , -0.1593774 , 0.92155463],\n", " [ 1.74711817, 0.03463427, 0.4903537 ],\n", " [ 1.72960462, 0.78025716, 0.22764192],\n", " ...,\n", " [-1.74314491, -0.50257968, -0.34461461],\n", " [-1.66920021, 0.2734256 , -0.03379416],\n", " [-1.65589407, 0.43992714, -0.01309807]])" ] }, "execution_count": 73, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pca.fit_transform(X_fulltrain)" ] }, { "cell_type": "markdown", "id": "connected-awareness", "metadata": {}, "source": [ "### PCA regression" ] }, { "cell_type": "markdown", "id": "cubic-drilling", "metadata": {}, "source": [ "sklearn 是 duck typing,因此无需继承,只需在定义类的时候包括对应的方法,`fit()`(return self),`transform()`,`fit_transform()`即可。\n", "\n", "但直接用继承,可以更方便。\n", "- `BaseEstimator`是sklearn里最基本的类,其他的类都从这个类继承而来,包括了`set_params()`和`get_params()`的方法。\n", "- `TransformerMixin`包括了`fit_transform()`方法。因此由这个类继承而来的话,就不用自定义 `fit_transform` 了\n", "- 类似的,`RegressorMixin`包括了`predict()`方法" ] }, { "cell_type": "code", "execution_count": 74, "id": "suitable-shareware", "metadata": {}, "outputs": [], "source": [ "class PCARegressor(BaseEstimator, RegressorMixin):\n", " def __init__(self, n_components=3):\n", " self.n_components = n_components\n", " \n", " def fit(self, X, y):\n", " self.pca_ = PCA(n_components=self.n_components).fit(X)\n", " self.X_ = self.pca_.transform(X)\n", " self.reg_ = LinearRegression().fit(self.X_,y)\n", " return self\n", " \n", " def predict(self, X):\n", " self.pred_ = self.reg_.predict(self.pca_.transform(X))\n", " return self.pred_" ] }, { "cell_type": "code", "execution_count": 75, "id": "proprietary-morrison", "metadata": {}, "outputs": [], "source": [ "model = PCARegressor()" ] }, { "cell_type": "code", "execution_count": 76, "id": "passing-jackson", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PCARegressor()" ] }, "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit(X=X_fulltrain, y=y_fulltrain)" ] }, { "cell_type": "code", "execution_count": 77, "id": "communist-burton", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 1.5880652 , -0.1593774 , 0.92155463],\n", " [ 1.74711817, 0.03463427, 0.4903537 ],\n", " [ 1.72960462, 0.78025716, 0.22764192],\n", " ...,\n", " [-1.74314491, -0.50257968, -0.34461461],\n", " [-1.66920021, 0.2734256 , -0.03379416],\n", " [-1.65589407, 0.43992714, -0.01309807]])" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.X_" ] }, { "cell_type": "code", "execution_count": 78, "id": "passing-ecuador", "metadata": {}, "outputs": [], "source": [ "hyperparam_grid = [\n", " {'n_components': range(1, len(cols)+1)}\n", "]" ] }, { "cell_type": "code", "execution_count": 79, "id": "duplicate-cambridge", "metadata": {}, "outputs": [], "source": [ "grid_search = GridSearchCV(model, hyperparam_grid, cv=[cv_idx[0]], \n", " scoring=r2_oos_scorer,\n", " return_train_score=True)" ] }, { "cell_type": "code", "execution_count": 80, "id": "norman-backup", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GridSearchCV(cv=[(array([ 0, 1, 2, ..., 86282, 86283, 86284]),\n", " array([ 86285, 86286, 86287, ..., 203045, 203046, 203047]))],\n", " estimator=PCARegressor(),\n", " param_grid=[{'n_components': range(1, 10)}],\n", " return_train_score=True, scoring=make_scorer(r2_oos))" ] }, "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grid_search.fit(X=X_fulltrain, y=y_fulltrain)" ] }, { "cell_type": "code", "execution_count": 81, "id": "egyptian-bathroom", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'n_components': 9}" ] }, "execution_count": 81, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grid_search.best_params_" ] }, { "cell_type": "code", "execution_count": 82, "id": "vocal-enforcement", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.06047645010846184 {'n_components': 1}\n", "0.05648442569524177 {'n_components': 2}\n", "0.05505885643926543 {'n_components': 3}\n", "0.06827088360625753 {'n_components': 4}\n", "0.07112802051912726 {'n_components': 5}\n", "0.07107450847527214 {'n_components': 6}\n", "0.06906175137901698 {'n_components': 7}\n", "0.07715717389016781 {'n_components': 8}\n", "0.07815893155696621 {'n_components': 9}\n" ] } ], "source": [ "cv_results = grid_search.cv_results_\n", "for mean_score, params in zip(cv_results['mean_test_score'],\n", " cv_results['params']):\n", " print(np.sqrt(mean_score), params)" ] }, { "cell_type": "code", "execution_count": 83, "id": "fixed-washington", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "-0.01117901387512199" ] }, "execution_count": 83, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_pred = grid_search.predict(X_test)\n", "r2_oos(y_true=y_test,y_pred=y_pred)" ] }, { "cell_type": "code", "execution_count": 84, "id": "dietary-static", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test year 2016 : -0.01117901387512199\n", "Test year 2017 : -0.08825633439790259\n", "Test year 2018 : -0.04484009526192567\n", "Test year 2019 : 0.005339549758387907\n", "Test year 2020 : -0.00026464526848823944\n", "Test year 2021 : -0.021433913684697492\n", "CPU times: user 53.8 s, sys: 4.24 s, total: 58.1 s\n", "Wall time: 15.5 s\n" ] } ], "source": [ "%%time\n", "for i in range(len(fulltrain_idx)):\n", " X_fulltrain = df_rank.loc[fulltrain_idx[i], cols]\n", " y_fulltrain = df_rank.loc[fulltrain_idx[i], 'exret']\n", " X_test = df_rank.loc[test_idx[i], cols]\n", " y_test = df_rank.loc[test_idx[i], 'exret']\n", " \n", " grid_search = GridSearchCV(model, hyperparam_grid, cv=[cv_idx[i]], \n", " scoring=r2_oos_scorer,\n", " return_train_score=True)\n", " grid_search.fit(X_fulltrain, y_fulltrain)\n", " y_pred = grid_search.predict(X=X_test)\n", " y_pred = y_pred.reshape(-1)\n", " print(\"Test year\", test_years[i],\":\",r2_oos(y_true=y_test, y_pred=y_pred))" ] }, { "cell_type": "markdown", "id": "rocky-portfolio", "metadata": {}, "source": [ "## Pipeline" ] }, { "cell_type": "code", "execution_count": 85, "id": "provincial-bridal", "metadata": {}, "outputs": [], "source": [ "pca = PCA()\n", "linear_reg = LinearRegression()\n", "pipeline = Pipeline(steps=[('pca',pca),\n", " ('linear_regression', linear_reg)])\n", "hyperparam_grid = {'pca__n_components': range(1,len(cols)+1)}\n", "grid_search = GridSearchCV(pipeline, hyperparam_grid, cv=[cv_idx[0]],\n", " scoring=r2_oos_scorer,\n", " return_train_score=True)" ] }, { "cell_type": "code", "execution_count": 86, "id": "likely-moscow", "metadata": {}, "outputs": [], "source": [ "X_fulltrain = df_rank.loc[fulltrain_idx[0], cols]\n", "y_fulltrain = df_rank.loc[fulltrain_idx[0], 'exret']\n", "X_test = df_rank.loc[test_idx[0],cols]\n", "y_test = df_rank.loc[test_idx[0],'exret']" ] }, { "cell_type": "code", "execution_count": 87, "id": "union-jamaica", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 4 s, sys: 345 ms, total: 4.34 s\n", "Wall time: 1.13 s\n" ] }, { "data": { "text/plain": [ "GridSearchCV(cv=[(array([ 0, 1, 2, ..., 86282, 86283, 86284]),\n", " array([ 86285, 86286, 86287, ..., 203045, 203046, 203047]))],\n", " estimator=Pipeline(steps=[('pca', PCA()),\n", " ('linear_regression',\n", " LinearRegression())]),\n", " param_grid={'pca__n_components': range(1, 10)},\n", " return_train_score=True, scoring=make_scorer(r2_oos))" ] }, "execution_count": 87, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "grid_search.fit(X=X_fulltrain,y=y_fulltrain)" ] }, { "cell_type": "code", "execution_count": 88, "id": "beginning-compromise", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'pca__n_components': 9}" ] }, "execution_count": 88, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grid_search.best_params_" ] }, { "cell_type": "code", "execution_count": 89, "id": "written-treaty", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.06047645010846184 {'pca__n_components': 1}\n", "0.05648442569524177 {'pca__n_components': 2}\n", "0.05505885643926543 {'pca__n_components': 3}\n", "0.06827088360625753 {'pca__n_components': 4}\n", "0.07112802051912726 {'pca__n_components': 5}\n", "0.07107450847527291 {'pca__n_components': 6}\n", "0.06906175137901698 {'pca__n_components': 7}\n", "0.07715717389016709 {'pca__n_components': 8}\n", "0.07815893155696621 {'pca__n_components': 9}\n" ] } ], "source": [ "cv_results = grid_search.cv_results_\n", "for mean_score, params in zip(cv_results['mean_test_score'],\n", " cv_results['params']):\n", " print(np.sqrt(mean_score), params)" ] }, { "cell_type": "code", "execution_count": 90, "id": "impressed-guide", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "-0.01117901387512199" ] }, "execution_count": 90, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_pred = grid_search.predict(X_test)\n", "r2_oos(y_true=y_test, y_pred=y_pred)" ] }, { "cell_type": "markdown", "id": "adverse-settlement", "metadata": {}, "source": [ "## Elastic Net" ] }, { "cell_type": "code", "execution_count": 98, "id": "rapid-matrix", "metadata": {}, "outputs": [], "source": [ "X_fulltrain = df_rank.loc[fulltrain_idx[0], cols]\n", "y_fulltrain = df_rank.loc[fulltrain_idx[0], 'exret']\n", "X_test = df_rank.loc[test_idx[0],cols]\n", "y_test = df_rank.loc[test_idx[0],'exret']" ] }, { "cell_type": "code", "execution_count": 99, "id": "conservative-synthesis", "metadata": {}, "outputs": [], "source": [ "model = SGDRegressor(penalty='elasticnet')" ] }, { "cell_type": "code", "execution_count": 100, "id": "retained-values", "metadata": {}, "outputs": [], "source": [ "hyperparam_grid = [{'alpha':[0.0001, 0.001, 0.01, 0.1],\n", " 'l1_ratio':[0.15, 0.30, 0.5]}]" ] }, { "cell_type": "code", "execution_count": 101, "id": "hairy-water", "metadata": {}, "outputs": [], "source": [ "grid_search = GridSearchCV(model, hyperparam_grid, cv=[cv_idx[0]], \n", " scoring=r2_oos_scorer,\n", " return_train_score=True)" ] }, { "cell_type": "code", "execution_count": 102, "id": "brutal-alarm", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GridSearchCV(cv=[(array([ 0, 1, 2, ..., 86282, 86283, 86284]),\n", " array([ 86285, 86286, 86287, ..., 203045, 203046, 203047]))],\n", " estimator=SGDRegressor(penalty='elasticnet'),\n", " param_grid=[{'alpha': [0.0001, 0.001, 0.01, 0.1],\n", " 'l1_ratio': [0.15, 0.3, 0.5]}],\n", " return_train_score=True, scoring=make_scorer(r2_oos))" ] }, "execution_count": 102, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grid_search.fit(X=X_fulltrain, y=y_fulltrain)" ] }, { "cell_type": "code", "execution_count": 103, "id": "integral-stability", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'alpha': 0.001, 'l1_ratio': 0.15}" ] }, "execution_count": 103, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grid_search.best_params_" ] }, { "cell_type": "code", "execution_count": 104, "id": "chubby-replica", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "-0.010007949151975781" ] }, "execution_count": 104, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_pred = grid_search.predict(X_test)\n", "r2_oos(y_true=y_test, y_pred=y_pred)" ] }, { "cell_type": "code", "execution_count": 105, "id": "impaired-meditation", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test year 2016 : -0.01183600331821233\n", "Test year 2017 : -0.0754301490157141\n", "Test year 2018 : -0.045485803118642476\n", "Test year 2019 : 0.003638036962217317\n", "Test year 2020 : 0.006039585416782511\n", "Test year 2021 : -0.013272185902106548\n", "CPU times: user 1min 3s, sys: 1.49 s, total: 1min 4s\n", "Wall time: 22.7 s\n" ] } ], "source": [ "%%time\n", "for i in range(len(fulltrain_idx)):\n", " X_fulltrain = df_rank.loc[fulltrain_idx[i], cols]\n", " y_fulltrain = df_rank.loc[fulltrain_idx[i], 'exret']\n", " X_test = df_rank.loc[test_idx[i], cols]\n", " y_test = df_rank.loc[test_idx[i], 'exret']\n", " \n", " grid_search = GridSearchCV(model, hyperparam_grid, cv=[cv_idx[i]], \n", " scoring=r2_oos_scorer,\n", " return_train_score=True)\n", " grid_search.fit(X_fulltrain, y_fulltrain)\n", " y_pred = grid_search.predict(X=X_test)\n", " y_pred = y_pred.reshape(-1)\n", " print(\"Test year\", test_years[i],\":\",r2_oos(y_true=y_test, y_pred=y_pred))" ] }, { "cell_type": "markdown", "id": "friendly-plymouth", "metadata": {}, "source": [ "## Gradient Boosted Regression Trees" ] }, { "cell_type": "code", "execution_count": 113, "id": "thousand-object", "metadata": {}, "outputs": [], "source": [ "X_fulltrain = df_rank.loc[fulltrain_idx[0], cols]\n", "y_fulltrain = df_rank.loc[fulltrain_idx[0], 'exret']\n", "X_test = df_rank.loc[test_idx[0],cols]\n", "y_test = df_rank.loc[test_idx[0],'exret']" ] }, { "cell_type": "code", "execution_count": 105, "id": "phantom-dance", "metadata": {}, "outputs": [], "source": [ "hyperparam_grid = [\n", " {'max_depth': [1,2,3,4,5,6], \n", " 'learning_rate': [0.1, 0.05, 0.01]}\n", "]" ] }, { "cell_type": "code", "execution_count": 106, "id": "aboriginal-healthcare", "metadata": {}, "outputs": [], "source": [ "model = GradientBoostingRegressor()" ] }, { "cell_type": "code", "execution_count": 107, "id": "unusual-division", "metadata": {}, "outputs": [], "source": [ "grid_search = GridSearchCV(model, hyperparam_grid, cv=[cv_idx[0]], \n", " scoring=r2_oos_scorer,\n", " return_train_score=True)" ] }, { "cell_type": "code", "execution_count": 108, "id": "according-elizabeth", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 8min 28s, sys: 1.86 s, total: 8min 30s\n", "Wall time: 8min 33s\n" ] }, { "data": { "text/plain": [ "GridSearchCV(cv=[(array([ 0, 1, 2, ..., 86282, 86283, 86284]),\n", " array([ 86285, 86286, 86287, ..., 203045, 203046, 203047]))],\n", " estimator=GradientBoostingRegressor(),\n", " param_grid=[{'learning_rate': [0.1, 0.05, 0.01],\n", " 'max_depth': [1, 2, 3, 4, 5, 6]}],\n", " return_train_score=True, scoring=make_scorer(r2_oos))" ] }, "execution_count": 108, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "grid_search.fit(X=X_fulltrain, y=y_fulltrain)" ] }, { "cell_type": "code", "execution_count": 109, "id": "cutting-description", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'learning_rate': 0.1, 'max_depth': 1}" ] }, "execution_count": 109, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grid_search.best_params_" ] }, { "cell_type": "code", "execution_count": 110, "id": "established-night", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.06543277039809901 {'learning_rate': 0.1, 'max_depth': 1}\n", "0.02000841809692018 {'learning_rate': 0.1, 'max_depth': 2}\n", "nan {'learning_rate': 0.1, 'max_depth': 3}\n", "nan {'learning_rate': 0.1, 'max_depth': 4}\n", "nan {'learning_rate': 0.1, 'max_depth': 5}\n", "nan {'learning_rate': 0.1, 'max_depth': 6}\n", "0.06399724146424032 {'learning_rate': 0.05, 'max_depth': 1}\n", "0.05965646878537504 {'learning_rate': 0.05, 'max_depth': 2}\n", "0.04302709526703145 {'learning_rate': 0.05, 'max_depth': 3}\n", "nan {'learning_rate': 0.05, 'max_depth': 4}\n", "nan {'learning_rate': 0.05, 'max_depth': 5}\n", "nan {'learning_rate': 0.05, 'max_depth': 6}\n", "0.05053331785726358 {'learning_rate': 0.01, 'max_depth': 1}\n", "0.05778650929940647 {'learning_rate': 0.01, 'max_depth': 2}\n", "0.05718740382416022 {'learning_rate': 0.01, 'max_depth': 3}\n", "0.057024598216398666 {'learning_rate': 0.01, 'max_depth': 4}\n", "0.05472527792527635 {'learning_rate': 0.01, 'max_depth': 5}\n", "0.05257540418160582 {'learning_rate': 0.01, 'max_depth': 6}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":4: RuntimeWarning: invalid value encountered in sqrt\n", " print(np.sqrt(mean_score), params)\n" ] } ], "source": [ "cv_results = grid_search.cv_results_\n", "for mean_score, params in zip(cv_results['mean_test_score'],\n", " cv_results['params']):\n", " print(mean_score, params)" ] }, { "cell_type": "code", "execution_count": 111, "id": "senior-apparatus", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "-0.01074611143132742" ] }, "execution_count": 111, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_pred = grid_search.predict(X_test)\n", "r2_oos(y_true=y_test, y_pred=y_pred)" ] }, { "cell_type": "markdown", "id": "solar-helena", "metadata": {}, "source": [ "## Neural Nets" ] }, { "cell_type": "code", "execution_count": 106, "id": "alive-newton", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'2.4.1'" ] }, "execution_count": 106, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.__version__" ] }, { "cell_type": "code", "execution_count": 107, "id": "reverse-trick", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'2.4.0'" ] }, "execution_count": 107, "metadata": {}, "output_type": "execute_result" } ], "source": [ "keras.__version__" ] }, { "cell_type": "code", "execution_count": 108, "id": "adult-classroom", "metadata": {}, "outputs": [], "source": [ "X_fulltrain = df_rank.loc[fulltrain_idx[0], cols]\n", "y_fulltrain = df_rank.loc[fulltrain_idx[0], 'exret']\n", "X_train = X_fulltrain.values[cv_idx[0][0]]\n", "y_train = y_fulltrain.values[cv_idx[0][0]]\n", "X_val = X_fulltrain.values[cv_idx[0][1]]\n", "y_val = y_fulltrain.values[cv_idx[0][1]]\n", "X_test = df_rank.loc[test_idx[0],cols]\n", "y_test = df_rank.loc[test_idx[0],'exret']" ] }, { "cell_type": "code", "execution_count": 109, "id": "developing-mexican", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(86285, 9)" ] }, "execution_count": 109, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train.shape" ] }, { "cell_type": "code", "execution_count": 110, "id": "wrong-bishop", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(116763, 9)" ] }, "execution_count": 110, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_val.shape" ] }, { "cell_type": "code", "execution_count": 114, "id": "ignored-clinton", "metadata": {}, "outputs": [], "source": [ "nn_model = keras.models.Sequential()\n", "nn_model.add(keras.layers.InputLayer(input_shape=[X_fulltrain.shape[1]]))\n", "nn_model.add(keras.layers.Dense(32, activation='relu'))\n", "nn_model.add(keras.layers.Dense(16, activation='relu'))\n", "nn_model.add(keras.layers.Dense(1))" ] }, { "cell_type": "code", "execution_count": 115, "id": "coupled-colon", "metadata": {}, "outputs": [], "source": [ "nn_model.compile(loss='mse',optimizer='sgd')" ] }, { "cell_type": "code", "execution_count": 116, "id": "after-intranet", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0334 - val_loss: 0.0259\n", "Epoch 2/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0237 - val_loss: 0.0252\n", "Epoch 3/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0234 - val_loss: 0.0249\n", "Epoch 4/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0228 - val_loss: 0.0247\n", "Epoch 5/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0225 - val_loss: 0.0246\n", "Epoch 6/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0225 - val_loss: 0.0242\n", "Epoch 7/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0224 - val_loss: 0.0243\n", "Epoch 8/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0224 - val_loss: 0.0244\n", "Epoch 9/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0225 - val_loss: 0.0239\n", "Epoch 10/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0224 - val_loss: 0.0246\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 116, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nn_model.fit(X_train, y_train, epochs=10,\n", " validation_data=(X_val,y_val))" ] }, { "cell_type": "code", "execution_count": 117, "id": "exciting-cisco", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "-0.010190283642224518" ] }, "execution_count": 117, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_pred = nn_model.predict(X_test).reshape(-1)\n", "r2_oos(y_true=y_test, y_pred=y_pred)" ] }, { "cell_type": "markdown", "id": "dominant-cache", "metadata": {}, "source": [ "### GridSeachCV Neural Nets" ] }, { "cell_type": "code", "execution_count": 119, "id": "tribal-reunion", "metadata": {}, "outputs": [], "source": [ "def build_model(learning_rate=0.003):\n", " nn_model = keras.models.Sequential()\n", " nn_model.add(keras.layers.InputLayer(input_shape=[9]))\n", " nn_model.add(keras.layers.Dense(32, activation='relu'))\n", " nn_model.add(keras.layers.Dense(16, activation='relu'))\n", " nn_model.add(keras.layers.Dense(1))\n", " optimizer = keras.optimizers.SGD(lr=learning_rate) \n", " nn_model.compile(loss=\"mse\", optimizer=optimizer)\n", " return nn_model" ] }, { "cell_type": "code", "execution_count": 120, "id": "italic-blind", "metadata": {}, "outputs": [], "source": [ "keras_reg = keras.wrappers.scikit_learn.KerasRegressor(build_model)" ] }, { "cell_type": "code", "execution_count": 121, "id": "uniform-estonia", "metadata": {}, "outputs": [], "source": [ "hyperparams_grid = {\n", " 'learning_rate':[0.003,0.001]\n", "}" ] }, { "cell_type": "code", "execution_count": 122, "id": "mysterious-carter", "metadata": {}, "outputs": [], "source": [ "nn_search_cv = GridSearchCV(keras_reg, hyperparams_grid, cv=[cv_idx[0]])" ] }, { "cell_type": "code", "execution_count": 123, "id": "differential-sector", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0426 - val_loss: 0.0288\n", "Epoch 2/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0258 - val_loss: 0.0268\n", "Epoch 3/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0244 - val_loss: 0.0262\n", "Epoch 4/10\n", "2697/2697 [==============================] - 4s 1ms/step - loss: 0.0238 - val_loss: 0.0256\n", "Epoch 5/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0239 - val_loss: 0.0255\n", "Epoch 6/10\n", "2697/2697 [==============================] - 4s 1ms/step - loss: 0.0237 - val_loss: 0.0255\n", "Epoch 7/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0234 - val_loss: 0.0253\n", "Epoch 8/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0233 - val_loss: 0.0249\n", "Epoch 9/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0233 - val_loss: 0.0248\n", "Epoch 10/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0233 - val_loss: 0.0250\n", "3649/3649 [==============================] - 3s 760us/step - loss: 0.0250\n", "Epoch 1/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0562 - val_loss: 0.0356\n", "Epoch 2/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0322 - val_loss: 0.0305\n", "Epoch 3/10\n", "2697/2697 [==============================] - 3s 1ms/step - loss: 0.0276 - val_loss: 0.0284\n", "Epoch 4/10\n", "1268/2697 [=============>................] - ETA: 0s - loss: 0.0265" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m nn_search_cv.fit(X_fulltrain, y_fulltrain, epochs=10,\n\u001b[0m\u001b[1;32m 2\u001b[0m validation_data=(X_val,y_val))\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 70\u001b[0m FutureWarning)\n\u001b[1;32m 71\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 72\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 73\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minner_f\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, groups, **fit_params)\u001b[0m\n\u001b[1;32m 734\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 735\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 736\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_search\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 737\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 738\u001b[0m \u001b[0;31m# For multi-metric evaluation, store the best_index_, best_params_ and\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36m_run_search\u001b[0;34m(self, evaluate_candidates)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_run_search\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1187\u001b[0m \u001b[0;34m\"\"\"Search all candidates in param_grid\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1188\u001b[0;31m \u001b[0mevaluate_candidates\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mParameterGrid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparam_grid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1189\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mevaluate_candidates\u001b[0;34m(candidate_params)\u001b[0m\n\u001b[1;32m 706\u001b[0m n_splits, n_candidates, n_candidates * n_splits))\n\u001b[1;32m 707\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 708\u001b[0;31m out = parallel(delayed(_fit_and_score)(clone(base_estimator),\n\u001b[0m\u001b[1;32m 709\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 710\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtest\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 1049\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_iterating\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_original_iterator\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1050\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1051\u001b[0;31m \u001b[0;32mwhile\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_one_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1052\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1053\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36mdispatch_one_batch\u001b[0;34m(self, iterator)\u001b[0m\n\u001b[1;32m 864\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 865\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 866\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtasks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 867\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 868\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m_dispatch\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 782\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_lock\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 783\u001b[0m \u001b[0mjob_idx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 784\u001b[0;31m \u001b[0mjob\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 785\u001b[0m \u001b[0;31m# A job can complete so quickly than its callback is\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 786\u001b[0m \u001b[0;31m# called before we get here, causing self._jobs to\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/joblib/_parallel_backends.py\u001b[0m in \u001b[0;36mapply_async\u001b[0;34m(self, func, callback)\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mapply_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[0;34m\"\"\"Schedule a func to be run\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 208\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImmediateResult\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 209\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/joblib/_parallel_backends.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 570\u001b[0m \u001b[0;31m# Don't delay the application, to avoid keeping the input\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 571\u001b[0m \u001b[0;31m# arguments in memory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 572\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 573\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 574\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 260\u001b[0m \u001b[0;31m# change the default number of processes to -1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mparallel_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_n_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 262\u001b[0;31m return [func(*args, **kwargs)\n\u001b[0m\u001b[1;32m 263\u001b[0m for func, args, kwargs in self.items]\n\u001b[1;32m 264\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 260\u001b[0m \u001b[0;31m# change the default number of processes to -1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mparallel_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_n_jobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 262\u001b[0;31m return [func(*args, **kwargs)\n\u001b[0m\u001b[1;32m 263\u001b[0m for func, args, kwargs in self.items]\n\u001b[1;32m 264\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/sklearn/model_selection/_validation.py\u001b[0m in \u001b[0;36m_fit_and_score\u001b[0;34m(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score, return_parameters, return_n_test_samples, return_times, return_estimator, error_score)\u001b[0m\n\u001b[1;32m 529\u001b[0m \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 531\u001b[0;31m \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 532\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 533\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/tensorflow/python/keras/wrappers/scikit_learn.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, x, y, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0mfit_args\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m \u001b[0mhistory\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 167\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mhistory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)\u001b[0m\n\u001b[1;32m 1097\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1098\u001b[0m _r=1):\n\u001b[0;32m-> 1099\u001b[0;31m \u001b[0mcallbacks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_train_batch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1100\u001b[0m \u001b[0mtmp_logs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1101\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdata_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshould_sync\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/tensorflow/python/keras/callbacks.py\u001b[0m in \u001b[0;36mon_train_batch_begin\u001b[0;34m(self, batch, logs)\u001b[0m\n\u001b[1;32m 442\u001b[0m \"\"\"\n\u001b[1;32m 443\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_should_call_train_batch_hooks\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 444\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_batch_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mModeKeys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTRAIN\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'begin'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlogs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlogs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 445\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 446\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mon_train_batch_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlogs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/tensorflow/python/keras/callbacks.py\u001b[0m in \u001b[0;36m_call_batch_hook\u001b[0;34m(self, mode, hook, batch, logs)\u001b[0m\n\u001b[1;32m 292\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 293\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'begin'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 294\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_batch_begin_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlogs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 295\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'end'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 296\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_batch_end_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlogs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/tensorflow/python/keras/callbacks.py\u001b[0m in \u001b[0;36m_call_batch_begin_hook\u001b[0;34m(self, mode, batch, logs)\u001b[0m\n\u001b[1;32m 301\u001b[0m \u001b[0;34m\"\"\"Helper function for `on_*_batch_begin` methods.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 302\u001b[0m \u001b[0mhook_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'on_{mode}_batch_begin'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 303\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_batch_hook_helper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhook_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlogs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 304\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_timing\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/tensorflow/python/keras/callbacks.py\u001b[0m in \u001b[0;36m_call_batch_hook_helper\u001b[0;34m(self, hook_name, batch, logs)\u001b[0m\n\u001b[1;32m 353\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcallback\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcallback\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 355\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcallback\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'_supports_tf_logs'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 356\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlogs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 357\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "nn_search_cv.fit(X_fulltrain, y_fulltrain, epochs=10,\n", " validation_data=(X_val,y_val))" ] }, { "cell_type": "code", "execution_count": 124, "id": "novel-edinburgh", "metadata": {}, "outputs": [ { "ename": "AttributeError", "evalue": "'GridSearchCV' object has no attribute 'best_estimator_'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0my_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn_search_cv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mr2_oos\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/sklearn/utils/metaestimators.py\u001b[0m in \u001b[0;36m\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0;31m# lambda, but not partial, allows help() to work with update_wrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 119\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 120\u001b[0m \u001b[0;31m# update the docstring of the returned function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[0mupdate_wrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/envs/tf2/lib/python3.8/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 485\u001b[0m \"\"\"\n\u001b[1;32m 486\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_is_fitted\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'predict'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 487\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbest_estimator_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 488\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mif_delegate_has_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdelegate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'best_estimator_'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'estimator'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mAttributeError\u001b[0m: 'GridSearchCV' object has no attribute 'best_estimator_'" ] } ], "source": [ "y_pred = nn_search_cv.predict(X_test).reshape(-1)\n", "r2_oos(y_true=y_test, y_pred=y_pred)" ] }, { "cell_type": "markdown", "id": "brave-directory", "metadata": {}, "source": [ "# Transformation pipeline example" ] }, { "cell_type": "code", "execution_count": 125, "id": "conditional-frank", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
secIDret_dateexretymsizerevmombetabmilliqilliq_12mvolivolyear
0000001.XSHE2008-02-0.0074502008-010.974984-0.4154050.242923-0.133641-0.608953-0.964450-0.980250-0.4799210.0000002008
1000001.XSHE2008-03-0.1520682008-020.971391-0.6579970.4629390.431730-0.560468-0.962289-0.9804940.465540-0.4993502008
2000001.XSHE2008-040.0474932008-030.9727450.4847500.5872810.161583-0.162881-0.977936-0.9792340.2809860.5509412008
3000001.XSHE2008-05-0.1511642008-040.9687090.5436770.601043-0.211213-0.294654-0.973924-0.979140-0.580183-0.8305082008
4000001.XSHE2008-06-0.2369612008-050.967617-0.8005180.6696890.246114-0.137306-0.950777-0.979275-0.155440-0.6049222008
.............................................
419502900957.XSHG2020-100.0035732020-09-0.9906150.225488-0.6646090.0000000.0000000.9866630.987157-0.0456900.0940972020
419503900957.XSHG2020-110.0112022020-10-0.9893150.238951-0.6556580.0002430.0000000.9868870.987373-0.600291-0.6682862020
419504900957.XSHG2020-12-0.0383732020-11-0.989375-0.100700-0.4904610.0000000.0000000.9893750.987443-0.325767-0.3204542020
419505900957.XSHG2021-010.3309732020-12-0.9899140.217099-0.4913540.0000000.0000000.9894330.987512-0.973103-0.8477432021
419506900957.XSHG2021-020.1012842021-01-0.9881400.974858-0.4454460.0000000.0000000.9805500.9871920.6930740.7722962021
\n", "

412946 rows × 14 columns

\n", "
" ], "text/plain": [ " secID ret_date exret ym size rev mom \\\n", "0 000001.XSHE 2008-02 -0.007450 2008-01 0.974984 -0.415405 0.242923 \n", "1 000001.XSHE 2008-03 -0.152068 2008-02 0.971391 -0.657997 0.462939 \n", "2 000001.XSHE 2008-04 0.047493 2008-03 0.972745 0.484750 0.587281 \n", "3 000001.XSHE 2008-05 -0.151164 2008-04 0.968709 0.543677 0.601043 \n", "4 000001.XSHE 2008-06 -0.236961 2008-05 0.967617 -0.800518 0.669689 \n", "... ... ... ... ... ... ... ... \n", "419502 900957.XSHG 2020-10 0.003573 2020-09 -0.990615 0.225488 -0.664609 \n", "419503 900957.XSHG 2020-11 0.011202 2020-10 -0.989315 0.238951 -0.655658 \n", "419504 900957.XSHG 2020-12 -0.038373 2020-11 -0.989375 -0.100700 -0.490461 \n", "419505 900957.XSHG 2021-01 0.330973 2020-12 -0.989914 0.217099 -0.491354 \n", "419506 900957.XSHG 2021-02 0.101284 2021-01 -0.988140 0.974858 -0.445446 \n", "\n", " beta bm illiq illiq_12m vol ivol year \n", "0 -0.133641 -0.608953 -0.964450 -0.980250 -0.479921 0.000000 2008 \n", "1 0.431730 -0.560468 -0.962289 -0.980494 0.465540 -0.499350 2008 \n", "2 0.161583 -0.162881 -0.977936 -0.979234 0.280986 0.550941 2008 \n", "3 -0.211213 -0.294654 -0.973924 -0.979140 -0.580183 -0.830508 2008 \n", "4 0.246114 -0.137306 -0.950777 -0.979275 -0.155440 -0.604922 2008 \n", "... ... ... ... ... ... ... ... \n", "419502 0.000000 0.000000 0.986663 0.987157 -0.045690 0.094097 2020 \n", "419503 0.000243 0.000000 0.986887 0.987373 -0.600291 -0.668286 2020 \n", "419504 0.000000 0.000000 0.989375 0.987443 -0.325767 -0.320454 2020 \n", "419505 0.000000 0.000000 0.989433 0.987512 -0.973103 -0.847743 2021 \n", "419506 0.000000 0.000000 0.980550 0.987192 0.693074 0.772296 2021 \n", "\n", "[412946 rows x 14 columns]" ] }, "execution_count": 125, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_rank" ] }, { "cell_type": "code", "execution_count": 126, "id": "chronic-california", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "5" ] }, "execution_count": 126, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_fulltrain.columns.tolist().index('illiq')" ] }, { "cell_type": "code", "execution_count": 127, "id": "fitted-fairy", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6" ] }, "execution_count": 127, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_fulltrain.columns.tolist().index('illiq_12m')" ] }, { "cell_type": "code", "execution_count": 128, "id": "unauthorized-person", "metadata": {}, "outputs": [], "source": [ "illiq_idx = 5\n", "illiq_12m_idx = 6" ] }, { "cell_type": "code", "execution_count": 129, "id": "indian-spokesman", "metadata": {}, "outputs": [], "source": [ "class FeatureAdder(BaseEstimator, TransformerMixin):\n", " def __init__(self, add_avg_illiq=True):\n", " self.add_avg_illiq = add_avg_illiq\n", " def fit(self, X, y=None):\n", " return self\n", " def transform(self, X, y=None):\n", " avg_illiq = (X[:,illiq_idx] + X[:, illiq_12m_idx]) / 2\n", " return np.c_[X, avg_illiq]\n", "\n", "feature_adder = FeatureAdder()" ] }, { "cell_type": "code", "execution_count": 130, "id": "certified-fiction", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(203048, 9)" ] }, "execution_count": 130, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_fulltrain.values.shape" ] }, { "cell_type": "code", "execution_count": 131, "id": "continuing-desperate", "metadata": {}, "outputs": [], "source": [ "X_fulltrain_new = feature_adder.transform(X_fulltrain.values)" ] }, { "cell_type": "code", "execution_count": 132, "id": "completed-creator", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 0.97498354, -0.41540487, 0.24292298, ..., -0.479921 ,\n", " 0. , -0.97235023],\n", " [ 0.97139142, -0.6579974 , 0.46293888, ..., 0.46553966,\n", " -0.4993498 , -0.97139142],\n", " [ 0.97274497, 0.48475016, 0.58728099, ..., 0.28098637,\n", " 0.55094095, -0.97858533],\n", " ...,\n", " [-0.98812049, 0.42299533, -0.4900297 , ..., -0.46457361,\n", " -0.51887993, 0.98515062],\n", " [-0.98844884, 0.3019802 , 0.42079208, ..., -0.20379538,\n", " 0.19059406, 0.98638614],\n", " [-0.98793727, 0.68234821, 0.61640531, ..., -0.20305589,\n", " 0.19018898, 0.977885 ]])" ] }, "execution_count": 132, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_fulltrain_new" ] }, { "cell_type": "code", "execution_count": 133, "id": "wooden-driver", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(203048, 10)" ] }, "execution_count": 133, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_fulltrain_new.shape" ] }, { "cell_type": "code", "execution_count": 134, "id": "sporting-celebrity", "metadata": {}, "outputs": [], "source": [ "# This can be added to a pipeline\n", "pipeline = Pipeline([\n", " ('feature_adder', FeatureAdder()),\n", " ('std_scaler', StandardScaler())\n", "])" ] }, { "cell_type": "code", "execution_count": 135, "id": "magnetic-strip", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 1.68951131e+00, -7.19839143e-01, 4.22125848e-01, ...,\n", " -8.31636558e-01, -1.94790357e-18, -1.76971464e+00],\n", " [ 1.68328665e+00, -1.14021841e+00, 8.04446213e-01, ...,\n", " 8.06715692e-01, -8.68622094e-01, -1.76796957e+00],\n", " [ 1.68563218e+00, 8.40004933e-01, 1.02051477e+00, ...,\n", " 4.86910428e-01, 9.58365208e-01, -1.78106277e+00],\n", " ...,\n", " [-1.71227582e+00, 7.32992362e-01, -8.51521769e-01, ...,\n", " -8.05041659e-01, -9.02594872e-01, 1.79301184e+00],\n", " [-1.71284481e+00, 5.23289884e-01, 7.31207958e-01, ...,\n", " -3.53149139e-01, 3.31539553e-01, 1.79526054e+00],\n", " [-1.71195833e+00, 1.18241500e+00, 1.07112393e+00, ...,\n", " -3.51867707e-01, 3.30834919e-01, 1.77978814e+00]])" ] }, "execution_count": 135, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline.fit_transform(X_fulltrain.values)" ] }, { "cell_type": "code", "execution_count": null, "id": "promising-newport", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" }, "toc-autonumbering": true }, "nbformat": 4, "nbformat_minor": 5 }