Uncategorized

Hugging Face Fine-tune for Multilingual NER (Japanese Example)

(Please download source code from here.)

Hugging Face multilingual fine-tuning (series of posts)

  1. Named Entity Recognition (NER)
  2. Text Summarization
  3. Question Answering

Now a lot of AI companies are providing pre-trained large language models (LLMs) including methods that tune to enable models self-trained. Among such tools and framework, Hugging Face is widely used because of its flexibilities, in which over 20,000 transformer-based models are available.

In this series of posts, I’ll show you brief examples to learn how to fine-tune multilingual transformer models in Hugging Face.
I don’t focus on generic auto-regressive models (i.e. decoder-only models) which are widely used in today’s LLMs, but I’ll focus on models for specific purposes (i.e. encoder-only models and encoder-decoder models) in this series of posts.
The first example is for NER (named entity recognition) in Japanese language.

Note : For fine-tuning generic auto-regressive (decoder-only) LLM, see this example.

Especially, some languages, such like, Chinese, Korean, and Japanese, don’t have an explicit whitespace tokenization, and we need to concern several things before fine-tuning models. In this post, I’ll focus on Japanese, but you can perform fine-tuning in the same way also in these other languages.

NER (named entity recognition) is a common NLP task that identifies entities, such like, person name, organization name, or location name in text.
NER doesn’t just classify individual word respectively, but it classifies by the context of text with transformer architecture. For instance, the word “mean” has several meanings, similar to either “average” or “poor”. The well-trained NER classifier will recognize which one to apply for the meaning in its context.

Note : If you’re new to transformers in deep learning, see these tutorials.

Set up Environment

In this example, I have used GPU instance with Ubuntu Server 20.04 LTS image in Microsoft Azure.

Set up GPU drivers (NVIDIA CUDA) and install Hugging Face standard libraries (transformers, datasets) as follows.

# compilers and development settings
sudo apt-get update
sudo apt install -y gcc
sudo apt-get install -y make

# install CUDA 11.4.4 (because I use old generation K80 GPU)
wget https://developer.download.nvidia.com/compute/cuda/11.4.4/local_installers/cuda_11.4.4_470.82.01_linux.run
sudo sh cuda_11.4.4_470.82.01_linux.run
echo -e "export LD_LIBRARY_PATH=/usr/local/cuda-11.4/lib64" >> ~/.bashrc
source ~/.bashrc

# install and upgrade pip
sudo apt-get install -y python3-pip
sudo -H pip3 install --upgrade pip

# install pytorch with GPU accelerated
# (see https://pytorch.org/get-started/locally/ )
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu114

# install sentencepiece for multi-lingual modeling
pip3 install omegaconf hydra-core fairseq sentencepiece

# install additional packages
pip3 install numpy seqeval pandas matplotlib scikit-learn

# install jupyter if you run code in notebook
pip3 install jupyter

Note : In general, the checkpoint files in the training will become so large, because it’s large model. Please expand disk in Azure VM, if you need. (See here to expand disk in Azure.)

Note : In this series of posts, I’ll show you examples with small footprint and you can run examples on a single GPU or consumer GPUs (such as, NVIDIA RTX).
But, in practical large model’s training (including fine-tuning), it won’t fit to a single GPU (also GPU memory), and the mixture of optimization techniques – such as, model parallelisms, quantization, or compressions – are often used.
In such case, you can run training on multiple GPUs on a host (which are communicated through NVLink) in Azure or also on multiple GPUs on multiple hosts in Azure (which are connected with NVIDIA InfiniBand network).
For these implementation (for model parallelisms, quantization, or compressions), you can use DeepSpeed library which simplifies the complicated configurations of a variety of mixed optimizations – such as, 3D parallelism.
DeepSpeed is also well-integrated in HuggingFace.

Preprocess Data (Dataset)

I have used Japanese NER dataset by Stockmark Inc, because it has labeling with high accuracy. (You can download dataset (ner.json) from here.)
This dataset (JSON) consists of Wikipedia articles in Japanese, and every tokens are annotated by the span of characters. (See below.)
This dataset has 5343 records.

original dataset (json)

[
  {
    "curid": "3572156",
    "text": "SPRiNGSと最も仲の良いライバルグループ。",
    "entities": [
      {
        "name": "SPRiNGS",
        "span": [
          0,
          7
        ],
        "type": "その他の組織名"
      }
    ]
  },
  {
    "curid": "2415078",
    "text": "レッドフォックス株式会社は、東京都千代田区に本社を置くITサービス企業である。",
    "entities": [
      {
        "name": "レッドフォックス株式会社",
        "span": [
          0,
          12
        ],
        "type": "法人名"
      },
      {
        "name": "東京都千代田区",
        "span": [
          14,
          21
        ],
        "type": "地名"
      }
    ]
  },
  .....

]

Note : When you perform cross-lingual or multilingual model training, you can also use XTREAM corpora, such as PAN-X. These dataset include the labeled set in multiple languages – such as, PAN-X.ja, PAN-X.ko, PAN-X.zh, etc – and has large number of records. (For instance, PAN-X.ja has 40000 records, in which 20000 is for training, 10000 is for validation, and 10000 is for test.)
Later I’ll describe about cross-lingual model training.

First I convert original dataset, and then generate Hugging Face dataset with labels annotated by each characters.
Each character respectively has either of 9 label’s index – O (others or nothing), PER (person), ORG (general corporation organization), ORG-P (political organization), ORG-O (other organization), LOC (location), INS (institution, facility), PRD (product), or EVT (event).

import json
from datasets import Dataset, Features, Sequence, Value, ClassLabel

# load dataset
with open("ner.json") as f:
  json_all = json.load(f)

# create chracater-based annotated dataset
tokens_list = []
ner_tags_list = []
for json_dat in json_all:
  tokens = list(json_dat["text"])
  ner_tags = ["O"] * len(tokens)
  for ent in json_dat["entities"]:
    for i in range(ent["span"][0], ent["span"][1]):
      # See https://github.com/stockmarkteam/ner-wikipedia-dataset
      if ent["type"] == "人名":  # person
        ner_tags[i] = "PER"
      elif ent["type"] == "法人名":  # organization (corporation general)
        ner_tags[i] = "ORG"
      elif ent["type"] == "政治的組織名":  # organization (political)
        ner_tags[i] = "ORG-P"
      elif ent["type"] == "その他の組織名":  # organization (others)
        ner_tags[i] = "ORG-O"
      elif ent["type"] == "地名":  # location
        ner_tags[i] = "LOC"
      elif ent["type"] == "施設名":  # institution (facility)
        ner_tags[i] = "INS"
      elif ent["type"] == "製品名":  # product
        ner_tags[i] = "PRD"
      elif ent["type"] == "イベント名":  # event
        ner_tags[i] = "EVT"
  tokens_list.append(tokens)
  ner_tags_list.append(ner_tags)

features = Features({
  "tokens": Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
  "ner_tags": Sequence(feature=ClassLabel(names=["O", "PER", "ORG", "ORG-P", "ORG-O", "LOC", "INS", "PRD", "EVT"], id=None), length=-1, id=None)
})
ds = Dataset.from_dict(
  {"tokens": tokens_list, "ner_tags": ner_tags_list},
  features=features
)

# generate converter for index(int)-to-tag(string) and tag(string)-to-index(int)
tags = ds.features["ner_tags"].feature
index2tag = {idx: tag for idx, tag in enumerate(tags.names)}
tag2index = {tag: idx for idx, tag in enumerate(tags.names)}

# separate dataset into train dataset and validation dataset
ds = ds.train_test_split(test_size=0.1, shuffle=True)

For instance, the Japanese sentence “松友美佐紀は、日本のバドミントン選手。” (=”Misaki Matsumoto is a badminton player in Japan.”) is labeled as follows.

In HuggingFace transformer, SentencePiece tokenizer (by Unigram) can be used for subword segmentation among multilingual corpra. This will be useful for operating multiple languages especially which don’t have an explicit whitespace tokenization (such like, Japanese, Korean, and Chinese), because it’s agnostic about the properties of many languages, such as, accents, punctuation, etc.
In this example, I’ll use XML-RoBERTa model (which is a BERT-based improved architecture on cross-lingual language model, shortly XML-R) and the tokenizer in this model will tokenize the known general words – such like, “わたし” (=I) and “です” (=is) – as each words, and the other unknown proper words – such like, “マーベリック” – as each characters. (See below.)

Remind that all text in our original dataset are now tokenized by each characters.
We should then convert dataset to fit to this tokenizer as follows.

Tokens in original dataset

Tokens 2 8
Tags O O O O PER PER PER PER PER PER O O O O O

Tokens we need in this model

Tokens わたし マー ック 28 です
Tags O O PER PER PER PER O O O

The following code converts the original dataset (character-based annotated data) into model-compliant dataset.
Labels are also converted into the corresponding tokens.

from transformers import AutoTokenizer

# load tokenizer of pre-trained XML-RoBERTa model
xlmr_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")

# define function for dataset conversion
def tokenize_and_align_labels(data):
  text = ["".join(t) for t in data["tokens"]]
  # tokenized_inputs = xlmr_tokenizer(text, truncation=True, max_length=512)
  tokenized_inputs = xlmr_tokenizer(text)

  #
  # map label to the new token
  #
  # [example]
  #   org token (data)      : ["松", "崎", "は", "日", "本", "に", "い", "る"]
  #   new token (tokenized_inputs): ["_", "松", "崎", "は", "日本", "に", "いる"]
  labels = []
  for row_idx, label_old in enumerate(data["ner_tags"]):
    # label is initialized as [[], [], [], [], [], [], []]
    label_new = [[] for t in tokenized_inputs.tokens(batch_index=row_idx)]
    # label becomes [[1], [1], [1], [0], [5, 5], [0], [0, 0]]
    for char_idx in range(len(data["tokens"][row_idx])):
      token_idx = tokenized_inputs.char_to_token(row_idx, char_idx)
      if token_idx is not None:
        label_new[token_idx].append(data["ner_tags"][row_idx][char_idx])
        if (tokenized_inputs.tokens(batch_index=row_idx)[token_idx] == "▁") and (data["ner_tags"][row_idx][char_idx] != 0):
          label_new[token_idx+1].append(data["ner_tags"][row_idx][char_idx])
    # label becomes [1, 1, 1, 0, 5, 0, 0]
    label_new = list(map(lambda i : max(i, default=0), label_new))
    # append result
    labels.append(label_new)

  tokenized_inputs["labels"] = labels
  return tokenized_inputs

# run conversion
tokenized_ds = ds.map(
  tokenize_and_align_labels,
  remove_columns=["tokens", "ner_tags"],
  batched=True,
  batch_size=128)

Note : SentencePiece tokenizer adds “<s>” and “<\s>“, instead of “[CLS]” and “[SEP]“.

I note that several pre-processing tasks, such as normalization, are also performed in the HuggingFace tokenizer.
For instance, when you pass multi-byte’s bracket characters (which is specific to Japanese language) into HuggingFace tokenizer, these characters will be standardized into normal bracket characters.

Note : In general, the quality of data is a very important factor for successful LLM’s fine-tuning, rather than the quantity of data.
See the paper for LIMA, which tells “Less Is More”.

Fine-tuning for NER

Now we have completed dataset preparation and let’s start training.

In Hugging Face, there are the following 2 options to run training (fine-tuning).

  • Use transformer’s Trainer class, with which you can run training without manually writing training loop.
  • Build your own training loop. (See here for example.)

In this example, I’ll use Trainer class for fine-tuning the pre-trained model.

Note : Even when using Trainer wrapper in HuggingFace, you can seamlessly integrate with DeepSpeed optimization.

First I prepare the pre-trained XLM-RoBERTa (XML-R) model as follows.

import torch
from transformers import AutoConfig
from transformers.models.roberta.modeling_roberta import RobertaForTokenClassification

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

xlmr_config = AutoConfig.from_pretrained(
  "xlm-roberta-base",
  num_labels=tags.num_classes,
  id2label=index2tag,
  label2id=tag2index
)
model = (RobertaForTokenClassification
         .from_pretrained("xlm-roberta-base", config=xlmr_config)
         .to(device))

XLM-RoBERTa model is a variant of BERT model class.
There exist a lot of transformer architectures (such like, BERT-based, T5-based, or GPT-based), but BERT-based encoder-only model will be well-suited for classification task, and I then use this model for our token classification (NER).

As you can see above, here I have used built-in RobertaForTokenClassification model.
This model consists of pre-trained RobertaModel with classification head – i.e, subsequent dropout and linear layer on top of RoBERTa’s hidden-states output (on all tokens) – as follows.

If you prefer, you can also bring your own class (see below) and fully customize your model in HuggingFace fine-tuning.

custom model class (example)

import torch
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaPreTrainedModel

class MyCustomizationSampleModel(RobertaPreTrainedModel):
  _keys_to_ignore_on_load_unexpected = [r"pooler"]  # because we don't add pooling
  _keys_to_ignore_on_load_missing = [r"position_ids"]

  def __init__(self, config):
    super().__init__(config)
    
    #
    # The name of layer ("roberta", etc) is very important !
    # When you change these names, these weights and bias might be ignored in saving checkpoint.
    #

    self.num_labels = config.num_labels
    # hf roberta model
    self.roberta = RobertaModel(config, add_pooling_layer=False)
    # linear for classification
    self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
    self.linear = torch.nn.Linear(config.hidden_size, self.num_labels)
    # initialize weights
    ### self.init_weights()
    self.post_init()

  def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs):
    # build model
    roberta_output = self.roberta(
      input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids, # this will always be None
      **kwargs
    )
    x = self.dropout(roberta_output[0])
    logits = self.linear(x)
    # calculate loss if labels are provided
    loss = None
    if labels is not None:
      cross_entropy = torch.nn.CrossEntropyLoss()
      loss = cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
    # return result
    return TokenClassifierOutput(
      loss=loss,
      logits=logits,
      hidden_states=roberta_output.hidden_states,
      attentions=roberta_output.attentions
    )

Note : When you use your own model in Hugging Face trainer, please take care for Hugging Face manners, such as, variable names.
Because HuggingFace built-in save_pretrained() (which is used in saving checkpoint) will ignore the weights and bias of unknown variables.
See here for the source code.

Note : As I’ll describe later, the padded label id will become -100 in pre-processing data.
In above code, the padded token will then be ignored in loss, because PyTorch cross-entropy loss class (torch.nn.CrossEntropyLoss) has an attribute ignore_index which default value is -100.

To use Trainer class, I should prepare trainer’s parameters, called TrainingArguments.
In this argument object, I set FP16 which has half precision of float value for saving memory as follows. This will reduce accuracy, but it will be enough for training our model.

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir = "xlm-roberta-ner-ja",
  log_level = "error",
  num_train_epochs = 1,
  per_device_train_batch_size = 12,
  per_device_eval_batch_size = 12,
  evaluation_strategy = "epoch",
  fp16 = True,
  logging_steps = len(tokenized_ds["train"]),
  push_to_hub = False
)

Next I prepare HuggingFace data collator.
Data collator is used as a data preprocessor, and here I use the classification data collator, called DataCollatorForTokenClassification class.

This data collator (DataCollatorForTokenClassification) will apply padding to the longest sequence in the batch, and the padded token will be filled with label id -100 by default.
This token will then be ignored in loss and validation computation in our source code.

from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(
  xlmr_tokenizer,
  return_tensors="pt")

Finally I prepare the metrics function.
This function will be invoked in the end of each epoch and compute the metrics (here, F1 score) for validation.

As I have mentioned above, the padded label id is -100. Therefore the entity with label id -100 is ignored in this code.

import numpy as np
from seqeval.metrics import f1_score

def metrics_func(eval_arg):
  preds = np.argmax(eval_arg.predictions, axis=2)
  batch_size, seq_len = preds.shape
  y_true, y_pred = [], []
  for b in range(batch_size):
    true_label, pred_label = [], []
    for s in range(seq_len):
      if eval_arg.label_ids[b, s] != -100:  # -100 must be ignored
        true_label.append(index2tag[eval_arg.label_ids[b][s]])
        pred_label.append(index2tag[preds[b][s]])
    y_true.append(true_label)
    y_pred.append(pred_label)
  return {"f1": f1_score(y_true, y_pred)}

Now let’s put it all together.
In the following code, I build up Trainer with previous model, training arguments (TrainingArguments), data collator (DataCollatorForTokenClassification), and metrics function.

from transformers import Trainer

trainer = Trainer(
  model = model,
  args = training_args,
  data_collator = data_collator,
  compute_metrics = metrics_func,
  train_dataset = tokenized_ds["train"],
  eval_dataset = tokenized_ds["test"],
  tokenizer = xlmr_tokenizer
)

Now let’s run training for NER classification.
As I have mentioned above, you don’t need to manually write training loop. (Trainer class will do all setup.)

trainer.train()

After the training has completed, you can save model with Hugging Face libraries as follows.

import os
import torch
from transformers import AutoConfig

# save fine-tuned model in local
os.makedirs("./trained_ner_classifier_jp", exist_ok=True)
if hasattr(trainer.model, "module"):
  trainer.model.module.save_pretrained("./trained_ner_classifier_jp")
else:
  trainer.model.save_pretrained("./trained_ner_classifier_jp")

# load from the saved checkpoint
xlmr_config = AutoConfig.from_pretrained(
  "xlm-roberta-base",
  num_labels=tags.num_classes,
  id2label=index2tag,
  label2id=tag2index
)
model = (RobertaForTokenClassification
         .from_pretrained("./trained_ner_classifier_jp", config=xlmr_config)
         .to(device))

Predict and Evaluate

The following code predicts text with our fine-tuned model.

As you can see below, the first word of “中国” (which means “China”) is recognized as location name, but the second word of “中国” (“China”) in “中国共産党” (“Communist Party of China”) is recognized as organization name. Also, “共産党” (which means “Communist Party”) is recognized as organization, but “一党” (“one party”) is not recognized as organization.

from datasets import Dataset
import torch
from torch.utils.data import DataLoader
import pandas as pd

# create dataset for prediction
sample_encoding = xlmr_tokenizer([
  "鈴木は4月の陽気の良い日に、鈴をつけて熊本県の阿蘇山に登った",
  "中国では、中国共産党による一党統治が続く",
], truncation=True, max_length=512)
sample_dataset = Dataset.from_dict(sample_encoding)
sample_dataset = sample_dataset.with_format("torch")

# predict
sample_dataloader = DataLoader(sample_dataset, batch_size=1)
tokens = []
labels = []
for batch in sample_dataloader:
  # predict
  with torch.no_grad():
    output = model(batch["input_ids"].to(device), batch["attention_mask"].to(device))
  predicted_label_id = torch.argmax(output.logits, axis=-1).cpu().numpy()
  # create output
  tokens.append(xlmr_tokenizer.convert_ids_to_tokens(batch["input_ids"][0]))
  labels.append([index2tag[i] for i in predicted_label_id[0]])

# show the first result
pd.DataFrame([tokens[0], labels[0]], index=["Tokens", "Tags"])

pd.DataFrame([tokens[1], labels[1]], index=["Tokens", "Tags"])

Note : You can also use predict() method in Trainer class for running prediction.

Here I show you the confusion matrix of fine-tuned model.

I have also published the fine-tuned model into Hugging Face hub, and you can soon examine your own text in the widget. (Please try.)

tsmatz/xlm-roberta-ner-japanese (Hugging Face hub)
https://huggingface.co/tsmatz/xlm-roberta-ner-japanese

Cross-Lingual Models

As you saw above, I have used a model class called XLM(Cross-Lingual language Model)-RoBERTa.
What does “cross-lingual” (or multilingual) in this model’s name imply ?

In fact, this pre-trained model already has the ability to transfer “中国” (Japanese word) to “China” (English word), or “ニューヨーク” (Japanese word) to “New York” (English word).
When you evaluate our fine-tuned XML-RoBERTa (which is fine-tuned in Japanese) to transfer to English, you will find that our model works somewhat well also for English language (see below), although our model hasn’t seen English labeled corpus for NER.

You will see that it works more better for Korean language, since it has similar forms to Japanese language.
For instance, when you have fine-tuned model with German language, this model will also have better performance for Germanic language’s family – such as, English.

In this example, I have just fine-tuned a model with monolingual corpus (Japanese), but you will also be able to improve a model with multiple languages – such as, Japanese and Korean, German and English – using XTREAM corpora, such as PAN-X or WikiANN.

 

Note : See this notebook for running above example on multiple distributed nodes, performed by Azure Machine Learning. (In this example, I have used XTREAM dataset which has 40000 records, and applied distributed data parallel (DDP) architecture for distribution.)

 

Reference :

Fine-tuned Model (Hugging Face Hub)
https://huggingface.co/tsmatz/xlm-roberta-ner-japanese

Source Code / Notebook (GitHub)
https://github.com/tsmatz/huggingface-finetune-japanese

Document : Hugging Face and DeepSpeed integration
https://huggingface.co/docs/transformers/main_classes/deepspeed

 

Categories: Uncategorized

Tagged as:

2 replies »

Leave a Reply