pyjuice.queries.marginal
- pyjuice.queries.marginal(pc: TensorCircuit, data: torch.Tensor, missing_mask: torch.Tensor | None = None, fw_input_fn: str | Callable | None = None, **kwargs)
Compute the marginal probability given the assignment of a subset of variables, i.e., P(e).
- Parameters:
pc (TensorCircuit) – the input PC
data (torch.Tensor) – data of size [B, num_vars] (hard evidence) or a custom shape paired with fw_input_fn
missing_mask (torch.Tensor) – a boolean mask indicating marginalized variables; the size can be [num_vars] or [B, num_vars]
fw_input_fn (Optional[Union[str,Callable]]) – an optional custom function for the forward pass of input layers