jiangjiang55 发表于 2024-4-4 08:03:20

期货量化软件:赫兹量化中包装 ONNX 模型

1. 我们会用到什么模型呢?
在之前的投票分类器中,我们用到了一个分类模型和一个回归模型。 在回归模型中,我们在用于计算分类时,用预测价格替代预测价格走势(下跌、上涨、不变)。 然而,在这种情况下,我们不能依据分类得到概率分布,而对于所谓的“软投票”这样是不允许的。
我们已准备了 3 个分类模型。 在“如何在 MQL5 中集成 ONNX 模型的示例”一文中已用到两个模型。 第一个模型(回归)被转换为分类模型。 基于 10 个 OHLC 价格序列进行了培训。 第二个模型是分类模型。 基于 63 个收盘价序列进
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
//--- price movement prediction
#define PRICE_UP   0
#define PRICE_SAME 1
#define PRICE_DOWN 2
//+------------------------------------------------------------------+
//| Base class for models based on trained symbol and period         |
//+------------------------------------------------------------------+
class CModelSymbolPeriod
{
protected:
long            m_handle;         // created model session handle
string            m_symbol;         // symbol of trained data
ENUM_TIMEFRAMES   m_period;         // timeframe of trained data
datetime          m_next_bar;         // time of next bar (we work at bar begin only)
double            m_class_delta;      // delta to recognize "price the same" in regression models
public:
//+------------------------------------------------------------------+
//| Constructor                                                      |
//+------------------------------------------------------------------+
CModelSymbolPeriod(const string symbol,const ENUM_TIMEFRAMES period,const double class_delta=0.0001)
{
m_handle=INVALID_HANDLE;
m_symbol=symbol;
m_period=period;
m_next_bar=0;
m_class_delta=class_delta;
}
//+------------------------------------------------------------------+
//| Destructor                                                       |
//| Check for initialization, create model                           |
//+------------------------------------------------------------------+
bool CheckInit(const string symbol,const ENUM_TIMEFRAMES period,const uchar& model[])
{
//--- check symbol, period
if(symbol!=m_symbol || period!=m_period)
{
PrintFormat("Model must work with %s,%s",m_symbol,EnumToString(m_period));
return(false);
}
//--- create a model from static buffer
m_handle=OnnxCreateFromBuffer(model,ONNX_DEFAULT);
if(m_handle==INVALID_HANDLE)
{
Print("OnnxCreateFromBuffer error ",GetLastError());
return(false);
}
//--- ok
return(true);
}
//+------------------------------------------------------------------+
m_next_bar=TimeCurrent();
m_next_bar-=m_next_bar%PeriodSeconds(m_period);
m_next_bar+=PeriodSeconds(m_period);
//--- work on new day bar
return(true);
}
//+------------------------------------------------------------------+
//| virtual stub for PredictPrice (regression model)               |
//+------------------------------------------------------------------+
virtual double PredictPrice(void)
{
return(DBL_MAX);
}
//+------------------------------------------------------------------+
//| Predict class (regression -> classification)                     |
//+------------------------------------------------------------------+
virtual int PredictClass(void)
{
double predicted_price=PredictPrice();
if(predicted_price==DBL_MAX)
return(-1);
int    predicted_class=-1;
double last_close=iClose(m_symbol,m_period,1);
//--- classify predicted price movement
double delta=last_close-predicted_price;
if(fabs(delta)<=m_class_delta)
predicted_class=PRICE_SAME;
else
private:
int               m_sample_size;
//+------------------------------------------------------------------+
virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period)
{
//--- check symbol, period, create model
if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_10_class))
{
Print("model_eurusd_D1_10_class : initialization error");
return(false);
}
//--- since not all sizes defined in the input tensor we must set them explicitly
//--- first index - batch size, second index - series size, third index - number of series (OHLC)
const long input_shape[] = {1,m_sample_size,4};
if(!OnnxSetInputShape(m_handle,0,input_shape))
{
Print("model_eurusd_D1_10_class : OnnxSetInputShape error ",GetLastError());
return(false);
}
//--- since not all sizes defined in the output tensor we must set them explicitly
//--- first index - batch size, must match the batch size of the input tensor
//--- second index - number of classes (up, same or down)
const long output_shape[] = {1,3};
if(!OnnxSetOutputShape(m_handle,0,output_shape))
{
Print("model_eurusd_D1_10_class : OnnxSetOutputShape error ",GetLastError());
return(false);
}
//--- ok
return(true);
}
//+------------------------------------------------------------------+
//| Predict class                                                    |
//+------------------------------------------------------------------+
virtual int PredictClass(void)
{
页: [1]
查看完整版本: 期货量化软件:赫兹量化中包装 ONNX 模型