サーバー側が完成しました。
ポイントとしては、Postgressサーバーをドッカーで立ち上げ、為替価格をテーブルに入れたこと
これまでよりも、スムーズに価格リストを引き出せるようになりました。
api
POSTGRESによる為替価格データベースを構築
%pip install sqlalchemy psycopg2 pandas_datareader mplfinance torch torchvision databases fsspec asyncpg fastapi_crudrouter sqlalchemy_utils --user
postgressのコンフィギュレーションです。
(ほとんどのソースは、書籍から使わせていただきました)
import os
class DBConfigurations:
postgres_username = "user"
postgres_password = "password"
postgres_port = 5432
postgres_db = "model_db"
postgres_server = "localhost"
sql_alchemy_database_url = (
f"postgresql://{postgres_username}:{postgres_password}@{postgres_server}:{postgres_port}/{postgres_db}"
)
class APIConfigurations:
title = os.getenv("API_TITLE", "Model_DB_Service")
description = os.getenv("API_DESCRIPTION", "machine learning system training patterns")
version = os.getenv("API_VERSION", "0.1")
WEB-APIの場合、SESSIONを接続のたびに構築するのですね・・・
(ここで、トークン認証などの処理がどうなるのかが、まだ疑問ですが。。。)
import os
from contextlib import contextmanager
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
engine = create_engine(
DBConfigurations.sql_alchemy_database_url,
encoding="utf-8",
pool_recycle=3600,
echo=False,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_db():
db = SessionLocal()
try:
yield db
except:
db.rollback()
raise
finally:
db.close()
為替データのクラスです。
テーブル名を足の数だけ宣言しています。
from sqlalchemy import Column, DateTime, ForeignKey, String, Text,Float,Integer
from sqlalchemy.sql.functions import current_timestamp
from sqlalchemy.types import JSON
from sqlalchemy import desc
import subprocess
import os
import shutil
class forex_m1(Base):
__tablename__= "forex_m1"
id = Column(DateTime(timezone=True), primary_key=True, index=True)
open = Column(Float())
high = Column(Float())
low = Column(Float())
close = Column(Float())
volume = Column(Integer())
class forex_m5(Base):
__tablename__= "forex_m5"
id = Column(DateTime(timezone=True), primary_key=True, index=True)
open = Column(Float())
high = Column(Float())
low = Column(Float())
close = Column(Float())
volume = Column(Integer())
class forex_m15(Base):
__tablename__= "forex_m15"
id = Column(DateTime(timezone=True), primary_key=True, index=True)
open = Column(Float())
high = Column(Float())
low = Column(Float())
close = Column(Float())
volume = Column(Integer())
立ち上げ時のみ一度だけ、CSVデータをMT4のヒストリデータから取得したものを、SQLでぶっこんでいます。
import pandas as pd
import datetime as dt
from pandas_datareader import data
import mplfinance as mpf
import torch
from torchvision.datasets import ImageFolder
from torchvision import models, transforms
import torch.nn as nn
import numpy as np
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
df1 = pd.read_csv(r'C://Users//User//Desktop//EURUSD.oj5k1.csv', sep=",",names=('date', 'time', 'open', 'high', 'low', 'close', 'volume'))
df1.index = pd.to_datetime(df1['date']+" "+df1['time'])
df1 = df1.drop(['date', 'time'], axis=1)
df2 = pd.read_csv(r'C://Users//User//Desktop//EURUSD.oj5k5.csv', sep=",",names=('date', 'time', 'open', 'high', 'low', 'close', 'volume'))
df2.index = pd.to_datetime(df2['date']+" "+df2['time'])
df2 = df2.drop(['date', 'time'], axis=1)
df3 = pd.read_csv(r'C://Users//User//Desktop//EURUSD.oj5k15.csv', sep=",",names=('date', 'time', 'open', 'high', 'low', 'close', 'volume'))
df3.index = pd.to_datetime(df3['date']+" "+df3['time'])
df3 = df3.drop(['date', 'time'], axis=1)
df1.to_sql('forex_m1', con=engine, index=True, index_label='id', if_exists='replace')
df2.to_sql('forex_m5', con=engine, index=True, index_label='id', if_exists='replace')
df3.to_sql('forex_m15', con=engine, index=True, index_label='id', if_exists='replace')
ここからが、APIの処理です。
少しSqlalchemyのサブクエリを使っています(かなり勉強させられました)
最大のポイントは、テーブル名からクラスを引き出すところ・・・
丸2日調べまくりました。
(最後はあっけなくgetメソドで解決。dir関数でメソドを調べつくしました)
def select_forex_by_name(db: Session,table_name:str):
return get_class_by_table(Base,Base.metadata.tables.get(table_name))
from fastapi import Body, FastAPI
from sqlalchemy_utils import get_class_by_table
from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends
from sqlalchemy.sql import func
app = FastAPI()
def select_forex_by_name(db: Session,table_name:str):
return get_class_by_table(Base,Base.metadata.tables.get(table_name))
def get_last_time(db: Session,table_name:str):
try:
model = select_forex_by_name( db=db, table_name=table_name, )
q = db.query( func.max(model.id).label('id_max')).subquery('sub1')
r = db.query(model).filter(model.id == q.c.id_max ).all()
return str(r[0].id).replace("-",".")
except:
pass
return
@app.get("/getlasttime/")
async def gettime(db:Session = Depends(get_db),):
return {"m1":get_last_time(db=db,table_name = "forex_m1"),
"m5":get_last_time(db=db,table_name = "forex_m5"),
"m15":get_last_time(db=db,table_name = "forex_m15")
}
価格更新の部分です。
基本はモデルクラスを作って、ぶっこめばイイんですね。
さっきの関数が大活躍です。
from typing import Dict, List, Optional
def add_forex( db: Session, table_name: str, time: Optional[str] = None, value: Optional[str] = None, commit: bool = True,):
dataframe = select_forex_by_name( db=db, table_name=table_name, )
data = dataframe(id=time,close=value,)
db.add(data)
if commit:
db.commit()
db.refresh(data)
return data
@app.post("/gettick/")
async def gettick(db:Session = Depends(get_db),body=Body(...)):
time,peristr,value = body["content"].split(",")
r = add_forex(
db=db,
table_name=peristr,
time = time,
value = value,
commit=True,
)
return {"msg":peristr}
ここからは推論の部分です。
PIX2PIXのコールがあるので、三段のAPIにしています。
他はほとんど、これまでの投稿と同じです。
from sqlalchemy import desc
import re
def get_dataframe( db: Session, table_name: str, framesize: int,):
try:
dataframe = select_forex_by_name( db=db, table_name=table_name, )
q = db.query(dataframe).order_by(desc(dataframe.id)).limit(framesize).subquery('sub1')
r = db.query(dataframe).filter(dataframe.id == q.c.id ).order_by(dataframe.id).all()
return [float(result.close) for result in r]
except:
pass
return
@app.get("/predict/")
async def predict(db:Session = Depends(get_db)):
wsize = 96
df11 = get_dataframe( db=db, table_name="forex_m1", framesize = wsize, )
df22 = get_dataframe( db=db, table_name="forex_m5", framesize = wsize, )
df33 = get_dataframe( db=db, table_name="forex_m15", framesize = wsize, )
print(df11[-1],df22[-1],df33[-1])
img = imagemake( df11, df22, df33)
fn = get_last_time(db=db,table_name = "forex_m1")
shutil.rmtree('datasets/facades2/test/')
os.mkdir('datasets/facades2/test/')
shutil.rmtree("results/facades_pix2pix2/test_latest/images")
os.mkdir("results/facades_pix2pix2/test_latest/images")
fname = "datasets/facades2/test/" + fn.replace(":","_").replace(".","-") + "sk.png"
img.save(fname)
return {"msg",fname}
@app.get("/predict1/")
async def predict1(db:Session = Depends(get_db)):
cmd = 'python test.py --dataroot ./datasets/facades2 --name facades_pix2pix2 --model pix2pix --direction AtoB'
subprocess.check_output(cmd, shell=True)
return {"msg","pass1"}
@app.get("/predict2/")
async def predict2(db:Session = Depends(get_db)):
path = os.getcwd()
new_dir_path = "results/facades_pix2pix2/test_latest/images"
img=[]
image={}
for imageName in os.listdir(new_dir_path):
inputPath = os.path.join(path, new_dir_path,imageName)
if "fake_B" in imageName : image['fakeB']=inputPath
if "real_A" in imageName: image['realA']=inputPath
if "real_B" in imageName: image['realB']=inputPath
if len(image)==3:
ddd=re.findall(r"\d\d\d\d-\d\d-\d\d \d\d_\d\d_\d\d",inputPath)
try:
image['date']=ddd[0].replace("_",":")
img.append(image)
image={}
except:
pass
signal = 0
transform = transforms.PILToTensor()
for item in img:
v2,date = getprice(item,transform,0)
signal =GetSignal(v2)
return {"signal":signal}
MT4
一分ごとに、価格更新して、推論サーバーを呼び出しています。
string URL = "http://127.0.0.1/";
datetime mBeforeBarCreationDateTime;
int OnInit()
{
Init();
return(INIT_SUCCEEDED);
};
int Init(){
string res,str,filename,sep_str[];
datetime m1,m5,m15,current;
int pos;
res = GET(URL + "getlasttime/", "");
pos =StringSplit(res,'\"',sep_str);
Print("pos ",pos);
if(pos!=0)
{
m1 = StringToTime(sep_str[3]);
m5 = StringToTime(sep_str[7]);
m15 = StringToTime(sep_str[11]);
}
SendValue(PERIOD_M1,"forex_m1",m1);
SendValue(PERIOD_M5,"forex_m5",m5);
SendValue(PERIOD_M15,"forex_m15",m15);
return(INIT_SUCCEEDED);
}
void SendValue(string peristr,string forex, datetime end){
for(int i=0; i<100000; i++){
if (end > iTime(NULL,peristr,i))
break;
string data = TimeToString(iTime(NULL,peristr,i))
+","+ forex + "," + DoubleToString(iClose(NULL,peristr,i));
Print(data);
POST(URL + "gettick/", data);
}
}
void On1Minute()
{
string res,str,pos,filename,sep_str[];
Print("On1Minute");
Init();
Print("predict", GET(URL + "predict/", ""));
Print("predict1", GET(URL + "predict1/", ""));
Print("predict2", GET(URL + "predict2/", ""));
}
//! @brief ティック毎の処理
void OnTick()
{
// 最新の1分足のバーの形成開始時刻を取得。
datetime current = iTime(NULL, PERIOD_M1, 0);
// 前のティックでの形成開始時刻と比較。
if (current != mBeforeBarCreationDateTime)
{
// 違うならば前のバーが確定し新しいバーになった=1分毎の更新タイミング。
On1Minute();
// バーの形成開始時刻を更新。
mBeforeBarCreationDateTime = current;
}
};
bool POST(string url, string text){
string headers;
string data;
char post[],result[];
headers = "Content-Type: application/json\r\n";
StringReplace(text, "\n", "\\n");
data = "{\"content\":\""+text+"\"}";
ArrayResize(post,StringToCharArray(data,post,0,WHOLE_ARRAY,CP_UTF8)-1);
int res=WebRequest("POST",url,headers,5000,post,result,headers);
if(res == -1){
Print(__FUNCTION__ + " Error code =",GetLastError(),data);
return(false);
}
Print("POST success! ", CharArrayToString(result, 0, -1));
return(true);
}
string GET(string url, string text){
string headers;
string data,str;
char post[],result[];
headers = "Content-Type: application/json\r\n";
StringReplace(text, "\n", "\\n");
data = "{\"content\":\""+text+"\"}";
ArrayResize(post,StringToCharArray(data,post,0,WHOLE_ARRAY,CP_UTF8)-1);
int res=WebRequest("GET",url,headers,5000,post,result,headers);
if(res == -1){
Print(__FUNCTION__ + " Error code =",GetLastError(),data);
return(false);
}
//--- Receive a link to the image uploaded to the server
str=CharArrayToString(result);
return(str);
さあ、きょうはHi Low binaryの、自動化するぞ~