BERTによる日本語固有表現抽出

1387
BERTによる日本語固有表現抽出

はじめに

Hugging Face TransformersのBERTを使い、固有表現抽出タスクでモデルのファインチューニングを行いました。
事前学習は行わず、東北大学乾研究室が公開している日本語の事前学習モデルを使いました。

モデル作成の手順は大まかには下記の通りとなります。

  1. BertJapaneseTokenizerで文をトークンに分割し、トークンにラベルとして固有表現のタイプを付与(学習用データの作成)
  2. 上記データを入力として、BertForTokenClassificationを使って固有表現抽出を学習(ファインチューニング)

本稿では、この手順について説明します。

なお、手順は一部省略をしているため、詳細はGitHubで公開しているソースコード(jurabiinc/bert-ner-japanese)をご覧ください。

準備

事前に以下のライブラリをpipなどでインストールしておく必要があります。

  • torch
  • transformers
  • unidic_lite
  • fugashi
  • sklearn
  • datasets
  • seqeval

学習用データをBERTの入力形式に変換

学習用データとして、ストックマーク株式会社が公開しているWikipediaを用いた日本語の固有表現抽出データセット(stockmarkteam/ner-wikipedia-dataset)を使用しました。

このデータは下記のようなJSONデータで、文とその文の中に含まれる固有表現を文中の位置、固有表現のタイプ(人名、法人名、政治的組織名、その他組織名、地名、施設名、製品名、イベント名)が表現されています。


{
    "curid": "473536",
    "text": "イギリスはリーマンショック直後の2008年10月にイングランド銀行のバランスシートを一気に3倍近く増やした後、2008年11月から2009年3月にかけて段階的に縮小させていった。",
    "entities": [
        {
            "name": "イギリス",
            "span": [0,4],
            "type": "地名"
        },
        {
            "name": "リーマンショック",
            "span": [5,13],
            "type": "イベント名"
        },
        {
            "name": "イングランド銀行",
            "span": [25,33],
            "type": "政治的組織名"
        }
    ]
}
BERTに入力するデータは、使用する事前学習モデルと同じトークナイザーでトークナイズし、固有表現のタイプをラベルとして付与する必要があります。
上記のJSONの例の場合だと、下記のようにラベル付けをします。
固有表現の先頭には「B-」、先頭以外には「I-」を接頭辞として追加し、固有表現以外は「O」というラベルを付与します。
イギリス リーマン ショック 直後 イングランド 銀行
B-地名 O B-イベント名 I-イベント名 O O B-政治的組織名 I-政治的組織名 O

このラベル付けをする際に、いくつかの問題があったので、その解決策とともに紹介します。

  1. 文中の半角スペースがトークナイズによって除去されてしまう。その結果としてJSON内で"span"で指定されている固有表現の位置とトークンの位置との間にずれが生じる。
    そのため、ラベル付けをする際には半角スペースを考慮する。
  2. トークナイザーによりトークンの先頭に「##」という文字列が追加されることがある(WordPieceの機能)。たとえば「Jurabi」という単語は「Ju」+「##ra」+「##b」+「##i」とトークナイズされる。その結果として"span"で指定されている固有表現の位置とトークンの位置との間にずれが生じる。
    そのため、半角スペースの場合と同様に、ラベル付けをする際には「##」を考慮する。
  3. トークナイザーの辞書に存在しないトークンは「[UNK]」というトークンに変換されてしまう。これはどうしようもないため、「[UNK]」が出現したらその文のラベル付けをやめる(この文は中途半端な学習用データとなるため、除去してしまった方が良いかもしれないが、今回は中途半端な状態のまま使用した)。
  4. トークンの切れ目が固有表現の切れ目と一致しないことがあった。これもどうしようもないため、ラベル付けを諦めた。
  5. いくつかのJSONデータに誤りと思われるものがあったので、これらは手でデータを修正してから使用した。

以下がラベル付けに使用したプログラムです。

import json
from transformers import BertJapaneseTokenizer
from label import label2id

MAX_LENGTH = 128  # 一文あたりの最大トークン数
BERT_MODEL = "cl-tohoku/bert-base-japanese-v2"  # 使用する学習済みモデル
DATASET_PATH = "./dataset/ner.json"
TAGGED_DATASET_PATH = "./dataset/ner_tagged.json"

# 1. データ読み込み

with open(DATASET_PATH) as f:
  ner_data_list = json.load(f)

# 2. 固有表現タグづけ

# 半角スペースによってエンティティの開始位置がずれるので、エンティティの開始・終了位置を調整する(トークナイズで半角スペースが削除されるため)
def adjust_entity_span(text, entities):
  white_spece_posisions = [i for i, c in enumerate(text) if c == " "]
  for entity in entities:
    start_pos = entity["span"][0]
    end_pos = entity["span"][1]
    start_diff = sum(white_spece_pos < start_pos for white_spece_pos in white_spece_posisions)
    end_diff = sum(white_spece_pos < end_pos for white_spece_pos in white_spece_posisions)
    entity["span"] = [start_pos - start_diff, end_pos - end_diff]

for ner_data in ner_data_list:
  adjust_entity_span(ner_data["text"], ner_data["entities"])

sentence_list = [ner_data["text"] for ner_data in ner_data_list]

tokenizer = BertJapaneseTokenizer.from_pretrained(BERT_MODEL)

encoded_sentence_list = [tokenizer(sentence, max_length=MAX_LENGTH, padding="max_length", truncation=True) for sentence in sentence_list]

def calc_token_length(token):
  return len(token) -2 if token.startswith("##") else len(token)

def warn_start_pos(pos, token, entity, curid):
  print("[WARN] トークンの開始位置がエンティティの開始位置を超えました。エンティティの開始=<" + str(entity["span"][0]) + "> トークンの開始=<" + str(pos) + "> curid=<" + curid + "> token=<" + token + "> entity=<" + entity["name"] + ">")

def warn_end_pos(pos, token, entity, curid):
  token_length = calc_token_length(token)
  print("[WARN] トークンの終了位置がエンティティの終了位置を超えました。エンティティの終了=<" + str(entity["span"][1]) + "> トークンの終了=<" + str(pos + token_length) + "> curid=<" + curid + "> token=<" + token + "> entity=<" + entity["name"] + ">")

def search_tokens(tokens, entity, curid):
  ret = {}

  entity_type = entity["type"]
  entity_span = entity["span"]
  entity_start_pos = entity_span[0]
  entity_end_pos = entity_span[1]

  pos = 0
  is_inside_entity = False
  for i, token in enumerate(tokens):
    if token in ["[UNK]", "[SEP]", "[PAD]"]:
      break
    elif token == "[CLS]":
      continue

    token_length = calc_token_length(token)
    if not is_inside_entity: # まだエンティティの中に入っていない場合
      if pos == entity_start_pos: # トークンの開始がエンティティの開始に一致した場合
        ret[i] = "B-" + entity_type
        if pos + token_length == entity_end_pos: # トークンの終了がエンティティの終了に一致した場合
          break
        elif pos + token_length < entity_end_pos:
          is_inside_entity = True
        else: # [WARN]トークンの終了がエンティティの終了を超えた場合
          warn_end_pos(pos, token, entity, curid)
          print(tokens)
      elif pos > entity_start_pos: # [WARN]トークンの開始がエンティティの開始を超えた場合
        warn_start_pos(pos, token, entity, curid)
        print(tokens)
        break
    else: # エンティティの中に入っている場合
      if pos + token_length == entity_end_pos: # トークンの終わりがエンティティの終わりに一致した場合
        ret[i] = "I-" + entity_type
        is_inside_entity = False
        break
      elif pos + token_length < entity_end_pos: # トークンがまだエンティティの終わりに達していない場合
        ret[i] = "I-" + entity_type
      else: # [WARN]トークンがエンティティの終わりを超えた場合
        warn_end_pos(pos, token, entity, curid)
        print(tokens)
        ret.clear()
        is_inside_entity = False
        break
    pos += token_length

  return ret

# トークンにタグ付けをする
tags_list = []
for i, encoded_sentence in enumerate(encoded_sentence_list):
  tokens = tokenizer.convert_ids_to_tokens(encoded_sentence["input_ids"])

  tags = ["O"] * MAX_LENGTH

  ner_data = ner_data_list[i]
  curid = ner_data["curid"]

  entities = ner_data["entities"]

  for entity in entities:
    found_token_pos_tags = search_tokens(tokens, entity, curid)
    for pos, tag in found_token_pos_tags.items():
      tags[pos] = tag

  tags_list.append(tags)

  # 固有表現タグをIDに変換
  encoded_tags_list = [[label2id[tag] for tag in tags] for tags in tags_list] # 学習で利用

# タグづけしたデータの保存
tagged_sentence_list = []
for encoded_sentence, encoded_tags in zip(encoded_sentence_list, encoded_tags_list):
  tagged_sentence = {}
  tagged_sentence['input_ids'] = encoded_sentence['input_ids']
  tagged_sentence['token_type_ids'] = encoded_sentence['token_type_ids']
  tagged_sentence['attention_mask'] = encoded_sentence['attention_mask']
  tagged_sentence['labels'] = encoded_tags
  tagged_sentence_list.append(tagged_sentence)

with open(TAGGED_DATASET_PATH, 'w') as f:
  json.dump(tagged_sentence_list, f)
  

学習

事前学習モデルには、東北大学乾研究室の日本語BERTモデル(cl-tohoku/bert-base-japanese-v2)を使いました。このモデルは、形態素解析器としてMeCabのラッパーであるfugashiを利用しています。

こちらが作成したプログラムになります。


import json
import torch
from transformers import BertJapaneseTokenizer, BertForTokenClassification, BertConfig
from label import label2id, id2label

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

MAX_LENGTH = 128  # 一文あたりの最大トークン数
BERT_MODEL = "cl-tohoku/bert-base-japanese-v2"  # 使用する学習済みモデル
TAGGED_DATASET_PATH = "./dataset/ner_tagged.json"
MODEL_DIR = "./model"
LOG_DIR = "./logs"

# データの読み込み
with open(TAGGED_DATASET_PATH, 'r') as f:
  encoded_tagged_sentence_list = json.load(f)

# 3. データセットの作成
from sklearn.model_selection import train_test_split

class NERDataset(torch.utils.data.Dataset):
  def __init__(self, encoded_tagged_sentence_list):
    self.encoded_tagged_sentence_list = encoded_tagged_sentence_list

  def __len__(self):
    return len(self.encoded_tagged_sentence_list)

  def __getitem__(self, idx):
    # 辞書の値をTensorに変換
    item = {key: torch.tensor(val).to(device) for key, val in self.encoded_tagged_sentence_list[idx].items()}
    return item

# データを学習用、検証用に分割
train_encoded_tagged_sentence_list, eval_encoded_tagged_sentence_list = train_test_split(encoded_tagged_sentence_list)
# データセットに変換
train_data = NERDataset(train_encoded_tagged_sentence_list)
eval_data = NERDataset(eval_encoded_tagged_sentence_list)

# 4. Trainerの作成
from transformers import Trainer, TrainingArguments

import numpy as np
from datasets import load_metric

# 事前学習モデル
config = BertConfig.from_pretrained(BERT_MODEL, id2label=id2label, label2id=label2id)
model = BertForTokenClassification.from_pretrained(BERT_MODEL, config=config).to(device)
tokenizer = BertJapaneseTokenizer.from_pretrained(BERT_MODEL)

# 学習用パラメーター
training_args = TrainingArguments(
    output_dir = MODEL_DIR,
    num_train_epochs = 2,
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 32,
    warmup_steps = 500,  # 学習係数が0からこのステップ数で上昇
    weight_decay = 0.01,  # 重みの減衰率
    logging_dir = LOG_DIR,
)

metric = load_metric("seqeval")

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # ラベルのIDをラベルに変換
    predictions = [
        [id2label[p] for p in prediction] for prediction in predictions
    ]
    labels = [
        [id2label[l] for l in label] for label in labels
    ]

    results = metric.compute(predictions=predictions, references=labels)
    print(results)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

# Trainerの初期化
trainer = Trainer(
    model = model, # 学習対象のモデル
    args = training_args, # 学習用パラメーター
    compute_metrics = compute_metrics, # 評価用関数
    train_dataset = train_data, # 学習用データ
    eval_dataset = eval_data, # 検証用データ
    tokenizer = tokenizer
)

# 5. 学習
trainer.train()
trainer.evaluate()

trainer.save_model(MODEL_DIR)

下表が学習後のモデルを検証した結果です。
全体のf1値は88.7%で、まずまずの結果だと思います。

雑感ですが、人名はデータ数が多いためか精度は高くなっています。製品名は人間でも判定が難しいイメージがあるために精度が低いのでしょうか。

データ数 precision recall f1
人名 726 0.95 0.968 0.959
法人名 660 0.881 0.915 0.898
政治的組織名 312 0.82 0.891 0.854
その他組織名 268 0.86 0.869 0.865
地名 546 0.881 0.921 0.901
施設名 275 0.786 0.866 0.824
製品名 323 0.79 0.805 0.798
イベント名 241 0.866 0.888 0.877
全体 3351 0.870 0.905 0.887

結果

作成したモデルを使って、固有表現を抽出してみます。
Transformersで提供されているpipelineという機能を使うと、数行のコードで実装をすることができます。


from transformers import pipeline
from transformers import BertJapaneseTokenizer, BertForTokenClassification

MODEL_DIR = "./model"

model = BertForTokenClassification.from_pretrained(MODEL_DIR)
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_DIR)

ner_pipeline = pipeline('ner', model=model, tokenizer=tokenizer)

ner_pipeline("株式会社はJurabi、東京都台東区に本社を置くIT企業である。")

このプログラムを実行すると、抽出された固有表現が以下のような形式で出力されます。
「Jurabi」が法人名、「東京都台東区」が地名として抽出されています。


[
  {'entity': 'B-法人名', 'score': 0.9811771, 'index': 4, 'word': 'Ju', 'start': None, 'end': None
  },
  {'entity': 'I-法人名', 'score': 0.9945182, 'index': 5, 'word': '##ra', 'start': None, 'end': None
  },
  {'entity': 'I-法人名', 'score': 0.9943316, 'index': 6, 'word': '##b', 'start': None, 'end': None
  },
  {'entity': 'I-法人名', 'score': 0.9925209, 'index': 7, 'word': '##i', 'start': None, 'end': None
  },
  {'entity': 'B-地名', 'score': 0.99534696, 'index': 9, 'word': '東京', 'start': None, 'end': None
  },
  {'entity': 'I-地名', 'score': 0.9967154, 'index': 10, 'word': '都', 'start': None, 'end': None
  },
  {'entity': 'I-地名', 'score': 0.996228, 'index': 11, 'word': '台東', 'start': None, 'end': None
  },
  {'entity': 'I-地名', 'score': 0.9965228, 'index': 12, 'word': '区', 'start': None, 'end': None
  }
]

まとめ

今回は初めての試みのため、精度の改良までは行っていませんが、

  • 教師データを増やす
  • 大規模な事前学習モデルを使う

などにより、より精度の高いモデルの作成もしてみたいと思っています。

なお、今回作成したモデルはHugging Faceで公開しています。
jurabi/bert-ner-japanese