huggingface/transformersのBERTで固有表現抽出

914
huggingface/transformersのBERTで固有表現抽出

はじめに

前回の記事では、huggingface/transformersBertForSequenceClassificationを使って、文の分類を行いました。

他のBERTモデルを見てみると、その中にBertForTokenClassification
というモデルがあり、これはトークン(単語)を分類するためのモデルらしいです。
サンプルプログラムを見ると、固有表現抽出タスクで学習済みのモデルも提供されているようなので、試してみました。

[参考] BertForTokenClassification

サンプルプログラム

これが、huggigfaceで公開されている固有表現抽出のサンプルプログラムです。
BertForSequenceClassificationの時にも感じましたが、非常にシンプルです。


from transformers import BertTokenizer, BertForTokenClassification
import torch

tokenizer = BertTokenizer.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
model = BertForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")

inputs = tokenizer(
    "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
)

with torch.no_grad():
    logits = model(**inputs).logits

predicted_token_class_ids = logits.argmax(-1)

# Note that tokens are classified rather then input words which means that
# there might be more predicted token classes than words.
# Multiple token classes might account for the same word
predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
print(predicted_tokens_classes)

このプログラムでは、

HuggingFace is a company based in Paris and New York

という文を固有表現抽出して、各トークンに固有表現の9種類のラベル('O'、'B-ORG'、'I-ORG'、'B-LOC'、'I-LOC'、'B-PER'、'I-PER'、'B-MISC'、'I-MISC')のいずれかを付与しています。
プログラム最後のprint文の出力は以下のようになります(詳細は後述)。


['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC']

プログラムの内容

ここからは、このプログラムの内容に関する説明をしていきます。

トークナイザー、モデルの読み込み

まず初めに、学習済みのトークナイザー、モデル(dbmdz/bert-large-cased-finetuned-conll03-english)を読み込みます。


from transformers import BertTokenizer, BertForTokenClassification
import torch

tokenizer = BertTokenizer.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
model = BertForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")

トークナイズ

次に、固有表現抽出をしたい文をトークンに分割します。

  • add_special_tokens: 文頭に[SEP]、文末に[CLS]という特殊なトークンを付与するかどうか
  • return_tensors: 出力形式("pt"の場合はPyTorchのTensor)

inputs = tokenizer(
    "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
)

この出力結果のinputsは、以下のようになります。

{'input_ids': tensor([[20164, 10932,  2271,  7954,  1110,   170,  1419,  1359,  1107,  2123,
          1105,  1203,  1365]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

input_idsが分割されてできたトークンのIDです。このままだとわかりずらいので、

tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])

でトークンのラベルに変換すると以下のようになります。基本は単語=トークンですが、'HuggingFace'という単語だけは4つのサブトークンに分割されていることがわかります。

['Hu', '##gging', '##F', '##ace', 'is', 'a', 'company', 'based', 'in', 'Paris', 'and', 'New', 'York']

推論

トークナイズされた入力inputsを使って、各トークンの分類を行います。
勾配の更新(学習)が行われないように、with torch.no_grad()で囲む必要があります。

with torch.no_grad():
    logits = model(**inputs).logits

分類結果であるlogitsは、以下のようになっています。

tensor([[[ 8.8331e+00, -2.6947e+00, -1.4335e+00, -2.0760e+00, -1.9929e+00,
          -1.5367e+00, -1.3373e-01, -2.4921e+00,  1.1420e+00],
         [ 1.8645e+00, -1.9180e+00,  6.6708e-01, -2.8445e+00, -1.8780e-01,
          -1.2113e+00,  4.3467e+00, -2.1351e+00, -6.0643e-01],
         [ 1.6579e+00, -2.9294e+00, -9.1255e-01, -2.8349e+00, -7.5518e-01,
          -9.7249e-01,  7.4999e+00, -2.6189e+00, -2.2595e-01],
         [ 1.7139e+00, -2.8050e+00, -6.7317e-01, -3.0918e+00, -2.3132e-02,
          -1.3786e+00,  7.2262e+00, -3.0614e+00, -5.5431e-01],
         [ 9.9437e+00, -2.3024e+00, -7.7941e-01, -2.6059e+00, -1.2730e+00,
          -1.4824e+00,  1.6615e+00, -2.5419e+00, -8.0900e-01],
         [ 1.0505e+01, -2.4093e+00, -8.5326e-01, -2.7351e+00, -1.2257e+00,
          -1.5433e+00,  1.2509e+00, -2.4653e+00, -8.2943e-01],
         [ 9.9881e+00, -2.7991e+00, -6.8294e-01, -2.8792e+00, -1.1324e+00,
          -1.8547e+00,  2.4081e+00, -2.6043e+00, -5.8434e-01],
         [ 1.0769e+01, -2.2194e+00, -6.5605e-01, -2.7579e+00, -1.3175e+00,
          -1.7480e+00,  8.2793e-01, -2.3371e+00, -5.2890e-01],
         [ 1.0750e+01, -2.5078e+00, -5.1507e-01, -2.5354e+00, -1.4459e+00,
          -1.5453e+00,  5.2378e-01, -2.1810e+00, -3.8061e-01],
         [-6.0714e-02, -2.5348e+00, -1.0712e+00, -3.0533e+00, -4.0452e-01,
          -2.4849e+00,  9.8940e-01, -2.4244e+00,  7.3672e+00],
         [ 9.4568e+00, -2.5034e+00, -8.3296e-01, -2.6379e+00, -1.3871e+00,
          -1.1879e+00,  8.6052e-01, -2.2480e+00,  4.0924e-01],
         [ 6.8505e-01, -2.2629e+00, -1.1475e+00, -2.5290e+00, -1.0518e+00,
          -2.2181e+00,  5.0383e-01, -1.7780e+00,  7.7758e+00],
         [-9.9010e-03, -2.2016e+00, -1.4453e+00, -2.7184e+00, -8.2106e-01,
          -2.2796e+00,  7.2668e-01, -1.6339e+00,  7.7748e+00]]])

この結果は、それぞれのトークンが9個の分類それぞれに割り当てられる確率を表しています。
わかりやすく表で表すとこうなります。

トークン 分類0 分類1 分類2 分類3 分類4 分類5 分類6 分類7 分類8
Hu 8.8331e+00 -2.6947e+00 -1.4335e+00 -2.0760e+00 -1.9929e+00 -1.5367e+00 -1.3373e-01 -2.4921e+00 1.1420e+00
##gging 1.8645e+00 -1.9180e+00 6.6708e-01 -2.8445e+00 -1.8780e-01 -1.2113e+00 4.3467e+00 -2.1351e+00 -6.0643e-01
##F 1.6579e+00 -2.9294e+00 -9.1255e-01 -2.8349e+00 -7.5518e-01 -9.7249e-01 7.4999e+00 -2.6189e+00 -2.2595e-01
##ace 1.7139e+00 -2.8050e+00 -6.7317e-01 -3.0918e+00 -2.3132e-02 -1.3786e+00 7.2262e+00 -3.0614e+00 -5.5431e-01
is 9.9437e+00 -2.3024e+00 -7.7941e-01 -2.6059e+00 -1.2730e+00 -1.4824e+00 1.6615e+00 -2.5419e+00 -8.0900e-01
a 1.0505e+01 -2.4093e+00 -8.5326e-01 -2.7351e+00 -1.2257e+00 -1.5433e+00 1.2509e+00 -2.4653e+00 -8.2943e-01
company 9.9881e+00 -2.7991e+00 -6.8294e-01 -2.8792e+00 -1.1324e+00 -1.8547e+00 2.4081e+00 -2.6043e+00 -5.8434e-01
based 1.0769e+01 -2.2194e+00 -6.5605e-01 -2.7579e+00 -1.3175e+00 -1.7480e+00 8.2793e-01 -2.3371e+00 -5.2890e-01
in 1.0750e+01 -2.5078e+00 -5.1507e-01 -2.5354e+00 -1.4459e+00 -1.5453e+00 5.2378e-01 -2.1810e+00 -3.8061e-01
Paris -6.0714e-02 -2.5348e+00 -1.0712e+00 -3.0533e+00 -4.0452e-01 -2.4849e+00 9.8940e-01 -2.4244e+00 7.3672e+00
and 9.4568e+00 -2.5034e+00 -8.3296e-01 -2.6379e+00 -1.3871e+00 -1.1879e+00 8.6052e-01 -2.2480e+00 4.0924e-01
New 6.8505e-01 -2.2629e+00 -1.1475e+00 -2.5290e+00 -1.0518e+00 -2.2181e+00 5.0383e-01 -1.7780e+00 7.7758e+00
York -9.9010e-03 -2.2016e+00 -1.4453e+00 -2.7184e+00 -8.2106e-01 -2.2796e+00 7.2668e-01 -1.6339e+00 7.7748e+00

この分類結果logitsから、各トークンに対して最大の確率となる分類を抽出します。

predicted_token_class_ids = logits.argmax(-1)

結果は以下のようになります。

tensor([[0, 6, 6, 6, 0, 0, 0, 0, 0, 8, 0, 8, 8]])

分類がIDになっていてわかりづらいので、ラベルに変換します。

predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]

結果はこうなります。

['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC']

トークンと分類との対応を表にすると、このようになります。'Hu'が
ここでI-ORGは「組織」、I-LOCは「地名」、Oは固有表現ではないことを表しています。
HuggingFaceの'Hu'以外は、正しく分類されています(後述しますが、正確な表現をすると正しくありません)。

トークン 分類
Hu O
##gging I-ORG
##F I-ORG
##ace I-ORG
is O
a O
company O
based O
in O
Paris I-LOC
and O
New I-LOC
York I-LOC

結果に対する疑問

疑問1:文頭が必ず'O'に分類される

これはすぐにピンときました。tokenizerの引数add_special_tokensをFalseにして、文頭の[CLS]、文末の[SEP]を省きましたが、BERTで学習する際にはこれらをつけているはずなので、省略したことによって結果が正しくなくなったのだと思います。

add_special_tokensをTrueにしてみたところ、文頭のトークンも正しく分類されるようになりました。


// トークン
['[CLS]', 'Hu', '##gging', '##F', '##ace', 'is', 'a', 'company', 'based', 'in', 'Paris', 'and', 'New', 'York', '[SEP]']
// 分類
['O', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC', 'O']

疑問2:固有表現の先頭が「B-XXX」になっていない

組織名であるHuggingFaceは、正しく組織(ORG)に分類されました。
しかし、通常の固有表現抽出の場合、トークンに分割された単語に対する固有表現ラベルの先頭は「I-XXX」ではなく「B-XXX」となるはずです(後続するトークンは「I-XXX」)。
例えばHuggingFaceの場合は、以下のようになるはずです。

トークン 分類
Hu B-ORG
##gging I-ORG
##F I-ORG
##ace I-ORG

学習データ自体が「B-ORG」ではなくて「I-ORG」になっていることが考えられますが、情報が得られなかったため正しいことは不明です。

まとめ

今回の学習済みモデルでは、英語の固有表現抽出しかできないため、次回は日本語での固有表現抽出について説明したいと思います。