1. 我们会用到什么模型呢?7 D" C* i& ~" W" T' ?9 d" ]
在之前的投票分类器中,我们用到了一个分类模型和一个回归模型。 在回归模型中,我们在用于计算分类时,用预测价格替代预测价格走势(下跌、上涨、不变)。 然而,在这种情况下,我们不能依据分类得到概率分布,而对于所谓的“软投票”这样是不允许的。; Y) Q$ I {8 i: Q5 b! x
我们已准备了 3 个分类模型。 在“如何在 MQL5 中集成 ONNX 模型的示例”一文中已用到两个模型。 第一个模型(回归)被转换为分类模型。 基于 10 个 OHLC 价格序列进行了培训。 第二个模型是分类模型。 基于 63 个收盘价序列进
9 L" E* i1 W& d2 A//| https://www.mql5.com |
s$ t4 _4 y$ R, g1 e2 [//+------------------------------------------------------------------+- o3 ^6 b8 y- s+ m# G4 _' F
//--- price movement prediction
9 z; x9 h& V( J2 H0 i7 w#define PRICE_UP 0! Q7 \/ i" n/ q
#define PRICE_SAME 1
0 d, S8 B3 |. [/ I& S" X#define PRICE_DOWN 2
8 O+ O0 ?; P3 L, _//+------------------------------------------------------------------+& G6 ~) u' R( Z
//| Base class for models based on trained symbol and period |% S- A l0 y5 ^4 u. Y0 l
//+------------------------------------------------------------------+" K" O6 r9 v$ r# r
class CModelSymbolPeriod
" |+ P/ Z, S! W0 K. q0 F{
4 U0 I( l8 @" m0 N1 `1 M# fprotected:
2 T* q" e& z6 c3 I+ c- glong m_handle; // created model session handle
$ F2 M2 t! U2 f6 i/ d2 n5 l* d& Nstring m_symbol; // symbol of trained data
, O8 h1 E3 p6 r& ?' |! \' h, gENUM_TIMEFRAMES m_period; // timeframe of trained data
* B7 D) \: B" C% A' j' a3 P* N Fdatetime m_next_bar; // time of next bar (we work at bar begin only)6 o+ M+ W5 P" z* w
double m_class_delta; // delta to recognize "price the same" in regression models
2 o$ V4 c& t. \! f; Spublic:
$ V4 j8 P+ e; N, ^//+------------------------------------------------------------------+0 O3 i& Y% x8 E# z6 y: m
//| Constructor |5 J% H" n3 ]2 D+ j8 X
//+------------------------------------------------------------------+
) A1 w) ~9 i! ^/ V9 k' W/ lCModelSymbolPeriod(const string symbol,const ENUM_TIMEFRAMES period,const double class_delta=0.0001)
% U; E" I; o# P0 _7 l* h{# L- i7 I. M3 X! h
m_handle=INVALID_HANDLE;6 p# v5 s$ ~5 ~7 p, m W0 T
m_symbol=symbol;
) w- k% C- ]) E3 L7 I1 M/ Zm_period=period;) A9 f% j2 P; \! `- t- G# N8 F
m_next_bar=0;
- N& T3 u% D. q5 [5 Vm_class_delta=class_delta;' ]2 m# M' F/ y3 I
}1 ?& B$ ]! G( q+ o6 s: p
//+------------------------------------------------------------------+4 L& f0 }5 D/ |: r
//| Destructor |
0 s$ o6 h8 D7 e* v//| Check for initialization, create model |1 l! T X0 y' m1 g
//+------------------------------------------------------------------+
" V3 m C( N Q* u6 c; R8 u* Fbool CheckInit(const string symbol,const ENUM_TIMEFRAMES period,const uchar& model[])5 c2 a. B. i8 l
{
: S% L3 h! H: A$ U R1 n7 R//--- check symbol, period
! Z5 c8 b5 A( ]& l& S# b+ iif(symbol!=m_symbol || period!=m_period)$ T+ j3 e G7 E
{
9 _4 ^9 `6 ~% [8 F3 h T# ?PrintFormat("Model must work with %s,%s",m_symbol,EnumToString(m_period));
" c$ t- [. P0 O' G) ereturn(false);
$ I( F* F0 q7 o4 Y}- M I4 q' T4 a4 J2 h1 h
//--- create a model from static buffer* e% z3 @, T. F. ]/ @
m_handle=OnnxCreateFromBuffer(model,ONNX_DEFAULT);, O( m/ j/ G& j) ~
if(m_handle==INVALID_HANDLE)" L9 L# |0 r! Q; M! ?/ ?6 {4 Y$ ]" n
{2 M1 |. F. {5 E3 d* E' V/ Z. q& h
Print("OnnxCreateFromBuffer error ",GetLastError());
1 Z+ N0 t+ e! ?# B0 C* y7 vreturn(false);
- m: n3 m0 J, Z) g {( m4 U; C+ T) ~4 O}$ u( W* M( c8 S W( [
//--- ok' F W1 P3 o7 [* ^
return(true);
" e j9 E$ Y- n, s& O0 o}. _! l, G; ^: h& `1 @, g
//+------------------------------------------------------------------+6 E* B8 g3 Y4 b+ z
m_next_bar=TimeCurrent();2 X) J& h0 t# w. f( ?
m_next_bar-=m_next_bar%PeriodSeconds(m_period);/ O8 u( h0 q! w, U7 u9 T
m_next_bar+=PeriodSeconds(m_period);
) I% t# c. s2 b" d' l//--- work on new day bar" e8 ~) ?8 V- h- y( h% ]; f
return(true);
5 Y; I! p! u& X, S}4 |. c$ l, r" w! \% V
//+------------------------------------------------------------------+
2 a! s, q4 } L0 G% S; A& C//| virtual stub for PredictPrice (regression model) |
+ f- J6 B; ?% Q% A; }//+------------------------------------------------------------------+
) T( e4 x8 u5 p: avirtual double PredictPrice(void)
( _' r* P) c! ]. [# B) W+ Y7 Z{
) z5 W# o5 v3 Z8 }% Z1 S* ^return(DBL_MAX);
( |5 B& t( n( `4 F6 C+ p% S}6 Q% _% [6 D% F4 R _" p
//+------------------------------------------------------------------+
! Z* F4 M, }6 i. e" w9 w//| Predict class (regression -> classification) |
& D. I# v3 \) v' ^9 d3 e* B9 n4 I! [! L//+------------------------------------------------------------------+7 O, i9 [ |2 @3 \
virtual int PredictClass(void)
9 b4 _9 O2 t% U1 l* a8 T{
0 m1 H; I' x& E, x2 {' ydouble predicted_price=PredictPrice();
' f. j- v/ q$ m4 W, N! r3 u' Nif(predicted_price==DBL_MAX)
3 `( J# Z1 a# D' y4 X+ @- }return(-1);
" w: E0 T9 [, W- ]/ Hint predicted_class=-1;& j+ ^; ~, W$ D( h* T
double last_close=iClose(m_symbol,m_period,1);" g$ G) \2 ]- \ \
//--- classify predicted price movement5 { w+ `# a! o- p2 e' Z8 j2 ~
double delta=last_close-predicted_price;
" j- l1 S9 U) {# V! C/ T; {2 tif(fabs(delta)<=m_class_delta)2 A4 [# c" }, f8 U. C, e$ l9 D
predicted_class=PRICE_SAME;
: P9 [3 s8 [* x2 M2 Welse {; C' t+ h( F
private:
+ @8 y; p3 x" }' l# o7 y% aint m_sample_size;
/ ]0 ]! E }# ?2 L; n+ i//+------------------------------------------------------------------+
( k4 X* }0 e2 D$ Gvirtual bool Init(const string symbol, const ENUM_TIMEFRAMES period); e" ~% h5 a: t7 q2 k
{$ ~" O) e% t" ]% k
//--- check symbol, period, create model
" g7 A2 C3 t% U( X( o6 Uif(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_10_class))3 f7 w5 i9 n! v p
{* X P" h$ x# H4 [' N# `
Print("model_eurusd_D1_10_class : initialization error");
6 X& C2 s& v+ d6 H. {$ Mreturn(false);
! H& e& T" G1 i7 b8 F}: D$ v, n; w J2 u @/ ~2 Q' Z
//--- since not all sizes defined in the input tensor we must set them explicitly2 ~8 `7 d8 z! p1 t
//--- first index - batch size, second index - series size, third index - number of series (OHLC)
9 Q2 L, o! y2 Q& [8 `" j3 W" Jconst long input_shape[] = {1,m_sample_size,4};
; I% z; y2 I: }if(!OnnxSetInputShape(m_handle,0,input_shape))
7 t( }: ]9 q$ R+ o{1 ]% [# J0 v' i6 R$ f+ o( j3 b# R, O
Print("model_eurusd_D1_10_class : OnnxSetInputShape error ",GetLastError());
( g, R0 a7 ?; Y \7 A2 P- H$ @return(false);
& H: \7 x9 Y/ ^ H0 O7 s. }2 ]}
* S. y$ S& O7 C" @7 W( z( i1 |//--- since not all sizes defined in the output tensor we must set them explicitly. S3 N& m8 e" ]" a2 O! F6 e, t
//--- first index - batch size, must match the batch size of the input tensor
" F' z2 r) h4 I) O, E) D2 [//--- second index - number of classes (up, same or down)8 T$ F# ?0 _) e% b
const long output_shape[] = {1,3};# g' z3 t9 z) ]8 i
if(!OnnxSetOutputShape(m_handle,0,output_shape))
6 j7 F0 D/ n/ F; {' T/ X{
8 _1 D. l% {; O% ^Print("model_eurusd_D1_10_class : OnnxSetOutputShape error ",GetLastError());
6 _" Q) Z' W7 N( w- X8 W! D/ Ureturn(false);
- A* r5 Z; d* J" v0 L/ L/ ]1 [}) }9 b0 F$ K1 I: s& v& h
//--- ok
6 x. o0 T, o6 t( d8 P L, C) zreturn(true);; z ]8 I1 a( `
}
- `! C: t+ T$ B3 B+ {) V//+------------------------------------------------------------------+0 u9 c' ]9 T8 o" [2 h) M
//| Predict class |
4 p+ l7 r- |# Y//+------------------------------------------------------------------+
9 m3 Q4 B4 ~8 \/ I4 k( i; R7 _virtual int PredictClass(void)
9 A4 G3 x, P5 y% G% I; X{ |