commit d3780956a991d64056a4bade210fffca53c531a3
Author:     Roberto E. Vargas Caballero <k...@shike2.com>
AuthorDate: Fri Sep 22 23:10:30 2023 +0200
Commit:     Roberto E. Vargas Caballero <k...@shike2.com>
CommitDate: Tue Sep 26 11:33:45 2023 +0200

    ed: Deal signals in a reliable way
    
    The signal handlers were calling longjmp() but as the code was calling
    non signal safe functions the behaviour was very unpredictable generating
    segmentation faults and dead lock. This commit changes the signal handlers
    to only set a variable that is checked in safe places where long loops
    happen.

diff --git a/TODO b/TODO
index 7a21d8f..3b60785 100644
--- a/TODO
+++ b/TODO
@@ -32,7 +32,6 @@ ed
     g/^line/a \
     line1
     .
-* Signal handling is broken.
 * cat <<EOF | ed
     0a
     int radix = 16;
diff --git a/ed.c b/ed.c
index b7ab16f..b430e74 100644
--- a/ed.c
+++ b/ed.c
@@ -71,6 +71,8 @@ static struct undo udata;
 static int newcmd;
 int eol, bol;
 
+static sig_atomic_t intr, hup;
+
 static void
 discard(void)
 {
@@ -119,6 +121,17 @@ prevln(int line)
        return (line < 0) ? lastln : line;
 }
 
+static String *
+string(String *s)
+{
+       free(s->str);
+       s->str = NULL;
+       s->siz = 0;
+       s->cap = 0;
+
+       return s;
+}
+
 static char *
 addchar(char c, String *s)
 {
@@ -136,6 +149,8 @@ addchar(char c, String *s)
        return t;
 }
 
+static void chksignals(void);
+
 static int
 input(void)
 {
@@ -146,6 +161,9 @@ input(void)
 
        if ((c = getchar()) != EOF)
                addchar(c, &cmdline);
+
+       chksignals();
+
        return c;
 }
 
@@ -455,6 +473,8 @@ search(int way)
 
        i = curln;
        do {
+               chksignals();
+
                i = (way == '?') ? prevln(i) : nextln(i);
                if (i > 0 && match(i))
                        return i;
@@ -636,12 +656,66 @@ deflines(int def1, int def2)
                error("invalid address");
 }
 
+static void
+quit(void)
+{
+       clearbuf();
+       exit(exstatus);
+}
+
+static void dowrite(const char *, int);
+
+static void
+dump(void)
+{
+       char *home;
+
+       line1 = nextln(0);
+       line2 = lastln;
+
+       if (!setjmp(savesp)) {
+               dowrite("ed.hup", 1);
+               return;
+       }
+
+       home = getenv("HOME");
+       if (!home || chdir(home) < 0)
+               return;
+
+       if (!setjmp(savesp))
+               dowrite("ed.hup", 1);
+}
+
+static void
+chksignals(void)
+{
+       if (hup) {
+               if (modflag)
+                       dump();
+               exstatus = 1;
+               quit();
+       }
+
+       if (intr) {
+               intr = 0;
+               clearerr(stdin);
+               error("Interrupt");
+       }
+}
+
 static void
 dowrite(const char *fname, int trunc)
 {
-       FILE *fp;
        size_t bytecount = 0;
-       int i, r, line, sh;
+       int i, r, line;
+       FILE *aux;
+       static int sh;
+       static FILE *fp;
+
+       if (fp) {
+               sh ? pclose(fp) : fclose(fp);
+               fp = NULL;
+       }
 
        if(fname[0] == '!') {
                sh = 1;
@@ -656,6 +730,8 @@ dowrite(const char *fname, int trunc)
 
        line = curln;
        for (i = line1; i <= line2; ++i) {
+               chksignals();
+
                gettxt(i);
                bytecount += text.siz - 1;
                fwrite(text.str, 1, text.siz - 1, fp);
@@ -663,7 +739,9 @@ dowrite(const char *fname, int trunc)
 
        curln = line2;
 
-       r = sh ? pclose(fp) : fclose(fp);
+       aux = fp;
+       fp = NULL;
+       r = sh ? pclose(aux) : fclose(aux);
        if (r)
                error("input/output error");
        strcpy(savfname, fname);
@@ -691,6 +769,7 @@ doread(const char *fname)
 
        curln = line2;
        for (cnt = 0; (n = getline(&s, &len, fp)) > 0; cnt += (size_t)n) {
+               chksignals();
                if (s[n-1] != '\n') {
                        if (len == SIZE_MAX || !(p = realloc(s, ++len)))
                                error("out of memory");
@@ -718,6 +797,7 @@ doprint(void)
        if (line1 <= 0 || line2 > lastln)
                error("incorrect address");
        for (i = line1; i <= line2; ++i) {
+               chksignals();
                if (pflag == 'n')
                        printf("%d\t", i);
                for (s = gettxt(i); (c = *s) != '\n'; ++s) {
@@ -867,11 +947,11 @@ join(void)
 {
        int i;
        char *t, c;
-       String s;
+       static String s;
 
-       s.str = NULL;
-       s.siz = s.cap = 0;
+       string(&s);
        for (i = line1;; i = nextln(i)) {
+               chksignals();
                for (t = gettxt(i); (c = *t) != '\n'; ++t)
                        addchar(*t, &s);
                if (i == line2)
@@ -898,6 +978,7 @@ scroll(int num)
        if (max > lastln)
                max = lastln;
        for (cnt = line1; cnt < max; cnt++) {
+               chksignals();
                fputs(gettxt(ln), stdout);
                ln = nextln(ln);
        }
@@ -913,6 +994,7 @@ copy(int where)
        curln = where;
 
        while (line1 <= line2) {
+               chksignals();
                inject(gettxt(line1), AFTER);
                if (line2 >= curln)
                        line2 = nextln(line2);
@@ -922,13 +1004,6 @@ copy(int where)
        }
 }
 
-static void
-quit(void)
-{
-       clearbuf();
-       exit(exstatus);
-}
-
 static void
 execsh(void)
 {
@@ -939,7 +1014,7 @@ execsh(void)
        skipblank();
        if ((c = input()) != '!') {
                back(c);
-               cmd.siz = 0;
+               string(&cmd);
        } else if (cmd.siz) {
                --cmd.siz;
                repl = 1;
@@ -973,9 +1048,7 @@ getrhs(int delim)
        int c;
        static String s;
 
-       free(s.str);
-       s.str = NULL;
-       s.siz = s.cap = 0;
+       string(&s);
        while ((c = input()) != '\n' && c != EOF && c != delim)
                addchar(c, &s);
        addchar('\0', &s);
@@ -1079,8 +1152,10 @@ subline(int num, int nth)
        int i, m, changed;
        static String s;
 
-       i = changed = s.siz = 0;
+       string(&s);
+       i = changed = 0;
        for (m = match(num); m; m = rematch(num)) {
+               chksignals();
                addpre(&s);
                changed |= addsub(&s, nth, ++i);
                if (eol || bol)
@@ -1099,8 +1174,10 @@ subst(int nth)
 {
        int i;
 
-       for (i = line1; i <= line2; ++i)
+       for (i = line1; i <= line2; ++i) {
+               chksignals();
                subline(i, nth);
+       }
 }
 
 static void
@@ -1362,6 +1439,7 @@ chkglobal(void)
        compile(delim);
 
        for (i = 1; i <= lastln; ++i) {
+               chksignals();
                if (i >= line1 && i <= line2)
                        v = match(i) == dir;
                else
@@ -1378,13 +1456,14 @@ doglobal(void)
        int cnt, ln, k;
 
        skipblank();
-       cmdline.siz = 0;
+       string(&cmdline);
        gflag = 1;
        if (uflag)
                chkprint(0);
 
        ln = line1;
        for (cnt = 0; cnt < lastln; ) {
+               chksignals();
                k = getindex(ln);
                if (zero[k].global) {
                        zero[k].global = 0;
@@ -1414,30 +1493,13 @@ usage(void)
 static void
 sigintr(int n)
 {
-       signal(SIGINT, sigintr);
-       error("interrupt");
+       intr = 1;
 }
 
 static void
 sighup(int dummy)
 {
-       int n;
-       char *home = getenv("HOME"), fname[FILENAME_MAX];
-
-       if (modflag) {
-               line1 = nextln(0);
-               line2 = lastln;
-               if (!setjmp(savesp)) {
-                       dowrite("ed.hup", 1);
-               } else if (home && !setjmp(savesp)) {
-                       n = snprintf(fname,
-                                    sizeof(fname), "%s/%s", home, "ed.hup");
-                       if (n < sizeof(fname) && n > 0)
-                               dowrite(fname, 1);
-               }
-       }
-       exstatus = 1;
-       quit();
+       hup = 1;
 }
 
 static void
@@ -1492,9 +1554,15 @@ main(int argc, char *argv[])
                usage();
 
        if (!setjmp(savesp)) {
-               signal(SIGINT, sigintr);
-               signal(SIGHUP, sighup);
-               signal(SIGQUIT, SIG_IGN);
+               sigaction(SIGINT,
+                         &(struct sigaction) {.sa_handler = sigintr},
+                         NULL);
+               sigaction(SIGHUP,
+                         &(struct sigaction) {.sa_handler = sighup},
+                         NULL);
+               sigaction(SIGQUIT,
+                         &(struct sigaction) {.sa_handler = SIG_IGN},
+                         NULL);
                init(*argv);
        }
        edit();

Reply via email to