Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

differentiable samples (rsample) #82

Open
RafaelPo opened this issue Nov 24, 2020 · 11 comments
Open

differentiable samples (rsample) #82

RafaelPo opened this issue Nov 24, 2020 · 11 comments

Comments

@RafaelPo
Copy link

Are there plans to introduce differentiable samples?

Thanks!

@srush
Copy link
Collaborator

srush commented Nov 24, 2020

Yeah... we are trying that out currently actually. There are a lot of different ways to do it with discrete distributions, did you have one in mind?

@RafaelPo
Copy link
Author

Hi,

I was thinking of applying results from: https://arxiv.org/pdf/2002.08676.pdf, recursively on the marginals... do you think that would work?

@srush
Copy link
Collaborator

srush commented Nov 25, 2020

Yes I think that would be cool. We have some of the papers referenced in that work already implemented, such as differentiable dynamic programming semiring. But it is not exposed in the api. I'm a bit hesistant to call it rsample, because it is biased. Maybe we should have a separate api function that exposes some of these tricks? If you are interested would be happy for a contribution.

@RafaelPo
Copy link
Author

Hi,

here is some code I have been playing with:
image
image

@srush
Copy link
Collaborator

srush commented Nov 25, 2020

Nice, that is similar in spirit to this code which we have been working on #81 .

We can integrate them both in to the library.

There might also be a way to do this by only calling cvxpy many fewer time.

@RafaelPo
Copy link
Author

I will have a look thanks!

How could you save on the number of runs?

also, I think they are supposed to be unbiased, no?

@srush
Copy link
Collaborator

srush commented Nov 25, 2020

Very neat. So I think that instead of first computing marginals, we can apply this approach in the backward operation of the semiring itself. This is how I compute unbiased gumbel-max samples (https://github.com/harvardnlp/pytorch-struct/pull/81/files#diff-5775ad09d6cfbdc4d52edd6797aba8e68ac66ae04dee680b0e456058bef106dcR70) .

It seems like I can just change this line (https://github.com/harvardnlp/pytorch-struct/pull/81/files#diff-5775ad09d6cfbdc4d52edd6797aba8e68ac66ae04dee680b0e456058bef106dcR71) from an argmax to your CVX code to get a differentiable sample? This should work for all models.

Another advantage of this method is that it will batch across n (our internal code does log n steps instead of n for linear chain).

I agree the forward sample is unbiased, but I will have to read the paper to understand if the gradient is unbiased to? (but I believe you).

@teffland
Copy link

teffland commented Dec 5, 2020

Hi,

Not sure how this compares to what you guys have been working on, but for what it's worth I have implemented a version of a biased rsample that uses local gumbel perturbations and temperature-controlled marginals (this is the marginal stochastic softmax trick from https://arxiv.org/abs/2006.08063) directly in the StructDistrubution class as:

def rsample(self, sample_shape=torch.Size(), temp=1.0, noise_shape=None, sample_batch_size=10):
        r"""
        Compute structured samples from the _relaxed_ distribution :math:`z \sim p(z;\theta+\gamma, \tau)`

        NOTE: These samples are biased.

        This uses gumbel perturbations on the potentials followed by the >zero-temp marginals to get approximate samples.
        As temp varies from 0 to inf the samples will vary from being exact onehots from an approximate distribution to
        a deterministic distribution that is always uniform over all values.

        The approximation of the zero-temp limit comes from the fact that we use polynomial (instead of exponential)
        perturbations, see:
          [Perturb-and-MAP](https://ttic.uchicago.edu/~gpapan/pubs/confr/PapandreouYuille_PerturbAndMap_ieee-c-iccv11.pdf)
          [Stochastic Softmax Tricks](https://arxiv.org/abs/2006.08063)

        Parameters:
            sample_shape (int): number of samples
            temp (float): (default=1.0) relaxation temperature
            noise_shape (torch.Shape): specify lower-order perturbations by placing ones along any of the potential dims
            sample_batch_size (int): size of batches to calculates samples

        Returns:
            samples (*sample_shape x batch_shape x event_shape*)

        """
        # Sanity checks
        if type(sample_shape) == int:
            nsamples = sample_shape
        else:
            assert len(sample_shape) == 1
            nsamples = sample_shape[0]
        if sample_batch_size > nsamples:
            sample_batch_size = nsamples
        samples = []

        if noise_shape is None:
            noise_shape = self.log_potentials.shape[1:]

        assert len(noise_shape) == len(self.log_potentials.shape[1:])
        assert all(
            s1 == 1 or s1 == s2 for s1, s2 in zip(noise_shape, self.log_potentials.shape[1:])
        ), f"Noise shapes must match dimension or be 1: got: {list(zip(noise_shape, self.log_potentials.shape[1:]))}"

        # Sampling
        for k in range(nsamples):
            if k % sample_batch_size == 0:
                shape = self.log_potentials.shape
                B = shape[0]
                s_log_potentials = (
                    self.log_potentials.reshape(1, *shape)
                    .repeat(sample_batch_size, *tuple(1 for _ in shape))
                    .reshape(-1, *shape[1:])
                )

                s_lengths = self.lengths
                if s_lengths is not None:
                    s_shape = s_lengths.shape
                    s_lengths = (
                        s_lengths.reshape(1, *s_shape)
                        .repeat(sample_batch_size, *tuple(1 for _ in s_shape))
                        .reshape(-1, *s_shape[1:])
                    )

                noise = (
                    torch.distributions.Gumbel(0, 1)
                    .sample((sample_batch_size * B, *noise_shape))
                    .expand_as(s_log_potentials)
                ).to(s_log_potentials.device)
                noisy_potentials = (s_log_potentials + noise) / temp

                r_sample = (
                    self._struct(LogSemiring)
                    .marginals(noisy_potentials, s_lengths)
                    .reshape(sample_batch_size, B, *shape[1:])
                )
                samples.append(r_sample)
        return torch.cat(samples, dim=0)[:nsamples]

Let me know if you'd like me to submit as a pr (with whatever changes you think make sense).

Thanks,
Tom

@srush
Copy link
Collaborator

srush commented Dec 5, 2020

Awesome sounds like we have three different methods. The one in my PR is from Yao's NeurIPS work https://arxiv.org/abs/2011.14244 which is unbiased forward and biased backward. Maybe we should have a phone call and figure out the differences and how to document and compare them.

@teffland
Copy link

teffland commented Dec 7, 2020

Very interesting, I'll take a look at the paper -- unbiased forward sounds like a big plus. I'm available for a call to discuss pretty much whenever.

@RafaelPo
Copy link
Author

RafaelPo commented Dec 7, 2020

Not sure how what I proposed compares to the rest, it seems (way) more computationally expensive but I would be interested in a call as well, but I am based in England.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants