Control collation behavior with a method table.
authorJeff Davis <[email protected]>
Wed, 8 Jan 2025 22:26:33 +0000 (14:26 -0800)
committerJeff Davis <[email protected]>
Wed, 8 Jan 2025 22:26:46 +0000 (14:26 -0800)
Previously, behavior branched based on the provider. A method table is
less error-prone and more flexible.

The ctype behavior will be addressed in an upcoming commit.

Reviewed-by: Andreas Karlsson
Discussion: https://postgr.es/m/2830211e1b6e6a2e26d845780b03e125281ea17b.camel%40j-davis.com

src/backend/utils/adt/pg_locale.c
src/backend/utils/adt/pg_locale_icu.c
src/backend/utils/adt/pg_locale_libc.c
src/include/utils/pg_locale.h

index dc8248fb26967039f3667e017edb629d65054cf9..875cca6efc858cf68ea7413bd666ab7839cbfd27 100644 (file)
@@ -92,27 +92,12 @@ extern char *get_collation_actual_version_builtin(const char *collcollate);
 /* pg_locale_icu.c */
 #ifdef USE_ICU
 extern UCollator *pg_ucol_open(const char *loc_str);
-extern int     strncoll_icu(const char *arg1, ssize_t len1,
-                                                const char *arg2, ssize_t len2,
-                                                pg_locale_t locale);
-extern size_t strnxfrm_icu(char *dest, size_t destsize,
-                                                  const char *src, ssize_t srclen,
-                                                  pg_locale_t locale);
-extern size_t strnxfrm_prefix_icu(char *dest, size_t destsize,
-                                                                 const char *src, ssize_t srclen,
-                                                                 pg_locale_t locale);
 extern char *get_collation_actual_version_icu(const char *collcollate);
 #endif
 extern pg_locale_t create_pg_locale_icu(Oid collid, MemoryContext context);
 
 /* pg_locale_libc.c */
 extern pg_locale_t create_pg_locale_libc(Oid collid, MemoryContext context);
-extern int     strncoll_libc(const char *arg1, ssize_t len1,
-                                                 const char *arg2, ssize_t len2,
-                                                 pg_locale_t locale);
-extern size_t strnxfrm_libc(char *dest, size_t destsize,
-                                                       const char *src, ssize_t srclen,
-                                                       pg_locale_t locale);
 extern char *get_collation_actual_version_libc(const char *collcollate);
 
 extern size_t strlower_builtin(char *dst, size_t dstsize, const char *src,
@@ -1244,6 +1229,9 @@ create_pg_locale(Oid collid, MemoryContext context)
 
        result->is_default = false;
 
+       Assert((result->collate_is_c && result->collate == NULL) ||
+                  (!result->collate_is_c && result->collate != NULL));
+
        datum = SysCacheGetAttr(COLLOID, tp, Anum_pg_collation_collversion,
                                                        &isnull);
        if (!isnull)
@@ -1467,19 +1455,7 @@ pg_strupper(char *dst, size_t dstsize, const char *src, ssize_t srclen,
 int
 pg_strcoll(const char *arg1, const char *arg2, pg_locale_t locale)
 {
-       int                     result;
-
-       if (locale->provider == COLLPROVIDER_LIBC)
-               result = strncoll_libc(arg1, -1, arg2, -1, locale);
-#ifdef USE_ICU
-       else if (locale->provider == COLLPROVIDER_ICU)
-               result = strncoll_icu(arg1, -1, arg2, -1, locale);
-#endif
-       else
-               /* shouldn't happen */
-               PGLOCALE_SUPPORT_ERROR(locale->provider);
-
-       return result;
+       return locale->collate->strncoll(arg1, -1, arg2, -1, locale);
 }
 
 /*
@@ -1500,51 +1476,25 @@ int
 pg_strncoll(const char *arg1, ssize_t len1, const char *arg2, ssize_t len2,
                        pg_locale_t locale)
 {
-       int                     result;
-
-       if (locale->provider == COLLPROVIDER_LIBC)
-               result = strncoll_libc(arg1, len1, arg2, len2, locale);
-#ifdef USE_ICU
-       else if (locale->provider == COLLPROVIDER_ICU)
-               result = strncoll_icu(arg1, len1, arg2, len2, locale);
-#endif
-       else
-               /* shouldn't happen */
-               PGLOCALE_SUPPORT_ERROR(locale->provider);
-
-       return result;
+       return locale->collate->strncoll(arg1, len1, arg2, len2, locale);
 }
 
 /*
  * Return true if the collation provider supports pg_strxfrm() and
  * pg_strnxfrm(); otherwise false.
  *
- * Unfortunately, it seems that strxfrm() for non-C collations is broken on
- * many common platforms; testing of multiple versions of glibc reveals that,
- * for many locales, strcoll() and strxfrm() do not return consistent
- * results. While no other libc other than Cygwin has so far been shown to
- * have a problem, we take the conservative course of action for right now and
- * disable this categorically.  (Users who are certain this isn't a problem on
- * their system can define TRUST_STRXFRM.)
  *
  * No similar problem is known for the ICU provider.
  */
 bool
 pg_strxfrm_enabled(pg_locale_t locale)
 {
-       if (locale->provider == COLLPROVIDER_LIBC)
-#ifdef TRUST_STRXFRM
-               return true;
-#else
-               return false;
-#endif
-       else if (locale->provider == COLLPROVIDER_ICU)
-               return true;
-       else
-               /* shouldn't happen */
-               PGLOCALE_SUPPORT_ERROR(locale->provider);
-
-       return false;                           /* keep compiler quiet */
+       /*
+        * locale->collate->strnxfrm is still a required method, even if it may
+        * have the wrong behavior, because the planner uses it for estimates in
+        * some cases.
+        */
+       return locale->collate->strxfrm_is_safe;
 }
 
 /*
@@ -1555,19 +1505,7 @@ pg_strxfrm_enabled(pg_locale_t locale)
 size_t
 pg_strxfrm(char *dest, const char *src, size_t destsize, pg_locale_t locale)
 {
-       size_t          result = 0;             /* keep compiler quiet */
-
-       if (locale->provider == COLLPROVIDER_LIBC)
-               result = strnxfrm_libc(dest, destsize, src, -1, locale);
-#ifdef USE_ICU
-       else if (locale->provider == COLLPROVIDER_ICU)
-               result = strnxfrm_icu(dest, destsize, src, -1, locale);
-#endif
-       else
-               /* shouldn't happen */
-               PGLOCALE_SUPPORT_ERROR(locale->provider);
-
-       return result;
+       return locale->collate->strnxfrm(dest, destsize, src, -1, locale);
 }
 
 /*
@@ -1593,19 +1531,7 @@ size_t
 pg_strnxfrm(char *dest, size_t destsize, const char *src, ssize_t srclen,
                        pg_locale_t locale)
 {
-       size_t          result = 0;             /* keep compiler quiet */
-
-       if (locale->provider == COLLPROVIDER_LIBC)
-               result = strnxfrm_libc(dest, destsize, src, srclen, locale);
-#ifdef USE_ICU
-       else if (locale->provider == COLLPROVIDER_ICU)
-               result = strnxfrm_icu(dest, destsize, src, srclen, locale);
-#endif
-       else
-               /* shouldn't happen */
-               PGLOCALE_SUPPORT_ERROR(locale->provider);
-
-       return result;
+       return locale->collate->strnxfrm(dest, destsize, src, srclen, locale);
 }
 
 /*
@@ -1615,15 +1541,7 @@ pg_strnxfrm(char *dest, size_t destsize, const char *src, ssize_t srclen,
 bool
 pg_strxfrm_prefix_enabled(pg_locale_t locale)
 {
-       if (locale->provider == COLLPROVIDER_LIBC)
-               return false;
-       else if (locale->provider == COLLPROVIDER_ICU)
-               return true;
-       else
-               /* shouldn't happen */
-               PGLOCALE_SUPPORT_ERROR(locale->provider);
-
-       return false;                           /* keep compiler quiet */
+       return (locale->collate->strnxfrm_prefix != NULL);
 }
 
 /*
@@ -1635,7 +1553,7 @@ size_t
 pg_strxfrm_prefix(char *dest, const char *src, size_t destsize,
                                  pg_locale_t locale)
 {
-       return pg_strnxfrm_prefix(dest, destsize, src, -1, locale);
+       return locale->collate->strnxfrm_prefix(dest, destsize, src, -1, locale);
 }
 
 /*
@@ -1660,16 +1578,7 @@ size_t
 pg_strnxfrm_prefix(char *dest, size_t destsize, const char *src,
                                   ssize_t srclen, pg_locale_t locale)
 {
-       size_t          result = 0;             /* keep compiler quiet */
-
-#ifdef USE_ICU
-       if (locale->provider == COLLPROVIDER_ICU)
-               result = strnxfrm_prefix_icu(dest, destsize, src, -1, locale);
-       else
-#endif
-               PGLOCALE_SUPPORT_ERROR(locale->provider);
-
-       return result;
+       return locale->collate->strnxfrm_prefix(dest, destsize, src, srclen, locale);
 }
 
 /*
index 6e1fb78bbf36f314eb106e66374e912d22b01b69..5185b0f728911fd3ddf4d452360028677301957d 100644 (file)
@@ -58,13 +58,14 @@ extern size_t strupper_icu(char *dst, size_t dstsize, const char *src,
 #ifdef USE_ICU
 
 extern UCollator *pg_ucol_open(const char *loc_str);
-extern int     strncoll_icu(const char *arg1, ssize_t len1,
+
+static int     strncoll_icu(const char *arg1, ssize_t len1,
                                                 const char *arg2, ssize_t len2,
                                                 pg_locale_t locale);
-extern size_t strnxfrm_icu(char *dest, size_t destsize,
+static size_t strnxfrm_icu(char *dest, size_t destsize,
                                                   const char *src, ssize_t srclen,
                                                   pg_locale_t locale);
-extern size_t strnxfrm_prefix_icu(char *dest, size_t destsize,
+static size_t strnxfrm_prefix_icu(char *dest, size_t destsize,
                                                                  const char *src, ssize_t srclen,
                                                                  pg_locale_t locale);
 extern char *get_collation_actual_version_icu(const char *collcollate);
@@ -83,12 +84,20 @@ static UConverter *icu_converter = NULL;
 
 static UCollator *make_icu_collator(const char *iculocstr,
                                                                        const char *icurules);
-static int     strncoll_icu_no_utf8(const char *arg1, ssize_t len1,
-                                                                const char *arg2, ssize_t len2,
-                                                                pg_locale_t locale);
-static size_t strnxfrm_prefix_icu_no_utf8(char *dest, size_t destsize,
-                                                                                 const char *src, ssize_t srclen,
-                                                                                 pg_locale_t locale);
+static int     strncoll_icu(const char *arg1, ssize_t len1,
+                                                const char *arg2, ssize_t len2,
+                                                pg_locale_t locale);
+static size_t strnxfrm_prefix_icu(char *dest, size_t destsize,
+                                                                 const char *src, ssize_t srclen,
+                                                                 pg_locale_t locale);
+#ifdef HAVE_UCOL_STRCOLLUTF8
+static int     strncoll_icu_utf8(const char *arg1, ssize_t len1,
+                                                         const char *arg2, ssize_t len2,
+                                                         pg_locale_t locale);
+#endif
+static size_t strnxfrm_prefix_icu_utf8(char *dest, size_t destsize,
+                                                                          const char *src, ssize_t srclen,
+                                                                          pg_locale_t locale);
 static void init_icu_converter(void);
 static size_t uchar_length(UConverter *converter,
                                                   const char *str, int32_t len);
@@ -108,6 +117,25 @@ static int32_t u_strToTitle_default_BI(UChar *dest, int32_t destCapacity,
                                                                           const UChar *src, int32_t srcLength,
                                                                           const char *locale,
                                                                           UErrorCode *pErrorCode);
+
+static const struct collate_methods collate_methods_icu = {
+       .strncoll = strncoll_icu,
+       .strnxfrm = strnxfrm_icu,
+       .strnxfrm_prefix = strnxfrm_prefix_icu,
+       .strxfrm_is_safe = true,
+};
+
+static const struct collate_methods collate_methods_icu_utf8 = {
+#ifdef HAVE_UCOL_STRCOLLUTF8
+       .strncoll = strncoll_icu_utf8,
+#else
+       .strncoll = strncoll_icu,
+#endif
+       .strnxfrm = strnxfrm_icu,
+       .strnxfrm_prefix = strnxfrm_prefix_icu_utf8,
+       .strxfrm_is_safe = true,
+};
+
 #endif
 
 pg_locale_t
@@ -174,6 +202,10 @@ create_pg_locale_icu(Oid collid, MemoryContext context)
        result->deterministic = deterministic;
        result->collate_is_c = false;
        result->ctype_is_c = false;
+       if (GetDatabaseEncoding() == PG_UTF8)
+               result->collate = &collate_methods_icu_utf8;
+       else
+               result->collate = &collate_methods_icu;
 
        return result;
 #else
@@ -408,42 +440,36 @@ strupper_icu(char *dest, size_t destsize, const char *src, ssize_t srclen,
 }
 
 /*
- * strncoll_icu
+ * strncoll_icu_utf8
  *
  * Call ucol_strcollUTF8() or ucol_strcoll() as appropriate for the given
  * database encoding. An argument length of -1 means the string is
  * NUL-terminated.
  */
+#ifdef HAVE_UCOL_STRCOLLUTF8
 int
-strncoll_icu(const char *arg1, ssize_t len1, const char *arg2, ssize_t len2,
-                        pg_locale_t locale)
+strncoll_icu_utf8(const char *arg1, ssize_t len1, const char *arg2, ssize_t len2,
+                                 pg_locale_t locale)
 {
        int                     result;
+       UErrorCode      status;
 
        Assert(locale->provider == COLLPROVIDER_ICU);
 
-#ifdef HAVE_UCOL_STRCOLLUTF8
-       if (GetDatabaseEncoding() == PG_UTF8)
-       {
-               UErrorCode      status;
+       Assert(GetDatabaseEncoding() == PG_UTF8);
 
-               status = U_ZERO_ERROR;
-               result = ucol_strcollUTF8(locale->info.icu.ucol,
-                                                                 arg1, len1,
-                                                                 arg2, len2,
-                                                                 &status);
-               if (U_FAILURE(status))
-                       ereport(ERROR,
-                                       (errmsg("collation failed: %s", u_errorName(status))));
-       }
-       else
-#endif
-       {
-               result = strncoll_icu_no_utf8(arg1, len1, arg2, len2, locale);
-       }
+       status = U_ZERO_ERROR;
+       result = ucol_strcollUTF8(locale->info.icu.ucol,
+                                                         arg1, len1,
+                                                         arg2, len2,
+                                                         &status);
+       if (U_FAILURE(status))
+               ereport(ERROR,
+                               (errmsg("collation failed: %s", u_errorName(status))));
 
        return result;
 }
+#endif
 
 /* 'srclen' of -1 means the strings are NUL-terminated */
 size_t
@@ -494,37 +520,32 @@ strnxfrm_icu(char *dest, size_t destsize, const char *src, ssize_t srclen,
 
 /* 'srclen' of -1 means the strings are NUL-terminated */
 size_t
-strnxfrm_prefix_icu(char *dest, size_t destsize,
-                                       const char *src, ssize_t srclen,
-                                       pg_locale_t locale)
+strnxfrm_prefix_icu_utf8(char *dest, size_t destsize,
+                                                const char *src, ssize_t srclen,
+                                                pg_locale_t locale)
 {
        size_t          result;
+       UCharIterator iter;
+       uint32_t        state[2];
+       UErrorCode      status;
 
        Assert(locale->provider == COLLPROVIDER_ICU);
 
-       if (GetDatabaseEncoding() == PG_UTF8)
-       {
-               UCharIterator iter;
-               uint32_t        state[2];
-               UErrorCode      status;
+       Assert(GetDatabaseEncoding() == PG_UTF8);
 
-               uiter_setUTF8(&iter, src, srclen);
-               state[0] = state[1] = 0;        /* won't need that again */
-               status = U_ZERO_ERROR;
-               result = ucol_nextSortKeyPart(locale->info.icu.ucol,
-                                                                         &iter,
-                                                                         state,
-                                                                         (uint8_t *) dest,
-                                                                         destsize,
-                                                                         &status);
-               if (U_FAILURE(status))
-                       ereport(ERROR,
-                                       (errmsg("sort key generation failed: %s",
-                                                       u_errorName(status))));
-       }
-       else
-               result = strnxfrm_prefix_icu_no_utf8(dest, destsize, src, srclen,
-                                                                                        locale);
+       uiter_setUTF8(&iter, src, srclen);
+       state[0] = state[1] = 0;        /* won't need that again */
+       status = U_ZERO_ERROR;
+       result = ucol_nextSortKeyPart(locale->info.icu.ucol,
+                                                                 &iter,
+                                                                 state,
+                                                                 (uint8_t *) dest,
+                                                                 destsize,
+                                                                 &status);
+       if (U_FAILURE(status))
+               ereport(ERROR,
+                               (errmsg("sort key generation failed: %s",
+                                               u_errorName(status))));
 
        return result;
 }
@@ -653,7 +674,7 @@ u_strToTitle_default_BI(UChar *dest, int32_t destCapacity,
 }
 
 /*
- * strncoll_icu_no_utf8
+ * strncoll_icu
  *
  * Convert the arguments from the database encoding to UChar strings, then
  * call ucol_strcoll(). An argument length of -1 means that the string is
@@ -663,8 +684,8 @@ u_strToTitle_default_BI(UChar *dest, int32_t destCapacity,
  * caller should call that instead.
  */
 static int
-strncoll_icu_no_utf8(const char *arg1, ssize_t len1,
-                                        const char *arg2, ssize_t len2, pg_locale_t locale)
+strncoll_icu(const char *arg1, ssize_t len1,
+                        const char *arg2, ssize_t len2, pg_locale_t locale)
 {
        char            sbuf[TEXTBUFLEN];
        char       *buf = sbuf;
@@ -677,6 +698,8 @@ strncoll_icu_no_utf8(const char *arg1, ssize_t len1,
        int                     result;
 
        Assert(locale->provider == COLLPROVIDER_ICU);
+
+       /* if encoding is UTF8, use more efficient strncoll_icu_utf8 */
 #ifdef HAVE_UCOL_STRCOLLUTF8
        Assert(GetDatabaseEncoding() != PG_UTF8);
 #endif
@@ -710,9 +733,9 @@ strncoll_icu_no_utf8(const char *arg1, ssize_t len1,
 
 /* 'srclen' of -1 means the strings are NUL-terminated */
 static size_t
-strnxfrm_prefix_icu_no_utf8(char *dest, size_t destsize,
-                                                       const char *src, ssize_t srclen,
-                                                       pg_locale_t locale)
+strnxfrm_prefix_icu(char *dest, size_t destsize,
+                                       const char *src, ssize_t srclen,
+                                       pg_locale_t locale)
 {
        char            sbuf[TEXTBUFLEN];
        char       *buf = sbuf;
@@ -725,6 +748,8 @@ strnxfrm_prefix_icu_no_utf8(char *dest, size_t destsize,
        Size            result_bsize;
 
        Assert(locale->provider == COLLPROVIDER_ICU);
+
+       /* if encoding is UTF8, use more efficient strnxfrm_prefix_icu_utf8 */
        Assert(GetDatabaseEncoding() != PG_UTF8);
 
        init_icu_converter();
index 81120061b503efbea3e0b989013c9889fdaa8afb..8f9a86378971b4fa32a5d638b4a2aae172d0365c 100644 (file)
@@ -50,10 +50,10 @@ extern size_t strtitle_libc(char *dst, size_t dstsize, const char *src,
 extern size_t strupper_libc(char *dst, size_t dstsize, const char *src,
                                                        ssize_t srclen, pg_locale_t locale);
 
-extern int     strncoll_libc(const char *arg1, ssize_t len1,
+static int     strncoll_libc(const char *arg1, ssize_t len1,
                                                  const char *arg2, ssize_t len2,
                                                  pg_locale_t locale);
-extern size_t strnxfrm_libc(char *dest, size_t destsize,
+static size_t strnxfrm_libc(char *dest, size_t destsize,
                                                        const char *src, ssize_t srclen,
                                                        pg_locale_t locale);
 extern char *get_collation_actual_version_libc(const char *collcollate);
@@ -86,6 +86,40 @@ static size_t strupper_libc_mb(char *dest, size_t destsize,
                                                           const char *src, ssize_t srclen,
                                                           pg_locale_t locale);
 
+static const struct collate_methods collate_methods_libc = {
+       .strncoll = strncoll_libc,
+       .strnxfrm = strnxfrm_libc,
+       .strnxfrm_prefix = NULL,
+
+       /*
+        * Unfortunately, it seems that strxfrm() for non-C collations is broken
+        * on many common platforms; testing of multiple versions of glibc reveals
+        * that, for many locales, strcoll() and strxfrm() do not return
+        * consistent results. While no other libc other than Cygwin has so far
+        * been shown to have a problem, we take the conservative course of action
+        * for right now and disable this categorically.  (Users who are certain
+        * this isn't a problem on their system can define TRUST_STRXFRM.)
+        */
+#ifdef TRUST_STRXFRM
+       .strxfrm_is_safe = true,
+#else
+       .strxfrm_is_safe = false,
+#endif
+};
+
+#ifdef WIN32
+static const struct collate_methods collate_methods_libc_win32_utf8 = {
+       .strncoll = strncoll_libc_win32_utf8,
+       .strnxfrm = strnxfrm_libc,
+       .strnxfrm_prefix = NULL,
+#ifdef TRUST_STRXFRM
+       .strxfrm_is_safe = true,
+#else
+       .strxfrm_is_safe = false,
+#endif
+};
+#endif
+
 size_t
 strlower_libc(char *dst, size_t dstsize, const char *src,
                          ssize_t srclen, pg_locale_t locale)
@@ -439,6 +473,15 @@ create_pg_locale_libc(Oid collid, MemoryContext context)
        result->ctype_is_c = (strcmp(ctype, "C") == 0) ||
                (strcmp(ctype, "POSIX") == 0);
        result->info.lt = loc;
+       if (!result->collate_is_c)
+       {
+#ifdef WIN32
+               if (GetDatabaseEncoding() == PG_UTF8)
+                       result->collate = &collate_methods_libc_win32_utf8;
+               else
+#endif
+                       result->collate = &collate_methods_libc;
+       }
 
        return result;
 }
@@ -536,12 +579,6 @@ strncoll_libc(const char *arg1, ssize_t len1, const char *arg2, ssize_t len2,
 
        Assert(locale->provider == COLLPROVIDER_LIBC);
 
-#ifdef WIN32
-       /* check for this case before doing the work for nul-termination */
-       if (GetDatabaseEncoding() == PG_UTF8)
-               return strncoll_libc_win32_utf8(arg1, len1, arg2, len2, locale);
-#endif                                                 /* WIN32 */
-
        if (bufsize1 + bufsize2 > TEXTBUFLEN)
                buf = palloc(bufsize1 + bufsize2);
 
index 7a70ec1173a32734d4696dc6362562d796bc68c8..97b866b344470f2738811121aa5d905caec661e4 100644 (file)
@@ -47,6 +47,36 @@ extern struct lconv *PGLC_localeconv(void);
 extern void cache_locale_time(void);
 
 
+struct pg_locale_struct;
+typedef struct pg_locale_struct *pg_locale_t;
+
+/* methods that define collation behavior */
+struct collate_methods
+{
+       /* required */
+       int                     (*strncoll) (const char *arg1, ssize_t len1,
+                                                        const char *arg2, ssize_t len2,
+                                                        pg_locale_t locale);
+
+       /* required */
+       size_t          (*strnxfrm) (char *dest, size_t destsize,
+                                                        const char *src, ssize_t srclen,
+                                                        pg_locale_t locale);
+
+       /* optional */
+       size_t          (*strnxfrm_prefix) (char *dest, size_t destsize,
+                                                                       const char *src, ssize_t srclen,
+                                                                       pg_locale_t locale);
+
+       /*
+        * If the strnxfrm method is not trusted to return the correct results,
+        * set strxfrm_is_safe to false. It set to false, the method will not be
+        * used in most cases, but the planner still expects it to be there for
+        * estimation purposes (where incorrect results are acceptable).
+        */
+       bool            strxfrm_is_safe;
+};
+
 /*
  * We use a discriminated union to hold either a locale_t or an ICU collator.
  * pg_locale_t is occasionally checked for truth, so make it a pointer.
@@ -70,6 +100,9 @@ struct pg_locale_struct
        bool            collate_is_c;
        bool            ctype_is_c;
        bool            is_default;
+
+       const struct collate_methods *collate;  /* NULL if collate_is_c */
+
        union
        {
                struct