Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reconsider checking for input alias during function calls #1026

Open
ricardoV94 opened this issue Oct 10, 2024 · 0 comments · May be fixed by #1024
Open

Reconsider checking for input alias during function calls #1026

ricardoV94 opened this issue Oct 10, 2024 · 0 comments · May be fixed by #1024

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 10, 2024

Description

In trying to simplify Function.__call__, (see #1024 and #222), I noticed some complicated logic to check if inputs marked as mutable (or borrowable) are not aliasing to the same memory of each other.

if (
not self.trust_input
and
# The getattr is only needed for old pickle
getattr(self, "_check_for_aliased_inputs", True)
):
# Collect aliased inputs among the storage space
args_share_memory = []
for i in range(len(self.input_storage)):
i_var = self.maker.inputs[i].variable
i_val = self.input_storage[i].storage[0]
if hasattr(i_var.type, "may_share_memory"):
is_aliased = False
for j in range(len(args_share_memory)):
group_j = zip(
[
self.maker.inputs[k].variable
for k in args_share_memory[j]
],
[
self.input_storage[k].storage[0]
for k in args_share_memory[j]
],
)
if any(
(
var.type is i_var.type
and var.type.may_share_memory(val, i_val)
)
for (var, val) in group_j
):
is_aliased = True
args_share_memory[j].append(i)
break
if not is_aliased:
args_share_memory.append([i])
# Check for groups of more than one argument that share memory
for group in args_share_memory:
if len(group) > 1:
# copy all but the first
for j in group[1:]:
self.input_storage[j].storage[0] = copy.copy(
self.input_storage[j].storage[0]
)

To avoid erroneous computation, __call__ tries to copy aliased inputs. However this logic is wrong because it assumes only variables with the same type can be aliased which doesn't make sense. See the example below where a matrix and a vector are aliased, which fails the check and return wrong values and corrupted input y which was not marked as mutable

import pytensor
import pytensor.tensor as pt
from pytensor import In
import numpy as np

x = pt.vector()
y = pt.matrix()

fn = pytensor.function([In(x, mutable=True), In(y, mutable=False)], [x * 2, y * 2])
fn.dprint(print_destroy_map=True)
# Mul [id A] d={0: [1]} 0
#  ├─ [2.] [id B]
#  └─ <Vector(float64, shape=(?,))> [id C]
# Mul [id D] d={0: [1]} 1
#  ├─ [[2.]] [id E]
#  └─ <Matrix(float64, shape=(?, ?))> [id F]

y_val = np.ones((2, 5))
x_val = y_val[0]  # x is an alias of y
res1, res2 = fn(x_val, y_val)
print(res1)
# [2. 2. 2. 2. 2.]

print(res2)  # Wrong
# [[4. 4. 4. 4. 4.]
#  [2. 2. 2. 2. 2.]]

print(y_val)  # Corrupted
# [[2. 2. 2. 2. 2.]
#  [1. 1. 1. 1. 1.]]

My suggestion is not to make the check for alias more robust (and therefore increase the Function call overhead), but instead to forego it completely. If users indicated that an input is mutable it shouldn't be too surprising that views of that input (or other variables sharing the same underlying memory) would also be corrupted.

@ricardoV94 ricardoV94 changed the title Reconsider checking for input alias Reconsider checking for input alias during function calls Oct 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant