from pyspark.sql import Row A_DF = sc.parallelize( [ Row(id='A123', name='a1'), Row(id='A234', name='a2') ]).toDF()
B_DF = sc.parallelize( [ Row(id='A123', pid='A234', ename='e1') ]).toDF() join_df = B_DF.join(A_DF, B_DF.id==A_DF.id).drop(B_DF.id) final_join = join_df.join(A_DF, join_df.pid==A_DF.id) final_join.collect() what I expect is A123, a1, A234, a2, e1 it works by just simply change one line to A_DF = sc.parallelize( [ Row(id='A123', name='a1'), Row(id='A234', name='a2') ]).toDF() B_DF = sc.parallelize( [ Row(id='A123', pid='A234', ename='e1') ]).toDF() join_df = B_DF.join(A_DF, B_DF.id==A_DF.id).drop(B_DF.id) final_join = A_DF.join(join_df, join_df.pid==A_DF.id) final_join.collect() Anyone has any idea why is that?