1. 我们会用到什么模型呢?. N0 k8 O% R% @. B$ [0 P( I8 P
在之前的投票分类器中,我们用到了一个分类模型和一个回归模型。 在回归模型中,我们在用于计算分类时,用预测价格替代预测价格走势(下跌、上涨、不变)。 然而,在这种情况下,我们不能依据分类得到概率分布,而对于所谓的“软投票”这样是不允许的。$ N. l' h- f& L8 F. Q
我们已准备了 3 个分类模型。 在“如何在 MQL5 中集成 ONNX 模型的示例”一文中已用到两个模型。 第一个模型(回归)被转换为分类模型。 基于 10 个 OHLC 价格序列进行了培训。 第二个模型是分类模型。 基于 63 个收盘价序列进5 l+ c' Q7 f# k" |) V% H
//| https://www.mql5.com |& N+ o# ]0 z! o6 E+ d8 {3 w
//+------------------------------------------------------------------+
- s* s9 O1 I, J! y: H( J$ R4 w2 r: ]& {//--- price movement prediction
6 w1 i4 K% b: m/ U+ s#define PRICE_UP 0
$ ~" n$ W9 s9 I- h1 c3 r' Y#define PRICE_SAME 1
0 F6 M* Q- J4 {#define PRICE_DOWN 2- d% F% }% H% f: w: O: }: ~
//+------------------------------------------------------------------+
* d$ p- O* l: b& U( i G |//| Base class for models based on trained symbol and period |6 c3 L) E/ @: R+ ]* b6 V
//+------------------------------------------------------------------+
) K3 b! u/ [+ c% m7 z0 M- Gclass CModelSymbolPeriod8 x& P; g0 D7 i n
{
) ]9 u7 h8 e1 h$ c1 U3 sprotected:# S$ e) w2 O" Z C, G E$ H$ z
long m_handle; // created model session handle. `; O, |& T4 \1 f
string m_symbol; // symbol of trained data+ w' u: k. h+ p2 k1 t$ V4 a
ENUM_TIMEFRAMES m_period; // timeframe of trained data
4 m: o ~! y6 l+ {) Edatetime m_next_bar; // time of next bar (we work at bar begin only)" y. n2 J" G( K% @
double m_class_delta; // delta to recognize "price the same" in regression models
3 m, y+ u! D1 a" G6 Vpublic:
5 m' b1 a! R6 O$ L: Z//+------------------------------------------------------------------+
! q5 W1 O3 o' z, Y& |//| Constructor |
* L8 x2 @3 d, D$ C3 |# t//+------------------------------------------------------------------+
8 I5 D* t( a, X# ?: i+ {CModelSymbolPeriod(const string symbol,const ENUM_TIMEFRAMES period,const double class_delta=0.0001)" K6 ~$ A4 S; ?! A+ A
{+ G5 F: r2 k ~" d% V p
m_handle=INVALID_HANDLE;- S; I8 C; @; f# m
m_symbol=symbol;2 _2 w+ v0 Q( ?$ v
m_period=period;
6 _9 g6 g+ [8 W/ g0 D' Q, Hm_next_bar=0;% I. I& k9 K0 k* @
m_class_delta=class_delta;4 s" e3 o" w V; o# A) ]5 g
}: a, ?" h( s2 P0 h
//+------------------------------------------------------------------+9 Z! r) y' k m. q( Q
//| Destructor |
! }9 C+ ?( R' f; L3 P# ?( s//| Check for initialization, create model |
! `6 E. G7 [/ F; C5 P//+------------------------------------------------------------------+6 T0 T/ L" s0 }& W+ F$ w
bool CheckInit(const string symbol,const ENUM_TIMEFRAMES period,const uchar& model[]); P- r# g" w- M; m- W* i- m1 M/ ?% ~
{
( S' ~" |) s# w' M0 m( ]. N5 ~//--- check symbol, period0 l) ^) ~) Y) ]4 j9 n. F
if(symbol!=m_symbol || period!=m_period)
A/ \8 L3 W4 }; u# f" c& q4 o{
$ V1 R4 }+ ^) t6 q: ]* |/ c8 `PrintFormat("Model must work with %s,%s",m_symbol,EnumToString(m_period));% K" R8 j$ M# K$ i* W7 f
return(false);) r* a2 c* Y2 f) \+ z
}9 u, F6 ?3 k7 c6 m8 U
//--- create a model from static buffer. S+ z3 u0 m5 `4 @0 }. `. |8 _
m_handle=OnnxCreateFromBuffer(model,ONNX_DEFAULT);
! T4 A3 j: u' F3 P, ?( z5 \if(m_handle==INVALID_HANDLE)1 G& n% |3 p( W: {6 a0 y H
{
; s, X7 \1 h4 N8 P4 ]8 x: y) {Print("OnnxCreateFromBuffer error ",GetLastError());3 ]+ I, W9 P4 W
return(false);" {& i v3 g) \( o! D8 ^+ @8 h
}
- y% H! \. ]( a, X9 c3 j2 p; _//--- ok( y ~' K! `9 `9 Z- B) l: D+ K
return(true);
! L& c; i* z' R4 ~4 N}
" k& r) X; W3 d+ h//+------------------------------------------------------------------+
( E" I7 I: ]0 {. C# ^m_next_bar=TimeCurrent();
4 z; b( ~/ w* ?& Cm_next_bar-=m_next_bar%PeriodSeconds(m_period);1 h2 a i0 t8 o2 B% R. F; L
m_next_bar+=PeriodSeconds(m_period);+ {0 S& V- g" X& R# c
//--- work on new day bar5 R$ S5 v3 C4 r, c
return(true);
. s6 v2 i) y6 d) X2 F& M( N}
! c' \7 K* Z, s% G; d$ o//+------------------------------------------------------------------+1 L2 e9 ^, P4 }% T {" ^5 p' j
//| virtual stub for PredictPrice (regression model) |
3 B! [ `3 K4 G# @4 F& |//+------------------------------------------------------------------+
% w' f( e7 m% U) T4 \virtual double PredictPrice(void). G) K. Y1 A! s
{0 k1 r9 {# h6 A8 Y Y
return(DBL_MAX);
1 j* ^! c4 v% @' Y1 O}( |/ `# u) T6 V, b
//+------------------------------------------------------------------+
" b! Z" n6 i# k- x" r* v//| Predict class (regression -> classification) |4 O$ m9 d6 p0 C* I2 w
//+------------------------------------------------------------------+
; S$ A" Q; e% ~- Cvirtual int PredictClass(void)
4 j4 v" D3 C4 r3 q7 R, l+ k{! ~$ ^# {0 m! j. t
double predicted_price=PredictPrice();
' _( A& v& x; Q0 \: y# U" Cif(predicted_price==DBL_MAX)& Q/ b) Z& ` b& p0 W% [* V
return(-1);
! T g, ~0 w2 \; C$ I+ Xint predicted_class=-1;$ i* ~- Q/ |4 h4 R2 q
double last_close=iClose(m_symbol,m_period,1);
( _! ~& D& |" {//--- classify predicted price movement
9 s3 U8 f. R& \* odouble delta=last_close-predicted_price;$ N; t6 Q1 u W+ i% Y
if(fabs(delta)<=m_class_delta)% d5 @3 l1 t2 I8 B
predicted_class=PRICE_SAME; A$ D' u. y- V
else
@/ |: |: O! ~8 o$ R5 ?; b$ Hprivate:8 L+ b8 L% C: A9 ^' y% Q- J
int m_sample_size;
& G8 g' P. p2 a4 F9 z6 W* R' a//+------------------------------------------------------------------+
9 n- D% z# g5 v& y6 s; |4 `$ Svirtual bool Init(const string symbol, const ENUM_TIMEFRAMES period)- }% L, R; h9 c3 G" D
{$ \% n0 @* Z ]" Y! V2 W
//--- check symbol, period, create model. g {! M! W" F( x. h9 V) t* \
if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_10_class))
5 v: |! {' M, v5 Y{+ D3 |9 w$ q1 J, R Q8 n
Print("model_eurusd_D1_10_class : initialization error");
# N5 H% N/ T {8 |8 _, Z, Mreturn(false);# [5 S/ o' @5 C, _ ]
}
" y! u& J. i; c7 z+ U- ?. U# u//--- since not all sizes defined in the input tensor we must set them explicitly
+ ^3 t% S& E/ G5 `5 q! L* V//--- first index - batch size, second index - series size, third index - number of series (OHLC)( j9 J( b) \2 K
const long input_shape[] = {1,m_sample_size,4};! z0 D+ ^8 J% r3 Q- e' q
if(!OnnxSetInputShape(m_handle,0,input_shape))
5 x a6 _7 x+ H4 C5 ~{
4 t, u4 X$ K6 @) a8 NPrint("model_eurusd_D1_10_class : OnnxSetInputShape error ",GetLastError());! {% a! y4 {3 |1 N/ F
return(false);6 u. }0 `7 z* s3 O
}4 E6 {5 V- P) L7 }4 ?0 [
//--- since not all sizes defined in the output tensor we must set them explicitly- O: ~. O. [' y2 z# O, e/ A/ v2 Q
//--- first index - batch size, must match the batch size of the input tensor
0 B) d4 o' ?8 l [# W& [//--- second index - number of classes (up, same or down)! C9 P2 r, H- H9 k1 l( r
const long output_shape[] = {1,3}; }' z4 p' x- |1 Q6 Q5 D. w
if(!OnnxSetOutputShape(m_handle,0,output_shape))
5 h6 g$ G( E) N9 q' T{8 d) c3 _! p$ S9 ^4 O5 _
Print("model_eurusd_D1_10_class : OnnxSetOutputShape error ",GetLastError());
7 ~, p+ N4 P$ O) areturn(false);3 T" r5 k/ J B* ^
}
: B( U$ T; {! Q//--- ok& m5 e. L) ^) e
return(true);
( Q4 v+ k3 E" K* @}
! } T K5 d/ B) b, i6 u//+------------------------------------------------------------------+, j: }& D: x5 t7 c }; q) G
//| Predict class |
9 C0 {& p9 G/ K7 X7 Q//+------------------------------------------------------------------+
; q* K# w7 i+ `" Zvirtual int PredictClass(void)
1 ?2 s8 a3 Y7 B" P{ |