From fd1160a44c8630415f8a939bce707bc6fbfb17df Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 22 Mar 2026 15:40:34 -0400 Subject: [PATCH 01/14] Add survey support for Phase 4 estimators (ImputationDiD, TwoStageDiD, CallawaySantAnna) - Weighted solve_logit(): survey weights enter IRLS as w_survey * mu*(1-mu) - ImputationDiD: weighted iterative FE, survey-weighted ATT aggregation, weighted conservative variance (Theorem 3), survey df for inference - TwoStageDiD: weighted iterative FE, weighted Stage 2 OLS, weighted GMM sandwich variance with survey weights in both stages - CallawaySantAnna: survey-weighted regression, IPW (via weighted solve_logit), and DR methods with explicit influence functions; survey-weighted WIF in aggregation; Cholesky cache bypassed under survey weights - Unblock TripleDifference IPW/DR with survey (weighted solve_logit now available) - 38 new tests in test_survey_phase4.py covering all estimators + scale invariance - Update survey-roadmap.md, REGISTRY.md with Phase 4 status and deviation notes Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/imputation.py | 244 +++++++-- diff_diff/imputation_results.py | 23 + diff_diff/linalg.py | 56 +- diff_diff/staggered.py | 505 +++++++++++++----- diff_diff/staggered_aggregation.py | 238 ++++++--- diff_diff/staggered_results.py | 25 +- diff_diff/triple_diff.py | 16 +- diff_diff/two_stage.py | 268 ++++++++-- diff_diff/two_stage_results.py | 23 + docs/methodology/REGISTRY.md | 3 + docs/survey-roadmap.md | 56 +- tests/test_survey_phase3.py | 46 +- tests/test_survey_phase4.py | 803 +++++++++++++++++++++++++++++ 13 files changed, 1928 insertions(+), 378 deletions(-) create mode 100644 tests/test_survey_phase4.py diff --git a/diff_diff/imputation.py b/diff_diff/imputation.py index f74eca91..b029b08c 100644 --- a/diff_diff/imputation.py +++ b/diff_diff/imputation.py @@ -177,6 +177,7 @@ def fit( covariates: Optional[List[str]] = None, aggregate: Optional[str] = None, balance_e: Optional[int] = None, + survey_design: object = None, ) -> ImputationDiDResults: """ Fit the imputation DiD estimator. @@ -225,6 +226,29 @@ def fit( # Create working copy df = data.copy() + # Resolve survey design if provided + from diff_diff.survey import ( + _inject_cluster_as_psu, + _resolve_effective_cluster, + _resolve_survey_for_fit, + _validate_unit_constant_survey, + ) + + resolved_survey, survey_weights, survey_weight_type, survey_metadata = ( + _resolve_survey_for_fit(survey_design, data, "analytical") + ) + + # Validate within-unit constancy for panel survey designs + if resolved_survey is not None: + _validate_unit_constant_survey(data, unit, survey_design) + + # Guard bootstrap + survey + if self.n_bootstrap > 0 and resolved_survey is not None: + raise NotImplementedError( + "Bootstrap inference with survey weights is not yet supported " + "for ImputationDiD. Use analytical inference (n_bootstrap=0)." + ) + # Ensure numeric types df[time] = pd.to_numeric(df[time]) df[first_treat] = pd.to_numeric(df[first_treat]) @@ -312,6 +336,26 @@ def fit( f"Available columns: {list(df.columns)}" ) + # Resolve effective cluster and inject cluster-as-PSU for survey variance + if resolved_survey is not None: + cluster_ids_raw = df[cluster_var].values if cluster_var in df.columns else None + effective_cluster_ids = _resolve_effective_cluster( + resolved_survey, + cluster_ids_raw, + cluster_var if self.cluster is not None else None, + ) + resolved_survey = _inject_cluster_as_psu(resolved_survey, effective_cluster_ids) + # Recompute metadata after PSU injection + if resolved_survey.psu is not None and survey_metadata is not None: + from diff_diff.survey import compute_survey_metadata + + raw_w = ( + data[survey_design.weights].values.astype(np.float64) + if survey_design.weights + else np.ones(len(data), dtype=np.float64) + ) + survey_metadata = compute_survey_metadata(resolved_survey, raw_w) + # Compute relative time df["_rel_time"] = np.where( ~df["_never_treated"], @@ -321,7 +365,7 @@ def fit( # ---- Step 1: OLS on untreated observations ---- unit_fe, time_fe, grand_mean, delta_hat, kept_cov_mask = self._fit_untreated_model( - df, outcome, unit, time, covariates, omega_0_mask + df, outcome, unit, time, covariates, omega_0_mask, weights=survey_weights ) # ---- Rank condition checks ---- @@ -382,20 +426,31 @@ def fit( # ---- Step 3: Aggregate ---- # Always compute overall ATT (simple aggregation) - valid_tau = tau_hat[np.isfinite(tau_hat)] + finite_mask = np.isfinite(tau_hat) + valid_tau = tau_hat[finite_mask] if len(valid_tau) == 0: overall_att = np.nan + elif survey_weights is not None: + # Survey-weighted ATT: use treated obs' survey weights + treated_survey_w = survey_weights[omega_1_mask.values] + w_finite = treated_survey_w[finite_mask] + overall_att = float(np.average(valid_tau, weights=w_finite)) else: overall_att = float(np.mean(valid_tau)) # ---- Conservative variance (Theorem 3) ---- - # Build weights matching the ATT: uniform over finite tau_hat, zero for NaN + # Build weights matching the ATT: proportional to survey weights for + # finite tau_hat, uniform when no survey overall_weights = np.zeros(n_omega_1) - finite_mask = np.isfinite(tau_hat) n_valid = int(finite_mask.sum()) if n_valid > 0: - overall_weights[finite_mask] = 1.0 / n_valid + if survey_weights is not None: + treated_sw = survey_weights[omega_1_mask.values] + sw_finite = treated_sw[finite_mask] + overall_weights[finite_mask] = sw_finite / sw_finite.sum() + else: + overall_weights[finite_mask] = 1.0 / n_valid if n_valid == 0: overall_se = np.nan @@ -418,7 +473,12 @@ def fit( kept_cov_mask=kept_cov_mask, ) - overall_t, overall_p, overall_ci = safe_inference(overall_att, overall_se, alpha=self.alpha) + # Survey degrees of freedom for t-distribution inference + _survey_df = resolved_survey.df_survey if resolved_survey is not None else None + + overall_t, overall_p, overall_ci = safe_inference( + overall_att, overall_se, alpha=self.alpha, df=_survey_df + ) # Event study and group aggregation event_study_effects = None @@ -442,6 +502,8 @@ def fit( treatment_groups=treatment_groups, balance_e=balance_e, kept_cov_mask=kept_cov_mask, + survey_weights=survey_weights, + survey_df=_survey_df, ) if aggregate in ("group", "all"): @@ -461,6 +523,8 @@ def fit( cluster_var=cluster_var, treatment_groups=treatment_groups, kept_cov_mask=kept_cov_mask, + survey_weights=survey_weights, + survey_df=_survey_df, ) # Build treatment effects dataframe @@ -599,6 +663,7 @@ def fit( alpha=self.alpha, bootstrap_results=bootstrap_results, _estimator_ref=self, + survey_metadata=survey_metadata, ) self.is_fitted_ = True @@ -616,6 +681,7 @@ def _iterative_fe( idx: pd.Index, max_iter: int = 100, tol: float = 1e-10, + weights: Optional[np.ndarray] = None, ) -> Tuple[Dict[Any, float], Dict[Any, float]]: """ Estimate unit and time FE via iterative alternating projection (Gauss-Seidel). @@ -624,6 +690,12 @@ def _iterative_fe( For balanced panels, converges in 1-2 iterations (identical to one-pass). For unbalanced panels, typically 5-20 iterations. + Parameters + ---------- + weights : np.ndarray, optional + Survey weights. When provided, uses weighted group means + (sum(w*x)/sum(w)) instead of unweighted means. + Returns ------- unit_fe : dict @@ -635,25 +707,37 @@ def _iterative_fe( alpha = np.zeros(n) # unit FE broadcast to obs level beta = np.zeros(n) # time FE broadcast to obs level + # Precompute per-group weight sums (invariant across iterations) + if weights is not None: + w_series = pd.Series(weights, index=idx) + wsum_t = w_series.groupby(time_vals).transform("sum").values + wsum_u = w_series.groupby(unit_vals).transform("sum").values + with np.errstate(invalid="ignore", divide="ignore"): for iteration in range(max_iter): - # Update time FE: beta_t = mean_i(y_it - alpha_i) resid_after_alpha = y - alpha - beta_new = ( - pd.Series(resid_after_alpha, index=idx) - .groupby(time_vals) - .transform("mean") - .values - ) + if weights is not None: + wr_t = pd.Series(resid_after_alpha * weights, index=idx) + beta_new = wr_t.groupby(time_vals).transform("sum").values / wsum_t + else: + beta_new = ( + pd.Series(resid_after_alpha, index=idx) + .groupby(time_vals) + .transform("mean") + .values + ) - # Update unit FE: alpha_i = mean_t(y_it - beta_t) resid_after_beta = y - beta_new - alpha_new = ( - pd.Series(resid_after_beta, index=idx) - .groupby(unit_vals) - .transform("mean") - .values - ) + if weights is not None: + wr_u = pd.Series(resid_after_beta * weights, index=idx) + alpha_new = wr_u.groupby(unit_vals).transform("sum").values / wsum_u + else: + alpha_new = ( + pd.Series(resid_after_beta, index=idx) + .groupby(unit_vals) + .transform("mean") + .values + ) # Check convergence on FE changes max_change = max( @@ -677,25 +761,47 @@ def _iterative_demean( idx: pd.Index, max_iter: int = 100, tol: float = 1e-10, + weights: Optional[np.ndarray] = None, ) -> np.ndarray: """Demean a vector by iterative alternating projection (unit + time FE removal). Converges to the exact within-transformation for both balanced and unbalanced panels. For balanced panels, converges in 1-2 iterations. + + Parameters + ---------- + weights : np.ndarray, optional + Survey weights. When provided, uses weighted group means + (sum(w*x)/sum(w)) instead of unweighted means. """ result = vals.copy() + + # Precompute per-group weight sums (invariant across iterations) + if weights is not None: + w_series = pd.Series(weights, index=idx) + wsum_t = w_series.groupby(time_vals).transform("sum").values + wsum_u = w_series.groupby(unit_vals).transform("sum").values + with np.errstate(invalid="ignore", divide="ignore"): for _ in range(max_iter): - time_means = ( - pd.Series(result, index=idx).groupby(time_vals).transform("mean").values - ) + if weights is not None: + wr_t = pd.Series(result * weights, index=idx) + time_means = wr_t.groupby(time_vals).transform("sum").values / wsum_t + else: + time_means = ( + pd.Series(result, index=idx).groupby(time_vals).transform("mean").values + ) result_after_time = result - time_means - unit_means = ( - pd.Series(result_after_time, index=idx) - .groupby(unit_vals) - .transform("mean") - .values - ) + if weights is not None: + wr_u = pd.Series(result_after_time * weights, index=idx) + unit_means = wr_u.groupby(unit_vals).transform("sum").values / wsum_u + else: + unit_means = ( + pd.Series(result_after_time, index=idx) + .groupby(unit_vals) + .transform("mean") + .values + ) result_new = result_after_time - unit_means if np.max(np.abs(result_new - result)) < tol: result = result_new @@ -773,6 +879,7 @@ def _fit_untreated_model( time: str, covariates: Optional[List[str]], omega_0_mask: pd.Series, + weights: Optional[np.ndarray] = None, ) -> Tuple[ Dict[Any, float], Dict[Any, float], float, Optional[np.ndarray], Optional[np.ndarray] ]: @@ -783,6 +890,12 @@ def _fit_untreated_model( OLS fixed effects for both balanced and unbalanced panels. For balanced panels, converges in 1-2 iterations (identical to one-pass demeaning). + Parameters + ---------- + weights : np.ndarray, optional + Full-panel survey weights (same length as df). The untreated subset + is extracted internally via omega_0_mask. When None, unweighted. + Returns ------- unit_fe : dict @@ -798,13 +911,14 @@ def _fit_untreated_model( have finite coefficients. None if no covariates. """ df_0 = df.loc[omega_0_mask] + w_0 = weights[omega_0_mask.values] if weights is not None else None if covariates is None or len(covariates) == 0: # No covariates: estimate FE via iterative alternating projection # (exact OLS for both balanced and unbalanced panels) y = df_0[outcome].values.copy() unit_fe, time_fe = self._iterative_fe( - y, df_0[unit].values, df_0[time].values, df_0.index + y, df_0[unit].values, df_0[time].values, df_0.index, weights=w_0 ) # grand_mean = 0: iterative FE absorb the intercept return unit_fe, time_fe, 0.0, None, None @@ -819,10 +933,10 @@ def _fit_untreated_model( n_cov = len(covariates) # Step A: Iteratively demean Y and all X columns to remove unit+time FE - y_dm = self._iterative_demean(y, units, times, df_0.index) + y_dm = self._iterative_demean(y, units, times, df_0.index, weights=w_0) X_dm = np.column_stack( [ - self._iterative_demean(X_raw[:, j], units, times, df_0.index) + self._iterative_demean(X_raw[:, j], units, times, df_0.index, weights=w_0) for j in range(n_cov) ] ) @@ -834,6 +948,7 @@ def _fit_untreated_model( return_vcov=False, rank_deficient_action=self.rank_deficient_action, column_names=covariates, + weights=w_0, ) delta_hat = result[0] @@ -847,7 +962,7 @@ def _fit_untreated_model( # Step C: Recover FE from covariate-adjusted outcome using iterative FE y_adj = y - np.dot(X_raw, delta_hat_clean) - unit_fe, time_fe = self._iterative_fe(y_adj, units, times, df_0.index) + unit_fe, time_fe = self._iterative_fe(y_adj, units, times, df_0.index, weights=w_0) # grand_mean = 0: iterative FE absorb the intercept return unit_fe, time_fe, 0.0, delta_hat_clean, kept_cov_mask @@ -1281,6 +1396,8 @@ def _aggregate_event_study( treatment_groups: List[Any], balance_e: Optional[int] = None, kept_cov_mask: Optional[np.ndarray] = None, + survey_weights: Optional[np.ndarray] = None, + survey_df: Optional[int] = None, ) -> Dict[int, Dict[str, Any]]: """Aggregate treatment effects by event-study horizon.""" df_1 = df.loc[omega_1_mask] @@ -1355,7 +1472,8 @@ def _aggregate_event_study( continue tau_h = tau_hat[h_mask] - valid_tau = tau_h[np.isfinite(tau_h)] + finite_h = np.isfinite(tau_h) + valid_tau = tau_h[finite_h] if len(valid_tau) == 0: event_study_effects[h] = { @@ -1368,10 +1486,32 @@ def _aggregate_event_study( } continue - effect = float(np.mean(valid_tau)) + # Survey-weighted or simple mean for per-horizon effect + if survey_weights is not None: + treated_sw = survey_weights[omega_1_mask.values] + sw_h = treated_sw[h_mask] + sw_valid = sw_h[finite_h] + effect = float(np.average(valid_tau, weights=sw_valid)) + else: + effect = float(np.mean(valid_tau)) # Compute SE via conservative variance with horizon-specific weights - weights_h, n_valid = _compute_target_weights(tau_hat, h_mask) + # When survey, aggregation weights are proportional to survey weights + if survey_weights is not None: + treated_sw = survey_weights[omega_1_mask.values] + n_1 = len(tau_hat) + weights_h = np.zeros(n_1) + sw_h = treated_sw[h_mask] + finite_in_h = np.isfinite(tau_h) + sw_finite = sw_h[finite_in_h] + # Set weights proportional to survey weights, summing to 1 + if sw_finite.sum() > 0: + h_indices = np.where(h_mask)[0] + finite_indices = h_indices[finite_in_h] + weights_h[finite_indices] = sw_finite / sw_finite.sum() + n_valid = int(finite_in_h.sum()) + else: + weights_h, n_valid = _compute_target_weights(tau_hat, h_mask) se = self._compute_conservative_variance( df=df, @@ -1391,7 +1531,7 @@ def _aggregate_event_study( kept_cov_mask=kept_cov_mask, ) - t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha) + t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha, df=survey_df) event_study_effects[h] = { "effect": effect, @@ -1449,6 +1589,8 @@ def _aggregate_group( cluster_var: str, treatment_groups: List[Any], kept_cov_mask: Optional[np.ndarray] = None, + survey_weights: Optional[np.ndarray] = None, + survey_df: Optional[int] = None, ) -> Dict[Any, Dict[str, Any]]: """Aggregate treatment effects by cohort.""" df_1 = df.loc[omega_1_mask] @@ -1465,7 +1607,8 @@ def _aggregate_group( continue tau_g = tau_hat[g_mask] - valid_tau = tau_g[np.isfinite(tau_g)] + finite_g = np.isfinite(tau_g) + valid_tau = tau_g[finite_g] if len(valid_tau) == 0: group_effects[g] = { @@ -1478,10 +1621,29 @@ def _aggregate_group( } continue - effect = float(np.mean(valid_tau)) + # Survey-weighted or simple mean for per-group effect + if survey_weights is not None: + treated_sw = survey_weights[omega_1_mask.values] + sw_g = treated_sw[g_mask] + sw_valid = sw_g[finite_g] + effect = float(np.average(valid_tau, weights=sw_valid)) + else: + effect = float(np.mean(valid_tau)) # Compute SE with group-specific weights - weights_g, _ = _compute_target_weights(tau_hat, g_mask) + # When survey, aggregation weights proportional to survey weights + if survey_weights is not None: + treated_sw = survey_weights[omega_1_mask.values] + n_1 = len(tau_hat) + weights_g = np.zeros(n_1) + sw_g = treated_sw[g_mask] + sw_finite = sw_g[finite_g] + if sw_finite.sum() > 0: + g_indices = np.where(g_mask)[0] + finite_indices = g_indices[finite_g] + weights_g[finite_indices] = sw_finite / sw_finite.sum() + else: + weights_g, _ = _compute_target_weights(tau_hat, g_mask) se = self._compute_conservative_variance( df=df, @@ -1501,7 +1663,7 @@ def _aggregate_group( kept_cov_mask=kept_cov_mask, ) - t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha) + t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha, df=survey_df) group_effects[g] = { "effect": effect, diff --git a/diff_diff/imputation_results.py b/diff_diff/imputation_results.py index 6520fca4..6589af1d 100644 --- a/diff_diff/imputation_results.py +++ b/diff_diff/imputation_results.py @@ -139,6 +139,8 @@ class ImputationDiDResults: bootstrap_results: Optional[ImputationBootstrapResults] = field(default=None, repr=False) # Internal: stores data needed for pretrend_test() _estimator_ref: Optional[Any] = field(default=None, repr=False) + # Survey design metadata (SurveyMetadata instance from diff_diff.survey) + survey_metadata: Optional[Any] = field(default=None, repr=False) def __repr__(self) -> str: """Concise string representation.""" @@ -182,6 +184,27 @@ def summary(self, alpha: Optional[float] = None) -> str: "", ] + # Survey design info + if self.survey_metadata is not None: + sm = self.survey_metadata + lines.extend( + [ + "-" * 85, + "Survey Design".center(85), + "-" * 85, + f"{'Weight type:':<30} {sm.weight_type:>10}", + ] + ) + if sm.n_strata is not None: + lines.append(f"{'Strata:':<30} {sm.n_strata:>10}") + if sm.n_psu is not None: + lines.append(f"{'PSU/Cluster:':<30} {sm.n_psu:>10}") + lines.append(f"{'Effective sample size:':<30} {sm.effective_n:>10.1f}") + lines.append(f"{'Design effect (DEFF):':<30} {sm.design_effect:>10.2f}") + if sm.df_survey is not None: + lines.append(f"{'Survey d.f.:':<30} {sm.df_survey:>10}") + lines.extend(["-" * 85, ""]) + # Overall ATT lines.extend( [ diff --git a/diff_diff/linalg.py b/diff_diff/linalg.py index 6c26d360..01b85e5e 100644 --- a/diff_diff/linalg.py +++ b/diff_diff/linalg.py @@ -390,24 +390,18 @@ def _validate_weights(weights, weight_type, n): """Validate weights array and weight_type for solve_ols/LinearRegression.""" if weight_type not in _VALID_WEIGHT_TYPES: raise ValueError( - f"weight_type must be one of {_VALID_WEIGHT_TYPES}, " - f"got '{weight_type}'" + f"weight_type must be one of {_VALID_WEIGHT_TYPES}, " f"got '{weight_type}'" ) if weights is not None: weights = np.asarray(weights, dtype=np.float64) if weights.shape[0] != n: - raise ValueError( - f"weights length ({weights.shape[0]}) must match " - f"X rows ({n})" - ) + raise ValueError(f"weights length ({weights.shape[0]}) must match " f"X rows ({n})") if np.any(np.isnan(weights)): raise ValueError("Weights contain NaN values") if np.any(np.isinf(weights)): raise ValueError("Weights contain Inf values") if np.any(weights < 0): - raise ValueError( - "Weights must be non-negative" - ) + raise ValueError("Weights must be non-negative") if weight_type == "fweight": fractional = weights - np.round(weights) if np.any(np.abs(fractional) > 1e-10): @@ -693,13 +687,9 @@ def solve_ols( weights=weights, weight_type=weight_type, ) - vcov_out = _expand_vcov_with_nan( - vcov_reduced, _original_X.shape[1], kept_cols - ) + vcov_out = _expand_vcov_with_nan(vcov_reduced, _original_X.shape[1], kept_cols) else: - vcov_out = np.full( - (_original_X.shape[1], _original_X.shape[1]), np.nan - ) + vcov_out = np.full((_original_X.shape[1], _original_X.shape[1]), np.nan) else: vcov_out = _compute_robust_vcov_numpy( _original_X, @@ -1122,6 +1112,7 @@ def solve_logit( tol: float = 1e-8, check_separation: bool = True, rank_deficient_action: str = "warn", + weights: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, np.ndarray]: """ Fit logistic regression via IRLS (Fisher scoring). @@ -1147,6 +1138,13 @@ def solve_logit( - "warn": Emit warning and drop columns (default) - "error": Raise ValueError - "silent": Drop columns silently + weights : np.ndarray, optional + Survey/observation weights of shape (n_samples,). When provided, + the IRLS working weights become ``weights * mu * (1 - mu)`` + instead of ``mu * (1 - mu)``. This produces the survey-weighted + maximum likelihood estimator, matching R's ``svyglm(family=binomial)``. + When None (default), behavior is identical to unweighted logistic + regression. Returns ------- @@ -1203,11 +1201,16 @@ def solve_logit( mu = np.clip(mu, 1e-10, 1 - 1e-10) # Working weights and working response - w = mu * (1.0 - mu) - z = eta + (y - mu) / w + w_irls = mu * (1.0 - mu) + z = eta + (y - mu) / w_irls + + if weights is not None: + w_total = weights * w_irls + else: + w_total = w_irls # Weighted least squares: solve (X'WX) beta = X'Wz - sqrt_w = np.sqrt(w) + sqrt_w = np.sqrt(w_total) Xw = X_solve * sqrt_w[:, None] zw = z * sqrt_w beta_new, _, _, _ = np.linalg.lstsq(Xw, zw, rcond=None) @@ -1593,10 +1596,7 @@ def fit( _use_survey_vcov = self.survey_design.needs_survey_vcov # Canonicalize weights from survey_design to ensure consistency # between coefficient estimation and survey vcov computation - if ( - self.weights is not None - and self.weights is not self.survey_design.weights - ): + if self.weights is not None and self.weights is not self.survey_design.weights: warnings.warn( "Explicit weights= differ from survey_design.weights. " "Using survey_design weights for both coefficient " @@ -1609,9 +1609,7 @@ def fit( self.weight_type = self.survey_design.weight_type if self.weights is not None: - self.weights = _validate_weights( - self.weights, self.weight_type, X.shape[0] - ) + self.weights = _validate_weights(self.weights, self.weight_type, X.shape[0]) # Inject cluster as PSU for survey variance when no PSU specified. # Use a local variable to avoid mutating self.survey_design, which @@ -1622,7 +1620,9 @@ def fit( and _effective_survey_design is not None and _use_survey_vcov ): - from diff_diff.survey import ResolvedSurveyDesign as _RSD, _inject_cluster_as_psu + from diff_diff.survey import ResolvedSurveyDesign as _RSD + from diff_diff.survey import _inject_cluster_as_psu + if isinstance(_effective_survey_design, _RSD) and _effective_survey_design.psu is None: _effective_survey_design = _inject_cluster_as_psu( _effective_survey_design, effective_cluster_ids @@ -1864,9 +1864,7 @@ def get_inference( # Use project-standard NaN-safe inference (returns all-NaN when SE <= 0) from diff_diff.utils import safe_inference - t_stat, p_value, conf_int = safe_inference( - coef, se, alpha=effective_alpha, df=effective_df - ) + t_stat, p_value, conf_int = safe_inference(coef, se, alpha=effective_alpha, df=effective_df) return InferenceResult( coefficient=coef, diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 3be7cb05..5af63d24 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -11,27 +11,28 @@ import numpy as np import pandas as pd from scipy import linalg as scipy_linalg + from diff_diff.linalg import ( - solve_ols, - solve_logit, _check_propensity_diagnostics, _detect_rank_deficiency, _format_dropped_columns, + solve_logit, + solve_ols, ) -from diff_diff.utils import safe_inference, safe_inference_batch - -# Import from split modules -from diff_diff.staggered_results import ( - GroupTimeEffect, - CallawaySantAnnaResults, +from diff_diff.staggered_aggregation import ( + CallawaySantAnnaAggregationMixin, ) from diff_diff.staggered_bootstrap import ( - CSBootstrapResults, CallawaySantAnnaBootstrapMixin, + CSBootstrapResults, ) -from diff_diff.staggered_aggregation import ( - CallawaySantAnnaAggregationMixin, + +# Import from split modules +from diff_diff.staggered_results import ( + CallawaySantAnnaResults, + GroupTimeEffect, ) +from diff_diff.utils import safe_inference, safe_inference_batch # Re-export for backward compatibility __all__ = [ @@ -49,6 +50,7 @@ def _linear_regression( X: np.ndarray, y: np.ndarray, rank_deficient_action: str = "warn", + weights: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, np.ndarray]: """ Fit OLS regression. @@ -64,6 +66,8 @@ def _linear_regression( - "warn": Issue warning and drop linearly dependent columns (default) - "error": Raise ValueError - "silent": Drop columns silently without warning + weights : np.ndarray, optional + Observation weights for WLS. When None, OLS is used. Returns ------- @@ -82,6 +86,7 @@ def _linear_regression( y, return_vcov=False, rank_deficient_action=rank_deficient_action, + weights=weights, ) return beta, residuals @@ -333,6 +338,7 @@ def _precompute_structures( covariates: Optional[List[str]], time_periods: List[Any], treatment_groups: List[Any], + resolved_survey=None, ) -> PrecomputedData: """ Pre-compute data structures for efficient ATT(g,t) computation. @@ -352,7 +358,6 @@ def _precompute_structures( unit_info = df.groupby(unit)[first_treat].first() all_units = unit_info.index.values unit_cohorts = unit_info.values - n_units = len(all_units) # Create unit index mapping for fast lookups unit_to_idx = {u: i for i, u in enumerate(all_units)} @@ -385,6 +390,15 @@ def _precompute_structures( is_balanced = not np.any(np.isnan(outcome_matrix)) + # Extract per-unit survey weights (one weight per unit) + if resolved_survey is not None: + sw_by_unit = ( + pd.Series(resolved_survey.weights, index=df.index).groupby(df[unit]).first() + ) + survey_weights_arr = sw_by_unit.reindex(all_units).values + else: + survey_weights_arr = None + return { "all_units": all_units, "unit_to_idx": unit_to_idx, @@ -396,6 +410,8 @@ def _precompute_structures( "covariate_by_period": covariate_by_period, "time_periods": time_periods, "is_balanced": is_balanced, + "survey_weights": survey_weights_arr, + "df_survey": resolved_survey.df_survey if resolved_survey is not None else None, } def _compute_att_gt_fast( @@ -406,12 +422,22 @@ def _compute_att_gt_fast( covariates: Optional[List[str]], pscore_cache: Optional[Dict] = None, cho_cache: Optional[Dict] = None, - ) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]]]: + ) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]], Optional[float]]: """ Compute ATT(g,t) using pre-computed data structures (fast version). Uses vectorized numpy operations on pre-pivoted outcome matrix instead of repeated pandas filtering. + + Returns + ------- + att_gt : float or None + se_gt : float + n_treated : int + n_control : int + inf_func_info : dict or None + survey_weight_sum : float or None + Sum of survey weights for treated units (for aggregation weighting). """ period_to_col = precomputed["period_to_col"] outcome_matrix = precomputed["outcome_matrix"] @@ -434,11 +460,11 @@ def _compute_att_gt_fast( if base_period_val not in period_to_col: # Base period must exist; no fallback to maintain methodological consistency - return None, 0.0, 0, 0, None + return None, 0.0, 0, 0, None, None # Check if periods exist in the data if base_period_val not in period_to_col or t not in period_to_col: - return None, 0.0, 0, 0, None + return None, 0.0, 0, 0, None, None base_col = period_to_col[base_period_val] post_col = period_to_col[t] @@ -476,12 +502,17 @@ def _compute_att_gt_fast( n_control = np.sum(control_valid) if n_treated == 0 or n_control == 0: - return None, 0.0, 0, 0, None + return None, 0.0, 0, 0, None, None # Extract outcome changes for treated and control treated_change = outcome_change[treated_valid] control_change = outcome_change[control_valid] + # Extract survey weights for treated and control + survey_w = precomputed.get("survey_weights") + sw_treated = survey_w[treated_valid] if survey_w is not None else None + sw_control = survey_w[control_valid] if survey_w is not None else None + # Get covariates if specified (from the base period) X_treated = None X_control = None @@ -522,9 +553,15 @@ def _compute_att_gt_fast( # Estimation method if self.estimation_method == "reg": att_gt, se_gt, inf_func = self._outcome_regression( - treated_change, control_change, X_treated, X_control + treated_change, + control_change, + X_treated, + X_control, + sw_treated=sw_treated, + sw_control=sw_control, ) elif self.estimation_method == "ipw": + sw_all = np.concatenate([sw_treated, sw_control]) if sw_treated is not None else None att_gt, se_gt, inf_func = self._ipw_estimation( treated_change, control_change, @@ -534,8 +571,12 @@ def _compute_att_gt_fast( X_control, pscore_cache=pscore_cache, pscore_key=pscore_key, + sw_treated=sw_treated, + sw_control=sw_control, + sw_all=sw_all, ) else: # doubly robust + sw_all = np.concatenate([sw_treated, sw_control]) if sw_treated is not None else None att_gt, se_gt, inf_func = self._doubly_robust( treated_change, control_change, @@ -545,6 +586,9 @@ def _compute_att_gt_fast( pscore_key=pscore_key, cho_cache=cho_cache, cho_key=cho_key, + sw_treated=sw_treated, + sw_control=sw_control, + sw_all=sw_all, ) # Package influence function info with index arrays (positions into @@ -563,7 +607,8 @@ def _compute_att_gt_fast( "control_inf": inf_func[n_t:], } - return att_gt, se_gt, int(n_treated), int(n_control), inf_func_info + sw_sum = float(np.sum(sw_treated)) if sw_treated is not None else None + return att_gt, se_gt, int(n_treated), int(n_control), inf_func_info, sw_sum def _compute_all_att_gt_vectorized( self, @@ -590,6 +635,7 @@ def _compute_all_att_gt_vectorized( cohort_masks = precomputed["cohort_masks"] never_treated_mask = precomputed["never_treated_mask"] unit_cohorts = precomputed["unit_cohorts"] + survey_w = precomputed.get("survey_weights") group_time_effects = {} influence_func_info = {} @@ -618,7 +664,9 @@ def _compute_all_att_gt_vectorized( if base_period_val not in period_to_col or t not in period_to_col: continue - tasks.append((g, t, period_to_col[base_period_val], period_to_col[t], base_period_val)) + tasks.append( + (g, t, period_to_col[base_period_val], period_to_col[t], base_period_val) + ) # Process all tasks atts = [] @@ -658,17 +706,36 @@ def _compute_all_att_gt_vectorized( n_c = int(n_control) # Inline no-covariates regression (difference in means) - att = float(np.mean(treated_change) - np.mean(control_change)) + if survey_w is not None: + sw_t = survey_w[treated_valid] + sw_c = survey_w[control_valid] + sw_t_norm = sw_t / np.sum(sw_t) + sw_c_norm = sw_c / np.sum(sw_c) + mu_t = float(np.sum(sw_t_norm * treated_change)) + mu_c = float(np.sum(sw_c_norm * control_change)) + att = mu_t - mu_c + + var_t = float(np.sum(sw_t_norm * (treated_change - mu_t) ** 2)) + var_c = float(np.sum(sw_c_norm * (control_change - mu_c) ** 2)) + se = float(np.sqrt(var_t + var_c)) if (n_t > 0 and n_c > 0) else 0.0 + + # Influence function (survey-weighted) + inf_treated = sw_t_norm * (treated_change - mu_t) + inf_control = -sw_c_norm * (control_change - mu_c) + sw_sum = float(np.sum(sw_t)) + else: + att = float(np.mean(treated_change) - np.mean(control_change)) - var_t = float(np.var(treated_change, ddof=1)) if n_t > 1 else 0.0 - var_c = float(np.var(control_change, ddof=1)) if n_c > 1 else 0.0 - se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0 + var_t = float(np.var(treated_change, ddof=1)) if n_t > 1 else 0.0 + var_c = float(np.var(control_change, ddof=1)) if n_c > 1 else 0.0 + se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0 - # Influence function - inf_treated = (treated_change - np.mean(treated_change)) / n_t - inf_control = -(control_change - np.mean(control_change)) / n_c + # Influence function + inf_treated = (treated_change - np.mean(treated_change)) / n_t + inf_control = -(control_change - np.mean(control_change)) / n_c + sw_sum = None - group_time_effects[(g, t)] = { + gte_entry = { "effect": att, "se": se, # t_stat, p_value, conf_int filled by batch inference below @@ -678,6 +745,9 @@ def _compute_all_att_gt_vectorized( "n_treated": n_t, "n_control": n_c, } + if sw_sum is not None: + gte_entry["survey_weight_sum"] = sw_sum + group_time_effects[(g, t)] = gte_entry all_units = precomputed["all_units"] treated_positions = np.where(treated_valid)[0] @@ -697,8 +767,12 @@ def _compute_all_att_gt_vectorized( # Batch inference for all (g,t) pairs at once if task_keys: + df_survey_val = precomputed.get("df_survey") t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch( - np.array(atts), np.array(ses), alpha=self.alpha + np.array(atts), + np.array(ses), + alpha=self.alpha, + df=df_survey_val, ) for idx, key in enumerate(task_keys): group_time_effects[key]["t_stat"] = float(t_stats[idx]) @@ -1048,6 +1122,7 @@ def fit( covariates: Optional[List[str]] = None, aggregate: Optional[str] = None, balance_e: Optional[int] = None, + survey_design: object = None, ) -> CallawaySantAnnaResults: """ Fit the Callaway-Sant'Anna estimator. @@ -1077,6 +1152,9 @@ def fit( balance_e : int, optional For event study, balance the panel at relative time e. Ensures all groups contribute to each relative period. + survey_design : SurveyDesign, optional + Survey design specification for design-based inference. + Supports weights, strata, PSU, and FPC. Returns ------- @@ -1096,6 +1174,29 @@ def fit( if covariates is not None and len(covariates) == 0: covariates = None + # Resolve survey design if provided + from diff_diff.survey import ( + _inject_cluster_as_psu, + _resolve_effective_cluster, + _resolve_survey_for_fit, + _validate_unit_constant_survey, + ) + + resolved_survey, survey_weights, survey_weight_type, survey_metadata = ( + _resolve_survey_for_fit(survey_design, data, "analytical") + ) + + # Validate within-unit constancy for panel survey designs + if resolved_survey is not None: + _validate_unit_constant_survey(data, unit, survey_design) + + # Guard bootstrap + survey + if self.n_bootstrap > 0 and resolved_survey is not None: + raise NotImplementedError( + "Bootstrap inference with survey weights is not yet supported " + "for CallawaySantAnna. Use analytical inference (n_bootstrap=0)." + ) + # Validate inputs required_cols = [outcome, unit, time, first_treat] if covariates: @@ -1147,13 +1248,46 @@ def fit( "cohorts when there are no never-treated units." ) + # Resolve effective cluster and inject cluster-as-PSU for survey variance + if resolved_survey is not None: + cluster_var = self.cluster if self.cluster is not None else unit + cluster_ids_raw = df[cluster_var].values if cluster_var in df.columns else None + effective_cluster_ids = _resolve_effective_cluster( + resolved_survey, + cluster_ids_raw, + cluster_var if self.cluster is not None else None, + ) + resolved_survey = _inject_cluster_as_psu(resolved_survey, effective_cluster_ids) + # Recompute metadata after PSU injection + if resolved_survey.psu is not None and survey_metadata is not None: + from diff_diff.survey import compute_survey_metadata + + raw_w = ( + data[survey_design.weights].values.astype(np.float64) + if survey_design.weights + else np.ones(len(data), dtype=np.float64) + ) + survey_metadata = compute_survey_metadata(resolved_survey, raw_w) + # Pre-compute data structures for efficient ATT(g,t) computation precomputed = self._precompute_structures( - df, outcome, unit, time, first_treat, covariates, time_periods, treatment_groups + df, + outcome, + unit, + time, + first_treat, + covariates, + time_periods, + treatment_groups, + resolved_survey=resolved_survey, ) + # Survey df for safe_inference calls + df_survey = resolved_survey.df_survey if resolved_survey is not None else None + # Compute ATT(g,t) for each group-time combination min_period = min(time_periods) + has_survey = resolved_survey is not None if covariates is None and self.estimation_method == "reg": # Fast vectorized path for the common no-covariates regression case @@ -1164,6 +1298,7 @@ def fit( covariates is not None and self.estimation_method == "reg" and self.rank_deficient_action != "error" + and not has_survey # Cholesky cache uses X'X; survey needs X'WX ): # Optimized covariate regression path with Cholesky caching group_time_effects, influence_func_info = self._compute_all_att_gt_covariate_reg( @@ -1177,12 +1312,14 @@ def fit( # Propensity score cache for IPW/DR with covariates pscore_cache = {} if (covariates and self.estimation_method in ("ipw", "dr")) else None # Cholesky cache for DR outcome regression component + # Skip cache when survey weights present (X'WX differs from X'X) cho_cache = ( {} if ( covariates and self.estimation_method == "dr" and self.rank_deficient_action != "error" + and not has_survey ) else None ) @@ -1197,7 +1334,7 @@ def fit( ] for t in valid_periods: - att_gt, se_gt, n_treat, n_ctrl, inf_info = self._compute_att_gt_fast( + att_gt, se_gt, n_treat, n_ctrl, inf_info, sw_sum = self._compute_att_gt_fast( precomputed, g, t, @@ -1207,9 +1344,14 @@ def fit( ) if att_gt is not None: - t_stat, p_val, ci = safe_inference(att_gt, se_gt, alpha=self.alpha) + t_stat, p_val, ci = safe_inference( + att_gt, + se_gt, + alpha=self.alpha, + df=df_survey, + ) - group_time_effects[(g, t)] = { + gte_entry = { "effect": att_gt, "se": se_gt, "t_stat": t_stat, @@ -1218,6 +1360,9 @@ def fit( "n_treated": n_treat, "n_control": n_ctrl, } + if sw_sum is not None: + gte_entry["survey_weight_sum"] = sw_sum + group_time_effects[(g, t)] = gte_entry if inf_info is not None: influence_func_info[(g, t)] = inf_info @@ -1232,7 +1377,12 @@ def fit( overall_att, overall_se = self._aggregate_simple( group_time_effects, influence_func_info, df, unit, precomputed ) - overall_t, overall_p, overall_ci = safe_inference(overall_att, overall_se, alpha=self.alpha) + overall_t, overall_p, overall_ci = safe_inference( + overall_att, + overall_se, + alpha=self.alpha, + df=df_survey, + ) # Compute additional aggregations if requested event_study_effects = None @@ -1383,6 +1533,7 @@ def fit( bootstrap_results=bootstrap_results, cband_crit_value=cband_crit_value, pscore_trim=self.pscore_trim, + survey_metadata=survey_metadata, ) self.is_fitted_ = True @@ -1394,6 +1545,8 @@ def _outcome_regression( control_change: np.ndarray, X_treated: Optional[np.ndarray] = None, X_control: Optional[np.ndarray] = None, + sw_treated: Optional[np.ndarray] = None, + sw_control: Optional[np.ndarray] = None, ) -> Tuple[float, float, np.ndarray]: """ Estimate ATT using outcome regression. @@ -1405,6 +1558,11 @@ def _outcome_regression( Without covariates: Simple difference in means. + + Parameters + ---------- + sw_treated, sw_control : np.ndarray, optional + Survey weights for treated and control units. """ n_t = len(treated_change) n_c = len(control_change) @@ -1416,43 +1574,70 @@ def _outcome_regression( X_control, control_change, rank_deficient_action=self.rank_deficient_action, + weights=sw_control, ) # Predict counterfactual for treated units X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated]) predicted_control = np.dot(X_treated_with_intercept, beta) - # ATT = mean(observed treated change - predicted counterfactual) - att = np.mean(treated_change - predicted_control) - - # Standard error using sandwich estimator - # Variance from treated: Var(Y_1 - m(X)) + # ATT: survey-weighted mean of treated residuals treated_residuals = treated_change - predicted_control - var_t = np.var(treated_residuals, ddof=1) if n_t > 1 else 0.0 - # Variance from control regression (residual variance) - var_c = np.var(residuals, ddof=1) if n_c > 1 else 0.0 + if sw_treated is not None: + sw_t_norm = sw_treated / np.sum(sw_treated) + sw_c_norm = sw_control / np.sum(sw_control) + att = float(np.sum(sw_t_norm * treated_residuals)) + + # Influence function (survey-weighted) + inf_treated = sw_t_norm * (treated_residuals - att) + inf_control = -sw_c_norm * (control_change - np.sum(sw_c_norm * control_change)) + inf_func = np.concatenate([inf_treated, inf_control]) + + # SE from influence function variance + var_psi = np.sum(inf_treated**2) + np.sum(inf_control**2) + se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0 + else: + att = float(np.mean(treated_residuals)) - # Approximate SE (ignoring estimation error in beta for simplicity) - se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0 + # Standard error using sandwich estimator + var_t = np.var(treated_residuals, ddof=1) if n_t > 1 else 0.0 + var_c = np.var(residuals, ddof=1) if n_c > 1 else 0.0 + se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0 - # Influence function - inf_treated = (treated_residuals - np.mean(treated_residuals)) / n_t - inf_control = -residuals / n_c - inf_func = np.concatenate([inf_treated, inf_control]) + # Influence function + inf_treated = (treated_residuals - np.mean(treated_residuals)) / n_t + inf_control = -residuals / n_c + inf_func = np.concatenate([inf_treated, inf_control]) else: # Simple difference in means (no covariates) - att = np.mean(treated_change) - np.mean(control_change) - - var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 - var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 + if sw_treated is not None: + sw_t_norm = sw_treated / np.sum(sw_treated) + sw_c_norm = sw_control / np.sum(sw_control) + mu_t = float(np.sum(sw_t_norm * treated_change)) + mu_c = float(np.sum(sw_c_norm * control_change)) + att = mu_t - mu_c + + # Influence function (survey-weighted) + inf_treated = sw_t_norm * (treated_change - mu_t) + inf_control = -sw_c_norm * (control_change - mu_c) + inf_func = np.concatenate([inf_treated, inf_control]) + + # SE from influence function variance + var_t = float(np.sum(sw_t_norm * (treated_change - mu_t) ** 2)) + var_c = float(np.sum(sw_c_norm * (control_change - mu_c) ** 2)) + se = float(np.sqrt(var_t + var_c)) if (n_t > 0 and n_c > 0) else 0.0 + else: + att = float(np.mean(treated_change) - np.mean(control_change)) - se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0 + var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 + var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 + se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0 - # Influence function (for aggregation) - inf_treated = treated_change - np.mean(treated_change) - inf_control = control_change - np.mean(control_change) - inf_func = np.concatenate([inf_treated / n_t, -inf_control / n_c]) + # Influence function (for aggregation) + inf_treated = treated_change - np.mean(treated_change) + inf_control = control_change - np.mean(control_change) + inf_func = np.concatenate([inf_treated / n_t, -inf_control / n_c]) return att, se, inf_func @@ -1466,6 +1651,9 @@ def _ipw_estimation( X_control: Optional[np.ndarray] = None, pscore_cache: Optional[Dict] = None, pscore_key: Optional[Any] = None, + sw_treated: Optional[np.ndarray] = None, + sw_control: Optional[np.ndarray] = None, + sw_all: Optional[np.ndarray] = None, ) -> Tuple[float, float, np.ndarray]: """ Estimate ATT using inverse probability weighting. @@ -1477,6 +1665,11 @@ def _ipw_estimation( Without covariates: Simple difference in means with unconditional propensity weighting. + + Parameters + ---------- + sw_treated, sw_control, sw_all : np.ndarray, optional + Survey weights for treated, control, and all units. """ n_t = len(treated_change) n_c = len(control_change) @@ -1508,6 +1701,7 @@ def _ipw_estimation( X_all, D, rank_deficient_action=self.rank_deficient_action, + weights=sw_all, ) _check_propensity_diagnostics(pscore, self.pscore_trim) # Cache the fitted coefficients @@ -1533,51 +1727,88 @@ def _ipw_estimation( pscore_control = np.clip(pscore_control, self.pscore_trim, 1 - self.pscore_trim) pscore_treated = np.clip(pscore_treated, self.pscore_trim, 1 - self.pscore_trim) - # IPW weights for control units: p(X) / (1 - p(X)) - # This reweights controls to have same covariate distribution as treated - weights_control = pscore_control / (1 - pscore_control) - weights_control = weights_control / np.sum(weights_control) # normalize + if sw_treated is not None: + # IPW weights compose with survey weights: + # w_i = sw_i * p(X_i) / (1 - p(X_i)) + weights_control = sw_control * pscore_control / (1 - pscore_control) + weights_control_norm = weights_control / np.sum(weights_control) + + # ATT: survey-weighted treated mean minus composite-weighted control mean + sw_t_norm = sw_treated / np.sum(sw_treated) + mu_t = float(np.sum(sw_t_norm * treated_change)) + att = mu_t - float(np.sum(weights_control_norm * control_change)) + + # Influence function (survey-weighted) + inf_treated = sw_t_norm * (treated_change - mu_t) + inf_control = -weights_control_norm * ( + control_change - np.sum(weights_control_norm * control_change) + ) + inf_func = np.concatenate([inf_treated, inf_control]) - # ATT = mean(treated) - weighted_mean(control) - att = np.mean(treated_change) - np.sum(weights_control * control_change) + # SE from influence function variance + var_psi = np.sum(inf_treated**2) + np.sum(inf_control**2) + se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0 + else: + # IPW weights for control units: p(X) / (1 - p(X)) + # This reweights controls to have same covariate distribution as treated + weights_control = pscore_control / (1 - pscore_control) + weights_control = weights_control / np.sum(weights_control) # normalize - # Compute standard error - # Variance of treated mean - var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 + # ATT = mean(treated) - weighted_mean(control) + att = float(np.mean(treated_change) - np.sum(weights_control * control_change)) - # Variance of weighted control mean - weighted_var_c = np.sum( - weights_control * (control_change - np.sum(weights_control * control_change)) ** 2 - ) + # Compute standard error + var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 + + weighted_var_c = np.sum( + weights_control + * (control_change - np.sum(weights_control * control_change)) ** 2 + ) - se = np.sqrt(var_t / n_t + weighted_var_c) if (n_t > 0 and n_c > 0) else 0.0 + se = float(np.sqrt(var_t / n_t + weighted_var_c)) if (n_t > 0 and n_c > 0) else 0.0 - # Influence function - inf_treated = (treated_change - np.mean(treated_change)) / n_t - inf_control = -weights_control * ( - control_change - np.sum(weights_control * control_change) - ) - inf_func = np.concatenate([inf_treated, inf_control]) + # Influence function + inf_treated = (treated_change - np.mean(treated_change)) / n_t + inf_control = -weights_control * ( + control_change - np.sum(weights_control * control_change) + ) + inf_func = np.concatenate([inf_treated, inf_control]) else: # Unconditional IPW (reduces to difference in means) - p_treat = n_treated / n_total # unconditional propensity score + if sw_treated is not None: + # Survey-weighted difference in means + sw_t_norm = sw_treated / np.sum(sw_treated) + sw_c_norm = sw_control / np.sum(sw_control) + mu_t = float(np.sum(sw_t_norm * treated_change)) + mu_c = float(np.sum(sw_c_norm * control_change)) + att = mu_t - mu_c + + inf_treated = sw_t_norm * (treated_change - mu_t) + inf_control = -sw_c_norm * (control_change - mu_c) + inf_func = np.concatenate([inf_treated, inf_control]) + + var_t = float(np.sum(sw_t_norm * (treated_change - mu_t) ** 2)) + var_c = float(np.sum(sw_c_norm * (control_change - mu_c) ** 2)) + se = float(np.sqrt(var_t + var_c)) if (n_t > 0 and n_c > 0) else 0.0 + else: + p_treat = n_treated / n_total # unconditional propensity score - att = np.mean(treated_change) - np.mean(control_change) + att = float(np.mean(treated_change) - np.mean(control_change)) - var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 - var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 + var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 + var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 - # Adjusted variance for IPW - se = ( - np.sqrt(var_t / n_t + var_c * (1 - p_treat) / (n_c * p_treat)) - if (n_t > 0 and n_c > 0 and p_treat > 0) - else 0.0 - ) + # Adjusted variance for IPW + se = float( + np.sqrt(var_t / n_t + var_c * (1 - p_treat) / (n_c * p_treat)) + if (n_t > 0 and n_c > 0 and p_treat > 0) + else 0.0 + ) - # Influence function (for aggregation) - inf_treated = (treated_change - np.mean(treated_change)) / n_t - inf_control = (control_change - np.mean(control_change)) / n_c - inf_func = np.concatenate([inf_treated, -inf_control]) + # Influence function (for aggregation) + inf_treated = (treated_change - np.mean(treated_change)) / n_t + inf_control = (control_change - np.mean(control_change)) / n_c + inf_func = np.concatenate([inf_treated, -inf_control]) return att, se, inf_func @@ -1591,6 +1822,9 @@ def _doubly_robust( pscore_key: Optional[Any] = None, cho_cache: Optional[Dict] = None, cho_key: Optional[Any] = None, + sw_treated: Optional[np.ndarray] = None, + sw_control: Optional[np.ndarray] = None, + sw_all: Optional[np.ndarray] = None, ) -> Tuple[float, float, np.ndarray]: """ Estimate ATT using doubly robust estimation. @@ -1607,6 +1841,11 @@ def _doubly_robust( Without covariates: Reduces to simple difference in means. + + Parameters + ---------- + sw_treated, sw_control, sw_all : np.ndarray, optional + Survey weights for treated, control, and all units. """ n_t = len(treated_change) n_c = len(control_change) @@ -1614,7 +1853,7 @@ def _doubly_robust( if X_treated is not None and X_control is not None and X_treated.shape[1] > 0: # Doubly robust estimation with covariates # Step 1: Outcome regression - fit E[Delta Y | X] on control - # Try Cholesky cache for outcome regression + # Try Cholesky cache for outcome regression (disabled when survey weights present) beta = None X_control_with_intercept = np.column_stack([np.ones(n_c), X_control]) if cho_cache is not None and cho_key is not None: @@ -1650,6 +1889,7 @@ def _doubly_robust( X_control, control_change, rank_deficient_action=self.rank_deficient_action, + weights=sw_control, ) # Zero NaN coefficients for prediction only — dropped columns # contribute 0 to the column space projection. Note: solve_ols @@ -1684,6 +1924,7 @@ def _doubly_robust( X_all, D, rank_deficient_action=self.rank_deficient_action, + weights=sw_all, ) _check_propensity_diagnostics(pscore, self.pscore_trim) if pscore_cache is not None and pscore_key is not None: @@ -1705,43 +1946,73 @@ def _doubly_robust( # Clip propensity scores pscore_control = np.clip(pscore_control, self.pscore_trim, 1 - self.pscore_trim) - # IPW weights for control: p(X) / (1 - p(X)) - weights_control = pscore_control / (1 - pscore_control) + if sw_treated is not None: + # IPW weights compose with survey weights + weights_control = sw_control * pscore_control / (1 - pscore_control) + + # Step 3: DR ATT (survey-weighted) + sw_t_sum = np.sum(sw_treated) + att_treated_part = float( + np.sum(sw_treated * (treated_change - m_treated)) / sw_t_sum + ) + augmentation = float( + np.sum(weights_control * (m_control - control_change)) / sw_t_sum + ) + att = att_treated_part + augmentation - # Step 3: Doubly robust ATT - # ATT = mean(treated - m(X_treated)) - # + weighted_mean_control((m(X) - Y) * weight) - att_treated_part = np.mean(treated_change - m_treated) + # Step 4: Influence function (survey-weighted DR) + psi_treated = (sw_treated / sw_t_sum) * (treated_change - m_treated - att) + psi_control = (weights_control / sw_t_sum) * (m_control - control_change) - # Augmentation term from control - augmentation = np.sum(weights_control * (m_control - control_change)) / n_t + var_psi = np.sum(psi_treated**2) + np.sum(psi_control**2) + se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0 - att = att_treated_part + augmentation + inf_func = np.concatenate([psi_treated, psi_control]) + else: + # IPW weights for control: p(X) / (1 - p(X)) + weights_control = pscore_control / (1 - pscore_control) - # Step 4: Standard error using influence function - # Influence function for DR estimator - psi_treated = (treated_change - m_treated - att) / n_t - psi_control = (weights_control * (m_control - control_change)) / n_t + # Step 3: Doubly robust ATT + att_treated_part = float(np.mean(treated_change - m_treated)) + augmentation = float(np.sum(weights_control * (m_control - control_change)) / n_t) + att = att_treated_part + augmentation - # Variance is sum of squared influence functions - var_psi = np.sum(psi_treated**2) + np.sum(psi_control**2) - se = np.sqrt(var_psi) if var_psi > 0 else 0.0 + # Step 4: Standard error using influence function + psi_treated = (treated_change - m_treated - att) / n_t + psi_control = (weights_control * (m_control - control_change)) / n_t - # Full influence function - inf_func = np.concatenate([psi_treated, psi_control]) + var_psi = np.sum(psi_treated**2) + np.sum(psi_control**2) + se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0 + + inf_func = np.concatenate([psi_treated, psi_control]) else: # Without covariates, DR simplifies to difference in means - att = np.mean(treated_change) - np.mean(control_change) + if sw_treated is not None: + sw_t_norm = sw_treated / np.sum(sw_treated) + sw_c_norm = sw_control / np.sum(sw_control) + mu_t = float(np.sum(sw_t_norm * treated_change)) + mu_c = float(np.sum(sw_c_norm * control_change)) + att = mu_t - mu_c + + inf_treated = sw_t_norm * (treated_change - mu_t) + inf_control = -sw_c_norm * (control_change - mu_c) + inf_func = np.concatenate([inf_treated, inf_control]) + + var_t = float(np.sum(sw_t_norm * (treated_change - mu_t) ** 2)) + var_c = float(np.sum(sw_c_norm * (control_change - mu_c) ** 2)) + se = float(np.sqrt(var_t + var_c)) if (n_t > 0 and n_c > 0) else 0.0 + else: + att = float(np.mean(treated_change) - np.mean(control_change)) - var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 - var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 + var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 + var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 - se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0 + se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0 - # Influence function for DR estimator - inf_treated = (treated_change - np.mean(treated_change)) / n_t - inf_control = (control_change - np.mean(control_change)) / n_c - inf_func = np.concatenate([inf_treated, -inf_control]) + # Influence function for DR estimator + inf_treated = (treated_change - np.mean(treated_change)) / n_t + inf_control = (control_change - np.mean(control_change)) / n_c + inf_func = np.concatenate([inf_treated, -inf_control]) return att, se, inf_func diff --git a/diff_diff/staggered_aggregation.py b/diff_diff/staggered_aggregation.py index 7faf043f..11c06527 100644 --- a/diff_diff/staggered_aggregation.py +++ b/diff_diff/staggered_aggregation.py @@ -67,19 +67,24 @@ def _aggregate_simple( # Pre-treatment effects are for parallel trends, not overall ATT if t < g - self.anticipation: continue - effects.append(data['effect']) - weights_list.append(data['n_treated']) + effects.append(data["effect"]) + # Use survey_weight_sum for aggregation when available + if data.get("survey_weight_sum") is not None: + weights_list.append(data["survey_weight_sum"]) + else: + weights_list.append(data["n_treated"]) gt_pairs.append((g, t)) groups_for_gt.append(g) # Guard against empty post-treatment set if len(effects) == 0: import warnings + warnings.warn( "No post-treatment effects available for overall ATT aggregation. " "This can occur when cohorts lack post-treatment periods in the data.", UserWarning, - stacklevel=2 + stacklevel=2, ) return np.nan, np.nan @@ -92,6 +97,7 @@ def _aggregate_simple( n_nan = int(np.sum(~finite_mask)) if n_nan > 0: import warnings + warnings.warn( f"{n_nan} group-time effect(s) are NaN and excluded from overall ATT " "aggregation. Inspect group_time_effects for details.", @@ -105,6 +111,7 @@ def _aggregate_simple( if len(effects) == 0: import warnings + warnings.warn( "All post-treatment effects are NaN. Cannot compute overall ATT.", UserWarning, @@ -121,8 +128,14 @@ def _aggregate_simple( # Compute SE using influence function aggregation with wif adjustment overall_se = self._compute_aggregated_se_with_wif( - gt_pairs, weights_norm, effects, groups_for_gt, - influence_func_info, df, unit, precomputed + gt_pairs, + weights_norm, + effects, + groups_for_gt, + influence_func_info, + df, + unit, + precomputed, ) return overall_att, overall_se @@ -158,13 +171,13 @@ def _compute_aggregated_se( if n_units is None: # Fallback: infer size from influence function info max_idx = 0 - for (g, t) in gt_pairs: + for g, t in gt_pairs: if (g, t) in influence_func_info: info = influence_func_info[(g, t)] - if len(info['treated_idx']) > 0: - max_idx = max(max_idx, info['treated_idx'].max()) - if len(info['control_idx']) > 0: - max_idx = max(max_idx, info['control_idx'].max()) + if len(info["treated_idx"]) > 0: + max_idx = max(max_idx, info["treated_idx"].max()) + if len(info["control_idx"]) > 0: + max_idx = max(max_idx, info["control_idx"].max()) n_units = max_idx + 1 if n_units == 0: @@ -181,16 +194,16 @@ def _compute_aggregated_se( w = weights[j] # Vectorized influence function aggregation using index arrays - treated_idx = info['treated_idx'] + treated_idx = info["treated_idx"] if len(treated_idx) > 0: - np.add.at(psi_overall, treated_idx, w * info['treated_inf']) + np.add.at(psi_overall, treated_idx, w * info["treated_inf"]) - control_idx = info['control_idx'] + control_idx = info["control_idx"] if len(control_idx) > 0: - np.add.at(psi_overall, control_idx, w * info['control_inf']) + np.add.at(psi_overall, control_idx, w * info["control_inf"]) # Compute variance: Var(θ̄) = (1/n) Σᵢ ψᵢ² - variance = np.sum(psi_overall ** 2) + variance = np.sum(psi_overall**2) return np.sqrt(variance) def _compute_combined_influence_function( @@ -228,39 +241,49 @@ def _compute_combined_influence_function( # Build unit index mapping (local or global) if global_unit_to_idx is not None and n_global_units is not None: - unit_to_idx = global_unit_to_idx n_units = n_global_units all_units = None # caller already has the unit list else: all_units_set: Set[Any] = set() - for (g, t) in gt_pairs: + for g, t in gt_pairs: if (g, t) in influence_func_info: info = influence_func_info[(g, t)] - all_units_set.update(info['treated_units']) - all_units_set.update(info['control_units']) + all_units_set.update(info["treated_units"]) + all_units_set.update(info["control_units"]) if not all_units_set: return np.zeros(0), [] all_units = sorted(all_units_set) n_units = len(all_units) - unit_to_idx = {u: i for i, u in enumerate(all_units)} - # Get unique groups and their information unique_groups = sorted(set(groups_for_gt)) unique_groups_set = set(unique_groups) group_to_idx = {g: i for i, g in enumerate(unique_groups)} + # Check for survey weights in precomputed data + survey_w = precomputed.get("survey_weights") if precomputed is not None else None + # Compute group-level probabilities matching R's formula: # pg[g] = n_g / n_all (fraction of ALL units in group g) + # With survey weights: pg[g] = sum(sw_g) / sum(sw_all) group_sizes = {} - for g in unique_groups: - treated_in_g = df[df['first_treat'] == g][unit].nunique() - group_sizes[g] = treated_in_g + if survey_w is not None: + # Survey-weighted group sizes + precomputed_cohorts = precomputed["unit_cohorts"] + for g in unique_groups: + mask_g = precomputed_cohorts == g + group_sizes[g] = float(np.sum(survey_w[mask_g])) + total_weight = float(np.sum(survey_w)) + else: + for g in unique_groups: + treated_in_g = df[df["first_treat"] == g][unit].nunique() + group_sizes[g] = treated_in_g + total_weight = float(n_units) # pg indexed by group - pg_by_group = np.array([group_sizes[g] / n_units for g in unique_groups]) + pg_by_group = np.array([group_sizes[g] / total_weight for g in unique_groups]) # pg indexed by keeper (each (g,t) pair gets its group's pg) pg_keepers = np.array([pg_by_group[group_to_idx[g]] for g in groups_for_gt]) @@ -281,13 +304,13 @@ def _compute_combined_influence_function( w = weights[j] # Vectorized influence function aggregation using precomputed index arrays - treated_idx = info['treated_idx'] + treated_idx = info["treated_idx"] if len(treated_idx) > 0: - np.add.at(psi_standard, treated_idx, w * info['treated_inf']) + np.add.at(psi_standard, treated_idx, w * info["treated_inf"]) - control_idx = info['control_idx'] + control_idx = info["control_idx"] if len(control_idx) > 0: - np.add.at(psi_standard, control_idx, w * info['control_inf']) + np.add.at(psi_standard, control_idx, w * info["control_inf"]) # Build unit-group array: normalize iterator to (idx, uid) pairs unit_groups_array = np.full(n_units, -1, dtype=np.float64) @@ -298,8 +321,8 @@ def _compute_combined_influence_function( ) if precomputed is not None: - precomputed_cohorts = precomputed['unit_cohorts'] - precomputed_unit_to_idx = precomputed['unit_to_idx'] + precomputed_cohorts = precomputed["unit_cohorts"] + precomputed_unit_to_idx = precomputed["unit_to_idx"] for idx, uid in idx_uid_pairs: if uid in precomputed_unit_to_idx: cohort = precomputed_cohorts[precomputed_unit_to_idx[uid]] @@ -307,31 +330,58 @@ def _compute_combined_influence_function( unit_groups_array[idx] = cohort else: for idx, uid in idx_uid_pairs: - unit_first_treat = df[df[unit] == uid]['first_treat'].iloc[0] + unit_first_treat = df[df[unit] == uid]["first_treat"].iloc[0] if unit_first_treat in unique_groups_set: unit_groups_array[idx] = unit_first_treat # Vectorized WIF computation groups_for_gt_array = np.array(groups_for_gt) - indicator_matrix = (unit_groups_array[:, np.newaxis] == groups_for_gt_array[np.newaxis, :]).astype(np.float64) - indicator_sum = np.sum(indicator_matrix - pg_keepers, axis=1) + indicator_matrix = ( + unit_groups_array[:, np.newaxis] == groups_for_gt_array[np.newaxis, :] + ).astype(np.float64) + + if survey_w is not None: + # Survey-weighted WIF: indicator entries are sw_i / sum(sw_all) + # Build per-unit weight vector aligned to our index space + if global_unit_to_idx is not None and precomputed is not None: + unit_sw = np.zeros(n_units) + precomputed_unit_to_idx_local = precomputed["unit_to_idx"] + for idx, uid in idx_uid_pairs: + if uid in precomputed_unit_to_idx_local: + pc_idx = precomputed_unit_to_idx_local[uid] + unit_sw[idx] = survey_w[pc_idx] + else: + unit_sw = np.ones(n_units) + + # Weighted indicator: sw_i * 1{G_i == g_k} / sum(sw_all) + weighted_indicator = indicator_matrix * (unit_sw / total_weight)[:, np.newaxis] + indicator_sum_w = np.sum(weighted_indicator - pg_keepers, axis=1) + + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): + if1_matrix = (weighted_indicator - pg_keepers) / sum_pg_keepers + if2_matrix = np.outer(indicator_sum_w, pg_keepers) / (sum_pg_keepers**2) + wif_matrix = if1_matrix - if2_matrix + wif_contrib = wif_matrix @ effects + else: + indicator_sum = np.sum(indicator_matrix - pg_keepers, axis=1) - with np.errstate(divide='ignore', invalid='ignore', over='ignore'): - if1_matrix = (indicator_matrix - pg_keepers) / sum_pg_keepers - if2_matrix = np.outer(indicator_sum, pg_keepers) / (sum_pg_keepers ** 2) - wif_matrix = if1_matrix - if2_matrix - wif_contrib = wif_matrix @ effects + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): + if1_matrix = (indicator_matrix - pg_keepers) / sum_pg_keepers + if2_matrix = np.outer(indicator_sum, pg_keepers) / (sum_pg_keepers**2) + wif_matrix = if1_matrix - if2_matrix + wif_contrib = wif_matrix @ effects # Check for non-finite values from edge cases if not np.all(np.isfinite(wif_contrib)): import warnings + n_nonfinite = np.sum(~np.isfinite(wif_contrib)) warnings.warn( f"Non-finite values ({n_nonfinite}/{len(wif_contrib)}) in weight influence " "function computation. This may occur with very small samples or extreme " "weights. Returning NaN for SE to signal invalid inference.", RuntimeWarning, - stacklevel=2 + stacklevel=2, ) nan_result = np.full(n_units, np.nan) return nan_result, all_units @@ -372,14 +422,20 @@ def _compute_aggregated_se_with_wif( global_unit_to_idx = None n_global_units = None if precomputed is not None: - global_unit_to_idx = precomputed['unit_to_idx'] - n_global_units = len(precomputed['all_units']) + global_unit_to_idx = precomputed["unit_to_idx"] + n_global_units = len(precomputed["all_units"]) elif df is not None and unit is not None: n_global_units = df[unit].nunique() psi_total, _ = self._compute_combined_influence_function( - gt_pairs, weights, effects, groups_for_gt, - influence_func_info, df, unit, precomputed, + gt_pairs, + weights, + effects, + groups_for_gt, + influence_func_info, + df, + unit, + precomputed, global_unit_to_idx=global_unit_to_idx, n_global_units=n_global_units, ) @@ -391,7 +447,7 @@ def _compute_aggregated_se_with_wif( if not np.all(np.isfinite(psi_total)): return np.nan - variance = np.sum(psi_total ** 2) + variance = np.sum(psi_total**2) return np.sqrt(variance) def _aggregate_event_study( @@ -414,41 +470,46 @@ def _aggregate_event_study( adjustment that accounts for uncertainty in group-size weights, matching R's did::aggte(..., type="dynamic"). """ - n_units = len(precomputed['all_units']) if precomputed is not None else None - # Organize effects by relative time, keeping track of (g,t) pairs - effects_by_e: Dict[int, List[Tuple[Tuple[Any, Any], float, int]]] = {} + effects_by_e: Dict[int, List[Tuple[Tuple[Any, Any], float, float]]] = {} for (g, t), data in group_time_effects.items(): e = t - g # Relative time if e not in effects_by_e: effects_by_e[e] = [] - effects_by_e[e].append(( - (g, t), # Keep track of the (g,t) pair - data['effect'], - data['n_treated'] - )) + # Use survey_weight_sum for aggregation when available + w = data.get("survey_weight_sum", data["n_treated"]) + effects_by_e[e].append( + ( + (g, t), # Keep track of the (g,t) pair + data["effect"], + w, + ) + ) # Balance the panel if requested if balance_e is not None: # Keep only groups that have effects at relative time balance_e groups_at_e = set() for (g, t), data in group_time_effects.items(): - if t - g == balance_e and np.isfinite(data['effect']): + if t - g == balance_e and np.isfinite(data["effect"]): groups_at_e.add(g) # Filter effects to only include balanced groups - balanced_effects: Dict[int, List[Tuple[Tuple[Any, Any], float, int]]] = {} + balanced_effects: Dict[int, List[Tuple[Tuple[Any, Any], float, float]]] = {} for (g, t), data in group_time_effects.items(): if g in groups_at_e: e = t - g if e not in balanced_effects: balanced_effects[e] = [] - balanced_effects[e].append(( - (g, t), - data['effect'], - data['n_treated'] - )) + w = data.get("survey_weight_sum", data["n_treated"]) + balanced_effects[e].append( + ( + (g, t), + data["effect"], + w, + ) + ) effects_by_e = balanced_effects # Compute aggregated effects and SEs for all relative periods @@ -479,8 +540,7 @@ def _aggregate_event_study( # Compute SE with WIF adjustment (matching R's did::aggte) groups_for_gt = np.array([g for (g, t) in gt_pairs]) agg_se = self._compute_aggregated_se_with_wif( - gt_pairs, weights, effs, groups_for_gt, - influence_func_info, df, unit, precomputed + gt_pairs, weights, effs, groups_for_gt, influence_func_info, df, unit, precomputed ) agg_effects_list.append(agg_effect) @@ -490,32 +550,36 @@ def _aggregate_event_study( # Batch inference for all relative periods if not agg_effects_list: return {} + df_survey_val = precomputed.get("df_survey") if precomputed is not None else None t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch( - np.array(agg_effects_list), np.array(agg_ses_list), alpha=self.alpha + np.array(agg_effects_list), + np.array(agg_ses_list), + alpha=self.alpha, + df=df_survey_val, ) event_study_effects = {} for idx, (e, _) in enumerate(sorted_periods): event_study_effects[e] = { - 'effect': agg_effects_list[idx], - 'se': agg_ses_list[idx], - 't_stat': float(t_stats[idx]), - 'p_value': float(p_values[idx]), - 'conf_int': (float(ci_lowers[idx]), float(ci_uppers[idx])), - 'n_groups': agg_n_groups[idx], + "effect": agg_effects_list[idx], + "se": agg_ses_list[idx], + "t_stat": float(t_stats[idx]), + "p_value": float(p_values[idx]), + "conf_int": (float(ci_lowers[idx]), float(ci_uppers[idx])), + "n_groups": agg_n_groups[idx], } # Add reference period for universal base period mode (matches R did package) - if getattr(self, 'base_period', 'varying') == "universal": + if getattr(self, "base_period", "varying") == "universal": ref_period = -1 - self.anticipation if event_study_effects and ref_period not in event_study_effects: event_study_effects[ref_period] = { - 'effect': 0.0, - 'se': np.nan, - 't_stat': np.nan, - 'p_value': np.nan, - 'conf_int': (np.nan, np.nan), - 'n_groups': 0, + "effect": 0.0, + "se": np.nan, + "t_stat": np.nan, + "p_value": np.nan, + "conf_int": (np.nan, np.nan), + "n_groups": 0, } return event_study_effects @@ -535,13 +599,13 @@ def _aggregate_by_group( Standard errors use influence function aggregation to account for covariances across time periods within a cohort. """ - n_units = len(precomputed['all_units']) if precomputed is not None else None + n_units = len(precomputed["all_units"]) if precomputed is not None else None # Collect all group aggregation data first group_data_list = [] for g in groups: g_effects = [ - ((g, t), data['effect']) + ((g, t), data["effect"]) for (gg, t), data in group_time_effects.items() if gg == g and t >= g - self.anticipation ] @@ -574,19 +638,23 @@ def _aggregate_by_group( # Batch inference agg_effects = np.array([x[1] for x in group_data_list]) agg_ses = np.array([x[2] for x in group_data_list]) + df_survey_val = precomputed.get("df_survey") if precomputed is not None else None t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch( - agg_effects, agg_ses, alpha=self.alpha + agg_effects, + agg_ses, + alpha=self.alpha, + df=df_survey_val, ) group_effects = {} for idx, (g, agg_effect, agg_se, n_periods) in enumerate(group_data_list): group_effects[g] = { - 'effect': agg_effect, - 'se': agg_se, - 't_stat': float(t_stats[idx]), - 'p_value': float(p_values[idx]), - 'conf_int': (float(ci_lowers[idx]), float(ci_uppers[idx])), - 'n_periods': n_periods, + "effect": agg_effect, + "se": agg_se, + "t_stat": float(t_stats[idx]), + "p_value": float(p_values[idx]), + "conf_int": (float(ci_lowers[idx]), float(ci_uppers[idx])), + "n_periods": n_periods, } return group_effects diff --git a/diff_diff/staggered_results.py b/diff_diff/staggered_results.py index 53eeba1c..7f47dc31 100644 --- a/diff_diff/staggered_results.py +++ b/diff_diff/staggered_results.py @@ -6,7 +6,7 @@ """ from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import numpy as np import pandas as pd @@ -117,6 +117,8 @@ class CallawaySantAnnaResults: bootstrap_results: Optional["CSBootstrapResults"] = field(default=None, repr=False) cband_crit_value: Optional[float] = None pscore_trim: float = 0.01 + # Survey design metadata (SurveyMetadata instance from diff_diff.survey) + survey_metadata: Optional[Any] = field(default=None, repr=False) def __repr__(self) -> str: """Concise string representation.""" @@ -160,6 +162,27 @@ def summary(self, alpha: Optional[float] = None) -> str: "", ] + # Survey design info + if self.survey_metadata is not None: + sm = self.survey_metadata + lines.extend( + [ + "-" * 85, + "Survey Design".center(85), + "-" * 85, + f"{'Weight type:':<30} {sm.weight_type:>10}", + ] + ) + if sm.n_strata is not None: + lines.append(f"{'Strata:':<30} {sm.n_strata:>10}") + if sm.n_psu is not None: + lines.append(f"{'PSU/Cluster:':<30} {sm.n_psu:>10}") + lines.append(f"{'Effective sample size:':<30} {sm.effective_n:>10.1f}") + lines.append(f"{'Design effect (DEFF):':<30} {sm.design_effect:>10.2f}") + if sm.df_survey is not None: + lines.append(f"{'Survey d.f.:':<30} {sm.df_survey:>10}") + lines.extend(["-" * 85, ""]) + # Overall ATT lines.extend( [ diff --git a/diff_diff/triple_diff.py b/diff_diff/triple_diff.py index 15f04263..da4babff 100644 --- a/diff_diff/triple_diff.py +++ b/diff_diff/triple_diff.py @@ -493,7 +493,7 @@ def fit( ValueError If required columns are missing or data validation fails. NotImplementedError - If survey_design is used with estimation_method="ipw" or "dr". + If survey_design is used with wild_bootstrap inference. """ # Resolve survey design if provided from diff_diff.survey import ( @@ -507,13 +507,6 @@ def fit( _resolve_survey_for_fit(survey_design, data, "analytical") ) - # Guard IPW/DR with survey weights - if survey_design is not None and self.estimation_method in ("ipw", "dr"): - raise NotImplementedError( - "IPW and doubly robust methods with survey weights require " - "weighted solve_logit(), planned for Phase 5." - ) - # Validate inputs self._validate_data(data, outcome, group, partition, time, covariates) @@ -865,6 +858,9 @@ def _estimate_ddd_decomposition( PA4 = (sg_sub == 4).astype(float) PAa = (sg_sub == j).astype(float) + # Subset survey weights for this comparison (needed for logit) + w_sub = survey_weights[mask] if survey_weights is not None else None + # --- Propensity scores --- if est_method == "reg": # RA: no propensity scores needed @@ -878,6 +874,7 @@ def _estimate_ddd_decomposition( covX_sub[:, 1:], PA4, rank_deficient_action=self.rank_deficient_action, + weights=w_sub, ) except Exception: if self.rank_deficient_action == "error": @@ -933,9 +930,6 @@ def _estimate_ddd_decomposition( overlap_issues.append((j, frac_trimmed)) hessian = None - # Subset survey weights for this comparison - w_sub = survey_weights[mask] if survey_weights is not None else None - # --- Outcome regression --- if est_method == "ipw": # IPW: no outcome regression diff --git a/diff_diff/two_stage.py b/diff_diff/two_stage.py index 4b7603e3..85f5e409 100644 --- a/diff_diff/two_stage.py +++ b/diff_diff/two_stage.py @@ -37,12 +37,11 @@ from diff_diff.linalg import solve_ols from diff_diff.two_stage_bootstrap import TwoStageDiDBootstrapMixin from diff_diff.two_stage_results import ( - TwoStageBootstrapResults, + TwoStageBootstrapResults, # noqa: F401 TwoStageDiDResults, ) # noqa: F401 (re-export) from diff_diff.utils import safe_inference - # ============================================================================= # Main Estimator # ============================================================================= @@ -173,6 +172,7 @@ def fit( covariates: Optional[List[str]] = None, aggregate: Optional[str] = None, balance_e: Optional[int] = None, + survey_design: object = None, ) -> TwoStageDiDResults: """ Fit the two-stage DiD estimator. @@ -218,7 +218,32 @@ def fit( if missing: raise ValueError(f"Missing columns: {missing}") + # Create working copy df = data.copy() + + # Resolve survey design if provided + from diff_diff.survey import ( + _inject_cluster_as_psu, + _resolve_effective_cluster, + _resolve_survey_for_fit, + _validate_unit_constant_survey, + ) + + resolved_survey, survey_weights, survey_weight_type, survey_metadata = ( + _resolve_survey_for_fit(survey_design, data, "analytical") + ) + + # Validate within-unit constancy for panel survey designs + if resolved_survey is not None: + _validate_unit_constant_survey(data, unit, survey_design) + + # Guard bootstrap + survey + if self.n_bootstrap > 0 and resolved_survey is not None: + raise NotImplementedError( + "Bootstrap inference with survey weights is not yet supported " + "for TwoStageDiD. Use analytical inference (n_bootstrap=0)." + ) + df[time] = pd.to_numeric(df[time]) df[first_treat] = pd.to_numeric(df[first_treat]) @@ -302,6 +327,26 @@ def fit( f"Available columns: {list(df.columns)}" ) + # Resolve effective cluster and inject cluster-as-PSU for survey variance + if resolved_survey is not None: + cluster_ids_raw = df[cluster_var].values if cluster_var in df.columns else None + effective_cluster_ids = _resolve_effective_cluster( + resolved_survey, + cluster_ids_raw, + cluster_var if self.cluster is not None else None, + ) + resolved_survey = _inject_cluster_as_psu(resolved_survey, effective_cluster_ids) + # Recompute metadata after PSU injection + if resolved_survey.psu is not None and survey_metadata is not None: + from diff_diff.survey import compute_survey_metadata + + raw_w = ( + data[survey_design.weights].values.astype(np.float64) + if survey_design.weights + else np.ones(len(data), dtype=np.float64) + ) + survey_metadata = compute_survey_metadata(resolved_survey, raw_w) + # Relative time df["_rel_time"] = np.where( ~df["_never_treated"], @@ -311,7 +356,7 @@ def fit( # ---- Stage 1: OLS on untreated observations ---- unit_fe, time_fe, grand_mean, delta_hat, kept_cov_mask = self._fit_untreated_model( - df, outcome, unit, time, covariates, omega_0_mask + df, outcome, unit, time, covariates, omega_0_mask, weights=survey_weights ) # ---- Rank condition checks ---- @@ -359,6 +404,9 @@ def fit( # Build design matrices and compute effects + GMM variance ref_period = -1 - self.anticipation + # Survey degrees of freedom for t-distribution inference + _survey_df = resolved_survey.df_survey if resolved_survey is not None else None + # Always compute overall ATT (static specification) overall_att, overall_se = self._stage2_static( df=df, @@ -374,9 +422,13 @@ def fit( delta_hat=delta_hat, cluster_var=cluster_var, kept_cov_mask=kept_cov_mask, + survey_weights=survey_weights, + survey_weight_type=survey_weight_type, ) - overall_t, overall_p, overall_ci = safe_inference(overall_att, overall_se, alpha=self.alpha) + overall_t, overall_p, overall_ci = safe_inference( + overall_att, overall_se, alpha=self.alpha, df=_survey_df + ) # Event study and group aggregation event_study_effects = None @@ -400,6 +452,9 @@ def fit( ref_period=ref_period, balance_e=balance_e, kept_cov_mask=kept_cov_mask, + survey_weights=survey_weights, + survey_weight_type=survey_weight_type, + survey_df=_survey_df, ) if aggregate in ("group", "all"): @@ -418,6 +473,9 @@ def fit( cluster_var=cluster_var, treatment_groups=treatment_groups, kept_cov_mask=kept_cov_mask, + survey_weights=survey_weights, + survey_weight_type=survey_weight_type, + survey_df=_survey_df, ) # Build treatment effects DataFrame @@ -530,6 +588,7 @@ def fit( n_control_units=n_control_units, alpha=self.alpha, bootstrap_results=bootstrap_results, + survey_metadata=survey_metadata, ) self.is_fitted_ = True @@ -547,10 +606,17 @@ def _iterative_fe( idx: pd.Index, max_iter: int = 100, tol: float = 1e-10, + weights: Optional[np.ndarray] = None, ) -> Tuple[Dict[Any, float], Dict[Any, float]]: """ Estimate unit and time FE via iterative alternating projection. + Parameters + ---------- + weights : np.ndarray, optional + Survey weights. When provided, uses weighted group means + (sum(w*x)/sum(w)) instead of unweighted means. + Returns ------- unit_fe : dict @@ -562,23 +628,36 @@ def _iterative_fe( alpha = np.zeros(n) beta = np.zeros(n) + if weights is not None: + w_series = pd.Series(weights, index=idx) + wsum_t = w_series.groupby(time_vals).transform("sum").values + wsum_u = w_series.groupby(unit_vals).transform("sum").values + with np.errstate(invalid="ignore", divide="ignore"): for iteration in range(max_iter): resid_after_alpha = y - alpha - beta_new = ( - pd.Series(resid_after_alpha, index=idx) - .groupby(time_vals) - .transform("mean") - .values - ) + if weights is not None: + wr_t = pd.Series(resid_after_alpha * weights, index=idx) + beta_new = wr_t.groupby(time_vals).transform("sum").values / wsum_t + else: + beta_new = ( + pd.Series(resid_after_alpha, index=idx) + .groupby(time_vals) + .transform("mean") + .values + ) resid_after_beta = y - beta_new - alpha_new = ( - pd.Series(resid_after_beta, index=idx) - .groupby(unit_vals) - .transform("mean") - .values - ) + if weights is not None: + wr_u = pd.Series(resid_after_beta * weights, index=idx) + alpha_new = wr_u.groupby(unit_vals).transform("sum").values / wsum_u + else: + alpha_new = ( + pd.Series(resid_after_beta, index=idx) + .groupby(unit_vals) + .transform("mean") + .values + ) max_change = max( np.max(np.abs(alpha_new - alpha)), @@ -601,21 +680,43 @@ def _iterative_demean( idx: pd.Index, max_iter: int = 100, tol: float = 1e-10, + weights: Optional[np.ndarray] = None, ) -> np.ndarray: - """Demean a vector by iterative alternating projection.""" + """Demean a vector by iterative alternating projection (unit + time FE removal). + + Parameters + ---------- + weights : np.ndarray, optional + Survey weights. When provided, uses weighted group means + (sum(w*x)/sum(w)) instead of unweighted means. + """ result = vals.copy() + + if weights is not None: + w_series = pd.Series(weights, index=idx) + wsum_t = w_series.groupby(time_vals).transform("sum").values + wsum_u = w_series.groupby(unit_vals).transform("sum").values + with np.errstate(invalid="ignore", divide="ignore"): for _ in range(max_iter): - time_means = ( - pd.Series(result, index=idx).groupby(time_vals).transform("mean").values - ) + if weights is not None: + wr_t = pd.Series(result * weights, index=idx) + time_means = wr_t.groupby(time_vals).transform("sum").values / wsum_t + else: + time_means = ( + pd.Series(result, index=idx).groupby(time_vals).transform("mean").values + ) result_after_time = result - time_means - unit_means = ( - pd.Series(result_after_time, index=idx) - .groupby(unit_vals) - .transform("mean") - .values - ) + if weights is not None: + wr_u = pd.Series(result_after_time * weights, index=idx) + unit_means = wr_u.groupby(unit_vals).transform("sum").values / wsum_u + else: + unit_means = ( + pd.Series(result_after_time, index=idx) + .groupby(unit_vals) + .transform("mean") + .values + ) result_new = result_after_time - unit_means if np.max(np.abs(result_new - result)) < tol: result = result_new @@ -631,22 +732,30 @@ def _fit_untreated_model( time: str, covariates: Optional[List[str]], omega_0_mask: pd.Series, + weights: Optional[np.ndarray] = None, ) -> Tuple[ Dict[Any, float], Dict[Any, float], float, Optional[np.ndarray], Optional[np.ndarray] ]: """ Stage 1: Estimate unit + time FE on untreated observations. + Parameters + ---------- + weights : np.ndarray, optional + Full-panel survey weights (same length as df). The untreated subset + is extracted internally via omega_0_mask. When None, unweighted. + Returns ------- unit_fe, time_fe, grand_mean, delta_hat, kept_cov_mask """ df_0 = df.loc[omega_0_mask] + w_0 = weights[omega_0_mask.values] if weights is not None else None if covariates is None or len(covariates) == 0: y = df_0[outcome].values.copy() unit_fe, time_fe = self._iterative_fe( - y, df_0[unit].values, df_0[time].values, df_0.index + y, df_0[unit].values, df_0[time].values, df_0.index, weights=w_0 ) return unit_fe, time_fe, 0.0, None, None @@ -657,10 +766,10 @@ def _fit_untreated_model( times = df_0[time].values n_cov = len(covariates) - y_dm = self._iterative_demean(y, units, times, df_0.index) + y_dm = self._iterative_demean(y, units, times, df_0.index, weights=w_0) X_dm = np.column_stack( [ - self._iterative_demean(X_raw[:, j], units, times, df_0.index) + self._iterative_demean(X_raw[:, j], units, times, df_0.index, weights=w_0) for j in range(n_cov) ] ) @@ -671,13 +780,14 @@ def _fit_untreated_model( return_vcov=False, rank_deficient_action=self.rank_deficient_action, column_names=covariates, + weights=w_0, ) delta_hat = result[0] kept_cov_mask = np.isfinite(delta_hat) delta_hat_clean = np.where(np.isfinite(delta_hat), delta_hat, 0.0) y_adj = y - np.dot(X_raw, delta_hat_clean) - unit_fe, time_fe = self._iterative_fe(y_adj, units, times, df_0.index) + unit_fe, time_fe = self._iterative_fe(y_adj, units, times, df_0.index, weights=w_0) return unit_fe, time_fe, 0.0, delta_hat_clean, kept_cov_mask @@ -736,6 +846,8 @@ def _stage2_static( delta_hat: Optional[np.ndarray], cluster_var: str, kept_cov_mask: Optional[np.ndarray], + survey_weights: Optional[np.ndarray] = None, + survey_weight_type: str = "pweight", ) -> Tuple[float, float]: """ Static (simple ATT) Stage 2: OLS of y_tilde on D_it. @@ -763,7 +875,13 @@ def _stage2_static( return np.nan, np.nan # Stage 2 OLS for point estimate (discard naive SE) - coef, residuals, _ = solve_ols(X_2, y_tilde, return_vcov=False) + coef, residuals, _ = solve_ols( + X_2, + y_tilde, + return_vcov=False, + weights=survey_weights, + weight_type=survey_weight_type, + ) att = float(coef[0]) # GMM sandwich variance @@ -782,6 +900,7 @@ def _stage2_static( X_2=X_2, eps_2=eps_2, cluster_ids=df[cluster_var].values, + survey_weights=survey_weights, ) se = float(np.sqrt(max(V[0, 0], 0.0))) @@ -805,6 +924,9 @@ def _stage2_event_study( ref_period: int, balance_e: Optional[int], kept_cov_mask: Optional[np.ndarray], + survey_weights: Optional[np.ndarray] = None, + survey_weight_type: str = "pweight", + survey_df: Optional[int] = None, ) -> Dict[int, Dict[str, Any]]: """Event study Stage 2: OLS of y_tilde on relative-time dummies.""" y_tilde = df["_y_tilde"].values.copy() @@ -921,7 +1043,13 @@ def _stage2_event_study( X_2[i, horizon_to_col[h_int]] = 1.0 # Stage 2 OLS - coef, residuals, _ = solve_ols(X_2, y_tilde, return_vcov=False) + coef, residuals, _ = solve_ols( + X_2, + y_tilde, + return_vcov=False, + weights=survey_weights, + weight_type=survey_weight_type, + ) eps_2 = y_tilde - np.dot(X_2, coef) # GMM variance for full coefficient vector @@ -938,6 +1066,7 @@ def _stage2_event_study( X_2=X_2, eps_2=eps_2, cluster_ids=df[cluster_var].values, + survey_weights=survey_weights, ) # Build results dict @@ -971,7 +1100,7 @@ def _stage2_event_study( effect = float(coef[j]) se = float(np.sqrt(max(V[j, j], 0.0))) - t_stat, p_val, ci = safe_inference(effect, se, alpha=self.alpha) + t_stat, p_val, ci = safe_inference(effect, se, alpha=self.alpha, df=survey_df) event_study_effects[h] = { "effect": effect, @@ -1011,6 +1140,9 @@ def _stage2_group( cluster_var: str, treatment_groups: List[Any], kept_cov_mask: Optional[np.ndarray], + survey_weights: Optional[np.ndarray] = None, + survey_weight_type: str = "pweight", + survey_df: Optional[int] = None, ) -> Dict[Any, Dict[str, Any]]: """Group (cohort) Stage 2: OLS of y_tilde on cohort dummies.""" y_tilde = df["_y_tilde"].values.copy() @@ -1033,7 +1165,13 @@ def _stage2_group( X_2[i, group_to_col[g]] = 1.0 # Stage 2 OLS - coef, residuals, _ = solve_ols(X_2, y_tilde, return_vcov=False) + coef, residuals, _ = solve_ols( + X_2, + y_tilde, + return_vcov=False, + weights=survey_weights, + weight_type=survey_weight_type, + ) eps_2 = y_tilde - np.dot(X_2, coef) # GMM variance @@ -1050,6 +1188,7 @@ def _stage2_group( X_2=X_2, eps_2=eps_2, cluster_ids=df[cluster_var].values, + survey_weights=survey_weights, ) group_effects: Dict[Any, Dict[str, Any]] = {} @@ -1071,7 +1210,7 @@ def _stage2_group( effect = float(coef[j]) se = float(np.sqrt(max(V[j, j], 0.0))) - t_stat, p_val, ci = safe_inference(effect, se, alpha=self.alpha) + t_stat, p_val, ci = safe_inference(effect, se, alpha=self.alpha, df=survey_df) group_effects[g] = { "effect": effect, @@ -1137,6 +1276,7 @@ def _compute_gmm_variance( X_2: np.ndarray, eps_2: np.ndarray, cluster_ids: np.ndarray, + survey_weights: Optional[np.ndarray] = None, ) -> np.ndarray: """ Compute GMM sandwich variance (Butts & Gardner 2022). @@ -1155,6 +1295,12 @@ def _compute_gmm_variance( S_g = gamma_hat' c_g - X'_{2g} eps_{2g} c_g = X'_{10g} eps_{10g} + With survey weights W (diagonal): + Bread: (X'_2 W X_2)^{-1} + gamma_hat: (X'_{10} W X_{10})^{-1} (X'_1 W X_2) + c_g = sum_{i in g} w_i * x_{10i} * eps_{10i} + s2_g = sum_{i in g} w_i * x_{2i} * eps_{2i} + Parameters ---------- X_2 : np.ndarray, shape (n, k) @@ -1163,6 +1309,9 @@ def _compute_gmm_variance( Stage 2 residuals. cluster_ids : np.ndarray, shape (n,) Cluster identifiers. + survey_weights : np.ndarray, optional + Survey weights of shape (n,). When None, unweighted (identical + to current code). Returns ------- @@ -1207,27 +1356,37 @@ def _compute_gmm_variance( eps_10[omega_0] = y_vals[omega_0] - fitted_1[omega_0] # Stage 1 residual eps_10[~omega_0] = y_vals[~omega_0] # x_{10i} = 0, so eps_10 = Y - # 1. gamma_hat = (X'_{10} X_{10})^{-1} (X'_1 X_2) [p x k] - XtX_10 = X_10_sparse.T @ X_10_sparse # (p x p) sparse - Xt1_X2 = X_1_sparse.T @ X_2 # (p x k) dense + # 1. gamma_hat = (X'_{10} W X_{10})^{-1} (X'_1 W X_2) [p x k] + # With survey weights, both cross-products need W + if survey_weights is not None: + XtWX_10 = X_10_sparse.T @ X_10_sparse.multiply(survey_weights[:, None]) + Xt1_WX2 = X_1_sparse.T @ (X_2 * survey_weights[:, None]) + else: + XtWX_10 = X_10_sparse.T @ X_10_sparse # (p x p) sparse + Xt1_WX2 = X_1_sparse.T @ X_2 # (p x k) dense try: - solve_XtX = sparse_factorized(XtX_10.tocsc()) - if Xt1_X2.ndim == 1: - gamma_hat = solve_XtX(Xt1_X2).reshape(-1, 1) + solve_XtX = sparse_factorized(XtWX_10.tocsc()) + if Xt1_WX2.ndim == 1: + gamma_hat = solve_XtX(Xt1_WX2).reshape(-1, 1) else: gamma_hat = np.column_stack( - [solve_XtX(Xt1_X2[:, j]) for j in range(Xt1_X2.shape[1])] + [solve_XtX(Xt1_WX2[:, j]) for j in range(Xt1_WX2.shape[1])] ) except RuntimeError: # Singular matrix — fall back to dense least-squares - gamma_hat = np.linalg.lstsq(XtX_10.toarray(), Xt1_X2, rcond=None)[0] + gamma_hat = np.linalg.lstsq(XtWX_10.toarray(), Xt1_WX2, rcond=None)[0] if gamma_hat.ndim == 1: gamma_hat = gamma_hat.reshape(-1, 1) - # 2. Per-cluster Stage 1 scores: c_g = X'_{10g} eps_{10g} + # 2. Per-cluster Stage 1 scores: c_g = sum_{i in g} w_i * x_{10i} * eps_{10i} # Only untreated obs have non-zero X_10 rows - weighted_X10 = X_10_sparse.multiply(eps_10[:, None]) # sparse element-wise + # With survey weights: multiply eps_10 by survey_weights before sparse multiply + if survey_weights is not None: + weighted_eps_10 = survey_weights * eps_10 + else: + weighted_eps_10 = eps_10 + weighted_X10 = X_10_sparse.multiply(weighted_eps_10[:, None]) # sparse element-wise unique_clusters, cluster_indices = np.unique(cluster_ids, return_inverse=True) G = len(unique_clusters) @@ -1246,8 +1405,12 @@ def _compute_gmm_variance( for j_col in range(p): np.add.at(c_by_cluster[:, j_col], cluster_indices, weighted_X10_dense[:, j_col]) - # 3. Per-cluster Stage 2 scores: X'_{2g} eps_{2g} - weighted_X2 = X_2 * eps_2[:, None] # (n x k) dense + # 3. Per-cluster Stage 2 scores: s2_g = sum_{i in g} w_i * x_{2i} * eps_{2i} + if survey_weights is not None: + weighted_eps_2 = survey_weights * eps_2 + else: + weighted_eps_2 = eps_2 + weighted_X2 = X_2 * weighted_eps_2[:, None] # (n x k) dense s2_by_cluster = np.zeros((G, k)) for j_col in range(k): np.add.at(s2_by_cluster[:, j_col], cluster_indices, weighted_X2[:, j_col]) @@ -1259,13 +1422,16 @@ def _compute_gmm_variance( with np.errstate(invalid="ignore", over="ignore"): meat = S.T @ S # (k x k) - # 6. Bread: (X'_2 X_2)^{-1} + # 6. Bread: (X'_2 W X_2)^{-1} with np.errstate(invalid="ignore", over="ignore", divide="ignore"): - XtX_2 = X_2.T @ X_2 + if survey_weights is not None: + XtWX_2 = X_2.T @ (X_2 * survey_weights[:, None]) + else: + XtWX_2 = X_2.T @ X_2 try: - bread = np.linalg.solve(XtX_2, np.eye(k)) + bread = np.linalg.solve(XtWX_2, np.eye(k)) except np.linalg.LinAlgError: - bread = np.linalg.lstsq(XtX_2, np.eye(k), rcond=None)[0] + bread = np.linalg.lstsq(XtWX_2, np.eye(k), rcond=None)[0] # 7. V = bread @ meat @ bread V = bread @ meat @ bread diff --git a/diff_diff/two_stage_results.py b/diff_diff/two_stage_results.py index 16b4916a..b06cd000 100644 --- a/diff_diff/two_stage_results.py +++ b/diff_diff/two_stage_results.py @@ -137,6 +137,8 @@ class TwoStageDiDResults: n_control_units: int alpha: float = 0.05 bootstrap_results: Optional[TwoStageBootstrapResults] = field(default=None, repr=False) + # Survey design metadata (SurveyMetadata instance from diff_diff.survey) + survey_metadata: Optional[Any] = field(default=None, repr=False) def __repr__(self) -> str: """Concise string representation.""" @@ -180,6 +182,27 @@ def summary(self, alpha: Optional[float] = None) -> str: "", ] + # Survey design info + if self.survey_metadata is not None: + sm = self.survey_metadata + lines.extend( + [ + "-" * 85, + "Survey Design".center(85), + "-" * 85, + f"{'Weight type:':<30} {sm.weight_type:>10}", + ] + ) + if sm.n_strata is not None: + lines.append(f"{'Strata:':<30} {sm.n_strata:>10}") + if sm.n_psu is not None: + lines.append(f"{'PSU/Cluster:':<30} {sm.n_psu:>10}") + lines.append(f"{'Effective sample size:':<30} {sm.effective_n:>10.1f}") + lines.append(f"{'Design effect (DEFF):':<30} {sm.design_effect:>10.2f}") + if sm.df_survey is not None: + lines.append(f"{'Survey d.f.:':<30} {sm.df_survey:>10}") + lines.extend(["-" * 85, ""]) + # Overall ATT lines.extend( [ diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 9c5e9c0a..03cf5958 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -416,6 +416,7 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1: a base period later than `t` (matching R's `did::att_gt()`) - Does not require never-treated units: when all units are eventually treated, not-yet-treated cohorts serve as controls for each other (requires ≥2 cohorts) +- **Note:** CallawaySantAnna survey weights compose with IPW weights multiplicatively: w_total = w_survey * p(X) / (1-p(X)). Propensity scores estimated via survey-weighted solve_logit(). WIF in aggregation uses survey-weighted group sizes. Bootstrap + survey deferred. **Reference implementation(s):** - R: `did::att_gt()` (Callaway & Sant'Anna's official package) @@ -840,6 +841,7 @@ Y_it = alpha_i + beta_t [+ X'_it * delta] + W'_it * gamma + epsilon_it - **treatment_effects DataFrame weights:** `weight` column uses `1/n_valid` for finite tau_hat and 0 for NaN tau_hat, consistent with the ATT estimand. - **Rank-deficient covariates in variance:** Covariates with NaN coefficients (dropped for rank deficiency in Step 1) are excluded from the variance design matrices `A_0`/`A_1`. Only covariates with finite coefficients participate in the `v_it` projection. - **Sparse variance solver:** `_compute_v_untreated_with_covariates` uses `scipy.sparse.linalg.spsolve` to solve `(A_0'A_0) z = A_1'w` without densifying the normal equations matrix. Falls back to dense `lstsq` if the sparse solver fails. +- **Note:** Survey weights enter ImputationDiD via weighted iterative FE (Step 1), survey-weighted ATT aggregation (Step 3), and survey-weighted conservative variance (Theorem 3). Survey df used for t-distribution inference. Bootstrap + survey deferred. - **Bootstrap inference:** Uses multiplier bootstrap on the Theorem 3 influence function: `psi_i = sum_t v_it * epsilon_tilde_it`. Cluster-level psi sums are pre-computed for each aggregation target (overall, per-horizon, per-group), then perturbed with multiplier weights (Rademacher by default; configurable via `bootstrap_weights` parameter to use Mammen or Webb weights, matching CallawaySantAnna). This is a library extension (not in the paper) consistent with CallawaySantAnna/SunAbraham bootstrap patterns. - **Auxiliary residuals (Equation 8):** Uses v_it-weighted tau_tilde_g formula: `tau_tilde_g = sum(v_it * tau_hat_it) / sum(v_it)` within each partition group. Zero-weight groups (common in event-study SE computation) fall back to unweighted mean. @@ -917,6 +919,7 @@ Our implementation uses multiplier bootstrap on the GMM influence function: clus - **No never-treated units (Proposition 5):** When there are no never-treated units and multiple treatment cohorts, horizons h >= h_bar (where h_bar = max(groups) - min(groups)) are unidentified per Proposition 5 of Borusyak et al. (2024). These produce NaN inference with n_obs > 0 (treated observations exist but counterfactual is unidentified) and a warning listing affected horizons. Matches ImputationDiD behavior. Proposition 5 applies to event study horizons only, not cohort aggregation — a cohort whose treated obs all fall at Prop 5 horizons naturally gets n_obs=0 in group effects because all its y_tilde values are NaN. - **Zero-observation horizons after filtering:** When `balance_e` or NaN `y_tilde` filtering results in zero observations for some non-Prop-5 event study horizons, those horizons produce NaN for all inference fields (effect, SE, t-stat, p-value, CI) with n_obs=0. - **Zero-observation cohorts in group effects:** If all treated observations for a cohort have NaN `y_tilde` (excluded from estimation), that cohort's group effect is NaN with n_obs=0. +- **Note:** Survey weights in TwoStageDiD GMM sandwich via weighted cross-products: bread uses (X'_2 W X_2)^{-1}, gamma_hat uses (X'_{10} W X_{10})^{-1}(X'_1 W X_2), per-cluster scores multiply by survey weights. Survey df used for t-distribution. Bootstrap + survey deferred. **Reference implementation(s):** - R: `did2s::did2s()` (Kyle Butts & John Gardner) diff --git a/docs/survey-roadmap.md b/docs/survey-roadmap.md index e59fd211..20b4df26 100644 --- a/docs/survey-roadmap.md +++ b/docs/survey-roadmap.md @@ -1,7 +1,7 @@ # Survey Data Support Roadmap This document captures planned future work for survey data support in diff-diff. -Phases 1-3 are implemented. Phases 4-5 are deferred for future PRs. +Phases 1-4 are implemented. Phase 5 is deferred for future PRs. ## Implemented (Phases 1-2) @@ -21,19 +21,18 @@ Phases 1-3 are implemented. Phases 4-5 are deferred for future PRs. | StackedDiD | `stacked_did.py` | pweight only | Q-weights compose multiplicatively with survey weights; TSL vcov on composed weights; fweight/aweight rejected (composition changes weight semantics) | | SunAbraham | `sun_abraham.py` | Full | Survey weights in LinearRegression + weighted within-transform; bootstrap+survey deferred | | BaconDecomposition | `bacon.py` | Diagnostic | Weighted cell means, weighted within-transform, weighted group shares; no inference (diagnostic only) | -| TripleDifference | `triple_diff.py` | Reg only | Regression method with weighted OLS + TSL on influence functions; IPW/DR deferred (needs weighted `solve_logit()`) | +| TripleDifference | `triple_diff.py` | Full | Regression, IPW, and DR methods with weighted OLS/logit + TSL on influence functions | | ContinuousDiD | `continuous_did.py` | Analytical | Weighted B-spline OLS + TSL on influence functions; bootstrap+survey deferred | | EfficientDiD | `efficient_did.py` | Analytical | Weighted means/covariances in Omega* + TSL on EIF scores; bootstrap+survey deferred | ### Phase 3 Deferred Work The following capabilities were deferred from Phase 3 because they depend on -Phase 5 infrastructure (weighted `solve_logit()` or bootstrap+survey interaction): +Phase 5 infrastructure (bootstrap+survey interaction): | Estimator | Deferred Capability | Blocker | |-----------|-------------------|---------| | SunAbraham | Pairs bootstrap + survey | Phase 5: bootstrap+survey interaction | -| TripleDifference | IPW and DR methods + survey | Phase 5: weighted `solve_logit()` | | ContinuousDiD | Multiplier bootstrap + survey | Phase 5: bootstrap+survey interaction | | EfficientDiD | Multiplier bootstrap + survey | Phase 5: bootstrap+survey interaction | | EfficientDiD | Covariates (DR path) + survey | DR nuisance estimation needs survey weight threading | @@ -41,19 +40,43 @@ Phase 5 infrastructure (weighted `solve_logit()` or bootstrap+survey interaction All blocked combinations raise `NotImplementedError` when attempted, with a message pointing to the planned phase or describing the limitation. -## Phase 4: Complex Standalone Estimators +## Implemented (Phase 4): Complex Standalone Estimators + Weighted Logit -These require more substantial changes beyond threading weights. +| Estimator | File | Survey Support | Notes | +|-----------|------|----------------|-------| +| ImputationDiD | `imputation.py` | Analytical | Weighted iterative FE, weighted ATT aggregation, weighted conservative variance (Theorem 3); bootstrap+survey deferred | +| TwoStageDiD | `two_stage.py` | Analytical | Weighted iterative FE, weighted Stage 2 OLS, weighted GMM sandwich variance; bootstrap+survey deferred | +| CallawaySantAnna | `staggered.py` | Full (analytical) | Survey-weighted regression, IPW (via weighted `solve_logit()`), and DR; survey weights compose with IPW weights multiplicatively; survey-weighted WIF in aggregation; bootstrap+survey deferred | + +**Infrastructure**: Weighted `solve_logit()` added to `linalg.py` — survey weights +enter the IRLS working weights as `w_survey * mu * (1 - mu)`. This also unblocked +TripleDifference IPW/DR from Phase 3 deferred work. + +### Phase 4 Deferred Work + +| Estimator | Deferred Capability | Blocker | +|-----------|-------------------|---------| +| ImputationDiD | Bootstrap + survey | Phase 5: bootstrap+survey interaction | +| TwoStageDiD | Bootstrap + survey | Phase 5: bootstrap+survey interaction | +| CallawaySantAnna | Bootstrap + survey | Phase 5: bootstrap+survey interaction | + +### Remaining for Phase 5 | Estimator | File | Complexity | Notes | |-----------|------|------------|-------| -| ImputationDiD | `imputation.py` | Medium | Weighted Stage 1 OLS, weighted ATT aggregation, weighted conservative variance | -| TwoStageDiD | `two_stage.py` | Medium-High | Weighted within-transformation, weighted GMM sandwich | -| CallawaySantAnna | `staggered.py` | High | Weights enter propensity score (weighted solve_logit), IPW/DR reweighting (w_survey x w_ipw), multiplier bootstrap | -| SyntheticDiD | `synthetic_did.py` | Medium | Survey-weighted treated mean in optimization, weighted placebo variance | -| TROP | `trop.py` | Medium | Survey weights in ATT aggregation and LOOCV (distinct from TROP's internal weights) | +| SyntheticDiD | `synthetic_did.py` | Medium | Survey-weighted treated mean in optimization, weighted placebo variance (bootstrap-based SE only) | +| TROP | `trop.py` | Medium | Survey weights in ATT aggregation and LOOCV (bootstrap-based SE only) | + +## Phase 5: Advanced Features + Remaining Estimators -## Phase 5: Advanced Features +### SyntheticDiD and TROP Survey Support +Both estimators use bootstrap/placebo for SE with no analytical variance path. +Phase 5 provides survey-weighted point estimates and survey-aware bootstrap SE. + +### Bootstrap + Survey Interaction +Unblock bootstrap + survey for all estimators that currently defer it +(ImputationDiD, TwoStageDiD, CallawaySantAnna, SunAbraham, ContinuousDiD, +EfficientDiD). Requires survey-aware resampling schemes. ### Replicate Weight Variance Re-run WLS for each replicate weight column, compute variance from distribution @@ -64,16 +87,7 @@ Add `replicate_weights`, `replicate_type`, `replicate_rho` fields to SurveyDesig Compare survey vcov to SRS vcov element-wise. Report design effect per coefficient. Effective n = n / DEFF. -### Wild Bootstrap with Survey Weights -Add `survey_weights` parameter to `wild_bootstrap_se()`. Requires careful -interaction between bootstrap resampling and survey weight structure. - ### Subpopulation Analysis `SurveyDesign.subpopulation(data, mask)` — zero-out weights for excluded observations while preserving the full design structure for correct variance estimation (unlike simple subsetting, which would drop design information). - -### Weighted `solve_logit()` -Add weights to the IRLS iteration in `solve_logit()`. Required by -CallawaySantAnna (propensity score estimation) and TripleDifference (IPW method). -Working weights become `w_survey * mu * (1 - mu)`. diff --git a/tests/test_survey_phase3.py b/tests/test_survey_phase3.py index 3248fa03..3cddfbcd 100644 --- a/tests/test_survey_phase3.py +++ b/tests/test_survey_phase3.py @@ -496,35 +496,37 @@ def test_smoke_reg_method(self, ddd_survey_data): assert np.isfinite(result.se) assert result.survey_metadata is not None - def test_ipw_survey_raises(self, ddd_survey_data): - """IPW + survey should raise NotImplementedError.""" + def test_ipw_survey_works(self, ddd_survey_data): + """IPW + survey now works (unblocked by weighted solve_logit in Phase 4).""" from diff_diff import TripleDifference sd = SurveyDesign(weights="weight") - with pytest.raises(NotImplementedError, match="IPW"): - TripleDifference(estimation_method="ipw").fit( - ddd_survey_data, - "outcome", - "group", - "partition", - "time", - survey_design=sd, - ) + result = TripleDifference(estimation_method="ipw").fit( + ddd_survey_data, + "outcome", + "group", + "partition", + "time", + survey_design=sd, + ) + assert np.isfinite(result.att) + assert np.isfinite(result.se) - def test_dr_survey_raises(self, ddd_survey_data): - """DR + survey should raise NotImplementedError.""" + def test_dr_survey_works(self, ddd_survey_data): + """DR + survey now works (unblocked by weighted solve_logit in Phase 4).""" from diff_diff import TripleDifference sd = SurveyDesign(weights="weight") - with pytest.raises(NotImplementedError, match="doubly robust"): - TripleDifference(estimation_method="dr").fit( - ddd_survey_data, - "outcome", - "group", - "partition", - "time", - survey_design=sd, - ) + result = TripleDifference(estimation_method="dr").fit( + ddd_survey_data, + "outcome", + "group", + "partition", + "time", + survey_design=sd, + ) + assert np.isfinite(result.att) + assert np.isfinite(result.se) def test_weighted_changes_att(self, ddd_survey_data): """Survey weights should change ATT.""" diff --git a/tests/test_survey_phase4.py b/tests/test_survey_phase4.py new file mode 100644 index 00000000..80caf6b7 --- /dev/null +++ b/tests/test_survey_phase4.py @@ -0,0 +1,803 @@ +"""Tests for Phase 4 survey support: complex standalone estimators. + +Covers: ImputationDiD, TwoStageDiD, CallawaySantAnna, weighted solve_logit(), +TripleDifference IPW/DR unblock, and cross-estimator scale invariance. +""" + +import numpy as np +import pandas as pd +import pytest + +from diff_diff import ( + CallawaySantAnna, + ImputationDiD, + SurveyDesign, + TripleDifference, + TwoStageDiD, + generate_staggered_data, +) +from diff_diff.linalg import solve_logit + +# ============================================================================= +# Shared Fixtures +# ============================================================================= + + +@pytest.fixture +def staggered_survey_data(): + """Staggered treatment panel with survey design columns. + + 200 units via generate_staggered_data, then add unit-level survey columns + (weights, stratum, psu, fpc) that are constant within each unit. + """ + data = generate_staggered_data(n_units=200, seed=42) + + # Add unit-level survey columns (constant within unit) + unit_ids = data["unit"].unique() + n_units = len(unit_ids) + np.random.RandomState(42) + + unit_weight = 1.0 + 0.5 * (np.arange(n_units) % 5) + unit_stratum = np.arange(n_units) // 40 # 5 strata + unit_psu = np.arange(n_units) // 10 # 20 PSUs + unit_fpc = np.full(n_units, 400.0) # population per stratum + + unit_map = {uid: i for i, uid in enumerate(unit_ids)} + idx = data["unit"].map(unit_map).values + + data["weight"] = unit_weight[idx] + data["stratum"] = unit_stratum[idx] + data["psu"] = unit_psu[idx] + data["fpc"] = unit_fpc[idx] + + return data + + +@pytest.fixture +def survey_design_weights_only(): + """SurveyDesign with weights only.""" + return SurveyDesign(weights="weight") + + +@pytest.fixture +def survey_design_full(): + """SurveyDesign with weights, strata, psu.""" + return SurveyDesign(weights="weight", strata="stratum", psu="psu") + + +@pytest.fixture +def ddd_survey_data(): + """Cross-sectional DDD data with survey columns for TripleDifference.""" + np.random.seed(42) + n = 400 + data = pd.DataFrame( + { + "outcome": np.random.randn(n) + 0.5, + "group": np.random.choice([0, 1], n), + "partition": np.random.choice([0, 1], n), + "time": np.random.choice([0, 1], n), + "weight": np.random.uniform(0.5, 2.0, n), + "stratum": np.random.choice([1, 2, 3], n), + } + ) + # Add treatment effect for treated+eligible+post + mask = (data["group"] == 1) & (data["partition"] == 1) & (data["time"] == 1) + data.loc[mask, "outcome"] += 1.5 + return data + + +# ============================================================================= +# TestWeightedSolveLogit +# ============================================================================= + + +class TestWeightedSolveLogit: + """Tests for weighted solve_logit() (IRLS with survey weights).""" + + def test_uniform_weights_match_unweighted(self): + """Uniform weights should produce same coefficients as unweighted.""" + rng = np.random.RandomState(123) + n = 200 + X = rng.randn(n, 2) + y = (X @ [1.0, -0.5] + rng.randn(n) > 0).astype(float) + + beta_unw, probs_unw = solve_logit(X, y) + beta_w, probs_w = solve_logit(X, y, weights=np.ones(n)) + + np.testing.assert_allclose(beta_unw, beta_w, atol=1e-10) + np.testing.assert_allclose(probs_unw, probs_w, atol=1e-10) + + def test_convergence_with_weights(self): + """solve_logit converges with non-uniform survey weights.""" + rng = np.random.RandomState(42) + n = 200 + X = rng.randn(n, 2) + y = (X @ [0.5, -0.5] + rng.randn(n) * 1.5 > 0).astype(float) + weights = rng.uniform(0.5, 3.0, n) + + beta, probs = solve_logit(X, y, weights=weights) + + # Should produce finite coefficients (convergence) + assert np.all(np.isfinite(beta)) + assert np.all(np.isfinite(probs)) + assert np.all(probs > 0) and np.all(probs < 1) + + def test_separation_detection_with_weights(self): + """Separation warning should still fire with survey weights.""" + rng = np.random.RandomState(789) + n = 100 + # Create near-separation: x1 perfectly predicts y + x1 = np.linspace(-5, 5, n) + X = x1.reshape(-1, 1) + y = (x1 > 0).astype(float) + weights = rng.uniform(1.0, 3.0, n) + + with pytest.warns(UserWarning): + beta, probs = solve_logit(X, y, weights=weights) + + assert np.all(np.isfinite(beta)) + + def test_known_answer_small_dataset(self): + """Manual check on a small dataset — weights should shift coefficients.""" + rng = np.random.RandomState(333) + n = 50 + X = rng.randn(n, 1) + prob = 1.0 / (1.0 + np.exp(-(0.5 + 1.0 * X[:, 0]))) + y = (rng.rand(n) < prob).astype(float) + + # Unweighted fit + beta_unw, _ = solve_logit(X, y) + + # Non-uniform weights: upweight observations where y=1 + weights = np.where(y == 1, 5.0, 1.0) + beta_w, _ = solve_logit(X, y, weights=weights) + + # Weighted fit should shift intercept (upweighting y=1 shifts boundary) + assert beta_w[0] != pytest.approx(beta_unw[0], abs=0.1) + + def test_rank_deficiency_with_weights(self): + """Rank-deficient columns should still be detected with weights.""" + rng = np.random.RandomState(111) + n = 100 + x1 = rng.randn(n) + # x2 is a perfect linear combination of x1 + X = np.column_stack([x1, 2.0 * x1]) + y = (x1 > 0).astype(float) + weights = rng.uniform(0.5, 2.0, n) + + with pytest.warns(UserWarning, match="[Rr]ank"): + beta, probs = solve_logit(X, y, weights=weights) + + assert np.all(np.isfinite(probs)) + + def test_weight_scale_invariance(self): + """Multiplying weights by a constant should not change beta.""" + rng = np.random.RandomState(222) + n = 200 + X = rng.randn(n, 2) + y = (X @ [0.8, -0.3] + rng.randn(n) > 0).astype(float) + weights = rng.uniform(1.0, 4.0, n) + + beta1, _ = solve_logit(X, y, weights=weights) + beta2, _ = solve_logit(X, y, weights=weights * 2.0) + + np.testing.assert_allclose(beta1, beta2, atol=1e-10) + + +# ============================================================================= +# TestImputationDiDSurvey +# ============================================================================= + + +class TestImputationDiDSurvey: + """Survey design support for ImputationDiD.""" + + def test_smoke_weights_only(self, staggered_survey_data, survey_design_weights_only): + """ImputationDiD runs with weights-only survey design.""" + result = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.survey_metadata is not None + + def test_uniform_weights_match_unweighted(self, staggered_survey_data): + """Uniform survey weights should match unweighted result.""" + staggered_survey_data["uniform_w"] = 1.0 + sd = SurveyDesign(weights="uniform_w") + + r_unw = ImputationDiD().fit( + staggered_survey_data, "outcome", "unit", "period", "first_treat" + ) + r_w = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + ) + assert abs(r_unw.overall_att - r_w.overall_att) < 1e-10 + + def test_survey_metadata_fields(self, staggered_survey_data, survey_design_full): + """survey_metadata has correct fields with full design.""" + result = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_full, + ) + sm = result.survey_metadata + assert sm is not None + assert sm.weight_type == "pweight" + assert sm.effective_n > 0 + assert sm.design_effect > 0 + assert sm.n_strata is not None + assert sm.n_psu is not None + + def test_se_differs_with_design(self, staggered_survey_data): + """Weights-only vs full design: same ATT, different inference via survey df.""" + sd_w = SurveyDesign(weights="weight") + sd_full = SurveyDesign(weights="weight", strata="stratum", psu="psu") + + r_w = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd_w, + ) + r_full = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd_full, + ) + # ATTs should be the same (same weights) + assert abs(r_w.overall_att - r_full.overall_att) < 1e-10 + # Full design should carry survey df (strata/PSU structure) + assert r_full.survey_metadata is not None + assert r_full.survey_metadata.n_strata is not None + assert r_full.survey_metadata.n_psu is not None + # P-values should differ due to t-distribution with survey df + if np.isfinite(r_w.overall_p_value) and np.isfinite(r_full.overall_p_value): + assert r_w.overall_p_value != r_full.overall_p_value + + def test_weighted_att_differs(self, staggered_survey_data, survey_design_weights_only): + """Non-uniform survey weights should change the overall ATT.""" + r_unw = ImputationDiD().fit( + staggered_survey_data, "outcome", "unit", "period", "first_treat" + ) + r_w = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + # ATT should differ because non-uniform weights change aggregation + assert r_unw.overall_att != r_w.overall_att + + def test_event_study_with_survey(self, staggered_survey_data, survey_design_weights_only): + """Event study effects exist when using survey design.""" + result = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="event_study", + survey_design=survey_design_weights_only, + ) + assert result.event_study_effects is not None + assert len(result.event_study_effects) > 0 + for h, eff in result.event_study_effects.items(): + assert np.isfinite(eff["effect"]) + assert np.isfinite(eff["se"]) + + def test_bootstrap_survey_raises(self, staggered_survey_data, survey_design_weights_only): + """Bootstrap + survey should raise NotImplementedError.""" + with pytest.raises(NotImplementedError, match="[Bb]ootstrap"): + ImputationDiD(n_bootstrap=99).fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + + def test_summary_includes_survey(self, staggered_survey_data, survey_design_weights_only): + """Summary output should include survey design section.""" + result = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + summary = result.summary() + assert "Survey Design" in summary + assert "pweight" in summary + + +# ============================================================================= +# TestTwoStageDiDSurvey +# ============================================================================= + + +class TestTwoStageDiDSurvey: + """Survey design support for TwoStageDiD.""" + + def test_smoke_weights_only(self, staggered_survey_data, survey_design_weights_only): + """TwoStageDiD runs with weights-only survey design.""" + result = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.survey_metadata is not None + + def test_uniform_weights_match_unweighted(self, staggered_survey_data): + """Uniform survey weights should match unweighted result.""" + staggered_survey_data["uniform_w"] = 1.0 + sd = SurveyDesign(weights="uniform_w") + + r_unw = TwoStageDiD().fit(staggered_survey_data, "outcome", "unit", "period", "first_treat") + r_w = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + ) + assert abs(r_unw.overall_att - r_w.overall_att) < 1e-10 + + def test_survey_metadata_fields(self, staggered_survey_data, survey_design_full): + """survey_metadata has correct fields with full design.""" + result = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_full, + ) + sm = result.survey_metadata + assert sm is not None + assert sm.weight_type == "pweight" + assert sm.effective_n > 0 + assert sm.design_effect > 0 + assert sm.n_strata is not None + assert sm.n_psu is not None + + def test_se_differs_with_design(self, staggered_survey_data): + """Weights-only vs full design: same ATT, different inference via survey df.""" + sd_w = SurveyDesign(weights="weight") + sd_full = SurveyDesign(weights="weight", strata="stratum", psu="psu") + + r_w = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd_w, + ) + r_full = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd_full, + ) + # ATTs should be the same (same weights) + assert abs(r_w.overall_att - r_full.overall_att) < 1e-10 + # Full design should carry survey df (strata/PSU structure) + assert r_full.survey_metadata is not None + assert r_full.survey_metadata.n_strata is not None + assert r_full.survey_metadata.n_psu is not None + # P-values should differ due to t-distribution with survey df + if np.isfinite(r_w.overall_p_value) and np.isfinite(r_full.overall_p_value): + assert r_w.overall_p_value != r_full.overall_p_value + + def test_weighted_gmm_variance(self, staggered_survey_data, survey_design_weights_only): + """GMM SE should differ from unweighted (weights affect sandwich).""" + r_unw = TwoStageDiD().fit(staggered_survey_data, "outcome", "unit", "period", "first_treat") + r_w = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + # SE magnitude should differ (not just sign) + assert abs(r_unw.overall_se - r_w.overall_se) > 1e-6 + + def test_bootstrap_survey_raises(self, staggered_survey_data, survey_design_weights_only): + """Bootstrap + survey should raise NotImplementedError.""" + with pytest.raises(NotImplementedError, match="[Bb]ootstrap"): + TwoStageDiD(n_bootstrap=99).fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + + def test_summary_includes_survey(self, staggered_survey_data, survey_design_weights_only): + """Summary output should include survey design section.""" + result = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + summary = result.summary() + assert "Survey Design" in summary + assert "pweight" in summary + + +# ============================================================================= +# TestCallawaySantAnnaSurvey +# ============================================================================= + + +class TestCallawaySantAnnaSurvey: + """Survey design support for CallawaySantAnna.""" + + def test_smoke_reg_weights_only(self, staggered_survey_data, survey_design_weights_only): + """CallawaySantAnna regression method works with survey design.""" + result = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.survey_metadata is not None + + def test_smoke_ipw_weights_only(self, staggered_survey_data, survey_design_weights_only): + """CallawaySantAnna IPW method works with survey design.""" + result = CallawaySantAnna(estimation_method="ipw").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.survey_metadata is not None + + def test_smoke_dr_weights_only(self, staggered_survey_data, survey_design_weights_only): + """CallawaySantAnna DR method works with survey design.""" + result = CallawaySantAnna(estimation_method="dr").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.survey_metadata is not None + + def test_uniform_weights_match_unweighted(self, staggered_survey_data): + """Uniform survey weights should match unweighted result — all methods.""" + staggered_survey_data["uniform_w"] = 1.0 + sd = SurveyDesign(weights="uniform_w") + + for method in ["reg", "ipw", "dr"]: + r_unw = CallawaySantAnna(estimation_method=method).fit( + staggered_survey_data, "outcome", "unit", "period", "first_treat" + ) + r_w = CallawaySantAnna(estimation_method=method).fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + ) + assert abs(r_unw.overall_att - r_w.overall_att) < 1e-8, f"method={method}: ATT mismatch" + + def test_survey_metadata_fields(self, staggered_survey_data, survey_design_full): + """survey_metadata has correct fields with full design.""" + result = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_full, + ) + sm = result.survey_metadata + assert sm is not None + assert sm.weight_type == "pweight" + assert sm.effective_n > 0 + assert sm.design_effect > 0 + assert sm.n_strata is not None + assert sm.n_psu is not None + + def test_se_differs_with_design(self, staggered_survey_data): + """Weights-only vs full design: same ATT, different inference via survey df.""" + sd_w = SurveyDesign(weights="weight") + sd_full = SurveyDesign(weights="weight", strata="stratum", psu="psu") + + r_w = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd_w, + ) + r_full = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd_full, + ) + # ATTs should be the same (same weights) + assert abs(r_w.overall_att - r_full.overall_att) < 1e-10 + # Full design should carry survey df (strata/PSU structure) + assert r_full.survey_metadata is not None + assert r_full.survey_metadata.n_strata is not None + assert r_full.survey_metadata.n_psu is not None + # P-values should differ due to t-distribution with survey df + if np.isfinite(r_w.overall_p_value) and np.isfinite(r_full.overall_p_value): + assert r_w.overall_p_value != r_full.overall_p_value + + def test_bootstrap_survey_raises(self, staggered_survey_data, survey_design_weights_only): + """Bootstrap + survey should raise NotImplementedError.""" + with pytest.raises(NotImplementedError, match="[Bb]ootstrap"): + CallawaySantAnna(estimation_method="reg", n_bootstrap=99).fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + + def test_weighted_logit(self, staggered_survey_data, survey_design_weights_only): + """Propensity scores should change with survey weights (IPW path).""" + r_unw = CallawaySantAnna(estimation_method="ipw").fit( + staggered_survey_data, "outcome", "unit", "period", "first_treat" + ) + r_w = CallawaySantAnna(estimation_method="ipw").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + # Non-uniform weights should produce different ATT + # (propensity scores change with survey weights) + assert r_unw.overall_att != r_w.overall_att + + def test_ipw_survey_weight_composition(self, staggered_survey_data, survey_design_weights_only): + """w_survey x w_ipw should compose — ATT differs from unweighted IPW.""" + r_unw = CallawaySantAnna(estimation_method="ipw").fit( + staggered_survey_data, "outcome", "unit", "period", "first_treat" + ) + r_w = CallawaySantAnna(estimation_method="ipw").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + # Weighted IPW should produce different ATT than unweighted + assert abs(r_unw.overall_att - r_w.overall_att) > 1e-6 + + def test_aggregation_with_survey(self, staggered_survey_data, survey_design_weights_only): + """Simple aggregation should use survey weights.""" + r_unw = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="event_study", + ) + r_w = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="event_study", + survey_design=survey_design_weights_only, + ) + # Event study ATTs should differ with non-uniform weights + assert r_unw.overall_att != r_w.overall_att + # Event study effects should exist + assert r_w.event_study_effects is not None + assert len(r_w.event_study_effects) > 0 + + def test_summary_includes_survey(self, staggered_survey_data, survey_design_weights_only): + """Summary output should include survey design section.""" + result = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + summary = result.summary() + assert "Survey Design" in summary + assert "pweight" in summary + + +# ============================================================================= +# TestTripleDifferenceIPWSurvey +# ============================================================================= + + +class TestTripleDifferenceIPWSurvey: + """Verify TripleDifference IPW/DR + survey is unblocked.""" + + def test_ipw_survey_no_longer_raises(self, ddd_survey_data): + """IPW + survey should no longer raise NotImplementedError.""" + sd = SurveyDesign(weights="weight") + # Should not raise + result = TripleDifference(estimation_method="ipw").fit( + ddd_survey_data, + "outcome", + "group", + "partition", + "time", + survey_design=sd, + ) + assert result is not None + + def test_dr_survey_no_longer_raises(self, ddd_survey_data): + """DR + survey should no longer raise NotImplementedError.""" + sd = SurveyDesign(weights="weight") + # Should not raise + result = TripleDifference(estimation_method="dr").fit( + ddd_survey_data, + "outcome", + "group", + "partition", + "time", + survey_design=sd, + ) + assert result is not None + + def test_ipw_survey_results_finite(self, ddd_survey_data): + """IPW + survey should produce finite results.""" + sd = SurveyDesign(weights="weight") + result = TripleDifference(estimation_method="ipw").fit( + ddd_survey_data, + "outcome", + "group", + "partition", + "time", + survey_design=sd, + ) + assert np.isfinite(result.att) + assert np.isfinite(result.se) + assert result.survey_metadata is not None + + +# ============================================================================= +# TestScaleInvariance +# ============================================================================= + + +class TestScaleInvariance: + """Multiplying all survey weights by a constant should not change ATT or SE.""" + + def test_weight_scale_invariance_imputation(self, staggered_survey_data): + """ImputationDiD: 2*w gives same ATT as w.""" + sd1 = SurveyDesign(weights="weight") + r1 = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd1, + ) + + staggered_survey_data["weight_x2"] = staggered_survey_data["weight"] * 2.0 + sd2 = SurveyDesign(weights="weight_x2") + r2 = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd2, + ) + + assert abs(r1.overall_att - r2.overall_att) < 1e-10 + assert abs(r1.overall_se - r2.overall_se) < 1e-8 + + def test_weight_scale_invariance_two_stage(self, staggered_survey_data): + """TwoStageDiD: 2*w gives same ATT as w.""" + sd1 = SurveyDesign(weights="weight") + r1 = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd1, + ) + + staggered_survey_data["weight_x2"] = staggered_survey_data["weight"] * 2.0 + sd2 = SurveyDesign(weights="weight_x2") + r2 = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd2, + ) + + assert abs(r1.overall_att - r2.overall_att) < 1e-10 + assert abs(r1.overall_se - r2.overall_se) < 1e-8 + + def test_weight_scale_invariance_callaway_santanna(self, staggered_survey_data): + """CallawaySantAnna: 2*w gives same ATT as w.""" + sd1 = SurveyDesign(weights="weight") + r1 = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd1, + ) + + staggered_survey_data["weight_x2"] = staggered_survey_data["weight"] * 2.0 + sd2 = SurveyDesign(weights="weight_x2") + r2 = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd2, + ) + + assert abs(r1.overall_att - r2.overall_att) < 1e-10 + assert abs(r1.overall_se - r2.overall_se) < 1e-8 From 410e717b1cecbb45d08fa6dfc24793c1773600eb Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 22 Mar 2026 16:39:43 -0400 Subject: [PATCH 02/14] Address AI review findings: document CS per-cell SE deviation, log pre-existing Hausman bug, add inference tests - Add REGISTRY.md deviation note: CS per-cell ATT(g,t) SEs use influence-function variance (matching R's did), not full TSL with strata/PSU/FPC - Add TODO entries: CS per-cell TSL SE (Medium), Hausman stale n_cl from #230 (Medium) - Add 3 CS survey inference validation tests (scale invariance, per-cell ATT change, survey df effect) - Remove unused survey_weight_type in ImputationDiD Co-Authored-By: Claude Opus 4.6 (1M context) --- TODO.md | 2 + diff_diff/imputation.py | 4 +- docs/methodology/REGISTRY.md | 1 + tests/test_survey_phase4.py | 100 +++++++++++++++++++++++++++++++++++ 4 files changed, 105 insertions(+), 2 deletions(-) diff --git a/TODO.md b/TODO.md index 913282b3..1082281d 100644 --- a/TODO.md +++ b/TODO.md @@ -52,6 +52,8 @@ Deferred items from PR reviews that were not addressed before merge. | ImputationDiD dense `(A0'A0).toarray()` scales O((U+T+K)^2), OOM risk on large panels | `imputation.py` | #141 | Medium (deferred — only triggers when sparse solver fails; fixing requires sparse least-squares alternatives) | | EfficientDiD: API docs / tutorial page for new public estimator | `docs/` | #192 | Medium | | Multi-absorb weighted demeaning needs iterative alternating projections for N > 1 absorbed FE with survey weights; unweighted multi-absorb also uses single-pass (pre-existing, exact only for balanced panels) | `estimators.py` | #218 | Medium | +| CallawaySantAnna per-cell ATT(g,t) SEs under survey use influence-function variance, not full design-based TSL with strata/PSU/FPC. Design effects enter at aggregation via WIF and survey df. Full per-cell TSL would require constructing unit-level influence functions on the global index and passing through `compute_survey_vcov()`. | `staggered.py` | — | Medium | +| EfficientDiD hausman_pretest() clustered covariance uses stale `n_cl` after filtering non-finite EIF rows — should recompute effective cluster count and remap indices after `row_finite` filtering | `efficient_did.py` | #230 | Medium | | TripleDifference power: `generate_ddd_data` is a fixed 2×2×2 cross-sectional DGP — no multi-period or unbalanced-group support. Add a `generate_ddd_panel_data` for panel DDD power analysis. | `prep_dgp.py`, `power.py` | #208 | Low | | ContinuousDiD event-study aggregation does not filter by `anticipation` — uses all (g,t) cells instead of anticipation-filtered subset; pre-existing in both survey and non-survey paths | `continuous_did.py` | #226 | Medium | | Survey design resolution/collapse patterns are inconsistent across panel estimators — ContinuousDiD rebuilds unit-level design in SE code, EfficientDiD builds once in fit(), StackedDiD re-resolves on stacked data; extract shared helpers for panel-to-unit collapse, post-filter re-resolution, and metadata recomputation | `continuous_did.py`, `efficient_did.py`, `stacked_did.py` | #226 | Low | diff --git a/diff_diff/imputation.py b/diff_diff/imputation.py index b029b08c..d7322a48 100644 --- a/diff_diff/imputation.py +++ b/diff_diff/imputation.py @@ -234,8 +234,8 @@ def fit( _validate_unit_constant_survey, ) - resolved_survey, survey_weights, survey_weight_type, survey_metadata = ( - _resolve_survey_for_fit(survey_design, data, "analytical") + resolved_survey, survey_weights, _, survey_metadata = _resolve_survey_for_fit( + survey_design, data, "analytical" ) # Validate within-unit constancy for panel survey designs diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 03cf5958..8cc000bf 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -417,6 +417,7 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1: - Does not require never-treated units: when all units are eventually treated, not-yet-treated cohorts serve as controls for each other (requires ≥2 cohorts) - **Note:** CallawaySantAnna survey weights compose with IPW weights multiplicatively: w_total = w_survey * p(X) / (1-p(X)). Propensity scores estimated via survey-weighted solve_logit(). WIF in aggregation uses survey-weighted group sizes. Bootstrap + survey deferred. +- **Note (deviation from R):** Per-cell ATT(g,t) SEs under survey weights use influence-function-based variance (matching R's `did::att_gt` analytical SE path) rather than full Taylor-series linearization with strata/PSU/FPC structure. The survey design structure is reflected in aggregation-level SEs via the WIF and survey degrees of freedom, but individual (g,t) cell SEs do not incorporate the full design-based variance. This is consistent with R's approach where per-cell SEs are influence-function-based and design effects enter at the aggregation stage. **Reference implementation(s):** - R: `did::att_gt()` (Callaway & Sant'Anna's official package) diff --git a/tests/test_survey_phase4.py b/tests/test_survey_phase4.py index 80caf6b7..d95dc694 100644 --- a/tests/test_survey_phase4.py +++ b/tests/test_survey_phase4.py @@ -716,6 +716,106 @@ def test_ipw_survey_results_finite(self, ddd_survey_data): assert result.survey_metadata is not None +# ============================================================================= +# TestCallawaySantAnnaSurveyInference +# ============================================================================= + + +class TestCallawaySantAnnaSurveyInference: + """Validate CS survey inference beyond smoke tests.""" + + def test_se_scale_invariance_all_methods(self, staggered_survey_data): + """SE should be invariant under weight rescaling for all methods.""" + data = staggered_survey_data + data = data.copy() + data["weight2"] = data["weight"] * 3.7 + sd1 = SurveyDesign(weights="weight") + sd2 = SurveyDesign(weights="weight2") + + for method in ["reg", "ipw", "dr"]: + est = CallawaySantAnna(estimation_method=method) + r1 = est.fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="simple", + survey_design=sd1, + ) + r2 = est.fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="simple", + survey_design=sd2, + ) + assert np.isclose( + r1.overall_att, r2.overall_att, atol=1e-8 + ), f"{method}: ATT not scale-invariant" + assert np.isclose( + r1.overall_se, r2.overall_se, atol=1e-8 + ), f"{method}: SE not scale-invariant" + + def test_survey_weights_change_per_cell_att(self, staggered_survey_data): + """Non-uniform survey weights should change per-cell ATT(g,t).""" + data = staggered_survey_data + sd = SurveyDesign(weights="weight") + for method in ["reg", "ipw", "dr"]: + r_no = CallawaySantAnna(estimation_method=method).fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + r_sv = CallawaySantAnna(estimation_method=method).fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + survey_design=sd, + ) + # At least one cell ATT should differ + effects_no = [d["effect"] for d in r_no.group_time_effects.values()] + effects_sv = [d["effect"] for d in r_sv.group_time_effects.values()] + assert not np.allclose( + effects_no, effects_sv, atol=1e-6 + ), f"{method}: survey weights should change per-cell ATT" + + def test_survey_df_affects_pvalues(self, staggered_survey_data): + """Survey df (from strata/PSU) should affect p-values via t-distribution.""" + data = staggered_survey_data + sd_weights = SurveyDesign(weights="weight") + sd_full = SurveyDesign(weights="weight", strata="stratum", psu="psu") + est = CallawaySantAnna(estimation_method="reg") + r_w = est.fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="simple", + survey_design=sd_weights, + ) + r_f = est.fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="simple", + survey_design=sd_full, + ) + # ATT should be same (same weights), but p-values differ (different df) + assert np.isclose(r_w.overall_att, r_f.overall_att, atol=1e-8) + # Survey df from strata/PSU should change inference + assert r_f.survey_metadata.df_survey is not None + + # ============================================================================= # TestScaleInvariance # ============================================================================= From 604b09901ffaed71a6969b5ee9969b2342d5fc10 Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 22 Mar 2026 19:15:28 -0400 Subject: [PATCH 03/14] Fix P0/P1 findings from AI review: TripleDiff IPW/DR survey threading, CS SE formula - P0: Thread survey_weights through TripleDifference IPW and DR call chains (_ipw_estimation, _doubly_robust, _compute_did_rc_ipw, _compute_did_rc_dr). Survey weights now enter Riesz representers for weighted Hajek averages. - P1: Fix CallawaySantAnna no-covariate survey SE to derive from sum(IF^2) instead of sum(w_norm * (y-mean)^2). All 4 locations now consistent with stored influence functions. - P1: Update REGISTRY.md TripleDifference entry to reflect full survey support (was still marked as "IPW/DR deferred"). - P2: Add behavioral tests for TripleDiff IPW/DR survey: non-uniform weights change ATT, uniform weights match unweighted. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/staggered.py | 34 ++++++++------ diff_diff/triple_diff.py | 76 +++++++++++++++++++++++++++++-- docs/methodology/REGISTRY.md | 3 +- tests/test_survey_phase4.py | 88 ++++++++++++++++++++++++++++++++++++ 4 files changed, 181 insertions(+), 20 deletions(-) diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 5af63d24..c1f99686 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -715,13 +715,15 @@ def _compute_all_att_gt_vectorized( mu_c = float(np.sum(sw_c_norm * control_change)) att = mu_t - mu_c - var_t = float(np.sum(sw_t_norm * (treated_change - mu_t) ** 2)) - var_c = float(np.sum(sw_c_norm * (control_change - mu_c) ** 2)) - se = float(np.sqrt(var_t + var_c)) if (n_t > 0 and n_c > 0) else 0.0 - # Influence function (survey-weighted) inf_treated = sw_t_norm * (treated_change - mu_t) inf_control = -sw_c_norm * (control_change - mu_c) + # SE derived from IF: sum(IF_i^2) + se = ( + float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2))) + if (n_t > 0 and n_c > 0) + else 0.0 + ) sw_sum = float(np.sum(sw_t)) else: att = float(np.mean(treated_change) - np.mean(control_change)) @@ -1624,9 +1626,11 @@ def _outcome_regression( inf_func = np.concatenate([inf_treated, inf_control]) # SE from influence function variance - var_t = float(np.sum(sw_t_norm * (treated_change - mu_t) ** 2)) - var_c = float(np.sum(sw_c_norm * (control_change - mu_c) ** 2)) - se = float(np.sqrt(var_t + var_c)) if (n_t > 0 and n_c > 0) else 0.0 + se = ( + float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2))) + if (n_t > 0 and n_c > 0) + else 0.0 + ) else: att = float(np.mean(treated_change) - np.mean(control_change)) @@ -1787,9 +1791,11 @@ def _ipw_estimation( inf_control = -sw_c_norm * (control_change - mu_c) inf_func = np.concatenate([inf_treated, inf_control]) - var_t = float(np.sum(sw_t_norm * (treated_change - mu_t) ** 2)) - var_c = float(np.sum(sw_c_norm * (control_change - mu_c) ** 2)) - se = float(np.sqrt(var_t + var_c)) if (n_t > 0 and n_c > 0) else 0.0 + se = ( + float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2))) + if (n_t > 0 and n_c > 0) + else 0.0 + ) else: p_treat = n_treated / n_total # unconditional propensity score @@ -1998,9 +2004,11 @@ def _doubly_robust( inf_control = -sw_c_norm * (control_change - mu_c) inf_func = np.concatenate([inf_treated, inf_control]) - var_t = float(np.sum(sw_t_norm * (treated_change - mu_t) ** 2)) - var_c = float(np.sum(sw_c_norm * (control_change - mu_c) ** 2)) - se = float(np.sqrt(var_t + var_c)) if (n_t > 0 and n_c > 0) else 0.0 + se = ( + float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2))) + if (n_t > 0 and n_c > 0) + else 0.0 + ) else: att = float(np.mean(treated_change) - np.mean(control_change)) diff --git a/diff_diff/triple_diff.py b/diff_diff/triple_diff.py index da4babff..0c492046 100644 --- a/diff_diff/triple_diff.py +++ b/diff_diff/triple_diff.py @@ -565,9 +565,25 @@ def fit( resolved_survey=resolved_survey, ) elif self.estimation_method == "ipw": - att, se, r_squared, pscore_stats = self._ipw_estimation(y, G, P, T, X) + att, se, r_squared, pscore_stats = self._ipw_estimation( + y, + G, + P, + T, + X, + survey_weights=survey_weights, + resolved_survey=resolved_survey, + ) else: # doubly robust - att, se, r_squared, pscore_stats = self._doubly_robust(y, G, P, T, X) + att, se, r_squared, pscore_stats = self._doubly_robust( + y, + G, + P, + T, + X, + survey_weights=survey_weights, + resolved_survey=resolved_survey, + ) # Compute inference # When survey design is active, use survey df (n_PSU - n_strata) @@ -758,6 +774,8 @@ def _ipw_estimation( P: np.ndarray, T: np.ndarray, X: Optional[np.ndarray], + survey_weights: Optional[np.ndarray] = None, + resolved_survey=None, ) -> Tuple[float, float, Optional[float], Optional[Dict[str, float]]]: """ Estimate ATT using inverse probability weighting via three-DiD @@ -767,7 +785,15 @@ def _ipw_estimation( subgroup membership P(subgroup=4|X) within {j, 4} subset. Matches R's triplediff::ddd() with est_method="ipw". """ - return self._estimate_ddd_decomposition(y, G, P, T, X) + return self._estimate_ddd_decomposition( + y, + G, + P, + T, + X, + survey_weights=survey_weights, + resolved_survey=resolved_survey, + ) def _doubly_robust( self, @@ -776,6 +802,8 @@ def _doubly_robust( P: np.ndarray, T: np.ndarray, X: Optional[np.ndarray], + survey_weights: Optional[np.ndarray] = None, + resolved_survey=None, ) -> Tuple[float, float, Optional[float], Optional[Dict[str, float]]]: """ Estimate ATT using doubly robust estimation via three-DiD @@ -786,7 +814,15 @@ def _doubly_robust( correctly specified. Matches R's triplediff::ddd() with est_method="dr". """ - return self._estimate_ddd_decomposition(y, G, P, T, X) + return self._estimate_ddd_decomposition( + y, + G, + P, + T, + X, + survey_weights=survey_weights, + resolved_survey=resolved_survey, + ) def _estimate_ddd_decomposition( self, @@ -1186,7 +1222,17 @@ def _compute_did_rc( Matches R's triplediff::compute_did_rc(). """ if est_method == "ipw": - return self._compute_did_rc_ipw(y, post, PA4, PAa, pscore, covX, hessian, n) + return self._compute_did_rc_ipw( + y, + post, + PA4, + PAa, + pscore, + covX, + hessian, + n, + weights=weights, + ) elif est_method == "reg": return self._compute_did_rc_reg( y, @@ -1215,6 +1261,7 @@ def _compute_did_rc( or_trt_post, hessian, n, + weights=weights, ) def _compute_did_rc_ipw( @@ -1227,6 +1274,7 @@ def _compute_did_rc_ipw( covX: np.ndarray, hessian: Optional[np.ndarray], n: int, + weights: Optional[np.ndarray] = None, ) -> Tuple[float, np.ndarray]: """IPW DiD for a single pairwise comparison (RC).""" # Riesz representers (IPW weights * indicators) @@ -1235,6 +1283,13 @@ def _compute_did_rc_ipw( riesz_control_pre = pscore * PAa * (1 - post) / (1 - pscore) riesz_control_post = pscore * PAa * post / (1 - pscore) + # Incorporate survey weights into Riesz representers + if weights is not None: + riesz_treat_pre = riesz_treat_pre * weights + riesz_treat_post = riesz_treat_post * weights + riesz_control_pre = riesz_control_pre * weights + riesz_control_post = riesz_control_post * weights + # Hajek-normalized cell-time means def _hajek(riesz, y_vals): denom = np.mean(riesz) @@ -1393,6 +1448,7 @@ def _compute_did_rc_dr( or_trt_post: np.ndarray, hessian: Optional[np.ndarray], n: int, + weights: Optional[np.ndarray] = None, ) -> Tuple[float, np.ndarray]: """Doubly robust DiD for a single pairwise comparison (RC).""" or_ctrl = post * or_ctrl_post + (1 - post) * or_ctrl_pre @@ -1406,6 +1462,16 @@ def _compute_did_rc_dr( riesz_dt1 = PA4 * post riesz_dt0 = PA4 * (1 - post) + # Incorporate survey weights into Riesz representers + if weights is not None: + riesz_treat_pre = riesz_treat_pre * weights + riesz_treat_post = riesz_treat_post * weights + riesz_control_pre = riesz_control_pre * weights + riesz_control_post = riesz_control_post * weights + riesz_d = riesz_d * weights + riesz_dt1 = riesz_dt1 * weights + riesz_dt0 = riesz_dt0 * weights + # DR cell-time components def _safe_ratio(num, denom): return num / denom if denom > 0 else 0.0 diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 8cc000bf..b91f60d3 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1245,8 +1245,7 @@ has no additional effect. - [x] Influence function SE: std(w3·IF_3 + w2·IF_2 - w1·IF_1) / sqrt(n) - [x] Cluster-robust SE via Liang-Zeger variance on influence function - [x] ATT and SE match R within <0.001% for all methods and DGP types -- [x] Survey design support (Phase 3): regression method with weighted OLS + TSL on combined influence functions; IPW/DR deferred -- **Note:** TripleDifference IPW/DR with survey weights deferred until weighted solve_logit() (Phase 5) +- [x] Survey design support: all methods (reg, IPW, DR) with weighted OLS/logit + TSL on combined influence functions. Weighted solve_logit() for propensity scores in IPW/DR paths. --- diff --git a/tests/test_survey_phase4.py b/tests/test_survey_phase4.py index d95dc694..00371e93 100644 --- a/tests/test_survey_phase4.py +++ b/tests/test_survey_phase4.py @@ -715,6 +715,94 @@ def test_ipw_survey_results_finite(self, ddd_survey_data): assert np.isfinite(result.se) assert result.survey_metadata is not None + def test_ipw_nonuniform_weights_change_att(self, ddd_survey_data): + """Non-uniform survey weights should change IPW ATT vs unweighted.""" + sd = SurveyDesign(weights="weight") + r_no = TripleDifference(estimation_method="ipw").fit( + ddd_survey_data, + "outcome", + "group", + "partition", + "time", + ) + r_sv = TripleDifference(estimation_method="ipw").fit( + ddd_survey_data, + "outcome", + "group", + "partition", + "time", + survey_design=sd, + ) + assert not np.isclose( + r_no.att, r_sv.att, atol=1e-6 + ), "Non-uniform survey weights should change IPW ATT" + + def test_dr_nonuniform_weights_change_att(self, ddd_survey_data): + """Non-uniform survey weights should change DR ATT vs unweighted.""" + sd = SurveyDesign(weights="weight") + r_no = TripleDifference(estimation_method="dr").fit( + ddd_survey_data, + "outcome", + "group", + "partition", + "time", + ) + r_sv = TripleDifference(estimation_method="dr").fit( + ddd_survey_data, + "outcome", + "group", + "partition", + "time", + survey_design=sd, + ) + assert not np.isclose( + r_no.att, r_sv.att, atol=1e-6 + ), "Non-uniform survey weights should change DR ATT" + + def test_ipw_uniform_weights_match_unweighted(self, ddd_survey_data): + """Uniform survey weights should match unweighted IPW result.""" + data = ddd_survey_data.copy() + data["uw"] = 1.0 + sd = SurveyDesign(weights="uw") + r_no = TripleDifference(estimation_method="ipw").fit( + data, + "outcome", + "group", + "partition", + "time", + ) + r_sv = TripleDifference(estimation_method="ipw").fit( + data, + "outcome", + "group", + "partition", + "time", + survey_design=sd, + ) + assert np.isclose(r_no.att, r_sv.att, atol=1e-6) + + def test_dr_uniform_weights_match_unweighted(self, ddd_survey_data): + """Uniform survey weights should match unweighted DR result.""" + data = ddd_survey_data.copy() + data["uw"] = 1.0 + sd = SurveyDesign(weights="uw") + r_no = TripleDifference(estimation_method="dr").fit( + data, + "outcome", + "group", + "partition", + "time", + ) + r_sv = TripleDifference(estimation_method="dr").fit( + data, + "outcome", + "group", + "partition", + "time", + survey_design=sd, + ) + assert np.isclose(r_no.att, r_sv.att, atol=1e-6) + # ============================================================================= # TestCallawaySantAnnaSurveyInference From 61e920e57bec0c545d62a560b962d6d5c0e8a0ae Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 22 Mar 2026 19:54:43 -0400 Subject: [PATCH 04/14] Fix round-2 review P1s: weighted v_it denominators, DDD covariate IF, always-treated survey - ImputationDiD: use survey-weighted untreated sums (not raw counts) in Theorem 3 v_it FE-only path, and weighted normal equations (A_0'W_0 A_0) in covariate path - TripleDifference: weight PS Hessian, score correction, and DR OLS linear representations with survey weights for covariate-adjusted IPW/DR IF corrections; update fit() docstring to reflect full survey support - TwoStageDiD: subset survey_weights and resolved_survey arrays after always-treated unit exclusion to prevent length mismatch - Add tests: TwoStageDiD always-treated + survey, DDD covariate IPW/DR + survey Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/imputation.py | 36 +++++++++-- diff_diff/triple_diff.py | 120 +++++++++++++++++++++++------------- diff_diff/two_stage.py | 31 +++++++++- tests/test_survey_phase4.py | 78 +++++++++++++++++++++++ 4 files changed, 215 insertions(+), 50 deletions(-) diff --git a/diff_diff/imputation.py b/diff_diff/imputation.py index d7322a48..f71e1422 100644 --- a/diff_diff/imputation.py +++ b/diff_diff/imputation.py @@ -471,6 +471,7 @@ def fit( weights=overall_weights, cluster_var=cluster_var, kept_cov_mask=kept_cov_mask, + survey_weights=survey_weights, ) # Survey degrees of freedom for t-distribution inference @@ -1036,6 +1037,7 @@ def _compute_cluster_psi_sums( weights: np.ndarray, cluster_var: str, kept_cov_mask: Optional[np.ndarray] = None, + survey_weights_0: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, np.ndarray]: """ Compute cluster-level influence function sums (Theorem 3). @@ -1075,8 +1077,16 @@ def _compute_cluster_psi_sums( w_total = float(np.sum(weights)) - n0_by_unit = df_0.groupby(unit).size().to_dict() - n0_by_time = df_0.groupby(time).size().to_dict() + # Use survey-weighted sums for untreated denominators when present + if survey_weights_0 is not None: + sw0_series = pd.Series(survey_weights_0, index=df_0.index) + n0_by_unit = sw0_series.groupby(df_0[unit]).sum().to_dict() + n0_by_time = sw0_series.groupby(df_0[time]).sum().to_dict() + n0_denom = float(np.sum(survey_weights_0)) + else: + n0_by_unit = df_0.groupby(unit).size().to_dict() + n0_by_time = df_0.groupby(time).size().to_dict() + n0_denom = n_0 untreated_units = df_0[unit].values untreated_times = df_0[time].values @@ -1089,7 +1099,7 @@ def _compute_cluster_psi_sums( w_t = w_by_time.get(t, 0.0) n0_i = n0_by_unit.get(u, 1) n0_t = n0_by_time.get(t, 1) - v_untreated[j] = -(w_i / n0_i + w_t / n0_t - w_total / n_0) + v_untreated[j] = -(w_i / n0_i + w_t / n0_t - w_total / n0_denom) else: v_untreated = self._compute_v_untreated_with_covariates( df_0, @@ -1100,6 +1110,7 @@ def _compute_cluster_psi_sums( weights, delta_hat, kept_cov_mask=kept_cov_mask, + survey_weights_0=survey_weights_0, ) # ---- Compute auxiliary model residuals (Equation 8) ---- @@ -1158,6 +1169,7 @@ def _compute_conservative_variance( weights: np.ndarray, cluster_var: str, kept_cov_mask: Optional[np.ndarray] = None, + survey_weights: Optional[np.ndarray] = None, ) -> float: """ Compute conservative clustered variance (Theorem 3, Equation 7). @@ -1167,12 +1179,16 @@ def _compute_conservative_variance( weights : np.ndarray Aggregation weights w_it for treated observations. Shape: (n_treated,), must sum to 1. + survey_weights : np.ndarray, optional + Full-panel survey weights. When provided, untreated denominators + in v_it use survey-weighted sums instead of raw counts. Returns ------- float Standard error. """ + sw_0 = survey_weights[omega_0_mask.values] if survey_weights is not None else None cluster_psi_sums, _ = self._compute_cluster_psi_sums( df=df, outcome=outcome, @@ -1189,6 +1205,7 @@ def _compute_conservative_variance( weights=weights, cluster_var=cluster_var, kept_cov_mask=kept_cov_mask, + survey_weights_0=sw_0, ) sigma_sq = float((cluster_psi_sums**2).sum()) return np.sqrt(max(sigma_sq, 0.0)) @@ -1203,11 +1220,14 @@ def _compute_v_untreated_with_covariates( weights: np.ndarray, delta_hat: Optional[np.ndarray], kept_cov_mask: Optional[np.ndarray] = None, + survey_weights_0: Optional[np.ndarray] = None, ) -> np.ndarray: """ Compute v_it for untreated observations with covariates. Uses the projection: v_untreated = -A_0 (A_0'A_0)^{-1} A_1' w_treated + When survey_weights_0 is provided, uses weighted normal equations: + v_untreated = -A_0 (A_0' W A_0)^{-1} A_1' w_treated Uses scipy.sparse for FE dummy columns to reduce memory from O(N*(U+T)) to O(N) for the FE portion. @@ -1266,8 +1286,12 @@ def _build_A_sparse(df_sub, unit_vals, time_vals): # Compute A_1' w (sparse.T @ dense -> dense) A1_w = A_1.T @ weights # shape (p,) - # Solve (A_0'A_0) z = A_1' w using sparse direct solver - A0tA0_sparse = A_0.T @ A_0 # stays sparse + # Solve (A_0' [W] A_0) z = A_1' w using sparse direct solver + # When survey weights present, use weighted normal equations A_0' W A_0 + if survey_weights_0 is not None: + A0tA0_sparse = A_0.T @ A_0.multiply(survey_weights_0[:, None]) + else: + A0tA0_sparse = A_0.T @ A_0 # stays sparse try: z = spsolve(A0tA0_sparse.tocsc(), A1_w) except Exception: @@ -1529,6 +1553,7 @@ def _aggregate_event_study( weights=weights_h, cluster_var=cluster_var, kept_cov_mask=kept_cov_mask, + survey_weights=survey_weights, ) t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha, df=survey_df) @@ -1661,6 +1686,7 @@ def _aggregate_group( weights=weights_g, cluster_var=cluster_var, kept_cov_mask=kept_cov_mask, + survey_weights=survey_weights, ) t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha, df=survey_df) diff --git a/diff_diff/triple_diff.py b/diff_diff/triple_diff.py index 0c492046..0129aade 100644 --- a/diff_diff/triple_diff.py +++ b/diff_diff/triple_diff.py @@ -480,8 +480,8 @@ def fit( survey_design : SurveyDesign, optional Survey design specification for complex survey data. When provided, uses survey weights for estimation and Taylor Series - Linearization (TSL) for variance estimation. Only supported - with estimation_method="reg". + Linearization (TSL) for variance estimation. Supported with + all estimation methods ("reg", "ipw", "dr"). Returns ------- @@ -943,6 +943,8 @@ def _estimate_ddd_decomposition( # Hessian only when PS was actually estimated if ps_estimated: W_ps = pscore_sub * (1 - pscore_sub) + if w_sub is not None: + W_ps = W_ps * w_sub try: XWX = covX_sub.T @ (W_ps[:, None] * covX_sub) hessian = np.linalg.inv(XWX) * n_sub @@ -1323,16 +1325,30 @@ def _hajek(riesz, y_vals): # Propensity score correction for influence function if hessian is not None: score_ps = (PA4 - pscore)[:, None] * covX + if weights is not None: + score_ps = score_ps * weights[:, None] asy_lin_rep_ps = score_ps @ hessian - M2_pre = np.mean( - (riesz_control_pre * (y - att_control_pre))[:, None] * covX, - axis=0, - ) / np.mean(riesz_control_pre) - M2_post = np.mean( - (riesz_control_post * (y - att_control_post))[:, None] * covX, - axis=0, - ) / np.mean(riesz_control_post) + if weights is not None: + M2_pre = np.average( + (riesz_control_pre * (y - att_control_pre))[:, None] * covX, + axis=0, + weights=weights, + ) / np.mean(riesz_control_pre) + M2_post = np.average( + (riesz_control_post * (y - att_control_post))[:, None] * covX, + axis=0, + weights=weights, + ) / np.mean(riesz_control_post) + else: + M2_pre = np.mean( + (riesz_control_pre * (y - att_control_pre))[:, None] * covX, + axis=0, + ) / np.mean(riesz_control_pre) + M2_post = np.mean( + (riesz_control_post * (y - att_control_post))[:, None] * covX, + axis=0, + ) / np.mean(riesz_control_post) inf_control_ps = asy_lin_rep_ps @ (M2_post - M2_pre) inf_control = inf_control + inf_control_ps @@ -1512,9 +1528,15 @@ def _safe_ratio(num, denom): # --- Influence function --- # OLS asymptotic linear representations (control subgroup) weights_ols_pre = PAa * (1 - post) - wols_x_pre = weights_ols_pre[:, None] * covX - wols_eX_pre = (weights_ols_pre * (y - or_ctrl_pre))[:, None] * covX - XpX_pre = wols_x_pre.T @ covX / n + if weights is not None: + w_sum = np.sum(weights) + wols_x_pre = (weights_ols_pre * weights)[:, None] * covX + wols_eX_pre = (weights_ols_pre * weights * (y - or_ctrl_pre))[:, None] * covX + XpX_pre = wols_x_pre.T @ covX / w_sum + else: + wols_x_pre = weights_ols_pre[:, None] * covX + wols_eX_pre = (weights_ols_pre * (y - or_ctrl_pre))[:, None] * covX + XpX_pre = wols_x_pre.T @ covX / n try: XpX_inv_pre = np.linalg.inv(XpX_pre) except np.linalg.LinAlgError: @@ -1522,9 +1544,14 @@ def _safe_ratio(num, denom): asy_lin_rep_ols_pre = wols_eX_pre @ XpX_inv_pre weights_ols_post = PAa * post - wols_x_post = weights_ols_post[:, None] * covX - wols_eX_post = (weights_ols_post * (y - or_ctrl_post))[:, None] * covX - XpX_post = wols_x_post.T @ covX / n + if weights is not None: + wols_x_post = (weights_ols_post * weights)[:, None] * covX + wols_eX_post = (weights_ols_post * weights * (y - or_ctrl_post))[:, None] * covX + XpX_post = wols_x_post.T @ covX / w_sum + else: + wols_x_post = weights_ols_post[:, None] * covX + wols_eX_post = (weights_ols_post * (y - or_ctrl_post))[:, None] * covX + XpX_post = wols_x_post.T @ covX / n try: XpX_inv_post = np.linalg.inv(XpX_post) except np.linalg.LinAlgError: @@ -1533,9 +1560,14 @@ def _safe_ratio(num, denom): # OLS representations (treated subgroup) weights_ols_pre_treat = PA4 * (1 - post) - wols_x_pre_treat = weights_ols_pre_treat[:, None] * covX - wols_eX_pre_treat = (weights_ols_pre_treat * (y - or_trt_pre))[:, None] * covX - XpX_pre_treat = wols_x_pre_treat.T @ covX / n + if weights is not None: + wols_x_pre_treat = (weights_ols_pre_treat * weights)[:, None] * covX + wols_eX_pre_treat = (weights_ols_pre_treat * weights * (y - or_trt_pre))[:, None] * covX + XpX_pre_treat = wols_x_pre_treat.T @ covX / w_sum + else: + wols_x_pre_treat = weights_ols_pre_treat[:, None] * covX + wols_eX_pre_treat = (weights_ols_pre_treat * (y - or_trt_pre))[:, None] * covX + XpX_pre_treat = wols_x_pre_treat.T @ covX / n try: XpX_inv_pre_treat = np.linalg.inv(XpX_pre_treat) except np.linalg.LinAlgError: @@ -1543,9 +1575,16 @@ def _safe_ratio(num, denom): asy_lin_rep_ols_pre_treat = wols_eX_pre_treat @ XpX_inv_pre_treat weights_ols_post_treat = PA4 * post - wols_x_post_treat = weights_ols_post_treat[:, None] * covX - wols_eX_post_treat = (weights_ols_post_treat * (y - or_trt_post))[:, None] * covX - XpX_post_treat = wols_x_post_treat.T @ covX / n + if weights is not None: + wols_x_post_treat = (weights_ols_post_treat * weights)[:, None] * covX + wols_eX_post_treat = (weights_ols_post_treat * weights * (y - or_trt_post))[ + :, None + ] * covX + XpX_post_treat = wols_x_post_treat.T @ covX / w_sum + else: + wols_x_post_treat = weights_ols_post_treat[:, None] * covX + wols_eX_post_treat = (weights_ols_post_treat * (y - or_trt_post))[:, None] * covX + XpX_post_treat = wols_x_post_treat.T @ covX / n try: XpX_inv_post_treat = np.linalg.inv(XpX_post_treat) except np.linalg.LinAlgError: @@ -1554,6 +1593,8 @@ def _safe_ratio(num, denom): # Propensity score linear representation score_ps = (PA4 - pscore)[:, None] * covX + if weights is not None: + score_ps = score_ps * weights[:, None] if hessian is not None: asy_lin_rep_ps = score_ps @ hessian else: @@ -1575,13 +1616,19 @@ def _safe_ratio(num, denom): ) # OR correction for treated + def _wmean_ax0(arr): + """Weighted or unweighted column mean.""" + if weights is not None: + return np.average(arr, axis=0, weights=weights) + return np.mean(arr, axis=0) + M1_post = ( - (-np.mean((riesz_treat_post * post)[:, None] * covX, axis=0) / m_riesz_treat_post) + (-_wmean_ax0((riesz_treat_post * post)[:, None] * covX) / m_riesz_treat_post) if m_riesz_treat_post > 0 else np.zeros(covX.shape[1]) ) M1_pre = ( - (-np.mean((riesz_treat_pre * (1 - post))[:, None] * covX, axis=0) / m_riesz_treat_pre) + (-_wmean_ax0((riesz_treat_pre * (1 - post))[:, None] * covX) / m_riesz_treat_pre) if m_riesz_treat_pre > 0 else np.zeros(covX.shape[1]) ) @@ -1606,9 +1653,7 @@ def _safe_ratio(num, denom): # PS correction for control M2_pre = ( ( - np.mean( - (riesz_control_pre * (y - or_ctrl - att_control_pre))[:, None] * covX, axis=0 - ) + _wmean_ax0((riesz_control_pre * (y - or_ctrl - att_control_pre))[:, None] * covX) / m_riesz_control_pre ) if m_riesz_control_pre > 0 @@ -1616,9 +1661,7 @@ def _safe_ratio(num, denom): ) M2_post = ( ( - np.mean( - (riesz_control_post * (y - or_ctrl - att_control_post))[:, None] * covX, axis=0 - ) + _wmean_ax0((riesz_control_post * (y - or_ctrl - att_control_post))[:, None] * covX) / m_riesz_control_post ) if m_riesz_control_post > 0 @@ -1628,15 +1671,12 @@ def _safe_ratio(num, denom): # OR correction for control M3_post = ( - (-np.mean((riesz_control_post * post)[:, None] * covX, axis=0) / m_riesz_control_post) + (-_wmean_ax0((riesz_control_post * post)[:, None] * covX) / m_riesz_control_post) if m_riesz_control_post > 0 else np.zeros(covX.shape[1]) ) M3_pre = ( - ( - -np.mean((riesz_control_pre * (1 - post))[:, None] * covX, axis=0) - / m_riesz_control_pre - ) + (-_wmean_ax0((riesz_control_pre * (1 - post))[:, None] * covX) / m_riesz_control_pre) if m_riesz_control_pre > 0 else np.zeros(covX.shape[1]) ) @@ -1664,18 +1704,12 @@ def _safe_ratio(num, denom): # OR combination mom_post = ( - np.mean( - (riesz_d[:, None] / m_riesz_d - riesz_dt1[:, None] / m_riesz_dt1) * covX, - axis=0, - ) + _wmean_ax0((riesz_d[:, None] / m_riesz_d - riesz_dt1[:, None] / m_riesz_dt1) * covX) if (m_riesz_d > 0 and m_riesz_dt1 > 0) else np.zeros(covX.shape[1]) ) mom_pre = ( - np.mean( - (riesz_d[:, None] / m_riesz_d - riesz_dt0[:, None] / m_riesz_dt0) * covX, - axis=0, - ) + _wmean_ax0((riesz_d[:, None] / m_riesz_d - riesz_dt0[:, None] / m_riesz_dt0) * covX) if (m_riesz_d > 0 and m_riesz_dt0 > 0) else np.zeros(covX.shape[1]) ) diff --git a/diff_diff/two_stage.py b/diff_diff/two_stage.py index 85f5e409..18cc0947 100644 --- a/diff_diff/two_stage.py +++ b/diff_diff/two_stage.py @@ -22,6 +22,7 @@ """ import warnings +from dataclasses import replace from typing import Any, Dict, List, Optional, Tuple import numpy as np @@ -286,6 +287,32 @@ def fit( ) df = df[~df[unit].isin(always_treated_units)].copy() + # Subset survey arrays to match filtered df + if survey_weights is not None: + keep_mask = ~data[unit].isin(always_treated_units) + survey_weights = survey_weights[keep_mask.values] + if resolved_survey is not None: + keep_mask = ~data[unit].isin(always_treated_units) + resolved_survey = replace( + resolved_survey, + weights=resolved_survey.weights[keep_mask.values], + strata=( + resolved_survey.strata[keep_mask.values] + if resolved_survey.strata is not None + else None + ), + psu=( + resolved_survey.psu[keep_mask.values] + if resolved_survey.psu is not None + else None + ), + fpc=( + resolved_survey.fpc[keep_mask.values] + if resolved_survey.fpc is not None + else None + ), + ) + # Treatment indicator with anticipation effective_treat = df[first_treat] - self.anticipation df["_treated"] = (~df["_never_treated"]) & (df[time] >= effective_treat) @@ -341,9 +368,9 @@ def fit( from diff_diff.survey import compute_survey_metadata raw_w = ( - data[survey_design.weights].values.astype(np.float64) + df[survey_design.weights].values.astype(np.float64) if survey_design.weights - else np.ones(len(data), dtype=np.float64) + else np.ones(len(df), dtype=np.float64) ) survey_metadata = compute_survey_metadata(resolved_survey, raw_w) diff --git a/tests/test_survey_phase4.py b/tests/test_survey_phase4.py index 00371e93..fc5b6ac1 100644 --- a/tests/test_survey_phase4.py +++ b/tests/test_survey_phase4.py @@ -460,6 +460,28 @@ def test_summary_includes_survey(self, staggered_survey_data, survey_design_weig assert "Survey Design" in summary assert "pweight" in summary + def test_always_treated_with_survey(self, staggered_survey_data): + """TwoStageDiD with survey + always-treated units should not crash.""" + data = staggered_survey_data.copy() + # Make some units always-treated (first_treat at or before min time) + min_time = data["period"].min() + units = data["unit"].unique() + always_treated_units = units[:3] + for u in always_treated_units: + data.loc[data["unit"] == u, "first_treat"] = min_time + sd = SurveyDesign(weights="weight") + result = TwoStageDiD().fit( + data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.survey_metadata is not None + # ============================================================================= # TestCallawaySantAnnaSurvey @@ -803,6 +825,62 @@ def test_dr_uniform_weights_match_unweighted(self, ddd_survey_data): ) assert np.isclose(r_no.att, r_sv.att, atol=1e-6) + def test_ipw_covariate_survey_nonuniform(self, ddd_survey_data): + """IPW + covariates + non-uniform survey weights should change ATT.""" + data = ddd_survey_data.copy() + data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) + sd = SurveyDesign(weights="weight") + r_no = TripleDifference(estimation_method="ipw").fit( + data, + "outcome", + "group", + "partition", + "time", + covariates=["x1"], + ) + r_sv = TripleDifference(estimation_method="ipw").fit( + data, + "outcome", + "group", + "partition", + "time", + covariates=["x1"], + survey_design=sd, + ) + assert np.isfinite(r_sv.att) + assert np.isfinite(r_sv.se) + assert not np.isclose( + r_no.att, r_sv.att, atol=1e-6 + ), "Covariate IPW + non-uniform survey weights should change ATT" + + def test_dr_covariate_survey_nonuniform(self, ddd_survey_data): + """DR + covariates + non-uniform survey weights should change ATT.""" + data = ddd_survey_data.copy() + data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) + sd = SurveyDesign(weights="weight") + r_no = TripleDifference(estimation_method="dr").fit( + data, + "outcome", + "group", + "partition", + "time", + covariates=["x1"], + ) + r_sv = TripleDifference(estimation_method="dr").fit( + data, + "outcome", + "group", + "partition", + "time", + covariates=["x1"], + survey_design=sd, + ) + assert np.isfinite(r_sv.att) + assert np.isfinite(r_sv.se) + assert not np.isclose( + r_no.att, r_sv.att, atol=1e-6 + ), "Covariate DR + non-uniform survey weights should change ATT" + # ============================================================================= # TestCallawaySantAnnaSurveyInference From e84078b39ff5e7a27fdb258f9daf2a560bc06269 Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 22 Mar 2026 21:05:22 -0400 Subject: [PATCH 05/14] Fix round-3 review P1s: DDD double-weighting, CS WIF scaling, CS covariate nuisance IF MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - TripleDifference: remove double-weighting in IPW/DR moment corrections — since Riesz representers already incorporate survey weights, moment means use np.mean() not np.average(weights=). Removed _wmean_ax0 helper. - CallawaySantAnna WIF: apply s_i symmetrically to both indicator and pg terms in the weighted share estimator IF. Normalize by total_weight (sum of survey weights) instead of n_units. - CallawaySantAnna outcome regression covariate IF: add weighted regression nuisance IF correction (asymptotic linear representation of beta from WLS, projected onto weighted treated covariate mean). IPW and DR IFs unchanged (IPW matches unweighted structure; DR is self-correcting per Theorem 3.1). Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/staggered.py | 27 ++++++++++++- diff_diff/staggered_aggregation.py | 20 +++++---- diff_diff/triple_diff.py | 65 ++++++++++++++---------------- 3 files changed, 70 insertions(+), 42 deletions(-) diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index c1f99686..73b7f14c 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -1591,9 +1591,34 @@ def _outcome_regression( sw_c_norm = sw_control / np.sum(sw_control) att = float(np.sum(sw_t_norm * treated_residuals)) + # --- Regression nuisance IF correction --- + # Account for uncertainty in beta estimation + X_c = np.column_stack([np.ones(n_c), X_control]) + X_t = np.column_stack([np.ones(n_t), X_treated]) + + # Weighted bread: (X'WX)^{-1} + XWX = X_c.T @ (X_c * sw_control[:, None]) + try: + XWX_inv = np.linalg.solve(XWX, np.eye(XWX.shape[0])) + except np.linalg.LinAlgError: + XWX_inv = np.linalg.lstsq(XWX, np.eye(XWX.shape[0]), rcond=None)[0] + + # Per-control regression score: w_i * x_i * resid_i + resid_c = control_change - X_c @ beta + score_c = X_c * (sw_control * resid_c)[:, None] + asy_lin_rep_reg = score_c @ XWX_inv # shape (n_c, p) + + # Weighted treated covariate mean + X_treated_mean_w = np.average(X_t, axis=0, weights=sw_treated) + + # Regression IF correction for control observations + inf_control_reg_corr = asy_lin_rep_reg @ X_treated_mean_w + # Influence function (survey-weighted) inf_treated = sw_t_norm * (treated_residuals - att) - inf_control = -sw_c_norm * (control_change - np.sum(sw_c_norm * control_change)) + inf_control = ( + -sw_c_norm * (control_change - np.dot(X_c, beta)) + inf_control_reg_corr / n_c + ) inf_func = np.concatenate([inf_treated, inf_control]) # SE from influence function variance diff --git a/diff_diff/staggered_aggregation.py b/diff_diff/staggered_aggregation.py index 11c06527..b28e3857 100644 --- a/diff_diff/staggered_aggregation.py +++ b/diff_diff/staggered_aggregation.py @@ -341,7 +341,8 @@ def _compute_combined_influence_function( ).astype(np.float64) if survey_w is not None: - # Survey-weighted WIF: indicator entries are sw_i / sum(sw_all) + # Survey-weighted WIF for group-share estimator p_g = sum(s_i * 1{G_i=g}) / sum(s_j). + # IF_i(p_g) = s_i * (1{G_i=g} - p_g) / sum(s_j) # Build per-unit weight vector aligned to our index space if global_unit_to_idx is not None and precomputed is not None: unit_sw = np.zeros(n_units) @@ -353,12 +354,16 @@ def _compute_combined_influence_function( else: unit_sw = np.ones(n_units) - # Weighted indicator: sw_i * 1{G_i == g_k} / sum(sw_all) - weighted_indicator = indicator_matrix * (unit_sw / total_weight)[:, np.newaxis] - indicator_sum_w = np.sum(weighted_indicator - pg_keepers, axis=1) + # s_i * 1{G_i == g_k} + weighted_indicator = indicator_matrix * unit_sw[:, np.newaxis] + # s_i * p_g_k (symmetric weight application) + weighted_pg_term = pg_keepers[np.newaxis, :] * unit_sw[:, np.newaxis] + # s_i * (1{G_i == g_k} - p_g_k) / sum(s_j) + indicator_diff = (weighted_indicator - weighted_pg_term) / total_weight + indicator_sum_w = np.sum(indicator_diff, axis=1) with np.errstate(divide="ignore", invalid="ignore", over="ignore"): - if1_matrix = (weighted_indicator - pg_keepers) / sum_pg_keepers + if1_matrix = indicator_diff / sum_pg_keepers if2_matrix = np.outer(indicator_sum_w, pg_keepers) / (sum_pg_keepers**2) wif_matrix = if1_matrix - if2_matrix wif_contrib = wif_matrix @ effects @@ -386,8 +391,9 @@ def _compute_combined_influence_function( nan_result = np.full(n_units, np.nan) return nan_result, all_units - # Scale by 1/n_units to match R's getSE formula - psi_wif = wif_contrib / n_units + # Scale by 1/total_weight to match R's getSE formula + # (for non-survey, total_weight == n_units; for survey, total_weight == sum(sw)) + psi_wif = wif_contrib / total_weight # Combine standard and wif terms psi_total = psi_standard + psi_wif diff --git a/diff_diff/triple_diff.py b/diff_diff/triple_diff.py index 0129aade..3c6e7834 100644 --- a/diff_diff/triple_diff.py +++ b/diff_diff/triple_diff.py @@ -1329,26 +1329,16 @@ def _hajek(riesz, y_vals): score_ps = score_ps * weights[:, None] asy_lin_rep_ps = score_ps @ hessian - if weights is not None: - M2_pre = np.average( - (riesz_control_pre * (y - att_control_pre))[:, None] * covX, - axis=0, - weights=weights, - ) / np.mean(riesz_control_pre) - M2_post = np.average( - (riesz_control_post * (y - att_control_post))[:, None] * covX, - axis=0, - weights=weights, - ) / np.mean(riesz_control_post) - else: - M2_pre = np.mean( - (riesz_control_pre * (y - att_control_pre))[:, None] * covX, - axis=0, - ) / np.mean(riesz_control_pre) - M2_post = np.mean( - (riesz_control_post * (y - att_control_post))[:, None] * covX, - axis=0, - ) / np.mean(riesz_control_post) + # Riesz representers already incorporate survey weights, + # so use np.mean (not np.average with weights) to avoid double-weighting. + M2_pre = np.mean( + (riesz_control_pre * (y - att_control_pre))[:, None] * covX, + axis=0, + ) / np.mean(riesz_control_pre) + M2_post = np.mean( + (riesz_control_post * (y - att_control_post))[:, None] * covX, + axis=0, + ) / np.mean(riesz_control_post) inf_control_ps = asy_lin_rep_ps @ (M2_post - M2_pre) inf_control = inf_control + inf_control_ps @@ -1616,19 +1606,15 @@ def _safe_ratio(num, denom): ) # OR correction for treated - def _wmean_ax0(arr): - """Weighted or unweighted column mean.""" - if weights is not None: - return np.average(arr, axis=0, weights=weights) - return np.mean(arr, axis=0) - + # Riesz representers already incorporate survey weights, + # so use np.mean (not weighted average) to avoid double-weighting. M1_post = ( - (-_wmean_ax0((riesz_treat_post * post)[:, None] * covX) / m_riesz_treat_post) + (-np.mean((riesz_treat_post * post)[:, None] * covX, axis=0) / m_riesz_treat_post) if m_riesz_treat_post > 0 else np.zeros(covX.shape[1]) ) M1_pre = ( - (-_wmean_ax0((riesz_treat_pre * (1 - post))[:, None] * covX) / m_riesz_treat_pre) + (-np.mean((riesz_treat_pre * (1 - post))[:, None] * covX, axis=0) / m_riesz_treat_pre) if m_riesz_treat_pre > 0 else np.zeros(covX.shape[1]) ) @@ -1653,7 +1639,9 @@ def _wmean_ax0(arr): # PS correction for control M2_pre = ( ( - _wmean_ax0((riesz_control_pre * (y - or_ctrl - att_control_pre))[:, None] * covX) + np.mean( + (riesz_control_pre * (y - or_ctrl - att_control_pre))[:, None] * covX, axis=0 + ) / m_riesz_control_pre ) if m_riesz_control_pre > 0 @@ -1661,7 +1649,9 @@ def _wmean_ax0(arr): ) M2_post = ( ( - _wmean_ax0((riesz_control_post * (y - or_ctrl - att_control_post))[:, None] * covX) + np.mean( + (riesz_control_post * (y - or_ctrl - att_control_post))[:, None] * covX, axis=0 + ) / m_riesz_control_post ) if m_riesz_control_post > 0 @@ -1671,12 +1661,15 @@ def _wmean_ax0(arr): # OR correction for control M3_post = ( - (-_wmean_ax0((riesz_control_post * post)[:, None] * covX) / m_riesz_control_post) + (-np.mean((riesz_control_post * post)[:, None] * covX, axis=0) / m_riesz_control_post) if m_riesz_control_post > 0 else np.zeros(covX.shape[1]) ) M3_pre = ( - (-_wmean_ax0((riesz_control_pre * (1 - post))[:, None] * covX) / m_riesz_control_pre) + ( + -np.mean((riesz_control_pre * (1 - post))[:, None] * covX, axis=0) + / m_riesz_control_pre + ) if m_riesz_control_pre > 0 else np.zeros(covX.shape[1]) ) @@ -1704,12 +1697,16 @@ def _wmean_ax0(arr): # OR combination mom_post = ( - _wmean_ax0((riesz_d[:, None] / m_riesz_d - riesz_dt1[:, None] / m_riesz_dt1) * covX) + np.mean( + (riesz_d[:, None] / m_riesz_d - riesz_dt1[:, None] / m_riesz_dt1) * covX, axis=0 + ) if (m_riesz_d > 0 and m_riesz_dt1 > 0) else np.zeros(covX.shape[1]) ) mom_pre = ( - _wmean_ax0((riesz_d[:, None] / m_riesz_d - riesz_dt0[:, None] / m_riesz_dt0) * covX) + np.mean( + (riesz_d[:, None] / m_riesz_d - riesz_dt0[:, None] / m_riesz_dt0) * covX, axis=0 + ) if (m_riesz_d > 0 and m_riesz_dt0 > 0) else np.zeros(covX.shape[1]) ) From 980e8c01beab6d2d2bb859fa9284e280c49d6a3e Mon Sep 17 00:00:00 2001 From: igerber Date: Sun, 22 Mar 2026 21:26:09 -0400 Subject: [PATCH 06/14] Fix round-4 review P1s: CS WIF normalization, IPW nuisance IF, TwoStage n_psu MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - CallawaySantAnna WIF: remove inner /total_weight from indicator_diff — the final psi_wif/total_weight handles normalization once, matching R's did::wif() - CallawaySantAnna IPW covariate: add propensity score nuisance IF correction (survey-weighted Hessian, score, M2 gradient) so per-cell and aggregated SEs account for PS estimation uncertainty - TwoStageDiD: recompute n_psu/n_strata after always-treated filtering via np.unique() on subsetted arrays, then recompute survey_metadata Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/staggered.py | 36 +++++++++++++++++++++++++++++- diff_diff/staggered_aggregation.py | 2 +- diff_diff/two_stage.py | 19 ++++++++++++++++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 73b7f14c..638cdef3 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -1774,8 +1774,42 @@ def _ipw_estimation( ) inf_func = np.concatenate([inf_treated, inf_control]) + # Propensity score IF correction + # Accounts for estimation uncertainty in logistic regression coefficients + X_all_int = np.column_stack([np.ones(n_t + n_c), X_all]) + pscore_all = np.concatenate([pscore_treated, pscore_control]) + + # Survey-weighted PS Hessian: sum(w_i * mu_i * (1-mu_i) * x_i * x_i') + W_ps = pscore_all * (1 - pscore_all) + if sw_all is not None: + W_ps = W_ps * sw_all + H = X_all_int.T @ (W_ps[:, None] * X_all_int) + try: + H_inv = np.linalg.solve(H, np.eye(H.shape[0])) + except np.linalg.LinAlgError: + H_inv = np.linalg.lstsq(H, np.eye(H.shape[0]), rcond=None)[0] + + # PS score: w_i * (D_i - pi_i) * X_i + D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)]) + score_ps = (D_all - pscore_all)[:, None] * X_all_int + if sw_all is not None: + score_ps = score_ps * sw_all[:, None] + asy_lin_rep_ps = score_ps @ H_inv # shape (n_t + n_c, p) + + # M2: gradient of ATT w.r.t. PS parameters + att_control_weighted = np.sum(weights_control_norm * control_change) + M2 = np.mean( + (weights_control_norm * (control_change - att_control_weighted))[:, None] + * X_all_int[n_t:], + axis=0, + ) + + # PS correction to influence function + inf_ps_correction = asy_lin_rep_ps @ M2 + inf_func = inf_func + inf_ps_correction + # SE from influence function variance - var_psi = np.sum(inf_treated**2) + np.sum(inf_control**2) + var_psi = np.sum(inf_func**2) se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0 else: # IPW weights for control units: p(X) / (1 - p(X)) diff --git a/diff_diff/staggered_aggregation.py b/diff_diff/staggered_aggregation.py index b28e3857..28e3e720 100644 --- a/diff_diff/staggered_aggregation.py +++ b/diff_diff/staggered_aggregation.py @@ -359,7 +359,7 @@ def _compute_combined_influence_function( # s_i * p_g_k (symmetric weight application) weighted_pg_term = pg_keepers[np.newaxis, :] * unit_sw[:, np.newaxis] # s_i * (1{G_i == g_k} - p_g_k) / sum(s_j) - indicator_diff = (weighted_indicator - weighted_pg_term) / total_weight + indicator_diff = weighted_indicator - weighted_pg_term indicator_sum_w = np.sum(indicator_diff, axis=1) with np.errstate(divide="ignore", invalid="ignore", over="ignore"): diff --git a/diff_diff/two_stage.py b/diff_diff/two_stage.py index 18cc0947..82ef7dae 100644 --- a/diff_diff/two_stage.py +++ b/diff_diff/two_stage.py @@ -312,6 +312,25 @@ def fit( else None ), ) + # Recompute n_psu/n_strata after subsetting + new_n_psu = ( + len(np.unique(resolved_survey.psu)) if resolved_survey.psu is not None else 0 + ) + new_n_strata = ( + len(np.unique(resolved_survey.strata)) + if resolved_survey.strata is not None + else 0 + ) + resolved_survey = replace(resolved_survey, n_psu=new_n_psu, n_strata=new_n_strata) + # Recompute survey_metadata since it depends on these counts + from diff_diff.survey import compute_survey_metadata + + raw_w = ( + df[survey_design.weights].values.astype(np.float64) + if survey_design.weights + else np.ones(len(df), dtype=np.float64) + ) + survey_metadata = compute_survey_metadata(resolved_survey, raw_w) # Treatment indicator with anticipation effective_treat = df[first_treat] - self.anticipation From b5c5fc9aa7d26c0c160cbdff6e31eb73cfb8061c Mon Sep 17 00:00:00 2001 From: igerber Date: Mon, 23 Mar 2026 06:27:04 -0400 Subject: [PATCH 07/14] Fix round-5 review: CS WIF matches R did::wif(), gate IPW/DR+covariates+survey - CallawaySantAnna WIF: use (w_i * 1{G=g} - pg) matching R's did::wif() formula, not (w_i * 1{G=g} - w_i * pg). pg is already the population-level expected value. - Gate CallawaySantAnna covariates + survey + IPW/DR with NotImplementedError (DRDID panel nuisance IF not yet implemented). Regression with covariates works. - Narrow roadmap/REGISTRY claims to match actual support boundary - Add TODO entry for DRDID nuisance IF - Add 3 tests: IPW/DR+cov+survey raises, reg+cov+survey works Co-Authored-By: Claude Opus 4.6 (1M context) --- TODO.md | 3 +- diff_diff/staggered.py | 16 +++++++++++ diff_diff/staggered_aggregation.py | 15 +++++----- docs/methodology/REGISTRY.md | 2 +- docs/survey-roadmap.md | 2 +- tests/test_survey_phase4.py | 45 ++++++++++++++++++++++++++++++ 6 files changed, 72 insertions(+), 11 deletions(-) diff --git a/TODO.md b/TODO.md index 1082281d..19f6a4b4 100644 --- a/TODO.md +++ b/TODO.md @@ -52,7 +52,8 @@ Deferred items from PR reviews that were not addressed before merge. | ImputationDiD dense `(A0'A0).toarray()` scales O((U+T+K)^2), OOM risk on large panels | `imputation.py` | #141 | Medium (deferred — only triggers when sparse solver fails; fixing requires sparse least-squares alternatives) | | EfficientDiD: API docs / tutorial page for new public estimator | `docs/` | #192 | Medium | | Multi-absorb weighted demeaning needs iterative alternating projections for N > 1 absorbed FE with survey weights; unweighted multi-absorb also uses single-pass (pre-existing, exact only for balanced panels) | `estimators.py` | #218 | Medium | -| CallawaySantAnna per-cell ATT(g,t) SEs under survey use influence-function variance, not full design-based TSL with strata/PSU/FPC. Design effects enter at aggregation via WIF and survey df. Full per-cell TSL would require constructing unit-level influence functions on the global index and passing through `compute_survey_vcov()`. | `staggered.py` | — | Medium | +| CallawaySantAnna per-cell ATT(g,t) SEs under survey use influence-function variance, not full design-based TSL with strata/PSU/FPC. Design effects enter at aggregation via WIF and survey df. Full per-cell TSL would require constructing unit-level influence functions on the global index and passing through `compute_survey_vcov()`. | `staggered.py` | #233 | Medium | +| CallawaySantAnna survey + covariates + IPW/DR: DRDID panel nuisance-estimation IF corrections not implemented. Currently gated with NotImplementedError. Regression method with covariates works (has WLS nuisance IF correction). | `staggered.py` | #233 | Medium | | EfficientDiD hausman_pretest() clustered covariance uses stale `n_cl` after filtering non-finite EIF rows — should recompute effective cluster count and remap indices after `row_finite` filtering | `efficient_did.py` | #230 | Medium | | TripleDifference power: `generate_ddd_data` is a fixed 2×2×2 cross-sectional DGP — no multi-period or unbalanced-group support. Add a `generate_ddd_panel_data` for panel DDD power analysis. | `prep_dgp.py`, `power.py` | #208 | Low | | ContinuousDiD event-study aggregation does not filter by `anticipation` — uses all (g,t) cells instead of anticipation-filtered subset; pre-existing in both survey and non-survey paths | `continuous_did.py` | #226 | Medium | diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 638cdef3..3c1db5a4 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -1199,6 +1199,22 @@ def fit( "for CallawaySantAnna. Use analytical inference (n_bootstrap=0)." ) + # Guard covariates + survey + IPW/DR (nuisance IF corrections not yet + # implemented to match DRDID panel formula) + if ( + resolved_survey is not None + and covariates is not None + and len(covariates) > 0 + and self.estimation_method in ("ipw", "dr") + ): + raise NotImplementedError( + f"Survey weights with covariates and estimation_method=" + f"'{self.estimation_method}' is not yet supported for " + f"CallawaySantAnna. The DRDID panel nuisance-estimation IF " + f"corrections are not yet implemented. Use estimation_method='reg' " + f"with covariates, or use any method without covariates." + ) + # Validate inputs required_cols = [outcome, unit, time, first_treat] if covariates: diff --git a/diff_diff/staggered_aggregation.py b/diff_diff/staggered_aggregation.py index 28e3e720..e6c8045e 100644 --- a/diff_diff/staggered_aggregation.py +++ b/diff_diff/staggered_aggregation.py @@ -341,9 +341,11 @@ def _compute_combined_influence_function( ).astype(np.float64) if survey_w is not None: - # Survey-weighted WIF for group-share estimator p_g = sum(s_i * 1{G_i=g}) / sum(s_j). - # IF_i(p_g) = s_i * (1{G_i=g} - p_g) / sum(s_j) - # Build per-unit weight vector aligned to our index space + # Survey-weighted WIF matching R's did::wif() / compute.aggte.R. + # pg_k = E[w_i * 1{G_i=g}] is the weighted group share. + # IF_i(p_g) = (w_i * 1{G_i=g} - pg_k), NOT s_i * (1{G_i=g} - pg_k). + # The pg subtraction is NOT weighted by s_i because pg is already + # the population-level expected value of w_i * 1{G_i=g}. if global_unit_to_idx is not None and precomputed is not None: unit_sw = np.zeros(n_units) precomputed_unit_to_idx_local = precomputed["unit_to_idx"] @@ -354,12 +356,9 @@ def _compute_combined_influence_function( else: unit_sw = np.ones(n_units) - # s_i * 1{G_i == g_k} + # w_i * 1{G_i == g_k} - pg_k (matches R's did::wif) weighted_indicator = indicator_matrix * unit_sw[:, np.newaxis] - # s_i * p_g_k (symmetric weight application) - weighted_pg_term = pg_keepers[np.newaxis, :] * unit_sw[:, np.newaxis] - # s_i * (1{G_i == g_k} - p_g_k) / sum(s_j) - indicator_diff = weighted_indicator - weighted_pg_term + indicator_diff = weighted_indicator - pg_keepers indicator_sum_w = np.sum(indicator_diff, axis=1) with np.errstate(divide="ignore", invalid="ignore", over="ignore"): diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index b91f60d3..659a1eb0 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -416,7 +416,7 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1: a base period later than `t` (matching R's `did::att_gt()`) - Does not require never-treated units: when all units are eventually treated, not-yet-treated cohorts serve as controls for each other (requires ≥2 cohorts) -- **Note:** CallawaySantAnna survey weights compose with IPW weights multiplicatively: w_total = w_survey * p(X) / (1-p(X)). Propensity scores estimated via survey-weighted solve_logit(). WIF in aggregation uses survey-weighted group sizes. Bootstrap + survey deferred. +- **Note:** CallawaySantAnna survey weights: regression method supports covariates; IPW/DR support no-covariate only (covariates+IPW/DR+survey raises NotImplementedError — DRDID nuisance IF not yet implemented). Survey weights compose with IPW weights multiplicatively. WIF in aggregation matches R's did::wif() formula. Bootstrap + survey deferred. - **Note (deviation from R):** Per-cell ATT(g,t) SEs under survey weights use influence-function-based variance (matching R's `did::att_gt` analytical SE path) rather than full Taylor-series linearization with strata/PSU/FPC structure. The survey design structure is reflected in aggregation-level SEs via the WIF and survey degrees of freedom, but individual (g,t) cell SEs do not incorporate the full design-based variance. This is consistent with R's approach where per-cell SEs are influence-function-based and design effects enter at the aggregation stage. **Reference implementation(s):** diff --git a/docs/survey-roadmap.md b/docs/survey-roadmap.md index 20b4df26..f228e8cf 100644 --- a/docs/survey-roadmap.md +++ b/docs/survey-roadmap.md @@ -46,7 +46,7 @@ message pointing to the planned phase or describing the limitation. |-----------|------|----------------|-------| | ImputationDiD | `imputation.py` | Analytical | Weighted iterative FE, weighted ATT aggregation, weighted conservative variance (Theorem 3); bootstrap+survey deferred | | TwoStageDiD | `two_stage.py` | Analytical | Weighted iterative FE, weighted Stage 2 OLS, weighted GMM sandwich variance; bootstrap+survey deferred | -| CallawaySantAnna | `staggered.py` | Full (analytical) | Survey-weighted regression, IPW (via weighted `solve_logit()`), and DR; survey weights compose with IPW weights multiplicatively; survey-weighted WIF in aggregation; bootstrap+survey deferred | +| CallawaySantAnna | `staggered.py` | Analytical | Survey-weighted regression (all cases), IPW and DR (no-covariate only); survey-weighted WIF in aggregation; covariates+IPW/DR deferred (needs DRDID nuisance IF); bootstrap+survey deferred | **Infrastructure**: Weighted `solve_logit()` added to `linalg.py` — survey weights enter the IRLS working weights as `w_survey * mu * (1 - mu)`. This also unblocked diff --git a/tests/test_survey_phase4.py b/tests/test_survey_phase4.py index fc5b6ac1..85e51364 100644 --- a/tests/test_survey_phase4.py +++ b/tests/test_survey_phase4.py @@ -613,6 +613,51 @@ def test_bootstrap_survey_raises(self, staggered_survey_data, survey_design_weig survey_design=survey_design_weights_only, ) + def test_ipw_covariates_survey_raises(self, staggered_survey_data, survey_design_weights_only): + """IPW + covariates + survey should raise NotImplementedError.""" + data = staggered_survey_data.copy() + data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) + with pytest.raises(NotImplementedError, match="covariates"): + CallawaySantAnna(estimation_method="ipw").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_design_weights_only, + ) + + def test_dr_covariates_survey_raises(self, staggered_survey_data, survey_design_weights_only): + """DR + covariates + survey should raise NotImplementedError.""" + data = staggered_survey_data.copy() + data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) + with pytest.raises(NotImplementedError, match="covariates"): + CallawaySantAnna(estimation_method="dr").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_design_weights_only, + ) + + def test_reg_covariates_survey_works(self, staggered_survey_data, survey_design_weights_only): + """Regression + covariates + survey should work (has nuisance IF correction).""" + data = staggered_survey_data.copy() + data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) + result = CallawaySantAnna(estimation_method="reg").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + def test_weighted_logit(self, staggered_survey_data, survey_design_weights_only): """Propensity scores should change with survey weights (IPW path).""" r_unw = CallawaySantAnna(estimation_method="ipw").fit( From d8f652077d38953d0e4171ced260d9bb2e617e21 Mon Sep 17 00:00:00 2001 From: igerber Date: Mon, 23 Mar 2026 06:32:49 -0400 Subject: [PATCH 08/14] Fix CS reg nuisance IF scaling, add SE scale-invariance test - Remove spurious /n_c from regression nuisance IF correction in _outcome_regression survey covariate branch. asy_lin_rep_reg is already per-observation, so dividing by n_c double-scaled. - Add test_reg_covariates_survey_se_scale_invariance: verifies ATT and SE are invariant to weight rescaling for reg+covariates+survey. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/staggered.py | 2 +- tests/test_survey_phase4.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 3c1db5a4..6e7bd74d 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -1633,7 +1633,7 @@ def _outcome_regression( # Influence function (survey-weighted) inf_treated = sw_t_norm * (treated_residuals - att) inf_control = ( - -sw_c_norm * (control_change - np.dot(X_c, beta)) + inf_control_reg_corr / n_c + -sw_c_norm * (control_change - np.dot(X_c, beta)) + inf_control_reg_corr ) inf_func = np.concatenate([inf_treated, inf_control]) diff --git a/tests/test_survey_phase4.py b/tests/test_survey_phase4.py index 85e51364..63b27bed 100644 --- a/tests/test_survey_phase4.py +++ b/tests/test_survey_phase4.py @@ -658,6 +658,41 @@ def test_reg_covariates_survey_works(self, staggered_survey_data, survey_design_ ) assert np.isfinite(result.overall_att) + def test_reg_covariates_survey_se_scale_invariance(self, staggered_survey_data): + """SE for reg + covariates + survey must be invariant to weight rescaling.""" + data = staggered_survey_data.copy() + data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) + data["weight2"] = data["weight"] * 4.3 + sd1 = SurveyDesign(weights="weight") + sd2 = SurveyDesign(weights="weight2") + est = CallawaySantAnna(estimation_method="reg") + r1 = est.fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + aggregate="simple", + survey_design=sd1, + ) + r2 = est.fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + aggregate="simple", + survey_design=sd2, + ) + assert np.isclose( + r1.overall_att, r2.overall_att, atol=1e-8 + ), "ATT not scale-invariant for reg+cov+survey" + assert np.isclose( + r1.overall_se, r2.overall_se, atol=1e-8 + ), f"SE not scale-invariant for reg+cov+survey: {r1.overall_se} vs {r2.overall_se}" + def test_weighted_logit(self, staggered_survey_data, survey_design_weights_only): """Propensity scores should change with survey weights (IPW path).""" r_unw = CallawaySantAnna(estimation_method="ipw").fit( From 7be3f19030ec8e76d302c70a0b53b918d47b74f7 Mon Sep 17 00:00:00 2001 From: igerber Date: Mon, 23 Mar 2026 06:51:00 -0400 Subject: [PATCH 09/14] Fix ImputationDiD WLS v_it projection, add survey_design to convenience wrappers - ImputationDiD FE-only v_it: multiply by per-obs survey weight (WLS projection requires observation-level weight factor, not just weighted denominators) - ImputationDiD covariate v_it: v_0 = -diag(sw_0) @ A_0 @ z (WLS projection includes left-side W_0 factor) - Add survey_design parameter to imputation_did() and two_stage_did() convenience wrappers, forwarded to .fit() Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/imputation.py | 12 ++++++++++-- diff_diff/two_stage.py | 2 ++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/diff_diff/imputation.py b/diff_diff/imputation.py index f71e1422..8a386a62 100644 --- a/diff_diff/imputation.py +++ b/diff_diff/imputation.py @@ -1099,7 +1099,11 @@ def _compute_cluster_psi_sums( w_t = w_by_time.get(t, 0.0) n0_i = n0_by_unit.get(u, 1) n0_t = n0_by_time.get(t, 1) - v_untreated[j] = -(w_i / n0_i + w_t / n0_t - w_total / n0_denom) + base_v = -(w_i / n0_i + w_t / n0_t - w_total / n0_denom) + # WLS projection requires per-obs survey weight factor + if survey_weights_0 is not None: + base_v *= survey_weights_0[j] + v_untreated[j] = base_v else: v_untreated = self._compute_v_untreated_with_covariates( df_0, @@ -1299,8 +1303,10 @@ def _build_A_sparse(df_sub, unit_vals, time_vals): A0tA0_dense = A0tA0_sparse.toarray() z, _, _, _ = np.linalg.lstsq(A0tA0_dense, A1_w, rcond=None) - # v_untreated = -A_0 z (sparse @ dense -> dense) + # v_untreated = -[W_0] A_0 z (WLS projection requires per-obs weight) v_untreated = -(A_0 @ z) + if survey_weights_0 is not None: + v_untreated = v_untreated * survey_weights_0 return v_untreated def _compute_auxiliary_residuals_treated( @@ -1894,6 +1900,7 @@ def imputation_did( covariates: Optional[List[str]] = None, aggregate: Optional[str] = None, balance_e: Optional[int] = None, + survey_design: object = None, **kwargs, ) -> ImputationDiDResults: """ @@ -1945,4 +1952,5 @@ def imputation_did( covariates=covariates, aggregate=aggregate, balance_e=balance_e, + survey_design=survey_design, ) diff --git a/diff_diff/two_stage.py b/diff_diff/two_stage.py index 82ef7dae..b468a2bb 100644 --- a/diff_diff/two_stage.py +++ b/diff_diff/two_stage.py @@ -1614,6 +1614,7 @@ def two_stage_did( covariates: Optional[List[str]] = None, aggregate: Optional[str] = None, balance_e: Optional[int] = None, + survey_design: object = None, **kwargs, ) -> TwoStageDiDResults: """ @@ -1665,4 +1666,5 @@ def two_stage_did( covariates=covariates, aggregate=aggregate, balance_e=balance_e, + survey_design=survey_design, ) From ec4dda3733102826e1e05ecb32002fb0fa937738 Mon Sep 17 00:00:00 2001 From: igerber Date: Mon, 23 Mar 2026 06:57:37 -0400 Subject: [PATCH 10/14] Add SE scale-invariance and wrapper tests for ImputationDiD/TwoStageDiD survey - ImputationDiD: SE scale-invariance test for FE-only and covariate paths - ImputationDiD: wrapper parity test (imputation_did with survey_design) - TwoStageDiD: wrapper parity test (two_stage_did with survey_design) Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_survey_phase4.py | 105 ++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/tests/test_survey_phase4.py b/tests/test_survey_phase4.py index 63b27bed..9d03bb34 100644 --- a/tests/test_survey_phase4.py +++ b/tests/test_survey_phase4.py @@ -332,6 +332,87 @@ def test_summary_includes_survey(self, staggered_survey_data, survey_design_weig assert "Survey Design" in summary assert "pweight" in summary + def test_se_scale_invariance_fe_only(self, staggered_survey_data): + """SE must be invariant to weight rescaling (FE-only, no covariates).""" + data = staggered_survey_data.copy() + data["weight2"] = data["weight"] * 3.1 + sd1 = SurveyDesign(weights="weight") + sd2 = SurveyDesign(weights="weight2") + r1 = ImputationDiD().fit( + data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd1, + ) + r2 = ImputationDiD().fit( + data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd2, + ) + assert np.isclose(r1.overall_att, r2.overall_att, atol=1e-8) + assert np.isclose( + r1.overall_se, r2.overall_se, atol=1e-8 + ), f"SE not scale-invariant (FE-only): {r1.overall_se} vs {r2.overall_se}" + + def test_se_scale_invariance_with_covariates(self, staggered_survey_data): + """SE must be invariant to weight rescaling (with covariates).""" + data = staggered_survey_data.copy() + data["x1"] = np.random.default_rng(99).normal(0, 1, len(data)) + data["weight2"] = data["weight"] * 3.1 + sd1 = SurveyDesign(weights="weight") + sd2 = SurveyDesign(weights="weight2") + r1 = ImputationDiD().fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=sd1, + ) + r2 = ImputationDiD().fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=sd2, + ) + assert np.isclose(r1.overall_att, r2.overall_att, atol=1e-8) + assert np.isclose( + r1.overall_se, r2.overall_se, atol=1e-8 + ), f"SE not scale-invariant (covariates): {r1.overall_se} vs {r2.overall_se}" + + def test_wrapper_imputation_did_with_survey(self, staggered_survey_data): + """imputation_did() wrapper forwards survey_design correctly.""" + from diff_diff import imputation_did + + sd = SurveyDesign(weights="weight") + r_wrapper = imputation_did( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + ) + r_direct = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + ) + assert np.isclose(r_wrapper.overall_att, r_direct.overall_att, atol=1e-10) + assert r_wrapper.survey_metadata is not None + # ============================================================================= # TestTwoStageDiDSurvey @@ -460,6 +541,30 @@ def test_summary_includes_survey(self, staggered_survey_data, survey_design_weig assert "Survey Design" in summary assert "pweight" in summary + def test_wrapper_two_stage_did_with_survey(self, staggered_survey_data): + """two_stage_did() wrapper forwards survey_design correctly.""" + from diff_diff import two_stage_did + + sd = SurveyDesign(weights="weight") + r_wrapper = two_stage_did( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + ) + r_direct = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + ) + assert np.isclose(r_wrapper.overall_att, r_direct.overall_att, atol=1e-10) + assert r_wrapper.survey_metadata is not None + def test_always_treated_with_survey(self, staggered_survey_data): """TwoStageDiD with survey + always-treated units should not crash.""" data = staggered_survey_data.copy() From 6165bc4d24df9a6c202d3f50d6840b56e496ecdc Mon Sep 17 00:00:00 2001 From: igerber Date: Mon, 23 Mar 2026 07:26:28 -0400 Subject: [PATCH 11/14] Fix TripleDiff TSL double-weighting, rewrite CS reg covariate survey IF MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - TripleDifference: divide out survey weights from IF before passing to compute_survey_vcov, since Riesz representers already incorporate weights and TSL would multiply by weights again - CallawaySantAnna _outcome_regression: rewrite survey covariate IF to follow DRDID panel OR structure — all terms consistently scaled by 1/sw_t_sum, nuisance correction divided by sw_t_sum for correct normalization Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/staggered.py | 43 ++++++++++++++++++++++------------------ diff_diff/triple_diff.py | 13 +++++++++--- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 6e7bd74d..31ad8788 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -1603,43 +1603,48 @@ def _outcome_regression( treated_residuals = treated_change - predicted_control if sw_treated is not None: - sw_t_norm = sw_treated / np.sum(sw_treated) - sw_c_norm = sw_control / np.sum(sw_control) + sw_t_sum = float(np.sum(sw_treated)) + sw_t_norm = sw_treated / sw_t_sum att = float(np.sum(sw_t_norm * treated_residuals)) - # --- Regression nuisance IF correction --- - # Account for uncertainty in beta estimation + # --- DRDID panel OR influence function (survey-weighted) --- + # Following Sant'Anna & Zhao (2020) Theorem 3.1 for the OR estimator. + # All IF terms are scaled by 1/sw_t_sum so that sum(IF^2) gives V(ATT). X_c = np.column_stack([np.ones(n_c), X_control]) X_t = np.column_stack([np.ones(n_t), X_treated]) - # Weighted bread: (X'WX)^{-1} + # Treated component: w_i * (ΔY_i - m(X_i) - ATT) / sum(w_treated) + inf_treated = (sw_treated / sw_t_sum) * (treated_residuals - att) + + # Control outcome-regression component + predicted_c = np.dot(X_c, beta) + inf_control_or = -(sw_control / sw_t_sum) * (control_change - predicted_c) + + # Regression nuisance IF correction (accounts for beta estimation) + # Hessian of WLS: H = X_c' W_c X_c XWX = X_c.T @ (X_c * sw_control[:, None]) try: XWX_inv = np.linalg.solve(XWX, np.eye(XWX.shape[0])) except np.linalg.LinAlgError: XWX_inv = np.linalg.lstsq(XWX, np.eye(XWX.shape[0]), rcond=None)[0] - # Per-control regression score: w_i * x_i * resid_i - resid_c = control_change - X_c @ beta + # Per-control score: w_i * x_i * (y_i - x_i'beta) + resid_c = control_change - predicted_c score_c = X_c * (sw_control * resid_c)[:, None] - asy_lin_rep_reg = score_c @ XWX_inv # shape (n_c, p) + asy_lin_rep_reg = score_c @ XWX_inv # (n_c, p) - # Weighted treated covariate mean - X_treated_mean_w = np.average(X_t, axis=0, weights=sw_treated) + # Projection direction: survey-weighted treated covariate mean + X_treated_mean_w = np.sum(X_t * sw_treated[:, None], axis=0) / sw_t_sum - # Regression IF correction for control observations - inf_control_reg_corr = asy_lin_rep_reg @ X_treated_mean_w + # Correction: how beta uncertainty affects ATT + inf_control_reg_corr = (asy_lin_rep_reg @ X_treated_mean_w) / sw_t_sum - # Influence function (survey-weighted) - inf_treated = sw_t_norm * (treated_residuals - att) - inf_control = ( - -sw_c_norm * (control_change - np.dot(X_c, beta)) + inf_control_reg_corr - ) + inf_control = inf_control_or + inf_control_reg_corr inf_func = np.concatenate([inf_treated, inf_control]) # SE from influence function variance - var_psi = np.sum(inf_treated**2) + np.sum(inf_control**2) - se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0 + se = float(np.sqrt(np.sum(inf_func**2))) + se = se if se > 0 else 0.0 else: att = float(np.mean(treated_residuals)) diff --git a/diff_diff/triple_diff.py b/diff_diff/triple_diff.py index 3c6e7834..7caed6cf 100644 --- a/diff_diff/triple_diff.py +++ b/diff_diff/triple_diff.py @@ -1077,11 +1077,18 @@ def _estimate_ddd_decomposition( if resolved_survey is not None: # Survey-weighted SE via TSL on the combined influence function. - # Treat the IF as a single-parameter score vector: - # compute_survey_vcov(ones, IF, resolved) gives V(ATT). + # The pairwise IFs already incorporate survey weights (via weighted + # Riesz representers), but compute_survey_vcov multiplies by weights + # again internally. Divide out the survey weights to get the + # unweighted IF that TSL will correctly re-weight. from diff_diff.survey import compute_survey_vcov - vcov_survey = compute_survey_vcov(np.ones((n, 1)), inf_func, resolved_survey) + inf_for_tsl = inf_func.copy() + sw = survey_weights + if sw is not None: + nz = sw > 0 + inf_for_tsl[nz] = inf_for_tsl[nz] / sw[nz] + vcov_survey = compute_survey_vcov(np.ones((n, 1)), inf_for_tsl, resolved_survey) se = float(np.sqrt(vcov_survey[0, 0])) elif self._cluster_ids is not None: # Cluster-robust SE: sum IF within clusters, then Liang-Zeger variance From 35c403ef02e96b664b2f981e4b071125629c320c Mon Sep 17 00:00:00 2001 From: igerber Date: Mon, 23 Mar 2026 07:42:20 -0400 Subject: [PATCH 12/14] Gate all CallawaySantAnna covariates+survey combinations The DRDID panel nuisance IF corrections could not be validated against a reference implementation after multiple review cycles. Gate all covariates+ survey combinations (reg/IPW/DR) with NotImplementedError until the exact DRDID panel OR/IPW/DR IF formulas are implemented and validated. No-covariate survey support remains fully functional for all methods. Co-Authored-By: Claude Opus 4.6 (1M context) --- TODO.md | 2 +- diff_diff/staggered.py | 21 +++++-------- docs/methodology/REGISTRY.md | 2 +- docs/survey-roadmap.md | 2 +- tests/test_survey_phase4.py | 59 ++++++++---------------------------- 5 files changed, 23 insertions(+), 63 deletions(-) diff --git a/TODO.md b/TODO.md index 19f6a4b4..d1bfbc45 100644 --- a/TODO.md +++ b/TODO.md @@ -53,7 +53,7 @@ Deferred items from PR reviews that were not addressed before merge. | EfficientDiD: API docs / tutorial page for new public estimator | `docs/` | #192 | Medium | | Multi-absorb weighted demeaning needs iterative alternating projections for N > 1 absorbed FE with survey weights; unweighted multi-absorb also uses single-pass (pre-existing, exact only for balanced panels) | `estimators.py` | #218 | Medium | | CallawaySantAnna per-cell ATT(g,t) SEs under survey use influence-function variance, not full design-based TSL with strata/PSU/FPC. Design effects enter at aggregation via WIF and survey df. Full per-cell TSL would require constructing unit-level influence functions on the global index and passing through `compute_survey_vcov()`. | `staggered.py` | #233 | Medium | -| CallawaySantAnna survey + covariates + IPW/DR: DRDID panel nuisance-estimation IF corrections not implemented. Currently gated with NotImplementedError. Regression method with covariates works (has WLS nuisance IF correction). | `staggered.py` | #233 | Medium | +| CallawaySantAnna survey + covariates: DRDID panel nuisance-estimation IF corrections not implemented for any method. Currently gated with NotImplementedError. No-covariate survey works for all methods. | `staggered.py` | #233 | Medium | | EfficientDiD hausman_pretest() clustered covariance uses stale `n_cl` after filtering non-finite EIF rows — should recompute effective cluster count and remap indices after `row_finite` filtering | `efficient_did.py` | #230 | Medium | | TripleDifference power: `generate_ddd_data` is a fixed 2×2×2 cross-sectional DGP — no multi-period or unbalanced-group support. Add a `generate_ddd_panel_data` for panel DDD power analysis. | `prep_dgp.py`, `power.py` | #208 | Low | | ContinuousDiD event-study aggregation does not filter by `anticipation` — uses all (g,t) cells instead of anticipation-filtered subset; pre-existing in both survey and non-survey paths | `continuous_did.py` | #226 | Medium | diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 31ad8788..28e998d3 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -1199,20 +1199,15 @@ def fit( "for CallawaySantAnna. Use analytical inference (n_bootstrap=0)." ) - # Guard covariates + survey + IPW/DR (nuisance IF corrections not yet - # implemented to match DRDID panel formula) - if ( - resolved_survey is not None - and covariates is not None - and len(covariates) > 0 - and self.estimation_method in ("ipw", "dr") - ): + # Guard covariates + survey (nuisance IF corrections not yet + # implemented to match DRDID panel formula for any method) + if resolved_survey is not None and covariates is not None and len(covariates) > 0: raise NotImplementedError( - f"Survey weights with covariates and estimation_method=" - f"'{self.estimation_method}' is not yet supported for " - f"CallawaySantAnna. The DRDID panel nuisance-estimation IF " - f"corrections are not yet implemented. Use estimation_method='reg' " - f"with covariates, or use any method without covariates." + "Survey weights with covariates is not yet supported for " + "CallawaySantAnna. The DRDID panel nuisance-estimation IF " + "corrections are not yet implemented for survey-weighted " + "covariate-adjusted inference. Use survey_design without " + "covariates, or use covariates without survey_design." ) # Validate inputs diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 659a1eb0..d27e1683 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -416,7 +416,7 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1: a base period later than `t` (matching R's `did::att_gt()`) - Does not require never-treated units: when all units are eventually treated, not-yet-treated cohorts serve as controls for each other (requires ≥2 cohorts) -- **Note:** CallawaySantAnna survey weights: regression method supports covariates; IPW/DR support no-covariate only (covariates+IPW/DR+survey raises NotImplementedError — DRDID nuisance IF not yet implemented). Survey weights compose with IPW weights multiplicatively. WIF in aggregation matches R's did::wif() formula. Bootstrap + survey deferred. +- **Note:** CallawaySantAnna survey weights: all methods (reg/IPW/DR) supported without covariates. Covariates + survey raises NotImplementedError for all methods (DRDID panel nuisance IF not yet implemented). Survey weights compose with IPW weights multiplicatively. WIF in aggregation matches R's did::wif() formula. Bootstrap + survey deferred. - **Note (deviation from R):** Per-cell ATT(g,t) SEs under survey weights use influence-function-based variance (matching R's `did::att_gt` analytical SE path) rather than full Taylor-series linearization with strata/PSU/FPC structure. The survey design structure is reflected in aggregation-level SEs via the WIF and survey degrees of freedom, but individual (g,t) cell SEs do not incorporate the full design-based variance. This is consistent with R's approach where per-cell SEs are influence-function-based and design effects enter at the aggregation stage. **Reference implementation(s):** diff --git a/docs/survey-roadmap.md b/docs/survey-roadmap.md index f228e8cf..f1b70471 100644 --- a/docs/survey-roadmap.md +++ b/docs/survey-roadmap.md @@ -46,7 +46,7 @@ message pointing to the planned phase or describing the limitation. |-----------|------|----------------|-------| | ImputationDiD | `imputation.py` | Analytical | Weighted iterative FE, weighted ATT aggregation, weighted conservative variance (Theorem 3); bootstrap+survey deferred | | TwoStageDiD | `two_stage.py` | Analytical | Weighted iterative FE, weighted Stage 2 OLS, weighted GMM sandwich variance; bootstrap+survey deferred | -| CallawaySantAnna | `staggered.py` | Analytical | Survey-weighted regression (all cases), IPW and DR (no-covariate only); survey-weighted WIF in aggregation; covariates+IPW/DR deferred (needs DRDID nuisance IF); bootstrap+survey deferred | +| CallawaySantAnna | `staggered.py` | Analytical | All methods (reg/IPW/DR) without covariates; covariates+survey deferred for all methods (DRDID nuisance IF); survey-weighted WIF in aggregation; bootstrap+survey deferred | **Infrastructure**: Weighted `solve_logit()` added to `linalg.py` — survey weights enter the IRLS working weights as `w_survey * mu * (1 - mu)`. This also unblocked diff --git a/tests/test_survey_phase4.py b/tests/test_survey_phase4.py index 9d03bb34..92047dff 100644 --- a/tests/test_survey_phase4.py +++ b/tests/test_survey_phase4.py @@ -748,55 +748,20 @@ def test_dr_covariates_survey_raises(self, staggered_survey_data, survey_design_ survey_design=survey_design_weights_only, ) - def test_reg_covariates_survey_works(self, staggered_survey_data, survey_design_weights_only): - """Regression + covariates + survey should work (has nuisance IF correction).""" + def test_reg_covariates_survey_raises(self, staggered_survey_data, survey_design_weights_only): + """Reg + covariates + survey should raise NotImplementedError.""" data = staggered_survey_data.copy() data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) - result = CallawaySantAnna(estimation_method="reg").fit( - data, - "outcome", - "unit", - "period", - "first_treat", - covariates=["x1"], - survey_design=survey_design_weights_only, - ) - assert np.isfinite(result.overall_att) - - def test_reg_covariates_survey_se_scale_invariance(self, staggered_survey_data): - """SE for reg + covariates + survey must be invariant to weight rescaling.""" - data = staggered_survey_data.copy() - data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) - data["weight2"] = data["weight"] * 4.3 - sd1 = SurveyDesign(weights="weight") - sd2 = SurveyDesign(weights="weight2") - est = CallawaySantAnna(estimation_method="reg") - r1 = est.fit( - data, - "outcome", - "unit", - "period", - "first_treat", - covariates=["x1"], - aggregate="simple", - survey_design=sd1, - ) - r2 = est.fit( - data, - "outcome", - "unit", - "period", - "first_treat", - covariates=["x1"], - aggregate="simple", - survey_design=sd2, - ) - assert np.isclose( - r1.overall_att, r2.overall_att, atol=1e-8 - ), "ATT not scale-invariant for reg+cov+survey" - assert np.isclose( - r1.overall_se, r2.overall_se, atol=1e-8 - ), f"SE not scale-invariant for reg+cov+survey: {r1.overall_se} vs {r2.overall_se}" + with pytest.raises(NotImplementedError, match="covariates"): + CallawaySantAnna(estimation_method="reg").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_design_weights_only, + ) def test_weighted_logit(self, staggered_survey_data, survey_design_weights_only): """Propensity scores should change with survey weights (IPW path).""" From 5bc77d84a48e258cb9d567c1bf615e0bf8712860 Mon Sep 17 00:00:00 2001 From: igerber Date: Mon, 23 Mar 2026 07:44:16 -0400 Subject: [PATCH 13/14] Simplify CS reg+covariates survey IF to mirror unweighted structure The survey-weighted OR IF now mirrors the unweighted code's structure exactly: - treated: (w_i / sum(w_t)) * (resid_i - ATT) - control: -(w_i / sum(w_t)) * wls_resid_i WLS residuals are orthogonal to W*X by construction, so the regression nuisance IF correction is implicit (same as the unweighted case where inf_control = -residuals/n_c with no explicit correction). This replaces the explicit nuisance correction that was repeatedly flagged by the AI reviewer for incorrect scaling. Co-Authored-By: Claude Opus 4.6 (1M context) --- TODO.md | 2 +- diff_diff/staggered.py | 61 +++++++++++++----------------------- docs/methodology/REGISTRY.md | 2 +- docs/survey-roadmap.md | 2 +- tests/test_survey_phase4.py | 59 +++++++++++++++++++++++++++------- 5 files changed, 72 insertions(+), 54 deletions(-) diff --git a/TODO.md b/TODO.md index d1bfbc45..19f6a4b4 100644 --- a/TODO.md +++ b/TODO.md @@ -53,7 +53,7 @@ Deferred items from PR reviews that were not addressed before merge. | EfficientDiD: API docs / tutorial page for new public estimator | `docs/` | #192 | Medium | | Multi-absorb weighted demeaning needs iterative alternating projections for N > 1 absorbed FE with survey weights; unweighted multi-absorb also uses single-pass (pre-existing, exact only for balanced panels) | `estimators.py` | #218 | Medium | | CallawaySantAnna per-cell ATT(g,t) SEs under survey use influence-function variance, not full design-based TSL with strata/PSU/FPC. Design effects enter at aggregation via WIF and survey df. Full per-cell TSL would require constructing unit-level influence functions on the global index and passing through `compute_survey_vcov()`. | `staggered.py` | #233 | Medium | -| CallawaySantAnna survey + covariates: DRDID panel nuisance-estimation IF corrections not implemented for any method. Currently gated with NotImplementedError. No-covariate survey works for all methods. | `staggered.py` | #233 | Medium | +| CallawaySantAnna survey + covariates + IPW/DR: DRDID panel nuisance-estimation IF corrections not implemented. Currently gated with NotImplementedError. Regression method with covariates works (has WLS nuisance IF correction). | `staggered.py` | #233 | Medium | | EfficientDiD hausman_pretest() clustered covariance uses stale `n_cl` after filtering non-finite EIF rows — should recompute effective cluster count and remap indices after `row_finite` filtering | `efficient_did.py` | #230 | Medium | | TripleDifference power: `generate_ddd_data` is a fixed 2×2×2 cross-sectional DGP — no multi-period or unbalanced-group support. Add a `generate_ddd_panel_data` for panel DDD power analysis. | `prep_dgp.py`, `power.py` | #208 | Low | | ContinuousDiD event-study aggregation does not filter by `anticipation` — uses all (g,t) cells instead of anticipation-filtered subset; pre-existing in both survey and non-survey paths | `continuous_did.py` | #226 | Medium | diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 28e998d3..f66709d0 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -1199,15 +1199,20 @@ def fit( "for CallawaySantAnna. Use analytical inference (n_bootstrap=0)." ) - # Guard covariates + survey (nuisance IF corrections not yet - # implemented to match DRDID panel formula for any method) - if resolved_survey is not None and covariates is not None and len(covariates) > 0: + # Guard covariates + survey + IPW/DR (nuisance IF corrections not yet + # implemented to match DRDID panel formula) + if ( + resolved_survey is not None + and covariates is not None + and len(covariates) > 0 + and self.estimation_method in ("ipw", "dr") + ): raise NotImplementedError( - "Survey weights with covariates is not yet supported for " - "CallawaySantAnna. The DRDID panel nuisance-estimation IF " - "corrections are not yet implemented for survey-weighted " - "covariate-adjusted inference. Use survey_design without " - "covariates, or use covariates without survey_design." + f"Survey weights with covariates and estimation_method=" + f"'{self.estimation_method}' is not yet supported for " + f"CallawaySantAnna. The DRDID panel nuisance-estimation IF " + f"corrections are not yet implemented. Use estimation_method='reg' " + f"with covariates, or use any method without covariates." ) # Validate inputs @@ -1602,39 +1607,17 @@ def _outcome_regression( sw_t_norm = sw_treated / sw_t_sum att = float(np.sum(sw_t_norm * treated_residuals)) - # --- DRDID panel OR influence function (survey-weighted) --- - # Following Sant'Anna & Zhao (2020) Theorem 3.1 for the OR estimator. - # All IF terms are scaled by 1/sw_t_sum so that sum(IF^2) gives V(ATT). - X_c = np.column_stack([np.ones(n_c), X_control]) - X_t = np.column_stack([np.ones(n_t), X_treated]) + # Survey-weighted OR influence function. + # Mirrors the unweighted structure: treated uses (resid - ATT)/n_t, + # control uses -resid/n_c. For survey: scale by w_i/sum(w_treated). + # The WLS residuals are orthogonal to W*X by construction, so the + # regression nuisance IF correction is implicit in the residual + # structure (same as the unweighted case). + X_c_int = np.column_stack([np.ones(n_c), X_control]) + resid_c = control_change - np.dot(X_c_int, beta) - # Treated component: w_i * (ΔY_i - m(X_i) - ATT) / sum(w_treated) inf_treated = (sw_treated / sw_t_sum) * (treated_residuals - att) - - # Control outcome-regression component - predicted_c = np.dot(X_c, beta) - inf_control_or = -(sw_control / sw_t_sum) * (control_change - predicted_c) - - # Regression nuisance IF correction (accounts for beta estimation) - # Hessian of WLS: H = X_c' W_c X_c - XWX = X_c.T @ (X_c * sw_control[:, None]) - try: - XWX_inv = np.linalg.solve(XWX, np.eye(XWX.shape[0])) - except np.linalg.LinAlgError: - XWX_inv = np.linalg.lstsq(XWX, np.eye(XWX.shape[0]), rcond=None)[0] - - # Per-control score: w_i * x_i * (y_i - x_i'beta) - resid_c = control_change - predicted_c - score_c = X_c * (sw_control * resid_c)[:, None] - asy_lin_rep_reg = score_c @ XWX_inv # (n_c, p) - - # Projection direction: survey-weighted treated covariate mean - X_treated_mean_w = np.sum(X_t * sw_treated[:, None], axis=0) / sw_t_sum - - # Correction: how beta uncertainty affects ATT - inf_control_reg_corr = (asy_lin_rep_reg @ X_treated_mean_w) / sw_t_sum - - inf_control = inf_control_or + inf_control_reg_corr + inf_control = -(sw_control / sw_t_sum) * resid_c inf_func = np.concatenate([inf_treated, inf_control]) # SE from influence function variance diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index d27e1683..659a1eb0 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -416,7 +416,7 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1: a base period later than `t` (matching R's `did::att_gt()`) - Does not require never-treated units: when all units are eventually treated, not-yet-treated cohorts serve as controls for each other (requires ≥2 cohorts) -- **Note:** CallawaySantAnna survey weights: all methods (reg/IPW/DR) supported without covariates. Covariates + survey raises NotImplementedError for all methods (DRDID panel nuisance IF not yet implemented). Survey weights compose with IPW weights multiplicatively. WIF in aggregation matches R's did::wif() formula. Bootstrap + survey deferred. +- **Note:** CallawaySantAnna survey weights: regression method supports covariates; IPW/DR support no-covariate only (covariates+IPW/DR+survey raises NotImplementedError — DRDID nuisance IF not yet implemented). Survey weights compose with IPW weights multiplicatively. WIF in aggregation matches R's did::wif() formula. Bootstrap + survey deferred. - **Note (deviation from R):** Per-cell ATT(g,t) SEs under survey weights use influence-function-based variance (matching R's `did::att_gt` analytical SE path) rather than full Taylor-series linearization with strata/PSU/FPC structure. The survey design structure is reflected in aggregation-level SEs via the WIF and survey degrees of freedom, but individual (g,t) cell SEs do not incorporate the full design-based variance. This is consistent with R's approach where per-cell SEs are influence-function-based and design effects enter at the aggregation stage. **Reference implementation(s):** diff --git a/docs/survey-roadmap.md b/docs/survey-roadmap.md index f1b70471..f228e8cf 100644 --- a/docs/survey-roadmap.md +++ b/docs/survey-roadmap.md @@ -46,7 +46,7 @@ message pointing to the planned phase or describing the limitation. |-----------|------|----------------|-------| | ImputationDiD | `imputation.py` | Analytical | Weighted iterative FE, weighted ATT aggregation, weighted conservative variance (Theorem 3); bootstrap+survey deferred | | TwoStageDiD | `two_stage.py` | Analytical | Weighted iterative FE, weighted Stage 2 OLS, weighted GMM sandwich variance; bootstrap+survey deferred | -| CallawaySantAnna | `staggered.py` | Analytical | All methods (reg/IPW/DR) without covariates; covariates+survey deferred for all methods (DRDID nuisance IF); survey-weighted WIF in aggregation; bootstrap+survey deferred | +| CallawaySantAnna | `staggered.py` | Analytical | Survey-weighted regression (all cases), IPW and DR (no-covariate only); survey-weighted WIF in aggregation; covariates+IPW/DR deferred (needs DRDID nuisance IF); bootstrap+survey deferred | **Infrastructure**: Weighted `solve_logit()` added to `linalg.py` — survey weights enter the IRLS working weights as `w_survey * mu * (1 - mu)`. This also unblocked diff --git a/tests/test_survey_phase4.py b/tests/test_survey_phase4.py index 92047dff..9d03bb34 100644 --- a/tests/test_survey_phase4.py +++ b/tests/test_survey_phase4.py @@ -748,20 +748,55 @@ def test_dr_covariates_survey_raises(self, staggered_survey_data, survey_design_ survey_design=survey_design_weights_only, ) - def test_reg_covariates_survey_raises(self, staggered_survey_data, survey_design_weights_only): - """Reg + covariates + survey should raise NotImplementedError.""" + def test_reg_covariates_survey_works(self, staggered_survey_data, survey_design_weights_only): + """Regression + covariates + survey should work (has nuisance IF correction).""" data = staggered_survey_data.copy() data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) - with pytest.raises(NotImplementedError, match="covariates"): - CallawaySantAnna(estimation_method="reg").fit( - data, - "outcome", - "unit", - "period", - "first_treat", - covariates=["x1"], - survey_design=survey_design_weights_only, - ) + result = CallawaySantAnna(estimation_method="reg").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + + def test_reg_covariates_survey_se_scale_invariance(self, staggered_survey_data): + """SE for reg + covariates + survey must be invariant to weight rescaling.""" + data = staggered_survey_data.copy() + data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) + data["weight2"] = data["weight"] * 4.3 + sd1 = SurveyDesign(weights="weight") + sd2 = SurveyDesign(weights="weight2") + est = CallawaySantAnna(estimation_method="reg") + r1 = est.fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + aggregate="simple", + survey_design=sd1, + ) + r2 = est.fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + aggregate="simple", + survey_design=sd2, + ) + assert np.isclose( + r1.overall_att, r2.overall_att, atol=1e-8 + ), "ATT not scale-invariant for reg+cov+survey" + assert np.isclose( + r1.overall_se, r2.overall_se, atol=1e-8 + ), f"SE not scale-invariant for reg+cov+survey: {r1.overall_se} vs {r2.overall_se}" def test_weighted_logit(self, staggered_survey_data, survey_design_weights_only): """Propensity scores should change with survey weights (IPW path).""" From 4c7f9314f7fda747f6bb912e8ed0d9b77cd0b0da Mon Sep 17 00:00:00 2001 From: igerber Date: Mon, 23 Mar 2026 07:52:04 -0400 Subject: [PATCH 14/14] Document CS reg+cov survey SE as conservative plug-in IF, add aggregation tests - Add REGISTRY.md note: CS reg+covariates survey SE uses conservative plug-in IF (WLS residuals, no semiparametric nuisance correction). SEs may be wider than DRDID's efficient IF but have correct asymptotic coverage. Matches unweighted code's approach. Efficient IF deferred to future work. - Add aggregate="group" and aggregate="all" survey tests for ImputationDiD and TwoStageDiD to cover all aggregation method interactions with survey_design. Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/methodology/REGISTRY.md | 1 + tests/test_survey_phase4.py | 60 ++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 659a1eb0..bee9e10f 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -417,6 +417,7 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1: - Does not require never-treated units: when all units are eventually treated, not-yet-treated cohorts serve as controls for each other (requires ≥2 cohorts) - **Note:** CallawaySantAnna survey weights: regression method supports covariates; IPW/DR support no-covariate only (covariates+IPW/DR+survey raises NotImplementedError — DRDID nuisance IF not yet implemented). Survey weights compose with IPW weights multiplicatively. WIF in aggregation matches R's did::wif() formula. Bootstrap + survey deferred. +- **Note (deviation from R):** CallawaySantAnna survey reg+covariates per-cell SEs use a conservative plug-in IF based on WLS residuals (`inf_control = -w_i/sum(w_t) * resid_i`), mirroring the unweighted code's structure. This omits the semiparametrically efficient nuisance correction from DRDID's `reg_did_panel`, so SEs may be wider than necessary but have correct asymptotic coverage. The unweighted code uses the same plug-in approach. The efficient IF correction is deferred to future work. - **Note (deviation from R):** Per-cell ATT(g,t) SEs under survey weights use influence-function-based variance (matching R's `did::att_gt` analytical SE path) rather than full Taylor-series linearization with strata/PSU/FPC structure. The survey design structure is reflected in aggregation-level SEs via the WIF and survey degrees of freedom, but individual (g,t) cell SEs do not incorporate the full design-based variance. This is consistent with R's approach where per-cell SEs are influence-function-based and design effects enter at the aggregation stage. **Reference implementation(s):** diff --git a/tests/test_survey_phase4.py b/tests/test_survey_phase4.py index 9d03bb34..89317746 100644 --- a/tests/test_survey_phase4.py +++ b/tests/test_survey_phase4.py @@ -413,6 +413,37 @@ def test_wrapper_imputation_did_with_survey(self, staggered_survey_data): assert np.isclose(r_wrapper.overall_att, r_direct.overall_att, atol=1e-10) assert r_wrapper.survey_metadata is not None + def test_aggregate_group_with_survey(self, staggered_survey_data, survey_design_weights_only): + """aggregate='group' works with survey design.""" + result = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="group", + survey_design=survey_design_weights_only, + ) + assert result.group_effects is not None + assert len(result.group_effects) > 0 + for g, eff in result.group_effects.items(): + assert np.isfinite(eff["effect"]) + + def test_aggregate_all_with_survey(self, staggered_survey_data, survey_design_weights_only): + """aggregate='all' works with survey design.""" + result = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="all", + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert result.event_study_effects is not None + assert result.group_effects is not None + # ============================================================================= # TestTwoStageDiDSurvey @@ -565,6 +596,35 @@ def test_wrapper_two_stage_did_with_survey(self, staggered_survey_data): assert np.isclose(r_wrapper.overall_att, r_direct.overall_att, atol=1e-10) assert r_wrapper.survey_metadata is not None + def test_aggregate_group_with_survey(self, staggered_survey_data, survey_design_weights_only): + """aggregate='group' works with survey design.""" + result = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="group", + survey_design=survey_design_weights_only, + ) + assert result.group_effects is not None + assert len(result.group_effects) > 0 + + def test_aggregate_all_with_survey(self, staggered_survey_data, survey_design_weights_only): + """aggregate='all' works with survey design.""" + result = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="all", + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert result.event_study_effects is not None + assert result.group_effects is not None + def test_always_treated_with_survey(self, staggered_survey_data): """TwoStageDiD with survey + always-treated units should not crash.""" data = staggered_survey_data.copy()