pyjuice.transformations.prune_by_score

pyjuice.transformations.prune_by_score(root_nodes: CircuitNodes, key: str = '_scores', scores: Dict[CircuitNodes, ndarray | Tensor] | None = None, keep_frac: float | None = None, score_threshold: float | None = None, block_reduction: str = 'sum')

Prune sum-edge connections from a PC based on per-edge scores, returning a new, sparser PC. Edges are kept either by retaining the top keep_frac fraction of the highest-scoring edges, or by keeping all edges whose score is at least score_threshold (exactly one of the two must be given). Parameter-tied nodes and nodes without scores are left untouched.

Parameters:
  • root_nodes (CircuitNodes) – the root of the PC to prune

  • key (str) – the attribute name under which per-node edge scores are stored on each node (used when scores is not given)

  • scores (Optional[Dict[CircuitNodes, Tensor]]) – an explicit mapping from nodes to their edge-score tensors; overrides key when provided

  • keep_frac (Optional[float]) – the fraction of highest-scoring edges to keep; mutually exclusive with score_threshold

  • score_threshold (Optional[float]) – the minimum score for an edge to be kept; mutually exclusive with keep_frac

  • block_reduction (str) – how per-edge scores are reduced over a node block when blocks have size > 1 (e.g., “sum”)

Returns:

a new PC with the low-scoring edges removed

Return type:

CircuitNodes