pyjuice.nodes.distributions.MaskedCategorical

class pyjuice.nodes.distributions.MaskedCategorical(num_cats: int, mask_mode: str)

A class representing Categorical distributions with masks.

Parameters:
  • num_cats (str) – number of categories

  • mask_mode – type of mask; should be in [“range”, “full_mask”, “rev_range”]

__init__(num_cats: int, mask_mode: str)

Methods

num_param_flows()

The number of parameter flows per node.

num_parameters()

The number of parameters per node.

Attributes

need_meta_parameters

A flag indicating whether users need to pass in meta-parameters to the constructor of InputNodes.