assertAlmostEqual w testach jednostkowych Pythona dla kolekcji pływaków

81

Metoda assertAlmostEqual (x, y) w ramach testów jednostkowych Pythona sprawdza , czy xi ysą w przybliżeniu równe, zakładając, że są zmiennoprzecinkowe.

Problem assertAlmostEqual()polega na tym, że działa tylko na pływakach. Szukam metody takiej, assertAlmostEqual()która działa na listach pływaków, zestawach pływaków, słownikach pływaków, krotkach pływaków, listach krotek pływaków, zestawach list pływaków itp.

Na przykład, powiedzmy x = 0.1234567890, y = 0.1234567891. xi ysą prawie równe, ponieważ zgadzają się co do każdej cyfry z wyjątkiem ostatniej. Dlatego self.assertAlmostEqual(x, y)to Truedlatego, że assertAlmostEqual()działa na pływaki.

Szukam bardziej ogólnego, assertAlmostEquals()który ocenia również następujące wywołania True:

  • self.assertAlmostEqual_generic([x, x, x], [y, y, y]).
  • self.assertAlmostEqual_generic({1: x, 2: x, 3: x}, {1: y, 2: y, 3: y}).
  • self.assertAlmostEqual_generic([(x,x)], [(y,y)]).

Czy jest taka metoda, czy muszę ją wdrożyć samodzielnie?

Wyjaśnienia:

  • assertAlmostEquals()ma opcjonalny parametr o nazwie, placesa liczby są porównywane poprzez obliczenie różnicy zaokrąglonej do liczby dziesiętnej places. Domyślnie places=7więc self.assertAlmostEqual(0.5, 0.4)jest False, podczas gdy self.assertAlmostEqual(0.12345678, 0.12345679)jest True. Mój spekulant assertAlmostEqual_generic()powinien mieć taką samą funkcjonalność.

  • Dwie listy są uważane za prawie równe, jeśli mają prawie równe liczby w dokładnie tej samej kolejności. formalnie for i in range(n): self.assertAlmostEqual(list1[i], list2[i]).

  • Podobnie dwa zestawy są uważane za prawie równe, jeśli można je przekształcić w prawie równe listy (przez przypisanie kolejności do każdego zestawu).

  • Podobnie, dwa słowniki są uważane za prawie równe, jeśli zestaw kluczy każdego słownika jest prawie równy zestawowi kluczy drugiego słownika, a dla każdej takiej prawie równej pary kluczy odpowiada prawie równa wartość.

  • Ogólnie: uważam, że dwie kolekcje są prawie równe, jeśli są równe, z wyjątkiem kilku odpowiadających im elementów zmiennoprzecinkowych, które są prawie równe sobie. Innymi słowy, chciałbym naprawdę porównać obiekty, ale z małą (dostosowaną) precyzją podczas porównywania elementów pływających po drodze.

wąż
źródło
Jaki jest sens używania floatkluczy w słowniku? Ponieważ nie możesz być pewien, że uzyskasz dokładnie ten sam zmiennoprzecinkowy, nigdy nie znajdziesz swoich przedmiotów za pomocą wyszukiwania. A jeśli nie używasz wyszukiwania, dlaczego nie użyć po prostu listy krotek zamiast słownika? Ten sam argument dotyczy zbiorów.
maksymalnie
Tylko link do źródła o assertAlmostEqual.
djvg

Odpowiedzi:

71

jeśli nie masz nic przeciwko używaniu NumPy (który jest dostarczany z twoim Pythonem (x, y)), możesz spojrzeć na np.testingmoduł, który definiuje między innymi assert_almost_equalfunkcję.

Podpis jest np.testing.assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True)

>>> x = 1.000001
>>> y = 1.000002
>>> np.testing.assert_almost_equal(x, y)
AssertionError: 
Arrays are not almost equal to 7 decimals
ACTUAL: 1.000001
DESIRED: 1.000002
>>> np.testing.assert_almost_equal(x, y, 5)
>>> np.testing.assert_almost_equal([x, x, x], [y, y, y], 5)
>>> np.testing.assert_almost_equal((x, x, x), (y, y, y), 5)
Pierre GM
źródło
4
To blisko, ale numpy.testingprawie równe metody działają tylko na liczbach, tablicach, krotkach i listach. Nie działają na słownikach, zestawach i zbiorach zbiorów.
snakile
Rzeczywiście, ale to dopiero początek. Poza tym masz dostęp do kodu źródłowego, który możesz zmodyfikować, aby umożliwić porównywanie słowników, kolekcji i tak dalej. np.testing.assert_equalna przykład rozpoznaje słowniki jako argumenty (nawet jeśli porównanie jest dokonywane za pomocą elementu, ==który nie zadziała).
Pierre GM
Oczywiście nadal będziesz mieć problemy podczas porównywania zestawów, jak wspomniał @BrenBarn.
Pierre GM
Zauważ, że obecna dokumentacja assert_array_almost_equalzaleca stosowanie assert_allclose, assert_array_almost_equal_nulplub assert_array_max_ulpzamiast.
phunehehe
10

Począwszy od Pythona 3.5 możesz porównać użycie

math.isclose(a, b, rel_tol=1e-9, abs_tol=0.0)

Jak opisano w pep-0485 . Implementacja powinna być równoważna z

abs(a-b) <= max( rel_tol * max(abs(a), abs(b)), abs_tol )
Maximiliano Ramirez Mellado
źródło
7
W jaki sposób pomaga to porównać pojemniki z pływakami, o które pytano?
max
9

Oto, jak I zostały wdrożone rodzajowe is_almost_equal(first, second)funkcji :

Najpierw skopiuj obiekty, które chcesz porównać ( firsti second), ale nie rób dokładnej kopii: wytnij nieznaczące cyfry dziesiętne każdego pływaka, który napotkasz wewnątrz obiektu.

Teraz, gdy masz kopie firsti seconddla których zniknęły nieznaczące cyfry dziesiętne, po prostu porównaj firsti secondużyj ==operatora.

Załóżmy, że mamy cut_insignificant_digits_recursively(obj, places)funkcję, która powiela, objale pozostawia placesw oryginale tylko najbardziej znaczące cyfry dziesiętne każdej liczby zmiennoprzecinkowej obj. Oto działająca implementacja is_almost_equals(first, second, places):

from insignificant_digit_cutter import cut_insignificant_digits_recursively

def is_almost_equal(first, second, places):
    '''returns True if first and second equal. 
    returns true if first and second aren't equal but have exactly the same
    structure and values except for a bunch of floats which are just almost
    equal (floats are almost equal if they're equal when we consider only the
    [places] most significant digits of each).'''
    if first == second: return True
    cut_first = cut_insignificant_digits_recursively(first, places)
    cut_second = cut_insignificant_digits_recursively(second, places)
    return cut_first == cut_second

A oto działająca implementacja cut_insignificant_digits_recursively(obj, places):

def cut_insignificant_digits(number, places):
    '''cut the least significant decimal digits of a number, 
    leave only [places] decimal digits'''
    if  type(number) != float: return number
    number_as_str = str(number)
    end_of_number = number_as_str.find('.')+places+1
    if end_of_number > len(number_as_str): return number
    return float(number_as_str[:end_of_number])

def cut_insignificant_digits_lazy(iterable, places):
    for obj in iterable:
        yield cut_insignificant_digits_recursively(obj, places)

def cut_insignificant_digits_recursively(obj, places):
    '''return a copy of obj except that every float loses its least significant 
    decimal digits remaining only [places] decimal digits'''
    t = type(obj)
    if t == float: return cut_insignificant_digits(obj, places)
    if t in (list, tuple, set):
        return t(cut_insignificant_digits_lazy(obj, places))
    if t == dict:
        return {cut_insignificant_digits_recursively(key, places):
                cut_insignificant_digits_recursively(val, places)
                for key,val in obj.items()}
    return obj

Kod i jego testy jednostkowe są dostępne tutaj: https://github.com/snakile/approximate_comparator . Z zadowoleniem przyjmuję wszelkie ulepszenia i poprawki błędów.

wąż
źródło
Zamiast porównywać zmiennoprzecinkowe, porównujesz łańcuchy? OK ... Ale czy w takim razie nie byłoby łatwiej ustawić wspólny format? Podoba Ci się fmt="{{0:{0}f}}".format(decimals)i użyć tego fmtformatu do „stringify” swoich pływaków?
Pierre GM
1
Wygląda to ładnie, ale mała placeskropka : podaje liczbę miejsc po przecinku, a nie liczbę cyfr znaczących. Na przykład porównywanie 1024.123i 1023.999do 3 znaczących powinno zwracać równe, ale do 3 miejsc po przecinku tak nie jest.
Rodney Richardson,
1
@pir, licencja jest rzeczywiście niezdefiniowana. Zobacz odpowiedź Snalile w tym numerze, w którym mówi, że nie ma czasu na wybranie / dodanie licencji, ale przyznaje uprawnienia do używania / modyfikacji. Dzięki za udostępnienie tego, BTW.
Jérôme
1
@RodneyRichardson, tak, to są miejsca dziesiętne, jak w assertAlmostEqual : "Zauważ, że te metody zaokrągla wartości do podanej liczby miejsc dziesiętnych (tj. Jak funkcja round ()) i nie znaczących cyfr."
Jérôme
2
@ Jérôme, dzięki za komentarz. Właśnie dodałem licencję MIT.
snakile
5

Jeśli nie przeszkadza przy użyciu numpypakietu następnie numpy.testingma assert_array_almost_equalmetody.

Działa to w przypadku array_likeobiektów, więc jest dobre dla tablic, list i krotek elementów zmiennoprzecinkowych, ale nie działa w przypadku zestawów i słowników.

Dokumentacja jest tutaj .

DJCowley
źródło
4

Nie ma takiej metody, musiałbyś to zrobić sam.

W przypadku list i krotek definicja jest oczywista, ale pamiętaj, że inne wspomniane przypadki nie są oczywiste, więc nic dziwnego, że taka funkcja nie jest dostarczana. Na przykład jest {1.00001: 1.00002}prawie równa {1.00002: 1.00001}? Obsługa takich przypadków wymaga dokonania wyboru, czy bliskość zależy od kluczy czy wartości, czy też obu. W przypadku zestawów jest mało prawdopodobne, abyś znalazł sensowną definicję, ponieważ zestawy są nieuporządkowane, więc nie ma pojęcia „odpowiadających” elementów.

BrenBarn
źródło
BrenBarn: Dodałem wyjaśnienia do pytania. Odpowiedź na twoje pytanie jest taka, że {1.00001: 1.00002}prawie równa się {1.00002: 1.00001}wtedy i tylko wtedy, gdy 1,00001 prawie równa się 1,00002. Domyślnie nie są prawie równe (ponieważ domyślna dokładność to 7 miejsc po przecinku), ale dla wystarczająco małej wartości, placessą prawie równe.
snakile
1
@BrenBarn: IMO, używanie kluczy typu floatw dict powinno być odradzane (a może nawet zabronione) z oczywistych powodów. Przybliżona równość dyktowania powinna opierać się tylko na wartościach; Framework testowy nie musi martwić się nieprawidłowym użyciem floatkluczy dla. W przypadku zestawów można je sortować przed porównaniem, a posortowane listy można porównać.
maksymalnie
2

Być może będziesz musiał to zaimplementować samodzielnie, podczas gdy to prawda, że ​​listy i zestawy można iterować w ten sam sposób, słowniki to inna historia, iterujesz ich klucze, a nie wartości, a trzeci przykład wydaje mi się nieco niejednoznaczny, czy masz na myśli porównaj każdą wartość w zestawie lub każdą wartość z każdego zestawu.

oto prosty fragment kodu.

def almost_equal(value_1, value_2, accuracy = 10**-8):
    return abs(value_1 - value_2) < accuracy

x = [1,2,3,4]
y = [1,2,4,5]
assert all(almost_equal(*values) for values in zip(x, y))
Samy Vilar
źródło
Dzięki, rozwiązanie jest poprawne w przypadku list i krotek, ale nie w przypadku innych typów kolekcji (lub kolekcji zagnieżdżonych). Zobacz wyjaśnienia, które dodałem do pytania. Mam nadzieję, że mój zamiar jest teraz jasny. Dwa zbiory są prawie równe, gdyby były uważane za równe w świecie, w którym liczby nie są mierzone bardzo dokładnie.
snakile
1

Żadna z tych odpowiedzi nie działa dla mnie. Poniższy kod powinien działać dla kolekcji, klas, klas danych i nazwanych krotek języka Python. Mogłem o czymś zapomnieć, ale na razie działa to dla mnie.

import unittest
from collections import namedtuple, OrderedDict
from dataclasses import dataclass
from typing import Any


def are_almost_equal(o1: Any, o2: Any, max_abs_ratio_diff: float, max_abs_diff: float) -> bool:
    """
    Compares two objects by recursively walking them trough. Equality is as usual except for floats.
    Floats are compared according to the two measures defined below.

    :param o1: The first object.
    :param o2: The second object.
    :param max_abs_ratio_diff: The maximum allowed absolute value of the difference.
    `abs(1 - (o1 / o2)` and vice-versa if o2 == 0.0. Ignored if < 0.
    :param max_abs_diff: The maximum allowed absolute difference `abs(o1 - o2)`. Ignored if < 0.
    :return: Whether the two objects are almost equal.
    """
    if type(o1) != type(o2):
        return False

    composite_type_passed = False

    if hasattr(o1, '__slots__'):
        if len(o1.__slots__) != len(o2.__slots__):
            return False
        if any(not are_almost_equal(getattr(o1, s1), getattr(o2, s2),
                                    max_abs_ratio_diff, max_abs_diff)
            for s1, s2 in zip(sorted(o1.__slots__), sorted(o2.__slots__))):
            return False
        else:
            composite_type_passed = True

    if hasattr(o1, '__dict__'):
        if len(o1.__dict__) != len(o2.__dict__):
            return False
        if any(not are_almost_equal(k1, k2, max_abs_ratio_diff, max_abs_diff)
            or not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
            for ((k1, v1), (k2, v2))
            in zip(sorted(o1.__dict__.items()), sorted(o2.__dict__.items()))
            if not k1.startswith('__')):  # avoid infinite loops
            return False
        else:
            composite_type_passed = True

    if isinstance(o1, dict):
        if len(o1) != len(o2):
            return False
        if any(not are_almost_equal(k1, k2, max_abs_ratio_diff, max_abs_diff)
            or not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
            for ((k1, v1), (k2, v2)) in zip(sorted(o1.items()), sorted(o2.items()))):
            return False

    elif any(issubclass(o1.__class__, c) for c in (list, tuple, set)):
        if len(o1) != len(o2):
            return False
        if any(not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
            for v1, v2 in zip(o1, o2)):
            return False

    elif isinstance(o1, float):
        if o1 == o2:
            return True
        else:
            if max_abs_ratio_diff > 0:  # if max_abs_ratio_diff < 0, max_abs_ratio_diff is ignored
                if o2 != 0:
                    if abs(1.0 - (o1 / o2)) > max_abs_ratio_diff:
                        return False
                else:  # if both == 0, we already returned True
                    if abs(1.0 - (o2 / o1)) > max_abs_ratio_diff:
                        return False
            if 0 < max_abs_diff < abs(o1 - o2):  # if max_abs_diff < 0, max_abs_diff is ignored
                return False
            return True

    else:
        if not composite_type_passed:
            return o1 == o2

    return True


class EqualityTest(unittest.TestCase):

    def test_floats(self) -> None:
        o1 = ('hi', 3, 3.4)
        o2 = ('hi', 3, 3.400001)
        self.assertTrue(are_almost_equal(o1, o2, 0.0001, 0.0001))
        self.assertFalse(are_almost_equal(o1, o2, 0.00000001, 0.00000001))

    def test_ratio_only(self):
        o1 = ['hey', 10000, 123.12]
        o2 = ['hey', 10000, 123.80]
        self.assertTrue(are_almost_equal(o1, o2, 0.01, -1))
        self.assertFalse(are_almost_equal(o1, o2, 0.001, -1))

    def test_diff_only(self):
        o1 = ['hey', 10000, 1234567890.12]
        o2 = ['hey', 10000, 1234567890.80]
        self.assertTrue(are_almost_equal(o1, o2, -1, 1))
        self.assertFalse(are_almost_equal(o1, o2, -1, 0.1))

    def test_both_ignored(self):
        o1 = ['hey', 10000, 1234567890.12]
        o2 = ['hey', 10000, 0.80]
        o3 = ['hi', 10000, 0.80]
        self.assertTrue(are_almost_equal(o1, o2, -1, -1))
        self.assertFalse(are_almost_equal(o1, o3, -1, -1))

    def test_different_lengths(self):
        o1 = ['hey', 1234567890.12, 10000]
        o2 = ['hey', 1234567890.80]
        self.assertFalse(are_almost_equal(o1, o2, 1, 1))

    def test_classes(self):
        class A:
            d = 12.3

            def __init__(self, a, b, c):
                self.a = a
                self.b = b
                self.c = c

        o1 = A(2.34, 'str', {1: 'hey', 345.23: [123, 'hi', 890.12]})
        o2 = A(2.34, 'str', {1: 'hey', 345.231: [123, 'hi', 890.121]})
        self.assertTrue(are_almost_equal(o1, o2, 0.1, 0.1))
        self.assertFalse(are_almost_equal(o1, o2, 0.0001, 0.0001))

        o2.hello = 'hello'
        self.assertFalse(are_almost_equal(o1, o2, -1, -1))

    def test_namedtuples(self):
        B = namedtuple('B', ['x', 'y'])
        o1 = B(3.3, 4.4)
        o2 = B(3.4, 4.5)
        self.assertTrue(are_almost_equal(o1, o2, 0.2, 0.2))
        self.assertFalse(are_almost_equal(o1, o2, 0.001, 0.001))

    def test_classes_with_slots(self):
        class C(object):
            __slots__ = ['a', 'b']

            def __init__(self, a, b):
                self.a = a
                self.b = b

        o1 = C(3.3, 4.4)
        o2 = C(3.4, 4.5)
        self.assertTrue(are_almost_equal(o1, o2, 0.3, 0.3))
        self.assertFalse(are_almost_equal(o1, o2, -1, 0.01))

    def test_dataclasses(self):
        @dataclass
        class D:
            s: str
            i: int
            f: float

        @dataclass
        class E:
            f2: float
            f4: str
            d: D

        o1 = E(12.3, 'hi', D('hello', 34, 20.01))
        o2 = E(12.1, 'hi', D('hello', 34, 20.0))
        self.assertTrue(are_almost_equal(o1, o2, -1, 0.4))
        self.assertFalse(are_almost_equal(o1, o2, -1, 0.001))

        o3 = E(12.1, 'hi', D('ciao', 34, 20.0))
        self.assertFalse(are_almost_equal(o2, o3, -1, -1))

    def test_ordereddict(self):
        o1 = OrderedDict({1: 'hey', 345.23: [123, 'hi', 890.12]})
        o2 = OrderedDict({1: 'hey', 345.23: [123, 'hi', 890.0]})
        self.assertTrue(are_almost_equal(o1, o2, 0.01, -1))
        self.assertFalse(are_almost_equal(o1, o2, 0.0001, -1))
redsk
źródło
0

Nadal bym używał, self.assertEqual()ponieważ pozostaje najbardziej pouczający, gdy gówno uderza w wentylator. Możesz to zrobić, zaokrąglając np.

self.assertEqual(round_tuple((13.949999999999999, 1.121212), 2), (13.95, 1.12))

gdzie round_tuplejest

def round_tuple(t: tuple, ndigits: int) -> tuple:
    return tuple(round(e, ndigits=ndigits) for e in t)

def round_list(l: list, ndigits: int) -> list:
    return [round(e, ndigits=ndigits) for e in l]

Zgodnie z dokumentacją Pythona (patrz https://stackoverflow.com/a/41407651/1031191 ) możesz uciec z zaokrągleniami, takimi jak 13.94999999, ponieważ 13.94999999 == 13.95jest True.

Barney Szabolcs
źródło
-1

Alternatywnym podejściem jest konwersja danych do porównywalnej formy, np. Poprzez zamianę każdego elementu zmiennoprzecinkowego na ciąg o ustalonej precyzji.

def comparable(data):
    """Converts `data` to a comparable structure by converting any floats to a string with fixed precision."""
    if isinstance(data, (int, str)):
        return data
    if isinstance(data, float):
        return '{:.4f}'.format(data)
    if isinstance(data, list):
        return [comparable(el) for el in data]
    if isinstance(data, tuple):
        return tuple([comparable(el) for el in data])
    if isinstance(data, dict):
        return {k: comparable(v) for k, v in data.items()}

Następnie możesz:

self.assertEquals(comparable(value1), comparable(value2))
Karl Rosaen
źródło