Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,23 @@ def change_out_bias(
else:
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)

def compute_fitting_stat(self, sample_merged) -> None:
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data..

Parameters
----------
sample_merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
"""
self.fitting_net.compute_input_stats(
sample_merged, protection=self.data_stat_protect
)

def _get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]:
"""Get a forward wrapper of the atomic model for output bias calculation."""

Expand Down
4 changes: 1 addition & 3 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,7 @@ def wrapped_sampler():
return sampled

self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
self.fitting_net.compute_input_stats(
wrapped_sampler, protection=self.data_stat_protect
)
self.compute_fitting_stat(wrapped_sampler)
if compute_or_load_out_stat:
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)

Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ def change_out_bias(
merged,
bias_adjust_mode=bias_adjust_mode,
)
if bias_adjust_mode == "set-by-statistic":
self.atomic_model.compute_fitting_stat(merged)

def forward_common_lower(
self,
Expand Down
Loading