Signed-off-by: Michael Hanselmann <[email protected]>
---
daemons/import-export | 9 ++++-
lib/impexpd/__init__.py | 22 +++++++++++--
test/ganeti.impexpd_unittest.py | 68 +++++++++++++++++++++++++++++++++++++-
3 files changed, 93 insertions(+), 6 deletions(-)
diff --git a/daemons/import-export b/daemons/import-export
index 0700339..f2bf874 100755
--- a/daemons/import-export
+++ b/daemons/import-export
@@ -362,6 +362,10 @@ def ParseOptions():
help="X509 CA file")
parser.add_option("--bind", dest="bind", action="store", type="string",
help="Bind address")
+ parser.add_option("--ipv4", dest="ipv4", action="store_true",
+ help="Use IPv4 only")
+ parser.add_option("--ipv6", dest="ipv6", action="store_true",
+ help="Use IPv6 only")
parser.add_option("--host", dest="host", action="store", type="string",
help="Remote hostname")
parser.add_option("--port", dest="port", action="store", type="int",
@@ -401,7 +405,7 @@ def ParseOptions():
parser.error("Invalid mode: %s" % mode)
# Normalize and check parameters
- if options.host is not None:
+ if options.host is not None and not netutils.IPAddress.IsValid(options.host):
try:
options.host = netutils.Hostname.GetNormalizedName(options.host)
except errors.OpPrereqError, err:
@@ -423,6 +427,9 @@ def ParseOptions():
parser.error("Magic must match regular expression %s" %
constants.IE_MAGIC_RE.pattern)
+ if options.ipv4 and options.ipv6:
+ parser.error("Can only use one of --ipv4 and --ipv6")
+
return (status_file_path, mode)
diff --git a/lib/impexpd/__init__.py b/lib/impexpd/__init__.py
index 983640e..cf3a8ca 100644
--- a/lib/impexpd/__init__.py
+++ b/lib/impexpd/__init__.py
@@ -35,6 +35,7 @@ from cStringIO import StringIO
from ganeti import constants
from ganeti import errors
from ganeti import utils
+from ganeti import netutils
#: Used to recognize point at which socat(1) starts to listen on its socket.
@@ -144,6 +145,13 @@ class CommandBuilder(object):
if self._opts.bind is not None:
common_addr_opts.append("bind=%s" % self._opts.bind)
+ assert not (self._opts.ipv4 and self._opts.ipv6)
+
+ if self._opts.ipv4:
+ common_addr_opts.append("pf=ipv4")
+ elif self._opts.ipv6:
+ common_addr_opts.append("pf=ipv6")
+
if self._mode == constants.IEM_IMPORT:
if self._opts.port is None:
port = 0
@@ -162,9 +170,14 @@ class CommandBuilder(object):
addr2 = ["stdout"]
elif self._mode == constants.IEM_EXPORT:
+ if self._opts.host and netutils.IP6Address.IsValid(self._opts.host):
+ host = "[%s]" % self._opts.host
+ else:
+ host = self._opts.host
+
addr1 = ["stdin"]
addr2 = [
- "OPENSSL:%s:%s" % (self._opts.host, self._opts.port),
+ "OPENSSL:%s:%s" % (host, self._opts.port),
# How long to wait per connection attempt
"connect-timeout=%s" % self._opts.connect_timeout,
@@ -329,10 +342,13 @@ def _VerifyListening(family, address, port):
"""Verify address given as listening address by socat.
"""
- # TODO: Implement IPv6 support
- if family != socket.AF_INET:
+ if family not in (socket.AF_INET, socket.AF_INET6):
raise errors.GenericError("Address family %r not supported" % family)
+ if (family == socket.AF_INET6 and address.startswith("[") and
+ address.endswith("]")):
+ address = address.lstrip("[").rstrip("]")
+
try:
packed_address = socket.inet_pton(family, address)
except socket.error:
diff --git a/test/ganeti.impexpd_unittest.py b/test/ganeti.impexpd_unittest.py
index e9e3391..17c2489 100755
--- a/test/ganeti.impexpd_unittest.py
+++ b/test/ganeti.impexpd_unittest.py
@@ -25,6 +25,7 @@ import os
import sys
import re
import unittest
+import socket
from ganeti import constants
from ganeti import objects
@@ -44,6 +45,8 @@ class CmdBuilderConfig(objects.ConfigObject):
"ca",
"host",
"port",
+ "ipv4",
+ "ipv6",
"compress",
"magic",
"connect_timeout",
@@ -101,10 +104,10 @@ class TestCommandBuilder(unittest.TestCase):
self.assert_(CheckCmdWord(cmd, comprcmd))
if cmd_prefix is not None:
- self.assert_(cmd_prefix in i for i in cmd)
+ self.assert_(compat.any(cmd_prefix in i for i in cmd))
if cmd_suffix is not None:
- self.assert_(cmd_suffix in i for i in cmd)
+ self.assert_(compat.any(cmd_suffix in i for i in cmd))
# Check socat command
socat_cmd = builder._GetSocatCommand()
@@ -118,6 +121,34 @@ class TestCommandBuilder(unittest.TestCase):
self.assert_("verify=1" in ssl_addr)
+ def testIPv6(self):
+ for mode in [constants.IEM_IMPORT, constants.IEM_EXPORT]:
+ opts = CmdBuilderConfig(host="localhost", port=6789,
+ ipv4=False, ipv6=False)
+ builder = impexpd.CommandBuilder(mode, opts, 1, 2, 3)
+ cmd = builder._GetSocatCommand()
+ self.assert_(compat.all("pf=" not in i for i in cmd))
+
+ # IPv4
+ opts = CmdBuilderConfig(host="localhost", port=6789,
+ ipv4=True, ipv6=False)
+ builder = impexpd.CommandBuilder(mode, opts, 1, 2, 3)
+ cmd = builder._GetSocatCommand()
+ self.assert_(compat.any(",pf=ipv4" in i for i in cmd))
+
+ # IPv6
+ opts = CmdBuilderConfig(host="localhost", port=6789,
+ ipv4=False, ipv6=True)
+ builder = impexpd.CommandBuilder(mode, opts, 1, 2, 3)
+ cmd = builder._GetSocatCommand()
+ self.assert_(compat.any(",pf=ipv6" in i for i in cmd))
+
+ # IPv4 and IPv6
+ opts = CmdBuilderConfig(host="localhost", port=6789,
+ ipv4=True, ipv6=True)
+ builder = impexpd.CommandBuilder(mode, opts, 1, 2, 3)
+ self.assertRaises(AssertionError, builder._GetSocatCommand)
+
def testCommaError(self):
opts = CmdBuilderConfig(host="localhost", port=1234,
ca="/some/path/with,a/,comma")
@@ -155,6 +186,39 @@ class TestCommandBuilder(unittest.TestCase):
self.assertRaises(errors.GenericError, builder.GetCommand)
+class TestVerifyListening(unittest.TestCase):
+ def test(self):
+ self.assertEqual(impexpd._VerifyListening(socket.AF_INET,
+ "192.0.2.7", 1234),
+ ("192.0.2.7", 1234))
+ self.assertEqual(impexpd._VerifyListening(socket.AF_INET6, "::1", 9876),
+ ("::1", 9876))
+ self.assertEqual(impexpd._VerifyListening(socket.AF_INET6, "[::1]", 4563),
+ ("::1", 4563))
+ self.assertEqual(impexpd._VerifyListening(socket.AF_INET6,
+ "[2001:db8::1:4563]", 4563),
+ ("2001:db8::1:4563", 4563))
+
+ def testError(self):
+ for family in [socket.AF_UNIX, socket.AF_INET, socket.AF_INET6]:
+ self.assertRaises(errors.GenericError, impexpd._VerifyListening,
+ family, "", 1234)
+ self.assertRaises(errors.GenericError, impexpd._VerifyListening,
+ family, "192", 999)
+
+ for family in [socket.AF_UNIX, socket.AF_INET6]:
+ self.assertRaises(errors.GenericError, impexpd._VerifyListening,
+ family, "192.0.2.7", 1234)
+ self.assertRaises(errors.GenericError, impexpd._VerifyListening,
+ family, "[2001:db8::1", 1234)
+ self.assertRaises(errors.GenericError, impexpd._VerifyListening,
+ family, "2001:db8::1]", 1234)
+
+ for family in [socket.AF_UNIX, socket.AF_INET]:
+ self.assertRaises(errors.GenericError, impexpd._VerifyListening,
+ family, "::1", 1234)
+
+
class TestCalcThroughput(unittest.TestCase):
def test(self):
self.assertEqual(impexpd._CalcThroughput([]), None)
--
1.7.3.1