Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
285 views
in Technique[技术] by (71.8m points)

k-fold cross validation: how to filter data based on a randomly generated integer variable in Stata

The following seems obvious, yet it does not behave as I would expect. I want to do k-fold cross validation without using SCC packages, and thought I could just filter my data and run my own regressions on the subsets.

First I generate a variable with a random integer between 1 and 5 (5-fold cross validation), then I loop over each fold number. I want to filter the data by the fold number, but using a boolean filter fails to filter anything. Why?

Bonus: what would be the best way to capture all of the test MSEs and average them? In Python I would just make a list or a numpy array and take the average.

gen randint = floor((6-1)*runiform()+1)

recast int randint

forval b = 1(1)5 {
    xtreg c.DepVar ///  // training set
    c.IndVar1 ///
    c.IndVar2 ///
    if randint !=`b' ///
    , fe vce(cluster uuid)

    xtreg c.DepVar /// // test set, needs to be performed with model above, not a               
    c.IndVar1 ///      // new model...
    c.IndVar2 ///
    if randint ==`b' ///
    , fe vce(cluster uuid)
}

EDIT: Test set needs to be performed with model fit to training set. I changed my comment in the code to reflect this.

Ultimately the solution to the filtering issue was I was using a scalar in quotes to define the bounds and I had:

replace randint = floor((`varscalar'-1)*runiform()+1)

instead of just

replace randint = floor((varscalar-1)*runiform()+1)

When and where to use the quotes in Stata is confusing to me. I cannot just use varscalar in a loop, I have to use `=varscalar', but I can for some reason use varscalar - 1 and get the expected result. Interestingly, I cannot use

replace randint = floor((`varscalar')*runiform()+1)

I suppose I should just use

replace randint = floor((`=varscalar')*runiform()+1)

So why is it ok to use the version with the minus one and without the equals sign??

The answer below is still extremely helpful and I learned much from it.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

As a matter of fact, two different things are going on here that are not necessarily directly related. 1) How to filter data with a randomly generated integer value and 2) k-fold cross-validation procedure.

For the first one, I will leave an example below that could help you work things out using Stata with some tools that can be easily transferable to other problems (such as matrix generation and manipulation to store the metrics). However, I would call neither your sketch of code nor my example "k-fold cross-validation", mainly because they fit the model, both in the testing and in training data. Nonetheless, the case should be that strictly speaking, the model should be trained in the training data, and using those parameters, assess the performance of the model in testing data.

For further references on the procedure Scikit-learn has done brilliant work explaining it with several visualizations included.

That being said, here is something that could be helpful.

clear all
set seed 4
set obs 100
*Simulate model
gen x1 = rnormal()
gen x2 = rnormal()
gen y = 1 + 0.5 * x1 + 1.5 *x2 + rnormal()
gen byte randint = runiformint(1, 5)
tab randint
/*
    randint |      Freq.     Percent        Cum.
------------+-----------------------------------
          1 |         17       17.00       17.00
          2 |         18       18.00       35.00
          3 |         21       21.00       56.00
          4 |         19       19.00       75.00
          5 |         25       25.00      100.00
------------+-----------------------------------
      Total |        100      100.00 
*/
// create a matrix to store results
matrix res = J(5,4,.)
matrix colnames res = "R2_fold"  "MSE_fold" "R2_hold"  "MSE_hold"
matrix rownames res ="1" "2" "3" "4" "5"
// show formated empty matrix 
matrix li res
/*
res[5,4]
    R2_fold  MSE_fold   R2_hold  MSE_hold
1         .         .         .         .
2         .         .         .         .
3         .         .         .         .
4         .         .         .         .
5         .         .         .         .
*/

// loop over different samples
forvalues b = 1/5 {
    // run the model using fold == `b'
    qui reg y x1 x2 if randint ==`b' 
    // save R squared training
    matrix res[`b', 1] = e(r2) 
    // save rmse training
    matrix res[`b', 2] = e(rmse)  

    // run the model using fold != `b'
    qui reg y x1 x2 if randint !=`b' 
    // save R squared training (?)
    matrix res[`b', 3] = e(r2)
    // save rmse testing (?)
    matrix res[`b', 4] = e(rmse)  
}

// Show matrix with stored metrics
mat li res 
/*
res[5,4]
     R2_fold   MSE_fold    R2_hold   MSE_hold
1  .50949187  1.2877728  .74155365  1.0070531
2  .89942838  .71776458  .66401888   1.089422
3  .75542004  1.0870525  .68884359  1.0517139
4  .68140328  1.1103964  .71990589  1.0329239
5  .68816084  1.0017175  .71229925  1.0596865
*/

// some matrix algebra workout to obtain the mean of the metrics
mat U = J(rowsof(res),1,1)
mat sum = U'*res
/* create vector of column (variable) means */
mat mean_res = sum/rowsof(res)
// show the average of the metrics acros the holds
mat li mean_res
/*
mean_res[1,4]
      R2_fold   MSE_fold    R2_hold   MSE_hold
c1  .70678088  1.0409408  .70532425  1.0481599
*/



与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...