From 31681b80e5bab40944b6c81dc3319444a510fa27 Mon Sep 17 00:00:00 2001 From: bastien-mva <bastien.batardiere@gmail.com> Date: Thu, 25 Jan 2024 09:47:37 +0100 Subject: [PATCH] add the offsets in the formula of the latent prob. --- pyPLNmodels/_closed_forms.py | 4 ++-- pyPLNmodels/models.py | 29 +++++++++++++++++++++-------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/pyPLNmodels/_closed_forms.py b/pyPLNmodels/_closed_forms.py index 3524d48d..0acfc8ae 100644 --- a/pyPLNmodels/_closed_forms.py +++ b/pyPLNmodels/_closed_forms.py @@ -101,7 +101,7 @@ def _closed_formula_pi( return torch._sigmoid(poiss_param + torch.mm(exog, _coef_inflation)) * dirac -def _closed_formula_latent_prob(exog, coef, coef_infla, cov, dirac): +def _closed_formula_latent_prob(exog, coef, offsets, coef_infla, cov, dirac): if exog is not None: XB = exog @ coef XB_zero = exog @ coef_infla @@ -112,4 +112,4 @@ def _closed_formula_latent_prob(exog, coef, coef_infla, cov, dirac): pi = torch.sigmoid(XB_zero) diag = torch.diag(cov) full_diag = diag.expand(exog.shape[0], -1) - return torch.sigmoid(XB_zero - torch.log(phi(XB, full_diag))) * dirac + return torch.sigmoid(XB_zero - torch.log(phi(XB + offsets, full_diag))) * dirac diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py index 41b2fc6c..e77cf2fb 100644 --- a/pyPLNmodels/models.py +++ b/pyPLNmodels/models.py @@ -3818,7 +3818,26 @@ class ZIPln(_model): The closed form for the latent probability. """ return _closed_formula_latent_prob( - self._exog, self._coef, self._coef_inflation, self._covariance, self._dirac + self._exog, + self._coef, + self._offsets, + self._coef_inflation, + self._covariance, + self._dirac, + ) + + @property + def closed_formula_latent_prob_b(self): + """ + The closed form for the latent probability for the batch. + """ + return _closed_formula_latent_prob( + self._exog_b, + self._coef, + self._offsets_b, + self._coef_inflation, + self._covariance, + self._dirac_b, ) def compute_elbo(self): @@ -3841,13 +3860,7 @@ class ZIPln(_model): def _compute_elbo_b(self): if self._use_closed_form_prob is True: - latent_prob_b = _closed_formula_latent_prob( - self._exog_b, - self._coef, - self._coef_inflation, - self._covariance, - self._dirac_b, - ) + latent_prob_b = self.closed_formula_latent_prob_b else: latent_prob_b = self._latent_prob_b return elbo_zi_pln( -- GitLab