threedb.evaluators.classification

Implements a basic evaluator for classification-based tasks.

class threedb.evaluators.classification.SimpleClassificationEvaluator(*, topk: int, classmap_path: str)

Bases: threedb.evaluators.base_evaluator.BaseEvaluator

A concrete implementation of the abstract threedb.evaluators.base_evaluator.BaseEvaluator that is designed for classification tasks.

output_type: str = 'classes'
KEYS: List[str] = ['is_correct', 'loss', 'prediction']
__init__(*, topk: int, classmap_path: str)

Initialize an Evaluator for classification

Parameters
  • topk (int) –

  • classmap_path (str) – a path to a JSON file mapping model UIDs to class numbers.

get_segmentation_label(model_uid: str)int

Given a model_uid, return a scalar label corresponding to the 3D model. This label is only used for the purpopses of generating a segmentation map. If the only goal is to separate the object from its background, this function can return anything greater than zero.

Parameters

model_uid (str) – Unique identifier for the model.

Returns

An integer which will be the color of the model in the segmentation map.

Return type

int

get_target(model_uid: str, render_output: Dict[str, torch.Tensor])Union[List[int], int]

See the docstring of threedb.evaluators.base_evaluator.BaseEvaluator.get_target().

declare_outputs()Dict[str, Tuple[List[int], str]]

See the docstring of threedb.evaluators.base_evaluator.BaseEvaluator.declare_outputs().

Returns

Return type

Dict[str, Tuple[List[int], type]]

summary_stats(pred: torch.Tensor, label: Union[List[int], int], input_shape: List[int])Dict[str, Union[bool, int, float, str, torch.Tensor]]

Concrete implementation of threedb.evaluators.base_evaluator.BaseEvaluator.summary_stats() (see that docstring for information on the abstract function). Returns correctness (binary value) and cross-entropy loss of the prediction.

Parameters
  • pred (ch.Tensor) – The output of the inference model: expected to be a 1D tensor of size (n_classes).

  • label (LabelType) – An integer or list of integers representing the target label

Returns

A dictionary containing the results to log from this evaluator, namely the correctness and the cross-entropy loss.

Return type

Dict[str, Output]

threedb.evaluators.classification.Evaluator

alias of threedb.evaluators.classification.SimpleClassificationEvaluator