Example#
00. Basic Usage#
"""
Example 00: Circuit Basics
"""
import jax
import jax.numpy as jnp
import diffqc
from diffqc import dense as op
def main():
nqubits = 5
@jax.jit
def circuit(params):
# Initialize |00..0> state
x = op.zeros(nqubits, jnp.complex64)
# Apply quantum operation
# At quantum operator,
# - the first argument is qubit state,
# - the second argument is wire(s) and its type is `tuple`,
# - the rest arguments are parameters, and
# - the return value is the new qubit state.
for i in range(nqubits):
# 1 qubit operation H on i-th wire
x = op.Hadamard(x, (i,))
for i in range(nqubits-1):
# 2 qubit operation CRZ on i-th and (i+1)-th wires
# The first wire is the control qubit.
x = op.CRZ(x, (i, i+1), params["CRZ"][i])
x = op.Rot(x, (4, ),
params["Rot"]["phi"], params["Rot"]["theta"], params["Rot"]["omega"])
# Convert internal representation to state-vector
# aka. [|00000>, |00001>, |00010>, ..., |11111>]
return op.to_state(x)
@jax.jit
def expval(params):
s = circuit(params)
# Caluculate probabilities of each state
p = diffqc.prob(s)
# Expectation of |1> at certain wire.
return diffqc.expval(p, 4)
# Define parameters as compatible with JAX pytree.
# Ref: https://jax.readthedocs.io/en/latest/pytrees.html
p = {
"CRZ": jnp.ones((4,)),
"Rot": {
"phi": jnp.ones((1,)),
"theta": jnp.ones((1,)),
"omega": jnp.ones((1,)),
}
}
# Now you can call circuit and take gradient.
print(f"expval: {expval(p)}")
print(f"grad: {jax.grad(expval)(p)}")
if __name__ == "__main__":
main()
01. Quantum Circuit Learning (QCL) with Flax#
"""
Exaple 01: QCL classification with Flax
K. Mitarai et al., "Quantum Circuit Learning", Phys. Rev. A 98, 032309 (2018)
https://doi.org/10.1103/PhysRevA.98.032309
https://arxiv.org/abs/1803.00745
This example additionally requires followings;
* Flax: https://flax.readthedocs.io/en/latest/index.html
* Optax: https://optax.readthedocs.io/en/latest/
* scikit-learn: https://scikit-learn.org/stable/
"""
import functools
import time
from typing import Callable
from diffqc import dense as op
from flax import linen as nn
from flax.training import train_state
import jax
import jax.numpy as jnp
import optax
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
def circuit(n_qubits, depth, features, weights):
msg = "BUG: n_qubits ({}) must be greater than feature size ({})."
assert n_qubits >= features.shape[0], msg.format(n_qubits, features.shape[0])
q = op.zeros(n_qubits, jnp.complex64)
for idx in range(n_qubits):
i = idx % features.shape[0]
q = op.RY(q, (idx,), jnp.arcsin(features.at[i].get()))
q = op.RZ(q, (idx,), jnp.arccos(features.at[i].get() ** 2))
# Note: We use much simpler circuit than that of the original paper,
# however, it seems fine for this easy task.
for k in range(depth):
for idx in range(0, n_qubits-1):
q = op.CNOT(q, (idx, idx+1))
for idx in range(n_qubits):
q = op.RY(q, (idx,), weights.at[k, idx].get())
return jnp.stack(tuple(op.expectZ(q, (idx,)) for idx in range(n_qubits)))
class QCL(nn.Module):
n_qubits: int
depth: int
output_dim: int
circuit_init: Callable = nn.initializers.uniform(2 * jnp.pi)
@nn.compact
def __call__(self, inputs):
x = inputs
w = self.param("circuit", self.circuit_init, (self.depth, self.n_qubits))
@jax.vmap
def batch_circuit(e):
return circuit(self.n_qubits, self.depth, e, w)
x = batch_circuit(x)
x = nn.Dense(self.output_dim)(x)
return x
def load_data():
features, labels = load_iris(return_X_y=True)
features = features[:,2:] # only sepal length/width
scalar = MinMaxScaler(feature_range=(-1, 1))
features_std = scalar.fit_transform(features)
x_train, x_test, y_train, y_test = train_test_split(features_std, labels,
random_state=42, shuffle=True)
return (jnp.asarray(x_train), jnp.asarray(y_train),
jnp.asarray(x_test) , jnp.asarray(y_test))
def cross_entropy_loss(logits, labels, num_classes):
y = jax.nn.one_hot(labels, num_classes=num_classes)
return optax.softmax_cross_entropy(logits=logits, labels=y).mean()
def compute_metrics(logits, labels, num_classes):
return {
"loss": cross_entropy_loss(logits, labels, num_classes),
"accuracy": jnp.mean(jnp.argmax(logits, -1) == labels),
}
def create_train_state(rng, n_qubits, depth, lr, feature_shape, output_dim):
qcl = QCL(n_qubits=n_qubits, depth=depth, output_dim=output_dim)
params = qcl.init(rng, jnp.ones((1, *feature_shape)))["params"]
tx = optax.adam(learning_rate=lr)
return train_state.TrainState.create(apply_fn=qcl.apply, params=params, tx=tx)
@functools.partial(jax.jit, static_argnums=(3,4,5))
def train_step(state, x, y, n_qubits, depth, output_dim):
def loss_fn(params):
logits = QCL(n_qubits=n_qubits,
depth=depth,
output_dim=output_dim).apply({"params": params}, x)
loss = cross_entropy_loss(logits, y, output_dim)
return loss, logits
grad_fn = jax.grad(loss_fn, has_aux=True)
grads, logits = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits=logits, labels=y, num_classes=output_dim)
return state, metrics
@functools.partial(jax.jit, static_argnums=(3,4,5))
def eval_step(params, x, y, n_qubits, depth, output_dim):
logits = QCL(n_qubits=n_qubits,
depth=depth,
output_dim=output_dim).apply({"params": params}, x)
return compute_metrics(logits=logits, labels=y, num_classes=output_dim)
def main():
# Circuit
n_qubits = 4
depth = 2
# Learning
lr = 0.05
epochs = 100
# Data
num_classes = 3
x_train, y_train, x_test, y_test = load_data()
rng = jax.random.PRNGKey(0)
rng, rng_apply = jax.random.split(rng)
state = create_train_state(rng_apply,
n_qubits=n_qubits, depth=depth, lr=lr,
feature_shape=x_train.shape[1:],
output_dim=num_classes)
for e in range(epochs):
# Notes: Data size is enough small, we don't divide data to mini-batch
t = time.perf_counter()
state, metrics = train_step(state, x_train, y_train,
n_qubits=n_qubits, depth=depth,
output_dim=num_classes)
print(f"Epoch: {e:2d} [Train] Loss: {metrics['loss']:.4f}, " +
f"Accuracy: {metrics['accuracy'] * 100:.2f}%", end=" ")
metrics = eval_step(state.params, x_test, y_test,
n_qubits=n_qubits, depth=depth,
output_dim=num_classes)
print(f"[Eval] Loss: {metrics['loss']:.4f}, " +
f"Accuracy: {metrics['accuracy'] * 100:.2f}%", end=" ")
print(f"[Time] Elapsed: {time.perf_counter() - t:.4f}s")
if __name__ == "__main__":
main()
02. CNN-like QCL#
"""
Example 02: CNN-like QCL classification with Flax
This example additionally requires followings;
* Flax: https://flax.readthedocs.io/en/latest/index.html
* Optax: https://optax.readthedocs.io/en/latest/
* scikit-learn: https://scikit-learn.org/stable/
Warnings
--------
This implementation is different from QCNN[1],
because the intermediate measurement and reaction are not easy for diffqc simulation,
and because this example implementation needs smaller qubits and shallow circuit.
[1] I. Cong et al., "Quantum Convolutional Neural Networks",
Nature Phys. 15 1273-1278 (2019)
https://doi.org/10.1038/s41567-019-0648-8
https://arxiv.org/abs/1810.03787
"""
import functools
import time
from typing import Callable
import diffqc
from diffqc import dense as op
from flax import linen as nn
from flax.training import train_state
import jax
import jax.numpy as jnp
import optax
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from tqdm import tqdm
def conv3x3cell(x, w):
x = jnp.reshape(x, (9,))
q = op.zeros(9, jnp.complex64)
for i in range(9):
q = op.RY(q, (i,), x.at[i].get())
for k in range(w.shape[0]):
for i in range(8):
q = op.CNOT(q, (i, i+1))
for i in range(9):
q = op.RY(q, (i,), w.at[k, i].get())
return op.expectZ(q, (0,))
def ConvLayer(x, w):
# [0, 1) -> [-pi/2, pi/2)
x = jnp.arcsin(2 *(x - 0.5))
# convolution
F = diffqc.nn.Convolution(op, conv3x3cell,
kernel_shape = (3, 3),
slide = (1, 1),
padding = (1, 1))
x = F(x, w)
# pooling
x = diffqc.nn.MaxPooling(x, (2, 2))
return x
def DenseLayer(x, w):
x = jnp.reshape(x, (-1,))
q = op.zeros(x.shape[0], jnp.complex64)
for i in range(x.shape[0]):
q = op.RY(q, (i,), jnp.arcsin(x.at[i].get()))
q = op.RZ(q, (i,), jnp.arccos(x.at[i].get() ** 2))
for k in range(w.shape[0]):
for i in range(x.shape[0]-1):
q = op.CNOT(q, (i, i+1))
for i in range(x.shape[0]):
q = op.RY(q, (i,), w.at[k,i].get())
for i in range(x.shape[0]):
q = op.PauliZ(q, (i,))
p = diffqc.prob(op.to_state(q))
return jnp.stack(tuple(diffqc.expval(p, i) for i in range(x.shape[0])))
class ConvQCL(nn.Module):
cdepth: int
ddepth: int
output_dim: int
circuit_init: Callable = nn.initializers.uniform(2 * jnp.pi)
@nn.compact
def __call__(self, inputs):
x = inputs
wc = self.param("conv", self.circuit_init, (self.cdepth, 9))
wd = self.param("dense", self.circuit_init, (self.ddepth, 16))
@jax.vmap
def batch(xi):
# 8 x 8 -> 4 x 4
xi = ConvLayer(xi, wc)
xi = DenseLayer(xi, wd)
return xi
x = batch(x)
x = nn.Dense(self.output_dim)(x)
return x
def load_data():
features, labels = load_digits(return_X_y=True)
features_std = features.reshape((-1, 8, 8)) / 256.0
x_train, x_test, y_train, y_test = train_test_split(features_std, labels,
random_state=42, shuffle=True)
return (jnp.asarray(x_train), jnp.asarray(y_train),
jnp.asarray(x_test) , jnp.asarray(y_test))
def cross_entropy_loss(logits, labels, num_classes):
y = jax.nn.one_hot(labels, num_classes)
return optax.softmax_cross_entropy(logits=logits, labels=y).mean()
def compute_metrics(logits, labels, num_classes):
return {
"loss": cross_entropy_loss(logits, labels, num_classes),
"accuracy": jnp.mean(jnp.argmax(logits, -1) == labels),
}
def create_train_state(rng, cdepth, ddepth, lr, input_shape, output_dim):
qcl = ConvQCL(cdepth=cdepth, ddepth=ddepth, output_dim=output_dim)
params = qcl.init(rng, jnp.ones(input_shape))["params"]
tx = optax.adam(learning_rate=lr)
return train_state.TrainState.create(apply_fn=qcl.apply, params=params, tx=tx)
@functools.partial(jax.jit, static_argnums=(3,4,5))
def train_step(state, x, y, cdepth, ddepth, output_dim):
def loss_fn(params):
logits = ConvQCL(cdepth=cdepth,
ddepth=ddepth,
output_dim=output_dim).apply({"params": params}, x)
loss = cross_entropy_loss(logits, y, output_dim)
return loss, logits
grad_fn = jax.grad(loss_fn, has_aux=True)
grads, logits = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits=logits, labels=y, num_classes=output_dim)
return state, metrics
@functools.partial(jax.jit, static_argnums=(3,4,5))
def eval_step(params, x, y, cdepth, ddepth, output_dim):
logits = ConvQCL(cdepth=cdepth,
ddepth=ddepth,
output_dim=output_dim).apply({"params": params}, x)
return compute_metrics(logits=logits, labels=y, num_classes=output_dim)
def main():
# Circuit
cdepth = 2
ddepth = 2
# Learning
lr = 0.05
epochs = 50
batch_size = 16
# Data
num_classes = 10
x_train, y_train, x_test, y_test = load_data()
rng = jax.random.PRNGKey(0)
rng, rng_apply = jax.random.split(rng)
state = create_train_state(rng_apply,
cdepth=cdepth, ddepth=ddepth, lr=lr,
input_shape=(batch_size, *x_train.shape[1:]),
output_dim=num_classes)
train_ds_size = len(x_train)
steps_per_epoch = train_ds_size // batch_size
for e in range(epochs):
t = time.perf_counter()
rng, rng_apply = jax.random.split(rng)
perms = jax.random.permutation(rng_apply, train_ds_size)
perms = perms[:steps_per_epoch * batch_size]
perms = perms.reshape((steps_per_epoch, batch_size))
batch_metrics = []
for perm_idx in tqdm(perms, ascii=True):
state, metrics = train_step(state,
x_train.at[perm_idx].get(),
y_train.at[perm_idx].get(),
cdepth=cdepth, ddepth=ddepth,
output_dim=num_classes)
batch_metrics.append(metrics)
loss = jnp.mean(jnp.asarray(tuple(m["loss"] for m in batch_metrics)))
acc = jnp.mean(jnp.asarray(tuple(m["accuracy"] for m in batch_metrics)))
print(f"Epoch: {e:2d} [Train] Loss: {loss:.4f}, " +
f"Accuracy: {acc * 100:.2f}%", end=" ")
metrics = eval_step(state.params, x_test, y_test,
cdepth=cdepth, ddepth=ddepth,
output_dim=num_classes)
print(f"[Eval] Loss: {metrics['loss']:.4f}, " +
f"Accuracy: {metrics['accuracy'] * 100:.2f}%", end=" ")
print(f"[Time] Elapsed: {time.perf_counter() - t:.4f}s")
if __name__ == "__main__":
main()
03. PennyLane plugin#
"""
Example 03: PennyLane plugin
Ref: https://pennylane.ai/qml/demos/tutorial_jax_transformations.html
"""
import pennylane as qml
def main():
dev = qml.device("diffqc.qubit", wires=2)
@qml.qnode(dev, interface="jax")
def circuit(param):
qml.RX(param, wires=0)
qml.CNOT(wires=[0,1])
return qml.expval(qml.PauliZ(0))
print(circuit(0.123))
if __name__ == "__main__":
main()
04. Builtin Parameterized Quantum Circuit (PQC): Circuit Centric Block#
"""
Example 04: Builtin Variational: Circuit Centric
This uses Code Block with range = 1 described at [1].
According to [2], this variational circuit is one of the
best choice in terms of expressivity, entangle capabilities,
and circuit cost.
Other part of this example is same with Example 01.
[1] M. Schuld et al., "Circuit-centric quantum classifiers",
Phys. Rev. A 101, 032308 (2020) (arXiv:1804.00633)
[2] S. Sim et al., "Expressibility and entangling capability of
parameterized quantum circuits for hybrid quantum-classical algorithms",
Adv. Quantum Technol. 2 (2019) 1900070 (arXiv:1905.10876)
This example additionally requires followings;
* Flax: https://flax.readthedocs.io/en/latest/index.html
* Optax: https://optax.readthedocs.io/en/latest/
* scikit-learn: https://scikit-learn.org/stable/
"""
import functools
import time
from typing import Callable
from diffqc import dense as op
from diffqc.nn import CircuitCentricBlock
from flax import linen as nn
from flax.training import train_state
import jax
import jax.numpy as jnp
import optax
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
def circuit(n_qubits, depth, features, weights):
msg = "BUG: n_qubits ({}) must be greater than feature size ({})."
assert n_qubits >= features.shape[0], msg.format(n_qubits, features.shape[0])
q = op.zeros(n_qubits, jnp.complex64)
for idx in range(n_qubits):
i = idx % features.shape[0]
q = op.RY(q, (idx,), jnp.arcsin(features.at[i].get()))
q = op.RZ(q, (idx,), jnp.arccos(features.at[i].get() ** 2))
# Variational Gate: Circuit Centric Block
q = CircuitCentricBlock(op, q, tuple(i for i in range(n_qubits)), weights)
return jnp.stack(tuple(op.expectZ(q, (idx,)) for idx in range(n_qubits)))
class QCL(nn.Module):
n_qubits: int
depth: int
output_dim: int
circuit_init: Callable = nn.initializers.uniform(2 * jnp.pi)
@nn.compact
def __call__(self, inputs):
x = inputs
w = self.param("circuit", self.circuit_init, (self.depth, 3 * self.n_qubits))
@jax.vmap
def batch_circuit(e):
return circuit(self.n_qubits, self.depth, e, w)
x = batch_circuit(x)
x = nn.Dense(self.output_dim)(x)
return x
def load_data():
features, labels = load_iris(return_X_y=True)
features = features[:,2:] # only sepal length/width
scalar = MinMaxScaler(feature_range=(-1, 1))
features_std = scalar.fit_transform(features)
x_train, x_test, y_train, y_test = train_test_split(features_std, labels,
random_state=42, shuffle=True)
return (jnp.asarray(x_train), jnp.asarray(y_train),
jnp.asarray(x_test) , jnp.asarray(y_test))
def cross_entropy_loss(logits, labels, num_classes):
y = jax.nn.one_hot(labels, num_classes=num_classes)
return optax.softmax_cross_entropy(logits=logits, labels=y).mean()
def compute_metrics(logits, labels, num_classes):
return {
"loss": cross_entropy_loss(logits, labels, num_classes),
"accuracy": jnp.mean(jnp.argmax(logits, -1) == labels),
}
def create_train_state(rng, n_qubits, depth, lr, feature_shape, output_dim):
qcl = QCL(n_qubits=n_qubits, depth=depth, output_dim=output_dim)
params = qcl.init(rng, jnp.ones((1, *feature_shape)))["params"]
tx = optax.adam(learning_rate=lr)
return train_state.TrainState.create(apply_fn=qcl.apply, params=params, tx=tx)
@functools.partial(jax.jit, static_argnums=(3,4,5))
def train_step(state, x, y, n_qubits, depth, output_dim):
def loss_fn(params):
logits = QCL(n_qubits=n_qubits,
depth=depth,
output_dim=output_dim).apply({"params": params}, x)
loss = cross_entropy_loss(logits, y, output_dim)
return loss, logits
grad_fn = jax.grad(loss_fn, has_aux=True)
grads, logits = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits=logits, labels=y, num_classes=output_dim)
return state, metrics
@functools.partial(jax.jit, static_argnums=(3,4,5))
def eval_step(params, x, y, n_qubits, depth, output_dim):
logits = QCL(n_qubits=n_qubits,
depth=depth,
output_dim=output_dim).apply({"params": params}, x)
return compute_metrics(logits=logits, labels=y, num_classes=output_dim)
def main():
# Circuit
n_qubits = 4
depth = 3
# Learning
lr = 0.05
epochs = 100
# Data
num_classes = 3
x_train, y_train, x_test, y_test = load_data()
rng = jax.random.PRNGKey(0)
rng, rng_apply = jax.random.split(rng)
state = create_train_state(rng_apply,
n_qubits=n_qubits, depth=depth, lr=lr,
feature_shape=x_train.shape[1:],
output_dim=num_classes)
for e in range(epochs):
# Notes: Data size is enough small, we don't divide data to mini-batch
t = time.perf_counter()
state, metrics = train_step(state, x_train, y_train,
n_qubits=n_qubits, depth=depth,
output_dim=num_classes)
print(f"Epoch: {e:2d} [Train] Loss: {metrics['loss']:.4f}, " +
f"Accuracy: {metrics['accuracy'] * 100:.2f}%", end=" ")
metrics = eval_step(state.params, x_test, y_test,
n_qubits=n_qubits, depth=depth,
output_dim=num_classes)
print(f"[Eval] Loss: {metrics['loss']:.4f}, " +
f"Accuracy: {metrics['accuracy'] * 100:.2f}%", end=" ")
print(f"[Time] Elapsed: {time.perf_counter() - t:.4f}s")
if __name__ == "__main__":
main()
05. Builtin Parameterinzed Quantum Circuit (PQC): Josephson Sampler#
"""
Example 05: Builtin Variational: Josephson Sampler
This uses Josephson Sampler described at [1].
According to [2], this variational circuit is one of the
best choice in terms of expressivity, entangle capabilities,
and circuit cost.
Other part of this example is same with Example 01.
[1] M. R. Geller, "Sampling and scrambling on a chain of superconducting qubits",
Phys. Rev. Applied 10, 024052 (2018) (arXiv:1711.11026)
[2] S. Sim et al., "Expressibility and entangling capability of
parameterized quantum circuits for hybrid quantum-classical algorithms",
Adv. Quantum Technol. 2 (2019) 1900070 (arXiv:1905.10876)
This example additionally requires followings;
* Flax: https://flax.readthedocs.io/en/latest/index.html
* Optax: https://optax.readthedocs.io/en/latest/
* scikit-learn: https://scikit-learn.org/stable/
"""
import functools
import time
from typing import Callable
from diffqc import dense as op
from diffqc.nn import JosephsonSampler
from flax import linen as nn
from flax.training import train_state
import jax
import jax.numpy as jnp
import optax
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
def circuit(n_qubits, depth, features, weights):
msg = "BUG: n_qubits ({}) must be greater than feature size ({})."
assert n_qubits >= features.shape[0], msg.format(n_qubits, features.shape[0])
q = op.zeros(n_qubits, jnp.complex64)
for idx in range(n_qubits):
i = idx % features.shape[0]
q = op.RY(q, (idx,), jnp.arcsin(features.at[i].get()))
q = op.RZ(q, (idx,), jnp.arccos(features.at[i].get() ** 2))
# Variational Gate: Circuit Centric Block
q = JosephsonSampler(op, q, tuple(i for i in range(n_qubits)), weights)
return jnp.stack(tuple(op.expectZ(q, (idx,)) for idx in range(n_qubits)))
class QCL(nn.Module):
n_qubits: int
depth: int
output_dim: int
circuit_init: Callable = nn.initializers.uniform(2 * jnp.pi)
@nn.compact
def __call__(self, inputs):
x = inputs
w = self.param("circuit", self.circuit_init, (self.depth,
4 * (self.n_qubits - 1)))
@jax.vmap
def batch_circuit(e):
return circuit(self.n_qubits, self.depth, e, w)
x = batch_circuit(x)
x = nn.Dense(self.output_dim)(x)
return x
def load_data():
features, labels = load_iris(return_X_y=True)
features = features[:,2:] # only sepal length/width
scalar = MinMaxScaler(feature_range=(-1, 1))
features_std = scalar.fit_transform(features)
x_train, x_test, y_train, y_test = train_test_split(features_std, labels,
random_state=42, shuffle=True)
return (jnp.asarray(x_train), jnp.asarray(y_train),
jnp.asarray(x_test) , jnp.asarray(y_test))
def cross_entropy_loss(logits, labels, num_classes):
y = jax.nn.one_hot(labels, num_classes=num_classes)
return optax.softmax_cross_entropy(logits=logits, labels=y).mean()
def compute_metrics(logits, labels, num_classes):
return {
"loss": cross_entropy_loss(logits, labels, num_classes),
"accuracy": jnp.mean(jnp.argmax(logits, -1) == labels),
}
def create_train_state(rng, n_qubits, depth, lr, feature_shape, output_dim):
qcl = QCL(n_qubits=n_qubits, depth=depth, output_dim=output_dim)
params = qcl.init(rng, jnp.ones((1, *feature_shape)))["params"]
tx = optax.adam(learning_rate=lr)
return train_state.TrainState.create(apply_fn=qcl.apply, params=params, tx=tx)
@functools.partial(jax.jit, static_argnums=(3,4,5))
def train_step(state, x, y, n_qubits, depth, output_dim):
def loss_fn(params):
logits = QCL(n_qubits=n_qubits,
depth=depth,
output_dim=output_dim).apply({"params": params}, x)
loss = cross_entropy_loss(logits, y, output_dim)
return loss, logits
grad_fn = jax.grad(loss_fn, has_aux=True)
grads, logits = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
metrics = compute_metrics(logits=logits, labels=y, num_classes=output_dim)
return state, metrics
@functools.partial(jax.jit, static_argnums=(3,4,5))
def eval_step(params, x, y, n_qubits, depth, output_dim):
logits = QCL(n_qubits=n_qubits,
depth=depth,
output_dim=output_dim).apply({"params": params}, x)
return compute_metrics(logits=logits, labels=y, num_classes=output_dim)
def main():
# Circuit
n_qubits = 4
depth = 3
# Learning
lr = 0.05
epochs = 100
# Data
num_classes = 3
x_train, y_train, x_test, y_test = load_data()
rng = jax.random.PRNGKey(0)
rng, rng_apply = jax.random.split(rng)
state = create_train_state(rng_apply,
n_qubits=n_qubits, depth=depth, lr=lr,
feature_shape=x_train.shape[1:],
output_dim=num_classes)
for e in range(epochs):
# Notes: Data size is enough small, we don't divide data to mini-batch
t = time.perf_counter()
state, metrics = train_step(state, x_train, y_train,
n_qubits=n_qubits, depth=depth,
output_dim=num_classes)
print(f"Epoch: {e:2d} [Train] Loss: {metrics['loss']:.4f}, " +
f"Accuracy: {metrics['accuracy'] * 100:.2f}%", end=" ")
metrics = eval_step(state.params, x_test, y_test,
n_qubits=n_qubits, depth=depth,
output_dim=num_classes)
print(f"[Eval] Loss: {metrics['loss']:.4f}, " +
f"Accuracy: {metrics['accuracy'] * 100:.2f}%", end=" ")
print(f"[Time] Elapsed: {time.perf_counter() - t:.4f}s")
if __name__ == "__main__":
main()