Source code for medner_j.ner

"""日本語医療文書のための病名抽出システム

症例報告などの医療文書から病名を抽出(・正規化)するシステムです.
BERT-CRFを使用しています

Args:
    DEFAULT_CACHE_PATH (env): モデルのダウンロード先指定のための環境変数(default: ~/.cache)
"""

import pathlib
from pathlib import Path
import itertools
import sys
import os
from logging import getLogger, StreamHandler, INFO

from transformers import BertJapaneseTokenizer, BertModel
from allennlp.modules.conditional_random_field import allowed_transitions
import torch
from torch.nn.utils.rnn import pad_sequence

from .model import BertCrf, ListTokenizer
from .util import (
    create_label_vocab_from_file,
    convert_dict_to_xml,
    convert_iob_to_dict,
    download_fileobj,
)
from .normalize import load_dict, DictNormalizer


logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel(INFO)
logger.setLevel(INFO)
logger.addHandler(handler)
logger.propagate = False

DEFAULT_CACHE_PATH = os.getenv("DEFAULT_CACHE_PATH", "~/.cache")
DEFAULT_MEDNERJ_PATH = Path(
    os.path.expanduser(os.path.join(DEFAULT_CACHE_PATH, "MedNERJ"))
)
DEFAULT_MODEL_PATH = DEFAULT_MEDNERJ_PATH / "pretrained"

BERT_URL = "http://aoi.naist.jp/MedEXJ2/pretrained"


[docs]class Ner(object): """NER model 本体のモデルです. 基本的に,from_pretrained()を使用してインスタンスを生成してください. Examples: インスタンスの生成:: from medner_j import Ner model = Ner.from_pretrained() Args: label_vocab (dict): {label:label_idx, ...} itol (dict): {label_idx: label, ...} basic_tokenizer (callable): 単語分割用トークナイザ subword_tokenizer (callable): サブワード分割用トークナイザ model (nn.Module): BertCrfモデル normalizer (callable): 単語正規化関数 """ def __init__( self, base_model, basic_tokenizer, subword_tokenizer, model_dir=DEFAULT_MODEL_PATH, normalizer=None, ): """初期化 非推奨です Args: base_model (nn.Module): BertCrfモデル basic_tokenizer (callable): 単語分割用トークナイザ subword_tokenizer (callable): サブワード分割用トークナイザ label_vocab (dict): {label:label_idx, ...} model_dir (pathlib.Path or str): モデルフォルダのpath.labels.txtとfinal.model normalizer (callable): 単語正規化関数 """ if not isinstance(model_dir, pathlib.PurePath): model_dir = Path(model_dir) # self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = "cpu" label_vocab = create_label_vocab_from_file(str(model_dir / "labels.txt")) self.itol = {i: l for l, i in label_vocab.items()} constraints = allowed_transitions("BIO", {i: w for w, i in label_vocab.items()}) self.model = BertCrf(base_model, len(label_vocab), constraints) self.model.load_state_dict( torch.load(str(model_dir / "final.model"), map_location=self.device) ) self.model.to(self.device) self.basic_tokenizer = basic_tokenizer self.subword_tokenizer = subword_tokenizer self.normalizer = normalizer def basic_tokenize(self, sents): return [self.basic_tokenizer.tokenize(s) for s in sents] def subword_tokenize(self, tokens): subwords = [[self.subword_tokenizer.tokenize(s) for s in ss] for ss in tokens] lengths = [[len(s) for s in ss] for ss in subwords] subwords = [list(itertools.chain.from_iterable(ss)) for ss in subwords] return subwords, lengths def numericalize(self, tokens): return [self.subword_tokenizer.convert_tokens_to_ids(t) for t in tokens] def encode(self, sents): tokens = self.basic_tokenize(sents) subwords, lengths = self.subword_tokenize(tokens) subwords = [["[CLS]"] + sub for sub in subwords] inputs = self.numericalize(subwords) inputs = [torch.tensor(i).to(self.device) for i in inputs] return inputs, lengths, tokens def _infer_space_tag(self, pre_tag, tag, post_tag): if pre_tag % 2 == 1 and post_tag == pre_tag + 1: return pre_tag + 1 elif pre_tag != 0 and post_tag == pre_tag: return pre_tag return 0 def integrate_subwords_tags(self, tags, lengths): results = [] for ts, ls in zip(tags, lengths): result = [] idx = 0 for l in ls: if l == 0: pre_tag = 0 if idx == 0 else ts[idx - 1] post_tag = 0 if idx == len(ts) - 1 else ts[idx + 1] tag = self._infer_space_tag(pre_tag, ts[idx], post_tag) else: # tag = merge(ts[idx : idx + l]) tag = ts[idx : idx + l][0] idx += l result.append(tag) results.append(result) return results
[docs] def predict(self, sents, output_format="xml"): """病名抽出 文のリストを受け取り,病名を抽出するメソッド Args: sents (List): 入力文のリスト output_format (str): 出力フォーマット.xml or dict(default: xml) Returns: List: 出力のリスト 出力フォーマット(xml):: ["<C>脳梗塞</C>を認める."] 出力フォーマット(dict):: [{"span": (0, 3), "type": "C", "disease":"脳梗塞", "norm":"脳梗塞"}] """ inputs, lengths, tokens = self.encode(sents) results = [] for s_idx in range(0, len(inputs) + 1, 16): e_idx = min(len(inputs), s_idx + 16) batch_inputs = inputs[s_idx:e_idx] padded_batch_inputs = pad_sequence( batch_inputs, batch_first=True, padding_value=0 ) mask = [[int(i > 0) for i in ii] for ii in padded_batch_inputs] mask = torch.tensor(mask).to(self.device) tags = self.model.decode(padded_batch_inputs, mask) tags = [t[0] for t in tags] tags = self.integrate_subwords_tags(tags, lengths[s_idx:e_idx]) results.extend(tags) results = [[self.itol[t] for t in tt] for tt in results] results = convert_iob_to_dict(tokens, results) if self.normalizer is not None: self._normalize(results) if output_format == "xml": results = convert_dict_to_xml(tokens, results) return results
def _normalize(self, dict_list): for dd in dict_list: for d in dd: d["norm"] = self.normalizer(d["disease"])
[docs] @classmethod def from_pretrained(cls, model_name="BERT", normalizer="dict"): """学習モデルの読み込み 学習済みモデルを読み込み,Nerインスタンスを返します. 学習済みモデルがキャッシュされていない場合,~/.cacheにモデルのダウンロードを行います. ダウンロード先を指定したい場合は環境変数DEFAULT_CACHE_PATHで指定してください. Args: model_name (str): モデル名.現バージョンはBERTのみしか実装していません. normalizer (str or callable): 標準化方法の指定.dict or dnorm. Returns: Ner: Nerインスタンス """ assert model_name == "BERT", "BERT以外未実装です" if model_name == "BERT": model_dir = DEFAULT_MODEL_PATH src_url = BERT_URL base_model = BertModel.from_pretrained("cl-tohoku/bert-base-japanese-char") basic_tokenizer = ListTokenizer() subword_tokenizer = BertJapaneseTokenizer.from_pretrained( "cl-tohoku/bert-base-japanese-char", do_basic_tokenize=False ) if not model_dir.parent.is_dir(): logger.info("creating %s", str(model_dir.parent)) model_dir.parent.mkdir() if not model_dir.is_dir(): logger.info("creating %s", str(model_dir)) model_dir.mkdir() if not (model_dir / "final.model").is_file(): logger.info("not found %s", str(model_dir / "final.model")) download_fileobj(src_url + "/final.model", model_dir / "final.model") if not (model_dir / "labels.txt").is_file(): logger.info("not found %s", str(model_dir / "labels.txt")) download_fileobj(src_url + "/labels.txt", model_dir / "labels.txt") if isinstance(normalizer, str): if normalizer == "dnorm": logger.info("try %s normalizer", "dnorm") try: from dnorm_j import DNorm normalizer = DNorm.from_pretrained().normalize logger.info("use %s normalizer", "dnorm") except: logger.warning("You did not install dnorm") logger.warning("use %s normalizer", "Dict") normalizer = DictNormalizer( DEFAULT_MEDNERJ_PATH / "norm_dic.csv" ).normalize else: logger.info("use %s normalizer", "Dict") normalizer = DictNormalizer( DEFAULT_MEDNERJ_PATH / "norm_dic.csv" ).normalize elif isinstance(normalizer, object): logger.info("use %s normalizer", "your original") normalizer = normalizer else: raise TypeError ner = cls( base_model, basic_tokenizer, subword_tokenizer, model_dir=model_dir, normalizer=normalizer, ) return ner
if __name__ == "__main__": fn = Path(sys.argv[1]) sents = [] with open(str(fn), "r") as f: for line in f: line = line.rstrip() if not line: continue sents.append(line) print(sents) model = Ner.from_pretrained() results = model.predict(sents) print(results)