Keras扩展嵌入层输入

时间:2019-02-06 10:25:19

标签: python tensorflow keras word-embedding

需要从当前已知的权重开始重新训练具有嵌入的keras顺序模型。

在提供的(文本)训练数据上训练Keras顺序模型。训练数据由(定制的)标记器标记。模型中第一层(嵌入层)的输入维是令牌生成器已知的单词数。

几天后,可以获得其他培训数据。由于可能包含其他单词,因此需要根据此新数据重新调整令牌生成器。这意味着嵌入层的输入尺寸会发生变化,因此先前训练的模型不再可用。

package de.phs.issues.model;

import com.fasterxml.jackson.annotation.JsonIdentityInfo;
import com.fasterxml.jackson.annotation.JsonIdentityReference;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.ObjectIdGenerators;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import lombok.Data;
import lombok.EqualsAndHashCode;

import javax.persistence.*;
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

@Entity
@Data
@EqualsAndHashCode(exclude = {"issues", "comments"})
@Table(name="`user`")
@JsonIgnoreProperties({"hibernateLazyInitializer", "handler"})
@JsonIdentityInfo(generator = ObjectIdGenerators.PropertyGenerator.class, property = "id", scope = User.class)
public class User {

    @Column(nullable = false, unique = true)
    private String name;

    @Id
    @GeneratedValue(strategy = GenerationType.SEQUENCE)
    private long id;

    @OneToMany(fetch = FetchType.EAGER, mappedBy = "creator", cascade = CascadeType.PERSIST)
    @JsonIdentityInfo(generator=ObjectIdGenerators.PropertyGenerator.class, property="id")
    @JsonIdentityReference(alwaysAsId=true)
    private Set<Issue> issues = new HashSet<>();

    @OneToMany(fetch = FetchType.EAGER, mappedBy = "user", cascade = CascadeType.PERSIST)
    @JsonIdentityInfo(generator=ObjectIdGenerators.PropertyGenerator.class, property="id")
    @JsonIdentityReference(alwaysAsId=true)
    private Set<Comment> comments = new HashSet<>();

    public String toJson() throws JsonProcessingException {
        return new ObjectMapper()
                .enable(SerializationFeature.INDENT_OUTPUT)
                .writeValueAsString(this);
    }

    public static User fromJson(String json) throws IOException {
        return new ObjectMapper()
                .configure(DeserializationFeature.USE_LONG_FOR_INTS, true)
                .readValue(json, User.class);
    }

    public boolean addIssue(Issue issue) {
        issue.setCreator(this);
        return issues.add(issue);
    }

    public boolean removeIssue(Issue issue) {
        issue.setCreator(null);
        return issues.remove(issue);
    }

    public boolean addComment(Comment comment) {
        comment.setUser(this);
        return comments.add(comment);
    }

    public boolean removeComment(Comment comment) {
        comment.setUser(null);
        return comments.remove(comment);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("User{id=");
        sb.append(id);
        sb.append(" ,name=");
        sb.append(name);
        sb.append(", issues=[");
        String issue_ids = issues.stream()
                .map(Issue::getId)
                .map(l -> Long.toString(l))
                .flatMap(s -> Stream.of(",", s))
                .skip(1)
                .collect(Collectors.joining());
        sb.append(issue_ids);
        sb.append("], comments=[");
        String comment_ids = comments.stream()
                .map(Comment::getId)
                .map(l -> Long.toString(l))
                .flatMap(s -> Stream.of(",", s))
                .skip(1)
                .collect(Collectors.joining());
        sb.append(comment_ids);
        sb.append("]}");
        return sb.toString();
    }
}

package de.phs.issues.model;

import com.fasterxml.jackson.annotation.JsonIdentityInfo;
import com.fasterxml.jackson.annotation.JsonIdentityReference;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.ObjectIdGenerators;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import lombok.Data;
import lombok.EqualsAndHashCode;

import javax.persistence.*;
import java.io.IOException;

@Data
@EqualsAndHashCode(exclude = {"user", "issue"})
@Entity
@JsonIgnoreProperties({"hibernateLazyInitializer", "handler"})
@JsonIdentityInfo(generator = ObjectIdGenerators.PropertyGenerator.class, property = "id", scope = Comment.class)
public class Comment {

    @Id
    @GeneratedValue(strategy = GenerationType.SEQUENCE)
    private long id;
    private String body;

    @ManyToOne(fetch = FetchType.EAGER, cascade = CascadeType.ALL)
    @JsonIdentityInfo(generator=ObjectIdGenerators.PropertyGenerator.class, property="id")
    @JsonIdentityReference(alwaysAsId=true)
    private User user;

    @ManyToOne(fetch = FetchType.EAGER)
    @JsonIdentityInfo(generator=ObjectIdGenerators.PropertyGenerator.class, property="id")
    @JsonIdentityReference(alwaysAsId=true)
    private Issue issue;

    public String toJson() throws JsonProcessingException {
        return new ObjectMapper()
                .enable(SerializationFeature.INDENT_OUTPUT)
                .writeValueAsString(this);
    }

    public static Comment fromJson(String json) throws IOException {
        return new ObjectMapper()
                .configure(DeserializationFeature.USE_LONG_FOR_INTS, true)
                .readValue(json, Comment.class);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Comment{body=");
        sb.append(body);
        sb.append(", id=");
        sb.append(id);
        sb.append(", user=");
        sb.append(user.getId());
        sb.append(", issue=");
        sb.append(issue.getId());
        sb.append("}");
        return sb.toString();
    }
}

我想使用先前训练过的模型作为新训练的初始化程序。对于令牌生成器中的新单词,嵌入层应仅使用随机初始化。对于分词器已经知道的单词,它应该使用以前训练有素的嵌入。

1 个答案:

答案 0 :(得分:0)

您可以使用weights = model.layers[0].get_weights()model.layers[0].set_weighs(weights之类的代码以numpy数组的形式直接访问(获取和设置)层的权重,其中model.layers[0]是您的嵌入层。这样,您可以分别存储嵌入并通过从存储的数据中复制已知嵌入来设置嵌入。