summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXavier Leroy <xavier.leroy@inria.fr>2003-11-07 07:59:10 +0000
committerXavier Leroy <xavier.leroy@inria.fr>2003-11-07 07:59:10 +0000
commit099f19538061e848ba060bbb792c2372e4f0b082 (patch)
tree7cb248d2ec8fcb1fdd7caff29ad4f6be11995021
parent5ed1c19bb26f9ea836279c530ffa4297f8a59e1f (diff)
downloadocaml-099f19538061e848ba060bbb792c2372e4f0b082.tar.gz
Primitive C pour l'elevation au carre
git-svn-id: http://caml.inria.fr/svn/ocaml/trunk@5900 f963ae5c-01c2-4b8c-9fe0-0dff7051ff02
-rw-r--r--otherlibs/num/bng.c39
-rw-r--r--otherlibs/num/bng.h7
-rw-r--r--otherlibs/num/nat.ml6
-rw-r--r--otherlibs/num/nat.mli2
-rw-r--r--otherlibs/num/nat_stubs.c14
-rw-r--r--otherlibs/num/test/test_bng.c14
6 files changed, 81 insertions, 1 deletions
diff --git a/otherlibs/num/bng.c b/otherlibs/num/bng.c
index 6033b6f8b1..c96af2cf28 100644
--- a/otherlibs/num/bng.c
+++ b/otherlibs/num/bng.c
@@ -271,6 +271,44 @@ static bngcarry bng_generic_mult_add
return carry;
}
+/* {a,alen} := 2 * {a,alen} + {b,blen}^2. Return carry out.
+ Require alen >= 2 * blen. */
+static bngcarry bng_generic_square_add
+ (bng a/*[alen]*/, bngsize alen,
+ bng b/*[blen]*/, bngsize blen)
+{
+ bngcarry carry1, carry2;
+ bngsize i, aofs;
+ bngdigit ph, pl, d;
+
+ /* Double products */
+ for (carry1 = 0, i = 1; i < blen; i++) {
+ aofs = 2 * i - 1;
+ carry1 += bng_mult_add_digit(a + aofs, alen - aofs,
+ b + i, blen - i, b[i - 1]);
+ }
+ /* Multiply by two */
+ carry1 = (carry1 << 1) | bng_shift_left(a, alen, 1);
+ /* Add square of digits */
+ carry2 = 0;
+ for (i = 0; i < blen; i++) {
+ d = b[i];
+ BngMult(ph, pl, d, d);
+ BngAdd2Carry(*a, carry2, *a, pl, carry2);
+ a++;
+ BngAdd2Carry(*a, carry2, *a, ph, carry2);
+ a++;
+ }
+ alen -= 2 * blen;
+ if (alen > 0 && carry2 != 0) {
+ do {
+ if (++(*a) != 0) { carry2 = 0; break; }
+ a++;
+ } while (--alen);
+ }
+ return carry1 + carry2;
+}
+
/* {a,len-1} := {b,len} / d. Return {b,len} modulo d.
Require MSD of b < d.
If BngDivNeedsNormalization is defined, require d normalized. */
@@ -378,6 +416,7 @@ struct bng_operations bng_ops = {
bng_generic_mult_add_digit,
bng_generic_mult_sub_digit,
bng_generic_mult_add,
+ bng_generic_square_add,
bng_generic_div_rem_norm_digit,
#ifdef BngDivNeedsNormalization
bng_generic_div_rem_digit,
diff --git a/otherlibs/num/bng.h b/otherlibs/num/bng.h
index f887e8c9fe..28c6b2d105 100644
--- a/otherlibs/num/bng.h
+++ b/otherlibs/num/bng.h
@@ -95,6 +95,13 @@ struct bng_operations {
bng c/*[clen]*/, bngsize clen);
#define bng_mult_add bng_ops.mult_add
+ /* {a,alen} := 2 * {a,alen} + {b,blen}^2. Return carry out.
+ Require alen >= 2 * blen. */
+ bngcarry (*square_add)
+ (bng a/*[alen]*/, bngsize alen,
+ bng b/*[blen]*/, bngsize blen);
+#define bng_square_add bng_ops.square_add
+
/* {a,len-1} := {b,len} / d. Return {b,len} modulo d.
Require d is normalized and MSD of b < d.
See div_rem_digit for a function that does not require d
diff --git a/otherlibs/num/nat.ml b/otherlibs/num/nat.ml
index 02f1b1b7ad..c8c0c2fc7c 100644
--- a/otherlibs/num/nat.ml
+++ b/otherlibs/num/nat.ml
@@ -35,6 +35,7 @@ external decr_nat: nat -> int -> int -> int -> int = "decr_nat"
external sub_nat: nat -> int -> int -> nat -> int -> int -> int -> int = "sub_nat" "sub_nat_native"
external mult_digit_nat: nat -> int -> int -> nat -> int -> int -> nat -> int -> int = "mult_digit_nat" "mult_digit_nat_native"
external mult_nat: nat -> int -> int -> nat -> int -> int -> nat -> int -> int -> int = "mult_nat" "mult_nat_native"
+external square_nat: nat -> int -> int -> nat -> int -> int -> int = "square_nat" "square_nat_native"
external shift_left_nat: nat -> int -> int -> nat -> int -> int -> unit = "shift_left_nat" "shift_left_nat_native"
external div_digit_nat: nat -> int -> nat -> int -> nat -> int -> int -> nat -> int -> unit = "div_digit_nat" "div_digit_nat_native"
external div_nat: nat -> int -> int -> nat -> int -> int -> unit = "div_nat" "div_nat_native"
@@ -101,6 +102,10 @@ and gt_nat nat1 off1 len1 nat2 off2 len2 =
compare_nat nat1 off1 (num_digits_nat nat1 off1 len1)
nat2 off2 (num_digits_nat nat2 off2 len2) > 0
+(* XL: now implemented in C for better performance.
+ The code below doesn't handle carries correctly.
+ Fortunately, the carry is never used. *)
+(***
let square_nat nat1 off1 len1 nat2 off2 len2 =
let c = ref 0
and trash = make_nat 1 in
@@ -130,6 +135,7 @@ let square_nat nat1 off1 len1 nat2 off2 len2 =
(off2 + i)
done;
!c
+***)
let gcd_int_nat i nat off len =
if i = 0 then 1 else
diff --git a/otherlibs/num/nat.mli b/otherlibs/num/nat.mli
index 4fadf39fc8..18cd812011 100644
--- a/otherlibs/num/nat.mli
+++ b/otherlibs/num/nat.mli
@@ -45,6 +45,7 @@ external decr_nat: nat -> int -> int -> int -> int = "decr_nat"
external sub_nat: nat -> int -> int -> nat -> int -> int -> int -> int = "sub_nat" "sub_nat_native"
external mult_digit_nat: nat -> int -> int -> nat -> int -> int -> nat -> int -> int = "mult_digit_nat" "mult_digit_nat_native"
external mult_nat: nat -> int -> int -> nat -> int -> int -> nat -> int -> int -> int = "mult_nat" "mult_nat_native"
+external square_nat: nat -> int -> int -> nat -> int -> int -> int = "square_nat" "square_nat_native"
external shift_left_nat: nat -> int -> int -> nat -> int -> int -> unit = "shift_left_nat" "shift_left_nat_native"
external div_digit_nat: nat -> int -> nat -> int -> nat -> int -> int -> nat -> int -> unit = "div_digit_nat" "div_digit_nat_native"
external div_nat: nat -> int -> int -> nat -> int -> int -> unit = "div_nat" "div_nat_native"
@@ -59,7 +60,6 @@ val gt_nat : nat -> int -> int -> nat -> int -> int -> bool
external land_digit_nat: nat -> int -> nat -> int -> unit = "land_digit_nat"
external lor_digit_nat: nat -> int -> nat -> int -> unit = "lor_digit_nat"
external lxor_digit_nat: nat -> int -> nat -> int -> unit = "lxor_digit_nat"
-val square_nat : nat -> int -> int -> nat -> int -> int -> int
val gcd_nat : nat -> int -> int -> nat -> int -> int -> int
val sqrt_nat : nat -> int -> int -> nat
val string_of_nat : nat -> string
diff --git a/otherlibs/num/nat_stubs.c b/otherlibs/num/nat_stubs.c
index bf3fa31326..a7fb7dcfe0 100644
--- a/otherlibs/num/nat_stubs.c
+++ b/otherlibs/num/nat_stubs.c
@@ -195,6 +195,20 @@ CAMLprim value mult_nat(value *argv, int argn)
argv[4], argv[5], argv[6], argv[7], argv[8]);
}
+value square_nat_native(value nat1, value ofs1, value len1,
+ value nat2, value ofs2, value len2)
+{
+ return
+ Val_long(bng_square_add(&Digit_val(nat1, Long_val(ofs1)), Long_val(len1),
+ &Digit_val(nat2, Long_val(ofs2)), Long_val(len2)));
+}
+
+CAMLprim value square_nat(value *argv, int argn)
+{
+ return square_nat_native(argv[0], argv[1], argv[2],
+ argv[3], argv[4], argv[5]);
+}
+
value shift_left_nat_native(value nat1, value ofs1, value len1,
value nat2, value ofs2, value nbits)
{
diff --git a/otherlibs/num/test/test_bng.c b/otherlibs/num/test/test_bng.c
index 4d92e43eaa..4fedcdfd56 100644
--- a/otherlibs/num/test/test_bng.c
+++ b/otherlibs/num/test/test_bng.c
@@ -344,6 +344,20 @@ int test_bng_ops(int i)
bng2string(g, 2*p), co);
return 1;
}
+ /* square_add */
+ randbng(f, 2*p);
+ bng_assign(g, f, 2*p);
+ co = bng_square_add(g, 2*p, b, q);
+ BnnAssign(h, f, 2*p);
+ cp = BnnAdd(h, 2*p, h, 2*p);
+ cp += BnnMultiply(h, 2*p, b, q, b, q);
+ if (co != cp || !bngsame(g, h, 2*p)) {
+ printf("Round %d, bng_square_add(%s, %ld, %s, %ld) -> %s, %d\n",
+ i, bng2string(f, 2*p), 2*p,
+ bng2string(b, q), q,
+ bng2string(g, 2*p), co);
+ return 1;
+ }
/* div_rem_digit */
if (a[p - 1] < dg) {
do_ = bng_div_rem_digit(c, a, p, dg);