アメリカ気象局の天気図の前線描画を学習させたSemantic Segmentationのニューラルネットワークに、日本付近の前線描画をさせるまで
1. はじめに
しばらくアメリカ暮らしをすることになりました。日本で使っていたパソコンも無事移設することが出来ましたので、せっかくの機会ですからアメリカの気象データを使って機械学習をやってみました。まずは「気象可視化画像から前線を自動描画する」を試しました。また、このネットワークに日本付近の前線解析をさせてみました。アメリカ流の気象データ解析(前線を検知して描画する)を学習して、日本のデータを解析してみた。ということになります。
結果としては、こんな感じです。
(停滞前線については、各機関作成の気象図では赤と青で温暖前線・寒冷前線成分を表しますが、ニューラルネットワーク生成版では緑色として表示しています)
静止画
アメリカ合衆国付近(2019年1月5日 12時)
左;ニューラルネットワークが生成、右:アメリカ気象局が解析
日本付近(2023年3月3日 0時)
左;ニューラルネットワークが生成、右:日本気象庁が解析
動画
アメリカ合衆国付近(2019年8月)
左;ニューラルネットワークが生成、右:アメリカ気象局が解析
日本付近(2023年8月)
左;ニューラルネットワークが生成、右:日本気象庁が解析
前線というのは、大まかには性質の異なる空気の接するところ、です。ただ、実際の気象場の中にはそういうところは沢山あって、それら全てに前線記号をつけていくわけではありません。どれにどう線を描いていくのかは各国の気象機関によって異なります。そういったわけで、気象データが同じでも機関や解析する人の判断が入った図になっています。その判断の部分を機械学習によって模倣できるか?というのが前線自動描画です。
必要な材料は以下の通りです。
(1) スーパーコンピュータによる全球気象予報モデルが作成したデータファイル
アメリカの気象機関が配信しているデータファイルです。これをもとに可視化画像を作成します。2.で説明します。
(2) 地上気象解析図
アメリカの気象機関が業務として行なっている気象解析(前線などを記入したもの)のアウトプットです。
ここから前線などの情報を抽出します。これについては下記に詳細をまとめました。
(3) Semantic Segmentationを行うプログラム
セマンティックセグメンテーションによって、地図の各ピクセルが前線に対応する・しない、前線に対応するならばどの種類の前線に対応するのか、を計算して色分けします。3.で説明します。
(1)と(2)は沢山あれば沢山あるほど学習データが増えます。
今回の投稿では、(1)と(3)について説明をした後、実際の自動描画データを見てみます。
2. 全球気象予報モデルが作成したデータファイル
2.1 アメリカ合衆国の気象機関NOAA
アメリカ合衆国の国家海洋大気庁(NOAA)は商務省に属する機関です。日本の気象庁は旧運輸省、現国土交通省の傘下の組織ですね。
NOAAの下には日本の気象庁に相当するNational Weather Service(国立気象局)など、いくつかの組織があります。
The organizational chart for NOAA headquarters leadership and senior staff as of May 30, 2023. (Image credit: NOAA)
日本の気象庁との違って地震地象関係は管掌範囲にないようです。一方で宇宙天気予報関係はNOAAの範囲です。
NOAA傘下の様々な組織は観測や予報のデータを広く公開していて"free and available to the public without restriction on use"であるとの記載が見られます。
2.2 予報モデルのデータ
いくつかの機関が予報モデルの結果データを公開しています。過去に遡って取得できるデータとしては私が検索した範囲では、過去のデータも含めて取得できる機関として、環境情報センター(NCEI)が目的に沿っていました。
2.2.1 Analysis data of GFS(全球予報モデルの解析データ)
NWSも全球予報モデルを運用しています。この計算結果が格子データとして利用できるようになっています。予報データ(forecast)と解析データ(analysis)に分かれていますが、機械学習で前線を描画する際には地表面解析の天気図と対応させますので、解析データの方を利用します。解析データはその時刻における、気象要素の値を格子点で計算したものです。限りなくその時刻に観測される値に近いと考えられます。
NCEIの全球予報モデルの解説ページによると、NCEIがサーバからダウンロードできるものとして提供している全球予報モデル(解析)は以下の種類があります。
Grid003(格子間隔緯度で1度)と、Grid004(0.5度)のものがあります。今回は高精度の0.5度のものを用います。
日本と異なりGPVデータを直接ダウンロード出来るようになっていて、自由に使ってよいとのことです。これらのデータは下記サイトからダウンロードできます。
2.2.2 データファイルの内容
ファイル形式はGRIBですので、日本気象庁の全球予報モデルのGPVを読み取るプログラムがほぼそのまま利用できます。
とはいえ、ファイルに含まれている気温、湿度といった物理量の種類も異なりますし、その物理量を指定するパラメータ名も異なるので、GRIBファイルをのぞいて調べてみます(フォーマットを解説しているページもありましたがレコード数が異なったりするので実物を見るのみです)。
Pythonのライブラリであるpygrib
を用いてファイルに含まれるデータの一覧を調べてみます。
>>> import pygrib
>>> grbs=pygrib.open("gfs_4_20200811_0000_000.grb2")
>>> for g in grbs:
... g
...
1:Cloud mixing ratio:kg kg**-1 (instant):regular_ll:hybrid:level 1:fcst time 0 hrs:from 202008110000
2:Ice water mixing ratio:kg kg**-1 (instant):regular_ll:hybrid:level 1:fcst time 0 hrs:from 202008110000
3:Rain mixing ratio:kg kg**-1 (instant):regular_ll:hybrid:level 1:fcst time 0 hrs:from 202008110000
4:Snow mixing ratio:kg kg**-1 (instant):regular_ll:hybrid:level 1:fcst time 0 hrs:from 202008110000
5:Graupel (snow pellets):kg kg**-1 (instant):regular_ll:hybrid:level 1:fcst time 0 hrs:from 202008110000
略
520:Pressure reduced to MSL:Pa (instant):regular_ll:meanSea:level 0:fcst time 0 hrs:from 202008110000
521:5-wave geopotential height:gpm (instant):regular_ll:isobaricInhPa:level 50000 Pa:fcst time 0 hrs:from 202008110000
522:Land-sea coverage (nearest neighbor) [land=1,sea=0]:~ (instant):regular_ll:surface:level 0:fcst time 0 hrs:from 202008110000
>>>
日本気象庁の場合と同じように要素別に格納されていますがデータの種別はかなり多いです。
いずれかの要素を取得するための指定には、parameterName
というパラメタを用いることが一般的です。この一覧を出してみます。
>>> for g in grbs:
... print(f"parameterName={g['parameterName']} record={g}")
...
parameterName=Ice water mixing ratio record=2:Ice water mixing ratio:kg kg**-1 (instant):regular_ll:hybrid:level 1:fcst time 0 hrs:from 202008110000
parameterName=Rain mixing ratio record=3:Rain mixing ratio:kg kg**-1 (instant):regular_ll:hybrid:level 1:fcst time 0 hrs:from 202008110000
parameterName=Snow mixing ratio record=4:Snow mixing ratio:kg kg**-1 (instant):regular_ll:hybrid:level 1:fcst time 0 hrs:from 202008110000
parameterName=Graupel (snow pellets) record=5:Graupel (snow pellets):kg kg**-1 (instant):regular_ll:hybrid:level 1:fcst time 0 hrs:from 202008110000
parameterName=196 record=6:Maximum/Composite radar reflectivity:dB (instant):regular_ll:atmosphere:level 0 -:fcst time 0 hrs:from 202008110000
parameterName=Visibility record=7:Visibility:m (instant):regular_ll:surface:level 0:fcst time 0 hrs:from 202008110000
parameterName=u-component of wind record=8:U component of wind:m s**-1 (instant):regular_ll:planetaryBoundaryLayer:level 0:fcst time 0 hrs:from 202008110000
parameterName=v-component of wind record=9:V component of wind:m s**-1 (instant):regular_ll:planetaryBoundaryLayer:level 0:fcst time 0 hrs:from 202008110000
parameterName=224 record=10:Ventilation Rate:m**2 s**-1 (instant):regular_ll:planetaryBoundaryLayer:level 0:fcst time 0 hrs:from 202008110000
以下略
paramterName
がただの数値(196とか224とか)のものもあります。
2.2.3 可視化プログラムについて
実は長い期間でデータを取得していると、日によってなぜかparameterName
が異なっている時があって、突然エラーになってしまったりします。このようなイレギュラリティ、どこかで解説されているのか見つけきれていません。そこで対処療法で遭遇した場合にparamterName
を調査しては以下のような対処を入れました。
letter_mslp = "Pressure reduced to MSL"
if os.path.exists(_GSMfilename) :
grbs = pygrib.open(_GSMfilename)
try:
grb_prs = grbs.select(parameterName=letter_mslp , level=0 , forecastTime=0)
except:
try:
grb_prs = grbs.select(parameterName="192" , level=0 , forecastTime=0)
大抵の場合は"Pressure reduced to MSL"
というparamterNameで取得できますが、まれに192
となっていることがあるためこういう処理を入れています。
2018年2月28日18時(UTC)のデータを可視化したもの。
上段左から850hPa相対温位、700hPa上昇流、700hPa湿数。
中段左から500hPa風・温度・気圧、850hPa風・温度・気圧、地表風・温度・気圧。
下段300hPa風速
その他の部分も含めた可視化のコードの詳細は以下に折りたたまれています。
気象独自のライブラリはPygribと、MetPy、地図のライブラリとして私はBasemapを用いています。
可視化コードの詳細はこちらから
import pygrib
import sys
import os
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import math
from mpl_toolkits.basemap import Basemap
from matplotlib.colors import BoundaryNorm
from matplotlib.ticker import MaxNLocator
#
#from metpy.calc import dewpoint_rh
from metpy.calc import dewpoint_from_relative_humidity
from metpy.calc import equivalent_potential_temperature
from metpy.calc import lat_lon_grid_deltas
from metpy.calc import vorticity
from metpy.calc import potential_temperature
from metpy.calc import wind_direction, wind_speed
from metpy.units import units
MetPyは数年前のバージョンから関数名が変更になったりしています。
dewpoint_rh
は、dewpoint_from_relative_humidity
に変わりました。
複数のデータを対象として描画するようにプログラムをする場合、Basemapの生成は時間がかかるので、最初に1度だけ行なってインスタンスを使い回すのがよいです。そういう理由で下記の描画ルーチンではBasemapのインスタンスm
を呼び出し元から受け取っています。
def drawmap( values , levels , cmap , _save_filename , flag_show , name , m) :
# 等値線とカラーコンターの組み合わせを描画する
# values:描画するデータ levels:数値範囲 cmap:使用するカラーマップ m:Basemapインスタンス
#図の生成と余白を小さくする調整
fig,ax = plt.subplots(figsize=(5.8,4.2))
plt.subplots_adjust(left=0.01, right=0.98, top=0.99, bottom=0.01)
#地図要素(緯度経度、海岸線、州境界)の描画
m.drawparallels(np.arange(-80.,81.,10.))
m.drawmeridians(np.arange(-180.,181.,10.))
m.drawcoastlines(linewidth=0.5)
m.drawstates(linewidth=0.3)
#データの緯度経度を描画座標に変換
x , y = m(lons, lats)
#等値線
m.contour( x , y , values , levels=levels , linewidths=0.4 , colors='k' )
#カラーコンター
m.contourf( x , y , values , levels , cmap=cmap )
plt.savefig(_save_filename)
plt.close()
def windmap( u , v , values , levels , cmap , _save_filename , flag_show ,
name , values2 , levels2 , m) :
# 矢羽根図付きの等値線・カラーコンターを描画する
# u, v:矢羽根図のためのデータ values, values2:描画するデータ2種
# levels, levels2:数値範囲2種 cmap:使用するカラーマップ m:Basemapインスタンス
#図の生成と余白を小さくする調整
fig,ax = plt.subplots(figsize=(5.8,4.2))
plt.subplots_adjust(left=0.01, right=0.98, top=0.99, bottom=0.01)
print(name)
print("max={0}, min={1}".format(values.max(),values.min()))
#地図要素(緯度経度、海岸線、州境界)の描画
m.drawparallels(np.arange(-80.,81.,10.))
m.drawmeridians(np.arange(-180.,181.,10.))
m.drawcoastlines(linewidth=0.5)
m.drawstates(linewidth=0.3)
#矢羽根をplotする点を間引いておく
lons_s = lons[::3, ::3]
lats_s = lats[::3, ::3]
x , y = m(lons, lats)
xs , ys = m(lons_s, lats_s)
us = u[::3, ::3]
vs = v[::3, ::3]
# 等値線2種(等温線、等圧線)
m.contour( x , y , values , levels=levels , linewidths=0.4 , colors='k' )
m.contour( x , y , values2 , levels=levels2 , linewidths=0.2 , colors='k' )
# カラーコンター
m.contourf( x , y , values2 , levels2 , cmap=cmap )
# 矢羽根図
m.barbs( xs , ys , us , vs , length=3.5, linewidth=0.4,
sizes=dict(spacing=0.16, emptybarb=0.00) )
plt.savefig(_save_filename)
plt.close()
GRIB形式のファイルデータを処理していきます。
#- level for contour map
levels_prs = np.arange(930.0,1080.0,4.0) #地上気圧(海面更正気圧)
levels_tmp = np.arange(210,316,2) #温度(ケルビン)
levels_tmps = np.arange(-35,45,2) #温度 (摂氏)
levels_dp2 = np.arange(0,3,1) #湿数(露点温度と気温の差)
levels_vv = np.arange(-5,5,0.5) #鉛直速度
levels_ept = np.arange(200,400,3) #相当温位
levels_wind = np.arange(-100,100,2) # 300hPaでの風速
levels_gh_500 = np.arange(4500,6200,40) # 500hPaでのジオポテンシャル高度
levels_gh_850 = np.arange(1000,2000,40) # 850hPaでのジオポテンシャル高度
#- parameterName for select from GRIB file
letter_mslp = "Pressure reduced to MSL" # 海面更正気圧
letter_prs = "Pressure reduced to MSL" # Same for JMA and NOAA
letter_tmp = "Temperature" # 温度
letter_rh = "Relative humidity" # 相対湿度
letter_vv = "Vertical velocity" # 鉛直速度
letter_wu = "u-component of wind" # 風速の東西成分
letter_wv = "v-component of wind" # 風速の南北成分
letter_gh = "Geopotential height" # ジオポテンシャル高度
#- Directory and File name settings
_GSMfile_head = "gfs_4_"
_sourcedir = Your GPV data directory
_outdir = Your visualized image output directory
# Basemap instance
# 地図インスタンス生成。これは以下のループ内では使い回します。
m = Basemap(projection='stere', llcrnrlat=18, urcrnrlat=51, llcrnrlon=-126, urcrnrlon=-54, lat_0=40, lon_0=-100, resolution='i' )
対象となる日付、時刻のGRIBファイルをオープンして必要なデータを取得します。その上で描画関数を呼び出して、画像を作っていきます。この時、GRIBファイルには含まれていない描画対象データを、複数のGRIBに含まれるデータ項目から計算する必要があるものがあります。
例えば、相当温位は気温、気圧、露点温度から計算されます。露点温度はGRIBに無いですが、相対湿度と気温から計算することができます。そういうわけで、必要なデータの取得はまとめて行なって、それから描画に入った方がよいということになります。
なお、NCEIのデータは欠落している時がみられますので、GRIBファイルの存在チェックをしておくことが
必須です。
# ひと月分のGPVを対象とする
for _dy in range(1, 32):
for _hr in _hr_list : # 描画する対象時刻
_day = _dy
# ファイル名を生成する (gfs_4_YYYYMMDD_HH00_000.grb2)
_GSMfilename =
_sourcedir + _GSMfile_head
+ str(_year) + str( _mon ) + str( _day ).zfill(2) + "_" + str( _hr )
+ "00_000.grb2"
#- Open GRIB file
# データが欠落している日・時刻が散在しているのでファイル存在チェックは必須
if os.path.exists(_GSMfilename) :
grbs = pygrib.open(_GSMfilename)
#-- Surface(地表)のパラメタ揺れ対応
try:
grb_prs = grbs.select(parameterName=letter_mslp , level=0 , forecastTime=0)
except:
try: #地表気圧のパラメタ名が違うパターンの場合
grb_prs = grbs.select(parameterName="192" , level=0 , forecastTime=0)
except: #未知の地表気圧のパラメタ名が違うパターンがあったら、の場合
print("GPV {0}{1}{2}{3} presshure unknown format".format(str(_year), str(_mon), str(_day).zfill(2), str(_hr)))
continue
try:
# データ取得
print("surface")
grb_tmp = grbs.select(parameterName=letter_tmp , level=80 , forecastTime=0)
grb_wu = grbs.select(parameterName=letter_wu , level=80, typeOfLevel="heightAboveGround", forecastTime=0)
grb_wv = grbs.select(parameterName=letter_wv , level=80, typeOfLevel="heightAboveGround", forecastTime=0)
grb_rh = grbs.select(parameterName=letter_rh , level=1000 , forecastTime=0)
#-- 850 hPa level : temperature wind equivalent-potential-temperature
print("850")
grb_tmp85 = grbs.select(parameterName=letter_tmp , level=850 , forecastTime=0)
grb_wu85 = grbs.select(parameterName=letter_wu , level=850 , forecastTime=0)
grb_wv85 = grbs.select(parameterName=letter_wv , level=850 , forecastTime=0)
grb_rh85 = grbs.select(parameterName=letter_rh , level=850 , forecastTime=0) # to be used for dew-point
grb_gh85 = grbs.select(name=letter_gh , level=850 , forecastTime=0)
# to be used for dew-point
#-- 700 hPa level : Sitsusu(T-Td) Vertical-wind
print("700")
grb_tmp70 = grbs.select(parameterName=letter_tmp , level=700 , forecastTime=0)
grb_rh70 = grbs.select(parameterName=letter_rh , level=700 , forecastTime=0) # to be used for dew-point
grb_vv70 = grbs.select(name=letter_vv , level=700 , forecastTime=0)
#-- 500 hPa level : vorticity geopotential-height
print("500")
grb_tmp50 = grbs.select(parameterName=letter_tmp , level=500 , forecastTime=0)
grb_wu50 = grbs.select(parameterName=letter_wu , level=500 , forecastTime=0)
grb_wv50 = grbs.select(parameterName=letter_wv , level=500 , forecastTime=0)
grb_rh50 = grbs.select(parameterName=letter_rh , level=500 , forecastTime=0) # to be used for dew-point
grb_gh50 = grbs.select(name=letter_gh , level=500 , forecastTime=0)
# to be used for dew-point
#-- 300 hPa level : wind geopotential-height
print("300")
grb_wu30 = grbs.select(parameterName=letter_wu , level=300 , forecastTime=0)
grb_wv30 = grbs.select(parameterName=letter_wv , level=300 , forecastTime=0)
grb_rh30 = grbs.select(parameterName=letter_rh , level=300 , forecastTime=0) # to be used for dew-point
grb_gh30 = grbs.select(name=letter_gh , level=300 , forecastTime=0)
grb_tmp30 = grbs.select(parameterName=letter_tmp , level=300 , forecastTime=0)
#
except: # 未知のパラメタ揺れが生じた場合への対処
print("GPV {0}{1}{2}{3} select error".format(
str(_year), str(_mon), str(_day).zfill(2), str(_hr)
)
)
continue
# 緯度経度を取得
g1=grb_tmp[0]
g2=grb_prs[0]
lats, lons = g1.latlons()
lats2, lons2 = g2.latlons()
_file_sufx = Your file name sufix etc.
ここまでがデータの取得です。この後、描画を行います(対象となる日・時刻のループ内の処理が継続しています)。
# 各可視化図の描画
# Surface
#- Surface pressure (MSLP)
_save_filename = _outdir + "gfs_prs_srf" + _file_sufx
drawmap( grb_prs[0].values / 100.0 , levels_prs , "coolwarm" , _save_filename ,_flag_show , "MSLP" , m )
#- Surface wind, tempareture, pressure
_save_filename = _outdir + "gfs_wnd_srf" + _file_sufx
windmap( grb_wu[0].values , grb_wv[0].values , grb_prs[0].values / 100.0 , levels_prs, "coolwarm" , _save_filename , _flag_show , "Surface wind srf.press." , grb_tmp[0].values - 273.15 , levels_tmps , m)
# 850 hPa
#- 850 hPa : equivalent-potential-temparature(相当温位);MetPyで算出
_save_filename = _outdir + "gfs_ept_850" + _file_sufx
dpval = dewpoint_from_relative_humidity(np.array(grb_tmp85[0].values) * units('K'), grb_rh85[0].values/100).to(units('K'))
prsval = dpval # 配列と同じDataタイプを作る
prsval = 850
prsval = prsval * units('hPa')
ept_val = equivalent_potential_temperature(prsval, np.array(grb_tmp85[0].values)*units('K'), dpval)
drawmap( ept_val , levels_ept , "coolwarm" , _save_filename , _flag_show , "850hPa equivalent potential temparature" , m)
#- 850 hpa wind and geo-hight, temperature
_save_filename = _outdir + "gfs_wnd_850" + _file_sufx
windmap( grb_wu85[0].values , grb_wv85[0].values , grb_gh85[0].values , levels_gh_850 , "coolwarm" , _save_filename , _flag_show , "850 hPa wind geo-hight" , grb_tmp85[0].values, levels_tmp , m)
#- 850 hpa potential temperature
# 700 hPa
#- 700 hPa : vertical velocity
_save_filename = _outdir + "gfs_vvl_700" + _file_sufx
drawmap( grb_vv70[0].values , levels_vv , "bwr" , _save_filename , _flag_show , "700hPa vertical velocity" , m )
#- 700 hPa : Sitsusu
_save_filename = _outdir + "gfs_situ_700" + _file_sufx
dpval = dewpoint_from_relative_humidity(np.array(grb_tmp70[0].values) * units('K'), grb_rh70[0].values/100).to(units('K'))
situsu_val = np.array(grb_tmp70[0].values)*units('K') - dpval
drawmap( situsu_val , levels_dp2 , "coolwarm" , _save_filename, _flag_show , "700 hPa situsu" , m )
# 500 hPa
#- 500 hpa wind and geopotential height, tempareture
_save_filename = _outdir + "gfs_wnd_500" + _file_sufx
windmap( grb_wu50[0].values , grb_wv50[0].values , grb_gh50[0].values , levels_gh_500, "coolwarm" , _save_filename , _flag_show , "500 hPa wind geo-height" , grb_tmp50[0].values, levels_tmp , m )
# 300 hPa
# 300hPa wind speed
_save_filename = _outdir + "gfs_wsp_300" + _file_sufx
drawmap( wind_speed(np.array(grb_wu30[0].values)*units('m/s'), np.array(grb_wv30[0].values)*units('m/s')) , levels_wind , "coolwarm" , _save_filename , _flag_show , "300hPa Wind speed" , m)
grbs.close()
else :
print("GPV file not found. {0}".format(_GSMfilename))
3. セマンティックセグメンテーションを用いた気象前線の自動生成
3.1 ResNet50
KerasのCode Exampleの中に、Multiclass semantic segmentation using DeepLabV3+というチュートリアルがあって、コードも公開されています。
入力となる気象データは7種類の画像データに可視化しますので、7種類x3ch=21chのデータです。
ResNetは3chの写真を想定していますので、ResNet50で扱えるように、ResNet50の前段にConvolution層を一段設置してチャンネル数を3個にします。
サンプルコードに対する主な修正は次の2点です。
(1) 前段のConvolution層の配置
# 追加:入力データを3チャンネルのデータに変換する
def Merge_Multiple_Var(mmv_input):
mmvs = mmv_input.shape
x = layers.Conv2D( 3, kernel_size=(3,3), \
activation='relu', padding='same')(mmv_input)
return x
# 本体ネットワークへの組み込み
def DeeplabV3Plus():
中略
# オリジナルのモデルの定義(NN_1として定義)
NN_1 = Model(inputs=model_input, outputs=model_output)
# 入力変換のための層を通った後にResNet50(NN_1)を通るようにネットワークNN_2を定義
addD3Net_input = tf.keras.Input(shape=(image_size, image_size, num_met_var))
y = Merge_Multiple_Var(addD3Net_input)
addD3Net_output = NN_1(y)
NN_2 = Model(inputs=addD3Net_input, outputs=addD3Net_output)
関数Muerge_Multiple_Var
を定義して、本体ネットワークDeeplabV3Plus
に組み込みます。
具体的には、まずオリジナルのResNet50をNN_1
として定義します。
そして入力であるaddD3Net_input
が関数Merge_Multiple_Var
を通り、その後NN_1
を通ってaddD3Net_output
として出力される、という流れを定義します。
その上で、addD3Net_input
を入力、addD3Net_output
を出力とするNN_2
を定義します。
このNN_2
が欲しいネットワークとなります。
(2) ResNet50の重みの登録
前後に層を追加したことで、ResNet50を定義するタイミングで重みのロードが出来なくなります。
層の数が違う、というエラーメッセージが出てしまいます。この対処方法は、ChatGPTに聞いたところ、たちどころに解決しました。慣れている人には当たり前の処理なんでしょうね...
個人的には、ChatGPTはデバッグ、エラーメッセージの意味を聞いてこれを解消する方法を示唆してもらうという使い方がとても役に立っています。
ということで、最初にresnet50
を定義する際にはweights=NONE
を指定します。
そのあとで、resnet50
を重み付き(weights="imagenet"
)でoriginal_model
として定義し、必要な層の重みをコピーするということを行うのです。
# 重みをコピーしないでresnet50を定義する
resnet50 = keras.applications.ResNet50( \
weights=None, \
include_top=False, \
input_tensor=model_input \
)
# weights="imagenet", include_top=False, input_tensor=x
中略
# ResNet50を重み付きで再定義
original_model = keras.applications.ResNet50( \
weights="imagenet", include_top=False,
input_shape=(256,256,3) \
)
# original_modelからresnet50にlayerをコピーする
for layer_new, layer_original in zip(resnet50.layers, original_model.layers):
layer_new.set_weights(layer_original.get_weights())
このようにして定義したCNNの全体はこちらから
def convolution_block( \
block_input, \
num_filters=256, \
kernel_size=3, \
dilation_rate=1, \
padding="same", \
use_bias=False, \
):
x = layers.Conv2D( \
num_filters, \
kernel_size=kernel_size, \
dilation_rate=dilation_rate, \
padding="same", \
use_bias=use_bias, \
kernel_initializer=keras.initializers.HeNormal(), \
)(block_input)
x = layers.BatchNormalization()(x)
return tf.nn.relu(x)
def DilatedSpatialPyramidPooling(dspp_input):
dims = dspp_input.shape
x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
x = convolution_block(x, kernel_size=1, use_bias=True)
out_pool = layers.UpSampling2D( \
size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), \
interpolation="bilinear", \
)(x)
out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
x = layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
output = convolution_block(x, kernel_size=1)
return output
# 追加:入力データを3チャンネルのデータに変換する
def Merge_Multiple_Var(mmv_input):
mmvs = mmv_input.shape
x = layers.Conv2D( 3, kernel_size=(3,3), \
activation='relu', padding='same')(mmv_input)
return x
# 一部修正
def DeeplabV3Plus( \
image_size, num_classes, num_met_var, \
use_saved_model_flag=0, \
model_input_dir="model_input_dir",\
paramdir="your_model_paramdir", \
param_basename="your_param_basename", \
param_sufix=".hdf5" \
):
if ( use_saved_model_flag == 0 ):
model_input = keras.Input(shape=(image_size, image_size, 3))
# 重みをコピーしないでresnet50を定義する
resnet50 = keras.applications.ResNet50( \
weights=None, \
include_top=False, \
input_tensor=model_input \
)
# weights="imagenet", include_top=False, input_tensor=x \
x = resnet50.get_layer("conv4_block6_2_relu").output
x = DilatedSpatialPyramidPooling(x)
input_a = layers.UpSampling2D( \
size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]), \
interpolation="bilinear", \
)(x)
input_b = resnet50.get_layer("conv2_block3_2_relu").output
input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
x = layers.Concatenate(axis=-1)([input_a, input_b])
x = convolution_block(x)
x = convolution_block(x)
x = layers.UpSampling2D( \
size=(image_size // x.shape[1], image_size // x.shape[2]), \
interpolation="bilinear", \
)(x)
model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
NN_1 = Model(inputs=model_input, outputs=model_output)
NN_1.summary()
# 入力変換のための層を通った後にResNet50(NN_1)を通るようにネットワークNN_2を定義
addD3Net_input = tf.keras.Input(shape=(image_size, image_size, num_met_var))
y = Merge_Multiple_Var(addD3Net_input)
addD3Net_output = NN_1(y)
NN_2 = Model(inputs=addD3Net_input, outputs=addD3Net_output)
NN_2.summary()
##学習済みimagenetの重みをコピーする
original_model = keras.applications.ResNet50( \
weights="imagenet", include_top=False,
input_shape=(256,256,3) \
)
for layer_new, layer_original in zip(resnet50.layers, original_model.layers):
print(layer_new.name, layer_original.name)
layer_new.set_weights(layer_original.get_weights())
else:
paramfiledir = paramdir + "/" + model_input_dir + "/"
paramfile = paramfiledir + param_basename + "model" + param_sufix
weightfile = paramfiledir + param_basename + "weight" + param_sufix
print("Read model from {0}".format(paramfile))
NN_2 = load_model( paramfile )
print("Read weights from {0}".format(weightfile))
NN_2.load_weights(weightfile)
#return keras.Model(inputs=model_input, outputs=model_output)
return NN_2
3.2 教師画像となる前線などの画像
教師画像となる前線のデータは、気象通報形式の地上気象解析データから作成します。
具体的な手法は、こちらをご覧ください。
結果として以下のように気象要素ごとのマスクを得ることができます。
最上段 左:寒冷前線 右:温暖前線
2段目 左:停滞前線 右:閉塞前線
3段目 トラフ
最下段 左:高気圧記号 右:低気圧記号
気象データ可視化画像を入力、教師画像を出力として3.1のネットワークを用いて学習させることでセマンティックセグメンテーション画像を得ることができます。
4. 学習結果
学習は2009年から2018年までの10年分のデータ(1日2件)について実施しました。
200エポック程度で下記のような前線などを描画するようになります。ResNetは気象データなど学習したことは無いと思いますが、学習時に再現し始めるまでのエポック数は、自作ネットワークと同じ程度でした。これはResNetの構造が優れていることと、事前学習された重みが初期値として優れている、ということなのだと思います。
4.1 アメリカ付近の気象データの前線描画
左が自動生成された前線等の記号、右は米国気象局が実際に解析した気象要素です。
生成図では、停滞前線を緑色で表現しています。
前線が概ね再現できています。強いていうとフロリダ半島を通る寒冷前線の副前線のようなものが生成されていません。
これも概ね再現できていますが、テキサス州から東海岸へ伸びる停滞前線と寒冷前線がやや不明瞭。
概ね再現できています。西部のワシントン州沖あたりの温暖前線が、生成図では上手く再現できていない。
中西部以西の前線の再現が不明瞭です。東部大西洋の寒冷前線も再現できていないです。
これは東海岸と大西洋の前線があまり再現性がよろしくないです。
4.2 日本付近の気象データの前線描画(by アメリカ気象局流)
さて、アメリカの10年分の気象図描画を学んだAIに、日本付近の気象図を解析させてみました。
日本気象庁ではなく、アメリカンな気象図を描くでしょうか?
左がAIが生成したもの、右は日本の気象庁が解析したものです。
このあたりは概ね日本気象庁と同じような図を描きました。
日本気象庁とは少し異なるパターンを解析し始めたものを示します。
東北沖の前線は、日本の気象庁の図の解析範囲から切れた範囲までつながっています。
オホーツク海と中国東北区に日本気象庁は解析していない前線を描画しました。
こちらは逆に朝鮮半島から華南に伸びる前線を、全く生成していないパターンでした。
5. まとめ
アメリカ気象局の全球予報モデルのGRIB形式のデータを可視化して、Semantic Segmentationによって前線描画を行うニューラルネットワークを作りました。ネットワークはResNet50をベースにして、入力部分を変更しました。
6. 関連投稿
(1)日本気象庁のデータから前線を描画する
以前に、前線を自動描画するセマンティックセグメンテーションは、日本の気象庁のデータをもとに取り組んだことがあります。
この時は、自作のUnetのようなCNNを使用しました。合わせてご覧いただけると幸いです。
(2)日本気象庁の前線解析を学習して、世界のデータを解析する
今回と逆のパターンを過去にやったことがあります。日本の天気図で訓練されたAIに、世界の他の場所の天気図を解析させました。
この頃は大きなサイズの動画ファイルを投稿できたので、GIFアニメを載せていました。