I'm attempting to select a single sample from a range of Normal distributions based upon the output of a categorical distribution, however can't seem to come up with quite the right way to do it. Using something along the lines of:
tfp.distributions.JointDistributionSequential([
tfp.distributions.Categorical(probs=[0, 0, 1/2, 1/2]),
lambda c: tfp.distributions.Normal([0, 1, -10, 30], 1)[..., c]
])
Returns exactly what I want for the single case, however if I want multiple samples at once this breaks (as c becomes a numpy array rather than an integer. Is this possible and if so, how should I go about it?
(I also attempted using OneHotCategorical and multiplying but that didn't work at all!)
You could do this, if you don't want to use
MixtureSameFamilyas Brian suggests:Note I needed to add a
.to the locs in the gather to avoid a dtype error.Here, what we end up doing is
nsamples from theCategoricalnNormals, whose locs are obtained by indexingntimes into the 4-vector of locsn-batch ofNormals.The previous approach doesn't work because
Distributionslicing doesn't support this kind of "fancy indexing" It would be cool if we did! TF doesn't support it in general, for reasons.