Core
torchmeter.core.Meter
¶
A comprehensive instrumentation tool for PyTorch model performance analysis and visualization.
The Meter
class provides end-to-end measurement capabilities for neural networks, including
parameter statistics, computational cost analysis, memory usage tracking, inference time and
throughput analysis. It serves as a wrapper around PyTorch
modules while maintaining full compatibility
with native model operations.
Key Features:
-
Zero-Intrusion Proxy
- acts as drop-in decorator without any changes of the underlying model
- Seamlessly integrates with PyTorch modules while preserving full compatibility (attributes and methods)
-
Full-Stack Model Analytics: Holistic performance analytics across 5 dimensions:
- parameter distribution
- calculation cost: FLOPs/MACs
- memory access assessment
- inference latency
- throughput benchmarking
-
Rich visualization
- Programmable tabular reports with real-time rendering
- Hierarchical operation tree with smart folding of repeated blocks for model structure insights
-
Fine-Grained Customization
- Real-time hot-reload rendering: Dynamic adjustment of rendering configuration for operation trees, report tables and their nested components
- Progressive update: Namespace assignment + dictionary batch update
-
Config-Driven Runtime Management
- Centralized control: Singleton-managed global configuration for dynamic behavior adjustment
- Portable presets: Export/import YAML profiles for runtime behaviors, eliminating repetitive setup
-
Portability and Practicality
- Decoupled pipeline: Separation of data collection and visualization
- Automatic device synchronization: Maintains production-ready status by keeping model and data co-located
- Dual-mode reporting with export flexibility:
- Measurement units mode vs. raw data mode
- Multi-format export (
CSV
/Excel
) for analysis integration
Core Functionality:
-
Parameter Analysis
- Total/trainable parameter quantification
- Layer-wise parameter distribution analysis
- Gradient state tracking (requires_grad flags)
-
Computational Profiling
- FLOPs/MACs precision calculation
- Operation-wise calculation distribution analysis
- Dynamic input/output detection (number, type, shape, ...)
-
Memory Diagnostics
- Input/output tensor memory awareness
- Hierarchical memory consumption analysis
-
Performance Benchmarking
- Auto warm-up phase execution (eliminates cold-start bias)
- Device-specific high-precision timing
- Inference latency & Throughput Benchmarking
-
Visualization Engine
- Centralized configuration management
- Programmable tabular report
- Style customization and real-time rendering
- Dynamic table structure adjustment
- Real-time data analysis in programmable way
- Multi-format data export
- Rich-text hierarchical structure tree rendering
- Style customization and real-time rendering
- Smart module folding based on structural equivalence detection
-
Cross-Platform Support
- Automatic model-data co-location
- Seamless device transition (CPU/CUDA)
Attributes:
Name | Type | Description |
---|---|---|
ipt |
Dict[str, Any]
|
Input arguments for underlying model's |
device |
device
|
Current computation device for model and tensor data in input . |
model |
Module
|
The wrapped PyTorch model instance. |
optree |
OperationTree
|
A backend hierarchical data structure of model operations. |
tree_renderer |
TreeRenderer
|
A renderer for operation tree visualization. |
table_renderer |
TabularRenderer
|
A renderer for programmable tabular reports. |
ittp_warmup |
int
|
Number of warm-up(i.e., feed-forward inference) iterations before |
ittp_benchmark_time |
int
|
Number of benchmark iterations per operation in measuring |
tree_fold_repeat |
bool
|
Whether to fold repeated blocks in the rendered tree structure. |
tree_levels_args |
FlagNameSpace
|
Rendering configuration for various levels of rendered tree structure. |
tree_repeat_block_args |
FlagNameSpace
|
Rendering configuration for repeated blocks of rendered tree structure. |
table_display_args |
FlagNameSpace
|
Comprehensive rendering configuration for rendered tables. |
table_column_args |
FlagNameSpace
|
Rendering configuration for all of the rendered tables' columns. |
structure |
Tree
|
A stylized tree representation of the model's operation hierarchy. |
param |
ParamsMeter
|
A ParamsMeter instance containing the measured parameter-related statistics. |
cal |
CalMeter
|
A CalMeter instance containing the measured computational cost data. |
mem |
MemMeter
|
A MemMeter instance containing the measured memory usage data. |
ittp |
IttpMeter
|
A IttpMeter instance containing fresh inference time and throughput data. |
model_info |
Text
|
A |
subnodes |
List[str]
|
A list of all nodes in the operation tree with their IDs and names. |
Methods:
Name | Description |
---|---|
__call__ |
Execute model inference while maintaining input and model device synchronization. |
__getattr__ |
Transparently proxy attribute access to the underlying model when not found in Meter instance. |
__setattr__ |
Prioritize setting attributes on Meter instance first, falling back to the underlying model. |
__delattr__ |
Try to delete attributes from Meter instance first, fall back to underlying model if needed. |
to |
Move the model to the specified device while keeping input and model device synchronization. |
profile |
Render a tabular report of the specified statistics with rich visualization. |
table_cols |
Get all column names of the backend dataframe for the specified statistics. |
stat_info |
Generates a formatted summary of the specified statistics. |
overview |
Generates an overview of all statistics in a formatted layout. |
rebase |
Rebases the Meter instance to a specific node in the operation tree. |
Note
- Requires at least one forward pass before most measurements become available.
- Implements lazy evaluation and cache for most statistics (i.e.
param
,cal
,mem
).
Example
import torch
from rich import print
from torchmeter import Meter, get_config
from torchvision import models
# prepare your torch model
underlying_model = models.resnet152()
# wrap the model with Meter class
model = Meter(underlying_model)
# Basic usage
input = torch.randn(1, 3, 224, 224)
output = model(input) # Standard model execution
# Performance analysis
print(model.structure) # Visualize model hierarchy
print(model.param) # Show parameter statistics
model.profile("cal") # Display computational cost table
# Automatic device synchronization
if torch.cuda.is_available():
model.to("cuda")
model(input)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module
|
|
required |
device
|
Optional[Union[str, device]]
|
Target device for model execution and measurement.
Accepts either device string (e.g., |
None
|
Raises:
Type | Description |
---|---|
TypeError
|
If provided model is not a |
UserWarning
|
When device is not specified and model contains no parameters (fallback to |
Initialization performs following key operations:
-
Device configuration:
- Uses specified device or auto-detects via model parameters
- Moves model to target device
-
Measurement infrastructure setup:
- Initializes input capture dictionary (
_ipt
) - Builds operation tree (
optree
) for model structure analysis - Prepares renderers for visualization (
tree_renderer
,table_renderer
)
- Initializes input capture dictionary (
-
Measurement state initialization:
- Resets measurement flags (
param
/cal
/mem
) - Sets default benchmark parameters (
ittp_warmup
=50,ittp_benchmark_time
=100) - Initializes accuracy warning trackers (
_has_nocall_nodes
,_has_not_support_nodes
)
- Resets measurement flags (
Example
Source code in torchmeter/core.py
def __init__(self, model: nn.Module, device: Optional[Union[str, tc_device]] = None) -> None:
"""Initialize a Meter instance for model performance measurement and visualization.
Args:
model (nn.Module): `PyTorch` model to be instrumented for measurement
device (Optional[Union[str, torch.device]]): Target device for model execution and measurement.
Accepts either device string (e.g., `cuda:0`) or
`torch.device` object. If `None`, automatically detects
model's current device via its parameters.
Raises:
TypeError: If provided model is not a `nn.Module` instance
UserWarning: When device is not specified and model contains no parameters (fallback to `CPU`)
Notes:
Initialization performs following key operations:
1. Device configuration:
- Uses specified device or auto-detects via model parameters
- Moves model to target device
2. Measurement infrastructure setup:
- Initializes input capture dictionary (`_ipt`)
- Builds operation tree (`optree`) for model structure analysis
- Prepares renderers for visualization (`tree_renderer`, `table_renderer`)
3. Measurement state initialization:
- Resets measurement flags (`param`/`cal`/`mem`)
- Sets default benchmark parameters (`ittp_warmup`=50, `ittp_benchmark_time`=100)
- Initializes accuracy warning trackers (`_has_nocall_nodes`, `_has_not_support_nodes`)
Example:
```python
from torchmeter import Meter
from torchvision import models
underlying_model = models.resnet18()
# auto detect device
model = Meter(underlying_model)
# init a gpu model
model = Meter(underlying_model, device="cuda")
model = Meter(underlying_model, device="cuda:1")
```
"""
from torchmeter.engine import OperationTree
from torchmeter.display import TreeRenderer, TabularRenderer
if not isinstance(model, nn.Module):
raise TypeError(f"model must be a nn.Module, but got `{type(model).__name__}`.")
device = device or self.__device_detect(model)
self.__device = tc_device(device) if isinstance(device, str) else device
self.model = model.to(self.__device)
self._ipt: IPT_TYPE = {"args": tuple(), "kwargs": dict()} # TODO: self.ipt_infer()
self.optree = OperationTree(self.model)
self.tree_renderer = TreeRenderer(self.optree.root)
self.table_renderer = TabularRenderer(self.optree.root)
self.__measure_param = False
self.__measure_cal = False
self.__measure_mem = False
self.ittp_warmup = 50
self.ittp_benchmark_time = 100
self.__has_nocall_nodes: Optional[bool] = None
self.__has_not_support_nodes: Optional[bool] = None
cal
property
¶
Measures the calculation cost of the model during inference.
This property calculates the computational cost (i.e., FLOPs
and MACs
) for each node in the
operation tree during a feed-forward inference pass.
Returns:
Name | Type | Description |
---|---|---|
CalMeter |
CalMeter
|
A CalMeter instance containing the measured computational cost data. |
Raises:
Type | Description |
---|---|
RuntimeError
|
If no input data has been provided (i.e., |
Notes
-
You must first invoke the Meter instance (via a forward pass) before accessing this property.
-
The measurement is performed only once for each Meter instance. Subsequent accesses will return the cached result.
-
The measurement results depend on the model input, and different input tensor sizes will lead to varying calculation costs, which is normal. For consistent and comparable results, we recommend using a single sample for measuring all statistics including
cal
. This can be achieved by providing a single-sample forward pass to the meter instance whenever you want.
device
property
writable
¶
The device where the model and all input tensors are currently located.
Returns:
Type | Description |
---|---|
device
|
torch.device: Current device as a torch.device object. |
ipt
property
¶
Captured underlying model input dictionary.
Returns:
Type | Description |
---|---|
IPT_TYPE
|
Dict[str, Any]: A dictionary containing the captured model input. |
Notes
-
This property is read-only and cannot be directly modified
-
Returned dictionary containing:
args
(tuple
): Positional arguments passed to theforward()
of the underlying model.kwargs
(dict
): Keyword arguments passed to theforward()
of the underlying model
-
Input can only be set/updated through
Meter
instance calls (i.e., feed-forward inference of the origin model) -
If there exists tensor data, its dimensions might directly impact the measurement results of multiple statistics (e.g.
cal
,mem
,ittp
). For consistent and comparable results, we recommend using a single sample for measuring all statistics. This can be achieved by providing a single-sample forward pass to the meter instance whenever you want.
ittp
property
¶
Measures the inference time and throughput of the model.
This property calculates the inference time and throughput for each node in the operation tree.
It performs a warm-up phase followed by a benchmark phase to ensure accurate measurements.
The results are returned as an IttpMeter
object.
Returns:
Name | Type | Description |
---|---|---|
IttpMeter |
IttpMeter
|
A IttpMeter instance containing fresh inference time and throughput data. |
Raises:
Type | Description |
---|---|
RuntimeError
|
If no input data has been provided (i.e., |
TypeError
|
If |
ValueError
|
If |
Notes
-
You must first invoke the Meter instance (via a forward pass) before accessing this property.
-
The measurements are performed on the device specified by
meter_instance.device
!!! -
The unit
IPS
means Input Per Second, which is the number of inferences with given input per second. -
Unlike other statistics, the measured result is not cached, so it will be re-measured every time
ittp
attribute is accessed. -
The warm-up phase runs for
meter_instance.ittp_warmup
iterations to stabilize the measurements. -
The benchmark phase runs for
meter_instance.ittp_benchmark_time
iterations per operation. -
The measurement results depend on the model input, and different input tensor sizes will lead to varying latencies and throughput, which is normal. For consistent and comparable results, we recommend using a single sample for measuring all statistics including
ittp
. This can be achieved by providing a single-sample forward pass to the meter instance whenever you want.
mem
property
¶
Measures the memory cost of the model during inference.
This property calculates the memory usage for each node in the operation tree during a feed-forward inference pass.
Returns:
Name | Type | Description |
---|---|---|
MemMeter |
MemMeter
|
A MemMeter instance containing the measured memory usage data. |
Raises:
Type | Description |
---|---|
RuntimeError
|
If no input data has been provided (i.e., |
Notes
-
You must first invoke the Meter instance (via a forward pass) before accessing this property.
-
The measurement is performed only once for each Meter instance. Subsequent accesses will return the cached result.
-
The measurement results depend on the model input, and different input tensor sizes will lead to varying memory costs, which is normal. For consistent and comparable results, we recommend using a single sample for measuring all statistics including
mem
. This can be achieved by providing a single-sample forward pass to the meter instance whenever you want.
model_info
property
¶
Generates a formatted summary of the model's basic information.
This property provides a detailed summary of the model, including its name, device, forward method signature, and structured input representation.
Returns:
Name | Type | Description |
---|---|---|
Text |
Text
|
A |
Notes
-
If no input has been provided (i.e.,
self._ipt
is empty), the input representation will indicate that it is not provided. -
Otherwise, all the values in
self._ipt
will correspond to the formal arguments of theforward
method, and a structured input representation with type prompts will be generated through thetorchmeter.utils.data_repr
function.
param
property
¶
Measures the number of model parameters.
This property calculates the parameter-related metrics (e.g., number of parameters, trainable parameters) for each node in the operation tree.
Returns:
Name | Type | Description |
---|---|---|
ParamsMeter |
ParamsMeter
|
A ParamsMeter instance containing the measured parameter-related statistics. |
Notes
The measurement is performed only once for each Meter instance. Subsequent accesses will return the cached result.
structure
property
¶
Generate a stylized tree representation of the model's operation hierarchy.
This property renders the operation tree based on current configuration settings. The rendering strategy (folded/unfolded) and customization options are determined by the active configuration parameters. Caching is applied to optimize rendering performance when configuration remains unchanged.
Returns:
Name | Type | Description |
---|---|---|
Tree |
Tree
|
A |
Notes
-
Configuration parameters influence rendering behavior, you can access them directly by
metered_model.<param_name>
tree_fold_repeat
: Controls whether a repeated block is rendered as a single block. Default to True.tree_levels_args
: Customizes rendering at different tree levels.tree_repeat_block_args
: Detailed parameters to control the rendering of the repeat blocks.
-
Caching mechanism: Reuses cached render result if all configuration parameters remain unchanged since last render.
-
For information on repeated blocks identification and rendering, please refer to the description of
Meter.tree_fold_repeat
property.
Example
from rich import print
from torchmeter import Meter
from torchvision import models
underlying_model = models.vit_b_16()
model = Meter(underlying_model)
# use the default configuration
print(model.structure)
# reaccess the structure, will be quickly returned
print(model.structure)
# use a custom configuration
model.tree_fold_repeat = False
model.tree_levels_args = {"default": {"guide_style": "red"}}
print(model.structure)
subnodes
property
¶
Retrieves a list of all nodes in the operation tree with their IDs and names.
This property returns a formatted list of all nodes in the operation tree, where each node is represented by its ID and name. This is useful for identifying specific nodes when rebasing or inspecting the tree structure.
Returns:
Type | Description |
---|---|
List[str]
|
A list of strings, each formatted as |
table_column_args
property
writable
¶
Gets column rendering configuration for rendered tables.
This property directly binds to torchmeter.display.TabularRenderer.col_args
to get column-level rendering configuration (e.g., style, justify) for tables
generated via Meter.profile()
. The configuration persists across all subsequent
table renderings until explicitly modified.
Returns:
Name | Type | Description |
---|---|---|
FlagNameSpace |
FlagNameSpace
|
A namespace containing concrete configuration names.
Accessible keys match valid arguments of |
table_display_args
property
writable
¶
Gets comprehensive rendering configuration for rendered tables.
This property directly binds to torchmeter.display.TabularRenderer.tb_args
to get rendering configuration (e.g., style, highlight) for tables generated
via Meter.profile()
. The configuration persists across all subsequent table
renderings until explicitly modified.
Returns:
Name | Type | Description |
---|---|---|
FlagNameSpace |
FlagNameSpace
|
A namespace containing concrete configuration names.
Accessible keys match valid arguments of |
tree_fold_repeat
property
writable
¶
Controls whether repeated tree blocks are rendered as collapsed panels.
This property directly binds to the tree_fold_repeat
property in the global configuration.
When enabled, repeated operation blocks are collapsed into a single panel during tree rendering
via Meter.structure
.
Returns:
Name | Type | Description |
---|---|---|
bool |
bool
|
True to collapse repeated blocks, False to expand them. |
Note
-
Repeated blocks are identified only when two operations exhibit structural equivalence in:
- Their own parameter signatures
- Their child operations' hierarchical parameters
- The execution order within the operation if it is a container.
-
The folding feature activates exclusively for such validated repetitive patterns. All other structures render sequentially following their topological order.
-
If your model doesn't have the repeated blocks mentioned above (like
AlexNet
), setting this property True or False won't affect the output.
tree_levels_args
property
writable
¶
Gets rendering configuration for various levels of rendered tree structure.
This property directly binds to torchmeter.display.TreeRenderer.tree_levels_args
to get rendering configuration (e.g., label, guide_style) for various levels of rendered
tree structure generated via Meter.structure
property. The configuration persists across
all subsequent tree renderings until explicitly modified.
Returns:
Name | Type | Description |
---|---|---|
FlagNameSpace |
FlagNameSpace
|
A nested namespace where the outer-layer keys are the specific tree levels,
and the values are the configuration namespaces for the corresponding levels.
In each configuration namespace, the keys contain the specific configuration names,
which match the valid parameters of |
tree_repeat_block_args
property
writable
¶
Gets rendering configuration for repeated blocks of rendered tree structure.
This property directly binds to torchmeter.display.TreeRenderer.repeat_block_args
to get rendering configuration (e.g., style
, highlight
) for repeated blocks of rendered
tree structure generated via Meter.structure
property. The configuration persists across
all subsequent tree renderings until explicitly modified.
Returns:
Name | Type | Description |
---|---|---|
FlagNameSpace |
FlagNameSpace
|
A namespace containing concrete configuration names.
Accessible keys match valid arguments of |
__call__
¶
Execute model inference while maintaining input and model device synchronization.
This method performs three key operations in order:
- Captures input arguments of underlying model's
forward
method for measurement purposes - Align the device where the input tensor is located with the device where the model on.
- Executes the underlying model's feed-forward inference
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args
|
Positional arguments of the underlying model's |
()
|
|
**kwargs
|
Keyword arguments of the underlying model's |
{}
|
Returns:
Name | Type | Description |
---|---|---|
Any |
Any
|
Output from the underlying model's feed-forward inference |
Raises:
Type | Description |
---|---|
RuntimeError
|
If input data is needed but not provided.
(triggered by |
Notes
-
From a macroscopic perspective, this is equivalent to direct model invocation:
meter_instance(input)
==model(input)
-
You can safely input tensors from different devices; automatic synchronization is handled:
- Moves all tensors in the input to current device via
_ipt2device()
- Ensures model is on current device before execution
- Moves all tensors in the input to current device via
-
Subsequent calls perform two key operations:
- Overwrite captured inputs, enabling
ipt
updates through normal model invocation - Clear cached measurements when input differs (determined by
Meter.__is_ipt_changed()
rules)
- Overwrite captured inputs, enabling
-
If there exists tensor data, its dimensions might directly impact the measurement results of multiple statistics (e.g.
cal
,mem
,ittp
). For consistent and comparable results, we recommend using a single sample for measuring all statistics. This can be achieved by passing in a single batch of sample data whenever you want.
Example
import torch
import torch.nn as nn
from torchmeter import Meter
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv = nn.Conv2d(3, 10, 3)
def forward(self, x, y=1):
return self.conv(x) + y
underlying_model = MyModel()
model = Meter(underlying_model, device="cuda:0")
# Standard invocation
output = model(torch.randn(1, 3, 224, 224))
# Mixed argument types
output = model(torch.randn(1, 3, 224, 224), y=2)
Source code in torchmeter/core.py
def __call__(self, *args, **kwargs) -> Any:
"""Execute model inference while maintaining input and model device synchronization.
This method performs three key operations in order:
1. Captures input arguments of underlying model's `forward` method for measurement purposes
2. Align the device where the input tensor is located with the device where the model on.
3. Executes the underlying model's feed-forward inference
Args:
*args: Positional arguments of the underlying model's `forward` method
**kwargs: Keyword arguments of the underlying model's `forward` method
Returns:
Any: Output from the underlying model's feed-forward inference
Raises:
RuntimeError: If input data is needed but not provided.
(triggered by `_ipt2device()` method).
Notes:
- From a macroscopic perspective, this is equivalent to direct model invocation:
`meter_instance(input)` == `model(input)`
- You can safely input tensors from different devices; automatic synchronization is handled:
- Moves all tensors in the input to current device via `_ipt2device()`
- Ensures model is on current device before execution
- Subsequent calls perform two key operations:
1. Overwrite captured inputs, enabling `ipt` updates through normal model invocation
2. Clear cached measurements when input differs (determined by `Meter.__is_ipt_changed()` rules)
- If there exists tensor data, its dimensions might directly impact the measurement results
of multiple statistics (e.g. `cal`, `mem`, `ittp`). For consistent and comparable results,
we recommend using **a single sample** for measuring all statistics. This can be achieved
by passing in a single batch of sample data whenever you want.
Example:
```python
import torch
import torch.nn as nn
from torchmeter import Meter
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv = nn.Conv2d(3, 10, 3)
def forward(self, x, y=1):
return self.conv(x) + y
underlying_model = MyModel()
model = Meter(underlying_model, device="cuda:0")
# Standard invocation
output = model(torch.randn(1, 3, 224, 224))
# Mixed argument types
output = model(torch.randn(1, 3, 224, 224), y=2)
```
"""
new_ipt: IPT_TYPE = {"args": args, "kwargs": kwargs}
if self.__is_ipt_changed(new_ipt):
self.__measure_param = False
self.__measure_cal = False
self.__measure_mem = False
self._ipt = new_ipt
self._ipt2device()
self.model.to(self.device)
return self.model(*self._ipt["args"], **self._ipt["kwargs"])
__delattr__
¶
Try to delete attributes from Meter instance first, fall back to underlying model if needed.
This method ensures:
- Class attributes defined in
Meter
cannot be deleted - Instance attributes are deleted along with corresponding model attributes (if exists)
- Attributes prefixed with "ORIGIN_" will delete the actual model attribute after removing prefix
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name
|
str
|
Name of the attribute to delete |
required |
Raises:
Type | Description |
---|---|
AttributeError
|
|
Notes
-
When encountering conflicting attribute names between Meter instance and the model: The Meter instance's attribute will be prioritized for deletion by default. To delete the underlying model's attribute with same name, prepend "ORIGIN_" prefix. Example:
del meter_instance.ORIGIN_param
will delete model'sparam
attribute -
This implementation ensures the Meter instance can be seamlessly used as a drop-in replacement for the underlying model without requiring code modifications
Source code in torchmeter/core.py
def __delattr__(self, name: str) -> None:
"""Try to delete attributes from Meter instance first, fall back to underlying model if needed.
This method ensures:
1. Class attributes defined in `Meter` cannot be deleted
2. Instance attributes are deleted along with corresponding model attributes (if exists)
3. Attributes prefixed with "ORIGIN_" will delete the actual model attribute after removing prefix
Args:
name (str): Name of the attribute to delete
Raises:
AttributeError:
- When trying to delete Meter's class attributes
- When attempting to delete non-existent attributes
- When failed to delete attribute from both Meter instance and the underlying model
Notes:
- When encountering conflicting attribute names between Meter instance and the model:
The Meter instance's attribute will be prioritized for deletion by default.
To delete the underlying model's attribute with same name, prepend "ORIGIN_" prefix.
Example: `del meter_instance.ORIGIN_param` will delete model's `param` attribute
- This implementation ensures the Meter instance can be seamlessly used as a drop-in replacement
for the underlying model without requiring code modifications
"""
cls_attrs: Dict[str, bool] = self.__get_clsattr_with_settable_flag()
if name in cls_attrs:
raise AttributeError(f"`{name}` could never be deleted.")
try:
# delete the property with same name defined in Meter from origin model
if name.startswith("ORIGIN_"):
name = name[7:]
raise AttributeError
super().__delattr__(name)
except AttributeError:
delattr(self.model, name)
__device_detect
¶
Detects the device where the model are located via model's parameters.
This method detects the model's device by checking its parameters' location. If no parameters are found, it will raise a warning and move the model to CPU by default.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module
|
The model whose device is to be detected. |
required |
Returns:
Type | Description |
---|---|
Union[str, device]
|
Union[str, torch.device]: The device where the model's parameters are located. If no parameters are found, |
Union[str, device]
|
returns 'cpu' as the default device. |
Raises:
Type | Description |
---|---|
UserWarning
|
If the model has no parameters, a warning is issued indicating that the model will be moved to CPU for subsequent analysis. |
Source code in torchmeter/core.py
def __device_detect(self, model: nn.Module) -> Union[str, tc_device]:
"""Detects the device where the model are located via model's parameters.
This method detects the model's device by checking its parameters' location.
If no parameters are found, it will raise a warning and move the model to CPU by default.
Args:
model (nn.Module): The model whose device is to be detected.
Returns:
Union[str, torch.device]: The device where the model's parameters are located. If no parameters are found,
returns 'cpu' as the default device.
Raises:
UserWarning: If the model has no parameters, a warning is issued indicating that the model will be moved
to CPU for subsequent analysis.
"""
import warnings
try:
model_first_param = next(model.parameters())
return model_first_param.device
except StopIteration:
warnings.warn(
category=UserWarning,
message="We can't detect the device where your model is located because no parameter was found "
+ "in your model. We'll move your model to CPU and do all subsequent analysis based on this CPU "
+ "version. If this isn't what you want, change the device with `to` method, "
+ "e.g. `metered_model.to('cuda')`.",
)
return "cpu"
__get_clsattr_with_settable_flag
classmethod
¶
Determines which class attributes have setter methods defined.
This method iterates over all properties of the class and checks if a setter method is defined for each property. It returns a dictionary mapping attribute names to a boolean indicating whether the attribute is settable.
Returns:
Type | Description |
---|---|
Dict[str, bool]
|
Dict[str, bool]: A dictionary where keys are attribute names and values indicate |
Dict[str, bool]
|
whether the attribute has a setter method (True if settable, False otherwise). |
Source code in torchmeter/core.py
@classmethod
def __get_clsattr_with_settable_flag(cls) -> Dict[str, bool]:
"""Determines which class attributes have setter methods defined.
This method iterates over all properties of the class and checks if a setter method
is defined for each property. It returns a dictionary mapping attribute names to a
boolean indicating whether the attribute is settable.
Returns:
Dict[str, bool]: A dictionary where keys are attribute names and values indicate
whether the attribute has a setter method (True if settable, False otherwise).
"""
return {k: v.fset is not None for k, v in cls.__dict__.items()
if isinstance(v, property)} # fmt: skip
__getattr__
¶
Transparently proxy attribute access to the underlying model when not found in Meter instance
This method enables seamless attribute access to the wrapped model while maintaining Meter's own attributes. It follows these resolution rules:
- Directly returns Meter's own attributes if they exist
- For attributes prefixed with "ORIGIN_", returns the underlying model's attribute with the prefix removed
- Otherwise, returns the corresponding attribute from the wrapped model
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name
|
str
|
Name of the attribute to retrieve |
required |
Returns:
Name | Type | Description |
---|---|---|
Any |
Any
|
The value of the requested attribute from either Meter instance or underlying model |
Raises:
Type | Description |
---|---|
AttributeError
|
|
Notes
-
Attribute resolution priority: Meter's own attributes > underlying model's attributes (unless "ORIGIN_" prefix is used)
-
To bypass Meter's attributes and directly access model's attributes with same name: Use "ORIGIN_" prefix (e.g.,
meter.ORIGIN_param
maps tomodel.param
) -
This implementation ensures the Meter instance can be seamlessly used as a drop-in replacement for the underlying model without requiring code modifications
Source code in torchmeter/core.py
def __getattr__(self, name: str) -> Any:
"""Transparently proxy attribute access to the underlying model when not found in Meter instance
This method enables seamless attribute access to the wrapped model while maintaining Meter's
own attributes. It follows these resolution rules:
1. Directly returns Meter's own attributes if they exist
2. For attributes prefixed with "ORIGIN_", returns the underlying model's attribute with the prefix removed
3. Otherwise, returns the corresponding attribute from the wrapped model
Args:
name (str): Name of the attribute to retrieve
Returns:
Any: The value of the requested attribute from either Meter instance or underlying model
Raises:
AttributeError:
- When the attribute does not exist in both Meter instance and underlying model
- When using "ORIGIN_" prefix with non-existent attribute in underlying model
Notes:
- Attribute resolution priority:
Meter's own attributes > underlying model's attributes (unless "ORIGIN_" prefix is used)
- To bypass Meter's attributes and directly access model's attributes with same name:
Use "ORIGIN_" prefix (e.g., `meter.ORIGIN_param` maps to `model.param`)
- This implementation ensures the Meter instance can be seamlessly used as a drop-in replacement
for the underlying model without requiring code modifications
"""
try:
# get the property with same name defined in Meter from origin model
if name.startswith("ORIGIN_"):
name = name[7:]
raise AttributeError
return super().__getattribute__(name)
except AttributeError:
return getattr(self.model, name)
__is_ipt_changed
¶
Determines if the new input differs from the current captured input.
Compares both positional arguments (args) and keyword arguments (kwargs) between current and new input:
- For Tensor arguments: Checks shape and dtype equivalence
- For non-Tensor arguments: Performs value equality check
- Verifies argument structure consistency (same length for args, same keys for kwargs)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
new_ipt
|
IPT_TYPE
|
New input arguments to compare against currently stored input.
It is a dictionary with two keys:
- |
required |
Returns:
Name | Type | Description |
---|---|---|
bool |
bool
|
|
Source code in torchmeter/core.py
def __is_ipt_changed(self, new_ipt: IPT_TYPE) -> bool: # noqa: C901
"""Determines if the new input differs from the current captured input.
Compares both positional arguments (args) and keyword arguments (kwargs) between current and new input:
- For Tensor arguments: Checks shape and dtype equivalence
- For non-Tensor arguments: Performs value equality check
- Verifies argument structure consistency (same length for args, same keys for kwargs)
Args:
new_ipt: New input arguments to compare against currently stored input.
It is a dictionary with two keys:
- `args`: A tuple containing all positional arguments.
- `kwargs`: A dictionary containing all keyword arguments.
Returns:
bool: `True` if any of following conditions met:
1. Current input is empty (first-time input)
2. Positional arguments differ in length/value type/shape (for Tensors)
3. Keyword arguments have different keys or values
4. Any argument value differs (non-Tensor) or tensor properties differ (Tensor)
"""
if self._is_ipt_empty():
return True
is_changed = False
# check anonymous arguments
if len(self._ipt["args"]) != len(new_ipt["args"]):
return True
for origin, new in zip(self._ipt["args"], new_ipt["args"]):
if type(origin) is not type(new):
is_changed = True
elif isinstance(origin, Tensor):
is_changed = origin.shape != new.shape or origin.dtype != new.dtype
else:
is_changed = origin != new
if is_changed:
return True
# check named arguments
if set(self._ipt["kwargs"].keys()) != set(new_ipt["kwargs"].keys()):
return True
for k, origin in self._ipt["kwargs"].items():
new = new_ipt["kwargs"][k]
if type(origin) is not type(new):
is_changed = True
elif isinstance(origin, Tensor):
is_changed = origin.shape != new.shape or origin.dtype != new.dtype
else:
is_changed = origin != new
if is_changed:
return True
return False
__setattr__
¶
Prioritize setting attributes on Meter instance first, falling back to the underlying model.
This method ensures:
1. Class attributes defined in Meter
that cannot be modified are blocked
2. Instance attributes are set first, falling back to model attributes if not present
3. Attributes prefixed with "ORIGIN_" will set the underlying model attribute after removing prefix
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name
|
str
|
Name of the attribute to assign |
required |
value
|
Any
|
Value to be assigned to the attribute |
required |
Raises:
Type | Description |
---|---|
AttributeError
|
|
Notes
-
When encountering conflicting attribute names between Meter instance and the model: The Meter instance's attribute will be prioritized for assignment by default. To assign the underlying model's attribute with same name, prepend "ORIGIN_" prefix. Example:
meter_instance.ORIGIN_param = 1
will set model'sparam
attribute to 1 -
This implementation ensures the Meter instance can be seamlessly used as a drop-in replacement for the underlying model without requiring code modifications
-
Non-modifiable Meter class attributes are attributes defined by
@property
but without a setter. include:ipt
structure
param
cal
mem
ittp
model_info
subnodes
Source code in torchmeter/core.py
def __setattr__(self, name: str, value: Any) -> None:
"""Prioritize setting attributes on Meter instance first, falling back to the underlying model.
This method ensures:
1. Class attributes defined in `Meter` that cannot be modified are blocked
2. Instance attributes are set first, falling back to model attributes if not present
3. Attributes prefixed with "ORIGIN_" will set the underlying model attribute after removing prefix
Args:
name (str): Name of the attribute to assign
value (Any): Value to be assigned to the attribute
Raises:
AttributeError:
- When attempting to set non-modifiable Meter class attributes.
- When attribute assignment fails for both `Meter` instance and the underlying model
Notes:
- When encountering conflicting attribute names between Meter instance and the model:
The Meter instance's attribute will be prioritized for assignment by default.
To assign the underlying model's attribute with same name, prepend "ORIGIN_" prefix.
Example: `meter_instance.ORIGIN_param = 1` will set model's `param` attribute to 1
- This implementation ensures the Meter instance can be seamlessly used as a drop-in replacement
for the underlying model without requiring code modifications
- Non-modifiable Meter class attributes are attributes defined by `@property` but without a setter.
include:
1. `ipt`
2. `structure`
3. `param`
4. `cal`
5. `mem`
6. `ittp`
7. `model_info`
8. `subnodes`
"""
cls_attrs: Dict[str, bool] = self.__get_clsattr_with_settable_flag()
notchange_cls_attrs = [k for k, v in cls_attrs.items() if not v]
if name in notchange_cls_attrs:
raise AttributeError(f"`{name}` could never be set.")
try:
# set the property with same name defined in Meter from origin model
if name.startswith("ORIGIN_"):
name = name[7:]
raise AttributeError
super().__setattr__(name, value)
except AttributeError:
setattr(self.model, name, value)
overview
¶
Generates an overview of all statistics in a formatted layout.
This method creates a visual overview of model statistics, including basic model information and core data of each specified statistic. You can customize the statistics contained in the rendering results and their order by passing in the statistics you want in the order you prefer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*order
|
str
|
The names of the statistics to include in the overview. If not provided, all available statistics are included. |
()
|
show_warning
|
bool
|
Whether to display warnings for potentially inaccurate results. Defaults to True. |
True
|
Returns:
Name | Type | Description |
---|---|---|
Columns |
Columns
|
A |
Raises:
Type | Description |
---|---|
ValueError
|
If any of the provided statistics names are invalid. |
Example
from torch import randn
from torchmeter import Meter
from torchvision import models
underlying_model = models.resnet18()
model = Meter(underlying_model)
model(randn(1, 3, 224, 224))
# overview all statistics (i.e. param, cal, mem, ittp)
model.overview()
# only overview `cal` and `param`
# and the order is `cal` then `param`
model.overview("cal", "param")
Source code in torchmeter/core.py
def overview(self, *order: str, show_warning: bool = True) -> Columns:
"""Generates an overview of all statistics in a formatted layout.
This method creates a visual overview of model statistics, including basic model
information and core data of each specified statistic. You can customize the statistics
contained in the rendering results and their order by passing in the statistics you want
in the order you prefer.
Args:
*order (str): The names of the statistics to include in the overview. If not provided,
all available statistics are included.
show_warning (bool): Whether to display warnings for potentially inaccurate results.
Defaults to True.
Returns:
Columns: A `rich.Columns` object containing the formatted overview.
Raises:
ValueError: If any of the provided statistics names are invalid.
Example:
```python
from torch import randn
from torchmeter import Meter
from torchvision import models
underlying_model = models.resnet18()
model = Meter(underlying_model)
model(randn(1, 3, 224, 224))
# overview all statistics (i.e. param, cal, mem, ittp)
model.overview()
# only overview `cal` and `param`
# and the order is `cal` then `param`
model.overview("cal", "param")
```
"""
from functools import partial
from rich.box import HORIZONTALS
from rich.panel import Panel
order = order or self.optree.root.statistics
invalid_stat = tuple(filter(lambda x: x not in self.optree.root.statistics, order))
if len(invalid_stat) > 0:
raise ValueError(f"Invalid statistics: {invalid_stat}")
container = Columns(expand=True, align="center")
format_cell = partial(Panel, safe_box=True, expand=False, highlight=True, box=HORIZONTALS)
container.add_renderable(format_cell(self.model_info, title="[b]Model INFO[/]", border_style="orange1"))
container.renderables.extend([
format_cell(
self.stat_info(stat_name, show_warning=show_warning),
title=f"[b]{stat_name.capitalize()} INFO[/]",
border_style="cyan",
)
for stat_name in order
])
return container
profile
¶
profile(stat_name: str, show: bool = True, no_tree: bool = False, **tb_kwargs) -> Tuple[Table, DataFrame]
Render a tabular report of the specified statistics with rich visualization.
This method generates an interactive table visualization for the given statistical data, optionally combined with the model's operation tree structure. The rendering supports real-time customization through keyword arguments and can export data to multiple formats.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stat_name
|
str
|
Name of the statistics to profile (i.e., 'param', 'cal', 'mem', 'ittp'). |
required |
show
|
bool
|
Whether to immediately render the visualization and display in terminal. Defaults to True. |
True
|
no_tree
|
bool
|
Not to display the rendered tree when set to True. Defaults to False. |
False
|
**tb_kwargs
|
Additional table customization options:
|
{}
|
Returns:
Type | Description |
---|---|
Tuple[Table, DataFrame]
|
The rendered |
Raises:
Type | Description |
---|---|
RuntimeWarning
|
If your model has some modules defined but not explicitly called. |
AttributeError
|
If |
ValueError
|
|
RuntimeError
|
|
TypeError
|
|
Notes
-
Ensure at least one forward pass has been executed before accessing
cal
/mem
/ittp
statistics to guarantee valid input capture. -
Table and tree rendering styles can be preconfigured through the properties:
table_display_args
table_column_args
tree_fold_args
tree_levels_args
tree_repeat_block_args
-
The rendering result will be progressively displayed line-by-line with a time interval. You can configure this interval through the following steps (must be non-negative):
-
Disable rendering (
show=False
) when only exporting data to reduce computational overhead. -
Enable
no_tree
to:- Enable focused data analysis
- Hide the operation tree for narrow terminals
- Default layout shows tree left + table right (may enforce row separators if space constrained)
-
Horizontal spacing between tree and table is controlled by
combine.horizon_gap
in global config. You should promise the value is non-negative. -
When
raw_data=True
displays unformatted values:param
: Parameter countscal
: FLOPs/MACs countsmem
: Bytes consumedittp
: Median inference time (seconds) and inferences per second per module
-
Column management:
- Use
pick_cols
to reorder columns (validate column names viametered_instance.table_cols(stat_name)
) - Processing order:
pick_cols
->exclude_cols
->custom_cols
->newcol
- Conflicts:
- picked columns override custom/newcol names
- exclusions override picks
- Use
-
About
newcol_func
:- must have exactly 1 formal parameter (name irrelevant) that will receive the
underlying
polars.DataFrame
of specified statistics. - Implement logic using the incoming dataframe to return new column values (must be 1D array-like data
such as
Series
,lists
,ndarrays
, etc.). Note that you can useval
property to access the raw data for all statistics (forittp
, the return will be a tuple made up of the median and iqr of the measurement data sequence). - The example below demonstrates adding a percentage column of the
cal
statistics. Refer to https://docs.pola.rs/api/python/stable/reference/dataframe/index.html for usingpolars.Dataframe
.
- must have exactly 1 formal parameter (name irrelevant) that will receive the
underlying
-
The
newcol_idx
parameter mostly follows Python list insertion semantics:- Negative values count backward from end (
-1
=append
) 0
inserts at beginning- Values exceeding column count clamp to nearest valid position:
- Negative abs values > column count → insert at start
- Positive values > column count → append at end
- Negative values count backward from end (
-
Session persistence:
keep_new_col
retains created columnskeep_custom_name
preserves renamed columns
-
Export paths:
- Directory paths require explicit
save_format
- File paths auto-detect format from extension unless
save_format
overrides
- Directory paths require explicit
Example
import torch
from torchmeter import Meter
from torchvision import models
# wrap your model with Meter
underlying_model = models.alexnet()
model = Meter(underlying_model)
# execute a forward inference (necessary, to provide input data)
input = torch.randn(1, 3, 224, 224)
model(input)
# check column names of cal tabel
print(model.table_cols("cal"))
# ('Operation_Id', 'Operation_Name', 'Operation_Type', 'Kernel_Size', 'Bias',
# 'Input', 'Output', 'MACs', 'FLOPs')
def newcol_logic(df):
flops_col = df["FLOPs"]
return flops_col.map_elements(lambda x: f"{100 * x / metered_model.cal.Flops:.4f} %")
# Customized profile with column operations
model.profile(
"cal",
# render and display immediately
show=True,
no_tree=True,
raw_data=False,
# columns management
exclude_cols=["Kernel_Size", "Bias"],
custom_cols={"Operation_Id": "ID", "Operation_Name": "Module Name", "Operation_Type": "Module Type"},
newcol_name="Percentage",
newcol_func=newcol_logic,
newcol_type=str,
newcol_idx=-1,
# export
save_to="./cal_profile.xlsx",
save_format="xlsx",
)
Source code in torchmeter/core.py
def profile(self, stat_name: str, show: bool = True, no_tree: bool = False, **tb_kwargs) -> Tuple[Table, DataFrame]:
"""Render a tabular report of the specified statistics with rich visualization.
This method generates an interactive table visualization for the given statistical data,
optionally combined with the model's operation tree structure. The rendering supports
real-time customization through keyword arguments and can export data to multiple formats.
Args:
stat_name (str): Name of the statistics to profile (i.e., 'param', 'cal', 'mem', 'ittp').
show (bool, optional): Whether to immediately render the visualization and display in terminal.
Defaults to True.
no_tree (bool, optional): Not to display the rendered tree when set to True. Defaults to False.
**tb_kwargs: Additional table customization options:
- raw_data (`bool`): Use raw numerical data instead of formatted values with unit.
Defaults to `False`.
- pick_cols (`Sequence[str]`): Whitelist of columns to display. Defaults to `[]`.
- exclude_cols (`Sequence[str]`): Blacklist of columns to hide. Defaults to `[]`.
- custom_cols (`Dict[str, str]`): Column rename mappings (original: new). Defaults to `{}`.
- keep_custom_name (`bool`): Whether to keep custom names after this call. Defaults to `False`.
- newcol_name (`str`): Name for new computed column. Defaults to `''`.
- newcol_func (`Callable[[DataFrame], ArrayLike]`): Function to compute new column values.
Defaults to `lambda df: [None]*len(df)`.
- newcol_type (`Optional[PolarsDataType]`): Explicit data type for new column. Defaults to `None`.
- newcol_idx (`int`): Insertion position for new column (-1=append). Defaults to `-1`.
- keep_new_col (`bool`): Retain new columns in backend dataframe and subsequent renders.
Defaults to `False`.
- save_to (`Optional[str]`): File path for data export, not None to trigger export. Defaults to `None`.
- save_format (`Optional[str]`): Export format, `None` to use the value in `save_to`.
Now we only support `csv` or `xlsx` file. Defaults to `None`.
Returns:
The rendered `rich.table.Table` object and underlying polars DataFrame.
Raises:
RuntimeWarning: If your model has some modules defined but not explicitly called.
AttributeError: If `stat_name` is not a valid statistics name.
ValueError:
- If horizontal gap defined in global config is negative when disable `no_tree`.
- If you specify any not existing column name to `pick_cols` when enable `show` and `pick_cols`.
- If you pass in a directory path as `save_to` but not specify `save_format`.
- If you pass in a non csv or xlsx file path as `save_to`.
- If you pass in a non-supported export format as `save_format`.
- If `newcol_name` already exists in the backend dataframe.
RuntimeError:
- If terminal width is insufficient for display when enable `show`.
- If no input data has been provided (i.e., `ipt` property is empty) and `stat_name` is
one of `cal`, `mem`, or `ittp`.
- If no module is called (e.g. the underlying model's `forward` method is not empty).
- If the whole model is empty and has no sublayers.
- If using a single layer as a model
- If `newcol_func` returns values with length mismatch to the underlying dataframe's row count
TypeError:
- If `stat_name` is not a string
- If `pick_cols` is not a list, tuple or set.
- If `exclude_cols` is not a list, tuple or set.
- If `custom_cols` is not a dict.
- If `newcol_name` is not a string.
- If `newcol_func` is uncallable
- If `newcol_func` doesn't have exactly **1** formal parameter
- If return value of `newcol_func` is not array-like
- If `newcol_idx` is not an integer.
- If `newcol_type` is not a valid Polars data type.
- If `save_to` is not a string, neither None.
- If `save_format` is not a string, neither None.
Notes:
1. Ensure at least one forward pass has been executed before accessing `cal`/`mem`/`ittp` statistics to
guarantee valid input capture.
2. Table and tree rendering styles can be preconfigured through the properties:
- `table_display_args`
- `table_column_args`
- `tree_fold_args`
- `tree_levels_args`
- `tree_repeat_block_args`
3. The rendering result will be progressively displayed line-by-line with a time interval.
You can configure this interval through the following steps (must be non-negative):
```python
from torchmeter import get_config
cfg = get_config()
cfg.render_interval = 0.5 # unit second, should be non-negative
```
4. Disable rendering (`show=False`) when only exporting data to reduce computational overhead.
5. Enable `no_tree` to:
- Enable focused data analysis
- Hide the operation tree for narrow terminals
- Default layout shows tree left + table right (may enforce row separators if space constrained)
6. Horizontal spacing between tree and table is controlled by `combine.horizon_gap` in global config.
You should promise the value is non-negative.
7. When `raw_data=True` displays unformatted values:
- `param`: Parameter counts
- `cal`: FLOPs/MACs counts
- `mem`: Bytes consumed
- `ittp`: Median inference time (seconds) and inferences per second per module
8. Column management:
- Use `pick_cols` to reorder columns (validate column names via
`metered_instance.table_cols(stat_name)`)
- Processing order: `pick_cols` -> `exclude_cols` -> `custom_cols` -> `newcol`
- Conflicts:
- picked columns override custom/newcol names
- exclusions override picks
9. About `newcol_func`:
- must have exactly **1** formal parameter (name irrelevant) that will receive the
underlying `polars.DataFrame` of specified statistics.
- Implement logic using the incoming dataframe to return new column values (must be 1D array-like data
such as `Series`, `lists`, `ndarrays`, etc.). Note that you can use `val` property to access the
raw data for all statistics (for `ittp`, the return will be a tuple made up of the median and iqr of
the measurement data sequence).
- The example below demonstrates adding a percentage column of the `cal` statistics. Refer to
https://docs.pola.rs/api/python/stable/reference/dataframe/index.html for using `polars.Dataframe`.
10. The `newcol_idx` parameter mostly follows Python list insertion semantics:
- Negative values count backward from end (`-1`=`append`)
- `0` inserts at beginning
- Values exceeding column count clamp to nearest valid position:
* Negative abs values > column count → insert at start
* Positive values > column count → append at end
12. Session persistence:
- `keep_new_col` retains created columns
- `keep_custom_name` preserves renamed columns
13. Export paths:
- Directory paths require explicit `save_format`
- File paths auto-detect format from extension unless `save_format` overrides
Example:
```python
import torch
from torchmeter import Meter
from torchvision import models
# wrap your model with Meter
underlying_model = models.alexnet()
model = Meter(underlying_model)
# execute a forward inference (necessary, to provide input data)
input = torch.randn(1, 3, 224, 224)
model(input)
# check column names of cal tabel
print(model.table_cols("cal"))
# ('Operation_Id', 'Operation_Name', 'Operation_Type', 'Kernel_Size', 'Bias',
# 'Input', 'Output', 'MACs', 'FLOPs')
def newcol_logic(df):
flops_col = df["FLOPs"]
return flops_col.map_elements(lambda x: f"{100 * x / metered_model.cal.Flops:.4f} %")
# Customized profile with column operations
model.profile(
"cal",
# render and display immediately
show=True,
no_tree=True,
raw_data=False,
# columns management
exclude_cols=["Kernel_Size", "Bias"],
custom_cols={"Operation_Id": "ID", "Operation_Name": "Module Name", "Operation_Type": "Module Type"},
newcol_name="Percentage",
newcol_func=newcol_logic,
newcol_type=str,
newcol_idx=-1,
# export
save_to="./cal_profile.xlsx",
save_format="xlsx",
)
```
"""
from rich.rule import Rule
from rich.layout import Layout
# the horizontal gap between tree and table
TREE_TABLE_GAP = __cfg__.combine.horizon_gap
if not isinstance(stat_name, str):
raise TypeError(f"stat_name must be a string, but got `{type(stat_name).__name__}`.")
if TREE_TABLE_GAP < 0:
raise ValueError(
"The gap between the rendered tree and the rendered table should be non-negative, "
+ f"but got `{TREE_TABLE_GAP}`."
)
stat = getattr(self, stat_name)
tb, data = self.table_renderer(stat_name=stat_name, **tb_kwargs)
if not show:
return tb, data
tree = None if no_tree else self.structure
console = get_console()
tree_width = console.measure(tree).maximum if not no_tree else 0 # type: ignore
desirable_tb_width = console.measure(tb).maximum
actual_tb_width = min(desirable_tb_width, console.width - tree_width - TREE_TABLE_GAP)
if actual_tb_width <= 5: # 5 is the minimum width of table
raise RuntimeError(
"The width of the terminal is too small, try to maximize the window or "
+ "set a smaller `horizon_gap` value in config and try again."
)
# when some cells in the table is overflown, we need to show a line between rows
if actual_tb_width < desirable_tb_width:
tb.show_lines = True
# get main content(i.e. tree & statistics table)
if no_tree:
main_content: Union[Table, Layout] = tb
tree_height = 0
else:
main_content = Layout()
main_content.split_row(
Layout(tree, name="left", size=tree_width + TREE_TABLE_GAP),
Layout(tb, name="right", size=actual_tb_width),
)
tree_height = len(console.render_lines(tree)) # type: ignore
temp_options = console.options.update_width(actual_tb_width)
tb_height = len(console.render_lines(tb, options=temp_options))
main_content_height = max(tree_height, tb_height)
main_content_width = tree_width + actual_tb_width + (0 if no_tree else TREE_TABLE_GAP)
# get footer content
footer = Columns(
title=Rule("[gray54]s u m m a r y[/]", characters="-", style="gray54"), # type: ignore
padding=(1, 1),
equal=True,
expand=True,
)
model_info = self.model_info
stat_info = self.stat_info(stat_or_statname=stat, show_warning=False)
model_info.style = "dim"
stat_info.style = "dim"
footer.add_renderable(model_info)
footer.add_renderable(stat_info)
temp_options = console.options.update_width(main_content_width)
footer_height = len(console.render_lines(footer, options=temp_options))
# render profile
canvas = Layout()
canvas.split_column(
Layout(main_content, name="top", size=main_content_height), Layout(footer, name="down", size=footer_height)
)
origin_width = console.width
origin_height = console.height
console.width = main_content_width
console.height = main_content_height + footer_height
try:
render_perline(renderable=canvas)
finally:
# if user interupts the rendering when render_interval > 0
# still restore the console size
console.width = origin_width
console.height = origin_height
return tb, data
rebase
¶
Rebases the Meter instance to a specific node in the operation tree.
This method allows the Meter instance to focus on a specific node in the operation tree, effectively treating that node as the new root of a new Meter instance. If the provided node ID is "0", the original Meter instance is returned unchanged.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
node_id
|
str
|
The ID of the node to rebase to. Must be a valid node ID in the operation tree. |
required |
Returns:
Name | Type | Description |
---|---|---|
Meter |
Meter
|
A new Meter instance with the specified node as the root. |
Raises:
Type | Description |
---|---|
TypeError
|
If |
ValueError
|
If |
Notes
- Use
Meter(your_model).subnodes
to retrieve a list of valid node IDs. - If
node_id
is "0", the original Meter instance is returned without modification.
Example
from torchmeter import Meter
from torchvision import models
underlying_model = models.resnet18()
model = Meter(underlying_model)
rebased_model = metered_model.rebase("5")
print(model) # Meter(model=0 ResNet: ResNet, device=cpu)
print(rebased_model) # Meter(model=0 Sequential: Sequential, device=cpu)
Source code in torchmeter/core.py
def rebase(self, node_id: str) -> Meter:
"""Rebases the Meter instance to a specific node in the operation tree.
This method allows the Meter instance to focus on a specific node in the operation tree,
effectively treating that node as the new root of a new Meter instance. If the provided
node ID is "0", the original Meter instance is returned unchanged.
Args:
node_id (str): The ID of the node to rebase to. Must be a valid node ID in the operation tree.
Returns:
Meter: A new Meter instance with the specified node as the root.
Raises:
TypeError: If `node_id` is not a string.
ValueError: If `node_id` does not exist in the operation tree.
Notes:
- Use `Meter(your_model).subnodes` to retrieve a list of valid node IDs.
- If `node_id` is "0", the original Meter instance is returned without modification.
Example:
```python
from torchmeter import Meter
from torchvision import models
underlying_model = models.resnet18()
model = Meter(underlying_model)
rebased_model = metered_model.rebase("5")
print(model) # Meter(model=0 ResNet: ResNet, device=cpu)
print(rebased_model) # Meter(model=0 Sequential: Sequential, device=cpu)
```
"""
if not isinstance(node_id, str):
raise TypeError(f"node_id must be a string, but got `{type(node_id).__name__}`.")
if node_id == "0":
return self
id_generator = ((node_idx, node.node_id) for node_idx, node in enumerate(self.optree.all_nodes))
for idx, valid_id in id_generator:
if node_id == valid_id:
new_base = self.optree.all_nodes[idx]
return self.__class__(new_base.operation, device=self.device)
else:
raise ValueError(f"Invalid node_id: {node_id}. Use `Meter(your_model).subnodes` to check valid ones.")
stat_info
¶
Generates a formatted summary of the specified statistics.
This method provides a summary of the given statistics, including its name and the crucial data
about this statistics. However, sometimes there may exist some modules which is defined but not
explicitly called, or some modules that its calculation measurement logic is not defined in this
version. To prevent confusing user, we will show inaccuracies warnings in the summary. If you don't
want to see these warnings, you can set show_warning
to False
manually.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stat_or_statname
|
Union[str, Statistics]
|
The name of the statistics or the statistics object itself. |
required |
show_warning
|
bool
|
Whether to display warnings about potential inaccuracies. Defaults to True. |
True
|
Returns:
Name | Type | Description |
---|---|---|
Text |
Text
|
A |
Raises:
Type | Description |
---|---|
TypeError
|
If |
Notes
-
The main content will be obtained from the
crucial_data
property of the statistics object, which is defined in the corresponding statistics class. -
For
ittp
, the number of repeated measurements, namelyBenchmark Times
, will be additionally displayed. This value can be accessed or modified through the `ittp_benchmark_time' attribute. -
show_warning
option is keyword-only argument, so you should use it through its keyword name. -
Warnings are only shown for the following two statistics: calculation (
cal
) and memory (mem
). Because only these two statistics are affected by the no called modules or the not supported mudules.
Example
from torch import randn
from torchmeter import Meter
from torchvision import models
from rich import print
underlying_model = models.vit_b_16()
model = Meter(underlying_model)
_ = model(randn(1, 3, 224, 224))
# using statistics name
print(model.stat_info("param"))
# using statistics object
cal = model.cal
print(model.stat_info(cal))
# not show warnings
print(model.stat_info("mem", show_warning=False))
Source code in torchmeter/core.py
def stat_info(self, stat_or_statname: Union[str, Statistics], *, show_warning: bool = True) -> Text: # noqa: C901
"""Generates a formatted summary of the specified statistics.
This method provides a summary of the given statistics, including its name and the crucial data
about this statistics. However, sometimes there may exist some modules which is defined but not
explicitly called, or some modules that its calculation measurement logic is not defined in this
version. To prevent confusing user, we will show inaccuracies warnings in the summary. If you don't
want to see these warnings, you can set `show_warning` to `False` manually.
Args:
stat_or_statname (Union[str, Statistics]): The name of the statistics or the statistics object itself.
show_warning (bool): Whether to display warnings about potential inaccuracies. Defaults to True.
Returns:
Text: A `rich.Text` object containing the formatted summary.
Raises:
TypeError: If `stat_or_statname` is neither a string nor a `Statistics` object.
Notes:
- The main content will be obtained from the `crucial_data` property of the statistics object, which is
defined in the corresponding statistics class.
- For `ittp`, the number of repeated measurements, namely `Benchmark Times`, will be additionally
displayed. This value can be accessed or modified through the `ittp_benchmark_time' attribute.
- `show_warning` option is keyword-only argument, so you should use it through its keyword name.
- Warnings are only shown for the following two statistics: calculation (`cal`) and memory (`mem`). Because
only these two statistics are affected by the no called modules or the not supported mudules.
Example:
```python
from torch import randn
from torchmeter import Meter
from torchvision import models
from rich import print
underlying_model = models.vit_b_16()
model = Meter(underlying_model)
_ = model(randn(1, 3, 224, 224))
# using statistics name
print(model.stat_info("param"))
# using statistics object
cal = model.cal
print(model.stat_info(cal))
# not show warnings
print(model.stat_info("mem", show_warning=False))
```
"""
if isinstance(stat_or_statname, str):
stat = getattr(self, stat_or_statname)
elif isinstance(stat_or_statname, Statistics):
stat = stat_or_statname
else:
raise TypeError(
f"Invalid type for stat_or_statname: `{type(stat_or_statname).__name__}`. "
+ "Please pass in the statistics name or the statistics object itself."
)
stat_name = stat.name
infos_ls: List[str] = [f"• [b]Statistics:[/b] {stat_name}"]
if stat_name == "ittp":
infos_ls.append(f"• [b]Benchmark Times:[/b] {self.ittp_benchmark_time}")
infos_ls.extend([f"• [b]{k}:[/b] {v}" for k, v in stat.crucial_data.items()])
# warning field, only works when stat is "cal" or "mem"
if show_warning and stat_name in ("cal", "mem"):
# cache for __has_nocall_nodes
if self.__has_nocall_nodes is None:
from operator import attrgetter
crucial_data_getter = attrgetter(f"{stat_name}.crucial_data")
try:
list(map(crucial_data_getter, self.optree.all_nodes))
self.__has_nocall_nodes = False
except RuntimeError:
self.__has_nocall_nodes = True
# cache for __has_not_support_nodes
if stat_name == "cal" and self.__has_not_support_nodes is None:
self.__has_not_support_nodes = any(n.cal.is_not_supported for n in self.optree.all_nodes)
warns_ls = []
if self.__has_nocall_nodes:
warns_ls.append(
" " * 2 + "[dim yellow]:arrow_forward: " + "Some nodes are defined but not called explicitly.[/]"
)
if stat_name == "cal" and self.__has_not_support_nodes:
warns_ls.append(
" " * 2
+ "[dim yellow]:arrow_forward: "
+ "Some modules don't support calculation measurement yet.[/]"
)
if warns_ls:
warns_ls.insert(0, "[dim yellow]:warning: Warning: the result may be inaccurate, cause:[/]")
warns_ls.append(
" " * 2
+ "[dim cyan]:ballot_box_with_check: "
+ f"use `Meter(your_model).profile('{stat_name}')` to see more.[/]"
)
infos_ls.extend(warns_ls)
infos = "\n".join(infos_ls)
console = get_console()
return console.render_str(infos)
table_cols
¶
Get all column names of the backend dataframe for the specified statistics.
This method returns the column names of the backend dataframe associated with the given statistics.
If the dataframe is empty(i.e. the profile
is not called yet), it falls back to the values of the
tb_fields
property of corresponding statistics class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stat_name
|
str
|
The name of the statistics for which to retrieve the columns. |
required |
Returns:
Type | Description |
---|---|
Tuple[str, ...]
|
A tuple of column names for the specified statistics backend dataframe. |
Raises:
Type | Description |
---|---|
TypeError
|
If |
KeyError
|
If |
default column names for each statistics:
- param: ("Operation_Id", "Operation_Name", "Operation_Type",
"Param_Name", "Requires_Grad", "Numeric_Num")
- cal: ("Operation_Id", "Operation_Name", "Operation_Type",
"Kernel_Size", "Bias", "Input", "Output", "MACs", "FLOPs")
- mem: ("Operation_Id", "Operation_Name", "Operation_Type",
"Param_Cost", "Buffer_Cost", "Output_Cost", "Total")
- ittp: ("Operation_Id", "Operation_Name", "Operation_Type",
"Infer_Time", "Throughput")
Example
from torchmeter import Meter
from torchvision import models
underlying_model = models.resnet18()
model = Meter(underlying_model)
model.table_cols("param")
# ('Operation_Id',
# 'Operation_Name',
# 'Operation_Type',
# 'Param_Name',
# 'Requires_Grad',
# 'Numeric_Num')
model.table_cols("cal")
# ('Operation_Id',
# 'Operation_Name',
# 'Operation_Type',
# 'Kernel_Size',
# 'Bias',
# 'Input',
# 'Output',
# 'MACs',
# 'FLOPs')
Source code in torchmeter/core.py
def table_cols(self, stat_name: str) -> Tuple[str, ...]:
"""Get all column names of the backend dataframe for the specified statistics.
This method returns the column names of the backend dataframe associated with the given statistics.
If the dataframe is empty(i.e. the `profile` is not called yet), it falls back to the values of the
`tb_fields` property of corresponding statistics class.
Args:
stat_name (str): The name of the statistics for which to retrieve the columns.
Returns:
A tuple of column names for the specified statistics backend dataframe.
Raises:
TypeError: If `stat_name` is not a string.
KeyError: If `stat_name` is not found in the available statistics (i.e. `param`, `cal`, `mem`, `ittp`).
Notes:
default column names for each statistics:
- param: ("Operation_Id", "Operation_Name", "Operation_Type",
"Param_Name", "Requires_Grad", "Numeric_Num")
- cal: ("Operation_Id", "Operation_Name", "Operation_Type",
"Kernel_Size", "Bias", "Input", "Output", "MACs", "FLOPs")
- mem: ("Operation_Id", "Operation_Name", "Operation_Type",
"Param_Cost", "Buffer_Cost", "Output_Cost", "Total")
- ittp: ("Operation_Id", "Operation_Name", "Operation_Type",
"Infer_Time", "Throughput")
Example:
```python
from torchmeter import Meter
from torchvision import models
underlying_model = models.resnet18()
model = Meter(underlying_model)
model.table_cols("param")
# ('Operation_Id',
# 'Operation_Name',
# 'Operation_Type',
# 'Param_Name',
# 'Requires_Grad',
# 'Numeric_Num')
model.table_cols("cal")
# ('Operation_Id',
# 'Operation_Name',
# 'Operation_Type',
# 'Kernel_Size',
# 'Bias',
# 'Input',
# 'Output',
# 'MACs',
# 'FLOPs')
```
"""
if not isinstance(stat_name, str):
raise TypeError(f"stat_name must be a string, but got `{type(stat_name).__name__}`.")
stats_data_dict: Dict[str, DataFrame] = self.table_renderer.stats_data
if stat_name not in stats_data_dict:
raise KeyError(f"Statistics `{stat_name}` not in {tuple(stats_data_dict.keys())}.")
stat_data: DataFrame = stats_data_dict[stat_name]
if stat_data.is_empty():
cols: Tuple[str, ...] = getattr(self.optree.root, stat_name).tb_fields
else:
cols = tuple(stat_data.columns)
return cols
to
¶
Move the model to the specified device while keeping input and model device synchronization.
Simulate the to
method of pytorch model and use it to move model and all tensor data in
self._ipt
to the specified device.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
new_device
|
Union[str, device]
|
Target device name or its corresponding torch.device object. |
required |
Example
Source code in torchmeter/core.py
def to(self, new_device: Union[str, tc_device]) -> None:
"""Move the model to the specified device while keeping input and model device synchronization.
Simulate the `to` method of pytorch model and use it to move model and all tensor data in
`self._ipt` to the specified device.
Args:
new_device (Union[str, torch.device]): Target device name or its corresponding torch.device object.
Example:
```python
import torch
from torchmeter import Meter
from torchvision import models
underlying_model = models.resnet18()
model = Meter(underlying_model)
# move to cuda:0
model.to("cuda:0")
# move to cpu
model.to(torch.device("cpu"))
```
"""
self.device = new_device # type: ignore