From 31681b80e5bab40944b6c81dc3319444a510fa27 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 25 Jan 2024 09:47:37 +0100
Subject: [PATCH] add the offsets in the formula of the latent prob.

---
 pyPLNmodels/_closed_forms.py |  4 ++--
 pyPLNmodels/models.py        | 29 +++++++++++++++++++++--------
 2 files changed, 23 insertions(+), 10 deletions(-)

diff --git a/pyPLNmodels/_closed_forms.py b/pyPLNmodels/_closed_forms.py
index 3524d48d..0acfc8ae 100644
--- a/pyPLNmodels/_closed_forms.py
+++ b/pyPLNmodels/_closed_forms.py
@@ -101,7 +101,7 @@ def _closed_formula_pi(
     return torch._sigmoid(poiss_param + torch.mm(exog, _coef_inflation)) * dirac
 
 
-def _closed_formula_latent_prob(exog, coef, coef_infla, cov, dirac):
+def _closed_formula_latent_prob(exog, coef, offsets, coef_infla, cov, dirac):
     if exog is not None:
         XB = exog @ coef
         XB_zero = exog @ coef_infla
@@ -112,4 +112,4 @@ def _closed_formula_latent_prob(exog, coef, coef_infla, cov, dirac):
     pi = torch.sigmoid(XB_zero)
     diag = torch.diag(cov)
     full_diag = diag.expand(exog.shape[0], -1)
-    return torch.sigmoid(XB_zero - torch.log(phi(XB, full_diag))) * dirac
+    return torch.sigmoid(XB_zero - torch.log(phi(XB + offsets, full_diag))) * dirac
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 41b2fc6c..e77cf2fb 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -3818,7 +3818,26 @@ class ZIPln(_model):
         The closed form for the latent probability.
         """
         return _closed_formula_latent_prob(
-            self._exog, self._coef, self._coef_inflation, self._covariance, self._dirac
+            self._exog,
+            self._coef,
+            self._offsets,
+            self._coef_inflation,
+            self._covariance,
+            self._dirac,
+        )
+
+    @property
+    def closed_formula_latent_prob_b(self):
+        """
+        The closed form for the latent probability for the batch.
+        """
+        return _closed_formula_latent_prob(
+            self._exog_b,
+            self._coef,
+            self._offsets_b,
+            self._coef_inflation,
+            self._covariance,
+            self._dirac_b,
         )
 
     def compute_elbo(self):
@@ -3841,13 +3860,7 @@ class ZIPln(_model):
 
     def _compute_elbo_b(self):
         if self._use_closed_form_prob is True:
-            latent_prob_b = _closed_formula_latent_prob(
-                self._exog_b,
-                self._coef,
-                self._coef_inflation,
-                self._covariance,
-                self._dirac_b,
-            )
+            latent_prob_b = self.closed_formula_latent_prob_b
         else:
             latent_prob_b = self._latent_prob_b
         return elbo_zi_pln(
-- 
GitLab