You may use Dataset.take()
and Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(test_size)
test_dataset = test_dataset.take(test_size)
For more generality, I gave an example using a 70/15/15 train/val/test split but if you don't need a test or a val set, just ignore the last 2 lines.
Take:
Creates a Dataset with at most count elements from this dataset.
Skip:
Creates a Dataset that skips count elements from this dataset.
You may also want to look into Dataset.shard()
:
Creates a Dataset that includes only 1/num_shards of this dataset.
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…