-
Notifications
You must be signed in to change notification settings - Fork 267
feat: random walks and embeddings #752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
SemyonSinchenko
merged 87 commits into
graphframes:main
from
SemyonSinchenko:726-sampling-api
Mar 10, 2026
Merged
Changes from all commits
Commits
Show all changes
87 commits
Select commit
Hold shift + click to select a range
3dcd569
edges sampling API (scala)
SemyonSinchenko 06f6af6
add seed to z-estimation
SemyonSinchenko e4ede62
wip
SemyonSinchenko 08ee0eb
Merge remote-tracking branch 'graphframes/main' into 726-sampling-api
SemyonSinchenko d050a7a
WIP
SemyonSinchenko 4a0a14f
scalfix
SemyonSinchenko 579a5c3
docstrings to RandomWalkBase and RandomWalkWithRestart
SemyonSinchenko 16dd46c
Fix RandomWalk implementation bugs and add example
SemyonSinchenko 965dea8
Add Word2VecHashingTrick implementation for graph embeddings
SemyonSinchenko 3260cc9
Implement reservoir sampling for neighbor selection in random walks
SemyonSinchenko aafcbc6
fix scalastyle?
SemyonSinchenko f27462d
fix reservoir
SemyonSinchenko 28b8ca5
add hash2vec
SemyonSinchenko b8255ac
fixes
SemyonSinchenko 2113376
docstrings + scalfix
SemyonSinchenko 6244ec2
remove sampling as not needed
SemyonSinchenko a9ce665
fixes in build and code
SemyonSinchenko 0d5da92
Big update
SemyonSinchenko 84306b6
Tests and updates
SemyonSinchenko 5084d88
workaround scala 2.13 deprecation of Searching.search
SemyonSinchenko 0a9dc90
Fixes
SemyonSinchenko f7a9aeb
Fix access
SemyonSinchenko acc5a2c
fallback to java Serialization
SemyonSinchenko 8b20f00
Fix some bugs
SemyonSinchenko 538a72d
Sampling Convolution tests and docstrings
SemyonSinchenko 3548bb1
docstrings for RW Emebddings
SemyonSinchenko 3df9282
Python API
SemyonSinchenko af97cac
hash2vec tests
SemyonSinchenko 5d724b5
Explicit types
SemyonSinchenko 526330c
hash2vec and random walks with restart tests
SemyonSinchenko 061fab3
Merge remote-tracking branch 'graphframes/main' into 726-sampling-api
SemyonSinchenko 7e77bc9
performance
SemyonSinchenko 02c22f8
ignore unused nowarn
SemyonSinchenko 97d529a
performance + cached walks support + continous mode
SemyonSinchenko cc7f908
Merge remote-tracking branch 'graphframes/main' into 726-sampling-api
SemyonSinchenko 48e981a
fix rw and update the branch
SemyonSinchenko de63a45
fix
SemyonSinchenko e1cf2a4
protobuf & connect
SemyonSinchenko d2780a4
public API for embeddings and small refactoring of methods
SemyonSinchenko a294c03
initial Py API for embeddings
SemyonSinchenko 60ca7e5
fix
SemyonSinchenko 6583c20
tests + docs
SemyonSinchenko 54ee3d7
fix connect tests
SemyonSinchenko 867ae47
Merge remote-tracking branch 'graphframes/main' into 726-sampling-api
SemyonSinchenko c7ef2d9
decrease GC pressure
SemyonSinchenko ff107a2
further optimizations
SemyonSinchenko 7f01b9c
refactor: optimize Hash2Vec string hashing and partitioning logic
SemyonSinchenko b7d95f8
refactor: replace generic hash function with type-specific implementa…
SemyonSinchenko 5def505
refactor: simplify hash function logic and improve performance in Has…
SemyonSinchenko 5f27b92
refactor: inline generic processPartitionGeneric into specialized Str…
SemyonSinchenko ed07856
test: add tests for PagedMatrixDouble helper covering page extension,…
SemyonSinchenko 5cc9f1a
refactor: replace unsafe hash functions with MurmurHash3 and optimize…
SemyonSinchenko 895a319
refactor: update processStringPartition to use PagedMatrixDouble for …
SemyonSinchenko 66809bd
docs: add internal documentation for PagedMatrixDouble explaining mem…
SemyonSinchenko 548a17a
chore: remove commented helper section from Hash2Vec
SemyonSinchenko 70d96cb
refactor: replace case class with class for PagedMatrixDouble and rem…
SemyonSinchenko 842cb7e
fix: correct LongMap type parameter in Hash2Vec vocabIndex initializa…
SemyonSinchenko c2c6ff4
refactor: reduce PAGE_BITS from 16 to 12 and update related constants…
SemyonSinchenko 54c32b3
feat: add max vectors per partition limit and batched processing for …
SemyonSinchenko e689b86
refactor: process long partitions in batches respecting max vectors l…
SemyonSinchenko 6d92b2c
test: add Hash2Vec tests for co-occurrence patterns and cosine simila…
SemyonSinchenko ae60afd
chore: clean up code formatting and improve test readability in Hash2Vec
SemyonSinchenko c0f3ca3
fix: correct typo in error message from 'gor' to 'got' in Hash2Vec ex…
SemyonSinchenko b0dc568
docs: add docstrings for Hash2Vec setters setDoNormalization and setM…
SemyonSinchenko 95d73c3
fix: correct typo in NOTICE and KMinSampling, update Hash2Vec default…
SemyonSinchenko de2d6a3
fix: skip seeds for previous batches to maintain consistency when sta…
SemyonSinchenko 3a0a302
fix: add overwrite mode when writing batch results to allow re-runnin…
SemyonSinchenko af7f8f4
feat: add cleanUp method to remove temporary files for a walk ID usin…
SemyonSinchenko 579a2b5
docs: improve documentation for cleanUp method in RandomWalkBase trait
SemyonSinchenko 5e2a462
test: add cleanUp call in RandomWalkWithRestart test
SemyonSinchenko d3cb1a2
test: move walks execution inside try block in RandomWalkWithRestartS…
SemyonSinchenko f5bb5ca
refactor: set walkID default to UUID and remove redundant runID variable
SemyonSinchenko 4b25649
docs: update comment to clarify walkID retrieval method behavior
SemyonSinchenko 0c9262d
refactor: rename walkID to runID for clarity in random walk operations
SemyonSinchenko c3b8d86
refactor: remove runId parameter from cleanUp method in RandomWalkBase
SemyonSinchenko 3309269
refactor: move cleanUp method to companion object with parameters and…
SemyonSinchenko a9ff184
refactor: improve documentation formatting and remove redundant log s…
SemyonSinchenko 9eb4dea
test: verify temporary files are deleted after RandomWalkWithRestart …
SemyonSinchenko 5b349c8
refactor: use numBatches variable and fix run path in RandomWalkWithR…
SemyonSinchenko 5c12113
test: add test for RandomWalkWithRestart resuming from middle iteration
SemyonSinchenko 483c3be
style: Format RandomWalkWithRestartSuite with consistent spacing and …
SemyonSinchenko 5efc5dd
refactor: use getSeq instead of getAs[Seq[String]] in RandomWalkWithR…
SemyonSinchenko 001a2bf
fix: correct typo in error message and parameter name for gaussian sigma
SemyonSinchenko 64aee65
feat: add clean-up option for temporary random walk files in embeddin…
SemyonSinchenko e45381d
docs: improve formatting and consistency in graph-ml documentation ta…
SemyonSinchenko d4f9f50
feat: add clean_up_after_run field to RandomWalkEmbeddings proto defi…
SemyonSinchenko 79be29d
chore: add clean_up_after_run parameter to _RandomWalksEmbeddingsPara…
SemyonSinchenko File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -79,3 +79,8 @@ spark-* | |
|
|
||
| # Zed | ||
| .zed | ||
|
|
||
| # Emacs | ||
| .dir-locals.el | ||
| *~ | ||
| .aider* | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
165 changes: 165 additions & 0 deletions
165
core/src/main/scala/org/apache/spark/sql/graphframes/expressions/KMinSampling.scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| package org.apache.spark.sql.graphframes.expressions | ||
|
|
||
| import org.apache.spark.sql.Encoder | ||
| import org.apache.spark.sql.Encoders | ||
| import org.apache.spark.sql.Row | ||
| import org.apache.spark.sql.SparkSession | ||
| import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder | ||
| import org.apache.spark.sql.expressions.Aggregator | ||
| import org.apache.spark.sql.expressions.UserDefinedFunction | ||
| import org.apache.spark.sql.functions.udaf | ||
| import org.apache.spark.sql.types.* | ||
| import org.apache.spark.sql.types.DataType | ||
| import org.graphframes.GraphFramesUnsupportedVertexTypeException | ||
|
|
||
| import scala.annotation.nowarn | ||
| import scala.reflect.ClassTag | ||
| import scala.reflect.runtime.universe.TypeTag | ||
|
|
||
| case class KMinAccum[T](values: Array[T], weights: Array[Long], var cnt: Int) extends Serializable | ||
|
|
||
| case class KMinSampling[T: ClassTag](size: Int)(implicit | ||
| @nowarn tag: TypeTag[T], | ||
|
rjurney marked this conversation as resolved.
|
||
| ord: Ordering[T]) | ||
| extends Aggregator[Row, KMinAccum[T], Seq[T]] | ||
| with Serializable { | ||
|
|
||
| override def zero: KMinAccum[T] = KMinAccum(Array.ofDim[T](size), Array.ofDim[Long](size), 0) | ||
|
|
||
| override def reduce(b: KMinAccum[T], a: Row): KMinAccum[T] = { | ||
| val newWeight = a.getLong(1) | ||
| val newValue = a.getAs[T](0) | ||
| // fast-path: buffer is already full of "strong" elements | ||
| // the case of "influencer" vertex | ||
| if (b.cnt == size) { | ||
| val lastWeight = b.weights.last | ||
| if ((lastWeight < newWeight) || ((lastWeight == newWeight) && (ord.compare( | ||
| newValue, | ||
| b.values.last) >= 0))) { | ||
| return b | ||
| } | ||
| } | ||
|
|
||
| // slow-path: custom binary search for (Weight, Value) | ||
| // We want to find the first index where (b.w, b.v) > (newWeight, newValue) | ||
| var low = 0 | ||
| var high = b.cnt - 1 | ||
| var idx = b.cnt // Default insertion point is at the end | ||
|
|
||
| while (low <= high) { | ||
| val mid = (low + high) / 2 | ||
| val midWeight = b.weights(mid) | ||
|
|
||
| // Compare (midWeight, midValue) vs (newWeight, newValue) | ||
| val res = | ||
| if (midWeight < newWeight) -1 | ||
| else if (midWeight > newWeight) 1 | ||
| else ord.compare(b.values(mid), newValue) | ||
|
|
||
| if (res <= 0) { | ||
| // mid is smaller or equal: we must insert after mid | ||
| low = mid + 1 | ||
| } else { | ||
| // mid is larger: potential insertion point here | ||
| idx = mid | ||
| high = mid - 1 | ||
| } | ||
| } | ||
|
|
||
| if (idx < size) { | ||
| val newCount = math.min(b.cnt + 1, size) | ||
| if (idx < newCount - 1) { | ||
| // shift to the right if needed | ||
| System.arraycopy(b.weights, idx, b.weights, idx + 1, newCount - idx - 1) | ||
| System.arraycopy(b.values, idx, b.values, idx + 1, newCount - idx - 1) | ||
| } | ||
|
|
||
| b.weights(idx) = newWeight | ||
| b.values(idx) = newValue | ||
| b.cnt = newCount | ||
| } | ||
|
|
||
| b | ||
| } | ||
|
|
||
| override def merge(b1: KMinAccum[T], b2: KMinAccum[T]): KMinAccum[T] = { | ||
|
|
||
| if (b1.cnt == 0) { | ||
| return b2 | ||
| } | ||
|
|
||
| if (b2.cnt == 0) { | ||
| return b1 | ||
| } | ||
|
|
||
| val resultSize = math.min(b1.cnt + b2.cnt, size) | ||
| val newValues = Array.ofDim[T](resultSize) | ||
| val newWeights = Array.ofDim[Long](resultSize) | ||
|
|
||
| var i = 0 | ||
| var j = 0 | ||
| var r = 0 | ||
|
|
||
| while (r < resultSize) { | ||
| val useLeft = if (i >= b1.cnt) { | ||
| false | ||
| } else if (j >= b2.cnt) { | ||
| true | ||
| } else { | ||
| val wLeft = b1.weights(i) | ||
| val wRight = b2.weights(j) | ||
|
|
||
| if (wLeft < wRight) { | ||
| true | ||
| } else if (wLeft > wRight) { | ||
| false | ||
| } else { | ||
| ord.compare(b1.values(i), b2.values(j)) <= 0 | ||
| } | ||
| } | ||
|
|
||
| if (useLeft) { | ||
| newWeights(r) = b1.weights(i) | ||
| newValues(r) = b1.values(i) | ||
| i += 1 | ||
| } else { | ||
| newWeights(r) = b2.weights(j) | ||
| newValues(r) = b2.values(j) | ||
| j += 1 | ||
| } | ||
|
|
||
| r += 1 | ||
| } | ||
|
|
||
| KMinAccum(newValues, newWeights, resultSize) | ||
| } | ||
|
|
||
| override def finish(reduction: KMinAccum[T]): Seq[T] = | ||
| reduction.values.slice(0, reduction.cnt).toSeq | ||
| // TODO: replace by Kryo after 4.0.2 is released, see SPARK-52819 | ||
| override def bufferEncoder: Encoder[KMinAccum[T]] = Encoders.product | ||
| override def outputEncoder: Encoder[Seq[T]] = ExpressionEncoder[Seq[T]]() | ||
| } | ||
|
|
||
| object KMinSampling extends Serializable { | ||
| def getEncoder(spark: SparkSession, dataType: DataType, colNames: Seq[String]): Encoder[Row] = { | ||
| // That is very stupid way actually. But it is the only way with public API | ||
| spark | ||
| .createDataFrame( | ||
| java.util.List.of[Row](), | ||
| StructType( | ||
| StructField(colNames(0), dataType) :: StructField(colNames(1), LongType) :: Nil)) | ||
| .encoder | ||
| } | ||
|
|
||
| def fromSparkType(dataType: DataType, size: Int, encoder: Encoder[Row]): UserDefinedFunction = { | ||
| dataType match { | ||
| case StringType => udaf(KMinSampling[java.lang.String](size), encoder) | ||
| case ShortType => udaf(KMinSampling[java.lang.Short](size), encoder) | ||
| case ByteType => udaf(KMinSampling[java.lang.Byte](size), encoder) | ||
| case IntegerType => udaf(KMinSampling[java.lang.Integer](size), encoder) | ||
| case LongType => udaf(KMinSampling[java.lang.Long](size), encoder) | ||
| case _ => throw new GraphFramesUnsupportedVertexTypeException("unsupported vertex type") | ||
| } | ||
| } | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.