LoginSignup
11
22

More than 5 years have passed since last update.

kerasでCNN-LSTM-Attention

Last updated at Posted at 2017-12-31

メモがわりに書いておく。あと、そもそもこれであってるかどうか不安なので...

入力と出力

入力はある三種類のテキストで、出力は二値です。
今回は、テキストをそれぞれEmbeddingでベクトル表現に直した後、concatして、CNN-lstm-attentionしていくことを考えます。
Embeddingではfasttextの学習済みモデルを使います。以下よりダウンロードしました。ありがとうございます。
fastTextの学習済みモデルを公開しました

モジュールのバージョンなど

  • keras...2.1.2
  • tensorflow...1.4.1
  • mecabはneologd

コード


# coding: utf-8
from keras.layers import merge, Embedding, Dense, Bidirectional, Conv1D, MaxPooling1D, Multiply, Permute, Reshape, Concatenate
from keras.layers.recurrent import LSTM
import numpy as np
import pandas as pd
from keras.preprocessing.text import Tokenizer,sequence
from keras.callbacks import EarlyStopping,LambdaCallback,TensorBoard
from sklearn.metrics import roc_curve, auc, accuracy_score, f1_score, recall_score,confusion_matrix,precision_recall_fscore_support
from sklearn.model_selection import train_test_split
import argparse
from keras.utils import np_utils
import MeCab
from gensim.models import KeyedVectors
import warnings
warnings.filterwarnings("ignore")
import logging
from slack_logger import SlackHandler, SlackFormatter

def build_slack_logger():
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    sh = SlackHandler(icon_url='icon_url',
                        username='logger',
                        url='url')
    sh.setLevel(logging.DEBUG)
    f = SlackFormatter()
    sh.setFormatter(f)
    logger.addHandler(sh)
    return logger

def tokenize(text):
    wakati = MeCab.Tagger("-O wakati")
    wakati.parse("")
    return wakati.parse(text)

def load_pretrained_model():
    print("========use pre-trained model=========")
    embeddings_model = KeyedVectors.load_word2vec_format('data/fasttext.model.vec', binary=False)
    return embeddings_model

def get_embed(embeddings_model,tokenizer):
    index2word = {v:k for k, v in tokenizer.word_index.items()}
    word_index = tokenizer.word_index
    num_words = len(word_index)
    embedding_matrix = np.zeros((num_words+1, 300))
    for word, i in word_index.items():
        if word in embeddings_model.index2word:
            embedding_matrix[i] = embeddings_model[word]
    return embedding_matrix

def input_date(pre):
    df = pd.read_csv("data/crossing/test.csv",header=0)
    if pre:
        df = df[:1000]
    tokenized_text1 = df["text1"].apply(lambda x: tokenize(x).replace("\n", "")).tolist()
    tokenized_text2 = df["text2"].apply(lambda x: tokenize(x).replace("\n", "")).tolist()
    tokenized_text3 = df["text3"].apply(lambda x: tokenize(x).replace("\n", "")).tolist()
    y = create_target(df)

    tokenizer_1 = Tokenizer()
    tokenizer_2 = Tokenizer()
    tokenizer_3 = Tokenizer()

    tokenized_texts = [tokenized_1,tokenized_2,tokenized_3]
    tokenizers = [tokenizer_1,tokenizer_2,tokenizer_3]
    X_list = []
    shapes = []

    for text,tokenizer in zip(tokenized_texts,tokenizers):
        tokenizer.fit_on_texts(text)
        seq = tokenizer.texts_to_sequences(tokenized_ans)
        maxlen = max([len(x) for x in seq])
        X = sequence.pad_sequences(seq, maxlen=maxlen)
        shapes.append(X.shape[1])
        X_list.append(X)

    df_train,df_test,X_1_train,X_1_test,X_2_train,X_2_test,X_3_train,X_3_test,y_train,y_test = train_test_split(df,X_list[0],X_list[1],X_list[2],y,test_size=0.3,random_state=42)

    return shapes,tokenizers,df_train,df_test,X_1_train,X_1_test,X_2_train,X_2_test,X_3_train,X_3_test,y_train,y_test

def create_target(df):
    y = df["target"].as_matrix()
    y[y > 0] = 1.0
    y[y <= 0] = 0.0
    y = np_utils.to_categorical(y)
    return y

def build_network_CNN_bi_lstm_attention(shapes,words_list,embed_weights,num_class):
    lstm_dim = 300

    '''embeddingをconcatしてCNNからのLSTMからのattention'''
    embeddings = []
    input_list = []

    for X_shape,num_words,embedding_matrix in zip(shapes,words_list,embed_weights):
        inputs = Input(shape=(X_shape,))
        x = Embedding(input_dim=num_words+1,output_dim=lstm_dim,weights=[embedding_matrix],trainable=False)(inputs)
        embeddings.append(x)
        input_list.append(inputs)

    embedding_concat = Concatenate(axis=1)(embeddings)
    CNN_out = Conv1D(filters=32, kernel_size=3, padding='same', activation='relu')(embedding_concat)
    pool_out =MaxPooling1D(pool_size=2)(CNN_out)
    lstm_out = Bidirectional(LSTM(lstm_dim, dropout=0.2, recurrent_dropout=0.2,return_sequences=True))(pool_out)
    attention_mul = attention_3d_block(lstm_out,int(pool_out.shape[1]))
    attention_flatten_mul = Flatten()(attention_mul)
    output = Dense(num_class, activation='sigmoid')(attention_flatten_mul)
    model = Model(input=input_list, output=output)
    print(model.summary())
    return model

def create_words_embeds(tokenizers):
    words_list = []
    embeddings = []
    embeddings_model = load_pretrained_model()

    for tokenizer in tokenizers:
        word_index = tokenizer.word_index
        num_words = len(word_index)
        embedding_matrix = get_embed(embeddings_model,tokenizer)
        words_list.append(num_words)
        embeddings.append(embedding_matrix)

    return words_list,embeddings

def attention_3d_block(inputs,time_steps):
    TIME_STEPS = time_steps
    # if True, the attention vector is shared across the input_dimensions where the attention is applied.
    SINGLE_ATTENTION_VECTOR = False

    input_dim = int(inputs.shape[2])
    a = Permute((2, 1))(inputs)
    a = Reshape((input_dim,TIME_STEPS))(a)
    a = Dense(TIME_STEPS, activation='softmax')(a)
    if SINGLE_ATTENTION_VECTOR:
        a = Lambda(lambda x: K.mean(x, axis=1), name='dim_reduction')(a)
        a = RepeatVector(input_dim)(a)
    a_probs = Permute((2, 1), name='attention_vec')(a)

    output_attention_mul = Multiply(name='attention_mul')([inputs, a_probs])
    return output_attention_mul

def test_and_save(logger,network_name,model,df_train,df_test,X_ans_test,X_question_test,X_section_test,y_test):
    y_pred = model.predict([X_ans_test,X_question_test,X_section_test],batch_size=32)
    trues = []
    preds = []
    preds2 = []
    for (t,p) in zip(y_test,y_pred):
        if t[1] == 1:
            trues.append(1.0)
        else:
            trues.append(0.0)
        if p[1] >=0.5:
            preds2.append(1.0)
        else:
            preds2.append(0.0)
        preds.append(p[1])
    accuracy = accuracy_score(trues,preds2)
    recall = recall_score(trues,preds2)
    f1 = f1_score(trues,preds2)
    fpr, tpr, thresholds = roc_curve(trues, preds)
    roc_auc = auc(fpr, tpr)
    cmx_data = confusion_matrix(y_test.argmax(1),y_pred.argmax(1))
    print("==========={}===========".format(network_name))
    print(cmx_data)
    logger.info("accuracy: {}".format(accuracy))
    logger.info("recall: {}".format(recall))
    logger.info("f1: {}".format(f1))
    logger.info("auc: {}".format(roc_auc))

    #モデルの保存など
    model.save('output/crossing/{}.h5'.format(network_name))
    df_test['predict']=preds2
    df_test.to_csv('output/crossing/{}_test.csv'.format(network_name), sep="\t", index=None)
    df_train.to_csv('output/crossing/{}_train.csv'.format(network_name), sep="\t", index=None)

def build_callback(logger,network_name):
    es_cb = EarlyStopping(monitor='val_loss', patience=5, verbose=1, mode='auto')
    slack_command = '{}, epoch:{:03d}, loss:{:.7f}, val_loss:{:.7f}'
    slack_callback = LambdaCallback(
    on_epoch_end=lambda epoch, logs: logger.info(slack_command.format(network_name,epoch, logs['loss'], logs['val_loss'])))
    tb_cb = TensorBoard(log_dir="log/crossing/{}/".format(network_name), histogram_freq=1)

    callbacks = []
    callbacks.append(es_cb)
    callbacks.append(slack_callback)
    callbacks.append(tb_cb)
    return callbacks

def build_network_wrapper(network_name,shapes,words_list,embed_weights,num_class):
    #他のネットワークも試していたけどここでは割愛
    if network_name == "CNN-bi-lstm-attention":
        model = build_network_CNN_bi_lstm_attention(shapes,words_list,embed_weights,2)
    return model

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--network_name",
                    help="network in this script")
    parser.add_argument("--pre_trial",
                    help="pre trial or not",action='store_true')
    args = parser.parse_args()

    logger = build_slack_logger()
    logger.info("train and test start. Network is {}".format(args.network_name))
    shapes,tokenizers,df_train,df_test,X_ans_train,X_ans_test,X_question_train,X_question_test,X_section_train,X_section_test,y_train,y_test = input_date(args.pre_trial)
    words_list,embeddings = create_words_embeds(tokenizers)

    model = build_network_wrapper(args.network_name,shapes,words_list,embeddings,2)
    model.compile(optimizer='nadam', loss='binary_crossentropy', metrics=['accuracy'])

    callback = build_callback(logger,args.network_name)
    model.fit([X_1_train,X_2_train,X_3_train],y_train,
              batch_size=32,
              epochs=30,
              callbacks=callback,
              validation_data=([X_1_test,X_2_test,X_3_test], y_test))

    test_and_save(logger,args.network_name,model,df_train,df_test,X_1_test,X_2_test,X_3_test,y_test)
    logger.info("End")

if __name__ == '__main__':
    main()

Attentionはphilipperemy/keras-attention-mechanismから持ってきました。

model.summary()

上のコードでこんな感じになる

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            (None, 32)           0
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, 28)           0
__________________________________________________________________________________________________
input_3 (InputLayer)            (None, 18)           0
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 32, 300)      181500      input_1[0][0]
__________________________________________________________________________________________________
embedding_2 (Embedding)         (None, 28, 300)      18600       input_2[0][0]
__________________________________________________________________________________________________
embedding_3 (Embedding)         (None, 18, 300)      7500        input_3[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 78, 300)      0           embedding_1[0][0]
                                                                 embedding_2[0][0]
                                                                 embedding_3[0][0]
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, 78, 32)       28832       concatenate_1[0][0]
__________________________________________________________________________________________________
max_pooling1d_1 (MaxPooling1D)  (None, 39, 32)       0           conv1d_1[0][0]
__________________________________________________________________________________________________
bidirectional_1 (Bidirectional) (None, 39, 600)      799200      max_pooling1d_1[0][0]
__________________________________________________________________________________________________
permute_1 (Permute)             (None, 600, 39)      0           bidirectional_1[0][0]
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 600, 39)      0           permute_1[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 600, 39)      1560        reshape_1[0][0]
__________________________________________________________________________________________________
attention_vec (Permute)         (None, 39, 600)      0           dense_1[0][0]
__________________________________________________________________________________________________
attention_mul (Multiply)        (None, 39, 600)      0           bidirectional_1[0][0]
                                                                 attention_vec[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 23400)        0           attention_mul[0][0]
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 2)            46802       flatten_1[0][0]
==================================================================================================
Total params: 1,083,994
Trainable params: 876,394
Non-trainable params: 207,600

これでどのくらいで学習終わるのかな...

11
22
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
22