For reference below is what I'd hoped to be able to do with
copy_from_user() [obviously needs to not just replace that
ALTERNATIVE_2 setup in _copy_from_user ... would have to invent
an ALTERNATIVE_3 to pick the new function for people willing
to sacrifice speed for recoverability]

BUT ... there are a maze of twisty little other places that
also need to be fixed:

1) There is iov_iter_fault_in_readable() that uses fault_in_pages_readable()
   to pre-fault the first (and last) page using a raw __get_user() call.
2) I tried to avoid that by injecting my error to the middle of a page,
   but still ended up in a "rep mov" copy function instead of mine. Not
   sure which one because no output from the core that died. :-(

-Tony

diff --git a/arch/x86/lib/copy_user_64.S b/arch/x86/lib/copy_user_64.S
index 982ce34f4a9b..e31d8964ac09 100644
--- a/arch/x86/lib/copy_user_64.S
+++ b/arch/x86/lib/copy_user_64.S
@@ -38,11 +38,15 @@ ENTRY(_copy_from_user)
        jc bad_from_user
        cmpq TI_addr_limit(%rax),%rcx
        ja bad_from_user
+#if 1
+       jmp copy_from_user_with_mce_check
+#else
        ALTERNATIVE_2 "jmp copy_user_generic_unrolled",         \
                      "jmp copy_user_generic_string",           \
                      X86_FEATURE_REP_GOOD,                     \
                      "jmp copy_user_enhanced_fast_string",     \
                      X86_FEATURE_ERMS
+#endif
 ENDPROC(_copy_from_user)
 
        .section .fixup,"ax"
diff --git a/arch/x86/lib/usercopy_64.c b/arch/x86/lib/usercopy_64.c
index 0a42327a59d7..c377a70474e0 100644
--- a/arch/x86/lib/usercopy_64.c
+++ b/arch/x86/lib/usercopy_64.c
@@ -6,6 +6,8 @@
  * Copyright 2002 Andi Kleen <[email protected]>
  */
 #include <linux/module.h>
+#include <linux/mm.h>
+#include <asm/traps.h>
 #include <asm/uaccess.h>
 
 /*
@@ -86,3 +88,25 @@ copy_user_handle_tail(char *to, char *from, unsigned len)
                memset(to, 0, len);
        return len;
 }
+
+__visible unsigned long
+copy_from_user_with_mce_check(void *to, const void __user *from, unsigned n)
+{
+       struct memcpy_trap_ret r;
+
+       stac();
+       r = memcpy_trap(to, (__force void *)from, n);
+       clac();
+
+       if (likely(r.bytes_left == 0))
+               return 0;
+
+       if (r.trap_nr == X86_TRAP_MC) {
+               volatile void *fault_addr = (volatile void *)from + n - 
r.bytes_left;
+               phys_addr_t p = virt_to_phys(fault_addr);
+
+               memory_failure(p >> PAGE_SHIFT, MCE_VECTOR, 0);
+       }
+
+       return r.bytes_left;
+}

Reply via email to