Example#

00. Basic Usage#

"""
Example 00: Circuit Basics
"""
import jax
import jax.numpy as jnp

import diffqc
from diffqc import dense as op

def main():
    nqubits = 5

    @jax.jit
    def circuit(params):
        # Initialize |00..0> state
        x = op.zeros(nqubits, jnp.complex64)

        # Apply quantum operation
        # At quantum operator,
        # - the first argument is qubit state,
        # - the second argument is wire(s) and its type is `tuple`,
        # - the rest arguments are parameters, and
        # - the return value is the new qubit state.

        for i in range(nqubits):
            # 1 qubit operation H on i-th wire
            x = op.Hadamard(x, (i,))

        for i in range(nqubits-1):
            # 2 qubit operation CRZ on i-th and (i+1)-th wires
            # The first wire is the control qubit.
            x = op.CRZ(x, (i, i+1), params["CRZ"][i])

        x = op.Rot(x, (4, ),
                   params["Rot"]["phi"], params["Rot"]["theta"], params["Rot"]["omega"])

        # Convert internal representation to state-vector
        # aka. [|00000>, |00001>, |00010>, ..., |11111>]
        return op.to_state(x)

    @jax.jit
    def expval(params):
        s = circuit(params)

        # Caluculate probabilities of each state
        p = diffqc.prob(s)

        # Expectation of |1> at certain wire.
        return diffqc.expval(p, 4)


    # Define parameters as compatible with JAX pytree.
    # Ref: https://jax.readthedocs.io/en/latest/pytrees.html
    p = {
        "CRZ": jnp.ones((4,)),
        "Rot": {
            "phi": jnp.ones((1,)),
            "theta": jnp.ones((1,)),
            "omega": jnp.ones((1,)),
        }
    }

    # Now you can call circuit and take gradient.
    print(f"expval: {expval(p)}")
    print(f"grad: {jax.grad(expval)(p)}")

if __name__ == "__main__":
    main()

01. Quantum Circuit Learning (QCL) with Flax#

"""
Exaple 01: QCL classification with Flax

K. Mitarai et al., "Quantum Circuit Learning", Phys. Rev. A 98, 032309 (2018)
https://doi.org/10.1103/PhysRevA.98.032309
https://arxiv.org/abs/1803.00745

This example additionally requires followings;
* Flax: https://flax.readthedocs.io/en/latest/index.html
* Optax: https://optax.readthedocs.io/en/latest/
* scikit-learn: https://scikit-learn.org/stable/
"""
import functools
import time
from typing import Callable

from diffqc import dense as op

from flax import linen as nn
from flax.training import train_state

import jax
import jax.numpy as jnp

import optax

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler


def circuit(n_qubits, depth, features, weights):
    msg = "BUG: n_qubits ({}) must be greater than feature size ({})."
    assert n_qubits >= features.shape[0], msg.format(n_qubits, features.shape[0])

    q = op.zeros(n_qubits, jnp.complex64)

    for idx in range(n_qubits):
        i = idx % features.shape[0]
        q = op.RY(q, (idx,), jnp.arcsin(features.at[i].get()))
        q = op.RZ(q, (idx,), jnp.arccos(features.at[i].get() ** 2))

    # Note: We use much simpler circuit than that of the original paper,
    #       however, it seems fine for this easy task.
    for k in range(depth):
        for idx in range(0, n_qubits-1):
            q = op.CNOT(q, (idx, idx+1))
        for idx in range(n_qubits):
            q = op.RY(q, (idx,), weights.at[k, idx].get())

    return jnp.stack(tuple(op.expectZ(q, (idx,)) for idx in range(n_qubits)))


class QCL(nn.Module):
    n_qubits: int
    depth: int
    output_dim: int
    circuit_init: Callable = nn.initializers.uniform(2 * jnp.pi)

    @nn.compact
    def __call__(self, inputs):
        x = inputs

        w = self.param("circuit", self.circuit_init, (self.depth, self.n_qubits))

        @jax.vmap
        def batch_circuit(e):
            return circuit(self.n_qubits, self.depth, e, w)

        x = batch_circuit(x)
        x = nn.Dense(self.output_dim)(x)
        return x


def load_data():
    features, labels = load_iris(return_X_y=True)
    features = features[:,2:] # only sepal length/width
    scalar = MinMaxScaler(feature_range=(-1, 1))
    features_std = scalar.fit_transform(features)
    x_train, x_test, y_train, y_test = train_test_split(features_std, labels,
                                                        random_state=42, shuffle=True)
    return (jnp.asarray(x_train), jnp.asarray(y_train),
            jnp.asarray(x_test) , jnp.asarray(y_test))


def cross_entropy_loss(logits, labels, num_classes):
    y = jax.nn.one_hot(labels, num_classes=num_classes)
    return optax.softmax_cross_entropy(logits=logits, labels=y).mean()


def compute_metrics(logits, labels, num_classes):
    return {
        "loss": cross_entropy_loss(logits, labels, num_classes),
        "accuracy": jnp.mean(jnp.argmax(logits, -1) == labels),
    }


def create_train_state(rng, n_qubits, depth, lr, feature_shape, output_dim):
    qcl = QCL(n_qubits=n_qubits, depth=depth, output_dim=output_dim)
    params = qcl.init(rng, jnp.ones((1, *feature_shape)))["params"]

    tx = optax.adam(learning_rate=lr)
    return train_state.TrainState.create(apply_fn=qcl.apply, params=params, tx=tx)


@functools.partial(jax.jit, static_argnums=(3,4,5))
def train_step(state, x, y, n_qubits, depth, output_dim):
    def loss_fn(params):
        logits = QCL(n_qubits=n_qubits,
                     depth=depth,
                     output_dim=output_dim).apply({"params": params}, x)
        loss = cross_entropy_loss(logits, y, output_dim)
        return loss, logits

    grad_fn = jax.grad(loss_fn, has_aux=True)
    grads, logits = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits=logits, labels=y, num_classes=output_dim)
    return state, metrics


@functools.partial(jax.jit, static_argnums=(3,4,5))
def eval_step(params, x, y, n_qubits, depth, output_dim):
    logits = QCL(n_qubits=n_qubits,
                 depth=depth,
                 output_dim=output_dim).apply({"params": params}, x)
    return compute_metrics(logits=logits, labels=y, num_classes=output_dim)


def main():
    # Circuit
    n_qubits = 4
    depth = 2

    # Learning
    lr = 0.05
    epochs = 100

    # Data
    num_classes = 3
    x_train, y_train, x_test, y_test = load_data()

    rng = jax.random.PRNGKey(0)
    rng, rng_apply = jax.random.split(rng)

    state = create_train_state(rng_apply,
                               n_qubits=n_qubits, depth=depth, lr=lr,
                               feature_shape=x_train.shape[1:],
                               output_dim=num_classes)


    for e in range(epochs):
        # Notes: Data size is enough small, we don't divide data to mini-batch
        t = time.perf_counter()

        state, metrics = train_step(state, x_train, y_train,
                                    n_qubits=n_qubits, depth=depth,
                                    output_dim=num_classes)
        print(f"Epoch: {e:2d} [Train] Loss: {metrics['loss']:.4f}, " +
              f"Accuracy: {metrics['accuracy'] * 100:.2f}%", end=" ")

        metrics = eval_step(state.params, x_test, y_test,
                            n_qubits=n_qubits, depth=depth,
                            output_dim=num_classes)
        print(f"[Eval] Loss: {metrics['loss']:.4f}, " +
              f"Accuracy: {metrics['accuracy'] * 100:.2f}%", end=" ")

        print(f"[Time] Elapsed: {time.perf_counter() - t:.4f}s")


if __name__ == "__main__":
    main()

02. CNN-like QCL#

"""
Example 02: CNN-like QCL classification with Flax

This example additionally requires followings;
* Flax: https://flax.readthedocs.io/en/latest/index.html
* Optax: https://optax.readthedocs.io/en/latest/
* scikit-learn: https://scikit-learn.org/stable/

Warnings
--------
This implementation is different from QCNN[1],
because the intermediate measurement and reaction are not easy for diffqc simulation,
and because this example implementation needs smaller qubits and shallow circuit.

[1] I. Cong et al., "Quantum Convolutional Neural Networks",
    Nature Phys. 15 1273-1278 (2019)
    https://doi.org/10.1038/s41567-019-0648-8
    https://arxiv.org/abs/1810.03787
"""
import functools
import time
from typing import Callable

import diffqc
from diffqc import dense as op

from flax import linen as nn
from flax.training import train_state

import jax
import jax.numpy as jnp

import optax

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split

from tqdm import tqdm


def conv3x3cell(x, w):
    x = jnp.reshape(x, (9,))

    q = op.zeros(9, jnp.complex64)
    for i in range(9):
        q = op.RY(q, (i,), x.at[i].get())

    for k in range(w.shape[0]):
        for i in range(8):
            q = op.CNOT(q, (i, i+1))
        for i in range(9):
            q = op.RY(q, (i,), w.at[k, i].get())

    return op.expectZ(q, (0,))


def ConvLayer(x, w):
    # [0, 1) -> [-pi/2, pi/2)
    x = jnp.arcsin(2 *(x - 0.5))

    # convolution
    F = diffqc.nn.Convolution(op, conv3x3cell,
                             kernel_shape = (3, 3),
                             slide = (1, 1),
                             padding = (1, 1))
    x = F(x, w)

    # pooling
    x = diffqc.nn.MaxPooling(x, (2, 2))

    return x

def DenseLayer(x, w):
    x = jnp.reshape(x, (-1,))

    q = op.zeros(x.shape[0], jnp.complex64)
    for i in range(x.shape[0]):
        q = op.RY(q, (i,), jnp.arcsin(x.at[i].get()))
        q = op.RZ(q, (i,), jnp.arccos(x.at[i].get() ** 2))

    for k in range(w.shape[0]):
        for i in range(x.shape[0]-1):
            q = op.CNOT(q, (i, i+1))
        for i in range(x.shape[0]):
            q = op.RY(q, (i,), w.at[k,i].get())

    for i in range(x.shape[0]):
        q = op.PauliZ(q, (i,))

    p = diffqc.prob(op.to_state(q))
    return jnp.stack(tuple(diffqc.expval(p, i) for i in range(x.shape[0])))


class ConvQCL(nn.Module):
    cdepth: int
    ddepth: int
    output_dim: int
    circuit_init: Callable = nn.initializers.uniform(2 * jnp.pi)

    @nn.compact
    def __call__(self, inputs):
        x = inputs
        wc = self.param("conv", self.circuit_init, (self.cdepth, 9))
        wd = self.param("dense", self.circuit_init, (self.ddepth, 16))

        @jax.vmap
        def batch(xi):
            # 8 x 8 -> 4 x 4
            xi = ConvLayer(xi, wc)

            xi = DenseLayer(xi, wd)
            return xi

        x = batch(x)
        x = nn.Dense(self.output_dim)(x)
        return x

def load_data():
    features, labels = load_digits(return_X_y=True)
    features_std = features.reshape((-1, 8, 8)) / 256.0
    x_train, x_test, y_train, y_test = train_test_split(features_std, labels,
                                                        random_state=42, shuffle=True)
    return (jnp.asarray(x_train), jnp.asarray(y_train),
            jnp.asarray(x_test) , jnp.asarray(y_test))


def cross_entropy_loss(logits, labels, num_classes):
    y = jax.nn.one_hot(labels, num_classes)
    return optax.softmax_cross_entropy(logits=logits, labels=y).mean()


def compute_metrics(logits, labels, num_classes):
    return {
        "loss": cross_entropy_loss(logits, labels, num_classes),
        "accuracy": jnp.mean(jnp.argmax(logits, -1) == labels),
    }


def create_train_state(rng, cdepth, ddepth, lr, input_shape, output_dim):
    qcl = ConvQCL(cdepth=cdepth, ddepth=ddepth, output_dim=output_dim)
    params = qcl.init(rng, jnp.ones(input_shape))["params"]

    tx = optax.adam(learning_rate=lr)
    return train_state.TrainState.create(apply_fn=qcl.apply, params=params, tx=tx)


@functools.partial(jax.jit, static_argnums=(3,4,5))
def train_step(state, x, y, cdepth, ddepth, output_dim):
    def loss_fn(params):
        logits = ConvQCL(cdepth=cdepth,
                         ddepth=ddepth,
                         output_dim=output_dim).apply({"params": params}, x)
        loss = cross_entropy_loss(logits, y, output_dim)
        return loss, logits

    grad_fn = jax.grad(loss_fn, has_aux=True)
    grads, logits = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits=logits, labels=y, num_classes=output_dim)
    return state, metrics


@functools.partial(jax.jit, static_argnums=(3,4,5))
def eval_step(params, x, y, cdepth, ddepth, output_dim):
    logits = ConvQCL(cdepth=cdepth,
                     ddepth=ddepth,
                     output_dim=output_dim).apply({"params": params}, x)
    return compute_metrics(logits=logits, labels=y, num_classes=output_dim)


def main():
    # Circuit
    cdepth = 2
    ddepth = 2

    # Learning
    lr = 0.05
    epochs = 50
    batch_size = 16

    # Data
    num_classes = 10
    x_train, y_train, x_test, y_test = load_data()

    rng = jax.random.PRNGKey(0)
    rng, rng_apply = jax.random.split(rng)

    state = create_train_state(rng_apply,
                               cdepth=cdepth, ddepth=ddepth, lr=lr,
                               input_shape=(batch_size, *x_train.shape[1:]),
                               output_dim=num_classes)


    train_ds_size = len(x_train)
    steps_per_epoch = train_ds_size // batch_size
    for e in range(epochs):
        t = time.perf_counter()

        rng, rng_apply = jax.random.split(rng)
        perms = jax.random.permutation(rng_apply, train_ds_size)
        perms = perms[:steps_per_epoch * batch_size]
        perms = perms.reshape((steps_per_epoch, batch_size))

        batch_metrics = []
        for perm_idx in tqdm(perms, ascii=True):
            state, metrics = train_step(state,
                                        x_train.at[perm_idx].get(),
                                        y_train.at[perm_idx].get(),
                                        cdepth=cdepth, ddepth=ddepth,
                                        output_dim=num_classes)
            batch_metrics.append(metrics)

        loss = jnp.mean(jnp.asarray(tuple(m["loss"] for m in batch_metrics)))
        acc = jnp.mean(jnp.asarray(tuple(m["accuracy"] for m in batch_metrics)))
        print(f"Epoch: {e:2d} [Train] Loss: {loss:.4f}, " +
              f"Accuracy: {acc * 100:.2f}%", end=" ")

        metrics = eval_step(state.params, x_test, y_test,
                            cdepth=cdepth, ddepth=ddepth,
                            output_dim=num_classes)
        print(f"[Eval] Loss: {metrics['loss']:.4f}, " +
              f"Accuracy: {metrics['accuracy'] * 100:.2f}%", end=" ")

        print(f"[Time] Elapsed: {time.perf_counter() - t:.4f}s")


if __name__ == "__main__":
    main()

03. PennyLane plugin#

"""
Example 03: PennyLane plugin

Ref: https://pennylane.ai/qml/demos/tutorial_jax_transformations.html
"""
import pennylane as qml

def main():
    dev = qml.device("diffqc.qubit", wires=2)

    @qml.qnode(dev, interface="jax")
    def circuit(param):
        qml.RX(param, wires=0)
        qml.CNOT(wires=[0,1])

        return qml.expval(qml.PauliZ(0))

    print(circuit(0.123))


if __name__ == "__main__":
    main()

04. Builtin Parameterized Quantum Circuit (PQC): Circuit Centric Block#

"""
Example 04: Builtin Variational: Circuit Centric

This uses Code Block with range = 1 described at [1].
According to [2], this variational circuit is one of the
best choice in terms of expressivity, entangle capabilities,
and circuit cost.

Other part of this example is same with Example 01.

[1] M. Schuld et al., "Circuit-centric quantum classifiers",
    Phys. Rev. A 101, 032308 (2020) (arXiv:1804.00633)

[2] S. Sim et al., "Expressibility and entangling capability of
    parameterized quantum circuits for hybrid quantum-classical algorithms",
    Adv. Quantum Technol. 2 (2019) 1900070 (arXiv:1905.10876)

This example additionally requires followings;
* Flax: https://flax.readthedocs.io/en/latest/index.html
* Optax: https://optax.readthedocs.io/en/latest/
* scikit-learn: https://scikit-learn.org/stable/
"""
import functools
import time
from typing import Callable

from diffqc import dense as op
from diffqc.nn import CircuitCentricBlock

from flax import linen as nn
from flax.training import train_state

import jax
import jax.numpy as jnp

import optax

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler


def circuit(n_qubits, depth, features, weights):
    msg = "BUG: n_qubits ({}) must be greater than feature size ({})."
    assert n_qubits >= features.shape[0], msg.format(n_qubits, features.shape[0])

    q = op.zeros(n_qubits, jnp.complex64)

    for idx in range(n_qubits):
        i = idx % features.shape[0]
        q = op.RY(q, (idx,), jnp.arcsin(features.at[i].get()))
        q = op.RZ(q, (idx,), jnp.arccos(features.at[i].get() ** 2))

    # Variational Gate: Circuit Centric Block
    q = CircuitCentricBlock(op, q, tuple(i for i in range(n_qubits)), weights)

    return jnp.stack(tuple(op.expectZ(q, (idx,)) for idx in range(n_qubits)))


class QCL(nn.Module):
    n_qubits: int
    depth: int
    output_dim: int
    circuit_init: Callable = nn.initializers.uniform(2 * jnp.pi)

    @nn.compact
    def __call__(self, inputs):
        x = inputs

        w = self.param("circuit", self.circuit_init, (self.depth, 3 * self.n_qubits))

        @jax.vmap
        def batch_circuit(e):
            return circuit(self.n_qubits, self.depth, e, w)

        x = batch_circuit(x)
        x = nn.Dense(self.output_dim)(x)
        return x


def load_data():
    features, labels = load_iris(return_X_y=True)
    features = features[:,2:] # only sepal length/width
    scalar = MinMaxScaler(feature_range=(-1, 1))
    features_std = scalar.fit_transform(features)
    x_train, x_test, y_train, y_test = train_test_split(features_std, labels,
                                                        random_state=42, shuffle=True)
    return (jnp.asarray(x_train), jnp.asarray(y_train),
            jnp.asarray(x_test) , jnp.asarray(y_test))


def cross_entropy_loss(logits, labels, num_classes):
    y = jax.nn.one_hot(labels, num_classes=num_classes)
    return optax.softmax_cross_entropy(logits=logits, labels=y).mean()


def compute_metrics(logits, labels, num_classes):
    return {
        "loss": cross_entropy_loss(logits, labels, num_classes),
        "accuracy": jnp.mean(jnp.argmax(logits, -1) == labels),
    }


def create_train_state(rng, n_qubits, depth, lr, feature_shape, output_dim):
    qcl = QCL(n_qubits=n_qubits, depth=depth, output_dim=output_dim)
    params = qcl.init(rng, jnp.ones((1, *feature_shape)))["params"]

    tx = optax.adam(learning_rate=lr)
    return train_state.TrainState.create(apply_fn=qcl.apply, params=params, tx=tx)


@functools.partial(jax.jit, static_argnums=(3,4,5))
def train_step(state, x, y, n_qubits, depth, output_dim):
    def loss_fn(params):
        logits = QCL(n_qubits=n_qubits,
                     depth=depth,
                     output_dim=output_dim).apply({"params": params}, x)
        loss = cross_entropy_loss(logits, y, output_dim)
        return loss, logits

    grad_fn = jax.grad(loss_fn, has_aux=True)
    grads, logits = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits=logits, labels=y, num_classes=output_dim)
    return state, metrics


@functools.partial(jax.jit, static_argnums=(3,4,5))
def eval_step(params, x, y, n_qubits, depth, output_dim):
    logits = QCL(n_qubits=n_qubits,
                 depth=depth,
                 output_dim=output_dim).apply({"params": params}, x)
    return compute_metrics(logits=logits, labels=y, num_classes=output_dim)


def main():
    # Circuit
    n_qubits = 4
    depth = 3

    # Learning
    lr = 0.05
    epochs = 100

    # Data
    num_classes = 3
    x_train, y_train, x_test, y_test = load_data()

    rng = jax.random.PRNGKey(0)
    rng, rng_apply = jax.random.split(rng)

    state = create_train_state(rng_apply,
                               n_qubits=n_qubits, depth=depth, lr=lr,
                               feature_shape=x_train.shape[1:],
                               output_dim=num_classes)


    for e in range(epochs):
        # Notes: Data size is enough small, we don't divide data to mini-batch
        t = time.perf_counter()

        state, metrics = train_step(state, x_train, y_train,
                                    n_qubits=n_qubits, depth=depth,
                                    output_dim=num_classes)
        print(f"Epoch: {e:2d} [Train] Loss: {metrics['loss']:.4f}, " +
              f"Accuracy: {metrics['accuracy'] * 100:.2f}%", end=" ")

        metrics = eval_step(state.params, x_test, y_test,
                            n_qubits=n_qubits, depth=depth,
                            output_dim=num_classes)
        print(f"[Eval] Loss: {metrics['loss']:.4f}, " +
              f"Accuracy: {metrics['accuracy'] * 100:.2f}%", end=" ")

        print(f"[Time] Elapsed: {time.perf_counter() - t:.4f}s")


if __name__ == "__main__":
    main()

05. Builtin Parameterinzed Quantum Circuit (PQC): Josephson Sampler#

"""
Example 05: Builtin Variational: Josephson Sampler

This uses Josephson Sampler described at [1].
According to [2], this variational circuit is one of the
best choice in terms of expressivity, entangle capabilities,
and circuit cost.

Other part of this example is same with Example 01.

[1] M. R. Geller, "Sampling and scrambling on a chain of superconducting qubits",
    Phys. Rev. Applied 10, 024052 (2018) (arXiv:1711.11026)

[2] S. Sim et al., "Expressibility and entangling capability of
    parameterized quantum circuits for hybrid quantum-classical algorithms",
    Adv. Quantum Technol. 2 (2019) 1900070 (arXiv:1905.10876)

This example additionally requires followings;
* Flax: https://flax.readthedocs.io/en/latest/index.html
* Optax: https://optax.readthedocs.io/en/latest/
* scikit-learn: https://scikit-learn.org/stable/
"""
import functools
import time
from typing import Callable

from diffqc import dense as op
from diffqc.nn import JosephsonSampler

from flax import linen as nn
from flax.training import train_state

import jax
import jax.numpy as jnp

import optax

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler


def circuit(n_qubits, depth, features, weights):
    msg = "BUG: n_qubits ({}) must be greater than feature size ({})."
    assert n_qubits >= features.shape[0], msg.format(n_qubits, features.shape[0])

    q = op.zeros(n_qubits, jnp.complex64)

    for idx in range(n_qubits):
        i = idx % features.shape[0]
        q = op.RY(q, (idx,), jnp.arcsin(features.at[i].get()))
        q = op.RZ(q, (idx,), jnp.arccos(features.at[i].get() ** 2))

    # Variational Gate: Circuit Centric Block
    q = JosephsonSampler(op, q, tuple(i for i in range(n_qubits)), weights)

    return jnp.stack(tuple(op.expectZ(q, (idx,)) for idx in range(n_qubits)))


class QCL(nn.Module):
    n_qubits: int
    depth: int
    output_dim: int
    circuit_init: Callable = nn.initializers.uniform(2 * jnp.pi)

    @nn.compact
    def __call__(self, inputs):
        x = inputs

        w = self.param("circuit", self.circuit_init, (self.depth,
                                                      4 * (self.n_qubits - 1)))

        @jax.vmap
        def batch_circuit(e):
            return circuit(self.n_qubits, self.depth, e, w)

        x = batch_circuit(x)
        x = nn.Dense(self.output_dim)(x)
        return x


def load_data():
    features, labels = load_iris(return_X_y=True)
    features = features[:,2:] # only sepal length/width
    scalar = MinMaxScaler(feature_range=(-1, 1))
    features_std = scalar.fit_transform(features)
    x_train, x_test, y_train, y_test = train_test_split(features_std, labels,
                                                        random_state=42, shuffle=True)
    return (jnp.asarray(x_train), jnp.asarray(y_train),
            jnp.asarray(x_test) , jnp.asarray(y_test))


def cross_entropy_loss(logits, labels, num_classes):
    y = jax.nn.one_hot(labels, num_classes=num_classes)
    return optax.softmax_cross_entropy(logits=logits, labels=y).mean()


def compute_metrics(logits, labels, num_classes):
    return {
        "loss": cross_entropy_loss(logits, labels, num_classes),
        "accuracy": jnp.mean(jnp.argmax(logits, -1) == labels),
    }


def create_train_state(rng, n_qubits, depth, lr, feature_shape, output_dim):
    qcl = QCL(n_qubits=n_qubits, depth=depth, output_dim=output_dim)
    params = qcl.init(rng, jnp.ones((1, *feature_shape)))["params"]

    tx = optax.adam(learning_rate=lr)
    return train_state.TrainState.create(apply_fn=qcl.apply, params=params, tx=tx)


@functools.partial(jax.jit, static_argnums=(3,4,5))
def train_step(state, x, y, n_qubits, depth, output_dim):
    def loss_fn(params):
        logits = QCL(n_qubits=n_qubits,
                     depth=depth,
                     output_dim=output_dim).apply({"params": params}, x)
        loss = cross_entropy_loss(logits, y, output_dim)
        return loss, logits

    grad_fn = jax.grad(loss_fn, has_aux=True)
    grads, logits = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits=logits, labels=y, num_classes=output_dim)
    return state, metrics


@functools.partial(jax.jit, static_argnums=(3,4,5))
def eval_step(params, x, y, n_qubits, depth, output_dim):
    logits = QCL(n_qubits=n_qubits,
                 depth=depth,
                 output_dim=output_dim).apply({"params": params}, x)
    return compute_metrics(logits=logits, labels=y, num_classes=output_dim)


def main():
    # Circuit
    n_qubits = 4
    depth = 3

    # Learning
    lr = 0.05
    epochs = 100

    # Data
    num_classes = 3
    x_train, y_train, x_test, y_test = load_data()

    rng = jax.random.PRNGKey(0)
    rng, rng_apply = jax.random.split(rng)

    state = create_train_state(rng_apply,
                               n_qubits=n_qubits, depth=depth, lr=lr,
                               feature_shape=x_train.shape[1:],
                               output_dim=num_classes)


    for e in range(epochs):
        # Notes: Data size is enough small, we don't divide data to mini-batch
        t = time.perf_counter()

        state, metrics = train_step(state, x_train, y_train,
                                    n_qubits=n_qubits, depth=depth,
                                    output_dim=num_classes)
        print(f"Epoch: {e:2d} [Train] Loss: {metrics['loss']:.4f}, " +
              f"Accuracy: {metrics['accuracy'] * 100:.2f}%", end=" ")

        metrics = eval_step(state.params, x_test, y_test,
                            n_qubits=n_qubits, depth=depth,
                            output_dim=num_classes)
        print(f"[Eval] Loss: {metrics['loss']:.4f}, " +
              f"Accuracy: {metrics['accuracy'] * 100:.2f}%", end=" ")

        print(f"[Time] Elapsed: {time.perf_counter() - t:.4f}s")


if __name__ == "__main__":
    main()