Tensorflow实现部分参数梯度更新操作-创新互联

在深度学习中,迁移学习经常被使用,在大数据集上预训练的模型迁移到特定的任务,往往需要保持模型参数不变,而微调与任务相关的模型层。

创新互联建站成立以来不断整合自身及行业资源、不断突破观念以使企业策略得到完善和成熟,建立了一套“以技术为基点,以客户需求中心、市场为导向”的快速反应体系。对公司的主营项目,如中高端企业网站企划 / 设计、行业 / 企业门户设计推广、行业门户平台运营、成都app软件开发成都做手机网站、微信网站制作、软件开发、遂宁服务器托管等实行标准化操作,让客户可以直观的预知到从创新互联建站可以获得的服务效果。

本文主要介绍,使用tensorflow部分更新模型参数的方法。

1. 根据Variable scope剔除需要固定参数的变量

def get_variable_via_scope(scope_lst):
  vars = []
  for sc in scope_lst:
    sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope)
    vars.extend(sc_variable)
  return vars
 
trainable_vars = tf.trainable_variables()
no_change_scope = ['your_unchange_scope_name']
 
no_change_vars = get_variable_via_scope(no_change_scope)
 
for v in no_change_vars:
  trainable_vars.remove(v)
 
grads, _ = tf.gradients(loss, trainable_vars)
 
optimizer = tf.train.AdamOptimizer(lr)
 
train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)

另外有需要云服务器可以了解下创新互联scvps.cn,海内外云服务器15元起步,三天无理由+7*72小时售后在线,公司持有idc许可证,提供“云服务器、裸金属服务器、高防服务器、香港服务器、美国服务器、虚拟主机、免备案服务器”等云主机租用服务以及企业上云的综合解决方案,具有“安全稳定、简单易用、服务可用性高、性价比高”等特点与优势,专为企业上云打造定制,能够满足用户丰富、多元化的应用场景需求。


分享文章:Tensorflow实现部分参数梯度更新操作-创新互联
URL链接:http://azwzsj.com/article/dpgsdg.html