Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GBTRegressor to ONNX output looks like Classifier #703

Open
MelihAkdag opened this issue Oct 21, 2024 · 0 comments
Open

GBTRegressor to ONNX output looks like Classifier #703

MelihAkdag opened this issue Oct 21, 2024 · 0 comments

Comments

@MelihAkdag
Copy link

I had an issue when using "convert_sparkml" to export GBTRegressor model to ONNX model.
What I experienced is, when I predict using the ONNX model, the output type is int64. I wonder if the conversion step changes regression model to a classifier.

Here is the simplified steps to reproduce my issue:

Creating the gbm regressor object

gbm = GBTRegressor(featuresCol='features', labelCol='label')

Training the model with train data

gbm_model = gbm.fit(train_df)

initial_types = [('features', FloatTensorType([None, 4]))] # Four feature columns with float variable types

Convert the trained model (gbm_model) to ONNX

onnx_model = convert_sparkml(gbm_model, 'GBT Regressor Model', initial_types, spark_session=spark)

Save the ONNX model to a file

with open("gbt_model.onnx", "wb") as f:
f.write(onnx_model.SerializeToString())

import onnxruntime as rt

Load the ONNX model

sess = rt.InferenceSession("gbt_model.onnx", providers=["CPUExecutionProvider"])

Prepare input as a numpy array

input_data = np.array([[1.0, 0.0, 5.3, 259.9]], dtype=np.float32)

Run the model

input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
predictions = sess.run([label_name], {input_name: input_data})

print("Predicted value:", predictions)

And the output is:
Predicted value: [array([1], dtype=int64)]

I tested the same steps with RFRegressor instead of GBTRegressor and the output was as I expected.
I would appreciate if you could check convert_sparkml for GBTRegressor.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant