pytorch: How to implement one link per neuron?

631 Views Asked by At

For example, I would like to have a standard feed-forward neural network with the following structure:

  1. n input neurons
  2. n neurons on the second layer
  3. 2 neurons on the third layer
  4. n neurons on the fourth layer

where

  • the i-th neuron in the first layer is connected precisely to the i-th neuron in the second layer (don't know how to do that)
  • the second and the third layer are fully connected, the same goes for the third and the fourth layer (I know how to do that - using nn.Linear)
  • loss function is MSE + L1 norm of the (vector of) weights between the first two layers (depends on the solution of the question whether I can do that)

Motivation: I want to implement an autoencoder and try to achieve some sparsity (this is why the inputs are multiplied by a single weight (going from the first to the second layer)).

1

There are 1 best solutions below

0
Berriel On

You can implement a custom layer, similar to nn.Linear:

import math
import torch
from torch import nn

class ElementWiseLinear(nn.Module):
    __constants__ = ['n_features']
    n_features: int
    weight: torch.Tensor
    def __init__(self, n_features: int, bias: bool = True) -> None:
        super(ElementWiseLinear, self).__init__()
        self.n_features = n_features
        self.weight = nn.Parameter(torch.Tensor(1, n_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(n_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = torch.mul(input, self.weight)
        if self.bias is not None:
            output += self.bias
        return output
    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.n_features, self.n_features, self.bias is not None
        )

and use it like this:

x = torch.rand(3)
layer = ElementWiseLinear(3, bias=False)
output = layer(x)

Of course you make make things a lot simpler than that :)