はじめに
前回の記事では、huggingface/transformersのBertForSequenceClassification
を使って、文の分類を行いました。
他のBERTモデルを見てみると、その中にBertForTokenClassification
というモデルがあり、これはトークン(単語)を分類するためのモデルらしいです。
サンプルプログラムを見ると、固有表現抽出タスクで学習済みのモデルも提供されているようなので、試してみました。
[参考] BertForTokenClassification
サンプルプログラム
これが、huggigfaceで公開されている固有表現抽出のサンプルプログラムです。
BertForSequenceClassification
の時にも感じましたが、非常にシンプルです。
from transformers import BertTokenizer, BertForTokenClassification
import torch
tokenizer = BertTokenizer.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
model = BertForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
inputs = tokenizer(
"HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
)
with torch.no_grad():
logits = model(**inputs).logits
predicted_token_class_ids = logits.argmax(-1)
# Note that tokens are classified rather then input words which means that
# there might be more predicted token classes than words.
# Multiple token classes might account for the same word
predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
print(predicted_tokens_classes)
このプログラムでは、
HuggingFace is a company based in Paris and New York
という文を固有表現抽出して、各トークンに固有表現の9種類のラベル('O'、'B-ORG'、'I-ORG'、'B-LOC'、'I-LOC'、'B-PER'、'I-PER'、'B-MISC'、'I-MISC')のいずれかを付与しています。
プログラム最後のprint文の出力は以下のようになります(詳細は後述)。
['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC']
プログラムの内容
ここからは、このプログラムの内容に関する説明をしていきます。
トークナイザー、モデルの読み込み
まず初めに、学習済みのトークナイザー、モデル(dbmdz/bert-large-cased-finetuned-conll03-english)を読み込みます。
from transformers import BertTokenizer, BertForTokenClassification
import torch
tokenizer = BertTokenizer.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
model = BertForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
トークナイズ
次に、固有表現抽出をしたい文をトークンに分割します。
add_special_tokens
: 文頭に[SEP]、文末に[CLS]という特殊なトークンを付与するかどうかreturn_tensors
: 出力形式("pt"の場合はPyTorchのTensor)
inputs = tokenizer(
"HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
)
この出力結果のinputs
は、以下のようになります。
{'input_ids': tensor([[20164, 10932, 2271, 7954, 1110, 170, 1419, 1359, 1107, 2123,
1105, 1203, 1365]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids
が分割されてできたトークンのIDです。このままだとわかりずらいので、
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
でトークンのラベルに変換すると以下のようになります。基本は単語=トークンですが、'HuggingFace'という単語だけは4つのサブトークンに分割されていることがわかります。
['Hu', '##gging', '##F', '##ace', 'is', 'a', 'company', 'based', 'in', 'Paris', 'and', 'New', 'York']
推論
トークナイズされた入力inputs
を使って、各トークンの分類を行います。
勾配の更新(学習)が行われないように、with torch.no_grad()
で囲む必要があります。
with torch.no_grad():
logits = model(**inputs).logits
分類結果であるlogits
は、以下のようになっています。
tensor([[[ 8.8331e+00, -2.6947e+00, -1.4335e+00, -2.0760e+00, -1.9929e+00,
-1.5367e+00, -1.3373e-01, -2.4921e+00, 1.1420e+00],
[ 1.8645e+00, -1.9180e+00, 6.6708e-01, -2.8445e+00, -1.8780e-01,
-1.2113e+00, 4.3467e+00, -2.1351e+00, -6.0643e-01],
[ 1.6579e+00, -2.9294e+00, -9.1255e-01, -2.8349e+00, -7.5518e-01,
-9.7249e-01, 7.4999e+00, -2.6189e+00, -2.2595e-01],
[ 1.7139e+00, -2.8050e+00, -6.7317e-01, -3.0918e+00, -2.3132e-02,
-1.3786e+00, 7.2262e+00, -3.0614e+00, -5.5431e-01],
[ 9.9437e+00, -2.3024e+00, -7.7941e-01, -2.6059e+00, -1.2730e+00,
-1.4824e+00, 1.6615e+00, -2.5419e+00, -8.0900e-01],
[ 1.0505e+01, -2.4093e+00, -8.5326e-01, -2.7351e+00, -1.2257e+00,
-1.5433e+00, 1.2509e+00, -2.4653e+00, -8.2943e-01],
[ 9.9881e+00, -2.7991e+00, -6.8294e-01, -2.8792e+00, -1.1324e+00,
-1.8547e+00, 2.4081e+00, -2.6043e+00, -5.8434e-01],
[ 1.0769e+01, -2.2194e+00, -6.5605e-01, -2.7579e+00, -1.3175e+00,
-1.7480e+00, 8.2793e-01, -2.3371e+00, -5.2890e-01],
[ 1.0750e+01, -2.5078e+00, -5.1507e-01, -2.5354e+00, -1.4459e+00,
-1.5453e+00, 5.2378e-01, -2.1810e+00, -3.8061e-01],
[-6.0714e-02, -2.5348e+00, -1.0712e+00, -3.0533e+00, -4.0452e-01,
-2.4849e+00, 9.8940e-01, -2.4244e+00, 7.3672e+00],
[ 9.4568e+00, -2.5034e+00, -8.3296e-01, -2.6379e+00, -1.3871e+00,
-1.1879e+00, 8.6052e-01, -2.2480e+00, 4.0924e-01],
[ 6.8505e-01, -2.2629e+00, -1.1475e+00, -2.5290e+00, -1.0518e+00,
-2.2181e+00, 5.0383e-01, -1.7780e+00, 7.7758e+00],
[-9.9010e-03, -2.2016e+00, -1.4453e+00, -2.7184e+00, -8.2106e-01,
-2.2796e+00, 7.2668e-01, -1.6339e+00, 7.7748e+00]]])
この結果は、それぞれのトークンが9個の分類それぞれに割り当てられる確率を表しています。
わかりやすく表で表すとこうなります。
トークン | 分類0 | 分類1 | 分類2 | 分類3 | 分類4 | 分類5 | 分類6 | 分類7 | 分類8 |
---|---|---|---|---|---|---|---|---|---|
Hu | 8.8331e+00 | -2.6947e+00 | -1.4335e+00 | -2.0760e+00 | -1.9929e+00 | -1.5367e+00 | -1.3373e-01 | -2.4921e+00 | 1.1420e+00 |
##gging | 1.8645e+00 | -1.9180e+00 | 6.6708e-01 | -2.8445e+00 | -1.8780e-01 | -1.2113e+00 | 4.3467e+00 | -2.1351e+00 | -6.0643e-01 |
##F | 1.6579e+00 | -2.9294e+00 | -9.1255e-01 | -2.8349e+00 | -7.5518e-01 | -9.7249e-01 | 7.4999e+00 | -2.6189e+00 | -2.2595e-01 |
##ace | 1.7139e+00 | -2.8050e+00 | -6.7317e-01 | -3.0918e+00 | -2.3132e-02 | -1.3786e+00 | 7.2262e+00 | -3.0614e+00 | -5.5431e-01 |
is | 9.9437e+00 | -2.3024e+00 | -7.7941e-01 | -2.6059e+00 | -1.2730e+00 | -1.4824e+00 | 1.6615e+00 | -2.5419e+00 | -8.0900e-01 |
a | 1.0505e+01 | -2.4093e+00 | -8.5326e-01 | -2.7351e+00 | -1.2257e+00 | -1.5433e+00 | 1.2509e+00 | -2.4653e+00 | -8.2943e-01 |
company | 9.9881e+00 | -2.7991e+00 | -6.8294e-01 | -2.8792e+00 | -1.1324e+00 | -1.8547e+00 | 2.4081e+00 | -2.6043e+00 | -5.8434e-01 |
based | 1.0769e+01 | -2.2194e+00 | -6.5605e-01 | -2.7579e+00 | -1.3175e+00 | -1.7480e+00 | 8.2793e-01 | -2.3371e+00 | -5.2890e-01 |
in | 1.0750e+01 | -2.5078e+00 | -5.1507e-01 | -2.5354e+00 | -1.4459e+00 | -1.5453e+00 | 5.2378e-01 | -2.1810e+00 | -3.8061e-01 |
Paris | -6.0714e-02 | -2.5348e+00 | -1.0712e+00 | -3.0533e+00 | -4.0452e-01 | -2.4849e+00 | 9.8940e-01 | -2.4244e+00 | 7.3672e+00 |
and | 9.4568e+00 | -2.5034e+00 | -8.3296e-01 | -2.6379e+00 | -1.3871e+00 | -1.1879e+00 | 8.6052e-01 | -2.2480e+00 | 4.0924e-01 |
New | 6.8505e-01 | -2.2629e+00 | -1.1475e+00 | -2.5290e+00 | -1.0518e+00 | -2.2181e+00 | 5.0383e-01 | -1.7780e+00 | 7.7758e+00 |
York | -9.9010e-03 | -2.2016e+00 | -1.4453e+00 | -2.7184e+00 | -8.2106e-01 | -2.2796e+00 | 7.2668e-01 | -1.6339e+00 | 7.7748e+00 |
この分類結果logits
から、各トークンに対して最大の確率となる分類を抽出します。
predicted_token_class_ids = logits.argmax(-1)
結果は以下のようになります。
tensor([[0, 6, 6, 6, 0, 0, 0, 0, 0, 8, 0, 8, 8]])
分類がIDになっていてわかりづらいので、ラベルに変換します。
predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
結果はこうなります。
['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC']
トークンと分類との対応を表にすると、このようになります。'Hu'が
ここでI-ORGは「組織」、I-LOCは「地名」、Oは固有表現ではないことを表しています。
HuggingFaceの'Hu'以外は、正しく分類されています(後述しますが、正確な表現をすると正しくありません)。
トークン | 分類 |
---|---|
Hu | O |
##gging | I-ORG |
##F | I-ORG |
##ace | I-ORG |
is | O |
a | O |
company | O |
based | O |
in | O |
Paris | I-LOC |
and | O |
New | I-LOC |
York | I-LOC |
結果に対する疑問
疑問1:文頭が必ず'O'に分類される
これはすぐにピンときました。tokenizerの引数add_special_tokens
をFalseにして、文頭の[CLS]、文末の[SEP]を省きましたが、BERTで学習する際にはこれらをつけているはずなので、省略したことによって結果が正しくなくなったのだと思います。
add_special_tokens
をTrueにしてみたところ、文頭のトークンも正しく分類されるようになりました。
// トークン
['[CLS]', 'Hu', '##gging', '##F', '##ace', 'is', 'a', 'company', 'based', 'in', 'Paris', 'and', 'New', 'York', '[SEP]']
// 分類
['O', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC', 'O']
疑問2:固有表現の先頭が「B-XXX」になっていない
組織名であるHuggingFaceは、正しく組織(ORG)に分類されました。
しかし、通常の固有表現抽出の場合、トークンに分割された単語に対する固有表現ラベルの先頭は「I-XXX」ではなく「B-XXX」となるはずです(後続するトークンは「I-XXX」)。
例えばHuggingFaceの場合は、以下のようになるはずです。
トークン | 分類 |
---|---|
Hu | B-ORG |
##gging | I-ORG |
##F | I-ORG |
##ace | I-ORG |
学習データ自体が「B-ORG」ではなくて「I-ORG」になっていることが考えられますが、情報が得られなかったため正しいことは不明です。
まとめ
今回の学習済みモデルでは、英語の固有表現抽出しかできないため、次回は日本語での固有表現抽出について説明したいと思います。