diff --git a/TODO.md b/TODO.md index 913282b..19f6a4b 100644 --- a/TODO.md +++ b/TODO.md @@ -52,6 +52,9 @@ Deferred items from PR reviews that were not addressed before merge. | ImputationDiD dense `(A0'A0).toarray()` scales O((U+T+K)^2), OOM risk on large panels | `imputation.py` | #141 | Medium (deferred — only triggers when sparse solver fails; fixing requires sparse least-squares alternatives) | | EfficientDiD: API docs / tutorial page for new public estimator | `docs/` | #192 | Medium | | Multi-absorb weighted demeaning needs iterative alternating projections for N > 1 absorbed FE with survey weights; unweighted multi-absorb also uses single-pass (pre-existing, exact only for balanced panels) | `estimators.py` | #218 | Medium | +| CallawaySantAnna per-cell ATT(g,t) SEs under survey use influence-function variance, not full design-based TSL with strata/PSU/FPC. Design effects enter at aggregation via WIF and survey df. Full per-cell TSL would require constructing unit-level influence functions on the global index and passing through `compute_survey_vcov()`. | `staggered.py` | #233 | Medium | +| CallawaySantAnna survey + covariates + IPW/DR: DRDID panel nuisance-estimation IF corrections not implemented. Currently gated with NotImplementedError. Regression method with covariates works (has WLS nuisance IF correction). | `staggered.py` | #233 | Medium | +| EfficientDiD hausman_pretest() clustered covariance uses stale `n_cl` after filtering non-finite EIF rows — should recompute effective cluster count and remap indices after `row_finite` filtering | `efficient_did.py` | #230 | Medium | | TripleDifference power: `generate_ddd_data` is a fixed 2×2×2 cross-sectional DGP — no multi-period or unbalanced-group support. Add a `generate_ddd_panel_data` for panel DDD power analysis. | `prep_dgp.py`, `power.py` | #208 | Low | | ContinuousDiD event-study aggregation does not filter by `anticipation` — uses all (g,t) cells instead of anticipation-filtered subset; pre-existing in both survey and non-survey paths | `continuous_did.py` | #226 | Medium | | Survey design resolution/collapse patterns are inconsistent across panel estimators — ContinuousDiD rebuilds unit-level design in SE code, EfficientDiD builds once in fit(), StackedDiD re-resolves on stacked data; extract shared helpers for panel-to-unit collapse, post-filter re-resolution, and metadata recomputation | `continuous_did.py`, `efficient_did.py`, `stacked_did.py` | #226 | Low | diff --git a/diff_diff/imputation.py b/diff_diff/imputation.py index f74eca9..8a386a6 100644 --- a/diff_diff/imputation.py +++ b/diff_diff/imputation.py @@ -177,6 +177,7 @@ def fit( covariates: Optional[List[str]] = None, aggregate: Optional[str] = None, balance_e: Optional[int] = None, + survey_design: object = None, ) -> ImputationDiDResults: """ Fit the imputation DiD estimator. @@ -225,6 +226,29 @@ def fit( # Create working copy df = data.copy() + # Resolve survey design if provided + from diff_diff.survey import ( + _inject_cluster_as_psu, + _resolve_effective_cluster, + _resolve_survey_for_fit, + _validate_unit_constant_survey, + ) + + resolved_survey, survey_weights, _, survey_metadata = _resolve_survey_for_fit( + survey_design, data, "analytical" + ) + + # Validate within-unit constancy for panel survey designs + if resolved_survey is not None: + _validate_unit_constant_survey(data, unit, survey_design) + + # Guard bootstrap + survey + if self.n_bootstrap > 0 and resolved_survey is not None: + raise NotImplementedError( + "Bootstrap inference with survey weights is not yet supported " + "for ImputationDiD. Use analytical inference (n_bootstrap=0)." + ) + # Ensure numeric types df[time] = pd.to_numeric(df[time]) df[first_treat] = pd.to_numeric(df[first_treat]) @@ -312,6 +336,26 @@ def fit( f"Available columns: {list(df.columns)}" ) + # Resolve effective cluster and inject cluster-as-PSU for survey variance + if resolved_survey is not None: + cluster_ids_raw = df[cluster_var].values if cluster_var in df.columns else None + effective_cluster_ids = _resolve_effective_cluster( + resolved_survey, + cluster_ids_raw, + cluster_var if self.cluster is not None else None, + ) + resolved_survey = _inject_cluster_as_psu(resolved_survey, effective_cluster_ids) + # Recompute metadata after PSU injection + if resolved_survey.psu is not None and survey_metadata is not None: + from diff_diff.survey import compute_survey_metadata + + raw_w = ( + data[survey_design.weights].values.astype(np.float64) + if survey_design.weights + else np.ones(len(data), dtype=np.float64) + ) + survey_metadata = compute_survey_metadata(resolved_survey, raw_w) + # Compute relative time df["_rel_time"] = np.where( ~df["_never_treated"], @@ -321,7 +365,7 @@ def fit( # ---- Step 1: OLS on untreated observations ---- unit_fe, time_fe, grand_mean, delta_hat, kept_cov_mask = self._fit_untreated_model( - df, outcome, unit, time, covariates, omega_0_mask + df, outcome, unit, time, covariates, omega_0_mask, weights=survey_weights ) # ---- Rank condition checks ---- @@ -382,20 +426,31 @@ def fit( # ---- Step 3: Aggregate ---- # Always compute overall ATT (simple aggregation) - valid_tau = tau_hat[np.isfinite(tau_hat)] + finite_mask = np.isfinite(tau_hat) + valid_tau = tau_hat[finite_mask] if len(valid_tau) == 0: overall_att = np.nan + elif survey_weights is not None: + # Survey-weighted ATT: use treated obs' survey weights + treated_survey_w = survey_weights[omega_1_mask.values] + w_finite = treated_survey_w[finite_mask] + overall_att = float(np.average(valid_tau, weights=w_finite)) else: overall_att = float(np.mean(valid_tau)) # ---- Conservative variance (Theorem 3) ---- - # Build weights matching the ATT: uniform over finite tau_hat, zero for NaN + # Build weights matching the ATT: proportional to survey weights for + # finite tau_hat, uniform when no survey overall_weights = np.zeros(n_omega_1) - finite_mask = np.isfinite(tau_hat) n_valid = int(finite_mask.sum()) if n_valid > 0: - overall_weights[finite_mask] = 1.0 / n_valid + if survey_weights is not None: + treated_sw = survey_weights[omega_1_mask.values] + sw_finite = treated_sw[finite_mask] + overall_weights[finite_mask] = sw_finite / sw_finite.sum() + else: + overall_weights[finite_mask] = 1.0 / n_valid if n_valid == 0: overall_se = np.nan @@ -416,9 +471,15 @@ def fit( weights=overall_weights, cluster_var=cluster_var, kept_cov_mask=kept_cov_mask, + survey_weights=survey_weights, ) - overall_t, overall_p, overall_ci = safe_inference(overall_att, overall_se, alpha=self.alpha) + # Survey degrees of freedom for t-distribution inference + _survey_df = resolved_survey.df_survey if resolved_survey is not None else None + + overall_t, overall_p, overall_ci = safe_inference( + overall_att, overall_se, alpha=self.alpha, df=_survey_df + ) # Event study and group aggregation event_study_effects = None @@ -442,6 +503,8 @@ def fit( treatment_groups=treatment_groups, balance_e=balance_e, kept_cov_mask=kept_cov_mask, + survey_weights=survey_weights, + survey_df=_survey_df, ) if aggregate in ("group", "all"): @@ -461,6 +524,8 @@ def fit( cluster_var=cluster_var, treatment_groups=treatment_groups, kept_cov_mask=kept_cov_mask, + survey_weights=survey_weights, + survey_df=_survey_df, ) # Build treatment effects dataframe @@ -599,6 +664,7 @@ def fit( alpha=self.alpha, bootstrap_results=bootstrap_results, _estimator_ref=self, + survey_metadata=survey_metadata, ) self.is_fitted_ = True @@ -616,6 +682,7 @@ def _iterative_fe( idx: pd.Index, max_iter: int = 100, tol: float = 1e-10, + weights: Optional[np.ndarray] = None, ) -> Tuple[Dict[Any, float], Dict[Any, float]]: """ Estimate unit and time FE via iterative alternating projection (Gauss-Seidel). @@ -624,6 +691,12 @@ def _iterative_fe( For balanced panels, converges in 1-2 iterations (identical to one-pass). For unbalanced panels, typically 5-20 iterations. + Parameters + ---------- + weights : np.ndarray, optional + Survey weights. When provided, uses weighted group means + (sum(w*x)/sum(w)) instead of unweighted means. + Returns ------- unit_fe : dict @@ -635,25 +708,37 @@ def _iterative_fe( alpha = np.zeros(n) # unit FE broadcast to obs level beta = np.zeros(n) # time FE broadcast to obs level + # Precompute per-group weight sums (invariant across iterations) + if weights is not None: + w_series = pd.Series(weights, index=idx) + wsum_t = w_series.groupby(time_vals).transform("sum").values + wsum_u = w_series.groupby(unit_vals).transform("sum").values + with np.errstate(invalid="ignore", divide="ignore"): for iteration in range(max_iter): - # Update time FE: beta_t = mean_i(y_it - alpha_i) resid_after_alpha = y - alpha - beta_new = ( - pd.Series(resid_after_alpha, index=idx) - .groupby(time_vals) - .transform("mean") - .values - ) + if weights is not None: + wr_t = pd.Series(resid_after_alpha * weights, index=idx) + beta_new = wr_t.groupby(time_vals).transform("sum").values / wsum_t + else: + beta_new = ( + pd.Series(resid_after_alpha, index=idx) + .groupby(time_vals) + .transform("mean") + .values + ) - # Update unit FE: alpha_i = mean_t(y_it - beta_t) resid_after_beta = y - beta_new - alpha_new = ( - pd.Series(resid_after_beta, index=idx) - .groupby(unit_vals) - .transform("mean") - .values - ) + if weights is not None: + wr_u = pd.Series(resid_after_beta * weights, index=idx) + alpha_new = wr_u.groupby(unit_vals).transform("sum").values / wsum_u + else: + alpha_new = ( + pd.Series(resid_after_beta, index=idx) + .groupby(unit_vals) + .transform("mean") + .values + ) # Check convergence on FE changes max_change = max( @@ -677,25 +762,47 @@ def _iterative_demean( idx: pd.Index, max_iter: int = 100, tol: float = 1e-10, + weights: Optional[np.ndarray] = None, ) -> np.ndarray: """Demean a vector by iterative alternating projection (unit + time FE removal). Converges to the exact within-transformation for both balanced and unbalanced panels. For balanced panels, converges in 1-2 iterations. + + Parameters + ---------- + weights : np.ndarray, optional + Survey weights. When provided, uses weighted group means + (sum(w*x)/sum(w)) instead of unweighted means. """ result = vals.copy() + + # Precompute per-group weight sums (invariant across iterations) + if weights is not None: + w_series = pd.Series(weights, index=idx) + wsum_t = w_series.groupby(time_vals).transform("sum").values + wsum_u = w_series.groupby(unit_vals).transform("sum").values + with np.errstate(invalid="ignore", divide="ignore"): for _ in range(max_iter): - time_means = ( - pd.Series(result, index=idx).groupby(time_vals).transform("mean").values - ) + if weights is not None: + wr_t = pd.Series(result * weights, index=idx) + time_means = wr_t.groupby(time_vals).transform("sum").values / wsum_t + else: + time_means = ( + pd.Series(result, index=idx).groupby(time_vals).transform("mean").values + ) result_after_time = result - time_means - unit_means = ( - pd.Series(result_after_time, index=idx) - .groupby(unit_vals) - .transform("mean") - .values - ) + if weights is not None: + wr_u = pd.Series(result_after_time * weights, index=idx) + unit_means = wr_u.groupby(unit_vals).transform("sum").values / wsum_u + else: + unit_means = ( + pd.Series(result_after_time, index=idx) + .groupby(unit_vals) + .transform("mean") + .values + ) result_new = result_after_time - unit_means if np.max(np.abs(result_new - result)) < tol: result = result_new @@ -773,6 +880,7 @@ def _fit_untreated_model( time: str, covariates: Optional[List[str]], omega_0_mask: pd.Series, + weights: Optional[np.ndarray] = None, ) -> Tuple[ Dict[Any, float], Dict[Any, float], float, Optional[np.ndarray], Optional[np.ndarray] ]: @@ -783,6 +891,12 @@ def _fit_untreated_model( OLS fixed effects for both balanced and unbalanced panels. For balanced panels, converges in 1-2 iterations (identical to one-pass demeaning). + Parameters + ---------- + weights : np.ndarray, optional + Full-panel survey weights (same length as df). The untreated subset + is extracted internally via omega_0_mask. When None, unweighted. + Returns ------- unit_fe : dict @@ -798,13 +912,14 @@ def _fit_untreated_model( have finite coefficients. None if no covariates. """ df_0 = df.loc[omega_0_mask] + w_0 = weights[omega_0_mask.values] if weights is not None else None if covariates is None or len(covariates) == 0: # No covariates: estimate FE via iterative alternating projection # (exact OLS for both balanced and unbalanced panels) y = df_0[outcome].values.copy() unit_fe, time_fe = self._iterative_fe( - y, df_0[unit].values, df_0[time].values, df_0.index + y, df_0[unit].values, df_0[time].values, df_0.index, weights=w_0 ) # grand_mean = 0: iterative FE absorb the intercept return unit_fe, time_fe, 0.0, None, None @@ -819,10 +934,10 @@ def _fit_untreated_model( n_cov = len(covariates) # Step A: Iteratively demean Y and all X columns to remove unit+time FE - y_dm = self._iterative_demean(y, units, times, df_0.index) + y_dm = self._iterative_demean(y, units, times, df_0.index, weights=w_0) X_dm = np.column_stack( [ - self._iterative_demean(X_raw[:, j], units, times, df_0.index) + self._iterative_demean(X_raw[:, j], units, times, df_0.index, weights=w_0) for j in range(n_cov) ] ) @@ -834,6 +949,7 @@ def _fit_untreated_model( return_vcov=False, rank_deficient_action=self.rank_deficient_action, column_names=covariates, + weights=w_0, ) delta_hat = result[0] @@ -847,7 +963,7 @@ def _fit_untreated_model( # Step C: Recover FE from covariate-adjusted outcome using iterative FE y_adj = y - np.dot(X_raw, delta_hat_clean) - unit_fe, time_fe = self._iterative_fe(y_adj, units, times, df_0.index) + unit_fe, time_fe = self._iterative_fe(y_adj, units, times, df_0.index, weights=w_0) # grand_mean = 0: iterative FE absorb the intercept return unit_fe, time_fe, 0.0, delta_hat_clean, kept_cov_mask @@ -921,6 +1037,7 @@ def _compute_cluster_psi_sums( weights: np.ndarray, cluster_var: str, kept_cov_mask: Optional[np.ndarray] = None, + survey_weights_0: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, np.ndarray]: """ Compute cluster-level influence function sums (Theorem 3). @@ -960,8 +1077,16 @@ def _compute_cluster_psi_sums( w_total = float(np.sum(weights)) - n0_by_unit = df_0.groupby(unit).size().to_dict() - n0_by_time = df_0.groupby(time).size().to_dict() + # Use survey-weighted sums for untreated denominators when present + if survey_weights_0 is not None: + sw0_series = pd.Series(survey_weights_0, index=df_0.index) + n0_by_unit = sw0_series.groupby(df_0[unit]).sum().to_dict() + n0_by_time = sw0_series.groupby(df_0[time]).sum().to_dict() + n0_denom = float(np.sum(survey_weights_0)) + else: + n0_by_unit = df_0.groupby(unit).size().to_dict() + n0_by_time = df_0.groupby(time).size().to_dict() + n0_denom = n_0 untreated_units = df_0[unit].values untreated_times = df_0[time].values @@ -974,7 +1099,11 @@ def _compute_cluster_psi_sums( w_t = w_by_time.get(t, 0.0) n0_i = n0_by_unit.get(u, 1) n0_t = n0_by_time.get(t, 1) - v_untreated[j] = -(w_i / n0_i + w_t / n0_t - w_total / n_0) + base_v = -(w_i / n0_i + w_t / n0_t - w_total / n0_denom) + # WLS projection requires per-obs survey weight factor + if survey_weights_0 is not None: + base_v *= survey_weights_0[j] + v_untreated[j] = base_v else: v_untreated = self._compute_v_untreated_with_covariates( df_0, @@ -985,6 +1114,7 @@ def _compute_cluster_psi_sums( weights, delta_hat, kept_cov_mask=kept_cov_mask, + survey_weights_0=survey_weights_0, ) # ---- Compute auxiliary model residuals (Equation 8) ---- @@ -1043,6 +1173,7 @@ def _compute_conservative_variance( weights: np.ndarray, cluster_var: str, kept_cov_mask: Optional[np.ndarray] = None, + survey_weights: Optional[np.ndarray] = None, ) -> float: """ Compute conservative clustered variance (Theorem 3, Equation 7). @@ -1052,12 +1183,16 @@ def _compute_conservative_variance( weights : np.ndarray Aggregation weights w_it for treated observations. Shape: (n_treated,), must sum to 1. + survey_weights : np.ndarray, optional + Full-panel survey weights. When provided, untreated denominators + in v_it use survey-weighted sums instead of raw counts. Returns ------- float Standard error. """ + sw_0 = survey_weights[omega_0_mask.values] if survey_weights is not None else None cluster_psi_sums, _ = self._compute_cluster_psi_sums( df=df, outcome=outcome, @@ -1074,6 +1209,7 @@ def _compute_conservative_variance( weights=weights, cluster_var=cluster_var, kept_cov_mask=kept_cov_mask, + survey_weights_0=sw_0, ) sigma_sq = float((cluster_psi_sums**2).sum()) return np.sqrt(max(sigma_sq, 0.0)) @@ -1088,11 +1224,14 @@ def _compute_v_untreated_with_covariates( weights: np.ndarray, delta_hat: Optional[np.ndarray], kept_cov_mask: Optional[np.ndarray] = None, + survey_weights_0: Optional[np.ndarray] = None, ) -> np.ndarray: """ Compute v_it for untreated observations with covariates. Uses the projection: v_untreated = -A_0 (A_0'A_0)^{-1} A_1' w_treated + When survey_weights_0 is provided, uses weighted normal equations: + v_untreated = -A_0 (A_0' W A_0)^{-1} A_1' w_treated Uses scipy.sparse for FE dummy columns to reduce memory from O(N*(U+T)) to O(N) for the FE portion. @@ -1151,8 +1290,12 @@ def _build_A_sparse(df_sub, unit_vals, time_vals): # Compute A_1' w (sparse.T @ dense -> dense) A1_w = A_1.T @ weights # shape (p,) - # Solve (A_0'A_0) z = A_1' w using sparse direct solver - A0tA0_sparse = A_0.T @ A_0 # stays sparse + # Solve (A_0' [W] A_0) z = A_1' w using sparse direct solver + # When survey weights present, use weighted normal equations A_0' W A_0 + if survey_weights_0 is not None: + A0tA0_sparse = A_0.T @ A_0.multiply(survey_weights_0[:, None]) + else: + A0tA0_sparse = A_0.T @ A_0 # stays sparse try: z = spsolve(A0tA0_sparse.tocsc(), A1_w) except Exception: @@ -1160,8 +1303,10 @@ def _build_A_sparse(df_sub, unit_vals, time_vals): A0tA0_dense = A0tA0_sparse.toarray() z, _, _, _ = np.linalg.lstsq(A0tA0_dense, A1_w, rcond=None) - # v_untreated = -A_0 z (sparse @ dense -> dense) + # v_untreated = -[W_0] A_0 z (WLS projection requires per-obs weight) v_untreated = -(A_0 @ z) + if survey_weights_0 is not None: + v_untreated = v_untreated * survey_weights_0 return v_untreated def _compute_auxiliary_residuals_treated( @@ -1281,6 +1426,8 @@ def _aggregate_event_study( treatment_groups: List[Any], balance_e: Optional[int] = None, kept_cov_mask: Optional[np.ndarray] = None, + survey_weights: Optional[np.ndarray] = None, + survey_df: Optional[int] = None, ) -> Dict[int, Dict[str, Any]]: """Aggregate treatment effects by event-study horizon.""" df_1 = df.loc[omega_1_mask] @@ -1355,7 +1502,8 @@ def _aggregate_event_study( continue tau_h = tau_hat[h_mask] - valid_tau = tau_h[np.isfinite(tau_h)] + finite_h = np.isfinite(tau_h) + valid_tau = tau_h[finite_h] if len(valid_tau) == 0: event_study_effects[h] = { @@ -1368,10 +1516,32 @@ def _aggregate_event_study( } continue - effect = float(np.mean(valid_tau)) + # Survey-weighted or simple mean for per-horizon effect + if survey_weights is not None: + treated_sw = survey_weights[omega_1_mask.values] + sw_h = treated_sw[h_mask] + sw_valid = sw_h[finite_h] + effect = float(np.average(valid_tau, weights=sw_valid)) + else: + effect = float(np.mean(valid_tau)) # Compute SE via conservative variance with horizon-specific weights - weights_h, n_valid = _compute_target_weights(tau_hat, h_mask) + # When survey, aggregation weights are proportional to survey weights + if survey_weights is not None: + treated_sw = survey_weights[omega_1_mask.values] + n_1 = len(tau_hat) + weights_h = np.zeros(n_1) + sw_h = treated_sw[h_mask] + finite_in_h = np.isfinite(tau_h) + sw_finite = sw_h[finite_in_h] + # Set weights proportional to survey weights, summing to 1 + if sw_finite.sum() > 0: + h_indices = np.where(h_mask)[0] + finite_indices = h_indices[finite_in_h] + weights_h[finite_indices] = sw_finite / sw_finite.sum() + n_valid = int(finite_in_h.sum()) + else: + weights_h, n_valid = _compute_target_weights(tau_hat, h_mask) se = self._compute_conservative_variance( df=df, @@ -1389,9 +1559,10 @@ def _aggregate_event_study( weights=weights_h, cluster_var=cluster_var, kept_cov_mask=kept_cov_mask, + survey_weights=survey_weights, ) - t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha) + t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha, df=survey_df) event_study_effects[h] = { "effect": effect, @@ -1449,6 +1620,8 @@ def _aggregate_group( cluster_var: str, treatment_groups: List[Any], kept_cov_mask: Optional[np.ndarray] = None, + survey_weights: Optional[np.ndarray] = None, + survey_df: Optional[int] = None, ) -> Dict[Any, Dict[str, Any]]: """Aggregate treatment effects by cohort.""" df_1 = df.loc[omega_1_mask] @@ -1465,7 +1638,8 @@ def _aggregate_group( continue tau_g = tau_hat[g_mask] - valid_tau = tau_g[np.isfinite(tau_g)] + finite_g = np.isfinite(tau_g) + valid_tau = tau_g[finite_g] if len(valid_tau) == 0: group_effects[g] = { @@ -1478,10 +1652,29 @@ def _aggregate_group( } continue - effect = float(np.mean(valid_tau)) + # Survey-weighted or simple mean for per-group effect + if survey_weights is not None: + treated_sw = survey_weights[omega_1_mask.values] + sw_g = treated_sw[g_mask] + sw_valid = sw_g[finite_g] + effect = float(np.average(valid_tau, weights=sw_valid)) + else: + effect = float(np.mean(valid_tau)) # Compute SE with group-specific weights - weights_g, _ = _compute_target_weights(tau_hat, g_mask) + # When survey, aggregation weights proportional to survey weights + if survey_weights is not None: + treated_sw = survey_weights[omega_1_mask.values] + n_1 = len(tau_hat) + weights_g = np.zeros(n_1) + sw_g = treated_sw[g_mask] + sw_finite = sw_g[finite_g] + if sw_finite.sum() > 0: + g_indices = np.where(g_mask)[0] + finite_indices = g_indices[finite_g] + weights_g[finite_indices] = sw_finite / sw_finite.sum() + else: + weights_g, _ = _compute_target_weights(tau_hat, g_mask) se = self._compute_conservative_variance( df=df, @@ -1499,9 +1692,10 @@ def _aggregate_group( weights=weights_g, cluster_var=cluster_var, kept_cov_mask=kept_cov_mask, + survey_weights=survey_weights, ) - t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha) + t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha, df=survey_df) group_effects[g] = { "effect": effect, @@ -1706,6 +1900,7 @@ def imputation_did( covariates: Optional[List[str]] = None, aggregate: Optional[str] = None, balance_e: Optional[int] = None, + survey_design: object = None, **kwargs, ) -> ImputationDiDResults: """ @@ -1757,4 +1952,5 @@ def imputation_did( covariates=covariates, aggregate=aggregate, balance_e=balance_e, + survey_design=survey_design, ) diff --git a/diff_diff/imputation_results.py b/diff_diff/imputation_results.py index 6520fca..6589af1 100644 --- a/diff_diff/imputation_results.py +++ b/diff_diff/imputation_results.py @@ -139,6 +139,8 @@ class ImputationDiDResults: bootstrap_results: Optional[ImputationBootstrapResults] = field(default=None, repr=False) # Internal: stores data needed for pretrend_test() _estimator_ref: Optional[Any] = field(default=None, repr=False) + # Survey design metadata (SurveyMetadata instance from diff_diff.survey) + survey_metadata: Optional[Any] = field(default=None, repr=False) def __repr__(self) -> str: """Concise string representation.""" @@ -182,6 +184,27 @@ def summary(self, alpha: Optional[float] = None) -> str: "", ] + # Survey design info + if self.survey_metadata is not None: + sm = self.survey_metadata + lines.extend( + [ + "-" * 85, + "Survey Design".center(85), + "-" * 85, + f"{'Weight type:':<30} {sm.weight_type:>10}", + ] + ) + if sm.n_strata is not None: + lines.append(f"{'Strata:':<30} {sm.n_strata:>10}") + if sm.n_psu is not None: + lines.append(f"{'PSU/Cluster:':<30} {sm.n_psu:>10}") + lines.append(f"{'Effective sample size:':<30} {sm.effective_n:>10.1f}") + lines.append(f"{'Design effect (DEFF):':<30} {sm.design_effect:>10.2f}") + if sm.df_survey is not None: + lines.append(f"{'Survey d.f.:':<30} {sm.df_survey:>10}") + lines.extend(["-" * 85, ""]) + # Overall ATT lines.extend( [ diff --git a/diff_diff/linalg.py b/diff_diff/linalg.py index 6c26d36..01b85e5 100644 --- a/diff_diff/linalg.py +++ b/diff_diff/linalg.py @@ -390,24 +390,18 @@ def _validate_weights(weights, weight_type, n): """Validate weights array and weight_type for solve_ols/LinearRegression.""" if weight_type not in _VALID_WEIGHT_TYPES: raise ValueError( - f"weight_type must be one of {_VALID_WEIGHT_TYPES}, " - f"got '{weight_type}'" + f"weight_type must be one of {_VALID_WEIGHT_TYPES}, " f"got '{weight_type}'" ) if weights is not None: weights = np.asarray(weights, dtype=np.float64) if weights.shape[0] != n: - raise ValueError( - f"weights length ({weights.shape[0]}) must match " - f"X rows ({n})" - ) + raise ValueError(f"weights length ({weights.shape[0]}) must match " f"X rows ({n})") if np.any(np.isnan(weights)): raise ValueError("Weights contain NaN values") if np.any(np.isinf(weights)): raise ValueError("Weights contain Inf values") if np.any(weights < 0): - raise ValueError( - "Weights must be non-negative" - ) + raise ValueError("Weights must be non-negative") if weight_type == "fweight": fractional = weights - np.round(weights) if np.any(np.abs(fractional) > 1e-10): @@ -693,13 +687,9 @@ def solve_ols( weights=weights, weight_type=weight_type, ) - vcov_out = _expand_vcov_with_nan( - vcov_reduced, _original_X.shape[1], kept_cols - ) + vcov_out = _expand_vcov_with_nan(vcov_reduced, _original_X.shape[1], kept_cols) else: - vcov_out = np.full( - (_original_X.shape[1], _original_X.shape[1]), np.nan - ) + vcov_out = np.full((_original_X.shape[1], _original_X.shape[1]), np.nan) else: vcov_out = _compute_robust_vcov_numpy( _original_X, @@ -1122,6 +1112,7 @@ def solve_logit( tol: float = 1e-8, check_separation: bool = True, rank_deficient_action: str = "warn", + weights: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, np.ndarray]: """ Fit logistic regression via IRLS (Fisher scoring). @@ -1147,6 +1138,13 @@ def solve_logit( - "warn": Emit warning and drop columns (default) - "error": Raise ValueError - "silent": Drop columns silently + weights : np.ndarray, optional + Survey/observation weights of shape (n_samples,). When provided, + the IRLS working weights become ``weights * mu * (1 - mu)`` + instead of ``mu * (1 - mu)``. This produces the survey-weighted + maximum likelihood estimator, matching R's ``svyglm(family=binomial)``. + When None (default), behavior is identical to unweighted logistic + regression. Returns ------- @@ -1203,11 +1201,16 @@ def solve_logit( mu = np.clip(mu, 1e-10, 1 - 1e-10) # Working weights and working response - w = mu * (1.0 - mu) - z = eta + (y - mu) / w + w_irls = mu * (1.0 - mu) + z = eta + (y - mu) / w_irls + + if weights is not None: + w_total = weights * w_irls + else: + w_total = w_irls # Weighted least squares: solve (X'WX) beta = X'Wz - sqrt_w = np.sqrt(w) + sqrt_w = np.sqrt(w_total) Xw = X_solve * sqrt_w[:, None] zw = z * sqrt_w beta_new, _, _, _ = np.linalg.lstsq(Xw, zw, rcond=None) @@ -1593,10 +1596,7 @@ def fit( _use_survey_vcov = self.survey_design.needs_survey_vcov # Canonicalize weights from survey_design to ensure consistency # between coefficient estimation and survey vcov computation - if ( - self.weights is not None - and self.weights is not self.survey_design.weights - ): + if self.weights is not None and self.weights is not self.survey_design.weights: warnings.warn( "Explicit weights= differ from survey_design.weights. " "Using survey_design weights for both coefficient " @@ -1609,9 +1609,7 @@ def fit( self.weight_type = self.survey_design.weight_type if self.weights is not None: - self.weights = _validate_weights( - self.weights, self.weight_type, X.shape[0] - ) + self.weights = _validate_weights(self.weights, self.weight_type, X.shape[0]) # Inject cluster as PSU for survey variance when no PSU specified. # Use a local variable to avoid mutating self.survey_design, which @@ -1622,7 +1620,9 @@ def fit( and _effective_survey_design is not None and _use_survey_vcov ): - from diff_diff.survey import ResolvedSurveyDesign as _RSD, _inject_cluster_as_psu + from diff_diff.survey import ResolvedSurveyDesign as _RSD + from diff_diff.survey import _inject_cluster_as_psu + if isinstance(_effective_survey_design, _RSD) and _effective_survey_design.psu is None: _effective_survey_design = _inject_cluster_as_psu( _effective_survey_design, effective_cluster_ids @@ -1864,9 +1864,7 @@ def get_inference( # Use project-standard NaN-safe inference (returns all-NaN when SE <= 0) from diff_diff.utils import safe_inference - t_stat, p_value, conf_int = safe_inference( - coef, se, alpha=effective_alpha, df=effective_df - ) + t_stat, p_value, conf_int = safe_inference(coef, se, alpha=effective_alpha, df=effective_df) return InferenceResult( coefficient=coef, diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index 3be7cb0..f66709d 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -11,27 +11,28 @@ import numpy as np import pandas as pd from scipy import linalg as scipy_linalg + from diff_diff.linalg import ( - solve_ols, - solve_logit, _check_propensity_diagnostics, _detect_rank_deficiency, _format_dropped_columns, + solve_logit, + solve_ols, ) -from diff_diff.utils import safe_inference, safe_inference_batch - -# Import from split modules -from diff_diff.staggered_results import ( - GroupTimeEffect, - CallawaySantAnnaResults, +from diff_diff.staggered_aggregation import ( + CallawaySantAnnaAggregationMixin, ) from diff_diff.staggered_bootstrap import ( - CSBootstrapResults, CallawaySantAnnaBootstrapMixin, + CSBootstrapResults, ) -from diff_diff.staggered_aggregation import ( - CallawaySantAnnaAggregationMixin, + +# Import from split modules +from diff_diff.staggered_results import ( + CallawaySantAnnaResults, + GroupTimeEffect, ) +from diff_diff.utils import safe_inference, safe_inference_batch # Re-export for backward compatibility __all__ = [ @@ -49,6 +50,7 @@ def _linear_regression( X: np.ndarray, y: np.ndarray, rank_deficient_action: str = "warn", + weights: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, np.ndarray]: """ Fit OLS regression. @@ -64,6 +66,8 @@ def _linear_regression( - "warn": Issue warning and drop linearly dependent columns (default) - "error": Raise ValueError - "silent": Drop columns silently without warning + weights : np.ndarray, optional + Observation weights for WLS. When None, OLS is used. Returns ------- @@ -82,6 +86,7 @@ def _linear_regression( y, return_vcov=False, rank_deficient_action=rank_deficient_action, + weights=weights, ) return beta, residuals @@ -333,6 +338,7 @@ def _precompute_structures( covariates: Optional[List[str]], time_periods: List[Any], treatment_groups: List[Any], + resolved_survey=None, ) -> PrecomputedData: """ Pre-compute data structures for efficient ATT(g,t) computation. @@ -352,7 +358,6 @@ def _precompute_structures( unit_info = df.groupby(unit)[first_treat].first() all_units = unit_info.index.values unit_cohorts = unit_info.values - n_units = len(all_units) # Create unit index mapping for fast lookups unit_to_idx = {u: i for i, u in enumerate(all_units)} @@ -385,6 +390,15 @@ def _precompute_structures( is_balanced = not np.any(np.isnan(outcome_matrix)) + # Extract per-unit survey weights (one weight per unit) + if resolved_survey is not None: + sw_by_unit = ( + pd.Series(resolved_survey.weights, index=df.index).groupby(df[unit]).first() + ) + survey_weights_arr = sw_by_unit.reindex(all_units).values + else: + survey_weights_arr = None + return { "all_units": all_units, "unit_to_idx": unit_to_idx, @@ -396,6 +410,8 @@ def _precompute_structures( "covariate_by_period": covariate_by_period, "time_periods": time_periods, "is_balanced": is_balanced, + "survey_weights": survey_weights_arr, + "df_survey": resolved_survey.df_survey if resolved_survey is not None else None, } def _compute_att_gt_fast( @@ -406,12 +422,22 @@ def _compute_att_gt_fast( covariates: Optional[List[str]], pscore_cache: Optional[Dict] = None, cho_cache: Optional[Dict] = None, - ) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]]]: + ) -> Tuple[Optional[float], float, int, int, Optional[Dict[str, Any]], Optional[float]]: """ Compute ATT(g,t) using pre-computed data structures (fast version). Uses vectorized numpy operations on pre-pivoted outcome matrix instead of repeated pandas filtering. + + Returns + ------- + att_gt : float or None + se_gt : float + n_treated : int + n_control : int + inf_func_info : dict or None + survey_weight_sum : float or None + Sum of survey weights for treated units (for aggregation weighting). """ period_to_col = precomputed["period_to_col"] outcome_matrix = precomputed["outcome_matrix"] @@ -434,11 +460,11 @@ def _compute_att_gt_fast( if base_period_val not in period_to_col: # Base period must exist; no fallback to maintain methodological consistency - return None, 0.0, 0, 0, None + return None, 0.0, 0, 0, None, None # Check if periods exist in the data if base_period_val not in period_to_col or t not in period_to_col: - return None, 0.0, 0, 0, None + return None, 0.0, 0, 0, None, None base_col = period_to_col[base_period_val] post_col = period_to_col[t] @@ -476,12 +502,17 @@ def _compute_att_gt_fast( n_control = np.sum(control_valid) if n_treated == 0 or n_control == 0: - return None, 0.0, 0, 0, None + return None, 0.0, 0, 0, None, None # Extract outcome changes for treated and control treated_change = outcome_change[treated_valid] control_change = outcome_change[control_valid] + # Extract survey weights for treated and control + survey_w = precomputed.get("survey_weights") + sw_treated = survey_w[treated_valid] if survey_w is not None else None + sw_control = survey_w[control_valid] if survey_w is not None else None + # Get covariates if specified (from the base period) X_treated = None X_control = None @@ -522,9 +553,15 @@ def _compute_att_gt_fast( # Estimation method if self.estimation_method == "reg": att_gt, se_gt, inf_func = self._outcome_regression( - treated_change, control_change, X_treated, X_control + treated_change, + control_change, + X_treated, + X_control, + sw_treated=sw_treated, + sw_control=sw_control, ) elif self.estimation_method == "ipw": + sw_all = np.concatenate([sw_treated, sw_control]) if sw_treated is not None else None att_gt, se_gt, inf_func = self._ipw_estimation( treated_change, control_change, @@ -534,8 +571,12 @@ def _compute_att_gt_fast( X_control, pscore_cache=pscore_cache, pscore_key=pscore_key, + sw_treated=sw_treated, + sw_control=sw_control, + sw_all=sw_all, ) else: # doubly robust + sw_all = np.concatenate([sw_treated, sw_control]) if sw_treated is not None else None att_gt, se_gt, inf_func = self._doubly_robust( treated_change, control_change, @@ -545,6 +586,9 @@ def _compute_att_gt_fast( pscore_key=pscore_key, cho_cache=cho_cache, cho_key=cho_key, + sw_treated=sw_treated, + sw_control=sw_control, + sw_all=sw_all, ) # Package influence function info with index arrays (positions into @@ -563,7 +607,8 @@ def _compute_att_gt_fast( "control_inf": inf_func[n_t:], } - return att_gt, se_gt, int(n_treated), int(n_control), inf_func_info + sw_sum = float(np.sum(sw_treated)) if sw_treated is not None else None + return att_gt, se_gt, int(n_treated), int(n_control), inf_func_info, sw_sum def _compute_all_att_gt_vectorized( self, @@ -590,6 +635,7 @@ def _compute_all_att_gt_vectorized( cohort_masks = precomputed["cohort_masks"] never_treated_mask = precomputed["never_treated_mask"] unit_cohorts = precomputed["unit_cohorts"] + survey_w = precomputed.get("survey_weights") group_time_effects = {} influence_func_info = {} @@ -618,7 +664,9 @@ def _compute_all_att_gt_vectorized( if base_period_val not in period_to_col or t not in period_to_col: continue - tasks.append((g, t, period_to_col[base_period_val], period_to_col[t], base_period_val)) + tasks.append( + (g, t, period_to_col[base_period_val], period_to_col[t], base_period_val) + ) # Process all tasks atts = [] @@ -658,17 +706,38 @@ def _compute_all_att_gt_vectorized( n_c = int(n_control) # Inline no-covariates regression (difference in means) - att = float(np.mean(treated_change) - np.mean(control_change)) + if survey_w is not None: + sw_t = survey_w[treated_valid] + sw_c = survey_w[control_valid] + sw_t_norm = sw_t / np.sum(sw_t) + sw_c_norm = sw_c / np.sum(sw_c) + mu_t = float(np.sum(sw_t_norm * treated_change)) + mu_c = float(np.sum(sw_c_norm * control_change)) + att = mu_t - mu_c + + # Influence function (survey-weighted) + inf_treated = sw_t_norm * (treated_change - mu_t) + inf_control = -sw_c_norm * (control_change - mu_c) + # SE derived from IF: sum(IF_i^2) + se = ( + float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2))) + if (n_t > 0 and n_c > 0) + else 0.0 + ) + sw_sum = float(np.sum(sw_t)) + else: + att = float(np.mean(treated_change) - np.mean(control_change)) - var_t = float(np.var(treated_change, ddof=1)) if n_t > 1 else 0.0 - var_c = float(np.var(control_change, ddof=1)) if n_c > 1 else 0.0 - se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0 + var_t = float(np.var(treated_change, ddof=1)) if n_t > 1 else 0.0 + var_c = float(np.var(control_change, ddof=1)) if n_c > 1 else 0.0 + se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0 - # Influence function - inf_treated = (treated_change - np.mean(treated_change)) / n_t - inf_control = -(control_change - np.mean(control_change)) / n_c + # Influence function + inf_treated = (treated_change - np.mean(treated_change)) / n_t + inf_control = -(control_change - np.mean(control_change)) / n_c + sw_sum = None - group_time_effects[(g, t)] = { + gte_entry = { "effect": att, "se": se, # t_stat, p_value, conf_int filled by batch inference below @@ -678,6 +747,9 @@ def _compute_all_att_gt_vectorized( "n_treated": n_t, "n_control": n_c, } + if sw_sum is not None: + gte_entry["survey_weight_sum"] = sw_sum + group_time_effects[(g, t)] = gte_entry all_units = precomputed["all_units"] treated_positions = np.where(treated_valid)[0] @@ -697,8 +769,12 @@ def _compute_all_att_gt_vectorized( # Batch inference for all (g,t) pairs at once if task_keys: + df_survey_val = precomputed.get("df_survey") t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch( - np.array(atts), np.array(ses), alpha=self.alpha + np.array(atts), + np.array(ses), + alpha=self.alpha, + df=df_survey_val, ) for idx, key in enumerate(task_keys): group_time_effects[key]["t_stat"] = float(t_stats[idx]) @@ -1048,6 +1124,7 @@ def fit( covariates: Optional[List[str]] = None, aggregate: Optional[str] = None, balance_e: Optional[int] = None, + survey_design: object = None, ) -> CallawaySantAnnaResults: """ Fit the Callaway-Sant'Anna estimator. @@ -1077,6 +1154,9 @@ def fit( balance_e : int, optional For event study, balance the panel at relative time e. Ensures all groups contribute to each relative period. + survey_design : SurveyDesign, optional + Survey design specification for design-based inference. + Supports weights, strata, PSU, and FPC. Returns ------- @@ -1096,6 +1176,45 @@ def fit( if covariates is not None and len(covariates) == 0: covariates = None + # Resolve survey design if provided + from diff_diff.survey import ( + _inject_cluster_as_psu, + _resolve_effective_cluster, + _resolve_survey_for_fit, + _validate_unit_constant_survey, + ) + + resolved_survey, survey_weights, survey_weight_type, survey_metadata = ( + _resolve_survey_for_fit(survey_design, data, "analytical") + ) + + # Validate within-unit constancy for panel survey designs + if resolved_survey is not None: + _validate_unit_constant_survey(data, unit, survey_design) + + # Guard bootstrap + survey + if self.n_bootstrap > 0 and resolved_survey is not None: + raise NotImplementedError( + "Bootstrap inference with survey weights is not yet supported " + "for CallawaySantAnna. Use analytical inference (n_bootstrap=0)." + ) + + # Guard covariates + survey + IPW/DR (nuisance IF corrections not yet + # implemented to match DRDID panel formula) + if ( + resolved_survey is not None + and covariates is not None + and len(covariates) > 0 + and self.estimation_method in ("ipw", "dr") + ): + raise NotImplementedError( + f"Survey weights with covariates and estimation_method=" + f"'{self.estimation_method}' is not yet supported for " + f"CallawaySantAnna. The DRDID panel nuisance-estimation IF " + f"corrections are not yet implemented. Use estimation_method='reg' " + f"with covariates, or use any method without covariates." + ) + # Validate inputs required_cols = [outcome, unit, time, first_treat] if covariates: @@ -1147,13 +1266,46 @@ def fit( "cohorts when there are no never-treated units." ) + # Resolve effective cluster and inject cluster-as-PSU for survey variance + if resolved_survey is not None: + cluster_var = self.cluster if self.cluster is not None else unit + cluster_ids_raw = df[cluster_var].values if cluster_var in df.columns else None + effective_cluster_ids = _resolve_effective_cluster( + resolved_survey, + cluster_ids_raw, + cluster_var if self.cluster is not None else None, + ) + resolved_survey = _inject_cluster_as_psu(resolved_survey, effective_cluster_ids) + # Recompute metadata after PSU injection + if resolved_survey.psu is not None and survey_metadata is not None: + from diff_diff.survey import compute_survey_metadata + + raw_w = ( + data[survey_design.weights].values.astype(np.float64) + if survey_design.weights + else np.ones(len(data), dtype=np.float64) + ) + survey_metadata = compute_survey_metadata(resolved_survey, raw_w) + # Pre-compute data structures for efficient ATT(g,t) computation precomputed = self._precompute_structures( - df, outcome, unit, time, first_treat, covariates, time_periods, treatment_groups + df, + outcome, + unit, + time, + first_treat, + covariates, + time_periods, + treatment_groups, + resolved_survey=resolved_survey, ) + # Survey df for safe_inference calls + df_survey = resolved_survey.df_survey if resolved_survey is not None else None + # Compute ATT(g,t) for each group-time combination min_period = min(time_periods) + has_survey = resolved_survey is not None if covariates is None and self.estimation_method == "reg": # Fast vectorized path for the common no-covariates regression case @@ -1164,6 +1316,7 @@ def fit( covariates is not None and self.estimation_method == "reg" and self.rank_deficient_action != "error" + and not has_survey # Cholesky cache uses X'X; survey needs X'WX ): # Optimized covariate regression path with Cholesky caching group_time_effects, influence_func_info = self._compute_all_att_gt_covariate_reg( @@ -1177,12 +1330,14 @@ def fit( # Propensity score cache for IPW/DR with covariates pscore_cache = {} if (covariates and self.estimation_method in ("ipw", "dr")) else None # Cholesky cache for DR outcome regression component + # Skip cache when survey weights present (X'WX differs from X'X) cho_cache = ( {} if ( covariates and self.estimation_method == "dr" and self.rank_deficient_action != "error" + and not has_survey ) else None ) @@ -1197,7 +1352,7 @@ def fit( ] for t in valid_periods: - att_gt, se_gt, n_treat, n_ctrl, inf_info = self._compute_att_gt_fast( + att_gt, se_gt, n_treat, n_ctrl, inf_info, sw_sum = self._compute_att_gt_fast( precomputed, g, t, @@ -1207,9 +1362,14 @@ def fit( ) if att_gt is not None: - t_stat, p_val, ci = safe_inference(att_gt, se_gt, alpha=self.alpha) + t_stat, p_val, ci = safe_inference( + att_gt, + se_gt, + alpha=self.alpha, + df=df_survey, + ) - group_time_effects[(g, t)] = { + gte_entry = { "effect": att_gt, "se": se_gt, "t_stat": t_stat, @@ -1218,6 +1378,9 @@ def fit( "n_treated": n_treat, "n_control": n_ctrl, } + if sw_sum is not None: + gte_entry["survey_weight_sum"] = sw_sum + group_time_effects[(g, t)] = gte_entry if inf_info is not None: influence_func_info[(g, t)] = inf_info @@ -1232,7 +1395,12 @@ def fit( overall_att, overall_se = self._aggregate_simple( group_time_effects, influence_func_info, df, unit, precomputed ) - overall_t, overall_p, overall_ci = safe_inference(overall_att, overall_se, alpha=self.alpha) + overall_t, overall_p, overall_ci = safe_inference( + overall_att, + overall_se, + alpha=self.alpha, + df=df_survey, + ) # Compute additional aggregations if requested event_study_effects = None @@ -1383,6 +1551,7 @@ def fit( bootstrap_results=bootstrap_results, cband_crit_value=cband_crit_value, pscore_trim=self.pscore_trim, + survey_metadata=survey_metadata, ) self.is_fitted_ = True @@ -1394,6 +1563,8 @@ def _outcome_regression( control_change: np.ndarray, X_treated: Optional[np.ndarray] = None, X_control: Optional[np.ndarray] = None, + sw_treated: Optional[np.ndarray] = None, + sw_control: Optional[np.ndarray] = None, ) -> Tuple[float, float, np.ndarray]: """ Estimate ATT using outcome regression. @@ -1405,6 +1576,11 @@ def _outcome_regression( Without covariates: Simple difference in means. + + Parameters + ---------- + sw_treated, sw_control : np.ndarray, optional + Survey weights for treated and control units. """ n_t = len(treated_change) n_c = len(control_change) @@ -1416,43 +1592,80 @@ def _outcome_regression( X_control, control_change, rank_deficient_action=self.rank_deficient_action, + weights=sw_control, ) # Predict counterfactual for treated units X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated]) predicted_control = np.dot(X_treated_with_intercept, beta) - # ATT = mean(observed treated change - predicted counterfactual) - att = np.mean(treated_change - predicted_control) - - # Standard error using sandwich estimator - # Variance from treated: Var(Y_1 - m(X)) + # ATT: survey-weighted mean of treated residuals treated_residuals = treated_change - predicted_control - var_t = np.var(treated_residuals, ddof=1) if n_t > 1 else 0.0 - # Variance from control regression (residual variance) - var_c = np.var(residuals, ddof=1) if n_c > 1 else 0.0 + if sw_treated is not None: + sw_t_sum = float(np.sum(sw_treated)) + sw_t_norm = sw_treated / sw_t_sum + att = float(np.sum(sw_t_norm * treated_residuals)) + + # Survey-weighted OR influence function. + # Mirrors the unweighted structure: treated uses (resid - ATT)/n_t, + # control uses -resid/n_c. For survey: scale by w_i/sum(w_treated). + # The WLS residuals are orthogonal to W*X by construction, so the + # regression nuisance IF correction is implicit in the residual + # structure (same as the unweighted case). + X_c_int = np.column_stack([np.ones(n_c), X_control]) + resid_c = control_change - np.dot(X_c_int, beta) + + inf_treated = (sw_treated / sw_t_sum) * (treated_residuals - att) + inf_control = -(sw_control / sw_t_sum) * resid_c + inf_func = np.concatenate([inf_treated, inf_control]) + + # SE from influence function variance + se = float(np.sqrt(np.sum(inf_func**2))) + se = se if se > 0 else 0.0 + else: + att = float(np.mean(treated_residuals)) - # Approximate SE (ignoring estimation error in beta for simplicity) - se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0 + # Standard error using sandwich estimator + var_t = np.var(treated_residuals, ddof=1) if n_t > 1 else 0.0 + var_c = np.var(residuals, ddof=1) if n_c > 1 else 0.0 + se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0 - # Influence function - inf_treated = (treated_residuals - np.mean(treated_residuals)) / n_t - inf_control = -residuals / n_c - inf_func = np.concatenate([inf_treated, inf_control]) + # Influence function + inf_treated = (treated_residuals - np.mean(treated_residuals)) / n_t + inf_control = -residuals / n_c + inf_func = np.concatenate([inf_treated, inf_control]) else: # Simple difference in means (no covariates) - att = np.mean(treated_change) - np.mean(control_change) - - var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 - var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 + if sw_treated is not None: + sw_t_norm = sw_treated / np.sum(sw_treated) + sw_c_norm = sw_control / np.sum(sw_control) + mu_t = float(np.sum(sw_t_norm * treated_change)) + mu_c = float(np.sum(sw_c_norm * control_change)) + att = mu_t - mu_c + + # Influence function (survey-weighted) + inf_treated = sw_t_norm * (treated_change - mu_t) + inf_control = -sw_c_norm * (control_change - mu_c) + inf_func = np.concatenate([inf_treated, inf_control]) + + # SE from influence function variance + se = ( + float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2))) + if (n_t > 0 and n_c > 0) + else 0.0 + ) + else: + att = float(np.mean(treated_change) - np.mean(control_change)) - se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0 + var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 + var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 + se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0 - # Influence function (for aggregation) - inf_treated = treated_change - np.mean(treated_change) - inf_control = control_change - np.mean(control_change) - inf_func = np.concatenate([inf_treated / n_t, -inf_control / n_c]) + # Influence function (for aggregation) + inf_treated = treated_change - np.mean(treated_change) + inf_control = control_change - np.mean(control_change) + inf_func = np.concatenate([inf_treated / n_t, -inf_control / n_c]) return att, se, inf_func @@ -1466,6 +1679,9 @@ def _ipw_estimation( X_control: Optional[np.ndarray] = None, pscore_cache: Optional[Dict] = None, pscore_key: Optional[Any] = None, + sw_treated: Optional[np.ndarray] = None, + sw_control: Optional[np.ndarray] = None, + sw_all: Optional[np.ndarray] = None, ) -> Tuple[float, float, np.ndarray]: """ Estimate ATT using inverse probability weighting. @@ -1477,6 +1693,11 @@ def _ipw_estimation( Without covariates: Simple difference in means with unconditional propensity weighting. + + Parameters + ---------- + sw_treated, sw_control, sw_all : np.ndarray, optional + Survey weights for treated, control, and all units. """ n_t = len(treated_change) n_c = len(control_change) @@ -1508,6 +1729,7 @@ def _ipw_estimation( X_all, D, rank_deficient_action=self.rank_deficient_action, + weights=sw_all, ) _check_propensity_diagnostics(pscore, self.pscore_trim) # Cache the fitted coefficients @@ -1533,51 +1755,124 @@ def _ipw_estimation( pscore_control = np.clip(pscore_control, self.pscore_trim, 1 - self.pscore_trim) pscore_treated = np.clip(pscore_treated, self.pscore_trim, 1 - self.pscore_trim) - # IPW weights for control units: p(X) / (1 - p(X)) - # This reweights controls to have same covariate distribution as treated - weights_control = pscore_control / (1 - pscore_control) - weights_control = weights_control / np.sum(weights_control) # normalize + if sw_treated is not None: + # IPW weights compose with survey weights: + # w_i = sw_i * p(X_i) / (1 - p(X_i)) + weights_control = sw_control * pscore_control / (1 - pscore_control) + weights_control_norm = weights_control / np.sum(weights_control) + + # ATT: survey-weighted treated mean minus composite-weighted control mean + sw_t_norm = sw_treated / np.sum(sw_treated) + mu_t = float(np.sum(sw_t_norm * treated_change)) + att = mu_t - float(np.sum(weights_control_norm * control_change)) + + # Influence function (survey-weighted) + inf_treated = sw_t_norm * (treated_change - mu_t) + inf_control = -weights_control_norm * ( + control_change - np.sum(weights_control_norm * control_change) + ) + inf_func = np.concatenate([inf_treated, inf_control]) + + # Propensity score IF correction + # Accounts for estimation uncertainty in logistic regression coefficients + X_all_int = np.column_stack([np.ones(n_t + n_c), X_all]) + pscore_all = np.concatenate([pscore_treated, pscore_control]) + + # Survey-weighted PS Hessian: sum(w_i * mu_i * (1-mu_i) * x_i * x_i') + W_ps = pscore_all * (1 - pscore_all) + if sw_all is not None: + W_ps = W_ps * sw_all + H = X_all_int.T @ (W_ps[:, None] * X_all_int) + try: + H_inv = np.linalg.solve(H, np.eye(H.shape[0])) + except np.linalg.LinAlgError: + H_inv = np.linalg.lstsq(H, np.eye(H.shape[0]), rcond=None)[0] + + # PS score: w_i * (D_i - pi_i) * X_i + D_all = np.concatenate([np.ones(n_t), np.zeros(n_c)]) + score_ps = (D_all - pscore_all)[:, None] * X_all_int + if sw_all is not None: + score_ps = score_ps * sw_all[:, None] + asy_lin_rep_ps = score_ps @ H_inv # shape (n_t + n_c, p) + + # M2: gradient of ATT w.r.t. PS parameters + att_control_weighted = np.sum(weights_control_norm * control_change) + M2 = np.mean( + (weights_control_norm * (control_change - att_control_weighted))[:, None] + * X_all_int[n_t:], + axis=0, + ) - # ATT = mean(treated) - weighted_mean(control) - att = np.mean(treated_change) - np.sum(weights_control * control_change) + # PS correction to influence function + inf_ps_correction = asy_lin_rep_ps @ M2 + inf_func = inf_func + inf_ps_correction - # Compute standard error - # Variance of treated mean - var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 + # SE from influence function variance + var_psi = np.sum(inf_func**2) + se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0 + else: + # IPW weights for control units: p(X) / (1 - p(X)) + # This reweights controls to have same covariate distribution as treated + weights_control = pscore_control / (1 - pscore_control) + weights_control = weights_control / np.sum(weights_control) # normalize - # Variance of weighted control mean - weighted_var_c = np.sum( - weights_control * (control_change - np.sum(weights_control * control_change)) ** 2 - ) + # ATT = mean(treated) - weighted_mean(control) + att = float(np.mean(treated_change) - np.sum(weights_control * control_change)) - se = np.sqrt(var_t / n_t + weighted_var_c) if (n_t > 0 and n_c > 0) else 0.0 + # Compute standard error + var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 - # Influence function - inf_treated = (treated_change - np.mean(treated_change)) / n_t - inf_control = -weights_control * ( - control_change - np.sum(weights_control * control_change) - ) - inf_func = np.concatenate([inf_treated, inf_control]) + weighted_var_c = np.sum( + weights_control + * (control_change - np.sum(weights_control * control_change)) ** 2 + ) + + se = float(np.sqrt(var_t / n_t + weighted_var_c)) if (n_t > 0 and n_c > 0) else 0.0 + + # Influence function + inf_treated = (treated_change - np.mean(treated_change)) / n_t + inf_control = -weights_control * ( + control_change - np.sum(weights_control * control_change) + ) + inf_func = np.concatenate([inf_treated, inf_control]) else: # Unconditional IPW (reduces to difference in means) - p_treat = n_treated / n_total # unconditional propensity score + if sw_treated is not None: + # Survey-weighted difference in means + sw_t_norm = sw_treated / np.sum(sw_treated) + sw_c_norm = sw_control / np.sum(sw_control) + mu_t = float(np.sum(sw_t_norm * treated_change)) + mu_c = float(np.sum(sw_c_norm * control_change)) + att = mu_t - mu_c + + inf_treated = sw_t_norm * (treated_change - mu_t) + inf_control = -sw_c_norm * (control_change - mu_c) + inf_func = np.concatenate([inf_treated, inf_control]) + + se = ( + float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2))) + if (n_t > 0 and n_c > 0) + else 0.0 + ) + else: + p_treat = n_treated / n_total # unconditional propensity score - att = np.mean(treated_change) - np.mean(control_change) + att = float(np.mean(treated_change) - np.mean(control_change)) - var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 - var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 + var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 + var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 - # Adjusted variance for IPW - se = ( - np.sqrt(var_t / n_t + var_c * (1 - p_treat) / (n_c * p_treat)) - if (n_t > 0 and n_c > 0 and p_treat > 0) - else 0.0 - ) + # Adjusted variance for IPW + se = float( + np.sqrt(var_t / n_t + var_c * (1 - p_treat) / (n_c * p_treat)) + if (n_t > 0 and n_c > 0 and p_treat > 0) + else 0.0 + ) - # Influence function (for aggregation) - inf_treated = (treated_change - np.mean(treated_change)) / n_t - inf_control = (control_change - np.mean(control_change)) / n_c - inf_func = np.concatenate([inf_treated, -inf_control]) + # Influence function (for aggregation) + inf_treated = (treated_change - np.mean(treated_change)) / n_t + inf_control = (control_change - np.mean(control_change)) / n_c + inf_func = np.concatenate([inf_treated, -inf_control]) return att, se, inf_func @@ -1591,6 +1886,9 @@ def _doubly_robust( pscore_key: Optional[Any] = None, cho_cache: Optional[Dict] = None, cho_key: Optional[Any] = None, + sw_treated: Optional[np.ndarray] = None, + sw_control: Optional[np.ndarray] = None, + sw_all: Optional[np.ndarray] = None, ) -> Tuple[float, float, np.ndarray]: """ Estimate ATT using doubly robust estimation. @@ -1607,6 +1905,11 @@ def _doubly_robust( Without covariates: Reduces to simple difference in means. + + Parameters + ---------- + sw_treated, sw_control, sw_all : np.ndarray, optional + Survey weights for treated, control, and all units. """ n_t = len(treated_change) n_c = len(control_change) @@ -1614,7 +1917,7 @@ def _doubly_robust( if X_treated is not None and X_control is not None and X_treated.shape[1] > 0: # Doubly robust estimation with covariates # Step 1: Outcome regression - fit E[Delta Y | X] on control - # Try Cholesky cache for outcome regression + # Try Cholesky cache for outcome regression (disabled when survey weights present) beta = None X_control_with_intercept = np.column_stack([np.ones(n_c), X_control]) if cho_cache is not None and cho_key is not None: @@ -1650,6 +1953,7 @@ def _doubly_robust( X_control, control_change, rank_deficient_action=self.rank_deficient_action, + weights=sw_control, ) # Zero NaN coefficients for prediction only — dropped columns # contribute 0 to the column space projection. Note: solve_ols @@ -1684,6 +1988,7 @@ def _doubly_robust( X_all, D, rank_deficient_action=self.rank_deficient_action, + weights=sw_all, ) _check_propensity_diagnostics(pscore, self.pscore_trim) if pscore_cache is not None and pscore_key is not None: @@ -1705,43 +2010,75 @@ def _doubly_robust( # Clip propensity scores pscore_control = np.clip(pscore_control, self.pscore_trim, 1 - self.pscore_trim) - # IPW weights for control: p(X) / (1 - p(X)) - weights_control = pscore_control / (1 - pscore_control) + if sw_treated is not None: + # IPW weights compose with survey weights + weights_control = sw_control * pscore_control / (1 - pscore_control) - # Step 3: Doubly robust ATT - # ATT = mean(treated - m(X_treated)) - # + weighted_mean_control((m(X) - Y) * weight) - att_treated_part = np.mean(treated_change - m_treated) + # Step 3: DR ATT (survey-weighted) + sw_t_sum = np.sum(sw_treated) + att_treated_part = float( + np.sum(sw_treated * (treated_change - m_treated)) / sw_t_sum + ) + augmentation = float( + np.sum(weights_control * (m_control - control_change)) / sw_t_sum + ) + att = att_treated_part + augmentation - # Augmentation term from control - augmentation = np.sum(weights_control * (m_control - control_change)) / n_t + # Step 4: Influence function (survey-weighted DR) + psi_treated = (sw_treated / sw_t_sum) * (treated_change - m_treated - att) + psi_control = (weights_control / sw_t_sum) * (m_control - control_change) - att = att_treated_part + augmentation + var_psi = np.sum(psi_treated**2) + np.sum(psi_control**2) + se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0 - # Step 4: Standard error using influence function - # Influence function for DR estimator - psi_treated = (treated_change - m_treated - att) / n_t - psi_control = (weights_control * (m_control - control_change)) / n_t + inf_func = np.concatenate([psi_treated, psi_control]) + else: + # IPW weights for control: p(X) / (1 - p(X)) + weights_control = pscore_control / (1 - pscore_control) + + # Step 3: Doubly robust ATT + att_treated_part = float(np.mean(treated_change - m_treated)) + augmentation = float(np.sum(weights_control * (m_control - control_change)) / n_t) + att = att_treated_part + augmentation + + # Step 4: Standard error using influence function + psi_treated = (treated_change - m_treated - att) / n_t + psi_control = (weights_control * (m_control - control_change)) / n_t - # Variance is sum of squared influence functions - var_psi = np.sum(psi_treated**2) + np.sum(psi_control**2) - se = np.sqrt(var_psi) if var_psi > 0 else 0.0 + var_psi = np.sum(psi_treated**2) + np.sum(psi_control**2) + se = float(np.sqrt(var_psi)) if var_psi > 0 else 0.0 - # Full influence function - inf_func = np.concatenate([psi_treated, psi_control]) + inf_func = np.concatenate([psi_treated, psi_control]) else: # Without covariates, DR simplifies to difference in means - att = np.mean(treated_change) - np.mean(control_change) + if sw_treated is not None: + sw_t_norm = sw_treated / np.sum(sw_treated) + sw_c_norm = sw_control / np.sum(sw_control) + mu_t = float(np.sum(sw_t_norm * treated_change)) + mu_c = float(np.sum(sw_c_norm * control_change)) + att = mu_t - mu_c + + inf_treated = sw_t_norm * (treated_change - mu_t) + inf_control = -sw_c_norm * (control_change - mu_c) + inf_func = np.concatenate([inf_treated, inf_control]) + + se = ( + float(np.sqrt(np.sum(inf_treated**2) + np.sum(inf_control**2))) + if (n_t > 0 and n_c > 0) + else 0.0 + ) + else: + att = float(np.mean(treated_change) - np.mean(control_change)) - var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 - var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 + var_t = np.var(treated_change, ddof=1) if n_t > 1 else 0.0 + var_c = np.var(control_change, ddof=1) if n_c > 1 else 0.0 - se = np.sqrt(var_t / n_t + var_c / n_c) if (n_t > 0 and n_c > 0) else 0.0 + se = float(np.sqrt(var_t / n_t + var_c / n_c)) if (n_t > 0 and n_c > 0) else 0.0 - # Influence function for DR estimator - inf_treated = (treated_change - np.mean(treated_change)) / n_t - inf_control = (control_change - np.mean(control_change)) / n_c - inf_func = np.concatenate([inf_treated, -inf_control]) + # Influence function for DR estimator + inf_treated = (treated_change - np.mean(treated_change)) / n_t + inf_control = (control_change - np.mean(control_change)) / n_c + inf_func = np.concatenate([inf_treated, -inf_control]) return att, se, inf_func diff --git a/diff_diff/staggered_aggregation.py b/diff_diff/staggered_aggregation.py index 7faf043..e6c8045 100644 --- a/diff_diff/staggered_aggregation.py +++ b/diff_diff/staggered_aggregation.py @@ -67,19 +67,24 @@ def _aggregate_simple( # Pre-treatment effects are for parallel trends, not overall ATT if t < g - self.anticipation: continue - effects.append(data['effect']) - weights_list.append(data['n_treated']) + effects.append(data["effect"]) + # Use survey_weight_sum for aggregation when available + if data.get("survey_weight_sum") is not None: + weights_list.append(data["survey_weight_sum"]) + else: + weights_list.append(data["n_treated"]) gt_pairs.append((g, t)) groups_for_gt.append(g) # Guard against empty post-treatment set if len(effects) == 0: import warnings + warnings.warn( "No post-treatment effects available for overall ATT aggregation. " "This can occur when cohorts lack post-treatment periods in the data.", UserWarning, - stacklevel=2 + stacklevel=2, ) return np.nan, np.nan @@ -92,6 +97,7 @@ def _aggregate_simple( n_nan = int(np.sum(~finite_mask)) if n_nan > 0: import warnings + warnings.warn( f"{n_nan} group-time effect(s) are NaN and excluded from overall ATT " "aggregation. Inspect group_time_effects for details.", @@ -105,6 +111,7 @@ def _aggregate_simple( if len(effects) == 0: import warnings + warnings.warn( "All post-treatment effects are NaN. Cannot compute overall ATT.", UserWarning, @@ -121,8 +128,14 @@ def _aggregate_simple( # Compute SE using influence function aggregation with wif adjustment overall_se = self._compute_aggregated_se_with_wif( - gt_pairs, weights_norm, effects, groups_for_gt, - influence_func_info, df, unit, precomputed + gt_pairs, + weights_norm, + effects, + groups_for_gt, + influence_func_info, + df, + unit, + precomputed, ) return overall_att, overall_se @@ -158,13 +171,13 @@ def _compute_aggregated_se( if n_units is None: # Fallback: infer size from influence function info max_idx = 0 - for (g, t) in gt_pairs: + for g, t in gt_pairs: if (g, t) in influence_func_info: info = influence_func_info[(g, t)] - if len(info['treated_idx']) > 0: - max_idx = max(max_idx, info['treated_idx'].max()) - if len(info['control_idx']) > 0: - max_idx = max(max_idx, info['control_idx'].max()) + if len(info["treated_idx"]) > 0: + max_idx = max(max_idx, info["treated_idx"].max()) + if len(info["control_idx"]) > 0: + max_idx = max(max_idx, info["control_idx"].max()) n_units = max_idx + 1 if n_units == 0: @@ -181,16 +194,16 @@ def _compute_aggregated_se( w = weights[j] # Vectorized influence function aggregation using index arrays - treated_idx = info['treated_idx'] + treated_idx = info["treated_idx"] if len(treated_idx) > 0: - np.add.at(psi_overall, treated_idx, w * info['treated_inf']) + np.add.at(psi_overall, treated_idx, w * info["treated_inf"]) - control_idx = info['control_idx'] + control_idx = info["control_idx"] if len(control_idx) > 0: - np.add.at(psi_overall, control_idx, w * info['control_inf']) + np.add.at(psi_overall, control_idx, w * info["control_inf"]) # Compute variance: Var(θ̄) = (1/n) Σᵢ ψᵢ² - variance = np.sum(psi_overall ** 2) + variance = np.sum(psi_overall**2) return np.sqrt(variance) def _compute_combined_influence_function( @@ -228,39 +241,49 @@ def _compute_combined_influence_function( # Build unit index mapping (local or global) if global_unit_to_idx is not None and n_global_units is not None: - unit_to_idx = global_unit_to_idx n_units = n_global_units all_units = None # caller already has the unit list else: all_units_set: Set[Any] = set() - for (g, t) in gt_pairs: + for g, t in gt_pairs: if (g, t) in influence_func_info: info = influence_func_info[(g, t)] - all_units_set.update(info['treated_units']) - all_units_set.update(info['control_units']) + all_units_set.update(info["treated_units"]) + all_units_set.update(info["control_units"]) if not all_units_set: return np.zeros(0), [] all_units = sorted(all_units_set) n_units = len(all_units) - unit_to_idx = {u: i for i, u in enumerate(all_units)} - # Get unique groups and their information unique_groups = sorted(set(groups_for_gt)) unique_groups_set = set(unique_groups) group_to_idx = {g: i for i, g in enumerate(unique_groups)} + # Check for survey weights in precomputed data + survey_w = precomputed.get("survey_weights") if precomputed is not None else None + # Compute group-level probabilities matching R's formula: # pg[g] = n_g / n_all (fraction of ALL units in group g) + # With survey weights: pg[g] = sum(sw_g) / sum(sw_all) group_sizes = {} - for g in unique_groups: - treated_in_g = df[df['first_treat'] == g][unit].nunique() - group_sizes[g] = treated_in_g + if survey_w is not None: + # Survey-weighted group sizes + precomputed_cohorts = precomputed["unit_cohorts"] + for g in unique_groups: + mask_g = precomputed_cohorts == g + group_sizes[g] = float(np.sum(survey_w[mask_g])) + total_weight = float(np.sum(survey_w)) + else: + for g in unique_groups: + treated_in_g = df[df["first_treat"] == g][unit].nunique() + group_sizes[g] = treated_in_g + total_weight = float(n_units) # pg indexed by group - pg_by_group = np.array([group_sizes[g] / n_units for g in unique_groups]) + pg_by_group = np.array([group_sizes[g] / total_weight for g in unique_groups]) # pg indexed by keeper (each (g,t) pair gets its group's pg) pg_keepers = np.array([pg_by_group[group_to_idx[g]] for g in groups_for_gt]) @@ -281,13 +304,13 @@ def _compute_combined_influence_function( w = weights[j] # Vectorized influence function aggregation using precomputed index arrays - treated_idx = info['treated_idx'] + treated_idx = info["treated_idx"] if len(treated_idx) > 0: - np.add.at(psi_standard, treated_idx, w * info['treated_inf']) + np.add.at(psi_standard, treated_idx, w * info["treated_inf"]) - control_idx = info['control_idx'] + control_idx = info["control_idx"] if len(control_idx) > 0: - np.add.at(psi_standard, control_idx, w * info['control_inf']) + np.add.at(psi_standard, control_idx, w * info["control_inf"]) # Build unit-group array: normalize iterator to (idx, uid) pairs unit_groups_array = np.full(n_units, -1, dtype=np.float64) @@ -298,8 +321,8 @@ def _compute_combined_influence_function( ) if precomputed is not None: - precomputed_cohorts = precomputed['unit_cohorts'] - precomputed_unit_to_idx = precomputed['unit_to_idx'] + precomputed_cohorts = precomputed["unit_cohorts"] + precomputed_unit_to_idx = precomputed["unit_to_idx"] for idx, uid in idx_uid_pairs: if uid in precomputed_unit_to_idx: cohort = precomputed_cohorts[precomputed_unit_to_idx[uid]] @@ -307,37 +330,69 @@ def _compute_combined_influence_function( unit_groups_array[idx] = cohort else: for idx, uid in idx_uid_pairs: - unit_first_treat = df[df[unit] == uid]['first_treat'].iloc[0] + unit_first_treat = df[df[unit] == uid]["first_treat"].iloc[0] if unit_first_treat in unique_groups_set: unit_groups_array[idx] = unit_first_treat # Vectorized WIF computation groups_for_gt_array = np.array(groups_for_gt) - indicator_matrix = (unit_groups_array[:, np.newaxis] == groups_for_gt_array[np.newaxis, :]).astype(np.float64) - indicator_sum = np.sum(indicator_matrix - pg_keepers, axis=1) + indicator_matrix = ( + unit_groups_array[:, np.newaxis] == groups_for_gt_array[np.newaxis, :] + ).astype(np.float64) + + if survey_w is not None: + # Survey-weighted WIF matching R's did::wif() / compute.aggte.R. + # pg_k = E[w_i * 1{G_i=g}] is the weighted group share. + # IF_i(p_g) = (w_i * 1{G_i=g} - pg_k), NOT s_i * (1{G_i=g} - pg_k). + # The pg subtraction is NOT weighted by s_i because pg is already + # the population-level expected value of w_i * 1{G_i=g}. + if global_unit_to_idx is not None and precomputed is not None: + unit_sw = np.zeros(n_units) + precomputed_unit_to_idx_local = precomputed["unit_to_idx"] + for idx, uid in idx_uid_pairs: + if uid in precomputed_unit_to_idx_local: + pc_idx = precomputed_unit_to_idx_local[uid] + unit_sw[idx] = survey_w[pc_idx] + else: + unit_sw = np.ones(n_units) + + # w_i * 1{G_i == g_k} - pg_k (matches R's did::wif) + weighted_indicator = indicator_matrix * unit_sw[:, np.newaxis] + indicator_diff = weighted_indicator - pg_keepers + indicator_sum_w = np.sum(indicator_diff, axis=1) + + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): + if1_matrix = indicator_diff / sum_pg_keepers + if2_matrix = np.outer(indicator_sum_w, pg_keepers) / (sum_pg_keepers**2) + wif_matrix = if1_matrix - if2_matrix + wif_contrib = wif_matrix @ effects + else: + indicator_sum = np.sum(indicator_matrix - pg_keepers, axis=1) - with np.errstate(divide='ignore', invalid='ignore', over='ignore'): - if1_matrix = (indicator_matrix - pg_keepers) / sum_pg_keepers - if2_matrix = np.outer(indicator_sum, pg_keepers) / (sum_pg_keepers ** 2) - wif_matrix = if1_matrix - if2_matrix - wif_contrib = wif_matrix @ effects + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): + if1_matrix = (indicator_matrix - pg_keepers) / sum_pg_keepers + if2_matrix = np.outer(indicator_sum, pg_keepers) / (sum_pg_keepers**2) + wif_matrix = if1_matrix - if2_matrix + wif_contrib = wif_matrix @ effects # Check for non-finite values from edge cases if not np.all(np.isfinite(wif_contrib)): import warnings + n_nonfinite = np.sum(~np.isfinite(wif_contrib)) warnings.warn( f"Non-finite values ({n_nonfinite}/{len(wif_contrib)}) in weight influence " "function computation. This may occur with very small samples or extreme " "weights. Returning NaN for SE to signal invalid inference.", RuntimeWarning, - stacklevel=2 + stacklevel=2, ) nan_result = np.full(n_units, np.nan) return nan_result, all_units - # Scale by 1/n_units to match R's getSE formula - psi_wif = wif_contrib / n_units + # Scale by 1/total_weight to match R's getSE formula + # (for non-survey, total_weight == n_units; for survey, total_weight == sum(sw)) + psi_wif = wif_contrib / total_weight # Combine standard and wif terms psi_total = psi_standard + psi_wif @@ -372,14 +427,20 @@ def _compute_aggregated_se_with_wif( global_unit_to_idx = None n_global_units = None if precomputed is not None: - global_unit_to_idx = precomputed['unit_to_idx'] - n_global_units = len(precomputed['all_units']) + global_unit_to_idx = precomputed["unit_to_idx"] + n_global_units = len(precomputed["all_units"]) elif df is not None and unit is not None: n_global_units = df[unit].nunique() psi_total, _ = self._compute_combined_influence_function( - gt_pairs, weights, effects, groups_for_gt, - influence_func_info, df, unit, precomputed, + gt_pairs, + weights, + effects, + groups_for_gt, + influence_func_info, + df, + unit, + precomputed, global_unit_to_idx=global_unit_to_idx, n_global_units=n_global_units, ) @@ -391,7 +452,7 @@ def _compute_aggregated_se_with_wif( if not np.all(np.isfinite(psi_total)): return np.nan - variance = np.sum(psi_total ** 2) + variance = np.sum(psi_total**2) return np.sqrt(variance) def _aggregate_event_study( @@ -414,41 +475,46 @@ def _aggregate_event_study( adjustment that accounts for uncertainty in group-size weights, matching R's did::aggte(..., type="dynamic"). """ - n_units = len(precomputed['all_units']) if precomputed is not None else None - # Organize effects by relative time, keeping track of (g,t) pairs - effects_by_e: Dict[int, List[Tuple[Tuple[Any, Any], float, int]]] = {} + effects_by_e: Dict[int, List[Tuple[Tuple[Any, Any], float, float]]] = {} for (g, t), data in group_time_effects.items(): e = t - g # Relative time if e not in effects_by_e: effects_by_e[e] = [] - effects_by_e[e].append(( - (g, t), # Keep track of the (g,t) pair - data['effect'], - data['n_treated'] - )) + # Use survey_weight_sum for aggregation when available + w = data.get("survey_weight_sum", data["n_treated"]) + effects_by_e[e].append( + ( + (g, t), # Keep track of the (g,t) pair + data["effect"], + w, + ) + ) # Balance the panel if requested if balance_e is not None: # Keep only groups that have effects at relative time balance_e groups_at_e = set() for (g, t), data in group_time_effects.items(): - if t - g == balance_e and np.isfinite(data['effect']): + if t - g == balance_e and np.isfinite(data["effect"]): groups_at_e.add(g) # Filter effects to only include balanced groups - balanced_effects: Dict[int, List[Tuple[Tuple[Any, Any], float, int]]] = {} + balanced_effects: Dict[int, List[Tuple[Tuple[Any, Any], float, float]]] = {} for (g, t), data in group_time_effects.items(): if g in groups_at_e: e = t - g if e not in balanced_effects: balanced_effects[e] = [] - balanced_effects[e].append(( - (g, t), - data['effect'], - data['n_treated'] - )) + w = data.get("survey_weight_sum", data["n_treated"]) + balanced_effects[e].append( + ( + (g, t), + data["effect"], + w, + ) + ) effects_by_e = balanced_effects # Compute aggregated effects and SEs for all relative periods @@ -479,8 +545,7 @@ def _aggregate_event_study( # Compute SE with WIF adjustment (matching R's did::aggte) groups_for_gt = np.array([g for (g, t) in gt_pairs]) agg_se = self._compute_aggregated_se_with_wif( - gt_pairs, weights, effs, groups_for_gt, - influence_func_info, df, unit, precomputed + gt_pairs, weights, effs, groups_for_gt, influence_func_info, df, unit, precomputed ) agg_effects_list.append(agg_effect) @@ -490,32 +555,36 @@ def _aggregate_event_study( # Batch inference for all relative periods if not agg_effects_list: return {} + df_survey_val = precomputed.get("df_survey") if precomputed is not None else None t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch( - np.array(agg_effects_list), np.array(agg_ses_list), alpha=self.alpha + np.array(agg_effects_list), + np.array(agg_ses_list), + alpha=self.alpha, + df=df_survey_val, ) event_study_effects = {} for idx, (e, _) in enumerate(sorted_periods): event_study_effects[e] = { - 'effect': agg_effects_list[idx], - 'se': agg_ses_list[idx], - 't_stat': float(t_stats[idx]), - 'p_value': float(p_values[idx]), - 'conf_int': (float(ci_lowers[idx]), float(ci_uppers[idx])), - 'n_groups': agg_n_groups[idx], + "effect": agg_effects_list[idx], + "se": agg_ses_list[idx], + "t_stat": float(t_stats[idx]), + "p_value": float(p_values[idx]), + "conf_int": (float(ci_lowers[idx]), float(ci_uppers[idx])), + "n_groups": agg_n_groups[idx], } # Add reference period for universal base period mode (matches R did package) - if getattr(self, 'base_period', 'varying') == "universal": + if getattr(self, "base_period", "varying") == "universal": ref_period = -1 - self.anticipation if event_study_effects and ref_period not in event_study_effects: event_study_effects[ref_period] = { - 'effect': 0.0, - 'se': np.nan, - 't_stat': np.nan, - 'p_value': np.nan, - 'conf_int': (np.nan, np.nan), - 'n_groups': 0, + "effect": 0.0, + "se": np.nan, + "t_stat": np.nan, + "p_value": np.nan, + "conf_int": (np.nan, np.nan), + "n_groups": 0, } return event_study_effects @@ -535,13 +604,13 @@ def _aggregate_by_group( Standard errors use influence function aggregation to account for covariances across time periods within a cohort. """ - n_units = len(precomputed['all_units']) if precomputed is not None else None + n_units = len(precomputed["all_units"]) if precomputed is not None else None # Collect all group aggregation data first group_data_list = [] for g in groups: g_effects = [ - ((g, t), data['effect']) + ((g, t), data["effect"]) for (gg, t), data in group_time_effects.items() if gg == g and t >= g - self.anticipation ] @@ -574,19 +643,23 @@ def _aggregate_by_group( # Batch inference agg_effects = np.array([x[1] for x in group_data_list]) agg_ses = np.array([x[2] for x in group_data_list]) + df_survey_val = precomputed.get("df_survey") if precomputed is not None else None t_stats, p_values, ci_lowers, ci_uppers = safe_inference_batch( - agg_effects, agg_ses, alpha=self.alpha + agg_effects, + agg_ses, + alpha=self.alpha, + df=df_survey_val, ) group_effects = {} for idx, (g, agg_effect, agg_se, n_periods) in enumerate(group_data_list): group_effects[g] = { - 'effect': agg_effect, - 'se': agg_se, - 't_stat': float(t_stats[idx]), - 'p_value': float(p_values[idx]), - 'conf_int': (float(ci_lowers[idx]), float(ci_uppers[idx])), - 'n_periods': n_periods, + "effect": agg_effect, + "se": agg_se, + "t_stat": float(t_stats[idx]), + "p_value": float(p_values[idx]), + "conf_int": (float(ci_lowers[idx]), float(ci_uppers[idx])), + "n_periods": n_periods, } return group_effects diff --git a/diff_diff/staggered_results.py b/diff_diff/staggered_results.py index 53eeba1..7f47dc3 100644 --- a/diff_diff/staggered_results.py +++ b/diff_diff/staggered_results.py @@ -6,7 +6,7 @@ """ from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import numpy as np import pandas as pd @@ -117,6 +117,8 @@ class CallawaySantAnnaResults: bootstrap_results: Optional["CSBootstrapResults"] = field(default=None, repr=False) cband_crit_value: Optional[float] = None pscore_trim: float = 0.01 + # Survey design metadata (SurveyMetadata instance from diff_diff.survey) + survey_metadata: Optional[Any] = field(default=None, repr=False) def __repr__(self) -> str: """Concise string representation.""" @@ -160,6 +162,27 @@ def summary(self, alpha: Optional[float] = None) -> str: "", ] + # Survey design info + if self.survey_metadata is not None: + sm = self.survey_metadata + lines.extend( + [ + "-" * 85, + "Survey Design".center(85), + "-" * 85, + f"{'Weight type:':<30} {sm.weight_type:>10}", + ] + ) + if sm.n_strata is not None: + lines.append(f"{'Strata:':<30} {sm.n_strata:>10}") + if sm.n_psu is not None: + lines.append(f"{'PSU/Cluster:':<30} {sm.n_psu:>10}") + lines.append(f"{'Effective sample size:':<30} {sm.effective_n:>10.1f}") + lines.append(f"{'Design effect (DEFF):':<30} {sm.design_effect:>10.2f}") + if sm.df_survey is not None: + lines.append(f"{'Survey d.f.:':<30} {sm.df_survey:>10}") + lines.extend(["-" * 85, ""]) + # Overall ATT lines.extend( [ diff --git a/diff_diff/triple_diff.py b/diff_diff/triple_diff.py index 15f0426..7caed6c 100644 --- a/diff_diff/triple_diff.py +++ b/diff_diff/triple_diff.py @@ -480,8 +480,8 @@ def fit( survey_design : SurveyDesign, optional Survey design specification for complex survey data. When provided, uses survey weights for estimation and Taylor Series - Linearization (TSL) for variance estimation. Only supported - with estimation_method="reg". + Linearization (TSL) for variance estimation. Supported with + all estimation methods ("reg", "ipw", "dr"). Returns ------- @@ -493,7 +493,7 @@ def fit( ValueError If required columns are missing or data validation fails. NotImplementedError - If survey_design is used with estimation_method="ipw" or "dr". + If survey_design is used with wild_bootstrap inference. """ # Resolve survey design if provided from diff_diff.survey import ( @@ -507,13 +507,6 @@ def fit( _resolve_survey_for_fit(survey_design, data, "analytical") ) - # Guard IPW/DR with survey weights - if survey_design is not None and self.estimation_method in ("ipw", "dr"): - raise NotImplementedError( - "IPW and doubly robust methods with survey weights require " - "weighted solve_logit(), planned for Phase 5." - ) - # Validate inputs self._validate_data(data, outcome, group, partition, time, covariates) @@ -572,9 +565,25 @@ def fit( resolved_survey=resolved_survey, ) elif self.estimation_method == "ipw": - att, se, r_squared, pscore_stats = self._ipw_estimation(y, G, P, T, X) + att, se, r_squared, pscore_stats = self._ipw_estimation( + y, + G, + P, + T, + X, + survey_weights=survey_weights, + resolved_survey=resolved_survey, + ) else: # doubly robust - att, se, r_squared, pscore_stats = self._doubly_robust(y, G, P, T, X) + att, se, r_squared, pscore_stats = self._doubly_robust( + y, + G, + P, + T, + X, + survey_weights=survey_weights, + resolved_survey=resolved_survey, + ) # Compute inference # When survey design is active, use survey df (n_PSU - n_strata) @@ -765,6 +774,8 @@ def _ipw_estimation( P: np.ndarray, T: np.ndarray, X: Optional[np.ndarray], + survey_weights: Optional[np.ndarray] = None, + resolved_survey=None, ) -> Tuple[float, float, Optional[float], Optional[Dict[str, float]]]: """ Estimate ATT using inverse probability weighting via three-DiD @@ -774,7 +785,15 @@ def _ipw_estimation( subgroup membership P(subgroup=4|X) within {j, 4} subset. Matches R's triplediff::ddd() with est_method="ipw". """ - return self._estimate_ddd_decomposition(y, G, P, T, X) + return self._estimate_ddd_decomposition( + y, + G, + P, + T, + X, + survey_weights=survey_weights, + resolved_survey=resolved_survey, + ) def _doubly_robust( self, @@ -783,6 +802,8 @@ def _doubly_robust( P: np.ndarray, T: np.ndarray, X: Optional[np.ndarray], + survey_weights: Optional[np.ndarray] = None, + resolved_survey=None, ) -> Tuple[float, float, Optional[float], Optional[Dict[str, float]]]: """ Estimate ATT using doubly robust estimation via three-DiD @@ -793,7 +814,15 @@ def _doubly_robust( correctly specified. Matches R's triplediff::ddd() with est_method="dr". """ - return self._estimate_ddd_decomposition(y, G, P, T, X) + return self._estimate_ddd_decomposition( + y, + G, + P, + T, + X, + survey_weights=survey_weights, + resolved_survey=resolved_survey, + ) def _estimate_ddd_decomposition( self, @@ -865,6 +894,9 @@ def _estimate_ddd_decomposition( PA4 = (sg_sub == 4).astype(float) PAa = (sg_sub == j).astype(float) + # Subset survey weights for this comparison (needed for logit) + w_sub = survey_weights[mask] if survey_weights is not None else None + # --- Propensity scores --- if est_method == "reg": # RA: no propensity scores needed @@ -878,6 +910,7 @@ def _estimate_ddd_decomposition( covX_sub[:, 1:], PA4, rank_deficient_action=self.rank_deficient_action, + weights=w_sub, ) except Exception: if self.rank_deficient_action == "error": @@ -910,6 +943,8 @@ def _estimate_ddd_decomposition( # Hessian only when PS was actually estimated if ps_estimated: W_ps = pscore_sub * (1 - pscore_sub) + if w_sub is not None: + W_ps = W_ps * w_sub try: XWX = covX_sub.T @ (W_ps[:, None] * covX_sub) hessian = np.linalg.inv(XWX) * n_sub @@ -933,9 +968,6 @@ def _estimate_ddd_decomposition( overlap_issues.append((j, frac_trimmed)) hessian = None - # Subset survey weights for this comparison - w_sub = survey_weights[mask] if survey_weights is not None else None - # --- Outcome regression --- if est_method == "ipw": # IPW: no outcome regression @@ -1045,11 +1077,18 @@ def _estimate_ddd_decomposition( if resolved_survey is not None: # Survey-weighted SE via TSL on the combined influence function. - # Treat the IF as a single-parameter score vector: - # compute_survey_vcov(ones, IF, resolved) gives V(ATT). + # The pairwise IFs already incorporate survey weights (via weighted + # Riesz representers), but compute_survey_vcov multiplies by weights + # again internally. Divide out the survey weights to get the + # unweighted IF that TSL will correctly re-weight. from diff_diff.survey import compute_survey_vcov - vcov_survey = compute_survey_vcov(np.ones((n, 1)), inf_func, resolved_survey) + inf_for_tsl = inf_func.copy() + sw = survey_weights + if sw is not None: + nz = sw > 0 + inf_for_tsl[nz] = inf_for_tsl[nz] / sw[nz] + vcov_survey = compute_survey_vcov(np.ones((n, 1)), inf_for_tsl, resolved_survey) se = float(np.sqrt(vcov_survey[0, 0])) elif self._cluster_ids is not None: # Cluster-robust SE: sum IF within clusters, then Liang-Zeger variance @@ -1192,7 +1231,17 @@ def _compute_did_rc( Matches R's triplediff::compute_did_rc(). """ if est_method == "ipw": - return self._compute_did_rc_ipw(y, post, PA4, PAa, pscore, covX, hessian, n) + return self._compute_did_rc_ipw( + y, + post, + PA4, + PAa, + pscore, + covX, + hessian, + n, + weights=weights, + ) elif est_method == "reg": return self._compute_did_rc_reg( y, @@ -1221,6 +1270,7 @@ def _compute_did_rc( or_trt_post, hessian, n, + weights=weights, ) def _compute_did_rc_ipw( @@ -1233,6 +1283,7 @@ def _compute_did_rc_ipw( covX: np.ndarray, hessian: Optional[np.ndarray], n: int, + weights: Optional[np.ndarray] = None, ) -> Tuple[float, np.ndarray]: """IPW DiD for a single pairwise comparison (RC).""" # Riesz representers (IPW weights * indicators) @@ -1241,6 +1292,13 @@ def _compute_did_rc_ipw( riesz_control_pre = pscore * PAa * (1 - post) / (1 - pscore) riesz_control_post = pscore * PAa * post / (1 - pscore) + # Incorporate survey weights into Riesz representers + if weights is not None: + riesz_treat_pre = riesz_treat_pre * weights + riesz_treat_post = riesz_treat_post * weights + riesz_control_pre = riesz_control_pre * weights + riesz_control_post = riesz_control_post * weights + # Hajek-normalized cell-time means def _hajek(riesz, y_vals): denom = np.mean(riesz) @@ -1274,8 +1332,12 @@ def _hajek(riesz, y_vals): # Propensity score correction for influence function if hessian is not None: score_ps = (PA4 - pscore)[:, None] * covX + if weights is not None: + score_ps = score_ps * weights[:, None] asy_lin_rep_ps = score_ps @ hessian + # Riesz representers already incorporate survey weights, + # so use np.mean (not np.average with weights) to avoid double-weighting. M2_pre = np.mean( (riesz_control_pre * (y - att_control_pre))[:, None] * covX, axis=0, @@ -1399,6 +1461,7 @@ def _compute_did_rc_dr( or_trt_post: np.ndarray, hessian: Optional[np.ndarray], n: int, + weights: Optional[np.ndarray] = None, ) -> Tuple[float, np.ndarray]: """Doubly robust DiD for a single pairwise comparison (RC).""" or_ctrl = post * or_ctrl_post + (1 - post) * or_ctrl_pre @@ -1412,6 +1475,16 @@ def _compute_did_rc_dr( riesz_dt1 = PA4 * post riesz_dt0 = PA4 * (1 - post) + # Incorporate survey weights into Riesz representers + if weights is not None: + riesz_treat_pre = riesz_treat_pre * weights + riesz_treat_post = riesz_treat_post * weights + riesz_control_pre = riesz_control_pre * weights + riesz_control_post = riesz_control_post * weights + riesz_d = riesz_d * weights + riesz_dt1 = riesz_dt1 * weights + riesz_dt0 = riesz_dt0 * weights + # DR cell-time components def _safe_ratio(num, denom): return num / denom if denom > 0 else 0.0 @@ -1452,9 +1525,15 @@ def _safe_ratio(num, denom): # --- Influence function --- # OLS asymptotic linear representations (control subgroup) weights_ols_pre = PAa * (1 - post) - wols_x_pre = weights_ols_pre[:, None] * covX - wols_eX_pre = (weights_ols_pre * (y - or_ctrl_pre))[:, None] * covX - XpX_pre = wols_x_pre.T @ covX / n + if weights is not None: + w_sum = np.sum(weights) + wols_x_pre = (weights_ols_pre * weights)[:, None] * covX + wols_eX_pre = (weights_ols_pre * weights * (y - or_ctrl_pre))[:, None] * covX + XpX_pre = wols_x_pre.T @ covX / w_sum + else: + wols_x_pre = weights_ols_pre[:, None] * covX + wols_eX_pre = (weights_ols_pre * (y - or_ctrl_pre))[:, None] * covX + XpX_pre = wols_x_pre.T @ covX / n try: XpX_inv_pre = np.linalg.inv(XpX_pre) except np.linalg.LinAlgError: @@ -1462,9 +1541,14 @@ def _safe_ratio(num, denom): asy_lin_rep_ols_pre = wols_eX_pre @ XpX_inv_pre weights_ols_post = PAa * post - wols_x_post = weights_ols_post[:, None] * covX - wols_eX_post = (weights_ols_post * (y - or_ctrl_post))[:, None] * covX - XpX_post = wols_x_post.T @ covX / n + if weights is not None: + wols_x_post = (weights_ols_post * weights)[:, None] * covX + wols_eX_post = (weights_ols_post * weights * (y - or_ctrl_post))[:, None] * covX + XpX_post = wols_x_post.T @ covX / w_sum + else: + wols_x_post = weights_ols_post[:, None] * covX + wols_eX_post = (weights_ols_post * (y - or_ctrl_post))[:, None] * covX + XpX_post = wols_x_post.T @ covX / n try: XpX_inv_post = np.linalg.inv(XpX_post) except np.linalg.LinAlgError: @@ -1473,9 +1557,14 @@ def _safe_ratio(num, denom): # OLS representations (treated subgroup) weights_ols_pre_treat = PA4 * (1 - post) - wols_x_pre_treat = weights_ols_pre_treat[:, None] * covX - wols_eX_pre_treat = (weights_ols_pre_treat * (y - or_trt_pre))[:, None] * covX - XpX_pre_treat = wols_x_pre_treat.T @ covX / n + if weights is not None: + wols_x_pre_treat = (weights_ols_pre_treat * weights)[:, None] * covX + wols_eX_pre_treat = (weights_ols_pre_treat * weights * (y - or_trt_pre))[:, None] * covX + XpX_pre_treat = wols_x_pre_treat.T @ covX / w_sum + else: + wols_x_pre_treat = weights_ols_pre_treat[:, None] * covX + wols_eX_pre_treat = (weights_ols_pre_treat * (y - or_trt_pre))[:, None] * covX + XpX_pre_treat = wols_x_pre_treat.T @ covX / n try: XpX_inv_pre_treat = np.linalg.inv(XpX_pre_treat) except np.linalg.LinAlgError: @@ -1483,9 +1572,16 @@ def _safe_ratio(num, denom): asy_lin_rep_ols_pre_treat = wols_eX_pre_treat @ XpX_inv_pre_treat weights_ols_post_treat = PA4 * post - wols_x_post_treat = weights_ols_post_treat[:, None] * covX - wols_eX_post_treat = (weights_ols_post_treat * (y - or_trt_post))[:, None] * covX - XpX_post_treat = wols_x_post_treat.T @ covX / n + if weights is not None: + wols_x_post_treat = (weights_ols_post_treat * weights)[:, None] * covX + wols_eX_post_treat = (weights_ols_post_treat * weights * (y - or_trt_post))[ + :, None + ] * covX + XpX_post_treat = wols_x_post_treat.T @ covX / w_sum + else: + wols_x_post_treat = weights_ols_post_treat[:, None] * covX + wols_eX_post_treat = (weights_ols_post_treat * (y - or_trt_post))[:, None] * covX + XpX_post_treat = wols_x_post_treat.T @ covX / n try: XpX_inv_post_treat = np.linalg.inv(XpX_post_treat) except np.linalg.LinAlgError: @@ -1494,6 +1590,8 @@ def _safe_ratio(num, denom): # Propensity score linear representation score_ps = (PA4 - pscore)[:, None] * covX + if weights is not None: + score_ps = score_ps * weights[:, None] if hessian is not None: asy_lin_rep_ps = score_ps @ hessian else: @@ -1515,6 +1613,8 @@ def _safe_ratio(num, denom): ) # OR correction for treated + # Riesz representers already incorporate survey weights, + # so use np.mean (not weighted average) to avoid double-weighting. M1_post = ( (-np.mean((riesz_treat_post * post)[:, None] * covX, axis=0) / m_riesz_treat_post) if m_riesz_treat_post > 0 @@ -1605,16 +1705,14 @@ def _safe_ratio(num, denom): # OR combination mom_post = ( np.mean( - (riesz_d[:, None] / m_riesz_d - riesz_dt1[:, None] / m_riesz_dt1) * covX, - axis=0, + (riesz_d[:, None] / m_riesz_d - riesz_dt1[:, None] / m_riesz_dt1) * covX, axis=0 ) if (m_riesz_d > 0 and m_riesz_dt1 > 0) else np.zeros(covX.shape[1]) ) mom_pre = ( np.mean( - (riesz_d[:, None] / m_riesz_d - riesz_dt0[:, None] / m_riesz_dt0) * covX, - axis=0, + (riesz_d[:, None] / m_riesz_d - riesz_dt0[:, None] / m_riesz_dt0) * covX, axis=0 ) if (m_riesz_d > 0 and m_riesz_dt0 > 0) else np.zeros(covX.shape[1]) diff --git a/diff_diff/two_stage.py b/diff_diff/two_stage.py index 4b7603e..b468a2b 100644 --- a/diff_diff/two_stage.py +++ b/diff_diff/two_stage.py @@ -22,6 +22,7 @@ """ import warnings +from dataclasses import replace from typing import Any, Dict, List, Optional, Tuple import numpy as np @@ -37,12 +38,11 @@ from diff_diff.linalg import solve_ols from diff_diff.two_stage_bootstrap import TwoStageDiDBootstrapMixin from diff_diff.two_stage_results import ( - TwoStageBootstrapResults, + TwoStageBootstrapResults, # noqa: F401 TwoStageDiDResults, ) # noqa: F401 (re-export) from diff_diff.utils import safe_inference - # ============================================================================= # Main Estimator # ============================================================================= @@ -173,6 +173,7 @@ def fit( covariates: Optional[List[str]] = None, aggregate: Optional[str] = None, balance_e: Optional[int] = None, + survey_design: object = None, ) -> TwoStageDiDResults: """ Fit the two-stage DiD estimator. @@ -218,7 +219,32 @@ def fit( if missing: raise ValueError(f"Missing columns: {missing}") + # Create working copy df = data.copy() + + # Resolve survey design if provided + from diff_diff.survey import ( + _inject_cluster_as_psu, + _resolve_effective_cluster, + _resolve_survey_for_fit, + _validate_unit_constant_survey, + ) + + resolved_survey, survey_weights, survey_weight_type, survey_metadata = ( + _resolve_survey_for_fit(survey_design, data, "analytical") + ) + + # Validate within-unit constancy for panel survey designs + if resolved_survey is not None: + _validate_unit_constant_survey(data, unit, survey_design) + + # Guard bootstrap + survey + if self.n_bootstrap > 0 and resolved_survey is not None: + raise NotImplementedError( + "Bootstrap inference with survey weights is not yet supported " + "for TwoStageDiD. Use analytical inference (n_bootstrap=0)." + ) + df[time] = pd.to_numeric(df[time]) df[first_treat] = pd.to_numeric(df[first_treat]) @@ -261,6 +287,51 @@ def fit( ) df = df[~df[unit].isin(always_treated_units)].copy() + # Subset survey arrays to match filtered df + if survey_weights is not None: + keep_mask = ~data[unit].isin(always_treated_units) + survey_weights = survey_weights[keep_mask.values] + if resolved_survey is not None: + keep_mask = ~data[unit].isin(always_treated_units) + resolved_survey = replace( + resolved_survey, + weights=resolved_survey.weights[keep_mask.values], + strata=( + resolved_survey.strata[keep_mask.values] + if resolved_survey.strata is not None + else None + ), + psu=( + resolved_survey.psu[keep_mask.values] + if resolved_survey.psu is not None + else None + ), + fpc=( + resolved_survey.fpc[keep_mask.values] + if resolved_survey.fpc is not None + else None + ), + ) + # Recompute n_psu/n_strata after subsetting + new_n_psu = ( + len(np.unique(resolved_survey.psu)) if resolved_survey.psu is not None else 0 + ) + new_n_strata = ( + len(np.unique(resolved_survey.strata)) + if resolved_survey.strata is not None + else 0 + ) + resolved_survey = replace(resolved_survey, n_psu=new_n_psu, n_strata=new_n_strata) + # Recompute survey_metadata since it depends on these counts + from diff_diff.survey import compute_survey_metadata + + raw_w = ( + df[survey_design.weights].values.astype(np.float64) + if survey_design.weights + else np.ones(len(df), dtype=np.float64) + ) + survey_metadata = compute_survey_metadata(resolved_survey, raw_w) + # Treatment indicator with anticipation effective_treat = df[first_treat] - self.anticipation df["_treated"] = (~df["_never_treated"]) & (df[time] >= effective_treat) @@ -302,6 +373,26 @@ def fit( f"Available columns: {list(df.columns)}" ) + # Resolve effective cluster and inject cluster-as-PSU for survey variance + if resolved_survey is not None: + cluster_ids_raw = df[cluster_var].values if cluster_var in df.columns else None + effective_cluster_ids = _resolve_effective_cluster( + resolved_survey, + cluster_ids_raw, + cluster_var if self.cluster is not None else None, + ) + resolved_survey = _inject_cluster_as_psu(resolved_survey, effective_cluster_ids) + # Recompute metadata after PSU injection + if resolved_survey.psu is not None and survey_metadata is not None: + from diff_diff.survey import compute_survey_metadata + + raw_w = ( + df[survey_design.weights].values.astype(np.float64) + if survey_design.weights + else np.ones(len(df), dtype=np.float64) + ) + survey_metadata = compute_survey_metadata(resolved_survey, raw_w) + # Relative time df["_rel_time"] = np.where( ~df["_never_treated"], @@ -311,7 +402,7 @@ def fit( # ---- Stage 1: OLS on untreated observations ---- unit_fe, time_fe, grand_mean, delta_hat, kept_cov_mask = self._fit_untreated_model( - df, outcome, unit, time, covariates, omega_0_mask + df, outcome, unit, time, covariates, omega_0_mask, weights=survey_weights ) # ---- Rank condition checks ---- @@ -359,6 +450,9 @@ def fit( # Build design matrices and compute effects + GMM variance ref_period = -1 - self.anticipation + # Survey degrees of freedom for t-distribution inference + _survey_df = resolved_survey.df_survey if resolved_survey is not None else None + # Always compute overall ATT (static specification) overall_att, overall_se = self._stage2_static( df=df, @@ -374,9 +468,13 @@ def fit( delta_hat=delta_hat, cluster_var=cluster_var, kept_cov_mask=kept_cov_mask, + survey_weights=survey_weights, + survey_weight_type=survey_weight_type, ) - overall_t, overall_p, overall_ci = safe_inference(overall_att, overall_se, alpha=self.alpha) + overall_t, overall_p, overall_ci = safe_inference( + overall_att, overall_se, alpha=self.alpha, df=_survey_df + ) # Event study and group aggregation event_study_effects = None @@ -400,6 +498,9 @@ def fit( ref_period=ref_period, balance_e=balance_e, kept_cov_mask=kept_cov_mask, + survey_weights=survey_weights, + survey_weight_type=survey_weight_type, + survey_df=_survey_df, ) if aggregate in ("group", "all"): @@ -418,6 +519,9 @@ def fit( cluster_var=cluster_var, treatment_groups=treatment_groups, kept_cov_mask=kept_cov_mask, + survey_weights=survey_weights, + survey_weight_type=survey_weight_type, + survey_df=_survey_df, ) # Build treatment effects DataFrame @@ -530,6 +634,7 @@ def fit( n_control_units=n_control_units, alpha=self.alpha, bootstrap_results=bootstrap_results, + survey_metadata=survey_metadata, ) self.is_fitted_ = True @@ -547,10 +652,17 @@ def _iterative_fe( idx: pd.Index, max_iter: int = 100, tol: float = 1e-10, + weights: Optional[np.ndarray] = None, ) -> Tuple[Dict[Any, float], Dict[Any, float]]: """ Estimate unit and time FE via iterative alternating projection. + Parameters + ---------- + weights : np.ndarray, optional + Survey weights. When provided, uses weighted group means + (sum(w*x)/sum(w)) instead of unweighted means. + Returns ------- unit_fe : dict @@ -562,23 +674,36 @@ def _iterative_fe( alpha = np.zeros(n) beta = np.zeros(n) + if weights is not None: + w_series = pd.Series(weights, index=idx) + wsum_t = w_series.groupby(time_vals).transform("sum").values + wsum_u = w_series.groupby(unit_vals).transform("sum").values + with np.errstate(invalid="ignore", divide="ignore"): for iteration in range(max_iter): resid_after_alpha = y - alpha - beta_new = ( - pd.Series(resid_after_alpha, index=idx) - .groupby(time_vals) - .transform("mean") - .values - ) + if weights is not None: + wr_t = pd.Series(resid_after_alpha * weights, index=idx) + beta_new = wr_t.groupby(time_vals).transform("sum").values / wsum_t + else: + beta_new = ( + pd.Series(resid_after_alpha, index=idx) + .groupby(time_vals) + .transform("mean") + .values + ) resid_after_beta = y - beta_new - alpha_new = ( - pd.Series(resid_after_beta, index=idx) - .groupby(unit_vals) - .transform("mean") - .values - ) + if weights is not None: + wr_u = pd.Series(resid_after_beta * weights, index=idx) + alpha_new = wr_u.groupby(unit_vals).transform("sum").values / wsum_u + else: + alpha_new = ( + pd.Series(resid_after_beta, index=idx) + .groupby(unit_vals) + .transform("mean") + .values + ) max_change = max( np.max(np.abs(alpha_new - alpha)), @@ -601,21 +726,43 @@ def _iterative_demean( idx: pd.Index, max_iter: int = 100, tol: float = 1e-10, + weights: Optional[np.ndarray] = None, ) -> np.ndarray: - """Demean a vector by iterative alternating projection.""" + """Demean a vector by iterative alternating projection (unit + time FE removal). + + Parameters + ---------- + weights : np.ndarray, optional + Survey weights. When provided, uses weighted group means + (sum(w*x)/sum(w)) instead of unweighted means. + """ result = vals.copy() + + if weights is not None: + w_series = pd.Series(weights, index=idx) + wsum_t = w_series.groupby(time_vals).transform("sum").values + wsum_u = w_series.groupby(unit_vals).transform("sum").values + with np.errstate(invalid="ignore", divide="ignore"): for _ in range(max_iter): - time_means = ( - pd.Series(result, index=idx).groupby(time_vals).transform("mean").values - ) + if weights is not None: + wr_t = pd.Series(result * weights, index=idx) + time_means = wr_t.groupby(time_vals).transform("sum").values / wsum_t + else: + time_means = ( + pd.Series(result, index=idx).groupby(time_vals).transform("mean").values + ) result_after_time = result - time_means - unit_means = ( - pd.Series(result_after_time, index=idx) - .groupby(unit_vals) - .transform("mean") - .values - ) + if weights is not None: + wr_u = pd.Series(result_after_time * weights, index=idx) + unit_means = wr_u.groupby(unit_vals).transform("sum").values / wsum_u + else: + unit_means = ( + pd.Series(result_after_time, index=idx) + .groupby(unit_vals) + .transform("mean") + .values + ) result_new = result_after_time - unit_means if np.max(np.abs(result_new - result)) < tol: result = result_new @@ -631,22 +778,30 @@ def _fit_untreated_model( time: str, covariates: Optional[List[str]], omega_0_mask: pd.Series, + weights: Optional[np.ndarray] = None, ) -> Tuple[ Dict[Any, float], Dict[Any, float], float, Optional[np.ndarray], Optional[np.ndarray] ]: """ Stage 1: Estimate unit + time FE on untreated observations. + Parameters + ---------- + weights : np.ndarray, optional + Full-panel survey weights (same length as df). The untreated subset + is extracted internally via omega_0_mask. When None, unweighted. + Returns ------- unit_fe, time_fe, grand_mean, delta_hat, kept_cov_mask """ df_0 = df.loc[omega_0_mask] + w_0 = weights[omega_0_mask.values] if weights is not None else None if covariates is None or len(covariates) == 0: y = df_0[outcome].values.copy() unit_fe, time_fe = self._iterative_fe( - y, df_0[unit].values, df_0[time].values, df_0.index + y, df_0[unit].values, df_0[time].values, df_0.index, weights=w_0 ) return unit_fe, time_fe, 0.0, None, None @@ -657,10 +812,10 @@ def _fit_untreated_model( times = df_0[time].values n_cov = len(covariates) - y_dm = self._iterative_demean(y, units, times, df_0.index) + y_dm = self._iterative_demean(y, units, times, df_0.index, weights=w_0) X_dm = np.column_stack( [ - self._iterative_demean(X_raw[:, j], units, times, df_0.index) + self._iterative_demean(X_raw[:, j], units, times, df_0.index, weights=w_0) for j in range(n_cov) ] ) @@ -671,13 +826,14 @@ def _fit_untreated_model( return_vcov=False, rank_deficient_action=self.rank_deficient_action, column_names=covariates, + weights=w_0, ) delta_hat = result[0] kept_cov_mask = np.isfinite(delta_hat) delta_hat_clean = np.where(np.isfinite(delta_hat), delta_hat, 0.0) y_adj = y - np.dot(X_raw, delta_hat_clean) - unit_fe, time_fe = self._iterative_fe(y_adj, units, times, df_0.index) + unit_fe, time_fe = self._iterative_fe(y_adj, units, times, df_0.index, weights=w_0) return unit_fe, time_fe, 0.0, delta_hat_clean, kept_cov_mask @@ -736,6 +892,8 @@ def _stage2_static( delta_hat: Optional[np.ndarray], cluster_var: str, kept_cov_mask: Optional[np.ndarray], + survey_weights: Optional[np.ndarray] = None, + survey_weight_type: str = "pweight", ) -> Tuple[float, float]: """ Static (simple ATT) Stage 2: OLS of y_tilde on D_it. @@ -763,7 +921,13 @@ def _stage2_static( return np.nan, np.nan # Stage 2 OLS for point estimate (discard naive SE) - coef, residuals, _ = solve_ols(X_2, y_tilde, return_vcov=False) + coef, residuals, _ = solve_ols( + X_2, + y_tilde, + return_vcov=False, + weights=survey_weights, + weight_type=survey_weight_type, + ) att = float(coef[0]) # GMM sandwich variance @@ -782,6 +946,7 @@ def _stage2_static( X_2=X_2, eps_2=eps_2, cluster_ids=df[cluster_var].values, + survey_weights=survey_weights, ) se = float(np.sqrt(max(V[0, 0], 0.0))) @@ -805,6 +970,9 @@ def _stage2_event_study( ref_period: int, balance_e: Optional[int], kept_cov_mask: Optional[np.ndarray], + survey_weights: Optional[np.ndarray] = None, + survey_weight_type: str = "pweight", + survey_df: Optional[int] = None, ) -> Dict[int, Dict[str, Any]]: """Event study Stage 2: OLS of y_tilde on relative-time dummies.""" y_tilde = df["_y_tilde"].values.copy() @@ -921,7 +1089,13 @@ def _stage2_event_study( X_2[i, horizon_to_col[h_int]] = 1.0 # Stage 2 OLS - coef, residuals, _ = solve_ols(X_2, y_tilde, return_vcov=False) + coef, residuals, _ = solve_ols( + X_2, + y_tilde, + return_vcov=False, + weights=survey_weights, + weight_type=survey_weight_type, + ) eps_2 = y_tilde - np.dot(X_2, coef) # GMM variance for full coefficient vector @@ -938,6 +1112,7 @@ def _stage2_event_study( X_2=X_2, eps_2=eps_2, cluster_ids=df[cluster_var].values, + survey_weights=survey_weights, ) # Build results dict @@ -971,7 +1146,7 @@ def _stage2_event_study( effect = float(coef[j]) se = float(np.sqrt(max(V[j, j], 0.0))) - t_stat, p_val, ci = safe_inference(effect, se, alpha=self.alpha) + t_stat, p_val, ci = safe_inference(effect, se, alpha=self.alpha, df=survey_df) event_study_effects[h] = { "effect": effect, @@ -1011,6 +1186,9 @@ def _stage2_group( cluster_var: str, treatment_groups: List[Any], kept_cov_mask: Optional[np.ndarray], + survey_weights: Optional[np.ndarray] = None, + survey_weight_type: str = "pweight", + survey_df: Optional[int] = None, ) -> Dict[Any, Dict[str, Any]]: """Group (cohort) Stage 2: OLS of y_tilde on cohort dummies.""" y_tilde = df["_y_tilde"].values.copy() @@ -1033,7 +1211,13 @@ def _stage2_group( X_2[i, group_to_col[g]] = 1.0 # Stage 2 OLS - coef, residuals, _ = solve_ols(X_2, y_tilde, return_vcov=False) + coef, residuals, _ = solve_ols( + X_2, + y_tilde, + return_vcov=False, + weights=survey_weights, + weight_type=survey_weight_type, + ) eps_2 = y_tilde - np.dot(X_2, coef) # GMM variance @@ -1050,6 +1234,7 @@ def _stage2_group( X_2=X_2, eps_2=eps_2, cluster_ids=df[cluster_var].values, + survey_weights=survey_weights, ) group_effects: Dict[Any, Dict[str, Any]] = {} @@ -1071,7 +1256,7 @@ def _stage2_group( effect = float(coef[j]) se = float(np.sqrt(max(V[j, j], 0.0))) - t_stat, p_val, ci = safe_inference(effect, se, alpha=self.alpha) + t_stat, p_val, ci = safe_inference(effect, se, alpha=self.alpha, df=survey_df) group_effects[g] = { "effect": effect, @@ -1137,6 +1322,7 @@ def _compute_gmm_variance( X_2: np.ndarray, eps_2: np.ndarray, cluster_ids: np.ndarray, + survey_weights: Optional[np.ndarray] = None, ) -> np.ndarray: """ Compute GMM sandwich variance (Butts & Gardner 2022). @@ -1155,6 +1341,12 @@ def _compute_gmm_variance( S_g = gamma_hat' c_g - X'_{2g} eps_{2g} c_g = X'_{10g} eps_{10g} + With survey weights W (diagonal): + Bread: (X'_2 W X_2)^{-1} + gamma_hat: (X'_{10} W X_{10})^{-1} (X'_1 W X_2) + c_g = sum_{i in g} w_i * x_{10i} * eps_{10i} + s2_g = sum_{i in g} w_i * x_{2i} * eps_{2i} + Parameters ---------- X_2 : np.ndarray, shape (n, k) @@ -1163,6 +1355,9 @@ def _compute_gmm_variance( Stage 2 residuals. cluster_ids : np.ndarray, shape (n,) Cluster identifiers. + survey_weights : np.ndarray, optional + Survey weights of shape (n,). When None, unweighted (identical + to current code). Returns ------- @@ -1207,27 +1402,37 @@ def _compute_gmm_variance( eps_10[omega_0] = y_vals[omega_0] - fitted_1[omega_0] # Stage 1 residual eps_10[~omega_0] = y_vals[~omega_0] # x_{10i} = 0, so eps_10 = Y - # 1. gamma_hat = (X'_{10} X_{10})^{-1} (X'_1 X_2) [p x k] - XtX_10 = X_10_sparse.T @ X_10_sparse # (p x p) sparse - Xt1_X2 = X_1_sparse.T @ X_2 # (p x k) dense + # 1. gamma_hat = (X'_{10} W X_{10})^{-1} (X'_1 W X_2) [p x k] + # With survey weights, both cross-products need W + if survey_weights is not None: + XtWX_10 = X_10_sparse.T @ X_10_sparse.multiply(survey_weights[:, None]) + Xt1_WX2 = X_1_sparse.T @ (X_2 * survey_weights[:, None]) + else: + XtWX_10 = X_10_sparse.T @ X_10_sparse # (p x p) sparse + Xt1_WX2 = X_1_sparse.T @ X_2 # (p x k) dense try: - solve_XtX = sparse_factorized(XtX_10.tocsc()) - if Xt1_X2.ndim == 1: - gamma_hat = solve_XtX(Xt1_X2).reshape(-1, 1) + solve_XtX = sparse_factorized(XtWX_10.tocsc()) + if Xt1_WX2.ndim == 1: + gamma_hat = solve_XtX(Xt1_WX2).reshape(-1, 1) else: gamma_hat = np.column_stack( - [solve_XtX(Xt1_X2[:, j]) for j in range(Xt1_X2.shape[1])] + [solve_XtX(Xt1_WX2[:, j]) for j in range(Xt1_WX2.shape[1])] ) except RuntimeError: # Singular matrix — fall back to dense least-squares - gamma_hat = np.linalg.lstsq(XtX_10.toarray(), Xt1_X2, rcond=None)[0] + gamma_hat = np.linalg.lstsq(XtWX_10.toarray(), Xt1_WX2, rcond=None)[0] if gamma_hat.ndim == 1: gamma_hat = gamma_hat.reshape(-1, 1) - # 2. Per-cluster Stage 1 scores: c_g = X'_{10g} eps_{10g} + # 2. Per-cluster Stage 1 scores: c_g = sum_{i in g} w_i * x_{10i} * eps_{10i} # Only untreated obs have non-zero X_10 rows - weighted_X10 = X_10_sparse.multiply(eps_10[:, None]) # sparse element-wise + # With survey weights: multiply eps_10 by survey_weights before sparse multiply + if survey_weights is not None: + weighted_eps_10 = survey_weights * eps_10 + else: + weighted_eps_10 = eps_10 + weighted_X10 = X_10_sparse.multiply(weighted_eps_10[:, None]) # sparse element-wise unique_clusters, cluster_indices = np.unique(cluster_ids, return_inverse=True) G = len(unique_clusters) @@ -1246,8 +1451,12 @@ def _compute_gmm_variance( for j_col in range(p): np.add.at(c_by_cluster[:, j_col], cluster_indices, weighted_X10_dense[:, j_col]) - # 3. Per-cluster Stage 2 scores: X'_{2g} eps_{2g} - weighted_X2 = X_2 * eps_2[:, None] # (n x k) dense + # 3. Per-cluster Stage 2 scores: s2_g = sum_{i in g} w_i * x_{2i} * eps_{2i} + if survey_weights is not None: + weighted_eps_2 = survey_weights * eps_2 + else: + weighted_eps_2 = eps_2 + weighted_X2 = X_2 * weighted_eps_2[:, None] # (n x k) dense s2_by_cluster = np.zeros((G, k)) for j_col in range(k): np.add.at(s2_by_cluster[:, j_col], cluster_indices, weighted_X2[:, j_col]) @@ -1259,13 +1468,16 @@ def _compute_gmm_variance( with np.errstate(invalid="ignore", over="ignore"): meat = S.T @ S # (k x k) - # 6. Bread: (X'_2 X_2)^{-1} + # 6. Bread: (X'_2 W X_2)^{-1} with np.errstate(invalid="ignore", over="ignore", divide="ignore"): - XtX_2 = X_2.T @ X_2 + if survey_weights is not None: + XtWX_2 = X_2.T @ (X_2 * survey_weights[:, None]) + else: + XtWX_2 = X_2.T @ X_2 try: - bread = np.linalg.solve(XtX_2, np.eye(k)) + bread = np.linalg.solve(XtWX_2, np.eye(k)) except np.linalg.LinAlgError: - bread = np.linalg.lstsq(XtX_2, np.eye(k), rcond=None)[0] + bread = np.linalg.lstsq(XtWX_2, np.eye(k), rcond=None)[0] # 7. V = bread @ meat @ bread V = bread @ meat @ bread @@ -1402,6 +1614,7 @@ def two_stage_did( covariates: Optional[List[str]] = None, aggregate: Optional[str] = None, balance_e: Optional[int] = None, + survey_design: object = None, **kwargs, ) -> TwoStageDiDResults: """ @@ -1453,4 +1666,5 @@ def two_stage_did( covariates=covariates, aggregate=aggregate, balance_e=balance_e, + survey_design=survey_design, ) diff --git a/diff_diff/two_stage_results.py b/diff_diff/two_stage_results.py index 16b4916..b06cd00 100644 --- a/diff_diff/two_stage_results.py +++ b/diff_diff/two_stage_results.py @@ -137,6 +137,8 @@ class TwoStageDiDResults: n_control_units: int alpha: float = 0.05 bootstrap_results: Optional[TwoStageBootstrapResults] = field(default=None, repr=False) + # Survey design metadata (SurveyMetadata instance from diff_diff.survey) + survey_metadata: Optional[Any] = field(default=None, repr=False) def __repr__(self) -> str: """Concise string representation.""" @@ -180,6 +182,27 @@ def summary(self, alpha: Optional[float] = None) -> str: "", ] + # Survey design info + if self.survey_metadata is not None: + sm = self.survey_metadata + lines.extend( + [ + "-" * 85, + "Survey Design".center(85), + "-" * 85, + f"{'Weight type:':<30} {sm.weight_type:>10}", + ] + ) + if sm.n_strata is not None: + lines.append(f"{'Strata:':<30} {sm.n_strata:>10}") + if sm.n_psu is not None: + lines.append(f"{'PSU/Cluster:':<30} {sm.n_psu:>10}") + lines.append(f"{'Effective sample size:':<30} {sm.effective_n:>10.1f}") + lines.append(f"{'Design effect (DEFF):':<30} {sm.design_effect:>10.2f}") + if sm.df_survey is not None: + lines.append(f"{'Survey d.f.:':<30} {sm.df_survey:>10}") + lines.extend(["-" * 85, ""]) + # Overall ATT lines.extend( [ diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 9c5e9c0..bee9e10 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -416,6 +416,9 @@ The multiplier bootstrap uses random weights w_i with E[w]=0 and Var(w)=1: a base period later than `t` (matching R's `did::att_gt()`) - Does not require never-treated units: when all units are eventually treated, not-yet-treated cohorts serve as controls for each other (requires ≥2 cohorts) +- **Note:** CallawaySantAnna survey weights: regression method supports covariates; IPW/DR support no-covariate only (covariates+IPW/DR+survey raises NotImplementedError — DRDID nuisance IF not yet implemented). Survey weights compose with IPW weights multiplicatively. WIF in aggregation matches R's did::wif() formula. Bootstrap + survey deferred. +- **Note (deviation from R):** CallawaySantAnna survey reg+covariates per-cell SEs use a conservative plug-in IF based on WLS residuals (`inf_control = -w_i/sum(w_t) * resid_i`), mirroring the unweighted code's structure. This omits the semiparametrically efficient nuisance correction from DRDID's `reg_did_panel`, so SEs may be wider than necessary but have correct asymptotic coverage. The unweighted code uses the same plug-in approach. The efficient IF correction is deferred to future work. +- **Note (deviation from R):** Per-cell ATT(g,t) SEs under survey weights use influence-function-based variance (matching R's `did::att_gt` analytical SE path) rather than full Taylor-series linearization with strata/PSU/FPC structure. The survey design structure is reflected in aggregation-level SEs via the WIF and survey degrees of freedom, but individual (g,t) cell SEs do not incorporate the full design-based variance. This is consistent with R's approach where per-cell SEs are influence-function-based and design effects enter at the aggregation stage. **Reference implementation(s):** - R: `did::att_gt()` (Callaway & Sant'Anna's official package) @@ -840,6 +843,7 @@ Y_it = alpha_i + beta_t [+ X'_it * delta] + W'_it * gamma + epsilon_it - **treatment_effects DataFrame weights:** `weight` column uses `1/n_valid` for finite tau_hat and 0 for NaN tau_hat, consistent with the ATT estimand. - **Rank-deficient covariates in variance:** Covariates with NaN coefficients (dropped for rank deficiency in Step 1) are excluded from the variance design matrices `A_0`/`A_1`. Only covariates with finite coefficients participate in the `v_it` projection. - **Sparse variance solver:** `_compute_v_untreated_with_covariates` uses `scipy.sparse.linalg.spsolve` to solve `(A_0'A_0) z = A_1'w` without densifying the normal equations matrix. Falls back to dense `lstsq` if the sparse solver fails. +- **Note:** Survey weights enter ImputationDiD via weighted iterative FE (Step 1), survey-weighted ATT aggregation (Step 3), and survey-weighted conservative variance (Theorem 3). Survey df used for t-distribution inference. Bootstrap + survey deferred. - **Bootstrap inference:** Uses multiplier bootstrap on the Theorem 3 influence function: `psi_i = sum_t v_it * epsilon_tilde_it`. Cluster-level psi sums are pre-computed for each aggregation target (overall, per-horizon, per-group), then perturbed with multiplier weights (Rademacher by default; configurable via `bootstrap_weights` parameter to use Mammen or Webb weights, matching CallawaySantAnna). This is a library extension (not in the paper) consistent with CallawaySantAnna/SunAbraham bootstrap patterns. - **Auxiliary residuals (Equation 8):** Uses v_it-weighted tau_tilde_g formula: `tau_tilde_g = sum(v_it * tau_hat_it) / sum(v_it)` within each partition group. Zero-weight groups (common in event-study SE computation) fall back to unweighted mean. @@ -917,6 +921,7 @@ Our implementation uses multiplier bootstrap on the GMM influence function: clus - **No never-treated units (Proposition 5):** When there are no never-treated units and multiple treatment cohorts, horizons h >= h_bar (where h_bar = max(groups) - min(groups)) are unidentified per Proposition 5 of Borusyak et al. (2024). These produce NaN inference with n_obs > 0 (treated observations exist but counterfactual is unidentified) and a warning listing affected horizons. Matches ImputationDiD behavior. Proposition 5 applies to event study horizons only, not cohort aggregation — a cohort whose treated obs all fall at Prop 5 horizons naturally gets n_obs=0 in group effects because all its y_tilde values are NaN. - **Zero-observation horizons after filtering:** When `balance_e` or NaN `y_tilde` filtering results in zero observations for some non-Prop-5 event study horizons, those horizons produce NaN for all inference fields (effect, SE, t-stat, p-value, CI) with n_obs=0. - **Zero-observation cohorts in group effects:** If all treated observations for a cohort have NaN `y_tilde` (excluded from estimation), that cohort's group effect is NaN with n_obs=0. +- **Note:** Survey weights in TwoStageDiD GMM sandwich via weighted cross-products: bread uses (X'_2 W X_2)^{-1}, gamma_hat uses (X'_{10} W X_{10})^{-1}(X'_1 W X_2), per-cluster scores multiply by survey weights. Survey df used for t-distribution. Bootstrap + survey deferred. **Reference implementation(s):** - R: `did2s::did2s()` (Kyle Butts & John Gardner) @@ -1241,8 +1246,7 @@ has no additional effect. - [x] Influence function SE: std(w3·IF_3 + w2·IF_2 - w1·IF_1) / sqrt(n) - [x] Cluster-robust SE via Liang-Zeger variance on influence function - [x] ATT and SE match R within <0.001% for all methods and DGP types -- [x] Survey design support (Phase 3): regression method with weighted OLS + TSL on combined influence functions; IPW/DR deferred -- **Note:** TripleDifference IPW/DR with survey weights deferred until weighted solve_logit() (Phase 5) +- [x] Survey design support: all methods (reg, IPW, DR) with weighted OLS/logit + TSL on combined influence functions. Weighted solve_logit() for propensity scores in IPW/DR paths. --- diff --git a/docs/survey-roadmap.md b/docs/survey-roadmap.md index e59fd21..f228e8c 100644 --- a/docs/survey-roadmap.md +++ b/docs/survey-roadmap.md @@ -1,7 +1,7 @@ # Survey Data Support Roadmap This document captures planned future work for survey data support in diff-diff. -Phases 1-3 are implemented. Phases 4-5 are deferred for future PRs. +Phases 1-4 are implemented. Phase 5 is deferred for future PRs. ## Implemented (Phases 1-2) @@ -21,19 +21,18 @@ Phases 1-3 are implemented. Phases 4-5 are deferred for future PRs. | StackedDiD | `stacked_did.py` | pweight only | Q-weights compose multiplicatively with survey weights; TSL vcov on composed weights; fweight/aweight rejected (composition changes weight semantics) | | SunAbraham | `sun_abraham.py` | Full | Survey weights in LinearRegression + weighted within-transform; bootstrap+survey deferred | | BaconDecomposition | `bacon.py` | Diagnostic | Weighted cell means, weighted within-transform, weighted group shares; no inference (diagnostic only) | -| TripleDifference | `triple_diff.py` | Reg only | Regression method with weighted OLS + TSL on influence functions; IPW/DR deferred (needs weighted `solve_logit()`) | +| TripleDifference | `triple_diff.py` | Full | Regression, IPW, and DR methods with weighted OLS/logit + TSL on influence functions | | ContinuousDiD | `continuous_did.py` | Analytical | Weighted B-spline OLS + TSL on influence functions; bootstrap+survey deferred | | EfficientDiD | `efficient_did.py` | Analytical | Weighted means/covariances in Omega* + TSL on EIF scores; bootstrap+survey deferred | ### Phase 3 Deferred Work The following capabilities were deferred from Phase 3 because they depend on -Phase 5 infrastructure (weighted `solve_logit()` or bootstrap+survey interaction): +Phase 5 infrastructure (bootstrap+survey interaction): | Estimator | Deferred Capability | Blocker | |-----------|-------------------|---------| | SunAbraham | Pairs bootstrap + survey | Phase 5: bootstrap+survey interaction | -| TripleDifference | IPW and DR methods + survey | Phase 5: weighted `solve_logit()` | | ContinuousDiD | Multiplier bootstrap + survey | Phase 5: bootstrap+survey interaction | | EfficientDiD | Multiplier bootstrap + survey | Phase 5: bootstrap+survey interaction | | EfficientDiD | Covariates (DR path) + survey | DR nuisance estimation needs survey weight threading | @@ -41,19 +40,43 @@ Phase 5 infrastructure (weighted `solve_logit()` or bootstrap+survey interaction All blocked combinations raise `NotImplementedError` when attempted, with a message pointing to the planned phase or describing the limitation. -## Phase 4: Complex Standalone Estimators +## Implemented (Phase 4): Complex Standalone Estimators + Weighted Logit -These require more substantial changes beyond threading weights. +| Estimator | File | Survey Support | Notes | +|-----------|------|----------------|-------| +| ImputationDiD | `imputation.py` | Analytical | Weighted iterative FE, weighted ATT aggregation, weighted conservative variance (Theorem 3); bootstrap+survey deferred | +| TwoStageDiD | `two_stage.py` | Analytical | Weighted iterative FE, weighted Stage 2 OLS, weighted GMM sandwich variance; bootstrap+survey deferred | +| CallawaySantAnna | `staggered.py` | Analytical | Survey-weighted regression (all cases), IPW and DR (no-covariate only); survey-weighted WIF in aggregation; covariates+IPW/DR deferred (needs DRDID nuisance IF); bootstrap+survey deferred | + +**Infrastructure**: Weighted `solve_logit()` added to `linalg.py` — survey weights +enter the IRLS working weights as `w_survey * mu * (1 - mu)`. This also unblocked +TripleDifference IPW/DR from Phase 3 deferred work. + +### Phase 4 Deferred Work + +| Estimator | Deferred Capability | Blocker | +|-----------|-------------------|---------| +| ImputationDiD | Bootstrap + survey | Phase 5: bootstrap+survey interaction | +| TwoStageDiD | Bootstrap + survey | Phase 5: bootstrap+survey interaction | +| CallawaySantAnna | Bootstrap + survey | Phase 5: bootstrap+survey interaction | + +### Remaining for Phase 5 | Estimator | File | Complexity | Notes | |-----------|------|------------|-------| -| ImputationDiD | `imputation.py` | Medium | Weighted Stage 1 OLS, weighted ATT aggregation, weighted conservative variance | -| TwoStageDiD | `two_stage.py` | Medium-High | Weighted within-transformation, weighted GMM sandwich | -| CallawaySantAnna | `staggered.py` | High | Weights enter propensity score (weighted solve_logit), IPW/DR reweighting (w_survey x w_ipw), multiplier bootstrap | -| SyntheticDiD | `synthetic_did.py` | Medium | Survey-weighted treated mean in optimization, weighted placebo variance | -| TROP | `trop.py` | Medium | Survey weights in ATT aggregation and LOOCV (distinct from TROP's internal weights) | +| SyntheticDiD | `synthetic_did.py` | Medium | Survey-weighted treated mean in optimization, weighted placebo variance (bootstrap-based SE only) | +| TROP | `trop.py` | Medium | Survey weights in ATT aggregation and LOOCV (bootstrap-based SE only) | + +## Phase 5: Advanced Features + Remaining Estimators -## Phase 5: Advanced Features +### SyntheticDiD and TROP Survey Support +Both estimators use bootstrap/placebo for SE with no analytical variance path. +Phase 5 provides survey-weighted point estimates and survey-aware bootstrap SE. + +### Bootstrap + Survey Interaction +Unblock bootstrap + survey for all estimators that currently defer it +(ImputationDiD, TwoStageDiD, CallawaySantAnna, SunAbraham, ContinuousDiD, +EfficientDiD). Requires survey-aware resampling schemes. ### Replicate Weight Variance Re-run WLS for each replicate weight column, compute variance from distribution @@ -64,16 +87,7 @@ Add `replicate_weights`, `replicate_type`, `replicate_rho` fields to SurveyDesig Compare survey vcov to SRS vcov element-wise. Report design effect per coefficient. Effective n = n / DEFF. -### Wild Bootstrap with Survey Weights -Add `survey_weights` parameter to `wild_bootstrap_se()`. Requires careful -interaction between bootstrap resampling and survey weight structure. - ### Subpopulation Analysis `SurveyDesign.subpopulation(data, mask)` — zero-out weights for excluded observations while preserving the full design structure for correct variance estimation (unlike simple subsetting, which would drop design information). - -### Weighted `solve_logit()` -Add weights to the IRLS iteration in `solve_logit()`. Required by -CallawaySantAnna (propensity score estimation) and TripleDifference (IPW method). -Working weights become `w_survey * mu * (1 - mu)`. diff --git a/tests/test_survey_phase3.py b/tests/test_survey_phase3.py index 3248fa0..3cddfbc 100644 --- a/tests/test_survey_phase3.py +++ b/tests/test_survey_phase3.py @@ -496,35 +496,37 @@ def test_smoke_reg_method(self, ddd_survey_data): assert np.isfinite(result.se) assert result.survey_metadata is not None - def test_ipw_survey_raises(self, ddd_survey_data): - """IPW + survey should raise NotImplementedError.""" + def test_ipw_survey_works(self, ddd_survey_data): + """IPW + survey now works (unblocked by weighted solve_logit in Phase 4).""" from diff_diff import TripleDifference sd = SurveyDesign(weights="weight") - with pytest.raises(NotImplementedError, match="IPW"): - TripleDifference(estimation_method="ipw").fit( - ddd_survey_data, - "outcome", - "group", - "partition", - "time", - survey_design=sd, - ) + result = TripleDifference(estimation_method="ipw").fit( + ddd_survey_data, + "outcome", + "group", + "partition", + "time", + survey_design=sd, + ) + assert np.isfinite(result.att) + assert np.isfinite(result.se) - def test_dr_survey_raises(self, ddd_survey_data): - """DR + survey should raise NotImplementedError.""" + def test_dr_survey_works(self, ddd_survey_data): + """DR + survey now works (unblocked by weighted solve_logit in Phase 4).""" from diff_diff import TripleDifference sd = SurveyDesign(weights="weight") - with pytest.raises(NotImplementedError, match="doubly robust"): - TripleDifference(estimation_method="dr").fit( - ddd_survey_data, - "outcome", - "group", - "partition", - "time", - survey_design=sd, - ) + result = TripleDifference(estimation_method="dr").fit( + ddd_survey_data, + "outcome", + "group", + "partition", + "time", + survey_design=sd, + ) + assert np.isfinite(result.att) + assert np.isfinite(result.se) def test_weighted_changes_att(self, ddd_survey_data): """Survey weights should change ATT.""" diff --git a/tests/test_survey_phase4.py b/tests/test_survey_phase4.py new file mode 100644 index 0000000..8931774 --- /dev/null +++ b/tests/test_survey_phase4.py @@ -0,0 +1,1314 @@ +"""Tests for Phase 4 survey support: complex standalone estimators. + +Covers: ImputationDiD, TwoStageDiD, CallawaySantAnna, weighted solve_logit(), +TripleDifference IPW/DR unblock, and cross-estimator scale invariance. +""" + +import numpy as np +import pandas as pd +import pytest + +from diff_diff import ( + CallawaySantAnna, + ImputationDiD, + SurveyDesign, + TripleDifference, + TwoStageDiD, + generate_staggered_data, +) +from diff_diff.linalg import solve_logit + +# ============================================================================= +# Shared Fixtures +# ============================================================================= + + +@pytest.fixture +def staggered_survey_data(): + """Staggered treatment panel with survey design columns. + + 200 units via generate_staggered_data, then add unit-level survey columns + (weights, stratum, psu, fpc) that are constant within each unit. + """ + data = generate_staggered_data(n_units=200, seed=42) + + # Add unit-level survey columns (constant within unit) + unit_ids = data["unit"].unique() + n_units = len(unit_ids) + np.random.RandomState(42) + + unit_weight = 1.0 + 0.5 * (np.arange(n_units) % 5) + unit_stratum = np.arange(n_units) // 40 # 5 strata + unit_psu = np.arange(n_units) // 10 # 20 PSUs + unit_fpc = np.full(n_units, 400.0) # population per stratum + + unit_map = {uid: i for i, uid in enumerate(unit_ids)} + idx = data["unit"].map(unit_map).values + + data["weight"] = unit_weight[idx] + data["stratum"] = unit_stratum[idx] + data["psu"] = unit_psu[idx] + data["fpc"] = unit_fpc[idx] + + return data + + +@pytest.fixture +def survey_design_weights_only(): + """SurveyDesign with weights only.""" + return SurveyDesign(weights="weight") + + +@pytest.fixture +def survey_design_full(): + """SurveyDesign with weights, strata, psu.""" + return SurveyDesign(weights="weight", strata="stratum", psu="psu") + + +@pytest.fixture +def ddd_survey_data(): + """Cross-sectional DDD data with survey columns for TripleDifference.""" + np.random.seed(42) + n = 400 + data = pd.DataFrame( + { + "outcome": np.random.randn(n) + 0.5, + "group": np.random.choice([0, 1], n), + "partition": np.random.choice([0, 1], n), + "time": np.random.choice([0, 1], n), + "weight": np.random.uniform(0.5, 2.0, n), + "stratum": np.random.choice([1, 2, 3], n), + } + ) + # Add treatment effect for treated+eligible+post + mask = (data["group"] == 1) & (data["partition"] == 1) & (data["time"] == 1) + data.loc[mask, "outcome"] += 1.5 + return data + + +# ============================================================================= +# TestWeightedSolveLogit +# ============================================================================= + + +class TestWeightedSolveLogit: + """Tests for weighted solve_logit() (IRLS with survey weights).""" + + def test_uniform_weights_match_unweighted(self): + """Uniform weights should produce same coefficients as unweighted.""" + rng = np.random.RandomState(123) + n = 200 + X = rng.randn(n, 2) + y = (X @ [1.0, -0.5] + rng.randn(n) > 0).astype(float) + + beta_unw, probs_unw = solve_logit(X, y) + beta_w, probs_w = solve_logit(X, y, weights=np.ones(n)) + + np.testing.assert_allclose(beta_unw, beta_w, atol=1e-10) + np.testing.assert_allclose(probs_unw, probs_w, atol=1e-10) + + def test_convergence_with_weights(self): + """solve_logit converges with non-uniform survey weights.""" + rng = np.random.RandomState(42) + n = 200 + X = rng.randn(n, 2) + y = (X @ [0.5, -0.5] + rng.randn(n) * 1.5 > 0).astype(float) + weights = rng.uniform(0.5, 3.0, n) + + beta, probs = solve_logit(X, y, weights=weights) + + # Should produce finite coefficients (convergence) + assert np.all(np.isfinite(beta)) + assert np.all(np.isfinite(probs)) + assert np.all(probs > 0) and np.all(probs < 1) + + def test_separation_detection_with_weights(self): + """Separation warning should still fire with survey weights.""" + rng = np.random.RandomState(789) + n = 100 + # Create near-separation: x1 perfectly predicts y + x1 = np.linspace(-5, 5, n) + X = x1.reshape(-1, 1) + y = (x1 > 0).astype(float) + weights = rng.uniform(1.0, 3.0, n) + + with pytest.warns(UserWarning): + beta, probs = solve_logit(X, y, weights=weights) + + assert np.all(np.isfinite(beta)) + + def test_known_answer_small_dataset(self): + """Manual check on a small dataset — weights should shift coefficients.""" + rng = np.random.RandomState(333) + n = 50 + X = rng.randn(n, 1) + prob = 1.0 / (1.0 + np.exp(-(0.5 + 1.0 * X[:, 0]))) + y = (rng.rand(n) < prob).astype(float) + + # Unweighted fit + beta_unw, _ = solve_logit(X, y) + + # Non-uniform weights: upweight observations where y=1 + weights = np.where(y == 1, 5.0, 1.0) + beta_w, _ = solve_logit(X, y, weights=weights) + + # Weighted fit should shift intercept (upweighting y=1 shifts boundary) + assert beta_w[0] != pytest.approx(beta_unw[0], abs=0.1) + + def test_rank_deficiency_with_weights(self): + """Rank-deficient columns should still be detected with weights.""" + rng = np.random.RandomState(111) + n = 100 + x1 = rng.randn(n) + # x2 is a perfect linear combination of x1 + X = np.column_stack([x1, 2.0 * x1]) + y = (x1 > 0).astype(float) + weights = rng.uniform(0.5, 2.0, n) + + with pytest.warns(UserWarning, match="[Rr]ank"): + beta, probs = solve_logit(X, y, weights=weights) + + assert np.all(np.isfinite(probs)) + + def test_weight_scale_invariance(self): + """Multiplying weights by a constant should not change beta.""" + rng = np.random.RandomState(222) + n = 200 + X = rng.randn(n, 2) + y = (X @ [0.8, -0.3] + rng.randn(n) > 0).astype(float) + weights = rng.uniform(1.0, 4.0, n) + + beta1, _ = solve_logit(X, y, weights=weights) + beta2, _ = solve_logit(X, y, weights=weights * 2.0) + + np.testing.assert_allclose(beta1, beta2, atol=1e-10) + + +# ============================================================================= +# TestImputationDiDSurvey +# ============================================================================= + + +class TestImputationDiDSurvey: + """Survey design support for ImputationDiD.""" + + def test_smoke_weights_only(self, staggered_survey_data, survey_design_weights_only): + """ImputationDiD runs with weights-only survey design.""" + result = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.survey_metadata is not None + + def test_uniform_weights_match_unweighted(self, staggered_survey_data): + """Uniform survey weights should match unweighted result.""" + staggered_survey_data["uniform_w"] = 1.0 + sd = SurveyDesign(weights="uniform_w") + + r_unw = ImputationDiD().fit( + staggered_survey_data, "outcome", "unit", "period", "first_treat" + ) + r_w = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + ) + assert abs(r_unw.overall_att - r_w.overall_att) < 1e-10 + + def test_survey_metadata_fields(self, staggered_survey_data, survey_design_full): + """survey_metadata has correct fields with full design.""" + result = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_full, + ) + sm = result.survey_metadata + assert sm is not None + assert sm.weight_type == "pweight" + assert sm.effective_n > 0 + assert sm.design_effect > 0 + assert sm.n_strata is not None + assert sm.n_psu is not None + + def test_se_differs_with_design(self, staggered_survey_data): + """Weights-only vs full design: same ATT, different inference via survey df.""" + sd_w = SurveyDesign(weights="weight") + sd_full = SurveyDesign(weights="weight", strata="stratum", psu="psu") + + r_w = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd_w, + ) + r_full = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd_full, + ) + # ATTs should be the same (same weights) + assert abs(r_w.overall_att - r_full.overall_att) < 1e-10 + # Full design should carry survey df (strata/PSU structure) + assert r_full.survey_metadata is not None + assert r_full.survey_metadata.n_strata is not None + assert r_full.survey_metadata.n_psu is not None + # P-values should differ due to t-distribution with survey df + if np.isfinite(r_w.overall_p_value) and np.isfinite(r_full.overall_p_value): + assert r_w.overall_p_value != r_full.overall_p_value + + def test_weighted_att_differs(self, staggered_survey_data, survey_design_weights_only): + """Non-uniform survey weights should change the overall ATT.""" + r_unw = ImputationDiD().fit( + staggered_survey_data, "outcome", "unit", "period", "first_treat" + ) + r_w = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + # ATT should differ because non-uniform weights change aggregation + assert r_unw.overall_att != r_w.overall_att + + def test_event_study_with_survey(self, staggered_survey_data, survey_design_weights_only): + """Event study effects exist when using survey design.""" + result = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="event_study", + survey_design=survey_design_weights_only, + ) + assert result.event_study_effects is not None + assert len(result.event_study_effects) > 0 + for h, eff in result.event_study_effects.items(): + assert np.isfinite(eff["effect"]) + assert np.isfinite(eff["se"]) + + def test_bootstrap_survey_raises(self, staggered_survey_data, survey_design_weights_only): + """Bootstrap + survey should raise NotImplementedError.""" + with pytest.raises(NotImplementedError, match="[Bb]ootstrap"): + ImputationDiD(n_bootstrap=99).fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + + def test_summary_includes_survey(self, staggered_survey_data, survey_design_weights_only): + """Summary output should include survey design section.""" + result = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + summary = result.summary() + assert "Survey Design" in summary + assert "pweight" in summary + + def test_se_scale_invariance_fe_only(self, staggered_survey_data): + """SE must be invariant to weight rescaling (FE-only, no covariates).""" + data = staggered_survey_data.copy() + data["weight2"] = data["weight"] * 3.1 + sd1 = SurveyDesign(weights="weight") + sd2 = SurveyDesign(weights="weight2") + r1 = ImputationDiD().fit( + data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd1, + ) + r2 = ImputationDiD().fit( + data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd2, + ) + assert np.isclose(r1.overall_att, r2.overall_att, atol=1e-8) + assert np.isclose( + r1.overall_se, r2.overall_se, atol=1e-8 + ), f"SE not scale-invariant (FE-only): {r1.overall_se} vs {r2.overall_se}" + + def test_se_scale_invariance_with_covariates(self, staggered_survey_data): + """SE must be invariant to weight rescaling (with covariates).""" + data = staggered_survey_data.copy() + data["x1"] = np.random.default_rng(99).normal(0, 1, len(data)) + data["weight2"] = data["weight"] * 3.1 + sd1 = SurveyDesign(weights="weight") + sd2 = SurveyDesign(weights="weight2") + r1 = ImputationDiD().fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=sd1, + ) + r2 = ImputationDiD().fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=sd2, + ) + assert np.isclose(r1.overall_att, r2.overall_att, atol=1e-8) + assert np.isclose( + r1.overall_se, r2.overall_se, atol=1e-8 + ), f"SE not scale-invariant (covariates): {r1.overall_se} vs {r2.overall_se}" + + def test_wrapper_imputation_did_with_survey(self, staggered_survey_data): + """imputation_did() wrapper forwards survey_design correctly.""" + from diff_diff import imputation_did + + sd = SurveyDesign(weights="weight") + r_wrapper = imputation_did( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + ) + r_direct = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + ) + assert np.isclose(r_wrapper.overall_att, r_direct.overall_att, atol=1e-10) + assert r_wrapper.survey_metadata is not None + + def test_aggregate_group_with_survey(self, staggered_survey_data, survey_design_weights_only): + """aggregate='group' works with survey design.""" + result = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="group", + survey_design=survey_design_weights_only, + ) + assert result.group_effects is not None + assert len(result.group_effects) > 0 + for g, eff in result.group_effects.items(): + assert np.isfinite(eff["effect"]) + + def test_aggregate_all_with_survey(self, staggered_survey_data, survey_design_weights_only): + """aggregate='all' works with survey design.""" + result = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="all", + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert result.event_study_effects is not None + assert result.group_effects is not None + + +# ============================================================================= +# TestTwoStageDiDSurvey +# ============================================================================= + + +class TestTwoStageDiDSurvey: + """Survey design support for TwoStageDiD.""" + + def test_smoke_weights_only(self, staggered_survey_data, survey_design_weights_only): + """TwoStageDiD runs with weights-only survey design.""" + result = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.survey_metadata is not None + + def test_uniform_weights_match_unweighted(self, staggered_survey_data): + """Uniform survey weights should match unweighted result.""" + staggered_survey_data["uniform_w"] = 1.0 + sd = SurveyDesign(weights="uniform_w") + + r_unw = TwoStageDiD().fit(staggered_survey_data, "outcome", "unit", "period", "first_treat") + r_w = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + ) + assert abs(r_unw.overall_att - r_w.overall_att) < 1e-10 + + def test_survey_metadata_fields(self, staggered_survey_data, survey_design_full): + """survey_metadata has correct fields with full design.""" + result = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_full, + ) + sm = result.survey_metadata + assert sm is not None + assert sm.weight_type == "pweight" + assert sm.effective_n > 0 + assert sm.design_effect > 0 + assert sm.n_strata is not None + assert sm.n_psu is not None + + def test_se_differs_with_design(self, staggered_survey_data): + """Weights-only vs full design: same ATT, different inference via survey df.""" + sd_w = SurveyDesign(weights="weight") + sd_full = SurveyDesign(weights="weight", strata="stratum", psu="psu") + + r_w = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd_w, + ) + r_full = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd_full, + ) + # ATTs should be the same (same weights) + assert abs(r_w.overall_att - r_full.overall_att) < 1e-10 + # Full design should carry survey df (strata/PSU structure) + assert r_full.survey_metadata is not None + assert r_full.survey_metadata.n_strata is not None + assert r_full.survey_metadata.n_psu is not None + # P-values should differ due to t-distribution with survey df + if np.isfinite(r_w.overall_p_value) and np.isfinite(r_full.overall_p_value): + assert r_w.overall_p_value != r_full.overall_p_value + + def test_weighted_gmm_variance(self, staggered_survey_data, survey_design_weights_only): + """GMM SE should differ from unweighted (weights affect sandwich).""" + r_unw = TwoStageDiD().fit(staggered_survey_data, "outcome", "unit", "period", "first_treat") + r_w = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + # SE magnitude should differ (not just sign) + assert abs(r_unw.overall_se - r_w.overall_se) > 1e-6 + + def test_bootstrap_survey_raises(self, staggered_survey_data, survey_design_weights_only): + """Bootstrap + survey should raise NotImplementedError.""" + with pytest.raises(NotImplementedError, match="[Bb]ootstrap"): + TwoStageDiD(n_bootstrap=99).fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + + def test_summary_includes_survey(self, staggered_survey_data, survey_design_weights_only): + """Summary output should include survey design section.""" + result = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + summary = result.summary() + assert "Survey Design" in summary + assert "pweight" in summary + + def test_wrapper_two_stage_did_with_survey(self, staggered_survey_data): + """two_stage_did() wrapper forwards survey_design correctly.""" + from diff_diff import two_stage_did + + sd = SurveyDesign(weights="weight") + r_wrapper = two_stage_did( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + ) + r_direct = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + ) + assert np.isclose(r_wrapper.overall_att, r_direct.overall_att, atol=1e-10) + assert r_wrapper.survey_metadata is not None + + def test_aggregate_group_with_survey(self, staggered_survey_data, survey_design_weights_only): + """aggregate='group' works with survey design.""" + result = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="group", + survey_design=survey_design_weights_only, + ) + assert result.group_effects is not None + assert len(result.group_effects) > 0 + + def test_aggregate_all_with_survey(self, staggered_survey_data, survey_design_weights_only): + """aggregate='all' works with survey design.""" + result = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="all", + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert result.event_study_effects is not None + assert result.group_effects is not None + + def test_always_treated_with_survey(self, staggered_survey_data): + """TwoStageDiD with survey + always-treated units should not crash.""" + data = staggered_survey_data.copy() + # Make some units always-treated (first_treat at or before min time) + min_time = data["period"].min() + units = data["unit"].unique() + always_treated_units = units[:3] + for u in always_treated_units: + data.loc[data["unit"] == u, "first_treat"] = min_time + sd = SurveyDesign(weights="weight") + result = TwoStageDiD().fit( + data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.survey_metadata is not None + + +# ============================================================================= +# TestCallawaySantAnnaSurvey +# ============================================================================= + + +class TestCallawaySantAnnaSurvey: + """Survey design support for CallawaySantAnna.""" + + def test_smoke_reg_weights_only(self, staggered_survey_data, survey_design_weights_only): + """CallawaySantAnna regression method works with survey design.""" + result = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.survey_metadata is not None + + def test_smoke_ipw_weights_only(self, staggered_survey_data, survey_design_weights_only): + """CallawaySantAnna IPW method works with survey design.""" + result = CallawaySantAnna(estimation_method="ipw").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.survey_metadata is not None + + def test_smoke_dr_weights_only(self, staggered_survey_data, survey_design_weights_only): + """CallawaySantAnna DR method works with survey design.""" + result = CallawaySantAnna(estimation_method="dr").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + assert np.isfinite(result.overall_se) + assert result.survey_metadata is not None + + def test_uniform_weights_match_unweighted(self, staggered_survey_data): + """Uniform survey weights should match unweighted result — all methods.""" + staggered_survey_data["uniform_w"] = 1.0 + sd = SurveyDesign(weights="uniform_w") + + for method in ["reg", "ipw", "dr"]: + r_unw = CallawaySantAnna(estimation_method=method).fit( + staggered_survey_data, "outcome", "unit", "period", "first_treat" + ) + r_w = CallawaySantAnna(estimation_method=method).fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd, + ) + assert abs(r_unw.overall_att - r_w.overall_att) < 1e-8, f"method={method}: ATT mismatch" + + def test_survey_metadata_fields(self, staggered_survey_data, survey_design_full): + """survey_metadata has correct fields with full design.""" + result = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_full, + ) + sm = result.survey_metadata + assert sm is not None + assert sm.weight_type == "pweight" + assert sm.effective_n > 0 + assert sm.design_effect > 0 + assert sm.n_strata is not None + assert sm.n_psu is not None + + def test_se_differs_with_design(self, staggered_survey_data): + """Weights-only vs full design: same ATT, different inference via survey df.""" + sd_w = SurveyDesign(weights="weight") + sd_full = SurveyDesign(weights="weight", strata="stratum", psu="psu") + + r_w = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd_w, + ) + r_full = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd_full, + ) + # ATTs should be the same (same weights) + assert abs(r_w.overall_att - r_full.overall_att) < 1e-10 + # Full design should carry survey df (strata/PSU structure) + assert r_full.survey_metadata is not None + assert r_full.survey_metadata.n_strata is not None + assert r_full.survey_metadata.n_psu is not None + # P-values should differ due to t-distribution with survey df + if np.isfinite(r_w.overall_p_value) and np.isfinite(r_full.overall_p_value): + assert r_w.overall_p_value != r_full.overall_p_value + + def test_bootstrap_survey_raises(self, staggered_survey_data, survey_design_weights_only): + """Bootstrap + survey should raise NotImplementedError.""" + with pytest.raises(NotImplementedError, match="[Bb]ootstrap"): + CallawaySantAnna(estimation_method="reg", n_bootstrap=99).fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + + def test_ipw_covariates_survey_raises(self, staggered_survey_data, survey_design_weights_only): + """IPW + covariates + survey should raise NotImplementedError.""" + data = staggered_survey_data.copy() + data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) + with pytest.raises(NotImplementedError, match="covariates"): + CallawaySantAnna(estimation_method="ipw").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_design_weights_only, + ) + + def test_dr_covariates_survey_raises(self, staggered_survey_data, survey_design_weights_only): + """DR + covariates + survey should raise NotImplementedError.""" + data = staggered_survey_data.copy() + data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) + with pytest.raises(NotImplementedError, match="covariates"): + CallawaySantAnna(estimation_method="dr").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_design_weights_only, + ) + + def test_reg_covariates_survey_works(self, staggered_survey_data, survey_design_weights_only): + """Regression + covariates + survey should work (has nuisance IF correction).""" + data = staggered_survey_data.copy() + data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) + result = CallawaySantAnna(estimation_method="reg").fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + survey_design=survey_design_weights_only, + ) + assert np.isfinite(result.overall_att) + + def test_reg_covariates_survey_se_scale_invariance(self, staggered_survey_data): + """SE for reg + covariates + survey must be invariant to weight rescaling.""" + data = staggered_survey_data.copy() + data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) + data["weight2"] = data["weight"] * 4.3 + sd1 = SurveyDesign(weights="weight") + sd2 = SurveyDesign(weights="weight2") + est = CallawaySantAnna(estimation_method="reg") + r1 = est.fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + aggregate="simple", + survey_design=sd1, + ) + r2 = est.fit( + data, + "outcome", + "unit", + "period", + "first_treat", + covariates=["x1"], + aggregate="simple", + survey_design=sd2, + ) + assert np.isclose( + r1.overall_att, r2.overall_att, atol=1e-8 + ), "ATT not scale-invariant for reg+cov+survey" + assert np.isclose( + r1.overall_se, r2.overall_se, atol=1e-8 + ), f"SE not scale-invariant for reg+cov+survey: {r1.overall_se} vs {r2.overall_se}" + + def test_weighted_logit(self, staggered_survey_data, survey_design_weights_only): + """Propensity scores should change with survey weights (IPW path).""" + r_unw = CallawaySantAnna(estimation_method="ipw").fit( + staggered_survey_data, "outcome", "unit", "period", "first_treat" + ) + r_w = CallawaySantAnna(estimation_method="ipw").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + # Non-uniform weights should produce different ATT + # (propensity scores change with survey weights) + assert r_unw.overall_att != r_w.overall_att + + def test_ipw_survey_weight_composition(self, staggered_survey_data, survey_design_weights_only): + """w_survey x w_ipw should compose — ATT differs from unweighted IPW.""" + r_unw = CallawaySantAnna(estimation_method="ipw").fit( + staggered_survey_data, "outcome", "unit", "period", "first_treat" + ) + r_w = CallawaySantAnna(estimation_method="ipw").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + # Weighted IPW should produce different ATT than unweighted + assert abs(r_unw.overall_att - r_w.overall_att) > 1e-6 + + def test_aggregation_with_survey(self, staggered_survey_data, survey_design_weights_only): + """Simple aggregation should use survey weights.""" + r_unw = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="event_study", + ) + r_w = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + aggregate="event_study", + survey_design=survey_design_weights_only, + ) + # Event study ATTs should differ with non-uniform weights + assert r_unw.overall_att != r_w.overall_att + # Event study effects should exist + assert r_w.event_study_effects is not None + assert len(r_w.event_study_effects) > 0 + + def test_summary_includes_survey(self, staggered_survey_data, survey_design_weights_only): + """Summary output should include survey design section.""" + result = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=survey_design_weights_only, + ) + summary = result.summary() + assert "Survey Design" in summary + assert "pweight" in summary + + +# ============================================================================= +# TestTripleDifferenceIPWSurvey +# ============================================================================= + + +class TestTripleDifferenceIPWSurvey: + """Verify TripleDifference IPW/DR + survey is unblocked.""" + + def test_ipw_survey_no_longer_raises(self, ddd_survey_data): + """IPW + survey should no longer raise NotImplementedError.""" + sd = SurveyDesign(weights="weight") + # Should not raise + result = TripleDifference(estimation_method="ipw").fit( + ddd_survey_data, + "outcome", + "group", + "partition", + "time", + survey_design=sd, + ) + assert result is not None + + def test_dr_survey_no_longer_raises(self, ddd_survey_data): + """DR + survey should no longer raise NotImplementedError.""" + sd = SurveyDesign(weights="weight") + # Should not raise + result = TripleDifference(estimation_method="dr").fit( + ddd_survey_data, + "outcome", + "group", + "partition", + "time", + survey_design=sd, + ) + assert result is not None + + def test_ipw_survey_results_finite(self, ddd_survey_data): + """IPW + survey should produce finite results.""" + sd = SurveyDesign(weights="weight") + result = TripleDifference(estimation_method="ipw").fit( + ddd_survey_data, + "outcome", + "group", + "partition", + "time", + survey_design=sd, + ) + assert np.isfinite(result.att) + assert np.isfinite(result.se) + assert result.survey_metadata is not None + + def test_ipw_nonuniform_weights_change_att(self, ddd_survey_data): + """Non-uniform survey weights should change IPW ATT vs unweighted.""" + sd = SurveyDesign(weights="weight") + r_no = TripleDifference(estimation_method="ipw").fit( + ddd_survey_data, + "outcome", + "group", + "partition", + "time", + ) + r_sv = TripleDifference(estimation_method="ipw").fit( + ddd_survey_data, + "outcome", + "group", + "partition", + "time", + survey_design=sd, + ) + assert not np.isclose( + r_no.att, r_sv.att, atol=1e-6 + ), "Non-uniform survey weights should change IPW ATT" + + def test_dr_nonuniform_weights_change_att(self, ddd_survey_data): + """Non-uniform survey weights should change DR ATT vs unweighted.""" + sd = SurveyDesign(weights="weight") + r_no = TripleDifference(estimation_method="dr").fit( + ddd_survey_data, + "outcome", + "group", + "partition", + "time", + ) + r_sv = TripleDifference(estimation_method="dr").fit( + ddd_survey_data, + "outcome", + "group", + "partition", + "time", + survey_design=sd, + ) + assert not np.isclose( + r_no.att, r_sv.att, atol=1e-6 + ), "Non-uniform survey weights should change DR ATT" + + def test_ipw_uniform_weights_match_unweighted(self, ddd_survey_data): + """Uniform survey weights should match unweighted IPW result.""" + data = ddd_survey_data.copy() + data["uw"] = 1.0 + sd = SurveyDesign(weights="uw") + r_no = TripleDifference(estimation_method="ipw").fit( + data, + "outcome", + "group", + "partition", + "time", + ) + r_sv = TripleDifference(estimation_method="ipw").fit( + data, + "outcome", + "group", + "partition", + "time", + survey_design=sd, + ) + assert np.isclose(r_no.att, r_sv.att, atol=1e-6) + + def test_dr_uniform_weights_match_unweighted(self, ddd_survey_data): + """Uniform survey weights should match unweighted DR result.""" + data = ddd_survey_data.copy() + data["uw"] = 1.0 + sd = SurveyDesign(weights="uw") + r_no = TripleDifference(estimation_method="dr").fit( + data, + "outcome", + "group", + "partition", + "time", + ) + r_sv = TripleDifference(estimation_method="dr").fit( + data, + "outcome", + "group", + "partition", + "time", + survey_design=sd, + ) + assert np.isclose(r_no.att, r_sv.att, atol=1e-6) + + def test_ipw_covariate_survey_nonuniform(self, ddd_survey_data): + """IPW + covariates + non-uniform survey weights should change ATT.""" + data = ddd_survey_data.copy() + data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) + sd = SurveyDesign(weights="weight") + r_no = TripleDifference(estimation_method="ipw").fit( + data, + "outcome", + "group", + "partition", + "time", + covariates=["x1"], + ) + r_sv = TripleDifference(estimation_method="ipw").fit( + data, + "outcome", + "group", + "partition", + "time", + covariates=["x1"], + survey_design=sd, + ) + assert np.isfinite(r_sv.att) + assert np.isfinite(r_sv.se) + assert not np.isclose( + r_no.att, r_sv.att, atol=1e-6 + ), "Covariate IPW + non-uniform survey weights should change ATT" + + def test_dr_covariate_survey_nonuniform(self, ddd_survey_data): + """DR + covariates + non-uniform survey weights should change ATT.""" + data = ddd_survey_data.copy() + data["x1"] = np.random.default_rng(42).normal(0, 1, len(data)) + sd = SurveyDesign(weights="weight") + r_no = TripleDifference(estimation_method="dr").fit( + data, + "outcome", + "group", + "partition", + "time", + covariates=["x1"], + ) + r_sv = TripleDifference(estimation_method="dr").fit( + data, + "outcome", + "group", + "partition", + "time", + covariates=["x1"], + survey_design=sd, + ) + assert np.isfinite(r_sv.att) + assert np.isfinite(r_sv.se) + assert not np.isclose( + r_no.att, r_sv.att, atol=1e-6 + ), "Covariate DR + non-uniform survey weights should change ATT" + + +# ============================================================================= +# TestCallawaySantAnnaSurveyInference +# ============================================================================= + + +class TestCallawaySantAnnaSurveyInference: + """Validate CS survey inference beyond smoke tests.""" + + def test_se_scale_invariance_all_methods(self, staggered_survey_data): + """SE should be invariant under weight rescaling for all methods.""" + data = staggered_survey_data + data = data.copy() + data["weight2"] = data["weight"] * 3.7 + sd1 = SurveyDesign(weights="weight") + sd2 = SurveyDesign(weights="weight2") + + for method in ["reg", "ipw", "dr"]: + est = CallawaySantAnna(estimation_method=method) + r1 = est.fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="simple", + survey_design=sd1, + ) + r2 = est.fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="simple", + survey_design=sd2, + ) + assert np.isclose( + r1.overall_att, r2.overall_att, atol=1e-8 + ), f"{method}: ATT not scale-invariant" + assert np.isclose( + r1.overall_se, r2.overall_se, atol=1e-8 + ), f"{method}: SE not scale-invariant" + + def test_survey_weights_change_per_cell_att(self, staggered_survey_data): + """Non-uniform survey weights should change per-cell ATT(g,t).""" + data = staggered_survey_data + sd = SurveyDesign(weights="weight") + for method in ["reg", "ipw", "dr"]: + r_no = CallawaySantAnna(estimation_method=method).fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ) + r_sv = CallawaySantAnna(estimation_method=method).fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + survey_design=sd, + ) + # At least one cell ATT should differ + effects_no = [d["effect"] for d in r_no.group_time_effects.values()] + effects_sv = [d["effect"] for d in r_sv.group_time_effects.values()] + assert not np.allclose( + effects_no, effects_sv, atol=1e-6 + ), f"{method}: survey weights should change per-cell ATT" + + def test_survey_df_affects_pvalues(self, staggered_survey_data): + """Survey df (from strata/PSU) should affect p-values via t-distribution.""" + data = staggered_survey_data + sd_weights = SurveyDesign(weights="weight") + sd_full = SurveyDesign(weights="weight", strata="stratum", psu="psu") + est = CallawaySantAnna(estimation_method="reg") + r_w = est.fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="simple", + survey_design=sd_weights, + ) + r_f = est.fit( + data, + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + aggregate="simple", + survey_design=sd_full, + ) + # ATT should be same (same weights), but p-values differ (different df) + assert np.isclose(r_w.overall_att, r_f.overall_att, atol=1e-8) + # Survey df from strata/PSU should change inference + assert r_f.survey_metadata.df_survey is not None + + +# ============================================================================= +# TestScaleInvariance +# ============================================================================= + + +class TestScaleInvariance: + """Multiplying all survey weights by a constant should not change ATT or SE.""" + + def test_weight_scale_invariance_imputation(self, staggered_survey_data): + """ImputationDiD: 2*w gives same ATT as w.""" + sd1 = SurveyDesign(weights="weight") + r1 = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd1, + ) + + staggered_survey_data["weight_x2"] = staggered_survey_data["weight"] * 2.0 + sd2 = SurveyDesign(weights="weight_x2") + r2 = ImputationDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd2, + ) + + assert abs(r1.overall_att - r2.overall_att) < 1e-10 + assert abs(r1.overall_se - r2.overall_se) < 1e-8 + + def test_weight_scale_invariance_two_stage(self, staggered_survey_data): + """TwoStageDiD: 2*w gives same ATT as w.""" + sd1 = SurveyDesign(weights="weight") + r1 = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd1, + ) + + staggered_survey_data["weight_x2"] = staggered_survey_data["weight"] * 2.0 + sd2 = SurveyDesign(weights="weight_x2") + r2 = TwoStageDiD().fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd2, + ) + + assert abs(r1.overall_att - r2.overall_att) < 1e-10 + assert abs(r1.overall_se - r2.overall_se) < 1e-8 + + def test_weight_scale_invariance_callaway_santanna(self, staggered_survey_data): + """CallawaySantAnna: 2*w gives same ATT as w.""" + sd1 = SurveyDesign(weights="weight") + r1 = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd1, + ) + + staggered_survey_data["weight_x2"] = staggered_survey_data["weight"] * 2.0 + sd2 = SurveyDesign(weights="weight_x2") + r2 = CallawaySantAnna(estimation_method="reg").fit( + staggered_survey_data, + "outcome", + "unit", + "period", + "first_treat", + survey_design=sd2, + ) + + assert abs(r1.overall_att - r2.overall_att) < 1e-10 + assert abs(r1.overall_se - r2.overall_se) < 1e-8