LoginSignup
4
5

More than 1 year has passed since last update.

AIトレードシステムのMT4埋め込み(2)

Posted at

サーバー側が完成しました。

ポイントとしては、Postgressサーバーをドッカーで立ち上げ、為替価格をテーブルに入れたこと
これまでよりも、スムーズに価格リストを引き出せるようになりました。

shell.sh
api
POSTGRESによる為替価格データベースを構築

%pip install sqlalchemy psycopg2 pandas_datareader mplfinance torch torchvision databases fsspec asyncpg fastapi_crudrouter sqlalchemy_utils --user

postgressのコンフィギュレーションです。
(ほとんどのソースは、書籍から使わせていただきました)

main.py
import os

class DBConfigurations:
    postgres_username = "user"
    postgres_password = "password"
    postgres_port = 5432
    postgres_db = "model_db"
    postgres_server = "localhost"
    sql_alchemy_database_url = (
        f"postgresql://{postgres_username}:{postgres_password}@{postgres_server}:{postgres_port}/{postgres_db}"
    )
    
class APIConfigurations:
    title = os.getenv("API_TITLE", "Model_DB_Service")
    description = os.getenv("API_DESCRIPTION", "machine learning system training patterns")
    version = os.getenv("API_VERSION", "0.1")

WEB-APIの場合、SESSIONを接続のたびに構築するのですね・・・
(ここで、トークン認証などの処理がどうなるのかが、まだ疑問ですが。。。)

main.py
import os
from contextlib import contextmanager
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

engine = create_engine(
    DBConfigurations.sql_alchemy_database_url,
    encoding="utf-8",
    pool_recycle=3600,
    echo=False,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

Base = declarative_base()

def get_db():
    db = SessionLocal()
    try:
        yield db
    except:
        db.rollback()
        raise
    finally:
        db.close()

為替データのクラスです。
テーブル名を足の数だけ宣言しています。

main.py
from sqlalchemy import Column, DateTime, ForeignKey, String, Text,Float,Integer
from sqlalchemy.sql.functions import current_timestamp
from sqlalchemy.types import JSON
from sqlalchemy import desc

import subprocess
import os
import shutil

class forex_m1(Base):
    __tablename__= "forex_m1"

    id   = Column(DateTime(timezone=True), primary_key=True, index=True)
    open    = Column(Float())
    high    = Column(Float())
    low     = Column(Float())
    close   = Column(Float())
    volume  = Column(Integer())

class forex_m5(Base):
    __tablename__= "forex_m5"

    id   = Column(DateTime(timezone=True), primary_key=True, index=True)
    open    = Column(Float())
    high    = Column(Float())
    low     = Column(Float())
    close   = Column(Float())
    volume  = Column(Integer())
    
class forex_m15(Base):
    __tablename__= "forex_m15"

    id   = Column(DateTime(timezone=True), primary_key=True, index=True)
    open    = Column(Float())
    high    = Column(Float())
    low     = Column(Float())
    close   = Column(Float())
    volume  = Column(Integer())

立ち上げ時のみ一度だけ、CSVデータをMT4のヒストリデータから取得したものを、SQLでぶっこんでいます。

main.py
import pandas as pd
import datetime as dt
from pandas_datareader import data
import mplfinance as mpf
import torch
from torchvision.datasets import ImageFolder
from torchvision import models, transforms
import torch.nn as nn
import numpy as np

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

df1 = pd.read_csv(r'C://Users//User//Desktop//EURUSD.oj5k1.csv', sep=",",names=('date', 'time', 'open', 'high', 'low', 'close', 'volume'))
df1.index = pd.to_datetime(df1['date']+" "+df1['time'])
df1 = df1.drop(['date', 'time'], axis=1)

df2 = pd.read_csv(r'C://Users//User//Desktop//EURUSD.oj5k5.csv', sep=",",names=('date', 'time', 'open', 'high', 'low', 'close', 'volume'))
df2.index = pd.to_datetime(df2['date']+" "+df2['time'])
df2 = df2.drop(['date', 'time'], axis=1)

df3 = pd.read_csv(r'C://Users//User//Desktop//EURUSD.oj5k15.csv', sep=",",names=('date', 'time', 'open', 'high', 'low', 'close', 'volume'))
df3.index = pd.to_datetime(df3['date']+" "+df3['time'])
df3 = df3.drop(['date', 'time'], axis=1)

df1.to_sql('forex_m1', con=engine, index=True, index_label='id', if_exists='replace')
df2.to_sql('forex_m5', con=engine, index=True, index_label='id', if_exists='replace')
df3.to_sql('forex_m15', con=engine, index=True, index_label='id', if_exists='replace')

ここからが、APIの処理です。
少しSqlalchemyのサブクエリを使っています(かなり勉強させられました)

最大のポイントは、テーブル名からクラスを引き出すところ・・・

丸2日調べまくりました。
(最後はあっけなくgetメソドで解決。dir関数でメソドを調べつくしました)

func.py
def select_forex_by_name(db: Session,table_name:str):
    return get_class_by_table(Base,Base.metadata.tables.get(table_name))
main.py
from fastapi import Body, FastAPI
from sqlalchemy_utils import get_class_by_table
from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends
from sqlalchemy.sql import func

app = FastAPI()         

def select_forex_by_name(db: Session,table_name:str):
    return get_class_by_table(Base,Base.metadata.tables.get(table_name))

def get_last_time(db: Session,table_name:str):
    try:
        model = select_forex_by_name(        db=db,        table_name=table_name,    )
        q = db.query( func.max(model.id).label('id_max')).subquery('sub1')
        r = db.query(model).filter(model.id == q.c.id_max ).all()
        return str(r[0].id).replace("-",".")
    except:
        pass
    return
    
@app.get("/getlasttime/")
async def gettime(db:Session = Depends(get_db),):    
    return {"m1":get_last_time(db=db,table_name = "forex_m1"),
            "m5":get_last_time(db=db,table_name = "forex_m5"),
            "m15":get_last_time(db=db,table_name = "forex_m15")
           }

価格更新の部分です。
基本はモデルクラスを作って、ぶっこめばイイんですね。
さっきの関数が大活躍です。

main.py
from typing import Dict, List, Optional

def add_forex(    db: Session,    table_name: str,    time: Optional[str] = None,    value: Optional[str] = None,    commit: bool = True,):

    dataframe = select_forex_by_name(        db=db,        table_name=table_name,    )
    
    data = dataframe(id=time,close=value,)
    db.add(data)
    if commit:
        db.commit()
        db.refresh(data)
    return data

@app.post("/gettick/")
async def gettick(db:Session = Depends(get_db),body=Body(...)):
       
        time,peristr,value = body["content"].split(",")       

        r = add_forex(
        db=db,
        table_name=peristr,
        time = time,
        value = value,
        commit=True,
    )    
        return {"msg":peristr}   

ここからは推論の部分です。
PIX2PIXのコールがあるので、三段のAPIにしています。

他はほとんど、これまでの投稿と同じです。

main.py
from sqlalchemy import desc
import re

def get_dataframe( db: Session,    table_name: str,    framesize: int,):

    try:
        dataframe = select_forex_by_name(  db=db, table_name=table_name,    )
        q = db.query(dataframe).order_by(desc(dataframe.id)).limit(framesize).subquery('sub1')
        r = db.query(dataframe).filter(dataframe.id == q.c.id ).order_by(dataframe.id).all()
        return [float(result.close) for result in r]

    except:
        pass
    return

@app.get("/predict/")
async def predict(db:Session = Depends(get_db)):
             
    wsize = 96
    
    df11 = get_dataframe( db=db, table_name="forex_m1", framesize = wsize,  )  
    df22 = get_dataframe( db=db, table_name="forex_m5", framesize = wsize,  )  
    df33 = get_dataframe( db=db, table_name="forex_m15", framesize = wsize,  )  
    
    print(df11[-1],df22[-1],df33[-1])
    
    img = imagemake( df11, df22, df33)
    
    fn = get_last_time(db=db,table_name = "forex_m1")
        
    shutil.rmtree('datasets/facades2/test/')
    os.mkdir('datasets/facades2/test/')
    shutil.rmtree("results/facades_pix2pix2/test_latest/images")
    os.mkdir("results/facades_pix2pix2/test_latest/images")

    fname = "datasets/facades2/test/" + fn.replace(":","_").replace(".","-") + "sk.png"
          
    img.save(fname)
        
    return {"msg",fname}
        
@app.get("/predict1/")
async def predict1(db:Session = Depends(get_db)):
    
        cmd = 'python test.py --dataroot ./datasets/facades2 --name facades_pix2pix2 --model pix2pix --direction AtoB'

        subprocess.check_output(cmd, shell=True)
                
        return {"msg","pass1"}
    
@app.get("/predict2/")
async def predict2(db:Session = Depends(get_db)):
        path = os.getcwd()
        new_dir_path = "results/facades_pix2pix2/test_latest/images"
                
        img=[]
        image={}

        for imageName in os.listdir(new_dir_path):
            inputPath = os.path.join(path, new_dir_path,imageName)
            if "fake_B" in  imageName : image['fakeB']=inputPath
            if "real_A" in  imageName: image['realA']=inputPath
            if "real_B" in  imageName: image['realB']=inputPath
            if len(image)==3:
                ddd=re.findall(r"\d\d\d\d-\d\d-\d\d \d\d_\d\d_\d\d",inputPath)
        
            try:
                image['date']=ddd[0].replace("_",":")
                img.append(image)
                image={}
            except:
                pass
        
        signal = 0
        
        transform = transforms.PILToTensor()
        
        for item in img:
            v2,date = getprice(item,transform,0)   
            signal =GetSignal(v2)

        return {"signal":signal}

MT4

一分ごとに、価格更新して、推論サーバーを呼び出しています。

main.c
string URL = "http://127.0.0.1/";
datetime mBeforeBarCreationDateTime;

int OnInit()
{
   Init();
   return(INIT_SUCCEEDED);
};


int Init(){
   string res,str,filename,sep_str[];
   datetime m1,m5,m15,current;
   
   int pos;
   
   res = GET(URL + "getlasttime/", "");
   
   pos =StringSplit(res,'\"',sep_str);
   Print("pos ",pos);
   
   if(pos!=0)
        {
         m1 = StringToTime(sep_str[3]);
         m5 = StringToTime(sep_str[7]);
         m15 = StringToTime(sep_str[11]);  
        }
    
   SendValue(PERIOD_M1,"forex_m1",m1);
   SendValue(PERIOD_M5,"forex_m5",m5);
   SendValue(PERIOD_M15,"forex_m15",m15);
      
   return(INIT_SUCCEEDED);
}

void SendValue(string peristr,string forex, datetime end){
   for(int i=0; i<100000; i++){
   if (end > iTime(NULL,peristr,i))
   
   break;
     
   string data = TimeToString(iTime(NULL,peristr,i))
   +","+ forex + "," + DoubleToString(iClose(NULL,peristr,i));

   Print(data);
   
   POST(URL + "gettick/", data);
   }
}

void On1Minute()
    {
       string res,str,pos,filename,sep_str[];
       
        Print("On1Minute");
        
        Init();
        
   Print("predict", GET(URL + "predict/", ""));
   Print("predict1", GET(URL + "predict1/", ""));
   Print("predict2", GET(URL + "predict2/", ""));
  
}


//! @brief  ティック毎の処理
void OnTick()
    {

        // 最新の1分足のバーの形成開始時刻を取得。
        datetime current = iTime(NULL, PERIOD_M1, 0);
        
        // 前のティックでの形成開始時刻と比較。
        if (current != mBeforeBarCreationDateTime)
        {
            // 違うならば前のバーが確定し新しいバーになった=1分毎の更新タイミング。
            On1Minute();

            // バーの形成開始時刻を更新。
            mBeforeBarCreationDateTime = current;
        }
};




bool POST(string url, string text){
       
string headers;
string data;
char post[],result[];

headers = "Content-Type: application/json\r\n";      

StringReplace(text, "\n", "\\n");
data = "{\"content\":\""+text+"\"}"; 
ArrayResize(post,StringToCharArray(data,post,0,WHOLE_ARRAY,CP_UTF8)-1);

int res=WebRequest("POST",url,headers,5000,post,result,headers);

   if(res == -1){
      Print(__FUNCTION__ + " Error code =",GetLastError(),data);
      return(false);
   }
   
   Print("POST success! ", CharArrayToString(result, 0, -1));
   return(true);
}

string GET(string url, string text){
       
string headers;
string data,str;
char post[],result[];

headers = "Content-Type: application/json\r\n";      

StringReplace(text, "\n", "\\n");
data = "{\"content\":\""+text+"\"}"; 
ArrayResize(post,StringToCharArray(data,post,0,WHOLE_ARRAY,CP_UTF8)-1);

int res=WebRequest("GET",url,headers,5000,post,result,headers);

   if(res == -1){
      Print(__FUNCTION__ + " Error code =",GetLastError(),data);
      return(false);
   }
   
         //--- Receive a link to the image uploaded to the server
   str=CharArrayToString(result);    
   return(str);

さあ、きょうはHi Low binaryの、自動化するぞ~

4
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
5