diff --git a/sparsebit/quantization/quantizers/__init__.py b/sparsebit/quantization/quantizers/__init__.py index d070308..f283576 100644 --- a/sparsebit/quantization/quantizers/__init__.py +++ b/sparsebit/quantization/quantizers/__init__.py @@ -13,6 +13,7 @@ def register_quantizer(quantizer): from . import lsq_plus from . import pact from . import adaround +from . import quadapter def build_quantizer(cfg): diff --git a/sparsebit/quantization/quantizers/adaround.py b/sparsebit/quantization/quantizers/adaround.py index 20eccad..b3c9567 100644 --- a/sparsebit/quantization/quantizers/adaround.py +++ b/sparsebit/quantization/quantizers/adaround.py @@ -23,6 +23,7 @@ def __init__(self, config): config.TARGET[0] == QuantTarget.WEIGHT ), "AdaRound only supports to quant weights" self.zeta, self.gamma = 1.1, -0.1 # stretch-parameters + self.reconstruct_qlayer = reconstruct_qlayer def init_variables(self, x): x_floor = torch.floor(x / self.scale) diff --git a/sparsebit/quantization/quantizers/quadapter.py b/sparsebit/quantization/quantizers/quadapter.py new file mode 100644 index 0000000..7b66ca5 --- /dev/null +++ b/sparsebit/quantization/quantizers/quadapter.py @@ -0,0 +1,65 @@ +import math +import torch +import torch.nn as nn +from torch.nn import functional as F + +from sparsebit.quantization.quantizers import Quantizer as BaseQuantizer +from sparsebit.quantization.quantizers import register_quantizer +from .quant_tensor import STE + + +@register_quantizer +class Quantizer(BaseQuantizer): + TYPE = "Quadapter" + + def __init__(self, config): + super(Quantizer, self).__init__(config) + self.reconstruct_qlayer = reconstruct_qlayer + + def init_variables(self, x: torch.Tensor): + alpha_shape = [1 for _ in range(self.dims)] + alpha_shape[self.qdesc._ch_axis] = x.shape[self.qdesc._ch_axis] + self.alpha = nn.Parameter(torch.ones(alpha_shape).to(self.device)) + + def update_observer(self, x): + self.dims = len(x.shape) + self.observer.data_cache.update(x.detach()) + + def _forward(self, x_f, scale, zero_point): + x_f = x_f * self.alpha + x_dq = STE.apply(x_f, scale, zero_point, self.qdesc, self.backend) + x_dq = x_dq / self.alpha + return x_dq + + +def reconstruct_qlayer( + layer, + inputs: torch.Tensor, + outputs: torch.Tensor, + batch_size=32, + max_steps=20000, + p=2.0, +): + # init + layer.eval() + layer.set_quant(w_quant=True, a_quant=True) + layer.input_quantizer.init_variables(inputs) + layer.input_quantizer.train() + opt_params = [layer.input_quantizer.alpha] + optimizer = torch.optim.Adam(opt_params) + print_freq = 500 + # training + device = layer.input_quantizer.device + inputs, outputs = inputs.to(device), outputs.to(device) + for step in range(max_steps): + idx = torch.randperm(inputs.size(0))[:batch_size] + cur_input, cur_output = inputs[idx], outputs[idx] + optimizer.zero_grad() + quant_output = layer(cur_input) + loss = (quant_output - cur_output).abs().pow(p).sum(1).mean() + loss.backward(retain_graph=True) + optimizer.step() + if step % print_freq == 0: + print("Loss: {:.3f} step={}".format(loss, step)) + torch.cuda.empty_cache() + layer.input_quantizer.eval() diff --git a/sparsebit/quantization/tools/calibration.py b/sparsebit/quantization/tools/calibration.py index 8a94706..12468e1 100644 --- a/sparsebit/quantization/tools/calibration.py +++ b/sparsebit/quantization/tools/calibration.py @@ -3,7 +3,6 @@ from functools import partial from sparsebit.quantization.modules import QuantOpr -from sparsebit.quantization.quantizers.adaround import reconstruct_qlayer from .graph_wrapper import GraphVisitor, fx_symbolic_trace from .tensor_wrapper import to_cpu, to_device, to_detach @@ -89,6 +88,8 @@ def layerwise_calibration(self, device, asym=False, w_quant=False, a_quant=False float_outputs = self.module_forward(batch_num, node, device) self.builder.storage.set_output(node.target, float_outputs) self.run_weight_calibration(node, asym, a_quant=a_quant) + # layerwise reconstruction + self.run_layerwise_reconstruction(node, asym, a_quant=a_quant) # foward quant output if asym: quant_outputs = self.module_forward( @@ -115,15 +116,41 @@ def run_weight_calibration(self, node, asym=False, a_quant=False): if isinstance(module, QuantOpr) and getattr(module, "weight_quantizer", None): module.weight_quantizer.update_observer(module.weight) module.weight_quantizer.calc_qparams() - if module.weight_quantizer.TYPE.lower() == "adaround": + + def run_layerwise_reconstruction(self, node, asym=False, a_quant=False): + module = self.model + for n in node.target.split("."): + module = getattr(module, n) + if isinstance(module, QuantOpr): + if ( + getattr(module, "input_quantizer", None) + and not module.input_quantizer.fake_fused + and module.input_quantizer.TYPE.lower() == "quadapter" + ): + assert ( + len(node.all_input_nodes) == 1 + ), "Quadapter not supports the oprs which has more than one inputs" + _storage = self.builder.qstorage if asym else self.builder.storage + inp_tensors = _storage.get_output(node.all_input_nodes[0].target) + out_tensors = self.builder.storage.get_output(node.target) + print("Reconstruct input_quantizer of {}".format(node.target)) + module.input_quantizer.reconstruct_qlayer( + module, + torch.cat(inp_tensors, dim=0), + torch.cat(out_tensors, dim=0), + ) + if ( + getattr(module, "weight_quantizer", None) + and module.weight_quantizer.TYPE.lower() == "adaround" + ): assert ( len(node.all_input_nodes) == 1 ), "AdaRound not supports the oprs which has more than one inputs" _storage = self.builder.qstorage if asym else self.builder.storage inp_tensors = _storage.get_output(node.all_input_nodes[0].target) out_tensors = self.builder.storage.get_output(node.target) - print("Reconstruct {}".format(node.target)) - reconstruct_qlayer( + print("Reconstruct weight_quantizer of {}".format(node.target)) + module.weight_quantizer.reconstruct_qlayer( module, torch.cat(inp_tensors, dim=0), torch.cat(out_tensors, dim=0),