#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Helpers
**Module name:** :mod:qmlt.helpers

Collection of helpers to set up an experiment with either the numerical or tf
circuit learner.

sample_from_distribution

"""

import numpy as np

r"""
Sample a Fock state from a nested probability distribution of Fock states.

Args:
distribution (ndarray): Nested array containing probabilities of Fock state.
Fock state :math:|i,j,k \rangle is retrieved by distribution([i,j,k]).
Can be the result of :func:state.all_fock_probs.

Return: List of photon numbers representing a Fock state.
"""

distribution = np.array(distribution)
cutoff = distribution.shape[0]
num_modes = len(distribution.shape)

probs_flat = np.reshape(distribution, (-1))
indices_flat = np.arange(len(probs_flat))
indices = np.reshape(indices_flat, [cutoff] * num_modes)
sample_index = np.random.choice(indices_flat, p=probs_flat / sum(probs_flat))
fock_state = np.asarray(np.where(indices == sample_index)).flatten()

return fock_state