[ https://issues.apache.org/jira/browse/SPARK-40549?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=17813780#comment-17813780 ]
Nicholas Chammas commented on SPARK-40549: ------------------------------------------ I think this is just a consequence of floating point arithmetic being imprecise. {code:python} >>> for i in range(10): ... o = Observation(f"test_{i}") ... df_o = df.observe(o, F.corr("id", "id2")) ... df_o.count() ... print(o.get) ... {'corr(id, id2)': 1.0} {'corr(id, id2)': 1.0000000000000002} {'corr(id, id2)': 1.0} {'corr(id, id2)': 1.0} {'corr(id, id2)': 1.0} {'corr(id, id2)': 1.0} {'corr(id, id2)': 1.0} {'corr(id, id2)': 1.0000000000000002} {'corr(id, id2)': 0.9999999999999999} {'corr(id, id2)': 1.0} {code} Unfortunately, {{corr}} seems to convert to float internally, so even if you give it decimals you will get a similar result: {code:python} >>> from decimal import Decimal >>> import pyspark.sql.functions as F >>> >>> df = spark.createDataFrame( ... [(Decimal(i), Decimal(i * 10)) for i in range(10)], ... schema="id decimal, id2 decimal", ... )for i in range(10): o = Observation(f"test_{i}") df_o = df.observe(o, F.corr("id", "id2")) df_o.count() print(o.get) >>> >>> for i in range(10): ... o = Observation(f"test_{i}") ... df_o = df.observe(o, F.corr("id", "id2")) ... df_o.count() ... print(o.get) ... {'corr(id, id2)': 1.0} {'corr(id, id2)': 1.0} {'corr(id, id2)': 1.0} {'corr(id, id2)': 0.9999999999999999} {'corr(id, id2)': 1.0} {'corr(id, id2)': 1.0000000000000002} {'corr(id, id2)': 1.0} {'corr(id, id2)': 1.0} {'corr(id, id2)': 1.0} {'corr(id, id2)': 1.0} {code} I don't think there is anything that can be done here. > PYSPARK: Observation computes the wrong results when using `corr` function > --------------------------------------------------------------------------- > > Key: SPARK-40549 > URL: https://issues.apache.org/jira/browse/SPARK-40549 > Project: Spark > Issue Type: Bug > Components: PySpark > Affects Versions: 3.3.0 > Environment: {code:java} > // lsb_release -a > No LSB modules are available. > Distributor ID: Ubuntu > Description: Ubuntu 22.04.1 LTS > Release: 22.04 > Codename: jammy {code} > {code:java} > // python -V > python 3.10.4 > {code} > {code:java} > // lshw -class cpu > *-cpu > description: CPU product: AMD Ryzen 9 3900X 12-Core Processor > vendor: Advanced Micro Devices [AMD] physical id: f bus info: > cpu@0 version: 23.113.0 serial: Unknown slot: AM4 > size: 2194MHz capacity: 4672MHz width: 64 bits clock: > 100MHz capabilities: lm fpu fpu_exception wp vme de pse tsc msr pae > mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht > syscall nx mmxext fxsr_opt pdpe1gb rdtscp x86-64 constant_tsc rep_good nopl > nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma > cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy > svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit > wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 > cdp_l3 hw_pstate ssbd mba ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm > rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves > cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr > rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean > flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif > v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es cpufreq > configuration: cores=12 enabledcores=12 microcode=141561875 threads=24 > {code} > Reporter: Herminio Vazquez > Priority: Major > Labels: correctness > > Minimalistic description of the odd computation results. > When creating a new `Observation` object and computing a simple correlation > function between 2 columns, the results appear to be non-deterministic. > {code:java} > # Init > from pyspark.sql import SparkSession, Observation > import pyspark.sql.functions as F > df = spark.createDataFrame([(float(i), float(i*10),) for i in range(10)], > schema="id double, id2 double") > for i in range(10): > o = Observation(f"test_{i}") > df_o = df.observe(o, F.corr("id", "id2").eqNullSafe(1.0)) > df_o.count() > print(o.get) > # Results > {'(corr(id, id2) <=> 1.0)': False} > {'(corr(id, id2) <=> 1.0)': False} > {'(corr(id, id2) <=> 1.0)': False} > {'(corr(id, id2) <=> 1.0)': True} > {'(corr(id, id2) <=> 1.0)': True} > {'(corr(id, id2) <=> 1.0)': True} > {'(corr(id, id2) <=> 1.0)': True} > {'(corr(id, id2) <=> 1.0)': True} > {'(corr(id, id2) <=> 1.0)': True} > {'(corr(id, id2) <=> 1.0)': False}{code} > -- This message was sent by Atlassian Jira (v8.20.10#820010) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org