keras训练数据的方式有哪些,是怎样的
Admin 2022-07-21 群英技术资讯 528 次浏览
model.train_on_batch(batchX, batchY)
train_on_batch函数接受单批数据,执行反向传播,然后更新模型参数,该批数据的大小可以是任意的,即,它不需要提供明确的批量大小,属于精细化控制训练模型,大部分情况下我们不需要这么精细,99%情况下使用fit_generator训练方式即可,下面会介绍。
model.fit(x_train, y_train, batch_size=32, epochs=10)
fit的方式是一次把训练数据全部加载到内存中,然后每次批处理batch_size个数据来更新模型参数,epochs就不用多介绍了。这种训练方式只适合训练数据量比较小的情况下使用。
利用Python的生成器,逐个生成数据的batch并进行训练,不占用大量内存,同时生成器与模型将并行执行以提高效率。例如,该函数允许我们在CPU上进行实时的数据提升,同时在GPU上进行模型训练
接口如下:
fit_generator(self, generator, steps_per_epoch, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_q_size=10, workers=1, pickle_safe=False, initial_epoch=0)
generator
:生成器函数
steps_per_epoch
:整数,当生成器返回steps_per_epoch次数据时,计一个epoch结束,执行下一个epoch。也就是一个epoch下执行多少次batch_size。
epochs
:整数,控制数据迭代的轮数,到了就结束训练。
callbacks=None, list,list中的元素为keras.callbacks.Callback对象,在训练过程中会调用list中的回调函数
def generate_arrays_from_file(path): while True: with open(path) as f: for line in f: # create numpy arrays of input data # and labels, from each line in the file x1, x2, y = process_line(line) yield ({'input_1': x1, 'input_2': x2}, {'output': y}) model.fit_generator(generate_arrays_from_file('./my_folder'), steps_per_epoch=10000, epochs=10)
补充:keras.fit_generator()属性及取值
fit_generator(self, generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)
通过Python generator产生一批批的数据用于训练模型。generator可以和模型并行运行,例如,可以使用CPU生成批数据同时在GPU上训练模型。
generator
:一个generator或Sequence实例,为了避免在使用multiprocessing时直接复制数据。
steps_per_epoch
:从generator产生的步骤的总数(样本批次总数)。通常情况下,应该等于数据集的样本数量除以批量的大小。
epochs
:整数,在数据集上迭代的总数。
works
:在使用基于进程的线程时,最多需要启动的进程数量。
use_multiprocessing
:布尔值。当为True时,使用基于基于过程的线程。
datagen = ImageDataGenator(...) model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size), epochs=epochs, validation_data=(x_test, y_test), workers=4)
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
DQN算法是DeepMind团队提出的一种深度强化学习算法,在许多电动游戏中达到人类玩家甚至超越人类玩家的水准,本文就带领大家了解一下这个算法,快来跟随小编学习一下
这篇文章主要给大家介绍了关于Pytorch实现简单自定义网络层的相关资料,文中通过实例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
我们在学习Python语言时会遇到各种各样的字符串方法处理,下面这篇文章主要给大家介绍了关于Python基础篇之字符串的最全常用操作方法的相关资料,文中通过示例代码介绍的非常详细,需要的朋友可以参考下
Pandas 是 Python 语言的一个扩展程序库,能用来数据分析。而且pandas还提供了大量能帮助我们快速便捷地处理数据的函数和方法。我们有时候需要对excel表的列做操作,使用pandas就是能实现我们想要的功能。下面我们就一起来看看使用pandas如何调整列的顺序。
大家好,本篇文章主要讲的是Python处理excel与txt文件详解,感兴趣的同学赶快来看一看吧,对你有帮助的话记得收藏一下,方便下次浏览
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008