Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 contract in some orders on a circuit with ring struction may get incorrect multiplication counts. #99

Open
Yonv1943 opened this issue Apr 8, 2023 · 2 comments
Labels
discussion code understanding help wanted Extra attention is needed

Comments

@Yonv1943
Copy link
Collaborator

Yonv1943 commented Apr 8, 2023

收缩 sycamore 以及 tensor grid ,tensor ring 这种有环状结构的电路,会有bug,导致乘法次数计算错误
(刚好我们测试的 tensor train,tensor tree 没有环状结构)

需要有环状结构,且按某个顺序收缩张量节点,才会触发

下面的代码,在一个小规模的 sycamore 电路 NodesSycamoreN12M14 上得到,然后逐行检查发现了这个bug

num_nodes             51
num_edges             99
ban_edges              0

先粗略记录一下。

这是print代码

'''calculate the multiple and avoid repeat'''
contract_dims = node_dims_arys[node_i0] + node_dims_arys[node_i1]  # 计算收缩后的node 的邻接张量的维度 以及来源
contract_bool = node_bool_arys[node_i0] | node_bool_arys[node_i1]  # 计算收缩后的node 由哪些原初node 合成
# assert contract_dims.shape == (num_nodes, )
# assert contract_bool.shape == (num_nodes, )

print(';;;', i, node_i0, node_i1)
print(node_dims_arys[node_i0].numpy().astype(int))
print(node_dims_arys[node_i1].numpy().astype(int))
print(contract_dims.numpy().astype(int))
print(contract_bool.numpy().astype(int))

这是print内容。可以看到,对已经收缩的节点竟然进行了不可能的收缩,并且产生了多余的乘法次数。

;;; 52 tensor(3) tensor(9)
[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 352   0   0   0   0  96   0   0   0 128   0   0   0   0 320   0 128   0   0   0  64   0  64  64   0   0   0   0 192   0  64  64   0]
[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 352   0   0   0   0  96   0   0   0 128   0   0   0   0 320   0 128   0   0   0  64   0  64  64   0   0   0   0 192   0  64  64   0]
[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 704   0   0   0   0 192   0   0   0 256   0   0   0   0 640   0 256   0   0   0 128   0 128 128   0   0   0   0 384   0 128 128   0]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 0 1 1 1 0 1 1 1 1 0 1 0 1 0 1 0 1 0 0 1 1 0 0 0 1 0 0 1]
@Yonv1943 Yonv1943 changed the title 🐛 Shrinking in a certain order on a circuit with a loop can result in a incorrect multiplication count. 🐛 contract in some orders on a circuit with ring struction may get incorrect multiplication counts. Apr 8, 2023
@Yonv1943
Copy link
Collaborator Author

Yonv1943 commented Apr 8, 2023

The following code fix this bug in Vanilla (single) mode.

The if_diff fix this bug.

First, choose a reasonable data type. It is safer that use float32 instead of int/long for node_dims_tens .

        node_dims_tens = th.stack([self.node_dims_ten.clone() for _ in range(num_envs)]).type(th.float32)
        node_bool_tens = th.stack([self.node_bool_ten.clone() for _ in range(num_envs)]).type(th.bool)
"""Vanilla (single)"""
for j in range(num_envs):
    edge_i = edge_is[j]
    node_dims_arys = node_dims_tens[j]
    node_bool_arys = node_bool_tens[j]

    '''find two nodes of an edge_i'''
    node_i0, node_i1 = th.where(edges_ary == edge_i)[0]  # 找出这条edge 两端的node
    # assert isinstance(node_i0.item(), int)
    # assert isinstance(node_i1.item(), int)

    '''whether node_i0 and node_i1 are different'''
    if_diff = th.logical_not(node_bool_arys[node_i0, node_i1])

    '''calculate the multiple and avoid repeat'''
    contract_dims = node_dims_arys[node_i0] + node_dims_arys[node_i1] * if_diff  # 计算收缩后的node 的邻接张量的维数以及来源
    contract_bool = node_bool_arys[node_i0] | node_bool_arys[node_i1]  # 计算收缩后的node 由哪些原初node 合成
    # assert contract_dims.shape == (num_nodes, )
    # assert contract_bool.shape == (num_nodes, )

    # 收缩掉的edge 只需要算一遍乘法。因此上面对 两次重复的指数求和后乘以0.5
    mult_pow_time = contract_dims.sum(dim=0) - (contract_dims * contract_bool).sum(dim=0) * 0.5
    mult_pow_timess[j, i] = mult_pow_time * if_diff

    '''adjust two list: node_dims_arys, node_bool_arys'''
    # 如果两个张量是一样的,那么 `contract_bool & if_diff` 就会全部变成 False,让下面这行代码不修改任何数值
    contract_dims[contract_bool & if_diff] = 0  # 把收缩掉的边的乘法数量赋值为2**0,接下来不再参与乘法次数的计算
    node_dims_tens[j, contract_bool] = contract_dims.repeat(1, 1)  # 根据 bool 将所有收缩后的节点都刷新成相同的信息
    node_bool_tens[j, contract_bool] = contract_bool.repeat(1, 1)  # 根据 bool 将所有收缩后的节点都刷新成相同的信息

Vectorized version

"""Vectorized"""
'''find two nodes of an edge_i'''
vec_edges_ary: TEN = edges_ary[None, :, :]
vec_edges_is: TEN = edge_is[:, None, None]
res = th.where(vec_edges_ary == vec_edges_is)[1]
res = res.reshape((num_envs, 2))
node_i0s, node_i1s = res[:, 0], res[:, 1]
# assert node_i0s.shape == (num_envs, )
# assert node_i1s.shape == (num_envs, )

'''whether node_i0 and node_i1 are different'''
if_diffs = th.logical_not(node_bool_tens[vec_env_is, node_i0s, node_i1s])

'''calculate the multiple and avoid repeat'''
contract_dimss = node_dims_tens[vec_env_is, node_i0s] + node_dims_tens[
    vec_env_is, node_i1s] * if_diffs.unsqueeze(1)
contract_bools = node_bool_tens[vec_env_is, node_i0s] | node_bool_tens[vec_env_is, node_i1s]
# assert contract_dimss.shape == (num_envs, num_nodes)
# assert contract_bools.shape == (num_envs, num_nodes)

mult_pow_times = contract_dimss.sum(dim=1) - (contract_dimss * contract_bools).sum(dim=1) * 0.5
# assert mult_pow_times.shape == (num_envs, )
mult_pow_timess[:, i] = mult_pow_times * if_diffs

'''adjust two list: node_dims_arys, node_bool_arys'''
for j in range(num_envs):  # 根据 bool 将所有收缩后的节点都刷新成相同的信息
    contract_dims = contract_dimss[j]
    contract_bool = contract_bools[j]

    contract_dims[contract_bool & if_diffs[j]] = 0  # 把收缩掉的边的乘法数量赋值为2**0,接下来不再参与乘法次数的计算
    node_dims_tens[j, contract_bool] = contract_dims.repeat(1, 1)
    node_bool_tens[j, contract_bool] = contract_bool.repeat(1, 1)

@YangletLiu YangletLiu added help wanted Extra attention is needed discussion code understanding labels Apr 8, 2023
@Yonv1943
Copy link
Collaborator Author

Yonv1943 commented Apr 9, 2023

修改后,我还使用换底公式 # max_tmp_power / th.log2(th.tensor((10, ), device=device)) # Change of Base Formula 优化了修复这个bug之后带来的 溢出问题

        # 计算这个乘法个数时,即便用 float64 也偶尔会过拟合,所以先除以 2**temp_power ,求log10 后再恢复它
        max_tmp_power = mult_pow_timess.max(dim=1)[0] - 960  # automatically set `max - 960`, 960 < the limit 1024,
        multiple_times = (2 ** (mult_pow_timess - max_tmp_power.unsqueeze(1))).sum(dim=1)
        multiple_times = multiple_times.log10() + max_tmp_power / th.log2(th.tensor((10,), device=device))
        # max_tmp_power / th.log2(th.tensor((10, ), device=device))  # Change of Base Formula

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion code understanding help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants