@@ -58,6 +58,13 @@ namespace mio
5858 * decreasing. This is no limitation as the support is only needed for StateAgeFunctions of Type a) as given above.
5959 * For classes of type b) a dummy implementation logging an error and returning -2 for get_support_max() should be implemented.
6060 *
61+ * The get_mean method is virtual and implements a basic version to determine the mean value of the StateAgeFunction.
62+ * The base class implementation uses the fact that the StateAgeFunction is a survival function
63+ * (i.e. 1-CDF for any cumulative distribution function CDF).
64+ * Therefore, the base class implementation should only be used for StateAgeFunction%s of type a).
65+ * For some derived classes there is a more efficient way (see e.g., ExponentialDecay) to do this which is
66+ * why it can be overridden.
67+ *
6168 * See ExponentialDecay, SmootherCosine and ConstantFunction for examples of derived classes.
6269 */
6370struct StateAgeFunction {
@@ -69,8 +76,10 @@ struct StateAgeFunction {
6976 */
7077 StateAgeFunction (ScalarType init_parameter)
7178 : m_parameter{init_parameter}
72- , m_support_max{-1 .} // initialize support maximum as not set
73- , m_support_tol{-1 .} // initialize support tolerance as not set
79+ , m_mean{-1 .} // Initialize mean as not set.
80+ , m_mean_tol{-1 .} // Initialize tolerance for computation of mean as not set.
81+ , m_support_max{-1 .} // Initialize support maximum as not set.
82+ , m_support_tol{-1 .} // Initialize tolerance for computation of support as not set.
7483 {
7584 }
7685
@@ -144,6 +153,7 @@ struct StateAgeFunction {
144153 m_parameter = new_parameter;
145154
146155 m_support_max = -1 .;
156+ m_mean = -1 ;
147157 }
148158
149159 /* *
@@ -176,6 +186,37 @@ struct StateAgeFunction {
176186 return m_support_max;
177187 }
178188
189+ /* *
190+ * @brief Computes the mean value of the function using the time step size dt and some tolerance tol.
191+ *
192+ * This is a basic version to determine the mean value of a survival function
193+ * through numerical integration of the integral that describes the expected value.
194+ * This basic implementation is only valid if the StateAgeFunction is of type a). Otherwise it should be overridden.
195+ *
196+ * For some specific derivations of StateAgeFunction%s there are more efficient ways to determine the
197+ * the mean value which is why this member function is virtual and can be overridden (see, e.g., ExponentialDecay).
198+ * The mean value is only needed for StateAgeFunction%s that are used as TransitionDistribution%s.
199+ *
200+ * @param[in] dt Time step size used for the numerical integration.
201+ * @param[in] tol The maximum support used for numerical integration is calculated using this tolerance.
202+ * @return ScalarType mean value.
203+ */
204+ virtual ScalarType get_mean (ScalarType dt = 1 ., ScalarType tol = 1e-10 )
205+ {
206+ if (!floating_point_equal (m_mean_tol, tol, 1e-14 ) || floating_point_equal (m_mean, -1 ., 1e-14 )) {
207+ // Integration using Trapezoidal rule.
208+ ScalarType mean = 0.5 * dt * eval (0 * dt);
209+ ScalarType supp_max_idx = std::ceil (get_support_max (dt, tol) / dt);
210+ for (int i = 1 ; i < supp_max_idx; i++) {
211+ mean += dt * eval (i * dt);
212+ }
213+
214+ m_mean = mean;
215+ m_mean_tol = tol;
216+ }
217+ return m_mean;
218+ }
219+
179220 /* *
180221 * @brief Get type of StateAgeFunction, i.e.which derived class is used.
181222 *
@@ -205,6 +246,8 @@ struct StateAgeFunction {
205246 virtual StateAgeFunction* clone_impl () const = 0;
206247
207248 ScalarType m_parameter; // /< Parameter for function in derived class.
249+ ScalarType m_mean; // /< Mean value of the function.
250+ ScalarType m_mean_tol; // /< Tolerance for computation of the mean.
208251 ScalarType m_support_max; // /< Maximum of the support of the function.
209252 ScalarType m_support_tol; // /< Tolerance for computation of the support.
210253};
@@ -241,6 +284,22 @@ struct ExponentialDecay : public StateAgeFunction {
241284 return std::exp (-m_parameter * state_age);
242285 }
243286
287+ /* *
288+ * @brief Computes the mean value of the function.
289+ *
290+ * For ExponentialDecay, the mean value is the reciprocal of the function parameter.
291+ *
292+ * @param[in] dt Time step size used for the numerical integration (unused for ExponentialDecay).
293+ * @param[in] tol The maximum support used for numerical integration is calculated using this tolerance (unused for ExponentialDecay).
294+ * @return ScalarType mean value.
295+ */
296+ ScalarType get_mean (ScalarType dt = 1 ., ScalarType tol = 1e-10 ) override
297+ {
298+ unused (dt);
299+ unused (tol);
300+ return 1 . / m_parameter;
301+ }
302+
244303protected:
245304 /* *
246305 * @brief Implements clone for ExponentialDecay.
@@ -298,6 +357,10 @@ struct SmootherCosine : public StateAgeFunction {
298357 return m_support_max;
299358 }
300359
360+ // TODO: There is also a closed form for the mean value of Smoothercosine: 0.5*m_parameter.
361+ // However, a StateAgeFunction that uses the default implementation is required for testing purposes.
362+ // Therefore, the closed form is only used for comparison in the tests.
363+ // If another StateAgeFunction is implemented that uses the default implementation, the function get_mean() should be overwritten here.
301364protected:
302365 /* *
303366 * @brief Clones unique pointer to a StateAgeFunction.
@@ -365,6 +428,22 @@ struct ConstantFunction : public StateAgeFunction {
365428 return m_support_max;
366429 }
367430
431+ /* *
432+ * @brief Computes the mean value of the function.
433+ *
434+ * For ConstantFunction, the mean value is the function parameter.
435+ *
436+ * @param[in] dt Time step size used for the numerical integration (unused for ConstantFunction).
437+ * @param[in] tol The maximum support used for numerical integration is calculated using this tolerance (unused for ConstantFunction).
438+ * @return ScalarType mean value.
439+ */
440+ ScalarType get_mean (ScalarType dt = 1 ., ScalarType tol = 1e-10 ) override
441+ {
442+ unused (dt);
443+ unused (tol);
444+ return m_parameter;
445+ }
446+
368447protected:
369448 /* *
370449 * @brief Clones unique pointer to a StateAgeFunction.
@@ -460,7 +539,7 @@ struct StateAgeFunctionWrapper {
460539 /* *
461540 * @brief Get type of StateAgeFunction, i.e. which derived class is used.
462541 *
463- * @param[out] string
542+ * @return string
464543 */
465544 std::string get_state_age_function_type () const
466545 {
@@ -498,11 +577,30 @@ struct StateAgeFunctionWrapper {
498577 m_function->set_parameter (new_parameter);
499578 }
500579
580+ /* *
581+ * @brief Get the m_support_max object of m_function.
582+ *
583+ * @param[in] dt Time step size at which function will be evaluated.
584+ * @param[in] tol Tolerance used for cutting the support if the function value falls below.
585+ * @return ScalarType m_support_max
586+ */
501587 ScalarType get_support_max (ScalarType dt, ScalarType tol = 1e-10 ) const
502588 {
503589 return m_function->get_support_max (dt, tol);
504590 }
505591
592+ /* *
593+ * @brief Get the m_mean object of m_function.
594+ *
595+ * @param[in] dt Time step size used for the numerical integration.
596+ * @param[in] tol The maximum support used for numerical integration is calculated using this tolerance.
597+ * @return ScalarType m_mean
598+ */
599+ ScalarType get_mean (ScalarType dt = 1 ., ScalarType tol = 1e-10 ) const
600+ {
601+ return m_function->get_mean (dt, tol);
602+ }
603+
506604private:
507605 std::unique_ptr<StateAgeFunction> m_function; // /< Stores StateAgeFunction that is used in Wrapper.
508606};
0 commit comments