microsoft/nni

ChannelDependency does not handle concat properly

Open

#4.637 aberto em 14 de mar. de 2022

Ver no GitHub
 (3 comments) (1 reaction) (1 assignee)Python (13.504 stars) (1.830 forks)batch import
buggood first issuehelp wantedmodel compression

Description

Describe the issue: During computing Channel Dependencies reshape_break_channel_dependency does following code to ensure that the number of input channels equals the number of output channels:

in_shape = op_node.auxiliary['in_shape']
out_shape = op_node.auxiliary['out_shape']
in_channel = in_shape[1]
out_channel = out_shape[1]
return in_channel != out_channel

This is correct for most reshape operations as long as they accept only one argument. In case of concatenation the in_shape is a list of the concatenated shapes and thus in_channel is being assigned a full shape (e.g. [1, 20, 32, 32]) instead of a single integer (e.g. 20).

This effectively prevents creation of channel dependencies caused by concatenations (although it's rather rare to concatenate feature maps in non-channel dimension).

Environment:

  • NNI version: 2.5
  • Training service local:
  • Client OS: linux
  • Server OS (for remote mode only):
  • Python version: 3.8.6
  • PyTorch/TensorFlow version: 1.10.0
  • Is conda/virtualenv/venv used?: yes
  • Is running in Docker?: no

Configuration: N/A

Log message: N/A

How to reproduce it?: Model code to quickly replicate the problem:

class NaiveModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
        self.conv2 = torch.nn.Conv2d(1, 20, 5, 1)
        self.fc1 = torch.nn.Linear(6 * 6 * 40, 500)
        self.fc2 = torch.nn.Linear(500, 10)
        self.relu1 = torch.nn.ReLU6()
        self.relu2 = torch.nn.ReLU6()
        self.relu3 = torch.nn.ReLU6()
        self.max_pool1 = torch.nn.MaxPool2d(4, 4)
        self.max_pool2 = torch.nn.MaxPool2d(4, 4)

    def forward(self, x):
        x1 = self.relu1(self.conv1(x))
        x1 = self.max_pool1(x1)
        x2 = self.relu2(self.conv2(x))
        x2 = self.max_pool2(x2)
        x = torch.cat([x1, x2], 2)
        x = x.view(-1, x.size()[1:].numel())
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = NaiveModel()
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner

dummy_input = torch.ones([1, 1, 28, 28]).to(device)

config_list = [{"sparsity": 0.5, "op_types": ["Conv2d"]}]
pruner = L1FilterPruner(
    model, config_list, dependency_aware=True, dummy_input=dummy_input
)

# just step into the code to see that the dependencies are not parsed correctly

Guia do colaborador