Erro no Pipeline do Spark

Estou tentando executar um modelo de regressão logística multinomial

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('prepare_data').getOrCreate()

from pyspark.sql.types import *
spark.sql("DROP TABLE IF EXISTS customers")
spark.sql("CREATE TABLE customers (
            Customer_ID DOUBLE, 
            Name STRING, 
            Gender STRING, 
            Address STRING, 
            Nationality DOUBLE, 
            Account_Type STRING, 
            Age DOUBLE, 
            Education STRING, 
            Employment STRING, 
            Salary DOUBLE, 
            Employer_Stability STRING, 
            Customer_Loyalty DOUBLE, 
            Balance DOUBLE, 
            Residential_Status STRING, 
            Service_Level STRING)")
spark.sql("LOAD DATA LOCAL INPATH '../datasets/dummyTrain.csv' INTO TABLE 
            customers")

dataset = spark.table("customers")
cols = dataset.columns
display(dataset)

from pyspark.ml import Pipeline
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler

categoricalColumns = ["Education", "Employment", "Employer_Stability", 
                      "Residential_Status"]
stages = [] 

for categoricalCol in categoricalColumns:
    stringIndexer = StringIndexer(inputCol=categoricalCol, 
        outputCol=categoricalCol+"Index")
    encoder = OneHotEncoder(inputCol=categoricalCol+"Index", 
        outputCol=categoricalCol+"classVec")
   stages += [stringIndexer, encoder]

label_stringIdx = StringIndexer(inputCol = "Service_Level", outputCol = 
    "label")
stages += [label_stringIdx]

numericCols = ["Age", "Salary", "Customer_Loyalty", "Balance"]
assemblerInputs = map(lambda c: c + "classVec", categoricalColumns) + 
    numericCols
assembler = VectorAssembler(inputCols=assemblerInputs, outputCol="features")
stages += [assembler]

pipeline = Pipeline(stages=stages)
pipelineModel = pipeline.fit(dataset)
dataset = pipelineModel.transform(dataset)
selectedcols = ["label", "features"] + cols
dataset = dataset.select(selectedcols)
display(dataset)

Eu estou recebendo o seguinte erro:

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-31-07d2fb5cecc8> in <module>()
      4 # - fit() computes feature statistics as needed
      5 # - transform() actually transforms the features
----> 6 pipelineModel = pipeline.fit(dataset)
      7 dataset = pipelineModel.transform(dataset)
      8 

/srv/spark/python/pyspark/ml/base.py in fit(self, dataset, params)
     62                 return self.copy(params)._fit(dataset)
     63             else:
---> 64                 return self._fit(dataset)
     65         else:
     66             raise ValueError("Params must be either a param map or a 
list/tuple of param maps, "

/srv/spark/python/pyspark/ml/pipeline.py in _fit(self, dataset)
    109                     transformers.append(model)
    110                     if i < indexOfLastEstimator:
--> 111                         dataset = model.transform(dataset)
    112             else:
    113                 transformers.append(stage)

/srv/spark/python/pyspark/ml/base.py in transform(self, dataset, params)
    103                 return self.copy(params)._transform(dataset)
    104             else:
--> 105                 return self._transform(dataset)
    106         else:
    107             raise ValueError("Params must be a param map but got 
%s." % type(params))

/srv/spark/python/pyspark/ml/wrapper.py in _transform(self, dataset)
    250     def _transform(self, dataset):
    251         self._transfer_params_to_java()
--> 252         return DataFrame(self._java_obj.transform(dataset._jdf), 
dataset.sql_ctx)
    253 
    254 

/srv/spark/python/lib/py4j-0.10.4-src.zip/py4j/java_gateway.py in 
__call__(self, *args)
   1131         answer = self.gateway_client.send_command(command)
   1132         return_value = get_return_value(
-> 1133             answer, self.gateway_client, self.target_id, self.name)
   1134 
   1135         for temp_arg in temp_args:

/srv/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
     61     def deco(*a, **kw):
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:
     65             s = e.java_exception.toString()

/srv/spark/python/lib/py4j-0.10.4-src.zip/py4j/protocol.py in 
get_return_value(answer, gateway_client, target_id, name)
    317                 raise Py4JJavaError(
    318                     "An error occurred while calling {0}{1}{2}.\n".
--> 319                     format(target_id, ".", name), value)
    320             else:
    321                 raise Py4JError(

Py4JJavaError: An error occurred while calling o798.transform.
: java.lang.NullPointerException at 

Não consegui descobrir o que fiz de errado e parece que o problema pode estar no método transform (). Qualquer ajuda seria apreciada.

questionAnswers(1)

yourAnswerToTheQuestion