diff --git a/src/qibojit/backends/cpu.py b/src/qibojit/backends/cpu.py index 4ae8dea..ce823ba 100644 --- a/src/qibojit/backends/cpu.py +++ b/src/qibojit/backends/cpu.py @@ -152,14 +152,21 @@ def _create_qubits_tensor(gate, nqubits): def _as_custom_matrix(self, gate): name = gate.__class__.__name__ + _matrix = getattr(self.custom_matrices, name) + if isinstance(gate, ParametrizedGate): - return getattr(self.custom_matrices, name)(*gate.parameters) - elif isinstance(gate, FusedGate): # pragma: no cover + if name == "GeneralizedRBS": # pragma: no cover + # this is tested in qibo tests + theta = gate.init_kwargs["theta"] + phi = gate.init_kwargs["phi"] + return _matrix(gate.init_args[0], gate.init_args[1], theta, phi) + return _matrix(*gate.parameters) + + if isinstance(gate, FusedGate): # pragma: no cover # fusion is tested in qibo tests return self.matrix_fused(gate) - else: - matrix = getattr(self.custom_matrices, name) - return matrix(2 ** len(gate.target_qubits)) if callable(matrix) else matrix + + return _matrix(2 ** len(gate.target_qubits)) if callable(_matrix) else _matrix def apply_gate(self, gate, state, nqubits): matrix = self._as_custom_matrix(gate)