pyjuice.queries.marginal

pyjuice.queries.marginal(pc: TensorCircuit, data: torch.Tensor, missing_mask: torch.Tensor | None = None, fw_input_fn: str | Callable | None = None, **kwargs)

Compute the marginal probability given the assignment of a subset of variables, i.e., P(e).

Parameters:
  • pc (TensorCircuit) – the input PC

  • data (torch.Tensor) – data of size [B, num_vars] (hard evidence) or a custom shape paired with fw_input_fn

  • missing_mask (torch.Tensor) – a boolean mask indicating marginalized variables; the size can be [num_vars] or [B, num_vars]

  • fw_input_fn (Optional[Union[str,Callable]]) – an optional custom function for the forward pass of input layers