Signed-off-by: Michael Hanselmann <[email protected]>
---
 daemons/import-export           |    9 ++++-
 lib/impexpd/__init__.py         |   22 +++++++++++--
 test/ganeti.impexpd_unittest.py |   68 +++++++++++++++++++++++++++++++++++++-
 3 files changed, 93 insertions(+), 6 deletions(-)

diff --git a/daemons/import-export b/daemons/import-export
index 0700339..f2bf874 100755
--- a/daemons/import-export
+++ b/daemons/import-export
@@ -362,6 +362,10 @@ def ParseOptions():
                     help="X509 CA file")
   parser.add_option("--bind", dest="bind", action="store", type="string",
                     help="Bind address")
+  parser.add_option("--ipv4", dest="ipv4", action="store_true",
+                    help="Use IPv4 only")
+  parser.add_option("--ipv6", dest="ipv6", action="store_true",
+                    help="Use IPv6 only")
   parser.add_option("--host", dest="host", action="store", type="string",
                     help="Remote hostname")
   parser.add_option("--port", dest="port", action="store", type="int",
@@ -401,7 +405,7 @@ def ParseOptions():
     parser.error("Invalid mode: %s" % mode)
 
   # Normalize and check parameters
-  if options.host is not None:
+  if options.host is not None and not netutils.IPAddress.IsValid(options.host):
     try:
       options.host = netutils.Hostname.GetNormalizedName(options.host)
     except errors.OpPrereqError, err:
@@ -423,6 +427,9 @@ def ParseOptions():
     parser.error("Magic must match regular expression %s" %
                  constants.IE_MAGIC_RE.pattern)
 
+  if options.ipv4 and options.ipv6:
+    parser.error("Can only use one of --ipv4 and --ipv6")
+
   return (status_file_path, mode)
 
 
diff --git a/lib/impexpd/__init__.py b/lib/impexpd/__init__.py
index 983640e..cf3a8ca 100644
--- a/lib/impexpd/__init__.py
+++ b/lib/impexpd/__init__.py
@@ -35,6 +35,7 @@ from cStringIO import StringIO
 from ganeti import constants
 from ganeti import errors
 from ganeti import utils
+from ganeti import netutils
 
 
 #: Used to recognize point at which socat(1) starts to listen on its socket.
@@ -144,6 +145,13 @@ class CommandBuilder(object):
     if self._opts.bind is not None:
       common_addr_opts.append("bind=%s" % self._opts.bind)
 
+    assert not (self._opts.ipv4 and self._opts.ipv6)
+
+    if self._opts.ipv4:
+      common_addr_opts.append("pf=ipv4")
+    elif self._opts.ipv6:
+      common_addr_opts.append("pf=ipv6")
+
     if self._mode == constants.IEM_IMPORT:
       if self._opts.port is None:
         port = 0
@@ -162,9 +170,14 @@ class CommandBuilder(object):
       addr2 = ["stdout"]
 
     elif self._mode == constants.IEM_EXPORT:
+      if self._opts.host and netutils.IP6Address.IsValid(self._opts.host):
+        host = "[%s]" % self._opts.host
+      else:
+        host = self._opts.host
+
       addr1 = ["stdin"]
       addr2 = [
-        "OPENSSL:%s:%s" % (self._opts.host, self._opts.port),
+        "OPENSSL:%s:%s" % (host, self._opts.port),
 
         # How long to wait per connection attempt
         "connect-timeout=%s" % self._opts.connect_timeout,
@@ -329,10 +342,13 @@ def _VerifyListening(family, address, port):
   """Verify address given as listening address by socat.
 
   """
-  # TODO: Implement IPv6 support
-  if family != socket.AF_INET:
+  if family not in (socket.AF_INET, socket.AF_INET6):
     raise errors.GenericError("Address family %r not supported" % family)
 
+  if (family == socket.AF_INET6 and address.startswith("[") and
+      address.endswith("]")):
+    address = address.lstrip("[").rstrip("]")
+
   try:
     packed_address = socket.inet_pton(family, address)
   except socket.error:
diff --git a/test/ganeti.impexpd_unittest.py b/test/ganeti.impexpd_unittest.py
index e9e3391..17c2489 100755
--- a/test/ganeti.impexpd_unittest.py
+++ b/test/ganeti.impexpd_unittest.py
@@ -25,6 +25,7 @@ import os
 import sys
 import re
 import unittest
+import socket
 
 from ganeti import constants
 from ganeti import objects
@@ -44,6 +45,8 @@ class CmdBuilderConfig(objects.ConfigObject):
     "ca",
     "host",
     "port",
+    "ipv4",
+    "ipv6",
     "compress",
     "magic",
     "connect_timeout",
@@ -101,10 +104,10 @@ class TestCommandBuilder(unittest.TestCase):
                   self.assert_(CheckCmdWord(cmd, comprcmd))
 
                 if cmd_prefix is not None:
-                  self.assert_(cmd_prefix in i for i in cmd)
+                  self.assert_(compat.any(cmd_prefix in i for i in cmd))
 
                 if cmd_suffix is not None:
-                  self.assert_(cmd_suffix in i for i in cmd)
+                  self.assert_(compat.any(cmd_suffix in i for i in cmd))
 
                 # Check socat command
                 socat_cmd = builder._GetSocatCommand()
@@ -118,6 +121,34 @@ class TestCommandBuilder(unittest.TestCase):
 
                 self.assert_("verify=1" in ssl_addr)
 
+  def testIPv6(self):
+    for mode in [constants.IEM_IMPORT, constants.IEM_EXPORT]:
+      opts = CmdBuilderConfig(host="localhost", port=6789,
+                              ipv4=False, ipv6=False)
+      builder = impexpd.CommandBuilder(mode, opts, 1, 2, 3)
+      cmd = builder._GetSocatCommand()
+      self.assert_(compat.all("pf=" not in i for i in cmd))
+
+      # IPv4
+      opts = CmdBuilderConfig(host="localhost", port=6789,
+                              ipv4=True, ipv6=False)
+      builder = impexpd.CommandBuilder(mode, opts, 1, 2, 3)
+      cmd = builder._GetSocatCommand()
+      self.assert_(compat.any(",pf=ipv4" in i for i in cmd))
+
+      # IPv6
+      opts = CmdBuilderConfig(host="localhost", port=6789,
+                              ipv4=False, ipv6=True)
+      builder = impexpd.CommandBuilder(mode, opts, 1, 2, 3)
+      cmd = builder._GetSocatCommand()
+      self.assert_(compat.any(",pf=ipv6" in i for i in cmd))
+
+      # IPv4 and IPv6
+      opts = CmdBuilderConfig(host="localhost", port=6789,
+                              ipv4=True, ipv6=True)
+      builder = impexpd.CommandBuilder(mode, opts, 1, 2, 3)
+      self.assertRaises(AssertionError, builder._GetSocatCommand)
+
   def testCommaError(self):
     opts = CmdBuilderConfig(host="localhost", port=1234,
                             ca="/some/path/with,a/,comma")
@@ -155,6 +186,39 @@ class TestCommandBuilder(unittest.TestCase):
     self.assertRaises(errors.GenericError, builder.GetCommand)
 
 
+class TestVerifyListening(unittest.TestCase):
+  def test(self):
+    self.assertEqual(impexpd._VerifyListening(socket.AF_INET,
+                                              "192.0.2.7", 1234),
+                     ("192.0.2.7", 1234))
+    self.assertEqual(impexpd._VerifyListening(socket.AF_INET6, "::1", 9876),
+                     ("::1", 9876))
+    self.assertEqual(impexpd._VerifyListening(socket.AF_INET6, "[::1]", 4563),
+                     ("::1", 4563))
+    self.assertEqual(impexpd._VerifyListening(socket.AF_INET6,
+                                              "[2001:db8::1:4563]", 4563),
+                     ("2001:db8::1:4563", 4563))
+
+  def testError(self):
+    for family in [socket.AF_UNIX, socket.AF_INET, socket.AF_INET6]:
+      self.assertRaises(errors.GenericError, impexpd._VerifyListening,
+                        family, "", 1234)
+      self.assertRaises(errors.GenericError, impexpd._VerifyListening,
+                        family, "192", 999)
+
+    for family in [socket.AF_UNIX, socket.AF_INET6]:
+      self.assertRaises(errors.GenericError, impexpd._VerifyListening,
+                        family, "192.0.2.7", 1234)
+      self.assertRaises(errors.GenericError, impexpd._VerifyListening,
+                        family, "[2001:db8::1", 1234)
+      self.assertRaises(errors.GenericError, impexpd._VerifyListening,
+                        family, "2001:db8::1]", 1234)
+
+    for family in [socket.AF_UNIX, socket.AF_INET]:
+      self.assertRaises(errors.GenericError, impexpd._VerifyListening,
+                        family, "::1", 1234)
+
+
 class TestCalcThroughput(unittest.TestCase):
   def test(self):
     self.assertEqual(impexpd._CalcThroughput([]), None)
-- 
1.7.3.1

Reply via email to