雲のメモ帳

猫とクラウドと旅行が好きなインフラエンジニアです。 日々の調べたことや興味が持ったことをこのブログにアウトプットします。

rinnaの日本語特化GPTモデルを触ってみた

これはなに?

昨日チャットボットのりんなで有名なrinna株式会社からGPTに日本語モデルが公開されましたので、とうとう日本語モデルが来たかとわくわくしながら触ってみました。

prtimes.jp

GPTとは

以下元記事からの引用です。
これまでOpenAI GPT-3でも日本の文章生成をすることができましたが、日本語の学習が足りていないため、英語の文章生成と比べると劣っていました。今回rinna社が日本語のwikipediaやC4やcc-100といった膨大なデータモデルを学習することで、日本語生成に関する精度が高いモデルが公開されました。

■日本語GPTモデルの機能
言語モデルは、会話や文章の「人間が使う言葉」を確率としてモデル化します。優れた言語モデルとは、確率を正確に推定できるものを指します。 例えば、 “確率(吾輩は猫である)>確率(吾輩が猫である)”
と推定できることが、言語モデルの能力です。
GPTは、単語の確率の組み合わせから次の単語の確率を計算する言語モデルです。例えば、“確率(吾輩は猫である) = 確率(吾輩)×確率(は|吾輩)×確率(猫|吾輩,は)×確率(で|吾輩,は,猫)×確率(ある|吾輩,は,猫,で)”のような方法で計算を行います。この性質を用いて、GPTは「吾輩は」という単語を入力したとき、次の単語として確率が高い「猫」を予測することができます。

引用: rinna社、日本語に特化した13億パラメータのGPT言語モデルを公開|rinna株式会社のプレスリリース

実行環境

モデルを実行するために、環境準備していきましょう。以下のブログで紹介しているDocker環境で実行しています。
Dockerで自然言語処理(NLP)をするための環境構築 - 雲のメモ帳

環境構築

Docker上でJupyterLabを起動後に、transformersを使っているので、必要なライブラリをインストールします。

!pip install transformers SentencePiece

注意点としては、SentencePiece をインストールした後にカーネルを再起動しないとうまく動作しないので、インストール後にカーネル再起動してください。

環境構築は以上OKです。

ソースコード

rinna社のほうでモデルを実行するためのソースコードが公開されているのでこちらを利用します。
再実行がしやすいように文書作成の部分を関数化しました。

rinna/japanese-gpt-1b · Hugging Face

出力結果を見るとパラメータの調整が必要な感触でした。GPT2の時のですが、パラメータを解説してくださっている記事があったのでリンク張っておきます。
rinna GPT-2モデルの生成パラメータ

import torch
from transformers import T5Tokenizer, AutoModelForCausalLM

tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt-1b")
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-1b")

if torch.cuda.is_available():
    model = model.to("cuda")

def generate_text(text):
    token_ids = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt")

    with torch.no_grad():
        output_ids = model.generate(
            token_ids.to(model.device),
            max_length=100,
            min_length=100,
            do_sample=True,
            top_k=500,
            top_p=0.95,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            bad_word_ids=[[tokenizer.unk_token_id]]
        )

    output = tokenizer.decode(output_ids.tolist()[0])
    return output 

text='今日は日曜日です。明日は、'
generate_text(text)

動作確認

いくつか文章を入れて、後の文章予測の結果を見てみましょう。
色を付けた部分がGPTによって生成された文章です。

単純な質問

Q: キズナアイは

A: 'キズナアイは、キズナアイ」という愛称で親しまれているバーチャルYouTuberです。バーチャルYouTuberとは、インターネット上に存在する仮想のYouTuberで、主に動画投稿サイト「YouTube」で配信されています。キズナアイは、その容姿から「キズナアイ」という愛称で親しまれているバーチャルYouTuberです。バーチャルYouTuberとは、インターネット上に存在する仮想のYouTuberで'
'

結果としては、いい感じですが、文章生成でよくあることで同じような文章でループしちゃってますね。
repetition_penaltyを入れるとループに関するペナルティを調整できるようですが、このパラメータを入れると生成される文章が短くループするようになったので、パラメータの指定の仕方間違っているかも。

物語の生成

Q:'むかしむかし、あるところに、おじいさんとおばあさんが住んでいました。 おじいさんは'

A: 'むかしむかし、あるところに、おじいさんとおばあさんが住んでいました。 おじいさんは山へ芝刈りに、おばあさんは川へ洗濯に行きました。 夕方になったので、おじいさんが「おなかがすいた」と言ったら、おばあさんが言いました。 「そうか。じゃあ、夕飯のしたくをしよう。」 すると、おじいさんが言いました。 「そうか。それなら私は畑へ行ってくるよ。」 夕方、お父さんとお母さんが畑の仕事に出かけると、おじいさんはおじいさんの弟を畑に残して、'

最初の方は桃太郎っぽい雰囲気でしたが、桃太郎に出会えなかった世界の話ですかね笑 どんな物語を学習データにしたのか気になりますね。

情報をインプットした後での質問

事故報告書のインプット & 事故の発生原因は

A: ○○部長殿 この度は業務中に交通事故を起こしたことで、多大なご迷惑をおかけしました。以下、今回の事故について報告致します。 記 事故の内容 発生日時:○月○日午後2時頃 発生場所:○市1丁目交差点付近 事故状況:左車線から進路変更してきた車と接触 発生原因 左車線から進路変更する車に気づくのが遅くなり、ブレーキのタイミングを誤ったことによる衝突 発生状況 ・○月○日午後1時40時頃、取引先〇〇社との打ち合わせのため社を出発 ・同日午後2時時頃、○市1丁目交差点付近に差し掛かる、進行方向の信号は青 ・左車線の軽トラックが加速して車線変更、接触 被害状況 軽トラックの荷台後部が左バンパーとライト部分に接触し、破損 衝突の際の衝撃で頸椎捻挫、全治約2カ月(相手方には健康上の被害はなし) 事故の発生原因は、○市1丁目交差点付近で左車線から進路変更して進路変更する車があったことに気づかず、それを止めるのが遅れたこと ・出先ということもあり、携帯電話を手にしていたため、必要以上に気を取られてしまい、スピードを出していたこと 以上、猛省し、改めて自分の行動や自覚の甘さを痛感致しました。今後は交通ルールや、基本的な車間距離や安全確認を怠らないよう日々努力していく所存です。 どうかお許しください。

インプットとして入れたデータをアウトプットの部分に出してくれてますね。文章の雰囲気を見てなのか、謝罪文のテイストで文章生成してるのがすごい。

AI同士で会話させてみる

GPTで出力されたテキストをまたGPTで分析してテキストを生成するというのを繰り替えして、AI同士が会話してるっぽくしてみました。

実行結果

結局文章の次の文章を予測するものなので、独り言を二人に区切ったみたいなものになることが多かったです。会話っぽくはならないですね。何回かためしてみましたが、ラジオのパーソナリティみたいな文章になることが多かったです。
文章の語尾に何かつけるとうまいことなるかもしれませえんが、もう少し検証が必要そうです。

・1回目
途中空白になっている部分には、HTMLのタグが入ってきました。

Aさん: こんにちは、今日はいい天気ですね。
Bさん: 明日は雨が降るとか言います。どうなっているんだ。土曜日なんて暇すぎます。うちの会社は基本的に土日がお休みなので土日です
Aさん: 。土日が休みの会社、多いですよねー。ほとんどの方は3連休ですね。みなさんどんな風に過ごされるんでしょうか。僕は土曜日に何しよう、お祭り行きますかねー。日曜日も何しようかな、適当に映画でも観るかなー。
Bさん:
Aさん: コラボレーションという形で、多くの素晴らしい作家さんたちとつながっている。それって非常に素晴らしいことですね。それを継続させるために、皆さん、本当に強いんです!
Bさん:

・2回目
言葉が途中で切れちゃうと会話じゃなくて前の文章の続きになっちゃいますね。

Aさん: こんにちは、今日はいい天気ですねー!最近どうですか
Bさん: ?自分は10日間ほど前に、とある大学の、とある映画を見てきましたよ!!しかも二本
Aさん: とも!すごく面白くて、元気になる映画でしたよ!また記事の方で告知しようかと。映画おすすめしたいですね、青春は素晴らしい!
Bさん: また明日です!おやすみなさい!

3回目

Aさん: こんにちは、今日はいい天気ですねー!最近どうですか
Bさん: ?晴れてる昼間は暑いですが、朝晩は寒いです風邪などひかないよう、
Aさん: お気を付けくださいませ。それでは!おやすみなさい。
Bさん: ←。良い夢見れますように。明日も頑張ろう!【告知】いつも閲覧していただいている

ソースコード

import torch
from transformers import T5Tokenizer, AutoModelForCausalLM

text='こんにちは、今日はいい天気ですねー!最近どうですか'

# tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt-1b")
# model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-1b")

# if torch.cuda.is_available():
#     model = model.to("cuda")

def generate_text(text):
    tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt-1b")
    model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-1b")

    if torch.cuda.is_available():
        model = model.to("cuda")
    token_ids = tokenizer.encode(text[-100:], add_special_tokens=False, return_tensors="pt")
    
    max_length=30
    min_length=30
#     print(len(text))
    if max_length < len(text):
        max_length = len(text)*2
        min_length = len(text)*2
    
    with torch.no_grad():
        output_ids = model.generate(
            token_ids.to(model.device),
            max_length=max_length,
            min_length=30,
            do_sample=True,
            top_k=500,
            top_p=0.95,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            bad_word_ids=[[tokenizer.unk_token_id]]
        )

    output = tokenizer.decode(output_ids.tolist()[0])
    return output.replace(text[-100:],'').replace('</s>','')


text_list_A=[text]
text_list_B=[]

count=0
while count < 5:
    if count == 0:
#         text_list_A.append(generate_text(text))
        print('Aさん: ' + text_list_A[0])
        
        text_list_B.append(generate_text(text_list_A[0]))
        print('Bさん: ' + text_list_B[0])
    else:
        text_list_A.append(generate_text(text_list_B[count-1]))
        print('Aさん: ' + text_list_A[count])
        text_list_B.append(generate_text(text_list_A[count]))
        print('Bさん: ' + text_list_B[count])
        
    count=count+1

最後に

かなり自然な日本語に近い文章が生成されるようになりましたね。文章を工夫すれば、分類や翻訳などにも利用できるようです。チャットボットの実態としては精度的にルールベースのものが現時点では多いと思いますが、今後この技術が進化していけばモデルを学習させ続けるだけでよくなるかもしれませんね。
懸念としては、生成する文章が結構バイオレンスな内容も含んでるので、本番を見据えるとモデルが生成した内容をチェックする必要があるかもです。すごいバイアスがかかった内容が生成されてびっくりすることがありました。
チャットボット以外にこれをどうやって使うか考えてみよう。