Russell Jurney created DATAFU-150:
-------------------------------------

             Summary: Add MultiLabelOneHotEncoder
                 Key: DATAFU-150
                 URL: https://issues.apache.org/jira/browse/DATAFU-150
             Project: DataFu
          Issue Type: Improvement
            Reporter: Russell Jurney
            Assignee: Russell Jurney


I have created the following code in Python to one-hot encode multilabel data 
and would like to add it to DataFu:

{{
questions_tags = filtered_lists.map(lambda x: Row(_Body=x[0], 
_Tags=x[1])).toDF()

# One-hot-encode the multilabel tags
enumerated_labels = [
    z for z in enumerate(
        sorted(
            remaining_tags_df.rdd
            .groupBy(lambda x: 1)
            .flatMap(lambda x: [y.tag for y in x[1]])
            .collect()
        )
    )
]
tag_index = {x: i for i, x in enumerated_labels}
index_tag = {i: x for i, x in enumerated_labels}

def one_hot_encode(tag_list, enumerated_labels):
    """PySpark can't one-hot-encode multilabel data, so we do it ourselves."""

    one_hot_row = []
    for i, label in enumerated_labels:
        if index_tag[i] in tag_list:
            one_hot_row.append(1)
        else:
            one_hot_row.append(0)
    assert(len(one_hot_row) == len(enumerated_labels))
    return one_hot_row

# Write the one-hot-encoded questions to S3 as a parquet file
one_hot_questions = questions_tags.rdd.map(
    lambda x: Row(_Body=x._Body, _Tags=one_hot_encode(x._Tags, 
enumerated_labels))
)

# Create a DataFrame for persisting as Parquet format
schema = T.StructType([
    T.StructField("_Body", T.ArrayType(
        T.StringType()
    )),
    T.StructField("_Tags", T.ArrayType(
        T.IntegerType()
    ))
])

one_hot_df = spark.createDataFrame(
    one_hot_questions,
    schema
)
one_hot_df.show()
}}





--
This message was sent by Atlassian Jira
(v8.3.4#803005)

Reply via email to