-
Notifications
You must be signed in to change notification settings - Fork 634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
random_categorical returns float when it should return int #2337
Labels
bug
Unexpected behaviour that should be corrected (type)
Comments
Hi @0seba, could you please share a minimum reproduce? |
Hi, seems I rushed and the error is not caused from the cast, rather from the .greater op. The following example crashes import numpy as np
import coremltools as ct
import coremltools.converters.mil as mil
from coremltools.converters.mil import Builder as mb
BSZ = 1
QLEN = 1
VOCAB_SIZE = 151936
input_specs = [
# mb.TensorSpec((1, VOCAB_SIZE, QLEN), mil.input_types.types.fp16),
mb.TensorSpec((1, QLEN, VOCAB_SIZE), mil.input_types.types.fp32),
mb.TensorSpec((1,), mil.input_types.types.fp32),
mb.TensorSpec((1,), mil.input_types.types.fp32),
]
@mb.program(
input_specs=input_specs,
opset_version=mil.builder.AvailableTarget.iOS18
)
def top_p_sample(logits, temp, top_p):
factor = mb.real_div(x=np.float32(1), y=temp)
logits = mb.mul(x=logits, y=factor)
probs = mb.softmax(x=logits, axis=-1)
sorted_indices = mb.argsort(x=probs, axis=-1, ascending=True)
sorted_probs = mb.gather(x=probs, indices=sorted_indices, batch_dims=2, axis=2)
cumulative_probs = mb.cumsum(x=sorted_probs, axis=-1)
top_p_inv = mb.sub(x=np.float32(1), y=top_p)
selection = mb.greater(x=cumulative_probs, y=top_p_inv)
return selection
print(top_p_sample)
cml_model = ct.convert(
top_p_sample,
compute_units=ct.ComputeUnit.ALL,
compute_precision=ct.precision.FLOAT32,
minimum_deployment_target=ct.target.iOS18,
)
print(cml_model.predict({
'logits': np.random.randn(1, 1, VOCAB_SIZE).astype(np.float32),
'temp': np.float32([1]),
'top_p': np.float32([0.1]),
})) Whereas if I use |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I tried casting the output to int, but that fails
The text was updated successfully, but these errors were encountered: