Repository: incubator-zeppelin Updated Branches: refs/heads/master c2f9a0174 -> d5ab911bf
Fixing issue with ZEPPELIN-173: Zeppelin websocket server is vulnerab⦠Fixing the socket cross-origin vulnerability as described in the Jira. Overwrote the checkOrigin in the WebSocketServlet class implemented by NotebookServer so that a list of all seen socket Get requests are kept and only Upgrade requests from the same origin will be accepted. Otherwise unauthorized will be returned. Included basic unit tests. Author: joelz <[email protected]> Author: djoelz <[email protected]> Closes #205 from djoelz/master and squashes the following commits: 08ff369 [djoelz] unecessary file 013f22d [joelz] Fixing issue with ZEPPELIN-173: Zeppelin websocket server is vulnerable to Cross-Site WebSocket Hijacking ea54b55 [joelz] Fixing issue with ZEPPELIN-173: Zeppelin websocket server is vulnerable to Cross-Site WebSocket Hijacking Project: http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/commit/d5ab911b Tree: http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/tree/d5ab911b Diff: http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/diff/d5ab911b Branch: refs/heads/master Commit: d5ab911bf4419fa7c6f38945c6c8ad4946f8abf6 Parents: c2f9a01 Author: joelz <[email protected]> Authored: Thu Aug 13 11:31:15 2015 -0700 Committer: Lee moon soo <[email protected]> Committed: Sat Aug 15 09:14:09 2015 -0700 ---------------------------------------------------------------------- .../apache/zeppelin/socket/NotebookServer.java | 111 +++--- .../zeppelin/socket/NotebookServerTests.java | 53 +++ .../zeppelin/socket/TestHttpServletRequest.java | 372 +++++++++++++++++++ 3 files changed, 486 insertions(+), 50 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/d5ab911b/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java ---------------------------------------------------------------------- diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java b/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java index 07265d5..8c8b600 100644 --- a/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java +++ b/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java @@ -14,19 +14,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.zeppelin.socket; - import java.io.IOException; -import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.UnknownHostException; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; - import javax.servlet.http.HttpServletRequest; - import org.apache.zeppelin.display.AngularObject; import org.apache.zeppelin.display.AngularObjectRegistry; import org.apache.zeppelin.display.AngularObjectRegistryListener; @@ -46,28 +44,46 @@ import org.eclipse.jetty.websocket.WebSocketServlet; import org.quartz.SchedulerException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - import com.google.common.base.Strings; import com.google.gson.Gson; - /** * Zeppelin websocket service. * * @author anthonycorbacho */ public class NotebookServer extends WebSocketServlet implements - NotebookSocketListener, JobListenerFactory, AngularObjectRegistryListener { - + NotebookSocketListener, JobListenerFactory, AngularObjectRegistryListener { private static final Logger LOG = LoggerFactory - .getLogger(NotebookServer.class); - + .getLogger(NotebookServer.class); Gson gson = new Gson(); - Map<String, List<NotebookSocket>> noteSocketMap = new HashMap<String, List<NotebookSocket>>(); - List<NotebookSocket> connectedSockets = new LinkedList<NotebookSocket>(); + final Map<String, List<NotebookSocket>> noteSocketMap = new HashMap<>(); + final List<NotebookSocket> connectedSockets = new LinkedList<>(); private Notebook notebook() { return ZeppelinServer.notebook; } + @Override + public boolean checkOrigin(HttpServletRequest request, String origin) { + URI sourceUri = null; + String currentHost = null; + + try { + sourceUri = new URI(origin); + currentHost = java.net.InetAddress.getLocalHost().getHostName(); + } catch (UnknownHostException e) { + e.printStackTrace(); + } + catch (URISyntaxException e) { + e.printStackTrace(); + } + + String sourceHost = sourceUri.getHost(); + if (currentHost.equals(sourceHost) || "localhost".equals(sourceHost)) { + return true; + } + + return false; + } @Override public WebSocket doWebSocketConnect(HttpServletRequest req, String protocol) { @@ -153,8 +169,7 @@ public class NotebookServer extends WebSocketServlet implements } private Message deserializeMessage(String msg) { - Message m = gson.fromJson(msg, Message.class); - return m; + return gson.fromJson(msg, Message.class); } private String serializeMessage(Message m) { @@ -164,14 +179,13 @@ public class NotebookServer extends WebSocketServlet implements private void addConnectionToNote(String noteId, NotebookSocket socket) { synchronized (noteSocketMap) { removeConnectionFromAllNote(socket); // make sure a socket relates only a - // single note. + // single note. List<NotebookSocket> socketList = noteSocketMap.get(noteId); if (socketList == null) { - socketList = new LinkedList<NotebookSocket>(); + socketList = new LinkedList<>(); noteSocketMap.put(noteId, socketList); } - - if (socketList.contains(socket) == false) { + if (!socketList.contains(socket)) { socketList.add(socket); } } @@ -212,6 +226,7 @@ public class NotebookServer extends WebSocketServlet implements } } } + return id; } @@ -235,9 +250,7 @@ public class NotebookServer extends WebSocketServlet implements if (socketLists == null || socketLists.size() == 0) { return; } - LOG.info("SEND >> " + m.op); - for (NotebookSocket conn : socketLists) { try { conn.send(serializeMessage(m)); @@ -267,13 +280,14 @@ public class NotebookServer extends WebSocketServlet implements private void broadcastNoteList() { Notebook notebook = notebook(); List<Note> notes = notebook.getAllNotes(); - List<Map<String, String>> notesInfo = new LinkedList<Map<String, String>>(); + List<Map<String, String>> notesInfo = new LinkedList<>(); for (Note note : notes) { - Map<String, String> info = new HashMap<String, String>(); + Map<String, String> info = new HashMap<>(); info.put("id", note.id()); info.put("name", note.getName()); notesInfo.add(info); } + broadcastAll(new Message(OP.NOTES_INFO).put("notes", notesInfo)); } @@ -283,8 +297,8 @@ public class NotebookServer extends WebSocketServlet implements if (noteId == null) { return; } - Note note = notebook.getNote(noteId); + Note note = notebook.getNote(noteId); if (note != null) { addConnectionToNote(note.id(), conn); conn.send(serializeMessage(new Message(OP.NOTE).put("note", note))); @@ -304,17 +318,17 @@ public class NotebookServer extends WebSocketServlet implements if (config == null) { return; } + Note note = notebook.getNote(noteId); if (note != null) { boolean cronUpdated = isCronUpdated(config, note.getConfig()); note.setName(name); note.setConfig(config); - if (cronUpdated) { notebook.refreshCron(note.id()); } - note.persist(); + note.persist(); broadcastNote(note); broadcastNoteList(); } @@ -331,17 +345,19 @@ public class NotebookServer extends WebSocketServlet implements } else if (configA.get("cron") != null || configB.get("cron") != null) { cronUpdated = true; } + return cronUpdated; } - private void createNote(WebSocket conn, Notebook notebook, Message message) throws IOException { Note note = notebook.createNote(); note.addParagraph(); // it's an empty note. so add one paragraph if (message != null) { String noteName = (String) message.get("name"); - if (noteName != null && !noteName.isEmpty()) + if (noteName != null && !noteName.isEmpty()){ note.setName(noteName); + } } + note.persist(); broadcastNote(note); broadcastNoteList(); @@ -353,6 +369,7 @@ public class NotebookServer extends WebSocketServlet implements if (noteId == null) { return; } + Note note = notebook.getNote(noteId); notebook.removeNote(noteId); removeNote(noteId); @@ -365,6 +382,7 @@ public class NotebookServer extends WebSocketServlet implements if (paragraphId == null) { return; } + Map<String, Object> params = (Map<String, Object>) fromMessage .get("params"); Map<String, Object> config = (Map<String, Object>) fromMessage @@ -385,6 +403,7 @@ public class NotebookServer extends WebSocketServlet implements if (paragraphId == null) { return; } + final Note note = notebook.getNote(getOpenNoteId(conn)); /** We dont want to remove the last paragraph */ if (!note.isLastParagraph(paragraphId)) { @@ -400,7 +419,6 @@ public class NotebookServer extends WebSocketServlet implements String buffer = (String) fromMessage.get("buf"); int cursor = (int) Double.parseDouble(fromMessage.get("cursor").toString()); Message resp = new Message(OP.COMPLETION_LIST).put("id", paragraphId); - if (paragraphId == null) { conn.send(serializeMessage(resp)); return; @@ -414,10 +432,10 @@ public class NotebookServer extends WebSocketServlet implements /** * When angular object updated from client - * - * @param conn - * @param notebook - * @param fromMessage + * + * @param conn the web socket. + * @param notebook the notebook. + * @param fromMessage the message. */ private void angularObjectUpdated(WebSocket conn, Notebook notebook, Message fromMessage) { @@ -425,10 +443,8 @@ public class NotebookServer extends WebSocketServlet implements String interpreterGroupId = (String) fromMessage.get("interpreterGroupId"); String varName = (String) fromMessage.get("name"); Object varValue = fromMessage.get("value"); - AngularObject ao = null; boolean global = false; - // propagate change to (Remote) AngularObjectRegistry Note note = notebook.getNote(noteId); if (note != null) { @@ -438,11 +454,9 @@ public class NotebookServer extends WebSocketServlet implements if (setting.getInterpreterGroup() == null) { continue; } - if (interpreterGroupId.equals(setting.getInterpreterGroup().getId())) { AngularObjectRegistry angularObjectRegistry = setting .getInterpreterGroup().getAngularObjectRegistry(); - // first trying to get local registry ao = angularObjectRegistry.get(varName, noteId); if (ao == null) { @@ -460,14 +474,13 @@ public class NotebookServer extends WebSocketServlet implements ao.set(varValue, false); global = false; } - break; } } } if (global) { // broadcast change to all web session that uses related - // interpreter. + // interpreter. for (Note n : notebook.getAllNotes()) { List<InterpreterSetting> settings = note.getNoteReplLoader() .getInterpreterSettings(); @@ -475,7 +488,6 @@ public class NotebookServer extends WebSocketServlet implements if (setting.getInterpreterGroup() == null) { continue; } - if (interpreterGroupId.equals(setting.getInterpreterGroup().getId())) { AngularObjectRegistry angularObjectRegistry = setting .getInterpreterGroup().getAngularObjectRegistry(); @@ -514,8 +526,7 @@ public class NotebookServer extends WebSocketServlet implements private void insertParagraph(NotebookSocket conn, Notebook notebook, Message fromMessage) throws IOException { final int index = (int) Double.parseDouble(fromMessage.get("index") - .toString()); - + .toString()); final Note note = notebook.getNote(getOpenNoteId(conn)); note.insertParagraph(index); note.persist(); @@ -540,24 +551,25 @@ public class NotebookServer extends WebSocketServlet implements if (paragraphId == null) { return; } + final Note note = notebook.getNote(getOpenNoteId(conn)); Paragraph p = note.getParagraph(paragraphId); String text = (String) fromMessage.get("paragraph"); p.setText(text); p.setTitle((String) fromMessage.get("title")); Map<String, Object> params = (Map<String, Object>) fromMessage - .get("params"); + .get("params"); p.settings.setParams(params); Map<String, Object> config = (Map<String, Object>) fromMessage - .get("config"); + .get("config"); p.setConfig(config); - // if it's the last paragraph, let's add a new one boolean isTheLastParagraph = note.getLastParagraph().getId() .equals(p.getId()); if (!Strings.isNullOrEmpty(text) && isTheLastParagraph) { note.addParagraph(); } + note.persist(); broadcastNote(note); try { @@ -580,7 +592,6 @@ public class NotebookServer extends WebSocketServlet implements public static class ParagraphJobListener implements JobListener { private NotebookServer notebookServer; private Note note; - public ParagraphJobListener(NotebookServer notebookServer, Note note) { this.notebookServer = notebookServer; this.note = note; @@ -605,6 +616,7 @@ public class NotebookServer extends WebSocketServlet implements LOG.error("Error", job.getException()); } } + if (job.isTerminated()) { LOG.info("Job {} is finished", job.getId()); try { @@ -613,6 +625,7 @@ public class NotebookServer extends WebSocketServlet implements e.printStackTrace(); } } + notebookServer.broadcastNote(note); } } @@ -621,7 +634,6 @@ public class NotebookServer extends WebSocketServlet implements public JobListener getParagraphJobListener(Note note) { return new ParagraphJobListener(this, note); } - private void pong() { } @@ -666,10 +678,8 @@ public class NotebookServer extends WebSocketServlet implements List<InterpreterSetting> intpSettings = note.getNoteReplLoader() .getInterpreterSettings(); - if (intpSettings.isEmpty()) continue; - for (InterpreterSetting setting : intpSettings) { if (setting.getInterpreterGroup().getId().equals(interpreterGroupId)) { broadcast( @@ -698,9 +708,10 @@ public class NotebookServer extends WebSocketServlet implements broadcast( note.id(), new Message(OP.ANGULAR_OBJECT_REMOVE).put("name", name).put( - "noteId", noteId)); + "noteId", noteId)); } } } } } + http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/d5ab911b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java ---------------------------------------------------------------------- diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java new file mode 100644 index 0000000..3ab06f0 --- /dev/null +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java @@ -0,0 +1,53 @@ +/** + * Created by joelz on 8/6/15. + * + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.socket; + +import org.apache.zeppelin.notebook.Note; +import org.apache.zeppelin.server.ZeppelinServer; +import org.junit.Assert; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; + +import java.io.IOException; +import java.net.UnknownHostException; + +/** + * BASIC Zeppelin rest api tests + * + * + * @author joelz + * + */ + public class NotebookServerTests { + + @Test + public void CheckOrigin() throws UnknownHostException { + NotebookServer server = new NotebookServer(); + Assert.assertTrue(server.checkOrigin(new TestHttpServletRequest(), + "http://" + java.net.InetAddress.getLocalHost().getHostName() + ":8080")); + } + + @Test + public void CheckInvalidOrigin(){ + NotebookServer server = new NotebookServer(); + Assert.assertFalse(server.checkOrigin(new TestHttpServletRequest(), "http://evillocalhost:8080")); + } +} http://git-wip-us.apache.org/repos/asf/incubator-zeppelin/blob/d5ab911b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestHttpServletRequest.java ---------------------------------------------------------------------- diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestHttpServletRequest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestHttpServletRequest.java new file mode 100644 index 0000000..9ec54ba --- /dev/null +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestHttpServletRequest.java @@ -0,0 +1,372 @@ +/** + * Created by joelz on 8/6/15. + * + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.socket; + +import javax.servlet.*; +import javax.servlet.http.*; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.security.Principal; +import java.util.Collection; +import java.util.Enumeration; +import java.util.Locale; +import java.util.Map; + +/** + * Created by joelz on 8/6/15. + * Helps mocking a http servlet request + */ +public class TestHttpServletRequest implements HttpServletRequest { + @Override + public boolean authenticate(HttpServletResponse httpServletResponse) throws IOException, ServletException { + return false; + } + + @Override + public String getAuthType() { + return null; + } + + @Override + public String getContextPath() { + return null; + } + + @Override + public Cookie[] getCookies() { + return new Cookie[0]; + } + + @Override + public long getDateHeader(String s) { + return 0; + } + + @Override + public String getHeader(String s) { + switch (s) { + case "Origin": + return "http://localhost:8080"; + } + + return null; + } + + @Override + public Enumeration<String> getHeaderNames() { + return null; + } + + @Override + public Enumeration<String> getHeaders(String s) { + return null; + } + + @Override + public int getIntHeader(String s) { + return 0; + } + + @Override + public String getMethod() { + return null; + } + + @Override + public Part getPart(String s) throws IOException, ServletException { + return null; + } + + @Override + public Collection<Part> getParts() throws IOException, ServletException { + return null; + } + + @Override + public String getPathInfo() { + return null; + } + + @Override + public String getPathTranslated() { + return null; + } + + @Override + public String getQueryString() { + return null; + } + + @Override + public String getRemoteUser() { + return null; + } + + @Override + public String getRequestedSessionId() { + return null; + } + + @Override + public String getRequestURI() { + return null; + } + + @Override + public StringBuffer getRequestURL() { + return null; + } + + @Override + public String getServletPath() { + return null; + } + + @Override + public HttpSession getSession() { + return null; + } + + @Override + public HttpSession getSession(boolean b) { + return null; + } + + @Override + public Principal getUserPrincipal() { + return null; + } + + @Override + public boolean isRequestedSessionIdFromCookie() { + return false; + } + + @Override + public boolean isRequestedSessionIdFromUrl() { + return false; + } + + @Override + public boolean isRequestedSessionIdFromURL() { + return false; + } + + @Override + public boolean isRequestedSessionIdValid() { + return false; + } + + @Override + public boolean isUserInRole(String s) { + return false; + } + + @Override + public void login(String s, String s1) throws ServletException { + + } + + @Override + public void logout() throws ServletException { + + } + + @Override + public AsyncContext getAsyncContext() { + return null; + } + + @Override + public Object getAttribute(String s) { + return null; + } + + @Override + public Enumeration<String> getAttributeNames() { + return null; + } + + @Override + public String getCharacterEncoding() { + return null; + } + + @Override + public int getContentLength() { + return 0; + } + + @Override + public String getContentType() { + return null; + } + + @Override + public DispatcherType getDispatcherType() { + return null; + } + + @Override + public ServletInputStream getInputStream() throws IOException { + return null; + } + + @Override + public String getLocalAddr() { + return null; + } + + @Override + public Locale getLocale() { + return null; + } + + @Override + public Enumeration<Locale> getLocales() { + return null; + } + + @Override + public String getLocalName() { + return null; + } + + @Override + public int getLocalPort() { + return 0; + } + + @Override + public String getParameter(String s) { + return null; + } + + @Override + public Map<String, String[]> getParameterMap() { + return null; + } + + @Override + public Enumeration<String> getParameterNames() { + return null; + } + + @Override + public String[] getParameterValues(String s) { + return new String[0]; + } + + @Override + public String getProtocol() { + return null; + } + + @Override + public BufferedReader getReader() throws IOException { + return null; + } + + @Override + public String getRealPath(String s) { + return null; + } + + @Override + public String getRemoteAddr() { + return null; + } + + @Override + public String getRemoteHost() { + return null; + } + + @Override + public int getRemotePort() { + return 0; + } + + @Override + public RequestDispatcher getRequestDispatcher(String s) { + return null; + } + + @Override + public String getScheme() { + return null; + } + + @Override + public String getServerName() { + return null; + } + + @Override + public int getServerPort() { + return 0; + } + + @Override + public ServletContext getServletContext() { + return null; + } + + @Override + public boolean isAsyncStarted() { + return false; + } + + @Override + public boolean isAsyncSupported() { + return false; + } + + @Override + public boolean isSecure() { + return false; + } + + @Override + public void removeAttribute(String s) { + + } + + @Override + public void setAttribute(String s, Object o) { + + } + + @Override + public void setCharacterEncoding(String s) throws UnsupportedEncodingException { + + } + + @Override + public AsyncContext startAsync() { + return null; + } + + @Override + public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) { + return null; + } +}
