pyjuice.nodes.distributions.MaskedCategorical
- class pyjuice.nodes.distributions.MaskedCategorical(num_cats: int, mask_mode: str)
A class representing Categorical distributions with masks.
- Parameters:
num_cats (str) – number of categories
mask_mode – type of mask; should be in [“range”, “full_mask”, “rev_range”]
Methods
num_param_flows
()The number of parameter flows per node.
num_parameters
()The number of parameters per node.
Attributes
need_meta_parameters
A flag indicating whether users need to pass in meta-parameters to the constructor of InputNodes.