TensorFlow创建自定义类继承tf.layers.Layer创建新的layer层,自定义类继承keras.Model创建自定义model

时间:3年前   阅读:9548

tf.layers.Layer类是tf.layers里所有层都继承的基类,实现了通用的基础功能。用户只需要实例化它,就可以直接调用得到的实例。

Layer的子类一般这样子实现:

__init__():先初始化父类。然后在成员变量中保存配置。

build():一般用于初始化层内的参数和变量。在调用call()方法前,类会自动调用该方法。在该方法末尾需要设置self.built = True,保证build()方法只被调用一次。

call():用于定义层对输入张量的实际操作。

下面是我们自定义一个全连接层的例子。(self.add_weight的参数name一定要定义,否则model.save_weights("./weight/07/07.weight")会报错,我错误找了好久)

class MyDense(keras.layers.Layer):
    def __init__(self, outdim):
        super().__init__()
        self.outdim = outdim
        
    def build(self, input_shape):
        self.indim = int(input_shape[-1])
            
        self.kernel = self.add_weight(
            name="w", 
            shape=[self.indim, self.outdim], 
            dtype=tf.float32, 
            initializer=tf.random_normal_initializer()
        )
        self.built = True 
        
    def call(self, inputs):
        inputs = tf.cast(inputs, dtype=tf.float32)
        return inputs@self.kernel
class MyModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.f1 = MyDense(256)
        self.f2 = MyDense(256)
        self.f3 = MyDense(128)
        self.f4 = MyDense(32)
        self.f5 = MyDense(10)
    def call(self, inputs):
        inputs = tf.reshape(inputs, [-1, 32*32*3])
        out = self.f1(inputs)
        out = tf.nn.relu(out)
        out = self.f2(out)
        out = tf.nn.relu(out)
        out = self.f3(out)
        out = tf.nn.relu(out)
        out = self.f4(out)
        out = tf.nn.relu(out)
        out = self.f5(out)
        
        self.out = out
        return out
model = MyModel()
model.build([None, 32*32*3])
model.summary()
model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss=keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"]
)
model.fit(db, epochs=5, validation_data=db_test)

版权声明:本文为期权记的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。

原文链接:https://www.qiquanji.com/post/9673.html

微信扫码关注

更新实时通知

上一篇:(tensorflow)tf.keras.callbacks.ModelCheckpoint在训练期间保存模型

下一篇:deepfacelab时用FFmpeg 将大量图片合成为视频 video作为data_src

网友评论

请先 登录 再评论,若不是会员请先 注册