CrabNet

CrabNet(Compositionally-Restricted Attention-Based Network)は、組成式に基づく材料物性予測に特化したオープンソースのTransformerモデルである。自然言語処理で使われる注意機構を材料組成に適用し、化学式を元素分布行列として表現して学習することで、特徴量設計を最小限にして回帰モデルを構築できる。

MatDaCs ツールレビュー: CrabNet

概要

CrabNetは、化学式のみを入力としてTransformer系の注意機構を材料インフォマティクスに適用するモデルである。本レビューでは、pip最新版(2.0.8)をmacOS(Apple M4 Pro)に導入し、どの程度の手間で微調整できるか、またMatminerやDScribeのような記述子ベース手法とどう補完関係にあるかを確認した。

CrabNetとは

CrabNetは、化学式を元素分布行列(EDM)に変換し、積み重ねた注意層と再帰的ヘッドを通して物性値と不確かさ(主にエピステミック)を推定する。Sterling Bairdによるリファクタ版リポジトリは、scikit-learnに近いAPI(例: CrabNet(mat_prop=\"...\").fit(df))を提供し、データ読み込みや可視化の補助ユーティリティも同梱する。

  • crabnet.utils.data.get_data: ベンチマークCSV(弾性率、硬さ等)や簡易サブセットを読み込む。
  • crabnet.utils.figures: ドキュメント品質の可視化(注意マップ等)を支援する。
  • extend_features: Transformerの後段でプロセス変数(温度・圧力など)を追加できる。

ドキュメントには、基本利用、追加特徴量との比較、教育用最小例の3つの例が示されている。

インストール

pip/condaのCrabNetは現状Python 3.7〜3.10を主に対象とする。pipホイールはCPU向け依存(NumPy、pandas、scikit-learn等)を導入するが、PyTorchは別途追加する必要がある。GPU利用(CUDA ≤11.6)を行う場合はsgbairdチャネルが案内されている。MatDaCsの記事では、Pythonバージョン制約を明記すると再現性が高まる。

例: 基本ワークフロー

公式の基本例は、弾性率データセットを読み込み、CrabNet(mat_prop=\"elasticity\")で学習し、不確かさ付きで予測する流れである。

from crabnet.utils.data import get_data
from crabnet.data.materials_data import elasticity
from crabnet.crabnet_ import CrabNet

train_df, val_df = get_data(elasticity, "train.csv", dummy=True)
cb = CrabNet(mat_prop="elasticity")
cb.fit(train_df)
val_pred, val_sigma = cb.predict(val_df, return_uncertainty=True)

本レビューでは、同等の処理をcrabnetbasicdemo.pyにまとめ、指標計算や成果物(CSV/図)を一括で出力できるようにした。実務的な報告に必要な情報(データ分割、パラメータ数、実行デバイスなど)をログし、検証用CSVを出力する構成である。

from pathlib import Path
import numpy as np
import pandas as pd

from crabnet.utils.data import get_data
from crabnet.data.materials_data import elasticity
from crabnet.crabnet_ import CrabNet

print('Loading full elasticity dataset...')
train_df, val_df = get_data(elasticity, 'train.csv', dummy=False)
print(f'Train rows: {len(train_df)}, Val rows: {len(val_df)}')

model = CrabNet(mat_prop='elasticity', epochs=20, batch_size=256, verbose=True)
model.fit(train_df)

val_pred, val_sigma = model.predict(val_df, return_uncertainty=True)
print('Validation predictions head:')
print(pd.DataFrame({'formula': val_df['formula'].head(), 'target': val_df['target'].head(), 'pred': val_pred[:5], 'sigma': val_sigma[:5]}))

mae = float(np.mean(np.abs(val_df['target'].values[:len(val_pred)] - val_pred)))
print(f'Validation MAE: {mae:.3f}')

Path('crabnet_val_predictions.csv').write_text(
    pd.DataFrame({'formula': val_df['formula'], 'target': val_df['target'], 'pred': val_pred, 'sigma': val_sigma}).to_csv(index=False)
)
print('Saved predictions to crabnet_val_predictions.csv')

ローカルでconda run -n crabnet310 python crabnetbasicdemo.pyを実行すると(CPUで約5分半)、validation MAEは概ね14 GPa程度となり、公式チュートリアルに近い水準を再現できた。保存したCSVから不確かさ付きのパリティプロットを作成し、MatDaCs向けの可視化素材も用意した。

このスクリプト化により、注意モデルのベースライン取得と、数値表および図のエクスポートを簡潔に行える。

所感

  • dummy=Trueはスモークテスト向けであり、再現可能な数値を示す場合はdummy=Falseで全データを用いる方がよい。
  • チェックポイントはmodels/trained_models/に保存される。複数試行では.pthが肥大化するため整理が必要である。
  • 学習は内部で並列化されており、特別なfork設定をせずに実行しやすい。
  • Torch 2.4ではnested tensorsの警告が出る場合があるが、現状は致命的ではない。

まとめ

CrabNetは、記述子ベースのツールに対する有力なニューラルベースラインとなる。少ないコードで弾性率や硬さなどのタスクで学習・予測・不確かさ推定ができ、extend_featuresによりプロセス変数も扱える。一方で最大の障壁はPythonバージョン互換性であり、3.9/3.10環境の準備が現実的である。導入後は、Matminer/DScribe系のワークフローと並べて比較する価値が高い。

参考

  • CrabNet documentation: <https://crabnet.readthedocs.io/en/latest/>
  • CrabNet GitHub repository: <https://github.com/sparks-baird/CrabNet>
  • Official example – Basic usage: <https://crabnet.readthedocs.io/en/latest/examples.html#basic-usage>
  • Official example – Extend features comparison: <https://crabnet.readthedocs.io/en/latest/examples.html#extend-features-comparison>
  • Official example – Bare-bones teaching example: <https://crabnet.readthedocs.io/en/latest/examples.html#bare-bones-teaching-example>
  • Wang et al., npj Comput. Mater. 7, 77 (2021)