Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor IonTrapQubitPass to use Visitor pattern #562

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 23 additions & 117 deletions quantum/plugins/iontrap/transformations/IonTrapTwoQubitPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "xacc_service.hpp"
#include "IonTrapTwoQubitPass.hpp"
#include "Accelerator.hpp"
#include "IonTrapTwoQubitPassVisitor.hpp"

namespace xacc {
namespace quantum {
Expand All @@ -23,6 +24,10 @@ namespace quantum {
// Two-qubit decompositions
//

IonTrapTwoQubitPass::IonTrapTwoQubitPass()
{
}

std::pair<double, double> IonTrapTwoQubitPass::findMSPhases(IonTrapMSPhaseMap *msPhases, InstPtr cnot) {
std::size_t leftIdx = std::min(cnot->bits()[0], cnot->bits()[1]);
std::size_t rightIdx = std::max(cnot->bits()[0], cnot->bits()[1]);
Expand Down Expand Up @@ -53,134 +58,35 @@ void IonTrapTwoQubitPass::apply(std::shared_ptr<CompositeInstruction> program,
logTransCallback = options.get<IonTrapLogTransformCallback>("log-trans-cb");
}

auto gateRegistry = xacc::getService<IRProvider>("quantum");
auto _twoQubitPassVisitor = std::make_shared<IonTrapTwoQubitPassVisitor>();

iontrapFlattenComposite(program);

HeterogeneousMap paramsMap{std::make_pair("composite", program),
std::make_pair("options", options)};

_twoQubitPassVisitor->_paramsMap = paramsMap;

for (std::size_t instIdx = 0; instIdx < program->nInstructions();) {
InstPtr inst = program->getInstruction(instIdx);
if (!inst->isEnabled()) {
instIdx++;
continue;
}

if (inst->name() == "CNOT") {
auto [controlMSPhase, targetMSPhase] = findMSPhases(msPhases, inst);
InstPtr ry1 = gateRegistry->createInstruction("Ry", {inst->bits()[0]}, {-M_PI/2.0});
InstPtr xx = gateRegistry->createInstruction("XX", inst->bits(), {M_PI/4.0});
InstPtr ry2 = gateRegistry->createInstruction("Ry", {inst->bits()[0]}, {M_PI/2.0});
InstPtr rz = gateRegistry->createInstruction("Rz", {inst->bits()[0]}, {M_PI/2.0});
InstPtr rx = gateRegistry->createInstruction("Rx", {inst->bits()[1]}, {M_PI/2.0});

std::size_t i = instIdx;
program->insertInstruction(i++, ry1);
// TODO: Note that this is kind of incorrect: really, the combination of these Rz gates
// and an MS gate is actually an XX gate (see https://doi.org/10.1088/1367-2630/18/2/023048)
// but we are surrounding an XX instruction with Rz instructions. But this will
// work for now
if (controlMSPhase) {
InstPtr msRz1 = gateRegistry->createInstruction("Rz", {inst->bits()[0]}, {controlMSPhase});
program->insertInstruction(i++, msRz1);
}
if (targetMSPhase) {
InstPtr msRz2 = gateRegistry->createInstruction("Rz", {inst->bits()[1]}, {targetMSPhase});
program->insertInstruction(i++, msRz2);
}
program->insertInstruction(i++, xx);
if (controlMSPhase) {
InstPtr msRz3 = gateRegistry->createInstruction("Rz", {inst->bits()[0]}, {-controlMSPhase});
program->insertInstruction(i++, msRz3);
}
if (targetMSPhase) {
InstPtr msRz4 = gateRegistry->createInstruction("Rz", {inst->bits()[1]}, {-targetMSPhase});
program->insertInstruction(i++, msRz4);
}
program->insertInstruction(i++, ry2);
program->insertInstruction(i++, rz);
program->insertInstruction(i++, rx);

if (logTransCallback) {
std::vector<InstPtr> newInsts;
for (std::size_t j = instIdx; j < i; j++) {
newInsts.push_back(program->getInstruction(j));
}
logTransCallback({inst}, newInsts);
}
} else if (inst->name() == "CH") {
InstPtr s = gateRegistry->createInstruction("S", {inst->bits()[1]});
InstPtr h = gateRegistry->createInstruction("H", {inst->bits()[1]});
InstPtr t = gateRegistry->createInstruction("T", {inst->bits()[1]});
InstPtr cx = gateRegistry->createInstruction("CNOT", inst->bits());
InstPtr tdg = gateRegistry->createInstruction("Tdg", {inst->bits()[1]});
InstPtr h2 = gateRegistry->createInstruction("H", {inst->bits()[1]});
InstPtr sdg = gateRegistry->createInstruction("Sdg", {inst->bits()[1]});

program->insertInstruction(instIdx, s);
program->insertInstruction(instIdx+1, h);
program->insertInstruction(instIdx+2, t);
program->insertInstruction(instIdx+3, cx);
program->insertInstruction(instIdx+4, tdg);
program->insertInstruction(instIdx+5, h2);
program->insertInstruction(instIdx+6, sdg);

if (logTransCallback) {
logTransCallback({program->getInstruction(instIdx+7)},
{program->getInstruction(instIdx),
program->getInstruction(instIdx+1),
program->getInstruction(instIdx+2),
program->getInstruction(instIdx+3),
program->getInstruction(instIdx+4),
program->getInstruction(instIdx+5),
program->getInstruction(instIdx+6)});
}
} else if (inst->name() == "CY") {
InstPtr sdg = gateRegistry->createInstruction("Sdg", {inst->bits()[1]});
InstPtr cx = gateRegistry->createInstruction("CNOT", inst->bits());
InstPtr s = gateRegistry->createInstruction("S", {inst->bits()[1]});

program->insertInstruction(instIdx, sdg);
program->insertInstruction(instIdx+1, cx);
program->insertInstruction(instIdx+2, s);

if (logTransCallback) {
logTransCallback({program->getInstruction(instIdx+3)},
{program->getInstruction(instIdx),
program->getInstruction(instIdx+1),
program->getInstruction(instIdx+2)});
}
} else if (inst->name() == "CZ") {
InstPtr h = gateRegistry->createInstruction("H", {inst->bits()[1]});
InstPtr cx = gateRegistry->createInstruction("CNOT", inst->bits());
InstPtr h2 = gateRegistry->createInstruction("H", {inst->bits()[1]});

program->insertInstruction(instIdx, h);
program->insertInstruction(instIdx+1, cx);
program->insertInstruction(instIdx+2, h2);

if (logTransCallback) {
logTransCallback({program->getInstruction(instIdx+3)},
{program->getInstruction(instIdx),
program->getInstruction(instIdx+1),
program->getInstruction(instIdx+2)});
}
} else if (inst->name() == "Swap") {
InstPtr cx1 = gateRegistry->createInstruction("CNOT", inst->bits());
InstPtr cx2 = gateRegistry->createInstruction("CNOT", {inst->bits()[1], inst->bits()[0]});
InstPtr cx3 = gateRegistry->createInstruction("CNOT", inst->bits());

program->insertInstruction(instIdx, cx1);
program->insertInstruction(instIdx+1, cx2);
program->insertInstruction(instIdx+2, cx3);

if (logTransCallback) {
logTransCallback({program->getInstruction(instIdx+3)},
{program->getInstruction(instIdx),
program->getInstruction(instIdx+1),
program->getInstruction(instIdx+2)});
}
} else {
_twoQubitPassVisitor->initializeInstructionVisitor(instIdx);

inst->attachMetadata({{"composite", program},
{"options", options}});

//std::cout << "instruction index is " << instIdx << ", name is " << inst->name() << std::endl;
inst->accept(_twoQubitPassVisitor);
//std::cout << "instruction index is " << instIdx << ", name is " << inst->name() << std::endl;

if (!_twoQubitPassVisitor->instructionVisited())
{
instIdx++;
continue;
continue;
}

inst->disable();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ namespace xacc {
namespace quantum {

typedef std::map<std::pair<std::size_t, std::size_t>, std::pair<double, double>> IonTrapMSPhaseMap;
class IonTrapTwoQubitPassVisitor;

class IonTrapTwoQubitPass : public IRTransformation {
public:
IonTrapTwoQubitPass() {}
IonTrapTwoQubitPass();
void apply(std::shared_ptr<CompositeInstruction> program,
const std::shared_ptr<Accelerator> accelerator,
const HeterogeneousMap &options = {}) override;
Expand Down
Loading