程序化交易吧 关注:3,286贴子:8,504
  • 0回复贴,共1

期货_期货量化交易软件:利用回归衡量度评估 ONNX 模型

只看楼主收藏回复

概述回归是一项依据未标记样本预测真实数值的任务。 一个著名的回归例子是基于大小、重量、颜色、净度、等特征来估算钻石的价值。而所谓的回归衡量度则是用来评估回归模型的预测准确性。 尽管算法相似,回归衡量度在语义上与相似的损失函数有所区别。 了解它们之间的区别很重要。 它可按如下方式表述:
当赫兹将构建模型的问题降解为优化问题时,此刻损失函数就会浮现。 通常要求它具有良好的性质(例如,可微性)
衡量度是外部客观品质准则,通常不依赖于模型参数,而仅建立在预测值之上。
MQL5 中的回归衡量度MQL5 语言具有以下衡量度:
平均绝对误差,MAE
均方误差,MSE
均方根误差,RMSE
R-平方,R2
平均绝对百分比误差,MAPE
均方百分比误差,MSPE
均方根对数误差,RMSLE
预计 MQL5 中回归衡量度的数量还会增加。
回归衡量度的简明特征MAE 估算绝对误差 — 预测数字偏离实际数字的程度。 误差的测量单位与目标函数的数值相同。 错误值基于可能的数值范围进行解释。 例如,如果目标值在 1 到 1.5 的范围内,则平均绝对误差值为 10 就是一个非常大的误差;而对于 10000...15000 的范围,那就可以接受。 它不适合针对扩散较大的数值评估预测。在 MSE 中,由于平方,每个误差都有自己的权重。 由此,预测与现实之间的巨大差异更加明显。RMSE 具有与 MSE 相同的优点,但更易于理解,因为误差的测量单位与目标函数的数值相同。 它对异常和尖峰非常敏感。 MAE 和 RMSE 可以一起使用,来检测一组预测中的误差变化。 RMSE 始终大于或等于 MAE。 它们之间的差值越大,样本中独立误差的扩散就越大。 如果 RMSE = MAE,则所有误差的幅度相同。R2 — 判定率表示两个随机变量之间的关系强度。 它有助于判定模型能够解释的数据多样性的份额。 如果模型始终准确预测,则衡量度为 1。 对于泛泛的模型,它是 0。 如果模型预测比泛泛者还糟糕,同时模型不遵循数据趋势,则衡量度数值可能为负值。MAPE — 误差没有量纲,非常容易解释。 它既可以表示为小数,也可以表示为百分比。 在 MQL5 中,它以小数表达。 例如,值 0.1 表示误差为真实数值的 10%。 该衡量度背后的思想是对于相对偏差的敏感性。 它不适用于需要使用真实测量单位的任务。MSPE 可被认为是 MSE 的加权版本,其中权重与观测值的平方成反比。 因此,随着观测值的递增,误差趋于递减。RMSLE 用于实际值超出若干个数量级时。 根据定义,预测值和实际观测值不能为负值。计算上述所有衡量度的算法在源文件 VectorRegressionMetric.mqh 中提供。
ONNX 模型赫兹量化用到 4 个回归模型,依据日线的前一根柱线预测当天的收盘价(EURUSD, D1)。 我们在之前的文章中研究了这些模型:“在类中包装 ONNX 模型”、“如何在 MQL5 中集成 ONNX 模型的示例”,以及 “如何在 MQL5 中使用 ONNX 模型”。 故此,我们不会在此重复训练模型所用的规则。 训练所有模型的脚本位于本文随附的 zip 存档的 Python 子文件夹之中。 经训练的 onnx 模型 — model.eurusd.D1.10、model.eurusd.D1.30、model.eurusd.D1.52 和 model.eurusd.D1.63 也位于那里。添加图片注释,不超过 140 字(可选)在类中包装 ONNX 模型在上一篇文章中,赫兹量化展示了 ONNX 模型的基类和分类模型的派生类。 赫兹量化已针对基类进行了一些小的修改,令其更加灵活。//+------------------------------------------------------------------+//| ModelSymbolPeriod.mqh |//| Copyright 2023, MetaQuotes Ltd. |//| https://www.mql5.com |//+------------------------------------------------------------------+//--- price movement prediction#define PRICE_UP 0#define PRICE_SAME 1#define PRICE_DOWN 2//+------------------------------------------------------------------+//| Base class for models based on trained symbol and period |//+------------------------------------------------------------------+class CModelSymbolPeriod {protected: string m_name; // model name long m_handle; // created model session handle string m_symbol; // symbol of trained data ENUM_TIMEFRAMES m_period; // timeframe of trained data datetime m_next_bar; // time of next bar (we work at bar begin only) double m_class_delta; // delta to recognize "price the same" in regression modelspublic: //+------------------------------------------------------------------+ //| Constructor | //+------------------------------------------------------------------+ CModelSymbolPeriod(const string symbol,const ENUM_TIMEFRAMES period,const double class_delta=0.0001) { m_name=""; m_handle=INVALID_HANDLE; m_symbol=symbol; m_period=period; m_next_bar=0; m_class_delta=class_delta; } //+------------------------------------------------------------------+ //| Destructor | //+------------------------------------------------------------------+ ~CModelSymbolPeriod(void) { Shutdown(); } //+------------------------------------------------------------------+ //| | //+------------------------------------------------------------------+ string GetModelName(void) { return(m_name); } //+------------------------------------------------------------------+ //| virtual stub for Init | //+------------------------------------------------------------------+ virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period) { return(false); } //+------------------------------------------------------------------+ //| Check for initialization, create model | //+------------------------------------------------------------------+ bool CheckInit(const string symbol, const ENUM_TIMEFRAMES period,const uchar& model[]) { //--- check symbol, period if(symbol!=m_symbol || period!=m_period) { PrintFormat("Model must work with %s,%s",m_symbol,EnumToString(m_period)); return(false); } //--- create a model from static buffer m_handle=OnnxCreateFromBuffer(model,ONNX_DEFAULT); if(m_handle==INVALID_HANDLE) { Print("OnnxCreateFromBuffer error ",GetLastError()); return(false); } //--- ok return(true); } //+------------------------------------------------------------------+ //| Release ONNX session | //+------------------------------------------------------------------+ void Shutdown(void) { if(m_handle!=INVALID_HANDLE) { OnnxRelease(m_handle); m_handle=INVALID_HANDLE; } } //+------------------------------------------------------------------+ //| Check for continue OnTick | //+------------------------------------------------------------------+ virtual bool CheckOnTick(void) { //--- check new bar if(TimeCurrent()<m_next_bar) return(false); //--- set next bar time m_next_bar=TimeCurrent(); m_next_bar-=m_next_bar%PeriodSeconds(m_period); m_next_bar+=PeriodSeconds(m_period); //--- work on new day bar return(true); } //+------------------------------------------------------------------+ //| virtual stub for PredictPrice (regression model) | //+------------------------------------------------------------------+ virtual double PredictPrice(datetime date) { return(DBL_MAX); } //+------------------------------------------------------------------+ //| Predict class (regression ~> classification) | //+------------------------------------------------------------------+ virtual int PredictClass(datetime date,vector& probabilities) { date-=date%PeriodSeconds(m_period); double predicted_price=PredictPrice(date); if(predicted_price==DBL_MAX) return(-1); double last_close[2]; if(CopyClose(m_symbol,m_period,date,2,last_close)!=2) return(-1); double prev_price=last_close[0]; //--- classify predicted price movement int predicted_class=-1; double delta=prev_price-predicted_price; if(fabs(delta)<=m_class_delta) predicted_class=PRICE_SAME; else { if(delta<0) predicted_class=PRICE_UP; else predicted_class=PRICE_DOWN; } //--- set predicted probability as 1.0 probabilities.Fill(0); if(predicted_class<(int)probabilities.Size()) probabilities[predicted_class]=1; //--- and return predicted class return(predicted_class); } };//+------------------------------------------------------------------+我们已在 PredictPrice 和 PredictClass 方法中加入了一个 datetime 参数,以便我们可以在任何时间点进行预测,而不仅只在当下。 这对于形成预测向量很实用。
D1_10 模型类我们的第一个模型称为 model.eurusd.D1.10.onnx。 回归模型依据 EURUSD D1 的 10 个 OHLC 价格序列进行了训练。//+------------------------------------------------------------------+//| ModelEurusdD1_10.mqh |//| Copyright 2023, MetaQuotes Ltd. |//| https://www.mql5.com |//+------------------------------------------------------------------+#include "ModelSymbolPeriod.mqh"#resource "Python/model.eurusd.D1.10.onnx" as uchar model_eurusd_D1_10[]//+------------------------------------------------------------------+//| ONNX-model wrapper class |//+------------------------------------------------------------------+class CModelEurusdD1_10 : public CModelSymbolPeriod {private: int m_sample_size;public: //+------------------------------------------------------------------+ //| Constructor | //+------------------------------------------------------------------+ CModelEurusdD1_10(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1)


IP属地:浙江1楼2024-01-10 17:26回复