期货量化软件:赫兹量化中包装 ONNX 模型
1. 我们会用到什么模型呢?在之前的投票分类器中,我们用到了一个分类模型和一个回归模型。 在回归模型中,我们在用于计算分类时,用预测价格替代预测价格走势(下跌、上涨、不变)。 然而,在这种情况下,我们不能依据分类得到概率分布,而对于所谓的“软投票”这样是不允许的。
我们已准备了 3 个分类模型。 在“如何在 MQL5 中集成 ONNX 模型的示例”一文中已用到两个模型。 第一个模型(回归)被转换为分类模型。 基于 10 个 OHLC 价格序列进行了培训。 第二个模型是分类模型。 基于 63 个收盘价序列进
//| https://www.mql5.com |
//+------------------------------------------------------------------+
//--- price movement prediction
#define PRICE_UP 0
#define PRICE_SAME 1
#define PRICE_DOWN 2
//+------------------------------------------------------------------+
//| Base class for models based on trained symbol and period |
//+------------------------------------------------------------------+
class CModelSymbolPeriod
{
protected:
long m_handle; // created model session handle
string m_symbol; // symbol of trained data
ENUM_TIMEFRAMES m_period; // timeframe of trained data
datetime m_next_bar; // time of next bar (we work at bar begin only)
double m_class_delta; // delta to recognize "price the same" in regression models
public:
//+------------------------------------------------------------------+
//| Constructor |
//+------------------------------------------------------------------+
CModelSymbolPeriod(const string symbol,const ENUM_TIMEFRAMES period,const double class_delta=0.0001)
{
m_handle=INVALID_HANDLE;
m_symbol=symbol;
m_period=period;
m_next_bar=0;
m_class_delta=class_delta;
}
//+------------------------------------------------------------------+
//| Destructor |
//| Check for initialization, create model |
//+------------------------------------------------------------------+
bool CheckInit(const string symbol,const ENUM_TIMEFRAMES period,const uchar& model[])
{
//--- check symbol, period
if(symbol!=m_symbol || period!=m_period)
{
PrintFormat("Model must work with %s,%s",m_symbol,EnumToString(m_period));
return(false);
}
//--- create a model from static buffer
m_handle=OnnxCreateFromBuffer(model,ONNX_DEFAULT);
if(m_handle==INVALID_HANDLE)
{
Print("OnnxCreateFromBuffer error ",GetLastError());
return(false);
}
//--- ok
return(true);
}
//+------------------------------------------------------------------+
m_next_bar=TimeCurrent();
m_next_bar-=m_next_bar%PeriodSeconds(m_period);
m_next_bar+=PeriodSeconds(m_period);
//--- work on new day bar
return(true);
}
//+------------------------------------------------------------------+
//| virtual stub for PredictPrice (regression model) |
//+------------------------------------------------------------------+
virtual double PredictPrice(void)
{
return(DBL_MAX);
}
//+------------------------------------------------------------------+
//| Predict class (regression -> classification) |
//+------------------------------------------------------------------+
virtual int PredictClass(void)
{
double predicted_price=PredictPrice();
if(predicted_price==DBL_MAX)
return(-1);
int predicted_class=-1;
double last_close=iClose(m_symbol,m_period,1);
//--- classify predicted price movement
double delta=last_close-predicted_price;
if(fabs(delta)<=m_class_delta)
predicted_class=PRICE_SAME;
else
private:
int m_sample_size;
//+------------------------------------------------------------------+
virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period)
{
//--- check symbol, period, create model
if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_10_class))
{
Print("model_eurusd_D1_10_class : initialization error");
return(false);
}
//--- since not all sizes defined in the input tensor we must set them explicitly
//--- first index - batch size, second index - series size, third index - number of series (OHLC)
const long input_shape[] = {1,m_sample_size,4};
if(!OnnxSetInputShape(m_handle,0,input_shape))
{
Print("model_eurusd_D1_10_class : OnnxSetInputShape error ",GetLastError());
return(false);
}
//--- since not all sizes defined in the output tensor we must set them explicitly
//--- first index - batch size, must match the batch size of the input tensor
//--- second index - number of classes (up, same or down)
const long output_shape[] = {1,3};
if(!OnnxSetOutputShape(m_handle,0,output_shape))
{
Print("model_eurusd_D1_10_class : OnnxSetOutputShape error ",GetLastError());
return(false);
}
//--- ok
return(true);
}
//+------------------------------------------------------------------+
//| Predict class |
//+------------------------------------------------------------------+
virtual int PredictClass(void)
{
页:
[1]